Showing preview only (931K chars total). Download the full file or copy to clipboard to get everything.
Repository: Physical-Intelligence/real-time-chunking-kinetix
Branch: main
Commit: 9296f31d62d5
Files: 23
Total size: 905.0 KB
Directory structure:
gitextract_sgcc59ek/
├── .gitignore
├── .gitmodules
├── LICENSE
├── README.md
├── pyproject.toml
├── src/
│ ├── eval_flow.py
│ ├── generate_data.py
│ ├── model.py
│ ├── render_levels.py
│ ├── train_expert.py
│ └── train_flow.py
└── worlds/
└── l/
├── car_launch.json
├── cartpole_thrust.json
├── catapult.json
├── catcher_v3.json
├── chain_lander.json
├── grasp_easy.json
├── h17_unicycle.json
├── hard_lunar_lander.json
├── mjc_half_cheetah.json
├── mjc_swimmer.json
├── mjc_walker.json
└── trampoline.json
================================================
FILE CONTENTS
================================================
================================================
FILE: .gitignore
================================================
# Created by https://www.toptal.com/developers/gitignore/api/visualstudiocode,python,vim
# Edit at https://www.toptal.com/developers/gitignore?templates=visualstudiocode,python,vim
### Python ###
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
.pybuilder/
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
# For a library or package, you might want to ignore these files since the code is
# intended to run in multiple environments; otherwise, check them in:
# .python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# poetry
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
# This is especially recommended for binary packages to ensure reproducibility, and is more
# commonly ignored for libraries.
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
#poetry.lock
# pdm
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
#pdm.lock
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
# in version control.
# https://pdm.fming.dev/#use-with-ide
.pdm.toml
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# pytype static type analyzer
.pytype/
# Cython debug symbols
cython_debug/
# PyCharm
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/
### Python Patch ###
# Poetry local configuration file - https://python-poetry.org/docs/configuration/#local-configuration
poetry.toml
# ruff
.ruff_cache/
# LSP config files
pyrightconfig.json
### Vim ###
# Swap
[._]*.s[a-v][a-z]
!*.svg # comment out if you don't need vector files
[._]*.sw[a-p]
[._]s[a-rt-v][a-z]
[._]ss[a-gi-z]
[._]sw[a-p]
# Session
Session.vim
Sessionx.vim
# Temporary
.netrwhist
*~
# Auto-generated tag files
tags
# Persistent undo
[._]*.un~
### VisualStudioCode ###
.vscode/*
!.vscode/settings.json
!.vscode/tasks.json
!.vscode/launch.json
!.vscode/extensions.json
!.vscode/*.code-snippets
# Local History for Visual Studio Code
.history/
# Built Visual Studio Code Extensions
*.vsix
### VisualStudioCode Patch ###
# Ignore all local history of files
.history
.ionide
# End of https://www.toptal.com/developers/gitignore/api/visualstudiocode,python,vim
wandb/
================================================
FILE: .gitmodules
================================================
[submodule "third_party/kinetix"]
path = third_party/kinetix
url = https://github.com/FlairOx/Kinetix.git
================================================
FILE: LICENSE
================================================
MIT License
Copyright (c) 2025 Physical Intelligence
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
================================================
FILE: README.md
================================================
Simulated experiments for the papers [Real-Time Execution of Action Chunking Flow Policies](https://arxiv.org/abs/2506.07339) and [Training-Time Action Conditioning for Efficient Real-Time Chunking](https://arxiv.org/abs/2512.05964).
## Installation
```bash
# Clone Kinetix submodule
git submodule update --init
# Install uv
curl -LsSf https://astral.sh/uv/install.sh | sh
# Install dependencies
uv sync
```
## Pre-trained checkpoints and data
`gs://rtc-assets/expert/` contains expert checkpoints generated by `src/train_expert.py`, and `gs://rtc-assets/expert/data/` contains million-transition datasets for each level (generated by `src/generate_data.py`). Be aware that the `expert/` directory is about 60GiB in total.
`gs://rtc-assets/bc/` contains imitation learning policies for each level trained on the aforementioned data (generated by `src/train_flow.py`). These are directly usable with `src/eval_flow.py`.
## Reproduce results
Note that, for all scripts, your number of GPUs must divide the number of levels (default 12) because computation is
sharded over levels.
1. Train expert policies: `uv run src/train_expert.py`
- By default, this will train 8 seeds per level for 65 million environment steps each.
- Checkpoints, videos, and stats are written to a wandb project called `rtc-kinetix-expert` and the local directory `./logs-expert/<wandb-run-name>`. It is recommended to control other wandb options, like the run name, using environment variables.
2. Generate data: `uv run src/generate_data.py --config.run-path ./logs-expert/<wandb-run-name>`
- For each level, this will automatically load the best-performing checkpoint for each seed (discarding seeds that didn't reach a certain success threshold).
- By default, 1 million environment steps are collected for each level using a mixture of expert policies.
- Data is written back to `./logs-expert/<wandb-run-name>/data/`.
3. Train imitation learning policies: `uv run src/train_flow.py --config.run-path ./logs-expert/<wandb-run-name>`
- This will load the data from step 2 and train flow matching policies for each level.
- Checkpoints, videos, and stats are written to a wandb project called `rtc-kinetix-bc` and the local directory `./logs-bc/<wandb-run-name>`. It is recommended to control other wandb options, like the run name, using environment variables.
4. Evaluate imitation learning policies: `uv run src/eval_flow.py --config.run-path ./logs-bc/<wandb-run-name> --output-dir <output-dir>`
- This will load the checkpoints from step 3 and evaluate them for 2048 trials per level by default.
- Currently, the script performs an exhaustive sweep over inference delay and execution horizon for all methods.
## Training-Time RTC
To reproduce the results for training-time RTC, follow the following steps:
1. Change `simulated_delay` in the model config to 5.
2. Fine-tune the pre-trained checkpoint with simulated delay for 8 epochs: `uv run src/train_flow.py --config.run-path <run_path> --config.load-dir bc/24 --config.num-epochs 8` where `bc` is the contents of `gs://rtc-assets/bc/`.
3. Evaluate as above.
================================================
FILE: pyproject.toml
================================================
[build-system]
requires = ["setuptools>=64.0"]
build-backend = "setuptools.build_meta"
[project]
name = "real-time-chunking-kinetix"
version = "0.0.1"
authors = [
{ name="Kevin Black", email="kevin@physicalintelligence.company" }
]
description = ""
requires-python = ">=3.11"
dependencies = [
"jax[cuda12]==0.4.35",
"numpy==1.26.4",
"tyro",
"einops",
"pandas",
"tqdm-loggable",
"kinetix",
]
[tool.ruff]
line-length = 120
[tool.uv.sources]
kinetix = { path = "third_party/kinetix", editable = true }
[tool.ruff.lint.isort]
force-single-line = true
force-sort-within-sections = true
single-line-exclusions = ["collections.abc", "typing", "typing_extensions"]
known-third-party = ["wandb"]
================================================
FILE: src/eval_flow.py
================================================
import collections
import dataclasses
import functools
import math
import pathlib
import pickle
from typing import Sequence
import flax.nnx as nnx
import jax
from jax.experimental import shard_map
import jax.numpy as jnp
import kinetix.environment.env as kenv
import kinetix.environment.env_state as kenv_state
import kinetix.environment.wrappers as wrappers
import kinetix.render.renderer_pixels as renderer_pixels
import pandas as pd
import tyro
import model as _model
import train_expert
@dataclasses.dataclass(frozen=True)
class NaiveMethodConfig:
pass
@dataclasses.dataclass(frozen=True)
class RealtimeMethodConfig:
prefix_attention_schedule: _model.PrefixAttentionSchedule = "exp"
max_guidance_weight: float = 5.0
@dataclasses.dataclass(frozen=True)
class BIDMethodConfig:
n_samples: int = 16
bid_k: int | None = None
@dataclasses.dataclass(frozen=True)
class EvalConfig:
step: int = -1
weak_step: int | None = None
num_evals: int = 2048
num_flow_steps: int = 5
inference_delay: int = 0
execute_horizon: int = 1
method: NaiveMethodConfig | RealtimeMethodConfig | BIDMethodConfig = NaiveMethodConfig()
model: _model.ModelConfig = _model.ModelConfig()
def eval(
config: EvalConfig,
env: kenv.environment.Environment,
rng: jax.Array,
level: kenv_state.EnvState,
policy: _model.FlowPolicy,
env_params: kenv_state.EnvParams,
static_env_params: kenv_state.EnvParams,
weak_policy: _model.FlowPolicy | None = None,
):
env = train_expert.BatchEnvWrapper(
wrappers.LogWrapper(wrappers.AutoReplayWrapper(train_expert.NoisyActionWrapper(env))), config.num_evals
)
render_video = train_expert.make_render_video(renderer_pixels.make_render_pixels(env_params, static_env_params))
assert config.execute_horizon >= config.inference_delay, f"{config.execute_horizon=} {config.inference_delay=}"
def execute_chunk(carry, _):
def step(carry, action):
rng, obs, env_state = carry
rng, key = jax.random.split(rng)
next_obs, next_env_state, reward, done, info = env.step(key, env_state, action, env_params)
return (rng, next_obs, next_env_state), (done, env_state, info)
rng, obs, env_state, action_chunk, n = carry
rng, key = jax.random.split(rng)
if isinstance(config.method, NaiveMethodConfig):
next_action_chunk = policy.action(key, obs, config.num_flow_steps)
elif isinstance(config.method, RealtimeMethodConfig):
prefix_attention_horizon = policy.action_chunk_size - config.execute_horizon
assert (
config.inference_delay <= policy.action_chunk_size
and prefix_attention_horizon <= policy.action_chunk_size
), f"{config.inference_delay=} {prefix_attention_horizon=} {policy.action_chunk_size=}"
print(
f"{config.execute_horizon=} {config.inference_delay=} {prefix_attention_horizon=} {policy.action_chunk_size=}"
)
next_action_chunk = policy.realtime_action(
key,
obs,
config.num_flow_steps,
action_chunk,
config.inference_delay,
prefix_attention_horizon,
config.method.prefix_attention_schedule,
config.method.max_guidance_weight,
)
elif isinstance(config.method, BIDMethodConfig):
prefix_attention_horizon = policy.action_chunk_size - config.execute_horizon
if config.method.bid_k is not None:
assert weak_policy is not None, "weak_policy is required for BID"
next_action_chunk = policy.bid_action(
key,
obs,
config.num_flow_steps,
action_chunk,
config.inference_delay,
prefix_attention_horizon,
config.method.n_samples,
bid_k=config.method.bid_k,
bid_weak_policy=weak_policy if config.method.bid_k is not None else None,
)
else:
raise ValueError(f"Unknown method: {config.method}")
# we execute `inference_delay` actions from the *previously generated* action chunk, and then the remaining
# `execute_horizon - inference_delay` actions from the newly generated action chunk
action_chunk_to_execute = jnp.concatenate(
[
action_chunk[:, : config.inference_delay],
next_action_chunk[:, config.inference_delay : config.execute_horizon],
],
axis=1,
)
# throw away the first `execute_horizon` actions from the newly generated action chunk, to align it with the
# correct frame of reference for the next scan iteration
next_action_chunk = jnp.concatenate(
[
next_action_chunk[:, config.execute_horizon :],
jnp.zeros((obs.shape[0], config.execute_horizon, policy.action_dim)),
],
axis=1,
)
next_n = jnp.concatenate([n[config.execute_horizon :], jnp.zeros(config.execute_horizon, dtype=jnp.int32)])
(rng, next_obs, next_env_state), (dones, env_states, infos) = jax.lax.scan(
step, (rng, obs, env_state), action_chunk_to_execute.transpose(1, 0, 2)
)
# if config.inference_delay > 0:
# infos["match"] = jnp.mean(jnp.abs(fixed_prefix - action_chunk_to_execute))
return (rng, next_obs, next_env_state, next_action_chunk, next_n), (dones, env_states, infos)
rng, key = jax.random.split(rng)
obs, env_state = env.reset_to_level(key, level, env_params)
rng, key = jax.random.split(rng)
action_chunk = policy.action(key, obs, config.num_flow_steps) # [batch, horizon, action_dim]
n = jnp.ones(action_chunk.shape[1], dtype=jnp.int32)
scan_length = math.ceil(env_params.max_timesteps / config.execute_horizon)
_, (dones, env_states, infos) = jax.lax.scan(
execute_chunk,
(rng, obs, env_state, action_chunk, n),
None,
length=scan_length,
)
dones, env_states, infos = jax.tree.map(lambda x: x.reshape(-1, *x.shape[2:]), (dones, env_states, infos))
assert dones.shape[0] >= env_params.max_timesteps, f"{dones.shape=}"
return_info = {}
for key in ["returned_episode_returns", "returned_episode_lengths", "returned_episode_solved"]:
# only consider the first episode of each rollout
first_done_idx = jnp.argmax(dones, axis=0)
return_info[key] = infos[key][first_done_idx, jnp.arange(config.num_evals)].mean()
for key in ["match"]:
if key in infos:
return_info[key] = jnp.mean(infos[key])
video = render_video(jax.tree.map(lambda x: x[:, 0], env_states))
return return_info, video
def main(
run_path: str,
config: EvalConfig = EvalConfig(),
level_paths: Sequence[str] = (
"worlds/l/grasp_easy.json",
"worlds/l/catapult.json",
"worlds/l/cartpole_thrust.json",
"worlds/l/hard_lunar_lander.json",
"worlds/l/mjc_half_cheetah.json",
"worlds/l/mjc_swimmer.json",
"worlds/l/mjc_walker.json",
"worlds/l/h17_unicycle.json",
"worlds/l/chain_lander.json",
"worlds/l/catcher_v3.json",
"worlds/l/trampoline.json",
"worlds/l/car_launch.json",
),
seed: int = 0,
output_dir: str | None = "eval_output",
):
static_env_params = kenv_state.StaticEnvParams(**train_expert.LARGE_ENV_PARAMS, frame_skip=train_expert.FRAME_SKIP)
env_params = kenv_state.EnvParams()
levels = train_expert.load_levels(level_paths, static_env_params, env_params)
static_env_params = static_env_params.replace(screen_dim=train_expert.SCREEN_DIM)
env = kenv.make_kinetix_env_from_name("Kinetix-Symbolic-Continuous-v1", static_env_params=static_env_params)
# load policies from best checkpoints by solve rate
state_dicts = []
weak_state_dicts = []
for level_path in level_paths:
level_name = level_path.replace("/", "_").replace(".json", "")
log_dirs = list(filter(lambda p: p.is_dir() and p.name.isdigit(), pathlib.Path(run_path).iterdir()))
log_dirs = sorted(log_dirs, key=lambda p: int(p.name))
# load policy
with (log_dirs[config.step] / "policies" / f"{level_name}.pkl").open("rb") as f:
state_dicts.append(pickle.load(f))
if config.weak_step is not None:
with (log_dirs[config.weak_step] / "policies" / f"{level_name}.pkl").open("rb") as f:
weak_state_dicts.append(pickle.load(f))
state_dicts = jax.device_put(jax.tree.map(lambda *x: jnp.array(x), *state_dicts))
if config.weak_step is not None:
weak_state_dicts = jax.device_put(jax.tree.map(lambda *x: jnp.array(x), *weak_state_dicts))
else:
weak_state_dicts = None
obs_dim = jax.eval_shape(env.reset_to_level, jax.random.key(0), jax.tree.map(lambda x: x[0], levels), env_params)[
0
].shape[-1]
action_dim = env.action_space(env_params).shape[0]
mesh = jax.make_mesh((jax.local_device_count(),), ("x",))
pspec = jax.sharding.PartitionSpec("x")
sharding = jax.sharding.NamedSharding(mesh, pspec)
@functools.partial(jax.jit, static_argnums=(0,), in_shardings=sharding, out_shardings=sharding)
@functools.partial(shard_map.shard_map, mesh=mesh, in_specs=(None, pspec, pspec, pspec, pspec), out_specs=pspec)
@functools.partial(jax.vmap, in_axes=(None, 0, 0, 0, 0))
def _eval(config: EvalConfig, rng: jax.Array, level: kenv_state.EnvState, state_dict, weak_state_dict):
policy = _model.FlowPolicy(
obs_dim=obs_dim,
action_dim=action_dim,
config=config.model,
rngs=nnx.Rngs(rng),
)
graphdef, state = nnx.split(policy)
state.replace_by_pure_dict(state_dict)
policy = nnx.merge(graphdef, state)
if weak_state_dict is not None:
graphdef, state = nnx.split(policy)
state.replace_by_pure_dict(weak_state_dict)
weak_policy = nnx.merge(graphdef, state)
else:
weak_policy = None
eval_info, _ = eval(config, env, rng, level, policy, env_params, static_env_params, weak_policy)
return eval_info
rngs = jax.random.split(jax.random.key(seed), len(level_paths))
results = collections.defaultdict(list)
for inference_delay in [0, 1, 2, 3, 4]:
for execute_horizon in range(max(1, inference_delay), 8 - inference_delay + 1):
print(f"{inference_delay=} {execute_horizon=}")
c = dataclasses.replace(
config, inference_delay=inference_delay, execute_horizon=execute_horizon, method=NaiveMethodConfig()
)
out = jax.device_get(_eval(c, rngs, levels, state_dicts, weak_state_dicts))
for i in range(len(level_paths)):
for k, v in out.items():
results[k].append(v[i])
results["delay"].append(inference_delay)
results["method"].append("naive")
results["level"].append(level_paths[i])
results["execute_horizon"].append(execute_horizon)
c = dataclasses.replace(
config, inference_delay=inference_delay, execute_horizon=execute_horizon, method=RealtimeMethodConfig()
)
out = jax.device_get(_eval(c, rngs, levels, state_dicts, weak_state_dicts))
for i in range(len(level_paths)):
for k, v in out.items():
results[k].append(v[i])
results["delay"].append(inference_delay)
results["method"].append("realtime")
results["level"].append(level_paths[i])
results["execute_horizon"].append(execute_horizon)
c = dataclasses.replace(
config, inference_delay=inference_delay, execute_horizon=execute_horizon, method=BIDMethodConfig()
)
out = jax.device_get(_eval(c, rngs, levels, state_dicts, weak_state_dicts))
for i in range(len(level_paths)):
for k, v in out.items():
results[k].append(v[i])
results["delay"].append(inference_delay)
results["method"].append("bid")
results["level"].append(level_paths[i])
results["execute_horizon"].append(execute_horizon)
c = dataclasses.replace(
config,
inference_delay=inference_delay,
execute_horizon=execute_horizon,
method=RealtimeMethodConfig(prefix_attention_schedule="zeros"),
)
out = jax.device_get(_eval(c, rngs, levels, state_dicts, weak_state_dicts))
for i in range(len(level_paths)):
for k, v in out.items():
results[k].append(v[i])
results["delay"].append(inference_delay)
results["method"].append("hard_masking")
results["level"].append(level_paths[i])
results["execute_horizon"].append(execute_horizon)
pathlib.Path(output_dir).mkdir(parents=True, exist_ok=True)
df = pd.DataFrame(results)
df.to_csv(pathlib.Path(output_dir) / "results.csv", index=False)
if __name__ == "__main__":
tyro.cli(main)
================================================
FILE: src/generate_data.py
================================================
import dataclasses
import functools
import json
import pathlib
import pickle
from typing import Sequence
import einops
from flax import struct
import flax.nnx as nnx
import flax.serialization
import jax
import jax.numpy as jnp
import kinetix.environment.env as kenv
import kinetix.environment.env_state as kenv_state
import kinetix.environment.wrappers as wrappers
import numpy as np
import tqdm_loggable.auto as tqdm
import tyro
import train_expert
@dataclasses.dataclass
class Config:
run_path: str
level_paths: Sequence[str] = (
"worlds/l/grasp_easy.json",
"worlds/l/catapult.json",
"worlds/l/cartpole_thrust.json",
"worlds/l/hard_lunar_lander.json",
"worlds/l/mjc_half_cheetah.json",
"worlds/l/mjc_swimmer.json",
"worlds/l/mjc_walker.json",
"worlds/l/h17_unicycle.json",
"worlds/l/chain_lander.json",
"worlds/l/catcher_v3.json",
"worlds/l/trampoline.json",
"worlds/l/car_launch.json",
)
seed: int = 0
# Number of environments to run in parallel.
num_envs: int = 128
# Batch size for scan in number of steps *per environment*.
batch_size: int = 256
# Number of *total* steps to collect (lower bound -- rounded up to nearest multiple of batch size * num_envs).
num_steps: int = 1_000_000
solve_rate_threshold: float = 0.65
action_sample_std: float | None = None
@struct.dataclass
class Data:
obs: jax.Array
action: jax.Array
done: jax.Array
solved: jax.Array
return_: jax.Array
length: jax.Array
@struct.dataclass
class StepCarry:
rng: jax.Array
obs: jax.Array
env_state: kenv_state.EnvState
policy_idxs: jax.Array
def main(config: Config):
num_steps_per_env = (
(config.num_steps // config.num_envs + config.batch_size - 1) // config.batch_size
) * config.batch_size
print(
f"Generating {num_steps_per_env * config.num_envs:_} steps with {config.num_envs} environments ({num_steps_per_env} steps per env)"
)
static_env_params = kenv_state.StaticEnvParams(**train_expert.LARGE_ENV_PARAMS, frame_skip=train_expert.FRAME_SKIP)
env_params = kenv_state.EnvParams()
levels = train_expert.load_levels(config.level_paths, static_env_params, env_params)
env = kenv.make_kinetix_env_from_name("Kinetix-Symbolic-Continuous-v1", static_env_params=static_env_params)
env = train_expert.BatchEnvWrapper(
wrappers.LogWrapper(
wrappers.AutoReplayWrapper(
train_expert.ActionHistoryWrapper(
train_expert.ObsHistoryWrapper(train_expert.NoisyActionWrapper(env), 4)
)
)
),
config.num_envs,
)
# load policies from best checkpoints by solve rate
gen = np.random.default_rng(config.seed)
state_dicts, good_policy_masks = [], []
for level_path in config.level_paths:
level_name = level_path.replace("/", "_").replace(".json", "")
print(level_name)
level_state_dicts, level_good_policy_mask = [], []
for seed_dir in pathlib.Path(config.run_path).glob("seed_*"):
# load stats
log_dirs = list(filter(lambda p: p.is_dir() and p.name.isdigit(), seed_dir.iterdir()))
level_stats = [json.load((p / "stats" / f"{level_name}.json").open("r")) for p in log_dirs]
level_stats = jax.tree.map(lambda *x: jnp.stack(x), *level_stats)
# pick a random policy with solve rate >= threshold
solved_idxs = np.nonzero(level_stats["returned_episode_solved"] >= config.solve_rate_threshold)[0]
if len(solved_idxs) == 0:
chosen_idx = np.argmax(level_stats["returned_episode_solved"])
level_good_policy_mask.append(False)
else:
chosen_idx = gen.choice(solved_idxs)
level_good_policy_mask.append(True)
# load policy
chosen_log_dir = log_dirs[chosen_idx]
with open(chosen_log_dir / "policies" / f"{level_name}.pkl", "rb") as f:
level_state_dicts.append(pickle.load(f))
print(
f"\t{seed_dir.name}: {level_stats['returned_episode_solved'][chosen_idx]:.3f} {'[MASKED]' if not level_good_policy_mask[-1] else ''}"
)
state_dicts.append(jax.tree.map(lambda *x: jnp.array(x), *level_state_dicts))
good_policy_masks.append(level_good_policy_mask)
state_dicts = jax.tree.map(lambda *x: jnp.array(x), *state_dicts)
good_policy_masks = jnp.array(good_policy_masks)
state_dicts, good_policy_masks = jax.device_put((state_dicts, good_policy_masks))
def new_policy_idxs(rng: jax.Array, good_policy_mask: jax.Array) -> jax.Array:
# select a random policy for each environment
rng, key = jax.random.split(rng)
randint = jax.random.randint(key, (config.num_envs,), 0, good_policy_mask.sum())
return jnp.nonzero(good_policy_mask, size=good_policy_mask.shape[0])[0][randint]
@jax.jit
@jax.vmap
def init(rng: jax.Array, level: kenv_state.EnvState, good_policy_mask: jax.Array) -> StepCarry:
rng, key = jax.random.split(rng)
obs, env_state = env.reset_to_level(key, level, env_params)
rng, key = jax.random.split(rng)
policy_idxs = new_policy_idxs(key, good_policy_mask)
return StepCarry(rng, obs, env_state, policy_idxs)
@functools.partial(jax.jit, static_argnums=(3,), donate_argnums=(0,))
@functools.partial(jax.vmap, in_axes=(0, 0, 0, None))
def step_n(carry: StepCarry, state_dict: dict, good_policy_mask: jax.Array, n: int):
def step(carry: StepCarry, _):
# create agent
action_dim = env.action_space(env_params).shape[0]
assert len(carry.obs.shape) == 2
obs_dim = carry.obs.shape[1]
@jax.vmap # over environments
def get_action(key, obs, policy_idx):
agent = train_expert.Agent(obs_dim, action_dim, 1, rngs=nnx.Rngs(0))
graphdef, state = nnx.split(agent)
state.replace_by_pure_dict(jax.tree.map(lambda x: x[policy_idx], state_dict))
agent = nnx.merge(graphdef, state)
mean, std = agent.action(obs)
if config.action_sample_std is not None:
std = jnp.full_like(mean, config.action_sample_std)
action_dist = train_expert.make_squashed_normal_diag(mean, std, static_env_params.num_motor_bindings)
return action_dist.sample(seed=key)
# step
rng, key = jax.random.split(carry.rng)
action = get_action(jax.random.split(key, config.num_envs), carry.obs, carry.policy_idxs)
rng, key = jax.random.split(rng)
next_obs, next_env_state, reward, done, info = env.step(key, carry.env_state, action, env_params)
# select new policies only at episode boundaries
rng, key = jax.random.split(rng)
next_policy_idxs = jnp.where(done, new_policy_idxs(key, good_policy_mask), carry.policy_idxs)
# only retain important info
info = {
k: v
for k, v in info.items()
if k in ["returned_episode_returns", "returned_episode_lengths", "returned_episode_solved"]
}
return StepCarry(rng, next_obs, next_env_state, next_policy_idxs), Data(
train_expert.ObsHistoryWrapper.get_original_obs(carry.env_state),
action,
done,
info["returned_episode_solved"],
info["returned_episode_returns"],
info["returned_episode_lengths"],
)
return jax.lax.scan(step, carry, None, length=n)
rng = jax.random.key(config.seed)
carry = init(jax.random.split(rng, len(config.level_paths)), levels, good_policy_masks)
pbar = tqdm.tqdm(total=num_steps_per_env * config.num_envs, dynamic_ncols=True)
data = []
for _ in range(0, num_steps_per_env, config.batch_size):
carry, result = step_n(carry, state_dicts, good_policy_masks, config.batch_size)
data.append(jax.device_get(result))
pbar.update(config.batch_size * config.num_envs)
pbar.close()
with jax.default_device(jax.devices("cpu")[0]):
data: Data = jax.tree.map(
lambda *x: einops.rearrange(
jnp.stack(x),
"num_batch level batch_size num_env ... -> level (num_batch batch_size) num_env ...",
),
*data,
)
for i, level_path in enumerate(config.level_paths):
level_name = level_path.replace("/", "_").replace(".json", "")
print_info = {"num_episodes": data.done[i].sum()}
for key in ["return_", "length", "solved"]:
print_info[key] = (getattr(data, key)[i] * data.done[i]).sum() / print_info["num_episodes"]
print(f"{level_name}:")
for k, v in print_info.items():
print(f"\t{k}: {v:.3f}")
data_path = pathlib.Path(config.run_path) / "data"
data_path.mkdir(parents=True, exist_ok=True)
level_data = flax.serialization.to_state_dict(jax.tree.map(lambda x: x[i], data))
np.savez(data_path / f"{level_name}.npz", **level_data)
if __name__ == "__main__":
tyro.cli(main)
================================================
FILE: src/model.py
================================================
import dataclasses
import functools
from typing import Literal, TypeAlias, Self
import einops
import flax.nnx as nnx
import jax
import jax.numpy as jnp
@dataclasses.dataclass(frozen=True)
class ModelConfig:
channel_dim: int = 256
channel_hidden_dim: int = 512
token_hidden_dim: int = 64
num_layers: int = 4
action_chunk_size: int = 8
simulated_delay: int | None = None
def posemb_sincos(pos: jax.Array, embedding_dim: int, min_period: float, max_period: float) -> jax.Array:
"""Computes sine-cosine positional embedding vectors for scalar positions."""
if embedding_dim % 2 != 0:
raise ValueError(f"embedding_dim ({embedding_dim}) must be divisible by 2")
fraction = jnp.linspace(0.0, 1.0, embedding_dim // 2)
period = min_period * (max_period / min_period) ** fraction
sinusoid_input = jnp.einsum(
"i,j->ij",
pos,
1.0 / period * 2 * jnp.pi,
precision=jax.lax.Precision.HIGHEST,
)
return jnp.concatenate([jnp.sin(sinusoid_input), jnp.cos(sinusoid_input)], axis=-1)
PrefixAttentionSchedule: TypeAlias = Literal["linear", "exp", "ones", "zeros"]
def get_prefix_weights(start: int, end: int, total: int, schedule: PrefixAttentionSchedule) -> jax.Array:
"""With start=2, end=6, total=10, the output will be:
1 1 4/5 3/5 2/5 1/5 0 0 0 0
^ ^
start end
`start` (inclusive) is where the chunk starts being allowed to change. `end` (exclusive) is where the chunk stops
paying attention to the prefix. if start == 0, then the entire chunk is allowed to change. if end == total, then the
entire prefix is attended to.
`end` takes precedence over `start` in the sense that, if `end < start`, then `start` is pushed down to `end`. Thus,
if `end` is 0, then the entire prefix will always be ignored.
"""
start = jnp.minimum(start, end)
if schedule == "ones":
w = jnp.ones(total)
elif schedule == "zeros":
w = (jnp.arange(total) < start).astype(jnp.float32)
elif schedule == "linear" or schedule == "exp":
w = jnp.clip((start - 1 - jnp.arange(total)) / (end - start + 1) + 1, 0, 1)
if schedule == "exp":
w = w * jnp.expm1(w) / (jnp.e - 1)
else:
raise ValueError(f"Invalid schedule: {schedule}")
return jnp.where(jnp.arange(total) >= end, 0, w)
class MLPMixerBlock(nnx.Module):
def __init__(
self, token_dim: int, token_hidden_dim: int, channel_dim: int, channel_hidden_dim: int, *, rngs: nnx.Rngs
):
self.token_mix_in = nnx.Linear(token_dim, token_hidden_dim, use_bias=False, rngs=rngs)
self.token_mix_out = nnx.Linear(token_hidden_dim, token_dim, use_bias=False, rngs=rngs)
self.channel_mix_in = nnx.Linear(channel_dim, channel_hidden_dim, use_bias=False, rngs=rngs)
self.channel_mix_out = nnx.Linear(channel_hidden_dim, channel_dim, use_bias=False, rngs=rngs)
self.norm_1 = nnx.LayerNorm(channel_dim, use_scale=False, use_bias=False, rngs=rngs)
self.norm_2 = nnx.LayerNorm(channel_dim, use_scale=False, use_bias=False, rngs=rngs)
self.adaln_1 = nnx.Linear(channel_dim, 3 * channel_dim, kernel_init=nnx.initializers.zeros_init(), rngs=rngs)
self.adaln_2 = nnx.Linear(channel_dim, 3 * channel_dim, kernel_init=nnx.initializers.zeros_init(), rngs=rngs)
def __call__(self, x: jax.Array, adaln_cond: jax.Array) -> jax.Array:
scale_1, shift_1, gate_1 = jnp.split(self.adaln_1(adaln_cond), 3, axis=-1)
scale_2, shift_2, gate_2 = jnp.split(self.adaln_2(adaln_cond), 3, axis=-1)
# token mix
residual = x
x = self.norm_1(x) * (1 + scale_1) + shift_1
x = x.transpose(0, 2, 1)
x = self.token_mix_in(x)
x = nnx.gelu(x)
x = self.token_mix_out(x)
x = x.transpose(0, 2, 1)
x = residual + gate_1 * x
# channel mix
residual = x
x = self.norm_2(x) * (1 + scale_2) + shift_2
x = self.channel_mix_in(x)
x = nnx.gelu(x)
x = self.channel_mix_out(x)
x = residual + gate_2 * x
return x
class FlowPolicy(nnx.Module):
def __init__(
self,
*,
obs_dim: int,
action_dim: int,
config: ModelConfig,
rngs: nnx.Rngs,
):
self.channel_dim = config.channel_dim
self.action_dim = action_dim
self.action_chunk_size = config.action_chunk_size
self.simulated_delay = config.simulated_delay
self.in_proj = nnx.Linear(action_dim + obs_dim, config.channel_dim, rngs=rngs)
self.mlp_stack = [
MLPMixerBlock(
config.action_chunk_size,
config.token_hidden_dim,
config.channel_dim,
config.channel_hidden_dim,
rngs=rngs,
)
for _ in range(config.num_layers)
]
self.time_mlp = nnx.Sequential(
nnx.Linear(config.channel_dim, config.channel_dim, rngs=rngs),
nnx.swish,
nnx.Linear(config.channel_dim, config.channel_dim, rngs=rngs),
nnx.swish,
)
self.final_norm = nnx.LayerNorm(config.channel_dim, use_scale=False, use_bias=False, rngs=rngs)
self.final_adaln = nnx.Linear(
config.channel_dim, 2 * config.channel_dim, kernel_init=nnx.initializers.zeros_init(), rngs=rngs
)
self.out_proj = nnx.Linear(config.channel_dim, action_dim, rngs=rngs)
def __call__(self, obs: jax.Array, x_t: jax.Array, time: jax.Array) -> jax.Array:
assert x_t.shape == (obs.shape[0], self.action_chunk_size, self.action_dim), x_t.shape
if time.ndim == 1:
time = time[:, None]
time = jnp.broadcast_to(time, (obs.shape[0], self.action_chunk_size))
time_emb = jax.vmap(
functools.partial(posemb_sincos, embedding_dim=self.channel_dim, min_period=4e-3, max_period=4.0)
)(time)
time_emb = self.time_mlp(time_emb)
obs = einops.repeat(obs, "b e -> b c e", c=self.action_chunk_size)
x = jnp.concatenate([x_t, obs], axis=-1)
x = self.in_proj(x)
for mlp in self.mlp_stack:
x = mlp(x, time_emb)
assert x.shape == (obs.shape[0], self.action_chunk_size, self.channel_dim), x.shape
scale, shift = jnp.split(self.final_adaln(time_emb), 2, axis=-1)
x = self.final_norm(x) * (1 + scale) + shift
x = self.out_proj(x)
return x
def action(self, rng: jax.Array, obs: jax.Array, num_steps: int) -> jax.Array:
dt = 1 / num_steps
def step(carry, _):
x_t, time = carry
v_t = self(obs, x_t, time)
return (x_t + dt * v_t, time + dt), None
noise = jax.random.normal(rng, shape=(obs.shape[0], self.action_chunk_size, self.action_dim))
(x_1, _), _ = jax.lax.scan(step, (noise, 0.0), length=num_steps)
assert x_1.shape == (obs.shape[0], self.action_chunk_size, self.action_dim), x_1.shape
return x_1
def bid_action(
self,
rng: jax.Array,
obs: jax.Array,
num_steps: int,
prev_action_chunk: jax.Array, # [batch, horizon, action_dim]
inference_delay: int,
prefix_attention_horizon: int,
n_samples: int,
# when below two are None, it becomes backwards loss only (i.e., rejection sampling)
bid_weak_policy: Self | None = None,
bid_k: int | None = None,
) -> jax.Array:
obs = einops.repeat(obs, "b ... -> (n b) ...", n=n_samples)
weights = get_prefix_weights(inference_delay, prefix_attention_horizon, self.action_chunk_size, "exp")
def backward_loss(action_chunks: jax.Array):
error = jnp.linalg.norm(action_chunks - prev_action_chunk, axis=-1) # [n, b, h]
return jnp.sum(error * weights[None, None, :], axis=-1) # [n, b]
strong_actions = einops.rearrange(self.action(rng, obs, num_steps), "(n b) h d -> n b h d", n=n_samples)
loss = backward_loss(strong_actions) # [n, b]
if bid_weak_policy is not None or bid_k is not None:
assert bid_weak_policy is not None and bid_k is not None, (bid_weak_policy, bid_k)
weak_actions = einops.rearrange(
bid_weak_policy.action(rng, obs, num_steps), "(n b) h d -> n b h d", n=n_samples
)
weak_loss = backward_loss(weak_actions) # [n, b]
weak_idxs = jax.lax.top_k(-weak_loss.T, bid_k)[1].T # [k, b]
strong_idxs = jax.lax.top_k(-loss.T, bid_k)[1].T # [k, b]
a_plus = jnp.take_along_axis(strong_actions, strong_idxs[:, :, None, None], axis=0) # [k, b, h, d]
a_minus = jnp.take_along_axis(weak_actions, weak_idxs[:, :, None, None], axis=0) # [k, b, h, d]
# compute forward loss for each action in strong_actions
forward_loss = jnp.sum(
jnp.linalg.norm(strong_actions[:, None] - a_plus[None, :], axis=-1), # [n, k, b, h]
axis=(1, 3), # [n, b]
) - jnp.sum(
jnp.linalg.norm(strong_actions[:, None] - a_minus[None, :], axis=-1), # [n, k, b, h]
axis=(1, 3), # [n, b]
)
loss += forward_loss / n_samples
best_idxs = jnp.argmin(loss, axis=0) # [b]
return jnp.take_along_axis(strong_actions, best_idxs[None, :, None, None], axis=0).squeeze(0) # [b, h, d]
def realtime_action(
self,
rng: jax.Array,
obs: jax.Array,
num_steps: int,
prev_action_chunk: jax.Array, # [batch, horizon, action_dim]
inference_delay: int,
prefix_attention_horizon: int,
prefix_attention_schedule: PrefixAttentionSchedule,
max_guidance_weight: float,
) -> jax.Array:
dt = 1 / num_steps
def step(carry, _):
x_t, time = carry
@functools.partial(jax.vmap, in_axes=(0, 0, 0, None)) # over batch
def pinv_corrected_velocity(obs, x_t, y, t):
def denoiser(x_t):
v_t = self(obs[None], x_t[None], t)[0]
return x_t + v_t * (1 - t), v_t
x_1, vjp_fun, v_t = jax.vjp(denoiser, x_t, has_aux=True)
weights = get_prefix_weights(
inference_delay, prefix_attention_horizon, self.action_chunk_size, prefix_attention_schedule
)
error = (y - x_1) * weights[:, None]
pinv_correction = vjp_fun(error)[0]
# constants from paper
inv_r2 = (t**2 + (1 - t) ** 2) / ((1 - t) ** 2)
c = jnp.nan_to_num((1 - t) / t, posinf=max_guidance_weight)
guidance_weight = jnp.minimum(c * inv_r2, max_guidance_weight)
return v_t + guidance_weight * pinv_correction
if self.simulated_delay is not None:
mask = jnp.arange(self.action_chunk_size)[None, :] < inference_delay
x_t = jnp.where(mask[:, :, None], prev_action_chunk, x_t)
time_chunk = jnp.where(mask, 1.0, time)
v_t = self(obs, x_t, time_chunk)
else:
v_t = pinv_corrected_velocity(obs, x_t, prev_action_chunk, time)
return (x_t + dt * v_t, time + dt), None
noise = jax.random.normal(rng, shape=(obs.shape[0], self.action_chunk_size, self.action_dim))
(x_1, _), _ = jax.lax.scan(step, (noise, 0.0), length=num_steps)
assert x_1.shape == (obs.shape[0], self.action_chunk_size, self.action_dim), x_1.shape
return x_1
def loss(self, rng: jax.Array, obs: jax.Array, action: jax.Array):
assert action.dtype == jnp.float32
assert action.shape == (obs.shape[0], self.action_chunk_size, self.action_dim), action.shape
noise_rng, time_rng, delay_rng = jax.random.split(rng, 3)
time = jax.random.uniform(time_rng, (obs.shape[0],))
noise = jax.random.normal(noise_rng, shape=action.shape)
u_t = action - noise
if self.simulated_delay is None:
x_t = (1 - time[:, None, None]) * noise + time[:, None, None] * action
pred = self(obs, x_t, time)
return jnp.mean(jnp.square(pred - u_t))
w = jnp.exp(jnp.arange(0, self.simulated_delay)[::-1])
w = w / jnp.sum(w)
delay = jax.random.choice(delay_rng, self.simulated_delay, (obs.shape[0],), p=w)
mask = jnp.arange(self.action_chunk_size)[None, :] < delay[:, None]
time = jnp.where(mask, 1.0, time[:, None])
x_t = (1 - time[:, :, None]) * noise + time[:, :, None] * action
pred = self(obs, x_t, time)
loss = jnp.square(pred - u_t)
loss_mask = jnp.logical_not(mask)[:, :, None]
return jnp.sum(loss * loss_mask) / (jnp.sum(loss_mask) + 1e-8)
================================================
FILE: src/render_levels.py
================================================
import pathlib
import jax
import jax.numpy as jnp
import kinetix.environment.env as kenv
import kinetix.environment.env_state as kenv_state
import kinetix.render.renderer_pixels as renderer_pixels
import kinetix.util.saving as saving
import imageio
# Constants from train_expert.py
LARGE_ENV_PARAMS = {
"num_polygons": 12,
"num_circles": 4,
"num_joints": 6,
"num_thrusters": 2,
"num_motor_bindings": 4,
"num_thruster_bindings": 2,
}
FRAME_SKIP = 2
SCREEN_DIM = (512, 512)
def load_levels(paths):
static_env_params = kenv_state.StaticEnvParams(**LARGE_ENV_PARAMS, frame_skip=FRAME_SKIP)
env_params = kenv_state.EnvParams()
levels = []
for level_path in paths:
level, level_static_env_params, level_env_params = saving.load_from_json_file(level_path)
# assert level_static_env_params == static_env_params, (
# f"Expected {static_env_params} got {level_static_env_params} for {level_path}"
# )
# assert level_env_params == env_params, f"Expected {env_params} got {level_env_params} for {level_path}"
levels.append(level)
return levels, static_env_params, env_params
def main():
# Define level paths
level_paths = [
"worlds/l/grasp_easy.json",
"worlds/l/catapult.json",
"worlds/l/cartpole_thrust.json",
"worlds/l/hard_lunar_lander.json",
"worlds/l/mjc_half_cheetah.json",
"worlds/l/mjc_swimmer.json",
"worlds/l/mjc_walker.json",
"worlds/l/h17_unicycle.json",
"worlds/l/chain_lander.json",
"worlds/l/catcher_v3.json",
"worlds/l/trampoline.json",
"worlds/l/car_launch.json",
]
# Load levels
levels, static_env_params, env_params = load_levels(level_paths)
# Update screen dimensions
static_env_params = static_env_params.replace(screen_dim=SCREEN_DIM, downscale=2)
# Create environment and renderer
env = kenv.make_kinetix_env_from_name("Kinetix-Symbolic-Continuous-v1", static_env_params=static_env_params)
render_pixels = renderer_pixels.make_render_pixels(env_params, static_env_params)
# Create output directory
output_dir = pathlib.Path("rendered_levels")
output_dir.mkdir(exist_ok=True)
# Render each level
for i, level in enumerate(levels):
# Reset environment to level
_, env_state = env.reset_to_level(jax.random.key(0), level, env_params)
# Render the state
image = render_pixels(env_state).round().astype(jnp.uint8).transpose(1, 0, 2)[::-1]
# Save image
level_name = level_paths[i].split("/")[-1].replace(".json", "")
imageio.imwrite(output_dir / f"{level_name}.jpg", image)
print(f"Saved {level_name}.jpg")
if __name__ == "__main__":
main()
================================================
FILE: src/train_expert.py
================================================
import dataclasses
import functools
import json
import pathlib
import pickle
from typing import Sequence
from flax import struct
import flax.nnx as nnx
import imageio
import jax
import jax.numpy as jnp
import kinetix.environment.env as kenv
import kinetix.environment.env_state as kenv_state
import kinetix.environment.wrappers as wrappers
import kinetix.render.renderer_pixels as renderer_pixels
import kinetix.util.saving as saving
import optax
from tensorflow_probability.substrates import jax as tfp
import tqdm_loggable.auto as tqdm
import tyro
import wandb
@dataclasses.dataclass
class Config:
level_paths: Sequence[str] = (
"worlds/l/grasp_easy.json",
"worlds/l/catapult.json",
"worlds/l/cartpole_thrust.json",
"worlds/l/hard_lunar_lander.json",
"worlds/l/mjc_half_cheetah.json",
"worlds/l/mjc_swimmer.json",
"worlds/l/mjc_walker.json",
"worlds/l/h17_unicycle.json",
"worlds/l/chain_lander.json",
"worlds/l/catcher_v3.json",
"worlds/l/trampoline.json",
"worlds/l/car_launch.json",
)
seed: int = 32
num_seeds: int = 8
log_interval: int = 20
num_updates: int = 1000
num_steps: int = 256
num_envs: int = 256
num_minibatches: int = 8
num_epochs: int = 4
gamma: float = 0.995
gae_lambda: float = 0.9
clip_eps: float = 0.2
v_loss_coef: float = 0.5
rpo_alpha: float = 0.3
layer_width: int = 256
grad_norm_clip: float = 1.0
lr: float = 3e-4
LOG_DIR = pathlib.Path("logs-expert")
WANDB_PROJECT = "rtc-kinetix-expert"
LARGE_ENV_PARAMS = {
"num_polygons": 12,
"num_circles": 4,
"num_joints": 6,
"num_thrusters": 2,
"num_motor_bindings": 4,
"num_thruster_bindings": 2,
}
FRAME_SKIP = 2
SCREEN_DIM = (512, 512)
ACTION_NOISE_STD = 0.1
LOG_STD_MIN = -10.0
LOG_STD_MAX = 2.0
MEAN_MAX_MAGNITUDE = 5
MAX_ACTION = 0.99999
class BatchEnvWrapper(wrappers.GymnaxWrapper):
"""Define our own BatchEnvWrapper (we don't need different levels)"""
def __init__(self, env, num: int):
super().__init__(env)
self.num = num
def reset(self, rng, params):
return jax.vmap(self._env.reset, in_axes=(0, None))(jax.random.split(rng, self.num), params)
def reset_to_level(self, rng, level, params):
return jax.vmap(self._env.reset_to_level, in_axes=(0, None, None))(
jax.random.split(rng, self.num), level, params
)
def step(self, rng, state, action, params):
return jax.vmap(self._env.step, in_axes=(0, 0, 0, None))(jax.random.split(rng, self.num), state, action, params)
@struct.dataclass
class DenseRewardState:
env_state: kenv_state.EnvState
timestep: int
action: jax.Array
class DenseRewardWrapper(wrappers.GymnaxWrapper):
def __init__(self, env):
super().__init__(env)
def step(self, key, state, action, params=None):
obs, env_state, reward, done, info = self._env.step_env(key, state.env_state, action, params)
dist_penalty = jax.lax.select(reward > 0, 0.0, info["distance"] / 6.0 / params.max_timesteps)
new_reward = reward - jax.lax.select(done, (params.max_timesteps - state.timestep) * dist_penalty, dist_penalty)
next_timestep = jax.lax.select(done, 0, state.timestep + 1)
return obs, DenseRewardState(env_state, next_timestep, action), new_reward, done, info
def reset(self, rng, params=None):
obs, env_state = self._env.reset(rng, params)
return obs, DenseRewardState(env_state, 0, jnp.zeros(self._env.action_space(params).shape[0]))
def reset_to_level(self, rng, level, params=None):
obs, env_state = self._env.reset_to_level(rng, level, params)
return obs, DenseRewardState(env_state, 0, jnp.zeros(self._env.action_space(params).shape[0]))
class ActionHistoryWrapper(wrappers.UnderspecifiedEnvWrapper):
def __init__(self, env):
super().__init__(env)
def step_env(self, key, state, action, params):
obs, env_state, reward, done, info = self._env.step_env(key, state, action, params)
obs = jnp.concatenate([obs, action])
return obs, env_state, reward, done, info
def reset_to_level(self, rng, level, params):
obs, env_state = self._env.reset_to_level(rng, level, params)
obs = jnp.concatenate([obs, jnp.zeros(self._env.action_space(params).shape[0])])
return obs, env_state
def action_space(self, params):
return self._env.action_space(params)
class NoisyActionWrapper(wrappers.UnderspecifiedEnvWrapper):
def __init__(self, env):
super().__init__(env)
def step_env(self, key, state, action, params):
key1, key2 = jax.random.split(key)
action = action + jax.random.normal(key1, action.shape) * ACTION_NOISE_STD
return self._env.step_env(key2, state, action, params)
def reset_to_level(self, rng, level, params):
return self._env.reset_to_level(rng, level, params)
def action_space(self, params):
return self._env.action_space(params)
@struct.dataclass
class StickyActionState:
env_state: kenv_state.EnvState
action: jax.Array
class StickyActionWrapper(wrappers.UnderspecifiedEnvWrapper):
def __init__(self, env, stickiness: float):
super().__init__(env)
self.stickiness = stickiness
def step_env(self, key, state, action, params):
key1, key2 = jax.random.split(key)
actual_action = jax.lax.select(jax.random.bernoulli(key1, self.stickiness), state.action, action)
obs, env_state, reward, done, info = self._env.step_env(key2, state.env_state, actual_action, params)
return obs, StickyActionState(env_state, actual_action), reward, done, info
def reset_to_level(self, rng, level, params):
obs, env_state = self._env.reset_to_level(rng, level, params)
return obs, StickyActionState(
env_state,
jnp.zeros(
len(self._env.action_space(params).number_of_dims_per_distribution),
dtype=jnp.int32,
),
)
def action_space(self, params):
return self._env.action_space(params)
@struct.dataclass
class ObsHistoryState:
env_state: kenv_state.EnvState
stacked_obs: jax.Array
original_obs: jax.Array
class ObsHistoryWrapper(wrappers.UnderspecifiedEnvWrapper):
def __init__(self, env, history_length: int):
super().__init__(env)
self.history_length = history_length
def step_env(self, key, state, action, params):
obs, env_state, reward, done, info = self._env.step_env(key, state.env_state, action, params)
stacked_obs = jnp.roll(state.stacked_obs, -1, axis=0).at[-1].set(obs)
return stacked_obs.flatten(), ObsHistoryState(env_state, stacked_obs, obs), reward, done, info
def reset_to_level(self, rng, level, params):
obs, env_state = self._env.reset_to_level(rng, level, params)
stacked_obs = jnp.repeat(obs[None], self.history_length, axis=0)
return stacked_obs.flatten(), ObsHistoryState(env_state, stacked_obs, obs)
def action_space(self, params):
return self._env.action_space(params)
@staticmethod
def get_original_obs(env_state) -> jax.Array:
while not isinstance(env_state, ObsHistoryState):
env_state = env_state.env_state
return env_state.original_obs
def make_squashed_normal_diag(mean, std, num_motor_bindings: int):
bijector = tfp.bijectors.Blockwise(
[tfp.bijectors.Tanh(), tfp.bijectors.Sigmoid()],
block_sizes=[num_motor_bindings, mean.shape[-1] - num_motor_bindings],
maybe_changes_size=False,
validate_args=True,
)
return tfp.distributions.TransformedDistribution(tfp.distributions.MultivariateNormalDiag(mean, std), bijector)
class Agent(nnx.Module):
def __init__(self, obs_dim: int, action_dim: int, layer_width: int, *, rngs: nnx.Rngs):
self.critic = nnx.Sequential(
nnx.Linear(obs_dim, layer_width, kernel_init=nnx.initializers.orthogonal(jnp.sqrt(2)), rngs=rngs),
nnx.tanh,
nnx.Linear(layer_width, layer_width, kernel_init=nnx.initializers.orthogonal(jnp.sqrt(2)), rngs=rngs),
nnx.tanh,
nnx.Linear(layer_width, 1, kernel_init=nnx.initializers.orthogonal(1), rngs=rngs),
)
self.actor = nnx.Sequential(
nnx.Linear(obs_dim, layer_width, kernel_init=nnx.initializers.orthogonal(jnp.sqrt(2)), rngs=rngs),
nnx.tanh,
nnx.Linear(layer_width, layer_width, kernel_init=nnx.initializers.orthogonal(jnp.sqrt(2)), rngs=rngs),
nnx.tanh,
nnx.Linear(layer_width, action_dim, kernel_init=nnx.initializers.orthogonal(0.01), rngs=rngs),
)
self.logstd = nnx.Param(jnp.zeros(action_dim))
def value(self, obs: jax.Array) -> jax.Array:
return self.critic(obs)[..., 0]
def action(self, obs: jax.Array):
mean = jnp.clip(self.actor(obs), -MEAN_MAX_MAGNITUDE, MEAN_MAX_MAGNITUDE)
std = jnp.exp(jnp.clip(self.logstd.value, LOG_STD_MIN, LOG_STD_MAX))
return mean, std
@struct.dataclass
class Transition:
obs: jax.Array
action: jax.Array
done: jax.Array
reward: jax.Array
value: jax.Array
log_prob: jax.Array
info: jax.Array
env_state: kenv_state.EnvState
@struct.dataclass
class StepCarry:
"""Environment-related information that must be carried through the rollout loop."""
rng: jax.Array
env_state: kenv_state.EnvState
obs: jax.Array
done: jax.Array
@struct.dataclass
class UpdateCarry:
"""Information that must be carried through the outermost update loop."""
rng: jax.Array
step_carry: StepCarry
train_state: nnx.State
graphdef: nnx.GraphDef[tuple[Agent, nnx.Optimizer]] = struct.field(pytree_node=False)
@struct.dataclass
class TrainCarry:
rng: jax.Array
train_state: nnx.State
def make_render_video(render_pixels):
@jax.vmap
def render_video(env_state):
while not isinstance(env_state, kenv_state.EnvState):
env_state = env_state.env_state
return render_pixels(env_state).round().astype(jnp.uint8).transpose(1, 0, 2)[::-1]
return render_video
def load_levels(paths: Sequence[str], static_env_params: kenv_state.StaticEnvParams, env_params: kenv_state.EnvParams):
levels = []
for level_path in paths:
level, level_static_env_params, level_env_params = saving.load_from_json_file(level_path)
assert level_static_env_params == static_env_params, (
f"Expected {static_env_params} got {level_static_env_params} for {level_path}"
)
assert level_env_params == env_params, f"Expected {env_params} got {level_env_params} for {level_path}"
levels.append(level)
return jax.tree.map(lambda *x: jnp.stack(x), *levels)
def main(config: Config):
static_env_params = kenv_state.StaticEnvParams(**LARGE_ENV_PARAMS, frame_skip=FRAME_SKIP)
env_params = kenv_state.EnvParams()
env = kenv.make_kinetix_env_from_name("Kinetix-Symbolic-Continuous-v1", static_env_params=static_env_params)
env = BatchEnvWrapper(
wrappers.LogWrapper(
DenseRewardWrapper(
wrappers.AutoReplayWrapper(ActionHistoryWrapper(ObsHistoryWrapper(NoisyActionWrapper(env), 4)))
)
),
config.num_envs,
)
levels = load_levels(config.level_paths, static_env_params, env_params)
static_env_params = static_env_params.replace(screen_dim=SCREEN_DIM)
batch_size = config.num_envs * config.num_steps
assert batch_size % config.num_minibatches == 0, "Batch size must be divisible by number of minibatches"
minibatch_size = batch_size // config.num_minibatches
print(f"Batch size: {batch_size}, minibatch size: {minibatch_size}")
# create rendering function
render_pixels = renderer_pixels.make_render_pixels(env_params, static_env_params)
render_video = make_render_video(render_pixels)
mesh = jax.make_mesh((jax.local_device_count(),), ("x",))
sharding = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec("x"))
@functools.partial(jax.jit, out_shardings=sharding)
@functools.partial(jax.vmap, in_axes=(0, None)) # over seeds
@jax.vmap # over levels
def init(rng: jax.Array, level: kenv_state.EnvState) -> UpdateCarry:
# initial reset
rng, key = jax.random.split(rng)
obs, env_state = env.reset_to_level(key, level, env_params)
# initialize agent
action_dim = env.action_space(env_params).shape[0]
assert len(obs.shape) == 2
obs_dim = obs.shape[1]
rng, key = jax.random.split(rng)
agent = Agent(obs_dim, action_dim, config.layer_width, rngs=nnx.Rngs(key))
optimizer = nnx.Optimizer(
agent, optax.chain(optax.clip_by_global_norm(config.grad_norm_clip), optax.adam(config.lr))
)
graphdef, initial_train_state = nnx.split((agent, optimizer))
update_rng, step_rng = jax.random.split(rng)
return UpdateCarry(
rng=update_rng,
step_carry=StepCarry(
rng=step_rng, env_state=env_state, obs=obs, done=jnp.zeros(config.num_envs, dtype=bool)
),
train_state=initial_train_state,
graphdef=graphdef,
)
# outermost PPO update loop
def update(update_carry: UpdateCarry, _):
agent, _ = nnx.merge(update_carry.graphdef, update_carry.train_state)
# environment rollout loop
def step(step_carry: StepCarry, _):
rng, key = jax.random.split(step_carry.rng)
# action = env.action_space(env_params).sample(key)
mean, std = agent.action(step_carry.obs)
action_dist = make_squashed_normal_diag(mean, std, static_env_params.num_motor_bindings)
action = action_dist.sample(seed=key)
action = jnp.clip(action, -MAX_ACTION, MAX_ACTION)
log_prob = action_dist.log_prob(action)
value = agent.value(step_carry.obs)
rng, key = jax.random.split(rng)
next_obs, next_env_state, reward, next_done, info = env.step(key, step_carry.env_state, action, env_params)
return (
StepCarry(rng=rng, env_state=next_env_state, obs=next_obs, done=next_done),
Transition(
obs=step_carry.obs,
action=action,
reward=reward,
value=value,
log_prob=log_prob,
done=step_carry.done,
info=info,
env_state=step_carry.env_state,
),
)
# transitions has shape: (NUM_STEPS, NUM_ENVS, ...)
final_step_carry, transitions = jax.lax.scan(step, update_carry.step_carry, None, length=config.num_steps)
# gae calculation loop
def gae_step(carry, transition: Transition):
gae, next_value, next_done = carry
delta = transition.reward + config.gamma * next_value * (1 - next_done) - transition.value
gae = delta + config.gamma * config.gae_lambda * (1 - next_done) * gae
return (gae, transition.value, transition.done), gae
last_value = agent.value(final_step_carry.obs)
last_done = final_step_carry.done
_, advantages = jax.lax.scan(
gae_step, (jnp.zeros_like(last_value), last_value, last_done), transitions, reverse=True, unroll=16
)
returns = advantages + transitions.value
# update epochs loop
def update_epoch(epoch_carry: TrainCarry, _):
# gradient update loop
def update_minibatch(minibatch_carry: TrainCarry, minibatch: tuple[Transition, jax.Array, jax.Array]):
agent, optimizer = nnx.merge(update_carry.graphdef, minibatch_carry.train_state)
transitions, advantages, returns = minibatch
rng, key = jax.random.split(minibatch_carry.rng)
def loss_fn(agent: Agent):
mean, std = agent.action(transitions.obs)
# RPO LOGIC
z = jax.random.uniform(
key, transitions.action.shape, minval=-config.rpo_alpha, maxval=config.rpo_alpha
)
dist = make_squashed_normal_diag(mean + z, std, static_env_params.num_motor_bindings)
value = agent.value(transitions.obs)
log_prob = dist.log_prob(transitions.action)
log_ratio = log_prob - transitions.log_prob
ratio = jnp.exp(log_ratio)
# actor loss
norm_advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
pg_loss1 = -norm_advantages * ratio
pg_loss2 = -norm_advantages * jnp.clip(ratio, 1.0 - config.clip_eps, 1.0 + config.clip_eps)
pg_loss = jnp.maximum(pg_loss1, pg_loss2).mean()
# value loss
v_loss_unclipped = (value - returns) ** 2
v_clipped = transitions.value + (value - transitions.value).clip(-config.clip_eps, config.clip_eps)
v_loss_clipped = (v_clipped - returns) ** 2
v_loss = 0.5 * jnp.maximum(v_loss_unclipped, v_loss_clipped).mean()
loss = pg_loss + config.v_loss_coef * v_loss
info = {
"pg_loss": pg_loss,
"v_loss": v_loss,
"clipfrac": (jnp.abs(ratio - 1) > config.clip_eps).mean(),
"approx_kl": ((ratio - 1) - log_ratio).mean(),
}
return loss, info
(loss, info), grads = nnx.value_and_grad(loss_fn, has_aux=True)(agent)
info["loss"] = loss
info["grad_norm"] = optax.global_norm(grads)
optimizer.update(grads)
_, train_state = nnx.split((agent, optimizer))
return TrainCarry(rng=rng, train_state=train_state), info
# flatten data in preparation for learning
data = jax.tree.map(
lambda x: x.reshape(config.num_steps * config.num_envs, *x.shape[2:]),
(transitions, advantages, returns),
)
# shuffle
rng, key = jax.random.split(epoch_carry.rng)
permutation = jax.random.permutation(key, config.num_envs * config.num_steps)
data = jax.tree.map(lambda x: x[permutation], data)
# batch
batches = jax.tree.map(lambda x: x.reshape(config.num_minibatches, minibatch_size, *x.shape[1:]), data)
# learn!
final_carry, info = jax.lax.scan(update_minibatch, epoch_carry.replace(rng=rng), batches)
return final_carry, info
final_epoch_carry, info = jax.lax.scan(
update_epoch,
TrainCarry(rng=update_carry.rng, train_state=update_carry.train_state),
None,
length=config.num_epochs,
)
for key in ["returned_episode_returns", "returned_episode_lengths", "returned_episode_solved"]:
info[key] = (transitions.info[key] * transitions.info["returned_episode"]).sum() / transitions.info[
"returned_episode"
].sum()
info["reward"] = transitions.reward.mean()
rollout = jax.tree.map(lambda x: x[:, 0], transitions.env_state)
return UpdateCarry(
final_epoch_carry.rng, final_step_carry, final_epoch_carry.train_state, update_carry.graphdef
), (info, rollout)
@functools.partial(jax.jit, donate_argnums=(0,), static_argnums=(1,), in_shardings=sharding, out_shardings=sharding)
@functools.partial(jax.vmap, in_axes=(0, None)) # over seeds
@functools.partial(jax.vmap, in_axes=(0, None)) # over levels
def update_n(update_carry: UpdateCarry, num: int):
update_carry, (info, rollout) = jax.lax.scan(update, update_carry, length=num)
video = render_video(jax.tree.map(lambda x: x[0], rollout))
return update_carry, (jax.tree.map(jnp.mean, info), video)
wandb.init(project=WANDB_PROJECT)
wandb.define_metric("num_env_steps")
wandb.define_metric("*", step_metric="num_env_steps")
pbar = tqdm.tqdm(total=config.num_updates * config.num_envs * config.num_steps, dynamic_ncols=True)
num_levels = len(config.level_paths)
rngs = jax.random.split(jax.random.key(config.seed), config.num_seeds * num_levels).reshape(
config.num_seeds, num_levels
)
update_carry = init(rngs, levels)
for update_idx in range(0, config.num_updates, config.log_interval):
update_carry, (info, video) = update_n(update_carry, config.log_interval)
if any(jnp.any(jnp.isnan(x)) for x in jax.tree.leaves(info)):
raise ValueError(f"NaN detected at update {update_idx}")
pbar.update(config.log_interval * config.num_envs * config.num_steps)
wandb.log({"num_env_steps": pbar.n}, step=update_idx)
for seed_idx in range(config.num_seeds):
for level_idx in range(num_levels):
level_name = config.level_paths[level_idx].replace("/", "_").replace(".json", "")
level_info = jax.tree.map(lambda x: x[seed_idx, level_idx].item(), info)
wandb.log({f"{level_name}/{seed_idx}/{k}": v for k, v in level_info.items()}, step=update_idx)
log_dir = LOG_DIR / wandb.run.name / f"seed_{seed_idx}" / str(update_idx)
stats_dir = log_dir / "stats"
stats_dir.mkdir(parents=True, exist_ok=True)
with (stats_dir / f"{level_name}.json").open("w") as f:
json.dump(level_info, f, indent=2)
video_dir = log_dir / "videos"
video_dir.mkdir(parents=True, exist_ok=True)
imageio.mimwrite(video_dir / f"{level_name}.mp4", video[seed_idx, level_idx], fps=15)
policy_dir = log_dir / "policies"
policy_dir.mkdir(parents=True, exist_ok=True)
level_train_state = jax.tree.map(lambda x: x[seed_idx, level_idx], update_carry.train_state)
with (policy_dir / f"{level_name}.pkl").open("wb") as f:
agent, _ = nnx.merge(update_carry.graphdef, level_train_state)
state_dict = nnx.split(agent)[1].to_pure_dict()
pickle.dump(state_dict, f)
if __name__ == "__main__":
tyro.cli(main)
================================================
FILE: src/train_flow.py
================================================
import concurrent.futures
import dataclasses
import functools
import pathlib
import pickle
from typing import Sequence
import einops
from flax import struct
import flax.nnx as nnx
import imageio
import jax
import jax.numpy as jnp
import kinetix.environment.env as kenv
import kinetix.environment.env_state as kenv_state
import numpy as np
import optax
import tqdm_loggable.auto as tqdm
import tyro
import wandb
import eval_flow as _eval
import generate_data
import model as _model
import train_expert
WANDB_PROJECT = "rtc-kinetix-bc"
LOG_DIR = pathlib.Path("logs-bc")
@dataclasses.dataclass(frozen=True)
class Config:
run_path: str
level_paths: Sequence[str] = (
"worlds/l/grasp_easy.json",
"worlds/l/catapult.json",
"worlds/l/cartpole_thrust.json",
"worlds/l/hard_lunar_lander.json",
"worlds/l/mjc_half_cheetah.json",
"worlds/l/mjc_swimmer.json",
"worlds/l/mjc_walker.json",
"worlds/l/h17_unicycle.json",
"worlds/l/chain_lander.json",
"worlds/l/catcher_v3.json",
"worlds/l/trampoline.json",
"worlds/l/car_launch.json",
)
batch_size: int = 512
num_epochs: int = 32
seed: int = 0
eval: _eval.EvalConfig = _eval.EvalConfig()
learning_rate: float = 3e-4
grad_norm_clip: float = 10.0
weight_decay: float = 1e-2
lr_warmup_steps: int = 1000
load_dir: str | None = None
@struct.dataclass
class EpochCarry:
rng: jax.Array
train_state: nnx.State
graphdef: nnx.GraphDef[tuple[_model.FlowPolicy, nnx.Optimizer]]
def main(config: Config):
static_env_params = kenv_state.StaticEnvParams(**train_expert.LARGE_ENV_PARAMS, frame_skip=train_expert.FRAME_SKIP)
env_params = kenv_state.EnvParams()
levels = train_expert.load_levels(config.level_paths, static_env_params, env_params)
static_env_params = static_env_params.replace(screen_dim=train_expert.SCREEN_DIM)
env = kenv.make_kinetix_env_from_name("Kinetix-Symbolic-Continuous-v1", static_env_params=static_env_params)
mesh = jax.make_mesh((jax.local_device_count(),), ("level",))
sharding = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec("level"))
action_chunk_size = config.eval.model.action_chunk_size
# load data
def load_data(level_path: str):
level_name = level_path.replace("/", "_").replace(".json", "")
print("Loading data for level:", level_name)
return dict(np.load(pathlib.Path(config.run_path) / "data" / f"{level_name}.npz"))
with concurrent.futures.ThreadPoolExecutor() as executor:
data = list(executor.map(load_data, config.level_paths))
with jax.default_device(jax.devices("cpu")[0]):
# data has shape: (num_levels, num_steps, num_envs, ...)
# flatten envs and steps together for learning
data = jax.tree.map(lambda *x: einops.rearrange(jnp.stack(x), "l s e ... -> l (e s) ..."), *data)
# truncate to multiple of batch size
valid_steps = data["obs"].shape[1] - action_chunk_size + 1
data = jax.tree.map(
lambda x: x[:, : (valid_steps // config.batch_size) * config.batch_size + action_chunk_size - 1], data
)
# put on device
data = jax.tree.map(
lambda x: jax.make_array_from_single_device_arrays(
x.shape,
sharding,
[
jax.device_put(y, d)
for y, d in zip(jnp.split(x, jax.local_device_count()), jax.local_devices(), strict=True)
],
),
data,
)
data: generate_data.Data = generate_data.Data(**data)
print(f"Truncated data to {data.obs.shape[1]:_} steps ({valid_steps // config.batch_size:_} batches)")
obs_dim = data.obs.shape[-1]
action_dim = env.action_space(env_params).shape[0]
if config.load_dir is not None:
state_dicts = []
for level_path in config.level_paths:
level_name = level_path.replace("/", "_").replace(".json", "")
with (pathlib.Path(config.load_dir) / "policies" / f"{level_name}.pkl").open("rb") as f:
state_dicts.append(pickle.load(f))
state_dicts = jax.device_put(jax.tree.map(lambda *x: jnp.array(x), *state_dicts))
else:
state_dicts = None
@functools.partial(jax.jit, in_shardings=sharding, out_shardings=sharding)
@jax.vmap
def init(rng: jax.Array, state_dict: dict | None) -> EpochCarry:
rng, key = jax.random.split(rng)
policy = _model.FlowPolicy(
obs_dim=obs_dim,
action_dim=action_dim,
config=config.eval.model,
rngs=nnx.Rngs(key),
)
if state_dict is not None:
graphdef, state = nnx.split(policy)
state.replace_by_pure_dict(state_dict)
policy = nnx.merge(graphdef, state)
total_params = sum(x.size for x in jax.tree.leaves(nnx.state(policy, nnx.Param)))
print(f"Total params: {total_params:,}")
optimizer = nnx.Optimizer(
policy,
optax.chain(
optax.clip_by_global_norm(config.grad_norm_clip),
optax.adamw(
optax.warmup_constant_schedule(0, config.learning_rate, config.lr_warmup_steps),
weight_decay=config.weight_decay,
),
),
)
graphdef, train_state = nnx.split((policy, optimizer))
return EpochCarry(rng, train_state, graphdef)
@functools.partial(jax.jit, donate_argnums=(0,), in_shardings=sharding, out_shardings=sharding)
@jax.vmap
def train_epoch(epoch_carry: EpochCarry, level: kenv_state.EnvState, data: generate_data.Data):
def train_minibatch(carry: tuple[jax.Array, nnx.State], batch_idxs: jax.Array):
rng, train_state = carry
policy, optimizer = nnx.merge(epoch_carry.graphdef, train_state)
rng, key = jax.random.split(rng)
def loss_fn(policy: _model.FlowPolicy):
obs = data.obs[batch_idxs]
action_chunks = data.action[batch_idxs[:, None] + jnp.arange(action_chunk_size)[None, :]]
# zero actions after done
done_chunks = data.done[batch_idxs[:, None] + jnp.arange(action_chunk_size)[None, :]]
done_idxs = jnp.where(
jnp.any(done_chunks, axis=-1),
jnp.argmax(done_chunks, axis=-1),
action_chunk_size,
)
action_chunks = jnp.where(
jnp.arange(action_chunk_size)[None, :, None] >= done_idxs[:, None, None],
0.0,
action_chunks,
)
return policy.loss(key, obs, action_chunks)
loss, grads = nnx.value_and_grad(loss_fn)(policy)
info = {"loss": loss, "grad_norm": optax.global_norm(grads)}
optimizer.update(grads)
_, train_state = nnx.split((policy, optimizer))
return (rng, train_state), info
# shuffle
rng, key = jax.random.split(epoch_carry.rng)
permutation = jax.random.permutation(key, data.obs.shape[0] - action_chunk_size + 1)
# batch
permutation = permutation.reshape(-1, config.batch_size)
# train
(rng, train_state), train_info = jax.lax.scan(
train_minibatch, (epoch_carry.rng, epoch_carry.train_state), permutation
)
train_info = jax.tree.map(lambda x: x.mean(), train_info)
# eval
rng, key = jax.random.split(rng)
eval_policy, _ = nnx.merge(epoch_carry.graphdef, train_state)
eval_info = {}
for horizon in range(1, config.eval.model.action_chunk_size + 1):
eval_config = dataclasses.replace(config.eval, execute_horizon=horizon)
info, _ = _eval.eval(eval_config, env, key, level, eval_policy, env_params, static_env_params)
eval_info.update({f"{k}_{horizon}": v for k, v in info.items()})
video = None
return EpochCarry(rng, train_state, epoch_carry.graphdef), ({**train_info, **eval_info}, video)
wandb.init(project=WANDB_PROJECT)
rng = jax.random.key(config.seed)
epoch_carry = init(jax.random.split(rng, len(config.level_paths)), state_dicts)
for epoch_idx in tqdm.tqdm(range(config.num_epochs)):
epoch_carry, (info, video) = train_epoch(epoch_carry, levels, data)
for i in range(len(config.level_paths)):
level_name = config.level_paths[i].replace("/", "_").replace(".json", "")
wandb.log({f"{level_name}/{k}": v[i] for k, v in info.items()}, step=epoch_idx)
log_dir = LOG_DIR / wandb.run.name / str(epoch_idx)
if video is not None:
video_dir = log_dir / "videos"
video_dir.mkdir(parents=True, exist_ok=True)
imageio.mimwrite(video_dir / f"{level_name}.mp4", video[i], fps=15)
policy_dir = log_dir / "policies"
policy_dir.mkdir(parents=True, exist_ok=True)
level_train_state = jax.tree.map(lambda x: x[i], epoch_carry.train_state)
with (policy_dir / f"{level_name}.pkl").open("wb") as f:
policy, _ = nnx.merge(epoch_carry.graphdef, level_train_state)
state_dict = nnx.state(policy).to_pure_dict()
pickle.dump(state_dict, f)
if __name__ == "__main__":
tyro.cli(main)
================================================
FILE: worlds/l/car_launch.json
================================================
{
"env_state": {
"polygon": [
{
"density": 1,
"velocity": {
"0": 0,
"1": 0
},
"position": {
"0": 2.5,
"1": -4.800000190734863
},
"active": true,
"friction": 1,
"vertices": {
"0": {
"0": 2.5,
"1": 5.199999809265137
},
"1": {
"0": 2.5,
"1": -5.199999809265137
},
"2": {
"0": -2.5,
"1": -5.199999809265137
},
"3": {
"0": -2.5,
"1": 5.199999809265137
}
},
"role": 3,
"rotation": 0,
"angular_velocity": 0,
"restitution": 0,
"collision_mode": 2,
"radius": 0.1,
"inverse_inertia": 0,
"inverse_mass": 0,
"n_vertices": 4
},
{
"velocity": {
"0": 0,
"1": 0
},
"position": {
"0": 0,
"1": 0
},
"density": 1,
"angular_velocity": 0,
"vertices": {
"0": {
"0": -5,
"1": 5
},
"1": {
"0": -0.05000000074505806,
"1": 5
},
"2": {
"0": -0.05000000074505806,
"1": 0
},
"3": {
"0": -5,
"1": 0
}
},
"inverse_mass": 0,
"rotation": 0,
"inverse_inertia": 0,
"role": 0,
"active": true,
"n_vertices": 4,
"restitution": 0,
"radius": 0,
"friction": 1,
"collision_mode": 2
},
{
"n_vertices": 4,
"friction": 1,
"density": 1,
"active": true,
"restitution": 0,
"rotation": 0,
"radius": 0,
"angular_velocity": 0,
"collision_mode": 2,
"role": 0,
"inverse_inertia": 0,
"vertices": {
"0": {
"0": 5,
"1": 5
},
"1": {
"0": 10,
"1": 5
},
"2": {
"0": 10,
"1": 0
},
"3": {
"0": 5,
"1": 0
}
},
"inverse_mass": 0,
"position": {
"0": 0,
"1": 0
},
"velocity": {
"0": 0,
"1": 0
}
},
{
"inverse_mass": 0,
"inverse_inertia": 0,
"collision_mode": 2,
"role": 0,
"n_vertices": 4,
"restitution": 0,
"velocity": {
"0": 0,
"1": 0
},
"vertices": {
"0": {
"0": 2.5,
"1": 5.199999809265137
},
"1": {
"0": 2.5,
"1": -5.199999809265137
},
"2": {
"0": -2.5,
"1": -5.199999809265137
},
"3": {
"0": -2.5,
"1": 5.199999809265137
}
},
"density": 1,
"angular_velocity": 0,
"position": {
"0": 2.5,
"1": 10.199999809265137
},
"friction": 1,
"rotation": 0,
"radius": 0,
"active": true
},
{
"vertices": {
"0": {
"0": 0.6200000047683716,
"1": 0.1325000524520874
},
"1": {
"0": 0.6200000047683716,
"1": -0.1325000524520874
},
"2": {
"0": -0.6200000047683716,
"1": -0.1325000524520874
},
"3": {
"0": -0.6200000047683716,
"1": 0.1325000524520874
}
},
"rotation": 0,
"velocity": {
"0": 0,
"1": 0
},
"density": 1,
"position": {
"0": 1.1749999523162842,
"1": 2.5774998664855957
},
"inverse_inertia": 0,
"n_vertices": 4,
"restitution": 0,
"friction": 1,
"angular_velocity": 0,
"inverse_mass": 0,
"collision_mode": 1,
"radius": 0,
"active": true,
"role": 0
},
{
"restitution": 0,
"active": true,
"rotation": 0,
"friction": 1,
"density": 1,
"inverse_mass": 0,
"radius": 0.1,
"angular_velocity": 0,
"velocity": {
"0": 0,
"1": 0
},
"vertices": {
"0": {
"0": 0.6200000047683716,
"1": 0.1325000524520874
},
"1": {
"0": 0.6200000047683716,
"1": -0.1325000524520874
},
"2": {
"0": -0.6200000047683716,
"1": -0.1325000524520874
},
"3": {
"0": -0.6200000047683716,
"1": 0.1325000524520874
}
},
"position": {
"0": 3.319484374154329,
"1": 2.1775002479553223
},
"collision_mode": 1,
"n_vertices": 4,
"inverse_inertia": 0,
"role": 0
},
{
"active": true,
"role": 0,
"inverse_inertia": 459.3979187011719,
"radius": 0,
"rotation": 0,
"n_vertices": 4,
"friction": 1,
"density": 1,
"velocity": {
"0": 0,
"1": 0
},
"angular_velocity": 0,
"vertices": {
"0": {
"0": 0.27250000834465027,
"1": 0.07499992847442627
},
"1": {
"0": 0.27250000834465027,
"1": -0.07499992847442627
},
"2": {
"0": -0.27250000834465027,
"1": -0.07499992847442627
},
"3": {
"0": -0.27250000834465027,
"1": 0.07499992847442627
}
},
"restitution": 0,
"position": {
"0": 1.0824999809265137,
"1": 2.8199996948242188
},
"collision_mode": 1,
"inverse_mass": 12.232426643371582
},
{
"rotation": 0,
"velocity": {
"0": 0,
"1": 0
},
"vertices": {
"0": {
"0": 0.04999999701976776,
"1": 0.31749996542930603
},
"1": {
"0": 0.04999999701976776,
"1": -0.31749996542930603
},
"2": {
"0": -0.04999999701976776,
"1": -0.31749996542930603
},
"3": {
"0": -0.04999999701976776,
"1": 0.31749996542930603
}
},
"n_vertices": 4,
"inverse_inertia": 457.3209533691406,
"role": 0,
"density": 1,
"position": {
"0": 1.0674998760223389,
"1": 3.0874996185302734
},
"inverse_mass": 15.748034477233887,
"active": true,
"radius": 0.10000000149011612,
"friction": 1,
"angular_velocity": 0,
"restitution": 0,
"collision_mode": 1
},
{
"role": 2,
"angular_velocity": 0,
"position": {
"0": 3.8059806310186377,
"1": 2.5774998664855957
},
"restitution": 0,
"active": true,
"rotation": 0,
"radius": 0.1,
"inverse_mass": 12.610349745258416,
"density": 1,
"friction": 1,
"n_vertices": 4,
"velocity": {
"0": 0,
"1": 0
},
"collision_mode": 1,
"inverse_inertia": 942.0966851360592,
"vertices": {
"0": {
"0": 0.12999987602233887,
"1": 0.15250003337860107
},
"1": {
"0": 0.12999987602233887,
"1": -0.15250003337860107
},
"2": {
"0": -0.12999987602233887,
"1": -0.15250003337860107
},
"3": {
"0": -0.12999987602233887,
"1": 0.15250003337860107
}
}
},
{
"position": {
"0": 2.8107502790242442,
"1": 1.3421653372278328
},
"velocity": {
"0": 0,
"1": 0
},
"angular_velocity": 0,
"rotation": 0,
"inverse_mass": 0,
"inverse_inertia": 0,
"restitution": 0,
"friction": 1,
"vertices": {
"0": {
"0": 0.1091825324387061,
"1": 0.7071067854336324
},
"1": {
"0": 0.1091825324387061,
"1": -0.7071067769394626
},
"2": {
"0": -0.1091825291514397,
"1": -0.7071067769394626
},
"3": {
"0": -0.1091825291514397,
"1": 0.7071067854336324
}
},
"n_vertices": 4,
"radius": 0.1,
"collision_mode": 1,
"active": true,
"density": 1,
"role": 0
},
{
"friction": 1,
"role": 0,
"position": {
"0": 1.8848819732666016,
"1": 2.6214592456817627
},
"radius": 0.10000000149011612,
"density": 1,
"velocity": {
"0": 0,
"1": 0
},
"inverse_mass": 3.015984535217285,
"active": false,
"inverse_inertia": 35.087406158447266,
"n_vertices": 4,
"rotation": 0,
"restitution": 0,
"angular_velocity": 0,
"vertices": {
"0": {
"0": 0.1737148016691208,
"1": 0.47717100381851196
},
"1": {
"0": 0.1737148016691208,
"1": -0.47717100381851196
},
"2": {
"0": -0.1737148016691208,
"1": -0.47717100381851196
},
"3": {
"0": -0.1737148016691208,
"1": 0.47717100381851196
}
},
"collision_mode": 1
},
{
"angular_velocity": 0,
"inverse_inertia": 0,
"n_vertices": 4,
"inverse_mass": 0,
"vertices": {
"0": {
"0": 0,
"1": 0
},
"1": {
"0": 0,
"1": 0
},
"2": {
"0": 0,
"1": 0
},
"3": {
"0": 0,
"1": 0
}
},
"radius": 0,
"friction": 1,
"role": 0,
"active": false,
"rotation": 0,
"density": 1,
"collision_mode": 1,
"velocity": {
"0": 0,
"1": 0
},
"restitution": 0,
"position": {
"0": 0,
"1": 0
}
}
],
"circle": [
{
"angular_velocity": 0,
"friction": 1,
"role": 1,
"vertices": {
"0": {
"0": 0,
"1": 0
},
"1": {
"0": 0,
"1": 0
},
"2": {
"0": 0,
"1": 0
},
"3": {
"0": 0,
"1": 0
}
},
"radius": 0.14,
"inverse_mass": 16.240300315499525,
"collision_mode": 1,
"active": true,
"inverse_inertia": 3314.347003163167,
"position": {
"0": 1.3199999332427979,
"1": 2.804999828338623
},
"n_vertices": 0,
"density": 1,
"rotation": 0,
"restitution": 0,
"velocity": {
"0": 0,
"1": 0
}
},
{
"n_vertices": 0,
"radius": 0.14,
"velocity": {
"0": 0,
"1": 0
},
"inverse_mass": 16.240300315499525,
"angular_velocity": 0,
"collision_mode": 1,
"density": 1,
"restitution": 0,
"rotation": 0,
"role": 0,
"inverse_inertia": 3314.347003163167,
"position": {
"0": 0.8199998140335083,
"1": 2.804999828338623
},
"vertices": {
"0": {
"0": 0,
"1": 0
},
"1": {
"0": 0,
"1": 0
},
"2": {
"0": 0,
"1": 0
},
"3": {
"0": 0,
"1": 0
}
},
"friction": 1,
"active": true
},
{
"density": 1,
"n_vertices": 4,
"inverse_inertia": 2438.58203125,
"friction": 1,
"vertices": {
"0": {
"0": 0.05000000074505806,
"1": 0.05000000074505806
},
"1": {
"0": 0.05000000074505806,
"1": -0.05000000074505806
},
"2": {
"0": -0.05000000074505806,
"1": -0.05000000074505806
},
"3": {
"0": -0.05000000074505806,
"1": 0.05000000074505806
}
},
"velocity": {
"0": 0,
"1": 0
},
"position": {
"0": 2.554999828338623,
"1": 0.7899999618530273
},
"active": false,
"role": 0,
"angular_velocity": 0,
"rotation": 0.5017720460891724,
"restitution": 0,
"collision_mode": 1,
"inverse_mass": 13.930404663085938,
"radius": 0.15116219222545624
},
{
"active": false,
"radius": 0,
"collision_mode": 1,
"density": 1,
"friction": 1,
"angular_velocity": 0,
"role": 0,
"velocity": {
"0": 0,
"1": 0
},
"vertices": {
"0": {
"0": 0,
"1": 0
},
"1": {
"0": 0,
"1": 0
},
"2": {
"0": 0,
"1": 0
},
"3": {
"0": 0,
"1": 0
}
},
"n_vertices": 0,
"inverse_mass": 0,
"position": {
"0": 0,
"1": 0
},
"inverse_inertia": 0,
"rotation": 0,
"restitution": 0
}
],
"joint": [
{
"min_rotation": 0,
"motor_power": 2.049999952316284,
"active": true,
"a_relative_pos": {
"0": -0.26250016689300537,
"1": -0.014999866485595703
},
"acc_impulse": {
"0": 0,
"1": 0
},
"acc_r_impulse": 0,
"motor_binding": 0,
"is_fixed_joint": false,
"b_relative_pos": {
"0": 0,
"1": 0
},
"global_position": {
"0": 0.8199998140335083,
"1": 2.804999828338623
},
"max_rotation": 0,
"motor_on": true,
"motor_speed": 2.049999952316284,
"rotation": 0,
"motor_has_joint_limits": false,
"a_index": 6,
"b_index": 13
},
{
"acc_r_impulse": 0,
"max_rotation": 0,
"b_index": 12,
"a_index": 6,
"motor_on": true,
"a_relative_pos": {
"0": 0.23749995231628418,
"1": -0.014999866485595703
},
"motor_has_joint_limits": false,
"rotation": 0,
"acc_impulse": {
"0": 0,
"1": 0
},
"is_fixed_joint": false,
"motor_binding": 0,
"active": true,
"motor_speed": 2.049999952316284,
"b_relative_pos": {
"0": 0,
"1": 0
},
"motor_power": 2.0999999046325684,
"global_position": {
"0": 1.3199999332427979,
"1": 2.804999828338623
},
"min_rotation": 0
},
{
"a_relative_pos": {
"0": 0,
"1": 0
},
"global_position": {
"0": 1.0824999809265137,
"1": 2.8199996948242188
},
"acc_r_impulse": 0,
"motor_power": 2.299999952316284,
"a_index": 6,
"motor_on": true,
"b_relative_pos": {
"0": 0.015000104904174805,
"1": -0.2674999237060547
},
"is_fixed_joint": false,
"b_index": 7,
"motor_speed": 2.200000047683716,
"min_rotation": 0,
"motor_has_joint_limits": false,
"max_rotation": 0,
"active": true,
"rotation": 0,
"motor_binding": 1,
"acc_impulse": {
"0": 0,
"1": 0
}
},
{
"b_index": 5,
"is_fixed_joint": false,
"acc_impulse": {
"0": 0,
"1": 0
},
"a_index": 4,
"motor_binding": 1,
"motor_speed": 1,
"motor_has_joint_limits": false,
"global_position": {
"0": 1.6524999141693115,
"1": 2.3524999022483826
},
"active": false,
"b_relative_pos": {
"0": -0.41750001907348633,
"1": 0.02249997854232788
},
"motor_on": true,
"a_relative_pos": {
"0": 0.47749996185302734,
"1": -0.22499996423721313
},
"acc_r_impulse": 0,
"motor_power": 3,
"rotation": 0,
"max_rotation": 0,
"min_rotation": 0
},
{
"rotation": 0,
"active": false,
"acc_r_impulse": 0,
"global_position": {
"0": 2.5,
"1": -4.800000190734863
},
"motor_binding": 0,
"a_index": 0,
"acc_impulse": {
"0": 0,
"1": 0
},
"motor_has_joint_limits": false,
"is_fixed_joint": false,
"b_relative_pos": {
"0": 0,
"1": 0
},
"motor_power": 0,
"a_relative_pos": {
"0": 0,
"1": 0
},
"motor_speed": 0,
"motor_on": false,
"max_rotation": 0,
"b_index": 0,
"min_rotation": 0
},
{
"motor_speed": 0,
"motor_has_joint_limits": false,
"is_fixed_joint": false,
"motor_on": false,
"b_relative_pos": {
"0": 0,
"1": 0
},
"a_relative_pos": {
"0": 0,
"1": 0
},
"rotation": 0,
"motor_power": 0,
"b_index": 0,
"max_rotation": 0,
"a_index": 0,
"min_rotation": 0,
"acc_impulse": {
"0": 0,
"1": 0
},
"global_position": {
"0": 2.5,
"1": -4.800000190734863
},
"active": false,
"acc_r_impulse": 0,
"motor_binding": 0
}
],
"thruster": [
{
"thruster_binding": 0,
"power": 0,
"global_position": {
"0": 2.5,
"1": -4.800000190734863
},
"rotation": 0,
"relative_position": {
"0": 0,
"1": 0
},
"active": false,
"object_index": 0
},
{
"power": 0.21093875169754028,
"rotation": 3.370077133178711,
"thruster_binding": 1,
"object_index": 9,
"global_position": {
"0": 2.7205383883863696,
"1": 1.3828215462198372
},
"active": false,
"relative_position": {
"0": -0.0902118906378746,
"1": 0.040656208992004395
}
}
],
"collision_matrix": {
"0": {
"0": false,
"1": true,
"2": true,
"3": true,
"4": true,
"5": true,
"6": true,
"7": true,
"8": true,
"9": true,
"10": true,
"11": true,
"12": true,
"13": true,
"14": true,
"15": true
},
"1": {
"0": true,
"1": false,
"2": true,
"3": true,
"4": true,
"5": true,
"6": true,
"7": true,
"8": true,
"9": true,
"10": true,
"11": true,
"12": true,
"13": true,
"14": true,
"15": true
},
"2": {
"0": true,
"1": true,
"2": false,
"3": true,
"4": true,
"5": true,
"6": true,
"7": true,
"8": true,
"9": true,
"10": true,
"11": true,
"12": true,
"13": true,
"14": true,
"15": true
},
"3": {
"0": true,
"1": true,
"2": true,
"3": false,
"4": true,
"5": true,
"6": true,
"7": true,
"8": true,
"9": true,
"10": true,
"11": true,
"12": true,
"13": true,
"14": true,
"15": true
},
"4": {
"0": true,
"1": true,
"2": true,
"3": true,
"4": false,
"5": true,
"6": true,
"7": true,
"8": true,
"9": true,
"10": true,
"11": true,
"12": true,
"13": true,
"14": true,
"15": true
},
"5": {
"0": true,
"1": true,
"2": true,
"3": true,
"4": true,
"5": false,
"6": true,
"7": true,
"8": true,
"9": true,
"10": true,
"11": true,
"12": true,
"13": true,
"14": true,
"15": true
},
"6": {
"0": true,
"1": true,
"2": true,
"3": true,
"4": true,
"5": true,
"6": false,
"7": false,
"8": true,
"9": true,
"10": true,
"11": true,
"12": false,
"13": false,
"14": true,
"15": true
},
"7": {
"0": true,
"1": true,
"2": true,
"3": true,
"4": true,
"5": true,
"6": false,
"7": false,
"8": true,
"9": true,
"10": true,
"11": true,
"12": false,
"13": false,
"14": true,
"15": true
},
"8": {
"0": true,
"1": true,
"2": true,
"3": true,
"4": true,
"5": true,
"6": true,
"7": true,
"8": false,
"9": true,
"10": true,
"11": true,
"12": true,
"13": true,
"14": true,
"15": true
},
"9": {
"0": true,
"1": true,
"2": true,
"3": true,
"4": true,
"5": true,
"6": true,
"7": true,
"8": true,
"9": false,
"10": true,
"11": true,
"12": true,
"13": true,
"14": true,
"15": true
},
"10": {
"0": true,
"1": true,
"2": true,
"3": true,
"4": true,
"5": true,
"6": true,
"7": true,
"8": true,
"9": true,
"10": false,
"11": true,
"12": true,
"13": true,
"14": true,
"15": true
},
"11": {
"0": true,
"1": true,
"2": true,
"3": true,
"4": true,
"5": true,
"6": true,
"7": true,
"8": true,
"9": true,
"10": true,
"11": false,
"12": true,
"13": true,
"14": true,
"15": true
},
"12": {
"0": true,
"1": true,
"2": true,
"3": true,
"4": true,
"5": true,
"6": false,
"7": false,
"8": true,
"9": true,
"10": true,
"11": true,
"12": false,
"13": false,
"14": true,
"15": true
},
"13": {
"0": true,
"1": true,
"2": true,
"3": true,
"4": true,
"5": true,
"6": false,
"7": false,
"8": true,
"9": true,
"10": true,
"11": true,
"12": false,
"13": false,
"14": true,
"15": true
},
"14": {
"0": true,
"1": true,
"2": true,
"3": true,
"4": true,
"5": true,
"6": true,
"7": true,
"8": true,
"9": true,
"10": true,
"11": true,
"12": true,
"13": true,
"14": false,
"15": true
},
"15": {
"0": true,
"1": true,
"2": true,
"3": true,
"4": true,
"5": true,
"6": true,
"7": true,
"8": true,
"9": true,
"10": true,
"11": true,
"12": true,
"13": true,
"14": true,
"15": false
}
},
"acc_rr_manifolds": [
{
"normal": {
"0": 0,
"1": 0
},
"restitution_velocity_target": 0,
"collision_point": {
"0": 0,
"1": 0
},
"penetration": 0,
"acc_impulse_tangent": 0,
"active": false,
"acc_impulse_normal": 0
},
{
"collision_point": {
"0": 0,
"1": 0
},
"acc_impulse_normal": 0,
"acc_impulse_tangent": 0,
"normal": {
"0": 0,
"1": 0
},
"penetration": 0,
"restitution_velocity_target": 0,
"active": false
},
{
"penetration": 0,
"acc_impulse_normal": 0,
"active": false,
"normal": {
"0": 0,
"1": 0
},
"acc_impulse_tangent": 0,
"collision_point": {
"0": 0,
"1": 0
},
"restitution_velocity_target": 0
},
{
"acc_impulse_tangent": 0,
"restitution_velocity_target": 0,
"penetration": 0,
"collision_point": {
"0": 0,
"1": 0
},
"normal": {
"0": 0,
"1": 0
},
"acc_impulse_normal": 0,
"active": false
},
{
"collision_point": {
"0": 0,
"1": 0
},
"acc_impulse_tangent": 0,
"active": false,
"acc_impulse_normal": 0,
"normal": {
"0": 0,
"1": 0
},
"penetration": 0,
"restitution_velocity_target": 0
},
{
"normal": {
"0": 0,
"1": 0
},
"acc_impulse_tangent": 0,
"penetration": 0,
"acc_impulse_normal": 0,
"collision_point": {
"0": 0,
"1": 0
},
"active": false,
"restitution_velocity_target": 0
},
{
"active": false,
"normal": {
"0": 0,
"1": 0
},
"collision_point": {
"0": 0,
"1": 0
},
"acc_impulse_tangent": 0,
"acc_impulse_normal": 0,
"restitution_velocity_target": 0,
"penetration": 0
},
{
"collision_point": {
"0": 0,
"1": 0
},
"normal": {
"0": 0,
"1": 0
},
"acc_impulse_tangent": 0,
"penetration": 0,
"acc_impulse_normal": 0,
"restitution_velocity_target": 0,
"active": false
},
{
"active": false,
"collision_point": {
"0": 0,
"1": 0
},
"penetration": 0,
"acc_impulse_normal": 0,
"normal": {
"0": 0,
"1": 0
},
"acc_impulse_tangent": 0,
"restitution_velocity_target": 0
},
{
"acc_impulse_tangent": 0,
"penetration": 0,
"collision_point": {
"0": 0,
"1": 0
},
"restitution_velocity_target": 0,
"active": false,
"normal": {
"0": 0,
"1": 0
},
"acc_impulse_normal": 0
},
{
"restitution_velocity_target": 0,
"penetration": 0,
"acc_impulse_tangent": 0,
"acc_impulse_normal": 0,
"active": false,
"normal": {
"0": 0,
"1": 0
},
"collision_point": {
"0": 0,
"1": 0
}
},
{
"collision_point": {
"0": 0,
"1": 0
},
"normal": {
"0": 0,
"1": 0
},
"restitution_velocity_target": 0,
"acc_impulse_normal": 0,
"active": false,
"penetration": 0,
"acc_impulse_tangent": 0
},
{
"restitution_velocity_target": 0,
"acc_impulse_normal": 0,
"normal": {
"0": 0,
"1": 0
},
"penetration": 0,
"collision_point": {
"0": 0,
"1": 0
},
"acc_impulse_tangent": 0,
"active": false
},
{
"collision_point": {
"0": 0,
"1": 0
},
"acc_impulse_normal": 0,
"restitution_velocity_target": 0,
"penetration": 0,
"acc_impulse_tangent": 0,
"normal": {
"0": 0,
"1": 0
},
"active": false
},
{
"penetration": 0,
"acc_impulse_normal": 0,
"restitution_velocity_target": 0,
"active": false,
"acc_impulse_tangent": 0,
"normal": {
"0": 0,
"1": 0
},
"collision_point": {
"0": 0,
"1": 0
}
},
{
"acc_impulse_tangent": 0,
"collision_point": {
"0": 0,
"1": 0
},
"active": false,
"restitution_velocity_target": 0,
"acc_impulse_normal": 0,
"penetration": 0,
"normal": {
"0": 0,
"1": 0
}
},
{
"normal": {
"0": 0,
"1": 0
},
"active": false,
"penetration": 0,
"restitution_velocity_target": 0,
"acc_impulse_tangent": 0,
"collision_point": {
"0": 0,
"1": 0
},
"acc_impulse_normal": 0
},
{
"active": false,
"collision_point": {
"0": 0,
"1": 0
},
"restitution_velocity_target": 0,
"normal": {
"0": 0,
"1": 0
},
"acc_impulse_tangent": 0,
"penetration": 0,
"acc_impulse_normal": 0
},
{
"normal": {
"0": 0,
"1": 0
},
"penetration": 0,
"active": false,
"restitution_velocity_target": 0,
"collision_point": {
"0": 0,
"1": 0
},
"acc_impulse_normal": 0,
"acc_impulse_tangent": 0
},
{
"restitution_velocity_target": 0,
"penetration": 0,
"acc_impulse_tangent": 0,
"collision_point": {
"0": 0,
"1": 0
},
"normal": {
"0": 0,
"1": 0
},
"acc_impulse_normal": 0,
"active": false
},
{
"acc_impulse_tangent": 0,
"collision_point": {
"0": 0,
"1": 0
},
"restitution_velocity_target": 0,
"normal": {
"0": 0,
"1": 0
},
"acc_impulse_normal": 0,
"active": false,
"penetration": 0
},
{
"acc_impulse_normal": 0,
"acc_impulse_tangent": 0,
"normal": {
"0": 0,
"1": 0
},
"collision_point": {
"0": 0,
"1": 0
},
"active": false,
"penetration": 0,
"restitution_velocity_target": 0
},
{
"collision_point": {
"0": 0,
"1": 0
},
"restitution_velocity_target": 0,
"acc_impulse_normal": 0,
"penetration": 0,
"acc_impulse_tangent": 0,
"normal": {
"0": 0,
"1": 0
},
"active": false
},
{
"collision_point": {
"0": 0,
"1": 0
},
"acc_impulse_tangent": 0,
"normal": {
"0": 0,
"1": 0
},
"restitution_velocity_target": 0,
"penetration": 0,
"acc_impulse_normal": 0,
"active": false
},
{
"penetration": 0,
"collision_point": {
"0": 0,
"1": 0
},
"acc_impulse_normal": 0,
"acc_impulse_tangent": 0,
"active": false,
"normal": {
"0": 0,
"1": 0
},
"restitution_velocity_target": 0
},
{
"restitution_velocity_target": 0,
"acc_impulse_normal": 0,
"collision_point": {
"0": 0,
"1": 0
},
"active": false,
"acc_impulse_tangent": 0,
"normal": {
"0": 0,
"1": 0
},
"penetration": 0
},
{
"restitution_velocity_target": 0,
"collision_point": {
"0": 0,
"1": 0
},
"acc_impulse_tangent": 0,
"acc_impulse_normal": 0,
"normal": {
"0": 0,
"1": 0
},
"penetration": 0,
"active": false
},
{
"collision_point": {
"0": 0,
"1": 0
},
"restitution_velocity_target": 0,
"active": false,
"acc_impulse_tangent": 0,
"normal": {
"0": 0,
"1": 0
},
"acc_impulse_normal": 0,
"penetration": 0
},
{
"normal": {
"0": 0,
"1": 0
},
"acc_impulse_tangent": 0,
"active": false,
"collision_point": {
"0": 0,
"1": 0
},
"restitution_velocity_target": 0,
"penetration": 0,
"acc_impulse_normal": 0
},
{
"acc_impulse_normal": 0,
"collision_point": {
"0": 0,
"1": 0
},
"restitution_velocity_target": 0,
"active": false,
"acc_impulse_tangent": 0,
"penetration": 0,
"normal": {
"0": 0,
"1": 0
}
},
{
"normal": {
"0": 0,
"1": 0
},
"active": false,
"restitution_velocity_target": 0,
"acc_impulse_tangent": 0,
"penetration": 0,
"collision_point": {
"0": 0,
"1": 0
},
"acc_impulse_normal": 0
},
{
"restitution_velocity_target": 0,
"acc_impulse_normal": 0,
"acc_impulse_tangent": 0,
"collision_point": {
"0": 0,
"1": 0
},
"normal": {
"0": 0,
"1": 0
},
"penetration": 0,
"active": false
},
{
"normal": {
"0": 0,
"1": 0
},
"acc_impulse_normal": 0,
"active": false,
"penetration": 0,
"acc_impulse_tangent": 0,
"collision_point": {
"0": 0,
"1": 0
},
"restitution_velocity_target": 0
},
{
"acc_impulse_tangent": 0,
"normal": {
"0": 0,
"1": 0
},
"acc_impulse_normal": 0,
"collision_point": {
"0": 0,
"1": 0
},
"active": false,
"restitution_velocity_target": 0,
"penetration": 0
},
{
"normal": {
"0": 0,
"1": 0
},
"penetration": 0,
"acc_impulse_tangent": 0,
"active": false,
"collision_point": {
"0": 0,
"1": 0
},
"acc_impulse_normal": 0,
"restitution_velocity_target": 0
},
{
"penetration": 0,
"acc_impulse_tangent": 0,
"normal": {
"0": 0,
"1": 0
},
"active": false,
"collision_point": {
"0": 0,
"1": 0
},
"acc_impulse_normal": 0,
"restitution_velocity_target": 0
},
{
"normal": {
"0": 0,
"1": 0
},
"penetration": 0,
"active": false,
"acc_impulse_tangent": 0,
"collision_point": {
"0": 0,
"1": 0
},
"restitution_velocity_target": 0,
"acc_impulse_normal": 0
},
{
"acc_impulse_tangent": 0,
"restitution_velocity_target": 0,
"collision_point": {
"0": 0,
"1": 0
},
"acc_impulse_normal": 0,
"penetration": 0,
"normal": {
"0": 0,
"1": 0
},
"active": false
},
{
"acc_impulse_normal": 0,
"restitution_velocity_target": 0,
"collision_point": {
"0": 0,
"1": 0
},
"active": false,
"acc_impulse_tangent": 0,
"penetration": 0,
"normal": {
"0": 0,
"1": 0
}
},
{
"acc_impulse_tangent": 0,
"acc_impulse_normal": 0,
"active": false,
"normal": {
"0": 0,
"1": 0
},
"restitution_velocity_target": 0,
"penetration": 0,
"collision_point": {
"0": 0,
"1": 0
}
},
{
"acc_impulse_tangent": 0,
"acc_impulse_normal": 0,
"active": false,
"penetration": 0,
"restitution_velocity_target": 0,
"collision_point": {
"0": 0,
"1": 0
},
"normal": {
"0": 0,
"1": 0
}
},
{
"acc_impulse_normal": 0,
"penetration": 0,
"collision_point": {
"0": 0,
"1": 0
},
"restitution_velocity_target": 0,
"active": false,
"normal": {
"0": 0,
"1": 0
},
"acc_impulse_tangent": 0
},
{
"restitution_velocity_target": 0,
"active": false,
"acc_impulse_normal": 0,
"penetration": 0,
"acc_impulse_tangent": 0,
"collision_point": {
"0": 0,
"1": 0
},
"normal": {
"0": 0,
"1": 0
}
},
{
"acc_impulse_normal": 0,
"active": false,
"collision_point": {
"0": 0,
"1": 0
},
"penetration": 0,
"normal": {
"0": 0,
"1": 0
},
"restitution_velocity_target": 0,
"acc_impulse_tangent": 0
},
{
"collision_point": {
"0": 0,
"1": 0
},
"acc_impulse_normal": 0,
"restitution_velocity_target": 0,
"acc_impulse_tangent": 0,
"active": false,
"normal": {
"0": 0,
"1": 0
},
"penetration": 0
},
{
"collision_point": {
"0": 0,
"1": 0
},
"restitution_velocity_target": 0,
"penetration": 0,
"acc_impulse_tangent": 0,
"normal": {
"0": 0,
"1": 0
},
"active": false,
"acc_impulse_normal": 0
},
{
"normal": {
"0": 0,
"1": 0
},
"active": false,
"acc_impulse_normal": 0,
"restitution_velocity_target": 0,
"acc_impulse_tangent": 0,
"penetration": 0,
"collision_point": {
"0": 0,
"1": 0
}
},
{
"acc_impulse_tangent": 0,
"penetration": 0,
"active": false,
"collision_point": {
"0": 0,
"1": 0
},
"restitution_velocity_target": 0,
"acc_impulse_normal": 0,
"normal": {
"0": 0,
"1": 0
}
},
{
"collision_point": {
"0": 0,
"1": 0
},
"restitution_velocity_target": 0,
"penetration": 0,
"active": false,
"normal": {
"0": 0,
"1": 0
},
"acc_impulse_tangent": 0,
"acc_impulse_normal": 0
},
{
"active": false,
"acc_impulse_tangent": 0,
"restitution_velocity_target": 0,
"penetration": 0,
"acc_impulse_normal": 0,
"collision_point": {
"0": 0,
"1": 0
},
"normal": {
"0": 0,
"1": 0
}
},
{
"collision_point": {
"0": 0,
"1": 0
},
"acc_impulse_tangent": 0,
"restitution_velocity_target": 0,
"active": false,
"penetration": 0,
"normal": {
"0": 0,
"1": 0
},
"acc_impulse_normal": 0
},
{
"acc_impulse_normal": 0,
"active": false,
"collision_point": {
"0": 0,
"1": 0
},
"restitution_velocity_target": 0,
"penetration": 0,
"acc_impulse_tangent": 0,
"normal": {
"0": 0,
"1": 0
}
},
{
"acc_impulse_tangent": 0,
"normal": {
"0": 0,
"1": 0
},
"restitution_velocity_target": 0,
"penetration": 0,
"collision_point": {
"0": 0,
"1": 0
},
"active": false,
"acc_impulse_normal": 0
},
{
"acc_impulse_normal": 0,
"restitution_velocity_target": 0,
"active": false,
"acc_impulse_tangent": 0,
"penetration": 0,
"normal": {
"0": 0,
"1": 0
},
"collision_point": {
"0": 0,
"1": 0
}
},
{
"normal": {
"0": 0,
"1": 0
},
"restitution_velocity_target": 0,
"collision_point": {
"0": 0,
"1": 0
},
"active": false,
"acc_impulse_tangent": 0,
"acc_impulse_normal": 0,
"penetration": 0
},
{
"collision_point": {
"0": 0,
"1": 0
},
"active": false,
"penetration": 0,
"acc_impulse_normal": 0,
"restitution_velocity_target": 0,
"normal": {
"0": 0,
"1": 0
},
"acc_impulse_tangent": 0
},
{
"restitution_velocity_target": 0,
"collision_point": {
"0": 0,
"1": 0
},
"penetration": 0,
"acc_impulse_tangent": 0,
"acc_impulse_normal": 0,
"normal": {
"0": 0,
"1": 0
},
"active": false
},
{
"penetration": 0,
"active": false,
"collision_point": {
"0": 0,
"1": 0
},
"acc_impulse_normal": 0,
"normal": {
"0": 0,
"1": 0
},
"restitution_velocity_target": 0,
"acc_impulse_tangent": 0
},
{
"acc_impulse_tangent": 0,
"acc_impulse_normal": 0,
"active": false,
"collision_point": {
"0": 0,
"1": 0
},
"normal": {
"0": 0,
"1": 0
},
"restitution_velocity_target": 0,
"penetration": 0
},
{
"penetration": 0,
"collision_point": {
"0": 0,
"1": 0
},
"acc_impulse_normal": 0,
"restitution_velocity_target": 0,
"acc_impulse_tangent": 0,
"active": false,
"normal": {
"0": 0,
"1": 0
}
},
{
"normal": {
"0": 0,
"1": 0
},
"penetration": 0,
"collision_point": {
"0": 0,
"1": 0
},
"active": false,
"restitution_velocity_target": 0,
"acc_impulse_normal": 0,
"acc_impulse_tangent": 0
},
{
"normal": {
"0": 0,
"1": 0
},
"penetration": 0,
"collision_point": {
"0": 0,
"1": 0
},
"active": false,
"restitution_velocity_target": 0,
"acc_impulse_normal": 0,
"acc_impulse_tangent": 0
},
{
"normal": {
"0": 0,
"1": 0
},
"penetration": 0,
"collision_point": {
"0": 0,
"1": 0
},
"active": false,
"restitution_velocity_target": 0,
"acc_impulse_normal": 0,
"acc_impulse_tangent": 0
},
{
"normal": {
"0": 0,
"1": 0
},
"penetration": 0,
"collision_point": {
"0": 0,
"1": 0
},
"active": false,
"restitution_velocity_target": 0,
"acc_impulse_normal": 0,
"acc_impulse_tangent": 0
},
{
"normal": {
"0": 0,
"1": 0
},
"penetration": 0,
"collision_point": {
"0": 0,
"1": 0
},
"active": false,
"restitution_velocity_target": 0,
"acc_impulse_normal": 0,
"acc_impulse_tangent": 0
},
{
"normal": {
"0": 0,
"1": 0
},
"penetration": 0,
"collision_point": {
"0": 0,
"1": 0
},
"active": false,
"restitution_velocity_target": 0,
"acc_impulse_normal": 0,
"acc_impulse_tangent": 0
},
{
"normal": {
"0": 0,
"1": 0
},
"penetration": 0,
"collision_point": {
"0": 0,
"1": 0
},
"active": false,
"restitution_velocity_target": 0,
"acc_impulse_normal": 0,
"acc_impulse_tangent": 0
},
{
"normal": {
"0": 0,
"1": 0
},
"penetration": 0,
"collision_point": {
"0": 0,
"1": 0
},
"active": false,
"restitution_velocity_target": 0,
"acc_impulse_normal": 0,
"acc_impulse_tangent": 0
},
{
"normal": {
"0": 0,
"1": 0
},
"penetration": 0,
"collision_point": {
"0": 0,
"1": 0
},
"active": false,
"restitution_velocity_target": 0,
"acc_impulse_normal": 0,
"acc_impulse_tangent": 0
},
{
"normal": {
"0": 0,
"1": 0
},
"penetration": 0,
"collision_point": {
"0": 0,
"1": 0
},
"active": false,
"restitution_velocity_target": 0,
"acc_impulse_normal": 0,
"acc_impulse_tangent": 0
},
{
"normal": {
"0": 0,
"1": 0
},
"penetration": 0,
"collision_point": {
"0": 0,
"1": 0
},
"active": false,
"restitution_velocity_target": 0,
"acc_impulse_normal": 0,
"acc_impulse_tangent": 0
},
{
"normal": {
"0": 0,
"1": 0
},
"penetration": 0,
"collision_point": {
"0": 0,
"1": 0
},
"active": false,
"restitution_velocity_target": 0,
"acc_impulse_normal": 0,
"acc_impulse_tangent": 0
},
{
"normal": {
"0": 0,
"1": 0
},
"penetration": 0,
"collision_point": {
"0": 0,
"1": 0
},
"active": false,
"restitution_velocity_target": 0,
"acc_impulse_normal": 0,
"acc_impulse_tangent": 0
},
{
"normal": {
"0": 0,
"1": 0
},
"penetration": 0,
"collision_point": {
"0": 0,
"1": 0
},
"active": false,
"restitution_velocity_target": 0,
"acc_impulse_normal": 0,
"acc_impulse_tangent": 0
},
{
"normal": {
"0": 0,
"1": 0
},
"penetration": 0,
"collision_point": {
"0": 0,
"1": 0
},
"active": false,
"restitution_velocity_target": 0,
"acc_impulse_normal": 0,
"acc_impulse_tangent": 0
},
{
"normal": {
"0": 0,
"1": 0
},
"penetration": 0,
"collision_point": {
"0": 0,
"1": 0
},
"active": false,
"restitution_velocity_target": 0,
"acc_impulse_normal": 0,
"acc_impulse_tangent": 0
},
{
"normal": {
"0": 0,
"1": 0
},
"penetration": 0,
"collision_point": {
"0": 0,
"1": 0
},
"active": false,
"restitution_velocity_target": 0,
"acc_impulse_normal": 0,
"acc_impulse_tangent": 0
},
{
"normal": {
"0": 0,
"1": 0
},
"penetration": 0,
"collision_point": {
"0": 0,
"1": 0
},
"active": false,
"restitution_velocity_target": 0,
"acc_impulse_normal": 0,
"acc_impulse_tangent": 0
},
{
"normal": {
"0": 0,
"1": 0
},
"penetration": 0,
"collision_point": {
"0": 0,
"1": 0
},
"active": false,
"restitution_velocity_target": 0,
"acc_impulse_normal": 0,
"acc_impulse_tangent": 0
},
{
"normal": {
"0": 0,
"1": 0
},
"penetration": 0,
"collision_point": {
"0": 0,
"1": 0
},
"active": false,
"restitution_velocity_target": 0,
"acc_impulse_normal": 0,
"acc_impulse_tangent": 0
},
{
"normal": {
"0": 0,
"1": 0
},
"penetration": 0,
"collision_point": {
"0": 0,
"1": 0
},
"active": false,
"restitution_velocity_target": 0,
"acc_impulse_normal": 0,
"acc_impulse_tangent": 0
},
{
"normal": {
"0": 0,
"1": 0
},
"penetration": 0,
"collision_point": {
"0": 0,
"1": 0
},
"active": false,
"restitution_velocity_target": 0,
"acc_impulse_normal": 0,
"acc_impulse_tangent": 0
},
{
"normal": {
"0": 0,
"1": 0
},
"penetration": 0,
"collision_point": {
"0": 0,
"1": 0
},
"active": false,
"restitution_velocity_target": 0,
"acc_impulse_normal": 0,
"acc_impulse_tangent": 0
},
{
"normal": {
"0": 0,
"1": 0
},
"penetration": 0,
"collision_point": {
"0": 0,
"1": 0
},
"active": false,
"restitution_velocity_target": 0,
"acc_impulse_normal": 0,
"acc_impulse_tangent": 0
},
{
"normal": {
"0": 0,
"1": 0
},
"penetration": 0,
"collision_point": {
"0": 0,
"1": 0
},
"active": false,
"restitution_velocity_target": 0,
"acc_impulse_normal": 0,
"acc_impulse_tangent": 0
},
{
"normal": {
"0": 0,
"1": 0
},
"penetration": 0,
"collision_point": {
"0": 0,
"1": 0
},
"active": false,
"restitution_velocity_target": 0,
"acc_impulse_normal": 0,
"acc_impulse_tangent": 0
},
{
"normal": {
"0": 0,
"1": 0
},
"penetration": 0,
"collision_point": {
"0": 0,
"1": 0
},
"active": false,
"restitution_velocity_target": 0,
"acc_impulse_normal": 0,
"acc_impulse_tangent": 0
},
{
"normal": {
"0": 0,
"1": 0
},
"penetration": 0,
"collision_point": {
"0": 0,
"1": 0
},
"active": false,
"restitution_velocity_target": 0,
"acc_impulse_normal": 0,
"acc_impulse_tangent": 0
},
{
"normal": {
"0": 0,
"1": 0
},
"penetration": 0,
"collision_point": {
"0": 0,
"1": 0
},
"active": false,
"restitution_velocity_target": 0,
"acc_impulse_normal": 0,
"acc_impulse_tangent": 0
},
{
"normal": {
"0": 0,
"1": 0
},
"penetration": 0,
"collision_point": {
"0": 0,
"1": 0
},
"active": false,
"restitution_velocity_target": 0,
"acc_impulse_normal": 0,
"acc_impulse_tangent": 0
},
{
"normal": {
"0": 0,
"1": 0
},
"penetration": 0,
"collision_point": {
"0": 0,
"1": 0
},
"active": false,
"restitution_velocity_target": 0,
"acc_impulse_normal": 0,
"acc_impulse_tangent": 0
},
{
"normal": {
"0": 0,
"1": 0
},
"penetration": 0,
"collision_point": {
"0": 0,
"1": 0
},
"active": false,
"restitution_velocity_target": 0,
"acc_impulse_normal": 0,
"acc_impulse_tangent": 0
},
{
"normal": {
"0": 0,
"1": 0
},
"penetration": 0,
"collision_point": {
"0": 0,
"1": 0
},
"active": false,
"restitution_velocity_target": 0,
"acc_impulse_normal": 0,
"acc_impulse_tangent": 0
},
{
"normal": {
"0": 0,
"1": 0
},
"penetration": 0,
"collision_point": {
"0": 0,
"1": 0
},
"active": false,
"restitution_velocity_target": 0,
"acc_impulse_normal": 0,
"acc_impulse_tangent": 0
},
{
"normal": {
"0": 0,
"1": 0
},
"penetration": 0,
"collision_point": {
"0": 0,
"1": 0
},
"active": false,
"restitution_velocity_target": 0,
"acc_impulse_normal": 0,
"acc_impulse_tangent": 0
},
{
"normal": {
"0": 0,
"1": 0
},
"penetration": 0,
"collision_point": {
"0": 0,
"1": 0
},
"active": false,
"restitution_velocity_target": 0,
"acc_impulse_normal": 0,
"acc_impulse_tangent": 0
},
{
"normal": {
"0": 0,
"1": 0
},
"penetration": 0,
"collision_point": {
"0": 0,
"1": 0
},
"active": false,
"restitution_velocity_target": 0,
"acc_impulse_normal": 0,
"acc_impulse_tangent": 0
},
{
"normal": {
"0": 0,
"1": 0
},
"penetration": 0,
"collision_point": {
"0": 0,
"1": 0
},
"active": false,
"restitution_velocity_target": 0,
"acc_impulse_normal": 0,
"acc_impulse_tangent": 0
},
{
"normal": {
"0": 0,
"1": 0
},
"penetration": 0,
"collision_point": {
"0": 0,
"1": 0
},
"active": false,
"restitution_velocity_target": 0,
"acc_impulse_normal": 0,
"acc_impulse_tangent": 0
},
{
"normal": {
"0": 0,
"1": 0
},
"penetration": 0,
"collision_point": {
"0": 0,
"1": 0
},
"active": false,
"restitution_velocity_target": 0,
"acc_impulse_normal": 0,
"acc_impulse_tangent": 0
},
{
"normal": {
"0": 0,
"1": 0
},
"penetration": 0,
"collision_point": {
"0": 0,
"1": 0
},
"active": false,
"restitution_velocity_target": 0,
"acc_impulse_normal": 0,
"acc_impulse_tangent": 0
},
{
"normal": {
"0": 0,
"1": 0
},
"penetration": 0,
"collision_point": {
"0": 0,
"1": 0
},
"active": false,
"restitution_velocity_target": 0,
"acc_impulse_normal": 0,
"acc_impulse_tangent": 0
},
{
"normal": {
"0": 0,
"1": 0
},
"penetration": 0,
"collision_point": {
"0": 0,
"1": 0
},
"active": false,
"restitution_velocity_target": 0,
"acc_impulse_normal": 0,
"acc_impulse_tangent": 0
},
{
"normal": {
"0": 0,
"1": 0
},
"penetration": 0,
"collision_point": {
"0": 0,
"1": 0
},
"active": false,
"restitution_velocity_target": 0,
"acc_impulse_normal": 0,
"acc_impulse_tangent": 0
},
{
"normal": {
"0": 0,
"1": 0
},
"penetration": 0,
"collision_point": {
"0": 0,
"1": 0
},
"active": false,
"restitution_velocity_target": 0,
"acc_impulse_normal": 0,
"acc_impulse_tangent": 0
},
{
"normal": {
"0": 0,
"1": 0
},
"penetration": 0,
"collision_point": {
"0": 0,
"1": 0
},
"active": false,
"restitution_velocity_target": 0,
"acc_impulse_normal": 0,
"acc_impulse_tangent": 0
},
{
"normal": {
"0": 0,
"1": 0
},
"penetration": 0,
"collision_point": {
"0": 0,
"1": 0
},
"active": false,
"restitution_velocity_target": 0,
"acc_impulse_normal": 0,
"acc_impulse_tangent": 0
},
{
"normal": {
"0": 0,
"1": 0
},
"penetration": 0,
"collision_point": {
"0": 0,
"1": 0
},
"active": false,
"restitution_velocity_target": 0,
"acc_impulse_normal": 0,
"acc_impulse_tangent": 0
},
{
"normal": {
"0": 0,
"1": 0
},
"penetration": 0,
"collision_point": {
"0": 0,
"1": 0
},
"active": false,
"restitution_velocity_target": 0,
"acc_impulse_normal": 0,
"acc_impulse_tangent": 0
},
{
"normal": {
"0": 0,
"1": 0
},
"penetration": 0,
"collision_point": {
"0": 0,
"1": 0
},
"active": false,
"restitution_velocity_target": 0,
"acc_impulse_normal": 0,
"acc_impulse_tangent": 0
},
{
"normal": {
"0": 0,
"1": 0
},
"penetration": 0,
"collision_point": {
"0": 0,
"1": 0
},
"active": false,
"restitution_velocity_target": 0,
"acc_impulse_normal": 0,
"acc_impulse_tangent": 0
},
{
"normal": {
"0": 0,
"1": 0
},
"penetration": 0,
"collision_point": {
"0": 0,
"1": 0
},
"active": false,
"restitution_velocity_target": 0,
"acc_impulse_normal": 0,
"acc_impulse_tangent": 0
},
{
"normal": {
"0": 0,
"1": 0
},
"penetration": 0,
"collision_point": {
"0": 0,
"1": 0
},
"active": false,
"restitution_velocity_target": 0,
"acc_impulse_normal": 0,
"acc_impulse_tangent": 0
},
{
"normal": {
"0": 0,
"1": 0
},
"penetration": 0,
"collision_point": {
"0": 0,
"1": 0
},
"active": false,
"restitution_velocity_target": 0,
"acc_impulse_normal": 0,
"acc_impulse_tangent": 0
},
{
"normal": {
"0": 0,
"1": 0
},
"penetration": 0,
"collision_point": {
"0": 0,
"1": 0
},
"active": false,
"restitution_velocity_target": 0,
"acc_impulse_normal": 0,
"acc_impulse_tangent": 0
},
{
"normal": {
"0": 0,
"1": 0
},
"penetration": 0,
"collision_point": {
"0": 0,
"1": 0
},
"active": false,
"restitution_velocity_target": 0,
"acc_impulse_normal": 0,
"acc_impulse_tangent": 0
},
{
"normal": {
"0": 0,
"1": 0
},
"penetration": 0,
"collision_point": {
"0": 0,
"1": 0
},
"active": false,
"restitution_velocity_target": 0,
"acc_impulse_normal": 0,
"acc_impulse_tangent": 0
},
{
"normal": {
"0": 0,
"1": 0
},
"penetration": 0,
"collision_point": {
"0": 0,
"1": 0
},
"active": false,
"restitution_velocity_target": 0,
"acc_impulse_normal": 0,
"acc_impulse_tangent": 0
},
{
"normal": {
"0": 0,
"1": 0
},
"penetration": 0,
"collision_point": {
"0": 0,
"1": 0
},
"active": false,
"restitution_velocity_target": 0,
"acc_impulse_normal": 0,
"acc_impulse_tangent": 0
},
{
"normal": {
"0": 0,
"1": 0
},
"penetration": 0,
"collision_point": {
"0": 0,
"1": 0
},
"active": false,
"restitution_velocity_target": 0,
"acc_impulse_normal": 0,
"acc_impulse_tangent": 0
}
],
"acc_cr_manifolds": [
{
"acc_impulse_normal": 0,
"acc_impulse_tangent": 0,
"normal": {
"0": 0,
"1": 0
},
"penetration": 0,
"collision_point": {
"0": 0,
"1": 0
},
"restitution_velocity_target": 0,
"active": false
},
{
"acc_impulse_tangent": 0,
"normal": {
"0": 0,
"1": 0
},
"acc_impulse_normal": 0,
"collision_point": {
"0": 0,
"1": 0
},
"penetration": 0,
"active": false,
"restitution_velocity_target": 0
},
{
"active": false,
"collision_point": {
"0": 0,
"1": 0
},
"acc_impulse_normal": 0,
"normal": {
"0": 0,
"1": 0
},
"acc_impulse_tangent": 0,
"restitution_velocity_target": 0,
"penetration": 0
},
{
"active": false,
"acc_impulse_tangent": 0,
"acc_impulse_normal": 0,
"restitution_velocity_target": 0,
"penetration": 0,
"normal": {
"0": 0,
"1": 0
},
"collision_point": {
"0": 0,
"1": 0
}
},
{
"acc_impulse_tangent": 0,
"restitution_velocity_target": 0,
"penetration": 0,
"acc_impulse_normal": 0,
"normal": {
"0": 0,
"1": 0
},
"active": false,
"collision_point": {
"0": 0,
"1": 0
}
},
{
"restitution_velocity_target": 0,
"acc_impulse_tangent": 0,
"acc_impulse_normal": 0,
"penetration": 0,
"active": false,
"normal": {
"0": 0,
"1": 0
},
"collision_point": {
"0": 0,
"1": 0
}
},
{
"normal": {
"0": 0,
"1": 0
},
"penetration": 0,
"active": false,
"collision_point": {
"0": 0,
"1": 0
},
"restitution_velocity_target": 0,
"acc_impulse_tangent": 0,
"acc_impulse_normal": 0
},
{
"normal": {
"0": 0,
"1": 0
},
"acc_impulse_normal": 0,
"restitution_velocity_target": 0,
"active": false,
"collision_point": {
"0": 0,
"1": 0
},
"acc_impulse_tangent": 0,
"penetration": 0
},
{
"penetration": 0,
"collision_point": {
"0": 0,
"1": 0
},
"restitution_velocity_target": 0,
"active": false,
"acc_impulse_normal": 0,
"acc_impulse_tangent": 0,
"normal": {
"0": 0,
"1": 0
}
},
{
"acc_impulse_tangent": 0,
"restitution_velocity_target": 0,
"acc_impulse_normal": 0,
"penetration": 0,
"normal": {
"0": 0,
"1": 0
},
"active": false,
"collision_point": {
"0": 0,
"1": 0
}
},
{
"acc_impulse_normal": 0,
"restitution_velocity_target": 0,
"normal": {
"0": 0,
"1": 0
},
"acc_impulse_tangent": 0,
"penetration": 0,
"collision_point": {
"0": 0,
"1": 0
},
"active": false
},
{
"active": false,
"normal": {
"0": 0,
"1": 0
},
"restitution_velocity_target": 0,
"penetration": 0,
"acc_impulse_tangent": 0,
"collision_point": {
"0": 0,
"1": 0
},
"acc_impulse_normal": 0
},
{
"acc_impulse_tangent": 0,
"active": false,
"penetration": 0,
"collision_point": {
"0": 0,
"1": 0
},
"acc_impulse_normal": 0,
"normal": {
"0": 0,
"1": 0
},
"restitution_velocity_target": 0
},
{
"normal": {
"0": 0,
"1": 0
},
"penetration": 0,
"collision_point": {
"0": 0,
"1": 0
},
"restitution_velocity_target": 0,
"acc_impulse_normal": 0,
"acc_impulse_tangent": 0,
"active": false
},
{
"restitution_velocity_target": 0,
"active": false,
"acc_impulse_tangent": 0,
"collision_point": {
"0": 0,
"1": 0
},
"acc_impulse_normal": 0,
"normal": {
"0": 0,
"1": 0
},
"penetration": 0
},
{
"active": false,
"restitution_velocity_target": 0,
"acc_impulse_normal": 0,
"collision_point": {
"0": 0,
"1": 0
},
"normal": {
"0": 0,
"1": 0
},
"acc_impulse_tangent": 0,
"penetration": 0
},
{
"normal": {
"0": 0,
"1": 0
},
"collision_point": {
"0": 0,
"1": 0
},
"penetration": 0,
"acc_impulse_tangent": 0,
"acc_impulse_normal": 0,
"restitution_velocity_target": 0,
"active": false
},
{
"normal": {
"0": 0,
"1": 0
},
"restitution_velocity_target": 0,
"acc_impulse_tangent": 0,
"active": false,
"collision_point": {
"0": 0,
"1": 0
},
"acc_impulse_normal": 0,
"penetration": 0
},
{
"restitution_velocity_target": 0,
"acc_impulse_normal": 0,
"active": false,
"collision_point": {
"0": 0,
"1": 0
},
"penetration": 0,
"normal": {
"0": 0,
"1": 0
},
"acc_impulse_tangent": 0
},
{
"collision_point": {
"0": 0,
"1": 0
},
"active": false,
"acc_impulse_tangent": 0,
"normal": {
"0": 0,
"1": 0
},
"acc_impulse_normal": 0,
"restitution_velocity_target": 0,
"penetration": 0
},
{
"acc_impulse_tangent": 0,
"normal": {
"0": 0,
"1": 0
},
"active": false,
"collision_point": {
"0": 0,
"1": 0
},
"acc_impulse_normal": 0,
"penetration": 0,
"restitution_velocity_target": 0
},
{
"collision_point": {
"0": 0,
"1": 0
},
"acc_impulse_tangent": 0,
"acc_impulse_normal": 0,
"restitution_velocity_target": 0,
"active": false,
"penetration": 0,
"normal": {
"0": 0,
"1": 0
}
},
{
"normal": {
"0": 0,
"1": 0
},
"restitution_velocity_target": 0,
"collision_point": {
"0": 0,
"1": 0
},
"acc_impulse_normal": 0,
"acc_impulse_tangent": 0,
"penetration": 0,
"active": false
},
{
"acc_impulse_normal": 0,
"normal": {
"0": 0,
"1": 0
},
"penetration": 0,
"restitution_velocity_target": 0,
"collision_point": {
"0": 0,
"1": 0
},
"active": false,
"acc_impulse_tangent": 0
},
{
"penetration": 0,
"acc_impulse_normal": 0,
"active": false,
"acc_impulse_tangent": 0,
"collision_point": {
"0": 0,
"1": 0
},
"normal": {
"0": 0,
"1": 0
},
"restitution_velocity_target": 0
},
{
"active": false,
"acc_impulse_normal": 0,
"collision_point": {
"0": 0,
"1": 0
},
"restitution_velocity_target": 0,
"acc_impulse_tangent": 0,
"normal": {
"0": 0,
"1": 0
},
"penetration": 0
},
{
"acc_impulse_normal": 0,
"penetration": 0,
"collision_point": {
"0": 0,
"1": 0
},
"acc_impulse_tangent": 0,
"active": false,
"normal": {
"0": 0,
"1": 0
},
"restitution_velocity_target": 0
},
{
"normal": {
"0": 0,
"1": 0
},
"acc_impulse_tangent": 0,
"penetration": 0,
"restitution_velocity_target": 0,
"acc_impulse_normal": 0,
"active": false,
"collision_point": {
"0": 0,
"1": 0
}
},
{
"penetration": 0,
"restitution_velocity_target": 0,
"active": false,
"acc_impulse_tangent": 0,
"normal": {
"0": 0,
"1": 0
},
"acc_impulse_normal": 0,
"collision_point": {
"0": 0,
"1": 0
}
},
{
"active": false,
"normal": {
"0": 0,
"1": 0
},
"acc_impulse_tangent": 0,
"restitution_velocity_target": 0,
"penetration": 0,
"collision_point": {
"0": 0,
"1": 0
},
"acc_impulse_normal": 0
},
{
"penetration": 0,
"normal": {
"0": 0,
"1": 0
},
"acc_impulse_tangent": 0,
"collision_point": {
"0": 0,
"1": 0
},
"restitution_velocity_target": 0,
"active": false,
"acc_impulse_normal": 0
},
{
"restitution_velocity_target": 0,
"acc_impulse_normal": 0,
"normal": {
"0": 0,
"1": 0
},
"active": false,
"collision_point": {
"0": 0,
"1": 0
},
"acc_impulse_tangent": 0,
"penetration": 0
},
{
"acc_impulse_normal": 0,
"active": false,
"normal": {
"0": 0,
"1": 0
},
"penetration": 0,
"restitution_velocity_target": 0,
"acc_impulse_tangent": 0,
"collision_point": {
"0": 0,
"1": 0
}
},
{
"acc_impulse_tangent": 0,
"acc_impulse_normal": 0,
"collision_point": {
"0": 0,
"1": 0
},
"penetration": 0,
"normal": {
"0": 0,
"1": 0
},
"active": false,
"restitution_velocity_target": 0
},
{
"acc_impulse_tangent": 0,
"penetration": 0,
"normal": {
"0": 0,
"1": 0
},
"acc_impulse_normal": 0,
"collision_point": {
"0": 0,
"1": 0
},
"active": false,
"restitution_velocity_target": 0
},
{
"active": false,
"normal": {
"0": 0,
"1": 0
},
"acc_impulse_normal": 0,
"collision_point": {
"0": 0,
"1": 0
},
"penetration": 0,
"restitution_velocity_target": 0,
"acc_impulse_tangent": 0
},
{
"acc_impulse_tangent": 0,
"acc_impulse_normal": 0,
"collision_point": {
"0": 0,
"1": 0
},
"active": false,
"restitution_velocity_target": 0,
"penetration": 0,
"normal": {
"0": 0,
"1": 0
}
},
{
"acc_impulse_normal": 0,
"active": false,
"acc_impulse_tangent": 0,
"normal": {
"0": 0,
"1": 0
},
"restitution_velocity_target": 0,
"collision_point": {
"0": 0,
"1": 0
},
"penetration": 0
},
{
"acc_impulse_tangent": 0,
"normal": {
"0": 0,
"1": 0
},
"restitution_velocity_target": 0,
"penetration": 0,
"collision_point": {
"0": 0,
"1": 0
},
"active": false,
"acc_impulse_normal": 0
},
{
"normal": {
"0": 0,
"1": 0
},
"penetration": 0,
"acc_impulse_normal": 0,
"acc_impulse_tangent": 0,
"active": false,
"restitution_velocity_target": 0,
"collision_point": {
"0": 0,
"1": 0
}
},
{
"active": false,
"penetration": 0,
"collision_point": {
"0": 0,
"1": 0
},
"acc_impulse_normal": 0,
"normal": {
"0": 0,
"1": 0
},
"acc_impulse_tangent": 0,
"restitution_velocity_target": 0
},
{
"restitution_velocity_target": 0,
"normal": {
"0": 0,
"1": 0
},
"acc_impulse_tangent": 0,
"acc_impulse_normal": 0,
"active": false,
"penetration": 0,
"collision_point": {
"0": 0,
"1": 0
}
},
{
"acc_impulse_tangent": 0,
"normal": {
"0": 0,
"1": 0
},
"acc_impulse_normal": 0,
"collision_point": {
"0": 0,
"1": 0
},
"active": false,
"penetration": 0,
"restitution_velocity_target": 0
},
{
"penetration": 0,
"collision_point": {
"0": 0,
"1": 0
},
"active": false,
"acc_impulse_tangent": 0,
"acc_impulse_normal": 0,
"restitution_velocity_target": 0,
"normal": {
"0": 0,
"1": 0
}
},
{
"acc_impulse_normal": 0,
"active": false,
"penetration": 0,
"collision_point": {
"0": 0,
"1": 0
},
"acc_impulse_tangent": 0,
"restitution_velocity_target": 0,
"normal": {
"0": 0,
"1": 0
}
},
{
"collision_point": {
"0": 0,
"1": 0
},
"restitution_velocity_target": 0,
"acc_impulse_normal": 0,
"penetration": 0,
"normal": {
"0": 0,
"1": 0
},
"active": false,
"acc_impulse_tangent": 0
},
{
"penetration": 0,
"collision_point": {
"0": 0,
"1": 0
},
"acc_impulse_tangent": 0,
"acc_impulse_normal": 0,
"restitution_velocity_target": 0,
"normal": {
"0": 0,
"1": 0
},
"active": false
},
{
"restitution_velocity_target": 0,
"penetration": 0,
"active": false,
"collision_point": {
"0": 0,
"1": 0
},
"acc_impulse_normal": 0,
"acc_impulse_tangent": 0,
"normal": {
"0": 0,
"1": 0
}
}
],
"acc_cc_manifolds": [
{
"normal": {
"0": 0,
"1": 0
},
"collision_point": {
"0": 0,
"1": 0
},
"acc_impulse_tangent": 0,
"restitution_velocity_target": 0,
"active": false,
"penetration": 0,
"acc_impulse_normal": 0
},
{
"normal": {
"0": 0,
"1": 0
},
"collision_point": {
"0": 0,
"1": 0
},
"active": false,
"penetration": 0,
"acc_impulse_normal": 0,
"acc_impulse_tangent": 0,
"restitution_velocity_target": 0
},
{
"acc_impulse_tangent": 0,
"active": false,
"acc_impulse_normal": 0,
"collision_point": {
"0": 0,
"1": 0
},
"restitution_velocity_target": 0,
"normal": {
"0": 0,
"1": 0
},
"penetration": 0
},
{
"active": false,
"acc_impulse_normal": 0,
"restitution_velocity_target": 0,
"penetration": 0,
"collision_point": {
"0": 0,
"1": 0
},
"normal": {
"0": 0,
"1": 0
},
"acc_impulse_tangent": 0
},
{
"penetration": 0,
"acc_impulse_tangent": 0,
"normal": {
"0": 0,
"1": 0
},
"active": false,
"collision_point": {
"0": 0,
"1": 0
},
"restitution_velocity_target": 0,
"acc_impulse_normal": 0
},
{
"normal": {
"0": 0,
"1": 0
},
"penetration": 0,
"restitution_velocity_target": 0,
"acc_impulse_normal": 0,
"acc_impulse_tangent": 0,
"active": false,
"collision_point": {
"0": 0,
"1": 0
}
}
],
"gravity": {
"0": 0,
"1": -9.8100004196167
}
},
"env_params": {
"dt": 0.016666666666666666,
"slop": 0.01,
"baumgarte_coefficient_joints_v": 2,
"baumgarte_coefficient_joints_p": 0.7,
"baumgarte_coefficient_fjoint_av": 2,
"baumgarte_coefficient_rjoint_limit_av": 5,
"baumgarte_coefficient_collision": 0.2,
"joint_stiffness": 0.6,
"clip_position": 15,
"clip_velocity": 100,
"clip_angular_velocity": 50,
"base_motor_speed": 6,
"base_motor_power": 900,
"base_thruster_power": 10,
"motor_decay_coefficient": 0.1,
"motor_joint_limit": 0.1,
"base_friction": 0.4,
"pixels_per_unit": 100,
"max_timesteps": 256,
"dense_reward_scale": 0.1,
"num_shape_roles": 4
},
"static_env_params": {
"num_polygons": 12,
"num_circles": 4,
"num_joints": 6,
"num_thrusters": 2,
"max_polygon_vertices": 4,
"num_solver_iterations": 10,
"solver_batch_size": 16,
"do_warm_starting": true,
"num_static_fixated_polys": 4,
"screen_dim": {
"0": 500,
"1": 500
},
"max_shape_size": 2,
"num_motor_bindings": 4,
"num_thruster_bindings": 2,
"frame_skip": 2
},
"version": "1.0.0"
}
================================================
FILE: worlds/l/cartpole_thrust.json
================================================
{"env_state":{"polygon":[{"active":true,"collision_mode":2,"inverse_mass":0,"friction":1,"rotation":0,"position":{"0":2.5,"1":-4.800000190734863},"vertices":{"0":{"0":2.5,"1":5.199999809265137},"1":{"0":2.5,"1":-5.199999809265137},"2":{"0":-2.5,"1":-5.199999809265137},"3":{"0":-2.5,"1":5.199999809265137}},"restitution":0,"n_vertices":4,"inverse_inertia":0,"velocity":{"0":0,"1":0},"angular_velocity":0,"density":1,"role":3,"radius":0},{"angular_velocity":0,"role":0,"inverse_mass":0,"inverse_inertia":0,"vertices":{"0":{"0":-5,"1":5},"1":{"0":-0.05000000074505806,"1":5},"2":{"0":-0.05000000074505806,"1":0},"3":{"0":-5,"1":0}},"collision_mode":2,"friction":1,"rotation":0,"density":1,"restitution":0,"position":{"0":0,"1":0},"velocity":{"0":0,"1":0},"n_vertices":4,"radius":0,"active":true},{"vertices":{"0":{"0":5,"1":5},"1":{"0":10,"1":5},"2":{"0":10,"1":0},"3":{"0":5,"1":0}},"restitution":0,"active":true,"position":{"0":0,"1":0},"radius":0,"inverse_inertia":0,"role":0,"angular_velocity":0,"inverse_mass":0,"n_vertices":4,"rotation":0,"friction":1,"velocity":{"0":0,"1":0},"collision_mode":2,"density":1},{"vertices":{"0":{"0":2.5,"1":5.199999809265137},"1":{"0":2.5,"1":-5.199999809265137},"2":{"0":-2.5,"1":-5.199999809265137},"3":{"0":-2.5,"1":5.199999809265137}},"friction":1,"angular_velocity":0,"active":true,"radius":0,"rotation":0,"inverse_mass":0,"restitution":0,"position":{"0":2.5,"1":10.199999809265137},"velocity":{"0":0,"1":0},"collision_mode":2,"n_vertices":4,"density":1,"role":0,"inverse_inertia":0},{"friction":1,"restitution":0,"vertices":{"0":{"0":0.4549999237060547,"1":0.16750000417232513},"1":{"0":0.4549999237060547,"1":-0.16750000417232513},"2":{"0":-0.4549999237060547,"1":-0.16750000417232513},"3":{"0":-0.4549999237060547,"1":0.16750000417232513}},"position":{"0":2.1449999809265137,"1":0.5724999904632568},"active":true,"inverse_inertia":41.86173629760742,"angular_velocity":0,"density":1,"n_vertices":4,"collision_mode":1,"velocity":{"0":0,"1":0},"rotation":0,"radius":0,"role":0,"inverse_mass":3.280302047729492},{"velocity":{"0":0,"1":0},"angular_velocity":0,"restitution":0,"density":1,"vertices":{"0":{"0":0.06749999523162842,"1":0.6399999856948853},"1":{"0":0.06749999523162842,"1":-0.6399999856948853},"2":{"0":-0.06749999523162842,"1":-0.6399999856948853},"3":{"0":-0.06749999523162842,"1":0.6399999856948853}},"radius":0,"n_vertices":4,"role":1,"active":true,"position":{"0":2.127500057220459,"1":1.2200000286102295},"collision_mode":1,"rotation":0,"inverse_mass":5.7870378494262695,"friction":1,"inverse_inertia":41.91923141479492},{"velocity":{"0":0,"1":0},"rotation":0,"inverse_inertia":0,"friction":1,"inverse_mass":0,"angular_velocity":0,"n_vertices":4,"position":{"0":0,"1":1.1021068096160889},"active":true,"restitution":0,"role":3,"radius":0,"vertices":{"0":{"0":0.05000000074505806,"1":0.7071067690849304},"1":{"0":0.05000000074505806,"1":-0.7071067690849304},"2":{"0":-0.05000000074505806,"1":-0.7071067690849304},"3":{"0":-0.05000000074505806,"1":0.7071067690849304}},"density":1,"collision_mode":1},{"position":{"0":5,"1":1.1021068096160889},"friction":1,"density":1,"angular_velocity":0,"inverse_inertia":0,"collision_mode":1,"restitution":0,"velocity":{"0":0,"1":0},"radius":0,"rotation":0,"role":3,"inverse_mass":0,"active":true,"n_vertices":4,"vertices":{"0":{"0":0.05000000074505806,"1":0.7071067690849304},"1":{"0":0.05000000074505806,"1":-0.7071067690849304},"2":{"0":-0.05000000074505806,"1":-0.7071067690849304},"3":{"0":-0.05000000074505806,"1":0.7071067690849304}}},{"inverse_inertia":0,"inverse_mass":0,"angular_velocity":0,"density":1,"friction":1,"role":2,"velocity":{"0":0,"1":0},"rotation":0,"collision_mode":1,"position":{"0":4.436287887179349,"1":1.7776370309706886},"restitution":0,"active":true,"n_vertices":4,"radius":0.1,"vertices":{"0":{"0":0.3349999785423279,"1":0.1725001335144043},"1":{"0":0.3349999785423279,"1":-0.1725001335144043},"2":{"0":-0.3349999785423279,"1":-0.1725001335144043},"3":{"0":-0.3349999785423279,"1":0.1725001335144043}}},{"inverse_inertia":0,"restitution":0,"collision_mode":1,"n_vertices":4,"angular_velocity":0,"vertices":{"0":{"0":0,"1":0},"1":{"0":0,"1":0},"2":{"0":0,"1":0},"3":{"0":0,"1":0}},"active":false,"friction":1,"radius":0,"inverse_mass":0,"density":1,"position":{"0":0,"1":0},"velocity":{"0":0,"1":0},"role":0,"rotation":0},{"rotation":0,"radius":0,"velocity":{"0":0,"1":0},"density":1,"n_vertices":4,"vertices":{"0":{"0":0,"1":0},"1":{"0":0,"1":0},"2":{"0":0,"1":0},"3":{"0":0,"1":0}},"inverse_inertia":0,"angular_velocity":0,"position":{"0":0,"1":0},"collision_mode":1,"active":false,"friction":1,"inverse_mass":0,"restitution":0,"role":0},{"radius":0,"n_vertices":4,"restitution":0,"collision_mode":1,"friction":1,"rotation":0,"active":false,"angular_velocity":0,"role":0,"velocity":{"0":0,"1":0},"vertices":{"0":{"0":0,"1":0},"1":{"0":0,"1":0},"2":{"0":0,"1":0},"3":{"0":0,"1":0}},"inverse_mass":0,"inverse_inertia":0,"position":{"0":0,"1":0},"density":1}],"circle":[{"active":false,"velocity":{"0":0,"1":0},"friction":1,"inverse_inertia":0,"role":0,"rotation":0,"collision_mode":1,"radius":0,"inverse_mass":0,"n_vertices":0,"density":1,"restitution":0,"angular_velocity":0,"position":{"0":0,"1":0},"vertices":{"0":{"0":0,"1":0},"1":{"0":0,"1":0},"2":{"0":0,"1":0},"3":{"0":0,"1":0}}},{"rotation":0,"angular_velocity":0,"vertices":{"0":{"0":0,"1":0},"1":{"0":0,"1":0},"2":{"
gitextract_sgcc59ek/
├── .gitignore
├── .gitmodules
├── LICENSE
├── README.md
├── pyproject.toml
├── src/
│ ├── eval_flow.py
│ ├── generate_data.py
│ ├── model.py
│ ├── render_levels.py
│ ├── train_expert.py
│ └── train_flow.py
└── worlds/
└── l/
├── car_launch.json
├── cartpole_thrust.json
├── catapult.json
├── catcher_v3.json
├── chain_lander.json
├── grasp_easy.json
├── h17_unicycle.json
├── hard_lunar_lander.json
├── mjc_half_cheetah.json
├── mjc_swimmer.json
├── mjc_walker.json
└── trampoline.json
SYMBOL INDEX (75 symbols across 6 files)
FILE: src/eval_flow.py
class NaiveMethodConfig (line 25) | class NaiveMethodConfig:
class RealtimeMethodConfig (line 30) | class RealtimeMethodConfig:
class BIDMethodConfig (line 36) | class BIDMethodConfig:
class EvalConfig (line 42) | class EvalConfig:
function eval (line 55) | def eval(
function main (line 171) | def main(
FILE: src/generate_data.py
class Config (line 25) | class Config:
class Data (line 54) | class Data:
class StepCarry (line 64) | class StepCarry:
function main (line 71) | def main(config: Config):
FILE: src/model.py
class ModelConfig (line 12) | class ModelConfig:
function posemb_sincos (line 21) | def posemb_sincos(pos: jax.Array, embedding_dim: int, min_period: float,...
function get_prefix_weights (line 40) | def get_prefix_weights(start: int, end: int, total: int, schedule: Prefi...
class MLPMixerBlock (line 66) | class MLPMixerBlock(nnx.Module):
method __init__ (line 67) | def __init__(
method __call__ (line 79) | def __call__(self, x: jax.Array, adaln_cond: jax.Array) -> jax.Array:
class FlowPolicy (line 103) | class FlowPolicy(nnx.Module):
method __init__ (line 104) | def __init__(
method __call__ (line 140) | def __call__(self, obs: jax.Array, x_t: jax.Array, time: jax.Array) ->...
method action (line 160) | def action(self, rng: jax.Array, obs: jax.Array, num_steps: int) -> ja...
method bid_action (line 173) | def bid_action(
method realtime_action (line 219) | def realtime_action(
method loss (line 267) | def loss(self, rng: jax.Array, obs: jax.Array, action: jax.Array):
FILE: src/render_levels.py
function load_levels (line 23) | def load_levels(paths):
function main (line 38) | def main():
FILE: src/train_expert.py
class Config (line 26) | class Config:
class BatchEnvWrapper (line 78) | class BatchEnvWrapper(wrappers.GymnaxWrapper):
method __init__ (line 81) | def __init__(self, env, num: int):
method reset (line 85) | def reset(self, rng, params):
method reset_to_level (line 88) | def reset_to_level(self, rng, level, params):
method step (line 93) | def step(self, rng, state, action, params):
class DenseRewardState (line 98) | class DenseRewardState:
class DenseRewardWrapper (line 104) | class DenseRewardWrapper(wrappers.GymnaxWrapper):
method __init__ (line 105) | def __init__(self, env):
method step (line 108) | def step(self, key, state, action, params=None):
method reset (line 116) | def reset(self, rng, params=None):
method reset_to_level (line 120) | def reset_to_level(self, rng, level, params=None):
class ActionHistoryWrapper (line 125) | class ActionHistoryWrapper(wrappers.UnderspecifiedEnvWrapper):
method __init__ (line 126) | def __init__(self, env):
method step_env (line 129) | def step_env(self, key, state, action, params):
method reset_to_level (line 134) | def reset_to_level(self, rng, level, params):
method action_space (line 139) | def action_space(self, params):
class NoisyActionWrapper (line 143) | class NoisyActionWrapper(wrappers.UnderspecifiedEnvWrapper):
method __init__ (line 144) | def __init__(self, env):
method step_env (line 147) | def step_env(self, key, state, action, params):
method reset_to_level (line 152) | def reset_to_level(self, rng, level, params):
method action_space (line 155) | def action_space(self, params):
class StickyActionState (line 160) | class StickyActionState:
class StickyActionWrapper (line 165) | class StickyActionWrapper(wrappers.UnderspecifiedEnvWrapper):
method __init__ (line 166) | def __init__(self, env, stickiness: float):
method step_env (line 170) | def step_env(self, key, state, action, params):
method reset_to_level (line 176) | def reset_to_level(self, rng, level, params):
method action_space (line 186) | def action_space(self, params):
class ObsHistoryState (line 191) | class ObsHistoryState:
class ObsHistoryWrapper (line 197) | class ObsHistoryWrapper(wrappers.UnderspecifiedEnvWrapper):
method __init__ (line 198) | def __init__(self, env, history_length: int):
method step_env (line 202) | def step_env(self, key, state, action, params):
method reset_to_level (line 207) | def reset_to_level(self, rng, level, params):
method action_space (line 212) | def action_space(self, params):
method get_original_obs (line 216) | def get_original_obs(env_state) -> jax.Array:
function make_squashed_normal_diag (line 222) | def make_squashed_normal_diag(mean, std, num_motor_bindings: int):
class Agent (line 232) | class Agent(nnx.Module):
method __init__ (line 233) | def __init__(self, obs_dim: int, action_dim: int, layer_width: int, *,...
method value (line 250) | def value(self, obs: jax.Array) -> jax.Array:
method action (line 253) | def action(self, obs: jax.Array):
class Transition (line 260) | class Transition:
class StepCarry (line 272) | class StepCarry:
class UpdateCarry (line 282) | class UpdateCarry:
class TrainCarry (line 292) | class TrainCarry:
function make_render_video (line 297) | def make_render_video(render_pixels):
function load_levels (line 307) | def load_levels(paths: Sequence[str], static_env_params: kenv_state.Stat...
function main (line 319) | def main(config: Config):
FILE: src/train_flow.py
class Config (line 32) | class Config:
class EpochCarry (line 63) | class EpochCarry:
function main (line 69) | def main(config: Config):
Condensed preview — 23 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (1,016K chars).
[
{
"path": ".gitignore",
"chars": 4169,
"preview": "# Created by https://www.toptal.com/developers/gitignore/api/visualstudiocode,python,vim\n# Edit at https://www.toptal.co"
},
{
"path": ".gitmodules",
"chars": 108,
"preview": "[submodule \"third_party/kinetix\"]\n\tpath = third_party/kinetix\n\turl = https://github.com/FlairOx/Kinetix.git\n"
},
{
"path": "LICENSE",
"chars": 1078,
"preview": "MIT License\n\nCopyright (c) 2025 Physical Intelligence\n\nPermission is hereby granted, free of charge, to any person obtai"
},
{
"path": "README.md",
"chars": 3139,
"preview": "Simulated experiments for the papers [Real-Time Execution of Action Chunking Flow Policies](https://arxiv.org/abs/2506.0"
},
{
"path": "pyproject.toml",
"chars": 722,
"preview": "[build-system]\nrequires = [\"setuptools>=64.0\"]\nbuild-backend = \"setuptools.build_meta\"\n\n[project]\nname = \"real-time-chun"
},
{
"path": "src/eval_flow.py",
"chars": 13462,
"preview": "import collections\nimport dataclasses\nimport functools\nimport math\nimport pathlib\nimport pickle\nfrom typing import Seque"
},
{
"path": "src/generate_data.py",
"chars": 9372,
"preview": "import dataclasses\nimport functools\nimport json\nimport pathlib\nimport pickle\nfrom typing import Sequence\n\nimport einops\n"
},
{
"path": "src/model.py",
"chars": 12883,
"preview": "import dataclasses\nimport functools\nfrom typing import Literal, TypeAlias, Self\n\nimport einops\nimport flax.nnx as nnx\nim"
},
{
"path": "src/render_levels.py",
"chars": 2782,
"preview": "import pathlib\nimport jax\nimport jax.numpy as jnp\nimport kinetix.environment.env as kenv\nimport kinetix.environment.env_"
},
{
"path": "src/train_expert.py",
"chars": 22682,
"preview": "import dataclasses\nimport functools\nimport json\nimport pathlib\nimport pickle\nfrom typing import Sequence\n\nfrom flax impo"
},
{
"path": "src/train_flow.py",
"chars": 9471,
"preview": "import concurrent.futures\nimport dataclasses\nimport functools\nimport pathlib\nimport pickle\nfrom typing import Sequence\n\n"
},
{
"path": "worlds/l/car_launch.json",
"chars": 113136,
"preview": "{\n \"env_state\": {\n \"polygon\": [\n {\n \"density\": 1,\n \"velocity\": {\n "
},
{
"path": "worlds/l/cartpole_thrust.json",
"chars": 41660,
"preview": "{\"env_state\":{\"polygon\":[{\"active\":true,\"collision_mode\":2,\"inverse_mass\":0,\"friction\":1,\"rotation\":0,\"position\":{\"0\":2."
},
{
"path": "worlds/l/catapult.json",
"chars": 41535,
"preview": "{\"env_state\":{\"polygon\":[{\"density\":1,\"inverse_mass\":0,\"n_vertices\":4,\"position\":{\"0\":2.5,\"1\":-4.800000190734863},\"radiu"
},
{
"path": "worlds/l/catcher_v3.json",
"chars": 112795,
"preview": "{\n \"env_state\": {\n \"polygon\": [\n {\n \"position\": {\n \"0\": 2.5,\n "
},
{
"path": "worlds/l/chain_lander.json",
"chars": 116884,
"preview": "{\n \"env_state\": {\n \"polygon\": [\n {\n \"position\": {\n \"0\": 2.5,\n "
},
{
"path": "worlds/l/grasp_easy.json",
"chars": 52287,
"preview": "{\"env_state\": {\"polygon\": [{\"position\": {\"0\": 2.5, \"1\": -4.800000190734863}, \"rotation\": 0.0, \"velocity\": {\"0\": 0.0, \"1\""
},
{
"path": "worlds/l/h17_unicycle.json",
"chars": 50292,
"preview": "{\"env_state\": {\"polygon\": [{\"position\": {\"0\": 2.5, \"1\": -4.800000190734863}, \"rotation\": 0.0, \"velocity\": {\"0\": 0.0, \"1\""
},
{
"path": "worlds/l/hard_lunar_lander.json",
"chars": 51010,
"preview": "{\"env_state\": {\"polygon\": [{\"position\": {\"0\": 2.5, \"1\": -4.800000190734863}, \"rotation\": 0.0, \"velocity\": {\"0\": 0.0, \"1\""
},
{
"path": "worlds/l/mjc_half_cheetah.json",
"chars": 51978,
"preview": "{\"env_state\": {\"polygon\": [{\"position\": {\"0\": 2.5, \"1\": -4.800000190734863}, \"rotation\": 0.0, \"velocity\": {\"0\": 0.0, \"1\""
},
{
"path": "worlds/l/mjc_swimmer.json",
"chars": 51742,
"preview": "{\"env_state\": {\"polygon\": [{\"position\": {\"0\": 2.5, \"1\": -4.800000190734863}, \"rotation\": 0.0, \"velocity\": {\"0\": 0.0, \"1\""
},
{
"path": "worlds/l/mjc_walker.json",
"chars": 51792,
"preview": "{\"env_state\": {\"polygon\": [{\"position\": {\"0\": 2.5, \"1\": -4.800000190734863}, \"rotation\": 0.0, \"velocity\": {\"0\": 0.0, \"1\""
},
{
"path": "worlds/l/trampoline.json",
"chars": 111763,
"preview": "{\n \"env_state\": {\n \"polygon\": [\n {\n \"position\": {\n \"0\": 2.5,\n "
}
]
About this extraction
This page contains the full source code of the Physical-Intelligence/real-time-chunking-kinetix GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 23 files (905.0 KB), approximately 281.6k tokens, and a symbol index with 75 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.
Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.