Showing preview only (408K chars total). Download the full file or copy to clipboard to get everything.
Repository: brentyi/egoallo
Branch: main
Commit: 0c20ae8c8d48
Files: 56
Total size: 387.7 KB
Directory structure:
gitextract_96b1_mr0/
├── .github/
│ └── workflows/
│ └── pyright.yml
├── .gitignore
├── 0a_preprocess_training_data.py
├── 0b_preprocess_training_data.py
├── 1_train_motion_prior.py
├── 2_run_hamer_on_vrs.py
├── 3_aria_inference.py
├── 4_visualize_outputs.py
├── 5_eval_body_metrics.py
├── LICENSE
├── README.md
├── download_checkpoint_and_data.sh
├── pyproject.toml
└── src/
└── egoallo/
├── __init__.py
├── fncsmpl.py
├── fncsmpl_extensions.py
├── fncsmpl_jax.py
├── guidance_optimizer_jax.py
├── hand_detection_structs.py
├── inference_utils.py
├── metrics_helpers.py
├── network.py
├── preprocessing/
│ ├── __init__.py
│ ├── body_model/
│ │ ├── __init__.py
│ │ ├── body_model.py
│ │ ├── skeleton.py
│ │ ├── specs.py
│ │ └── utils.py
│ ├── geometry/
│ │ ├── __init__.py
│ │ ├── camera.py
│ │ ├── helpers.py
│ │ ├── plane.py
│ │ ├── rotation.py
│ │ └── transforms/
│ │ ├── __init__.py
│ │ ├── _base.py
│ │ ├── _se2.py
│ │ ├── _se3.py
│ │ ├── _so2.py
│ │ ├── _so3.py
│ │ ├── hints/
│ │ │ └── __init__.py
│ │ └── utils/
│ │ ├── __init__.py
│ │ └── _utils.py
│ └── util/
│ ├── __init__.py
│ └── tensor.py
├── py.typed
├── sampling.py
├── tensor_dataclass.py
├── training_loss.py
├── training_utils.py
├── transforms/
│ ├── __init__.py
│ ├── _base.py
│ ├── _se3.py
│ ├── _so3.py
│ └── utils/
│ ├── __init__.py
│ └── _utils.py
└── vis_helpers.py
================================================
FILE CONTENTS
================================================
================================================
FILE: .github/workflows/pyright.yml
================================================
name: pyright
on:
push:
branches: [main]
pull_request:
branches: [main]
jobs:
pyright:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["3.12"]
steps:
- uses: actions/checkout@v2
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v1
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
pip install uv
uv pip install --system -e .
uv pip install --system jax
uv pip install --system git+https://github.com/brentyi/jaxls.git
uv pip install --system git+https://github.com/brentyi/hamer_helper.git
uv pip install --system pyright
- name: Run pyright
run: |
pyright .
================================================
FILE: .gitignore
================================================
*.swp
*.swo
*.pyc
*.egg-info
*.ipynb_checkpoints
__pycache__
.coverage
htmlcov
.mypy_cache
.dmypy.json
.hypothesis
.envrc
.lvimrc
.DS_Store
.envrc
lightning_logs/
outputs/
data/
egoallo_checkpoint_*
egoallo_example_*
================================================
FILE: 0a_preprocess_training_data.py
================================================
"""Convert raw AMASS data to HuMoR-style npz format.
Mostly taken from
https://github.com/davrempe/humor/blob/main/humor/scripts/process_amass_data.py,
but added gender neutral beta conversion and other utilities.
"""
import dataclasses
import os
import time
from concurrent.futures import ProcessPoolExecutor
from pathlib import Path
from typing import Dict, Tuple
import matplotlib.pyplot as plt
import numpy as np
import torch
import tyro
from loguru import logger as guru
from sklearn.cluster import DBSCAN
from tqdm import tqdm
from egoallo.preprocessing.body_model import (
KEYPT_VERTS,
SMPL_JOINTS,
BodyModel,
reflect_pose_aa,
reflect_root_trajectory,
run_smpl,
)
from egoallo.preprocessing.geometry import convert_rotation, joints_global_to_local
from egoallo.preprocessing.util import move_to
AMASS_SPLITS = {
"train": [
"ACCAD",
"BMLhandball",
"BMLmovi",
"BioMotionLab_NTroje",
"CMU",
"DFaust_67",
"DanceDB",
"EKUT",
"Eyes_Japan_Dataset",
"KIT",
"MPI_Limits",
"TCD_handMocap",
"TotalCapture",
],
"val": [
"HumanEva",
"MPI_HDM05",
"SFU",
"MPI_mosh",
],
"test": [
"Transitions_mocap",
"SSM_synced",
],
}
AMASS_SPLITS["all"] = AMASS_SPLITS["train"] + AMASS_SPLITS["val"] + AMASS_SPLITS["test"]
def load_neutral_beta_conversion(gender: str) -> Tuple[np.ndarray, np.ndarray]:
assert gender in ["female", "male"]
data = np.load(f"./data/smplh_gender_conversion/{gender}_to_neutral.npz")
return data["A"], data["b"]
def convert_gender_neutral_beta(
beta: np.ndarray, A: np.ndarray, b: np.ndarray
) -> np.ndarray:
"""
:param beta (*, B)
:param A (B, B)
:param b (B)
beta_neutral = A @ beta_gender + b
"""
*dims, B = beta.shape
A = A.reshape((*(1,) * len(dims), B, B))
b = b.reshape((*(1,) * len(dims), B))
return np.einsum("...ij,...j->...i", A, beta) + b
def determine_floor_height_and_contacts(
body_joint_seq,
fps,
vis=False,
floor_vel_thresh=0.005,
floor_height_offset=0.01,
contact_vel_thresh=0.005, # 0.015
contact_toe_height_thresh=0.04, # if static toe above this height
contact_ankle_height_thresh=0.08,
terrain_height_thresh=0.04,
root_height_thresh=0.04,
cluster_size_thresh=0.25,
discard_terrain_seqs=False, # throw away person steps onto objects (determined by a heuristic)
):
"""
Taken from
https://github.com/davrempe/humor/blob/main/humor/scripts/process_amass_data.py
Input: body_joint_seq N x 21 x 3 numpy array
Contacts are N x 4 where N is number of frames and each row is left heel/toe, right heel/toe
"""
num_frames = body_joint_seq.shape[0]
# compute toe velocities
root_seq = body_joint_seq[:, SMPL_JOINTS["hips"], :]
left_toe_seq = body_joint_seq[:, SMPL_JOINTS["leftToeBase"], :]
right_toe_seq = body_joint_seq[:, SMPL_JOINTS["rightToeBase"], :]
left_toe_vel = np.linalg.norm(left_toe_seq[1:] - left_toe_seq[:-1], axis=1)
left_toe_vel = np.append(left_toe_vel, left_toe_vel[-1])
right_toe_vel = np.linalg.norm(right_toe_seq[1:] - right_toe_seq[:-1], axis=1)
right_toe_vel = np.append(right_toe_vel, right_toe_vel[-1])
if vis:
plt.figure()
steps = np.arange(num_frames)
plt.plot(steps, left_toe_vel, "-r", label="left vel")
plt.plot(steps, right_toe_vel, "-b", label="right vel")
plt.legend()
plt.show()
plt.close()
# now foot heights (z is up)
left_toe_heights = left_toe_seq[:, 2]
right_toe_heights = right_toe_seq[:, 2]
root_heights = root_seq[:, 2]
if vis:
plt.figure()
steps = np.arange(num_frames)
plt.plot(steps, left_toe_heights, "-r", label="left toe height")
plt.plot(steps, right_toe_heights, "-b", label="right toe height")
plt.plot(steps, root_heights, "-g", label="root height")
plt.legend()
plt.show()
plt.close()
# filter out heights when velocity is greater than some threshold (not in contact)
all_inds = np.arange(left_toe_heights.shape[0])
left_static_foot_heights = left_toe_heights[left_toe_vel < floor_vel_thresh]
left_static_inds = all_inds[left_toe_vel < floor_vel_thresh]
right_static_foot_heights = right_toe_heights[right_toe_vel < floor_vel_thresh]
right_static_inds = all_inds[right_toe_vel < floor_vel_thresh]
all_static_foot_heights = np.append(
left_static_foot_heights, right_static_foot_heights
)
all_static_inds = np.append(left_static_inds, right_static_inds)
if vis:
plt.figure()
steps = np.arange(left_static_foot_heights.shape[0])
plt.plot(steps, left_static_foot_heights, "-r", label="left static height")
plt.legend()
plt.show()
plt.close()
discard_seq = False
if all_static_foot_heights.shape[0] > 0:
cluster_heights = []
cluster_root_heights = []
cluster_sizes = []
# cluster foot heights and find one with smallest median
clustering = DBSCAN(eps=0.005, min_samples=3).fit(
all_static_foot_heights.reshape(-1, 1)
)
all_labels = np.unique(clustering.labels_)
# print(all_labels)
if vis:
plt.figure()
min_median = min_root_median = float("inf")
for cur_label in all_labels:
cur_clust = all_static_foot_heights[clustering.labels_ == cur_label]
cur_clust_inds = np.unique(
all_static_inds[clustering.labels_ == cur_label]
) # inds in the original sequence that correspond to this cluster
if vis:
plt.scatter(
cur_clust, np.zeros_like(cur_clust), label="foot %d" % (cur_label)
)
# get median foot height and use this as height
cur_median = np.median(cur_clust)
cluster_heights.append(cur_median)
cluster_sizes.append(cur_clust.shape[0])
# get root information
cur_root_clust = root_heights[cur_clust_inds]
cur_root_median = np.median(cur_root_clust)
cluster_root_heights.append(cur_root_median)
if vis:
plt.scatter(
cur_root_clust,
np.zeros_like(cur_root_clust),
label="root %d" % (cur_label),
)
# update min info
if cur_median < min_median:
min_median = cur_median
min_root_median = cur_root_median
# print(cluster_heights)
# print(cluster_root_heights)
# print(cluster_sizes)
if vis:
plt.show()
plt.close()
floor_height = min_median
offset_floor_height = (
floor_height - floor_height_offset
) # toe joint is actually inside foot mesh a bit
if discard_terrain_seqs:
# print(min_median + TERRAIN_HEIGHT_THRESH)
# print(min_root_median + ROOT_HEIGHT_THRESH)
for cluster_root_height, cluster_height, cluster_size in zip(
cluster_root_heights, cluster_heights, cluster_sizes
):
root_above_thresh = cluster_root_height > (
min_root_median + root_height_thresh
)
toe_above_thresh = cluster_height > (min_median + terrain_height_thresh)
cluster_size_above_thresh = cluster_size > int(
cluster_size_thresh * fps
)
if root_above_thresh and toe_above_thresh and cluster_size_above_thresh:
discard_seq = True
print("DISCARDING sequence based on terrain interaction!")
break
else:
floor_height = offset_floor_height = 0.0
# now find contacts (feet are below certain velocity and within certain range of floor)
# compute heel velocities
left_heel_seq = body_joint_seq[:, SMPL_JOINTS["leftFoot"], :]
right_heel_seq = body_joint_seq[:, SMPL_JOINTS["rightFoot"], :]
left_heel_vel = np.linalg.norm(left_heel_seq[1:] - left_heel_seq[:-1], axis=1)
left_heel_vel = np.append(left_heel_vel, left_heel_vel[-1])
right_heel_vel = np.linalg.norm(right_heel_seq[1:] - right_heel_seq[:-1], axis=1)
right_heel_vel = np.append(right_heel_vel, right_heel_vel[-1])
left_heel_contact = left_heel_vel < contact_vel_thresh
right_heel_contact = right_heel_vel < contact_vel_thresh
left_toe_contact = left_toe_vel < contact_vel_thresh
right_toe_contact = right_toe_vel < contact_vel_thresh
# compute heel heights
left_heel_heights = left_heel_seq[:, 2] - floor_height
right_heel_heights = right_heel_seq[:, 2] - floor_height
left_toe_heights = left_toe_heights - floor_height
right_toe_heights = right_toe_heights - floor_height
left_heel_contact = np.logical_and(
left_heel_contact, left_heel_heights < contact_ankle_height_thresh
)
right_heel_contact = np.logical_and(
right_heel_contact, right_heel_heights < contact_ankle_height_thresh
)
left_toe_contact = np.logical_and(
left_toe_contact, left_toe_heights < contact_toe_height_thresh
)
right_toe_contact = np.logical_and(
right_toe_contact, right_toe_heights < contact_toe_height_thresh
)
contacts = np.zeros((num_frames, len(SMPL_JOINTS)))
contacts[:, SMPL_JOINTS["leftFoot"]] = left_heel_contact
contacts[:, SMPL_JOINTS["leftToeBase"]] = left_toe_contact
contacts[:, SMPL_JOINTS["rightFoot"]] = right_heel_contact
contacts[:, SMPL_JOINTS["rightToeBase"]] = right_toe_contact
# hand contacts
left_hand_contact = detect_joint_contact(
body_joint_seq,
"leftHand",
floor_height,
contact_vel_thresh,
contact_ankle_height_thresh,
)
right_hand_contact = detect_joint_contact(
body_joint_seq,
"rightHand",
floor_height,
contact_vel_thresh,
contact_ankle_height_thresh,
)
contacts[:, SMPL_JOINTS["leftHand"]] = left_hand_contact
contacts[:, SMPL_JOINTS["rightHand"]] = right_hand_contact
# knee contacts
left_knee_contact = detect_joint_contact(
body_joint_seq,
"leftLeg",
floor_height,
contact_vel_thresh,
contact_ankle_height_thresh,
)
right_knee_contact = detect_joint_contact(
body_joint_seq,
"rightLeg",
floor_height,
contact_vel_thresh,
contact_ankle_height_thresh,
)
contacts[:, SMPL_JOINTS["leftLeg"]] = left_knee_contact
contacts[:, SMPL_JOINTS["rightLeg"]] = right_knee_contact
return offset_floor_height, contacts, discard_seq
def detect_joint_contact(
body_joint_seq, joint_name, floor_height, vel_thresh, height_thresh
):
"""
Taken from
https://github.com/davrempe/humor/blob/main/humor/scripts/process_amass_data.py
"""
# calc velocity
joint_seq = body_joint_seq[:, SMPL_JOINTS[joint_name], :]
joint_vel = np.linalg.norm(joint_seq[1:] - joint_seq[:-1], axis=1)
joint_vel = np.append(joint_vel, joint_vel[-1])
# determine contact by velocity
joint_contact = joint_vel < vel_thresh
# compute heights
joint_heights = joint_seq[:, 2] - floor_height
# compute contact by vel + height
joint_contact = np.logical_and(joint_contact, joint_heights < height_thresh)
return joint_contact
def compute_root_align_mats(root_orient):
"""
Taken from
https://github.com/davrempe/humor/blob/main/humor/scripts/process_amass_data.py
compute world to canonical frame for each timestep (rotation around up axis)
"""
root_orient = torch.as_tensor(root_orient).reshape(-1, 3)
# convert aa to matrices
root_orient_mat = convert_rotation(root_orient, "aa", "mat").numpy()
# rotate root so aligning local body right vector (-x) with world right vector (+x)
# with a rotation around the up axis (+z)
# in body coordinates body x-axis is to the left
body_right = -root_orient_mat[:, :, 0]
world2aligned_mat, world2aligned_aa = compute_align_from_body_right(body_right)
return world2aligned_mat
def compute_joint_align_mats(joint_seq):
"""
Taken from
https://github.com/davrempe/humor/blob/main/humor/scripts/process_amass_data.py
Compute world to canonical frame for each timestep (rotation around up axis)
from the given joint seq (T x J x 3)
"""
left_idx = SMPL_JOINTS["leftUpLeg"]
right_idx = SMPL_JOINTS["rightUpLeg"]
body_right = joint_seq[:, right_idx] - joint_seq[:, left_idx]
body_right = body_right / np.linalg.norm(body_right, axis=1)[:, np.newaxis]
world2aligned_mat, world2aligned_aa = compute_align_from_body_right(body_right)
return world2aligned_mat
def compute_align_from_body_right(body_right):
"""
Taken from
https://github.com/davrempe/humor/blob/main/humor/scripts/process_amass_data.py
"""
world2aligned_angle = np.arccos(
body_right[:, 0] / (np.linalg.norm(body_right[:, :2], axis=1) + 1e-8)
) # project to world x axis, and compute angle
body_right[:, 2] = 0.0
world2aligned_axis = np.cross(body_right, np.array([[1.0, 0.0, 0.0]]))
world2aligned_aa = (
world2aligned_axis
/ (np.linalg.norm(world2aligned_axis, axis=1)[:, np.newaxis] + 1e-8)
) * world2aligned_angle[:, np.newaxis]
world2aligned_mat = convert_rotation(
torch.as_tensor(world2aligned_aa).reshape(-1, 3), "aa", "mat"
).numpy()
return world2aligned_mat, world2aligned_aa
def estimate_velocity(data_seq, h):
"""
Taken from
https://github.com/davrempe/humor/blob/main/humor/scripts/process_amass_data.py
Given some data sequence of T timesteps in the shape (T, ...), estimates
the velocity for the middle T-2 steps using a second order central difference scheme.
- h : step size
"""
data_tp1 = data_seq[2:]
data_tm1 = data_seq[0:-2]
data_vel_seq = (data_tp1 - data_tm1) / (2 * h)
return data_vel_seq
def estimate_angular_velocity(rot_seq, h):
"""
Taken from
https://github.com/davrempe/humor/blob/main/humor/scripts/process_amass_data.py
Given a sequence of T rotation matrices, estimates angular velocity at T-2 steps.
Input sequence should be of shape (T, ..., 3, 3)
"""
# see https://en.wikipedia.org/wiki/Angular_velocity#Calculation_from_the_orientation_matrix
dRdt = estimate_velocity(rot_seq, h)
R = rot_seq[1:-1]
RT = np.swapaxes(R, -1, -2)
# compute skew-symmetric angular velocity tensor
w_mat = np.matmul(dRdt, RT)
# pull out angular velocity vector
# average symmetric entries
w_x = (-w_mat[..., 1, 2] + w_mat[..., 2, 1]) / 2.0
w_y = (w_mat[..., 0, 2] - w_mat[..., 2, 0]) / 2.0
w_z = (-w_mat[..., 0, 1] + w_mat[..., 1, 0]) / 2.0
w = np.stack([w_x, w_y, w_z], axis=-1)
return w
def load_seq_smpl_params(input_path: str, num_betas: int = 16):
guru.info(f"Loading from {input_path}")
# load in input data
# we leave out "dmpls" and "marker_data"/"marker_label" which are not present in all datasets
bdata = np.load(input_path)
gender = np.array(bdata["gender"], ndmin=1)[0]
gender = str(gender, "utf-8") if isinstance(gender, bytes) else str(gender)
fps = bdata["mocap_framerate"]
trans = bdata["trans"][:] # global translation
num_frames = len(trans)
root_orient = bdata["poses"][:, :3] # global root orientation (1 joint)
pose_body = bdata["poses"][:, 3:66] # body joint rotations (21 joints)
pose_hand = bdata["poses"][:, 66:] # finger articulation joint rotations
betas = np.tile(
bdata["betas"][None, :num_betas], [num_frames, 1]
) # body shape parameters
# correct mislabeled data
if input_path.find("BMLhandball") >= 0:
fps = 240
if input_path.find("20160930_50032") >= 0 or input_path.find("20161014_50033") >= 0:
fps = 59
model_vars = {
"trans": trans,
"root_orient": root_orient,
"pose_body": pose_body,
"pose_hand": pose_hand,
"betas": betas,
}
meta = {"fps": fps, "gender": gender, "num_frames": num_frames}
guru.info(f"meta {meta}")
guru.info(f"model var shapes {str({k: v.shape for k, v in model_vars.items()})}")
return model_vars, meta
def run_batch_smpl(
body_model: BodyModel,
device: torch.device,
num_total: int,
batch_size: int,
return_verts: bool = True,
**kwargs,
):
var_dims = body_model.var_dims
var_names = [name for name in kwargs if name in var_dims]
model_vars = {
name: torch.as_tensor(kwargs[name], dtype=torch.float32).reshape(
-1, var_dims[name]
)
for name in var_names
}
fopts = {k: v for k, v in kwargs.items() if k not in var_names}
batch_joints, batch_verts = [], []
for sidx in range(0, num_total, batch_size):
eidx = min(sidx + batch_size, num_total)
batch_model_vars = move_to(
{name: x[sidx:eidx].contiguous() for name, x in model_vars.items()}, device
)
with torch.no_grad():
joints, verts, _ = run_smpl(
body_model, return_verts=return_verts, **batch_model_vars, **fopts
)
batch_joints.append(joints.detach().cpu())
if return_verts and verts is not None:
batch_verts.append(verts.detach().cpu())
joints_all = torch.cat(batch_joints, dim=0)
verts_all = torch.cat(batch_verts, dim=0) if len(batch_verts) > 0 else None
return joints_all, verts_all
def process_seq(
input_path: str,
out_path: str,
smplh_root: str,
dev_id: int,
beta_neutral: bool,
reflect: bool = False,
overwrite: bool = False,
**kwargs,
):
if not overwrite and os.path.isfile(out_path):
guru.info(f"{out_path} already exists, skipping.")
return
guru.info(f"process {input_path} to {out_path}")
model_vars, meta = load_seq_smpl_params(input_path)
if beta_neutral: # get the gender neutral beta
guru.info("converting betas to gender neutral")
A_beta, b_beta = load_neutral_beta_conversion(meta["gender"])
model_vars["betas"] = convert_gender_neutral_beta(
model_vars["betas"], A_beta, b_beta
)
meta["gender"] = "neutral"
process_seq_data(
model_vars, meta, out_path, dev_id, smplh_root, reflect=reflect, **kwargs
)
def process_seq_data(
model_vars: Dict,
meta: Dict,
out_path: str,
dev_id: int,
smplh_root: str,
reflect: bool = False,
split_frame_limit: int = 2000,
discard_shorter_than: float = 1.0, # seconds
out_fps: int = 30,
save_verts: bool = False,
save_velocities: bool = True, # save all parameter velocities available
):
guru.info(f"Processing seq with meta {meta}")
start_t = time.time()
gender = meta["gender"]
src_fps = meta["fps"]
num_frames = meta["num_frames"]
# only keep middle 80% of sequences to avoid redundanct static poses
sidx, eidx = int(0.1 * num_frames), int(0.9 * num_frames)
num_frames = eidx - sidx
for name, x in model_vars.items():
model_vars[name] = x[sidx:eidx]
guru.info(str({k: v.shape for k, v in model_vars.items()}))
# discard if shorter than threshold
if num_frames < discard_shorter_than * src_fps:
guru.info(f"Sequence shorter than {discard_shorter_than} s, discarding...")
return
# must do SMPL forward pass to get joints
# split into manageable chunks to avoid running out of GPU memory for SMPL
device = (
torch.device(f"cuda:{dev_id}")
if torch.cuda.is_available()
else torch.device("cpu")
)
# <HACKS>
# smplx tries to read shape properties, even when use_pca=False
from smplx.utils import Struct
Struct.hands_componentsl = np.zeros(100) # type: ignore
Struct.hands_componentsr = np.zeros(100) # type: ignore
Struct.hands_meanl = np.zeros(100) # type: ignore
Struct.hands_meanr = np.zeros(100) # type: ignore
# This defaults to 300, but we have 16 beta parameters. When
# 16<300 the SMPL class will set num_betas to 10...
from smplx import SMPLH
assert SMPLH.SHAPE_SPACE_DIM in (300, 16)
SMPLH.SHAPE_SPACE_DIM = 16
# <HACKS>
body_model = BodyModel(f"{smplh_root}/{gender}/model.npz", use_pca=False).to(device)
model_vars = {k: torch.as_tensor(v).float() for k, v in model_vars.items()}
if reflect:
rot_og = model_vars["root_orient"]
rot_re, model_vars["pose_body"] = reflect_pose_aa(
rot_og, model_vars["pose_body"]
)
out = body_model.forward(betas=model_vars["betas"][:1].to(device))
root_loc = out.Jtr[:, 0].cpu() # type: ignore
model_vars["root_orient"], model_vars["trans"] = reflect_root_trajectory(
rot_og, model_vars["trans"], rot_re, root_loc
)
body_joint_seq, body_vtx_seq = run_batch_smpl(
body_model,
device,
num_frames,
split_frame_limit,
return_verts=save_verts,
**model_vars,
)
joints_glob = body_joint_seq[:, : len(SMPL_JOINTS), :]
joint_seq = joints_glob.numpy()
guru.info(f"Recovered joints and verts {joint_seq.shape}")
out_dict = model_vars.copy()
out_dict["joints"] = joint_seq
out_dict["joints_loc"], _ = joints_global_to_local(
convert_rotation(model_vars["root_orient"], "aa", "mat"),
model_vars["trans"],
joints_glob,
)
if save_verts and body_vtx_seq is not None:
out_dict["mojo_verts"] = body_vtx_seq[:, KEYPT_VERTS, :].numpy()
# determine floor height and foot contacts
floor_height, contacts, discard_seq = determine_floor_height_and_contacts(
joint_seq, src_fps
)
if discard_seq:
guru.info("Terrain interaction detected, discarding...")
return
guru.info(f"Floor height: {floor_height}")
# translate so floor is at z=0
for name in ["trans", "joints", "mojo_verts"]:
if name not in out_dict:
continue
out_dict[name][..., 2] -= floor_height
# compute rotation to canonical frame (forward facing +y) for every frame
world2aligned_rot = compute_root_align_mats(model_vars["root_orient"])
out_dict.update(
{
"contacts": contacts,
"floor_height": floor_height,
"world2aligned_rot": world2aligned_rot,
}
)
# estimate various velocities based on full frame rate
# with second order central differences before downsampling
if save_velocities:
h = 1.0 / src_fps
lin_names = ["trans", "joints", "mojo_verts"]
ang_names = ["root_orient", "pose_body"]
cur_keys = lin_names + ang_names + ["contacts"]
for name in lin_names:
if name not in out_dict:
continue
out_dict[f"{name}_vel"] = estimate_velocity(out_dict[name], h)
# root orient
for name in ang_names:
if name not in out_dict:
continue
rot_aa = (
torch.as_tensor(out_dict[name]).reshape(num_frames, -1, 3).squeeze()
)
rot_mat = convert_rotation(rot_aa, "aa", "mat").numpy()
out_dict[f"{name}_vel"] = estimate_angular_velocity(rot_mat, h)
# joint up-axis angular velocity (need to compute joint frames first...)
# need the joint transform at all steps to find the angular velocity
joints_world2aligned_rot = compute_joint_align_mats(joint_seq)
joint_orient_vel = -estimate_angular_velocity(joints_world2aligned_rot, h)
# only need around z
out_dict["joint_orient_vel"] = joint_orient_vel[:, 2]
# throw out edge frames for other data so velocities are accurate
for name in cur_keys:
if name not in out_dict:
continue
out_dict[name] = out_dict[name][1:-1]
num_frames = num_frames - 2
# downsample frames
fps_ratio = float(out_fps) / src_fps
guru.info(f"Downsamp ratio: {fps_ratio}")
new_num_frames = int(fps_ratio * num_frames)
guru.info(f"Downsamp num frames: {new_num_frames}")
downsamp_inds = np.linspace(0, num_frames - 1, num=new_num_frames, dtype=int)
for k, v in out_dict.items():
# print(k, type(v))
if not isinstance(v, (torch.Tensor, np.ndarray)):
continue
if v.ndim >= 1:
# print("downsampling", k)
out_dict[k] = v[downsamp_inds]
meta = {
"fps": out_fps,
"num_frames": new_num_frames,
"gender": str(gender),
}
guru.info(f"Seq process time: {time.time() - start_t} s")
guru.info(f"Saving data to {out_path}")
os.makedirs(os.path.dirname(out_path), exist_ok=True)
np.savez(out_path, **meta, **out_dict)
@dataclasses.dataclass
class Config:
data_root: str
"""Where the AMASS dataset is stored."""
smplh_root: str = "./data/smplh"
out_root: str = "./data/processed_30fps_no_skating/"
devices: tuple[int, ...] = (0,)
"""CUDA devices. We use CPU if not available."""
overwrite: bool = False
def check_skip(path_name: str) -> bool:
"""Copied conditions from https://github.com/davrempe/humor/blob/main/humor/scripts/cleanup_amass_data.py"""
if "BioMotionLab_NTroje" in path_name and (
"treadmill" in path_name or "normal_" in path_name
):
return True
if "MPI_HDM05" in path_name and "dg/HDM_dg_07-01" in path_name:
return True
return False
def main(cfg: Config):
dsets = AMASS_SPLITS["all"]
paths_to_process = []
for dset in dsets:
paths_to_process.extend(
map(str, Path(f"{cfg.data_root}/{dset}").glob("**/*_poses.npz"))
)
dev_ids = cfg.devices
guru.info(f"devices {dev_ids}")
if len(dev_ids) <= 1:
guru.info("processing in sequence")
for i, path in tqdm(enumerate(paths_to_process)):
if check_skip(path):
guru.info(f"skipping {path}")
continue
fname = path.split(cfg.data_root)[-1].rstrip("/")
name, ext = os.path.splitext(fname)
out_path = f"{cfg.out_root}/neutral/{name}{ext}"
r_out_path = f"{cfg.out_root}/neutral/{name}_reflect{ext}"
process_seq(
path,
out_path,
cfg.smplh_root,
dev_ids[i % len(dev_ids)],
beta_neutral=True,
reflect=False,
overwrite=cfg.overwrite,
)
process_seq(
path,
r_out_path,
cfg.smplh_root,
dev_ids[i % len(dev_ids)],
beta_neutral=True,
reflect=True,
overwrite=cfg.overwrite,
)
return
with ProcessPoolExecutor(max_workers=len(dev_ids)) as exe:
for i, path in tqdm(enumerate(paths_to_process)):
if check_skip(path):
guru.info(f"skipping {path}")
continue
fname = path.split(cfg.data_root)[-1].rstrip("/")
name, ext = os.path.splitext(fname)
out_path = f"{cfg.out_root}/neutral/{name}{ext}"
r_out_path = f"{cfg.out_root}/neutral/{name}_reflect{ext}"
exe.submit(
process_seq,
path,
out_path,
cfg.smplh_root,
dev_ids[i % len(dev_ids)],
beta_neutral=True,
reflect=False,
overwrite=cfg.overwrite,
)
exe.submit(
process_seq,
path,
r_out_path,
cfg.smplh_root,
dev_ids[i % len(dev_ids)],
beta_neutral=True,
reflect=True,
overwrite=cfg.overwrite,
)
if __name__ == "__main__":
main(tyro.cli(Config))
================================================
FILE: 0b_preprocess_training_data.py
================================================
"""Translate data from HuMoR-style npz format to an hdf5-based one.
Due to AMASS licensing, we unfortunately can't re-distribute our preprocessed dataset. If you have questions
or run into issues, please reach out.
"""
import queue
import threading
import time
from pathlib import Path
import h5py
import torch
import torch.cuda
import tyro
from egoallo import fncsmpl
from egoallo.data.amass import EgoTrainingData
def main(
smplh_npz_path: Path = Path("./data/smplh/neutral/model.npz"),
data_npz_dir: Path = Path("./data/processed_30fps_no_skating/"),
output_file: Path = Path("./data/egoalgo_no_skating_dataset.hdf5"),
output_list_file: Path = Path("./data/egoalgo_no_skating_dataset_files.txt"),
include_hands: bool = True,
) -> None:
body_model = fncsmpl.SmplhModel.load(smplh_npz_path)
assert torch.cuda.is_available()
task_queue = queue.Queue[Path]()
for path in list(data_npz_dir.glob("**/*.npz")):
task_queue.put_nowait(path)
total_count = task_queue.qsize()
start_time = time.time()
output_hdf5 = h5py.File(output_file, "w")
file_list: list[str] = []
def worker(device_idx: int) -> None:
device_body_model = body_model.to("cuda:" + str(device_idx))
while True:
try:
npz_path = task_queue.get_nowait()
except queue.Empty:
break
print(f"Processing {npz_path} on device {device_idx}...")
train_data = EgoTrainingData.load_from_npz(
device_body_model, npz_path, include_hands=include_hands
)
assert "neutral" in str(npz_path)
group_name = str(npz_path).rpartition("neutral/")[2]
print(f"Writing to group {group_name} on {device_idx}...")
group = output_hdf5.create_group(group_name)
file_list.append(group_name)
for k, v in vars(train_data).items():
# No need to write the mask, which will always be ones when we
# load from the npz file!
if k == "mask":
continue
# Chunk into 32 timesteps at a time.
assert v.dtype == torch.float32
if v.shape[0] == train_data.T_world_cpf.shape[0]:
chunks = (min(32, v.shape[0]),) + v.shape[1:]
else:
assert v.shape[0] == 1
chunks = v.shape
group.create_dataset(k, data=v.numpy(force=True), chunks=chunks)
print(
f"Finished ~{total_count - task_queue.qsize()}/{total_count},",
f"{(total_count - task_queue.qsize()) / total_count * 100:.2f}% in",
f"{time.time() - start_time} seconds",
)
workers = [
threading.Thread(target=worker, args=(i,))
for i in range(torch.cuda.device_count())
]
for w in workers:
w.start()
for w in workers:
w.join()
output_list_file.write_text("\n".join(file_list))
if __name__ == "__main__":
tyro.cli(main)
================================================
FILE: 1_train_motion_prior.py
================================================
"""Training script for EgoAllo diffusion model using HuggingFace accelerate."""
import dataclasses
import shutil
from pathlib import Path
from typing import Literal
import tensorboardX
import torch.optim.lr_scheduler
import torch.utils.data
import tyro
import yaml
from accelerate import Accelerator, DataLoaderConfiguration
from accelerate.utils import ProjectConfiguration
from loguru import logger
from egoallo import network, training_loss, training_utils
from egoallo.data.amass import EgoAmassHdf5Dataset
from egoallo.data.dataclass import collate_dataclass
@dataclasses.dataclass(frozen=True)
class EgoAlloTrainConfig:
experiment_name: str
dataset_hdf5_path: Path
dataset_files_path: Path
model: network.EgoDenoiserConfig = network.EgoDenoiserConfig()
loss: training_loss.TrainingLossConfig = training_loss.TrainingLossConfig()
# Dataset arguments.
batch_size: int = 256
"""Effective batch size."""
num_workers: int = 2
subseq_len: int = 128
dataset_slice_strategy: Literal[
"deterministic", "random_uniform_len", "random_variable_len"
] = "random_uniform_len"
dataset_slice_random_variable_len_proportion: float = 0.3
"""Only used if dataset_slice_strategy == 'random_variable_len'."""
train_splits: tuple[Literal["train", "val", "test", "just_humaneva"], ...] = (
"train",
"val",
)
# Optimizer options.
learning_rate: float = 1e-4
weight_decay: float = 1e-4
warmup_steps: int = 1000
max_grad_norm: float = 1.0
def get_experiment_dir(experiment_name: str, version: int = 0) -> Path:
"""Creates a directory to put experiment files in, suffixed with a version
number. Similar to PyTorch lightning."""
experiment_dir = (
Path(__file__).absolute().parent
/ "experiments"
/ experiment_name
/ f"v{version}"
)
if experiment_dir.exists():
return get_experiment_dir(experiment_name, version + 1)
else:
return experiment_dir
def run_training(
config: EgoAlloTrainConfig,
restore_checkpoint_dir: Path | None = None,
) -> None:
# Set up experiment directory + HF accelerate.
# We're getting to manage logging, checkpoint directories, etc manually,
# and just use `accelerate` for distibuted training.
experiment_dir = get_experiment_dir(config.experiment_name)
assert not experiment_dir.exists()
accelerator = Accelerator(
project_config=ProjectConfiguration(project_dir=str(experiment_dir)),
dataloader_config=DataLoaderConfiguration(split_batches=True),
)
writer = (
tensorboardX.SummaryWriter(logdir=str(experiment_dir), flush_secs=10)
if accelerator.is_main_process
else None
)
device = accelerator.device
# Initialize experiment.
if accelerator.is_main_process:
training_utils.pdb_safety_net()
# Save various things that might be useful.
experiment_dir.mkdir(exist_ok=True, parents=True)
(experiment_dir / "git_commit.txt").write_text(
training_utils.get_git_commit_hash()
)
(experiment_dir / "git_diff.txt").write_text(training_utils.get_git_diff())
(experiment_dir / "run_config.yaml").write_text(yaml.dump(config))
(experiment_dir / "model_config.yaml").write_text(yaml.dump(config.model))
# Add hyperparameters to TensorBoard.
assert writer is not None
writer.add_hparams(
hparam_dict=training_utils.flattened_hparam_dict_from_dataclass(config),
metric_dict={},
name=".", # Hack to avoid timestamped subdirectory.
)
# Write logs to file.
logger.add(experiment_dir / "trainlog.log", rotation="100 MB")
# Setup.
model = network.EgoDenoiser(config.model)
train_loader = torch.utils.data.DataLoader(
dataset=EgoAmassHdf5Dataset(
config.dataset_hdf5_path,
config.dataset_files_path,
splits=config.train_splits,
subseq_len=config.subseq_len,
cache_files=True,
slice_strategy=config.dataset_slice_strategy,
random_variable_len_proportion=config.dataset_slice_random_variable_len_proportion,
),
batch_size=config.batch_size,
shuffle=True,
num_workers=config.num_workers,
persistent_workers=config.num_workers > 0,
pin_memory=True,
collate_fn=collate_dataclass,
drop_last=True,
)
optim = torch.optim.AdamW( # type: ignore
model.parameters(),
lr=config.learning_rate,
weight_decay=config.weight_decay,
)
scheduler = torch.optim.lr_scheduler.LambdaLR(
optim, lr_lambda=lambda step: min(1.0, step / config.warmup_steps)
)
# HF accelerate setup. We use this for parallelism, etc!
model, train_loader, optim, scheduler = accelerator.prepare(
model, train_loader, optim, scheduler
)
accelerator.register_for_checkpointing(scheduler)
# Restore an existing model checkpoint.
if restore_checkpoint_dir is not None:
accelerator.load_state(str(restore_checkpoint_dir))
# Get the initial step count.
if restore_checkpoint_dir is not None and restore_checkpoint_dir.name.startswith(
"checkpoint_"
):
step = int(restore_checkpoint_dir.name.partition("_")[2])
else:
step = int(scheduler.state_dict()["last_epoch"])
assert step == 0 or restore_checkpoint_dir is not None, step
# Save an initial checkpoint. Not a big deal but currently this has an
# off-by-one error, in that `step` means something different in this
# checkpoint vs the others.
accelerator.save_state(str(experiment_dir / f"checkpoints_{step}"))
# Run training loop!
loss_helper = training_loss.TrainingLossComputer(config.loss, device=device)
loop_metrics_gen = training_utils.loop_metric_generator(counter_init=step)
prev_checkpoint_path: Path | None = None
while True:
for train_batch in train_loader:
loop_metrics = next(loop_metrics_gen)
step = loop_metrics.counter
loss, log_outputs = loss_helper.compute_denoising_loss(
model,
unwrapped_model=accelerator.unwrap_model(model),
train_batch=train_batch,
)
log_outputs["learning_rate"] = scheduler.get_last_lr()[0]
accelerator.log(log_outputs, step=step)
accelerator.backward(loss)
if accelerator.sync_gradients:
accelerator.clip_grad_norm_(model.parameters(), config.max_grad_norm)
optim.step()
scheduler.step()
optim.zero_grad(set_to_none=True)
# The rest of the loop will only be executed by the main process.
if not accelerator.is_main_process:
continue
# Logging.
if step % 10 == 0:
assert writer is not None
for k, v in log_outputs.items():
writer.add_scalar(k, v, step)
# Print status update to terminal.
if step % 20 == 0:
mem_free, mem_total = torch.cuda.mem_get_info()
logger.info(
f"step: {step} ({loop_metrics.iterations_per_sec:.2f} it/sec)"
f" mem: {(mem_total - mem_free) / 1024**3:.2f}/{mem_total / 1024**3:.2f}G"
f" lr: {scheduler.get_last_lr()[0]:.7f}"
f" loss: {loss.item():.6f}"
)
# Checkpointing.
if step % 5000 == 0:
# Save checkpoint.
checkpoint_path = experiment_dir / f"checkpoints_{step}"
accelerator.save_state(str(checkpoint_path))
logger.info(f"Saved checkpoint to {checkpoint_path}")
# Keep checkpoints from only every 100k steps.
if prev_checkpoint_path is not None:
shutil.rmtree(prev_checkpoint_path)
prev_checkpoint_path = None if step % 100_000 == 0 else checkpoint_path
del checkpoint_path
if __name__ == "__main__":
tyro.cli(run_training)
================================================
FILE: 2_run_hamer_on_vrs.py
================================================
"""Script to run HaMeR on VRS data and save outputs to a pickle file."""
import pickle
import shutil
from pathlib import Path
import cv2
import imageio.v3 as iio
import numpy as np
import tyro
from egoallo.hand_detection_structs import (
SavedHamerOutputs,
SingleHandHamerOutputWrtCamera,
)
from hamer_helper import HamerHelper
from projectaria_tools.core import calibration
from projectaria_tools.core.data_provider import (
VrsDataProvider,
create_vrs_data_provider,
)
from tqdm.auto import tqdm
from egoallo.inference_utils import InferenceTrajectoryPaths
def main(traj_root: Path, overwrite: bool = False) -> None:
"""Run HaMeR for on trajectory. We'll save outputs to
`traj_root/hamer_outputs.pkl` and `traj_root/hamer_outputs_render".
Arguments:
traj_root: The root directory of the trajectory. We assume that there's
a VRS file in this directory.
overwrite: If True, overwrite any existing HaMeR outputs.
"""
paths = InferenceTrajectoryPaths.find(traj_root)
vrs_path = paths.vrs_file
assert vrs_path.exists()
pickle_out = traj_root / "hamer_outputs.pkl"
hamer_render_out = traj_root / "hamer_outputs_render" # This is just for debugging.
run_hamer_and_save(vrs_path, pickle_out, hamer_render_out, overwrite)
def run_hamer_and_save(
vrs_path: Path, pickle_out: Path, hamer_render_out: Path, overwrite: bool
) -> None:
if not overwrite:
assert not pickle_out.exists()
assert not hamer_render_out.exists()
else:
pickle_out.unlink(missing_ok=True)
shutil.rmtree(hamer_render_out, ignore_errors=True)
hamer_render_out.mkdir(exist_ok=True)
hamer_helper = HamerHelper()
# VRS data provider setup.
provider = create_vrs_data_provider(str(vrs_path.absolute()))
assert isinstance(provider, VrsDataProvider)
rgb_stream_id = provider.get_stream_id_from_label("camera-rgb")
assert rgb_stream_id is not None
num_images = provider.get_num_data(rgb_stream_id)
print(f"Found {num_images=}")
# Get calibrations.
device_calib = provider.get_device_calibration()
assert device_calib is not None
camera_calib = device_calib.get_camera_calib("camera-rgb")
assert camera_calib is not None
pinhole = calibration.get_linear_camera_calibration(1408, 1408, 450)
# Compute camera extrinsics!
sophus_T_device_camera = device_calib.get_transform_device_sensor("camera-rgb")
sophus_T_cpf_camera = device_calib.get_transform_cpf_sensor("camera-rgb")
assert sophus_T_device_camera is not None
assert sophus_T_cpf_camera is not None
T_device_cam = np.concatenate(
[
sophus_T_device_camera.rotation().to_quat().squeeze(axis=0),
sophus_T_device_camera.translation().squeeze(axis=0),
]
)
T_cpf_cam = np.concatenate(
[
sophus_T_cpf_camera.rotation().to_quat().squeeze(axis=0),
sophus_T_cpf_camera.translation().squeeze(axis=0),
]
)
assert T_device_cam.shape == T_cpf_cam.shape == (7,)
# Dict from capture timestamp in nanoseconds to fields we care about.
detections_left_wrt_cam: dict[int, SingleHandHamerOutputWrtCamera | None] = {}
detections_right_wrt_cam: dict[int, SingleHandHamerOutputWrtCamera | None] = {}
pbar = tqdm(range(num_images))
for i in pbar:
image_data, image_data_record = provider.get_image_data_by_index(
rgb_stream_id, i
)
undistorted_image = calibration.distort_by_calibration(
image_data.to_numpy_array(), pinhole, camera_calib
)
hamer_out_left, hamer_out_right = hamer_helper.look_for_hands(
undistorted_image,
focal_length=450,
)
timestamp_ns = image_data_record.capture_timestamp_ns
if hamer_out_left is None:
detections_left_wrt_cam[timestamp_ns] = None
else:
detections_left_wrt_cam[timestamp_ns] = {
"verts": hamer_out_left["verts"],
"keypoints_3d": hamer_out_left["keypoints_3d"],
"mano_hand_pose": hamer_out_left["mano_hand_pose"],
"mano_hand_betas": hamer_out_left["mano_hand_betas"],
"mano_hand_global_orient": hamer_out_left["mano_hand_global_orient"],
}
if hamer_out_right is None:
detections_right_wrt_cam[timestamp_ns] = None
else:
detections_right_wrt_cam[timestamp_ns] = {
"verts": hamer_out_right["verts"],
"keypoints_3d": hamer_out_right["keypoints_3d"],
"mano_hand_pose": hamer_out_right["mano_hand_pose"],
"mano_hand_betas": hamer_out_right["mano_hand_betas"],
"mano_hand_global_orient": hamer_out_right["mano_hand_global_orient"],
}
composited = undistorted_image
composited = hamer_helper.composite_detections(
composited,
hamer_out_left,
border_color=(255, 100, 100),
focal_length=450,
)
composited = hamer_helper.composite_detections(
composited,
hamer_out_right,
border_color=(100, 100, 255),
focal_length=450,
)
composited = put_text(
composited,
"L detections: "
+ (
"0" if hamer_out_left is None else str(hamer_out_left["verts"].shape[0])
),
0,
color=(255, 100, 100),
font_scale=10.0 / 2880.0 * undistorted_image.shape[0],
)
composited = put_text(
composited,
"R detections: "
+ (
"0"
if hamer_out_right is None
else str(hamer_out_right["verts"].shape[0])
),
1,
color=(100, 100, 255),
font_scale=10.0 / 2880.0 * undistorted_image.shape[0],
)
composited = put_text(
composited,
f"ns={timestamp_ns}",
2,
color=(255, 255, 255),
font_scale=10.0 / 2880.0 * undistorted_image.shape[0],
)
print(f"Saving image {i:06d} to {hamer_render_out / f'{i:06d}.jpeg'}")
iio.imwrite(
str(hamer_render_out / f"{i:06d}.jpeg"),
np.concatenate(
[
# Darken input image, just for contrast...
(undistorted_image * 0.6).astype(np.uint8),
composited,
],
axis=1,
),
quality=90,
)
outputs = SavedHamerOutputs(
mano_faces_right=hamer_helper.get_mano_faces("right"),
mano_faces_left=hamer_helper.get_mano_faces("left"),
detections_right_wrt_cam=detections_right_wrt_cam,
detections_left_wrt_cam=detections_left_wrt_cam,
T_device_cam=T_device_cam,
T_cpf_cam=T_cpf_cam,
)
with open(pickle_out, "wb") as f:
pickle.dump(outputs, f)
def put_text(
image: np.ndarray,
text: str,
line_number: int,
color: tuple[int, int, int],
font_scale: float,
) -> np.ndarray:
"""Put some text on the top-left corner of an image."""
image = image.copy()
font = cv2.FONT_HERSHEY_PLAIN
cv2.putText(
image,
text=text,
org=(2, 1 + int(15 * font_scale * (line_number + 1))),
fontFace=font,
fontScale=font_scale,
color=(0, 0, 0),
thickness=max(int(font_scale), 1),
lineType=cv2.LINE_AA,
)
cv2.putText(
image,
text=text,
org=(2, 1 + int(15 * font_scale * (line_number + 1))),
fontFace=font,
fontScale=font_scale,
color=color,
thickness=max(int(font_scale), 1),
lineType=cv2.LINE_AA,
)
return image
if __name__ == "__main__":
tyro.cli(main)
================================================
FILE: 3_aria_inference.py
================================================
from __future__ import annotations
import dataclasses
import time
from pathlib import Path
import numpy as np
import torch
import viser
import yaml
from egoallo import fncsmpl, fncsmpl_extensions
from egoallo.data.aria_mps import load_point_cloud_and_find_ground
from egoallo.guidance_optimizer_jax import GuidanceMode
from egoallo.hand_detection_structs import (
CorrespondedAriaHandWristPoseDetections,
CorrespondedHamerDetections,
)
from egoallo.inference_utils import (
InferenceInputTransforms,
InferenceTrajectoryPaths,
load_denoiser,
)
from egoallo.sampling import run_sampling_with_stitching
from egoallo.transforms import SE3, SO3
from egoallo.vis_helpers import visualize_traj_and_hand_detections
@dataclasses.dataclass
class Args:
traj_root: Path
"""Search directory for trajectories. This should generally be laid out as something like:
traj_dir/
video.vrs
egoallo_outputs/
{date}_{start_index}-{end_index}.npz
...
...
"""
checkpoint_dir: Path = Path("./egoallo_checkpoint_april13/checkpoints_3000000/")
smplh_npz_path: Path = Path("./data/smplh/neutral/model.npz")
glasses_x_angle_offset: float = 0.0
"""Rotate the CPF poses by some X angle."""
start_index: int = 0
"""Index within the downsampled trajectory to start inference at."""
traj_length: int = 128
"""How many timesteps to estimate body motion for."""
num_samples: int = 1
"""Number of samples to take."""
guidance_mode: GuidanceMode = "aria_hamer"
"""Which guidance mode to use."""
guidance_inner: bool = True
"""Whether to apply guidance optimizer between denoising steps. This is
important if we're doing anything with hands. It can be turned off to speed
up debugging/experiments, or if we only care about foot skating losses."""
guidance_post: bool = True
"""Whether to apply guidance optimizer after diffusion sampling."""
save_traj: bool = True
"""Whether to save the output trajectory, which will be placed under `traj_dir/egoallo_outputs/some_name.npz`."""
visualize_traj: bool = False
"""Whether to visualize the trajectory after sampling."""
def main(args: Args) -> None:
device = torch.device("cuda")
traj_paths = InferenceTrajectoryPaths.find(args.traj_root)
if traj_paths.splat_path is not None:
print("Found splat at", traj_paths.splat_path)
else:
print("No scene splat found.")
# Get point cloud + floor.
points_data, floor_z = load_point_cloud_and_find_ground(traj_paths.points_path)
# Read transforms from VRS / MPS, downsampled.
transforms = InferenceInputTransforms.load(
traj_paths.vrs_file, traj_paths.slam_root_dir, fps=30
).to(device=device)
# Note the off-by-one for Ts_world_cpf, which we need for relative transform computation.
Ts_world_cpf = (
SE3(
transforms.Ts_world_cpf[
args.start_index : args.start_index + args.traj_length + 1
]
)
@ SE3.from_rotation(
SO3.from_x_radians(
transforms.Ts_world_cpf.new_tensor(args.glasses_x_angle_offset)
)
)
).parameters()
pose_timestamps_sec = transforms.pose_timesteps[
args.start_index + 1 : args.start_index + args.traj_length + 1
]
Ts_world_device = transforms.Ts_world_device[
args.start_index + 1 : args.start_index + args.traj_length + 1
]
del transforms
# Get temporally corresponded HaMeR detections.
if traj_paths.hamer_outputs is not None:
hamer_detections = CorrespondedHamerDetections.load(
traj_paths.hamer_outputs,
pose_timestamps_sec,
).to(device)
else:
print("No hand detections found.")
hamer_detections = None
# Get temporally corresponded Aria wrist and palm estimates.
if traj_paths.wrist_and_palm_poses_csv is not None:
aria_detections = CorrespondedAriaHandWristPoseDetections.load(
traj_paths.wrist_and_palm_poses_csv,
pose_timestamps_sec,
Ts_world_device=Ts_world_device.numpy(force=True),
).to(device)
else:
print("No Aria hand detections found.")
aria_detections = None
print(f"{Ts_world_cpf.shape=}")
server = None
if args.visualize_traj:
server = viser.ViserServer()
server.gui.configure_theme(dark_mode=True)
denoiser_network = load_denoiser(args.checkpoint_dir).to(device)
body_model = fncsmpl.SmplhModel.load(args.smplh_npz_path).to(device)
traj = run_sampling_with_stitching(
denoiser_network,
body_model=body_model,
guidance_mode=args.guidance_mode,
guidance_inner=args.guidance_inner,
guidance_post=args.guidance_post,
Ts_world_cpf=Ts_world_cpf,
hamer_detections=hamer_detections,
aria_detections=aria_detections,
num_samples=args.num_samples,
device=device,
floor_z=floor_z,
)
# Save outputs in case we want to visualize later.
if args.save_traj:
save_name = (
time.strftime("%Y%m%d-%H%M%S")
+ f"_{args.start_index}-{args.start_index + args.traj_length}"
)
out_path = args.traj_root / "egoallo_outputs" / (save_name + ".npz")
out_path.parent.mkdir(parents=True, exist_ok=True)
assert not out_path.exists()
(args.traj_root / "egoallo_outputs" / (save_name + "_args.yaml")).write_text(
yaml.dump(dataclasses.asdict(args))
)
posed = traj.apply_to_body(body_model)
Ts_world_root = fncsmpl_extensions.get_T_world_root_from_cpf_pose(
posed, Ts_world_cpf[..., 1:, :]
)
print(f"Saving to {out_path}...", end="")
np.savez(
out_path,
Ts_world_cpf=Ts_world_cpf[1:, :].numpy(force=True),
Ts_world_root=Ts_world_root.numpy(force=True),
body_quats=posed.local_quats[..., :21, :].numpy(force=True),
left_hand_quats=posed.local_quats[..., 21:36, :].numpy(force=True),
right_hand_quats=posed.local_quats[..., 36:51, :].numpy(force=True),
contacts=traj.contacts.numpy(force=True), # Sometimes we forgot this...
betas=traj.betas.numpy(force=True),
frame_nums=np.arange(args.start_index, args.start_index + args.traj_length),
timestamps_ns=(np.array(pose_timestamps_sec) * 1e9).astype(np.int64),
)
print("saved!")
# Visualize.
if args.visualize_traj:
assert server is not None
loop_cb = visualize_traj_and_hand_detections(
server,
Ts_world_cpf[1:],
traj,
body_model,
hamer_detections,
aria_detections,
points_data=points_data,
splat_path=traj_paths.splat_path,
floor_z=floor_z,
)
while True:
loop_cb()
if __name__ == "__main__":
import tyro
main(tyro.cli(Args))
================================================
FILE: 4_visualize_outputs.py
================================================
from __future__ import annotations
import io
from pathlib import Path
from typing import Callable
import cv2
import imageio.v3 as iio
import numpy as np
import torch
import tyro
import viser
from projectaria_tools.core.data_provider import (
VrsDataProvider,
create_vrs_data_provider,
)
from projectaria_tools.core.sensor_data import TimeDomain
from tqdm import tqdm
from egoallo import fncsmpl
from egoallo.data.aria_mps import load_point_cloud_and_find_ground
from egoallo.hand_detection_structs import (
CorrespondedAriaHandWristPoseDetections,
CorrespondedHamerDetections,
)
from egoallo.inference_utils import InferenceTrajectoryPaths
from egoallo.network import EgoDenoiseTraj
from egoallo.transforms import SE3, SO3
from egoallo.vis_helpers import visualize_traj_and_hand_detections
def main(
search_root_dir: Path,
smplh_npz_path: Path = Path("./data/smplh/neutral/model.npz"),
) -> None:
"""Visualization script for outputs from EgoAllo.
Arguments:
search_root_dir: Root directory where inputs/outputs are stored. All
NPZ files in this directory will be assumed to be outputs from EgoAllo.
smplh_npz_path: Path to the SMPLH model NPZ file.
"""
device = torch.device("cuda")
body_model = fncsmpl.SmplhModel.load(smplh_npz_path).to(device)
server = viser.ViserServer()
server.gui.configure_theme(dark_mode=True)
def get_file_list():
return ["None"] + sorted(
str(p.relative_to(search_root_dir))
for p in search_root_dir.glob("**/egoallo_outputs/*.npz")
)
options = get_file_list()
file_dropdown = server.gui.add_dropdown("File", options=options)
refresh_file_list = server.gui.add_button("Refresh File List")
@refresh_file_list.on_click
def _(_) -> None:
file_dropdown.options = get_file_list()
trajectory_folder = server.gui.add_folder("Trajectory")
current_file = "None"
loop_cb = lambda: None
while True:
loop_cb()
if current_file != file_dropdown.value:
current_file = file_dropdown.value
# Clear the scene.
server.scene.reset()
if current_file != "None":
# Clear the folder by removing then re-adding it.
# Perhaps we should expose some API for looping through children?
trajectory_folder.remove()
trajectory_folder = server.gui.add_folder("Trajectory")
with trajectory_folder:
npz_path = Path(search_root_dir / current_file).resolve()
loop_cb = load_and_visualize(
server,
npz_path,
body_model,
device=device,
)
args = npz_path.parent / (npz_path.stem + "_args.yaml")
if args.exists():
with server.gui.add_folder("Args"):
server.gui.add_markdown(
"```\n" + args.read_text() + "\n```"
)
def load_and_visualize(
server: viser.ViserServer,
npz_path: Path,
body_model: fncsmpl.SmplhModel,
device: torch.device,
) -> Callable[[], int]:
# Here's how we saved:
#
# np.savez(
# out_path,
# Ts_world_cpf=Ts_world_cpf[1:, :].numpy(force=True),
# Ts_world_root=Ts_world_root.numpy(force=True),
# body_quats=posed.local_quats[..., :21, :].numpy(force=True),
# left_hand_quats=posed.local_quats[..., 21:36, :].numpy(force=True),
# right_hand_quats=posed.local_quats[..., 36:51, :].numpy(force=True),
# betas=traj.betas.numpy(force=True),
# frame_nums=np.arange(args.start_index, args.start_index + args.traj_length),
# timestamps_ns=(np.array(pose_timestamps_sec) * 1e9).astype(np.int64),
# )
outputs = np.load(npz_path)
expected_keys = [
"Ts_world_cpf",
"Ts_world_root",
"body_quats",
"left_hand_quats",
"right_hand_quats",
"betas",
"frame_nums",
"timestamps_ns",
]
assert all(key in outputs for key in expected_keys), (
f"Missing keys in NPZ file. Expected: {expected_keys}, Found: {list(outputs.keys())}"
)
(num_samples, timesteps, _, _) = outputs["body_quats"].shape
# We assume the directory structure is:
# - some trajectory root
# - outputs
# - the npz file
traj_dir = npz_path.resolve().parent.parent
paths = InferenceTrajectoryPaths.find(traj_dir)
provider = create_vrs_data_provider(str(paths.vrs_file))
device_calib = provider.get_device_calibration()
T_device_cpf = SE3(
torch.from_numpy(
device_calib.get_transform_device_cpf().to_quat_and_translation()
)
)
assert T_device_cpf.wxyz_xyz.shape == (1, 7)
pose_timestamps_sec = outputs["timestamps_ns"] / 1e9
Ts_world_device = (
SE3(torch.from_numpy(outputs["Ts_world_cpf"])) @ T_device_cpf.inverse()
).wxyz_xyz
# Get temporally corresponded HaMeR detections.
if paths.hamer_outputs is not None:
hamer_detections = CorrespondedHamerDetections.load(
paths.hamer_outputs,
pose_timestamps_sec,
)
else:
print("No hand detections found.")
hamer_detections = None
# Get temporally corresponded Aria wrist and palm estimates.
if paths.wrist_and_palm_poses_csv is not None:
aria_detections = CorrespondedAriaHandWristPoseDetections.load(
paths.wrist_and_palm_poses_csv,
pose_timestamps_sec,
Ts_world_device=Ts_world_device.numpy(force=True),
)
else:
aria_detections = None
if paths.splat_path is not None:
print("Found splat at", paths.splat_path)
else:
print("No scene splat found.")
# Get point cloud + floor.
points_data, floor_z = load_point_cloud_and_find_ground(
paths.points_path, "filtered"
)
traj = EgoDenoiseTraj(
betas=torch.from_numpy(outputs["betas"]).to(device),
body_rotmats=SO3(
torch.from_numpy(outputs["body_quats"]),
)
.as_matrix()
.to(device),
# We weren't saving contacts originally. We added it September 28th.
contacts=torch.zeros((num_samples, timesteps, 21), device=device)
if "contacts" not in outputs
else torch.from_numpy(outputs["contacts"]).to(device),
hand_rotmats=SO3(
torch.from_numpy(
np.concatenate(
[
outputs["left_hand_quats"],
outputs["right_hand_quats"],
],
axis=-2,
)
).to(device)
).as_matrix(),
)
Ts_world_cpf = torch.from_numpy(outputs["Ts_world_cpf"]).to(device)
def get_ego_video(
start_index: int,
end_index: int,
total_duration: float,
) -> bytes:
"""Helper function that returns the egocentric video corresponding to
some start/end pose index."""
assert isinstance(provider, VrsDataProvider)
rgb_stream_id = provider.get_stream_id_from_label("camera-rgb")
assert rgb_stream_id is not None
camera_fps = provider.get_configuration(rgb_stream_id).get_nominal_rate_hz()
print(f"{camera_fps=}")
start_ns = int(outputs["timestamps_ns"][start_index])
first_ns = provider.get_first_time_ns(rgb_stream_id, TimeDomain.RECORD_TIME)
image_start_index = int((start_ns - first_ns) / 1e9 * camera_fps)
image_end_index = min(
int(image_start_index + (end_index - start_index) / 30.0 * camera_fps) + 5,
provider.get_num_data(rgb_stream_id),
)
frames = []
for i in tqdm(range(image_start_index, image_end_index)):
image_data = provider.get_image_data_by_index(rgb_stream_id, i)[0]
image_array = image_data.to_numpy_array().copy()
image_array = cv2.resize(
image_array, (800, 800), interpolation=cv2.INTER_AREA
)
image_array = cv2.rotate(image_array, cv2.ROTATE_90_CLOCKWISE)
frames.append(image_array)
fps = len(frames) / total_duration
output = io.BytesIO()
iio.imwrite(
output,
frames,
fps=fps,
extension=".mp4",
codec="libx264",
pixelformat="yuv420p",
quality=None,
ffmpeg_params=["-crf", "23"],
)
return output.getvalue()
return visualize_traj_and_hand_detections(
server,
Ts_world_cpf,
traj,
body_model,
hamer_detections,
aria_detections,
points_data,
paths.splat_path,
floor_z=floor_z,
get_ego_video=get_ego_video,
)
if __name__ == "__main__":
tyro.cli(main)
================================================
FILE: 5_eval_body_metrics.py
================================================
"""Example script for computing body metrics on the test split of the AMASS dataset.
This is not the exact script we used for the paper metrics, but should have the
details that matter matched. Below are some metrics from this script when our
released checkpoint is passed in.
For --subseq-len 128:
mpjpe 118.340 +/- 1.350 (in paper: 119.7 +/- 1.3)
pampjpe 100.026 +/- 1.349 (in paper: 101.1 +/- 1.3)
T_head 0.006 +/- 0.000 (in paper: 0.0062 +/- 0.0001)
foot_contact (GND) 1.000 +/- 0.000 (in paper: 1.0 +/- 0.0)
foot_skate 0.417 +/- 0.017 (not reported in paper)
For --subseq-len 32:
mpjpe 129.193 +/- 1.108 (in paper: 129.8 +/- 1.1)
pampjpe 109.489 +/- 1.147 (in paper: 109.8 +/- 1.1)
T_head 0.006 +/- 0.000 (in paper: 0.0064 +/- 0.0001)
foot_contact (GND) 0.985 +/- 0.003 (in paper: 0.98 +/- 0.00)
foot_skate 0.185 +/- 0.005 (not reported in paper)
"""
from pathlib import Path
import jax.tree
import numpy as np
import torch.optim.lr_scheduler
import torch.utils.data
import tyro
from egoallo import fncsmpl
from egoallo.data.amass import EgoAmassHdf5Dataset
from egoallo.fncsmpl_extensions import get_T_world_root_from_cpf_pose
from egoallo.inference_utils import load_denoiser
from egoallo.metrics_helpers import (
compute_foot_contact,
compute_foot_skate,
compute_head_trans,
compute_mpjpe,
)
from egoallo.sampling import run_sampling_with_stitching
from egoallo.transforms import SE3, SO3
def main(
dataset_hdf5_path: Path,
dataset_files_path: Path,
subseq_len: int = 128,
guidance_inner: bool = False,
checkpoint_dir: Path = Path("./egoallo_checkpoint_april13/checkpoints_3000000/"),
smplh_npz_path: Path = Path("./data/smplh/neutral/model.npz"),
num_samples: int = 1,
) -> None:
"""Compute body metrics on the test split of the AMASS dataset."""
device = torch.device("cuda")
# Setup.
denoiser_network = load_denoiser(checkpoint_dir).to(device)
dataset = EgoAmassHdf5Dataset(
dataset_hdf5_path,
dataset_files_path,
splits=("test",),
# We need an extra timestep in order to compute the relative CPF pose. (T_cpf_tm1_cpf_t)
subseq_len=subseq_len + 1,
cache_files=True,
slice_strategy="deterministic",
random_variable_len_proportion=0.0,
)
body_model = fncsmpl.SmplhModel.load(smplh_npz_path).to(device)
metrics = list[dict[str, np.ndarray]]()
for i in range(len(dataset)):
sequence = dataset[i].to(device)
samples = run_sampling_with_stitching(
denoiser_network,
body_model=body_model,
guidance_mode="no_hands",
guidance_inner=guidance_inner,
guidance_post=True,
Ts_world_cpf=sequence.T_world_cpf,
hamer_detections=None,
aria_detections=None,
num_samples=num_samples,
floor_z=0.0,
device=device,
guidance_verbose=False,
)
assert samples.hand_rotmats is not None
assert samples.betas.shape == (num_samples, subseq_len, 16)
assert samples.body_rotmats.shape == (num_samples, subseq_len, 21, 3, 3)
assert samples.hand_rotmats.shape == (num_samples, subseq_len, 30, 3, 3)
assert sequence.hand_quats is not None
# We'll only use the body joint rotations.
pred_posed = body_model.with_shape(samples.betas).with_pose(
T_world_root=SE3.identity(device, torch.float32).wxyz_xyz,
local_quats=SO3.from_matrix(
torch.cat([samples.body_rotmats, samples.hand_rotmats], dim=2)
).wxyz,
)
pred_posed = pred_posed.with_new_T_world_root(
get_T_world_root_from_cpf_pose(pred_posed, sequence.T_world_cpf[1:, ...])
)
label_posed = body_model.with_shape(sequence.betas[1:, ...]).with_pose(
sequence.T_world_root[1:, ...],
torch.cat(
[
sequence.body_quats[1:, ...],
sequence.hand_quats[1:, ...],
],
dim=1,
),
)
metrics.append(
{
"mpjpe": compute_mpjpe(
label_T_world_root=label_posed.T_world_root,
label_Ts_world_joint=label_posed.Ts_world_joint[:, :21, :],
pred_T_world_root=pred_posed.T_world_root,
pred_Ts_world_joint=pred_posed.Ts_world_joint[:, :, :21, :],
per_frame_procrustes_align=False,
),
"pampjpe": compute_mpjpe(
label_T_world_root=label_posed.T_world_root,
label_Ts_world_joint=label_posed.Ts_world_joint[:, :21, :],
pred_T_world_root=pred_posed.T_world_root,
pred_Ts_world_joint=pred_posed.Ts_world_joint[:, :, :21, :],
per_frame_procrustes_align=True,
),
# We didn't report foot skating metrics in the paper. It's not
# really meaningful: since we optimize foot skating in the
# guidance optimizer, it's easy to "cheat" this metric.
"foot_skate": compute_foot_skate(
pred_Ts_world_joint=pred_posed.Ts_world_joint[:, :, :21, :],
),
"foot_contact (GND)": compute_foot_contact(
pred_Ts_world_joint=pred_posed.Ts_world_joint[:, :, :21, :],
),
"T_head": compute_head_trans(
label_Ts_world_joint=label_posed.Ts_world_joint[:, :21, :],
pred_Ts_world_joint=pred_posed.Ts_world_joint[:, :, :21, :],
),
}
)
print("=" * 80)
print("=" * 80)
print("=" * 80)
print(f"Metrics ({i}/{len(dataset)} processed)")
for k, v in jax.tree.map(
lambda *x: f"{np.mean(x):.3f} +/- {np.std(x) / np.sqrt(len(metrics) * num_samples):.3f}",
*metrics,
).items():
print("\t", k, v)
print("=" * 80)
print("=" * 80)
print("=" * 80)
if __name__ == "__main__":
tyro.cli(main)
================================================
FILE: LICENSE
================================================
MIT License
Copyright (c) 2024 Brent Yi
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
================================================
FILE: README.md
================================================
# egoallo
**[Project page](https://egoallo.github.io/) •
[arXiv](https://arxiv.org/abs/2410.03665)**
Code release for our preprint:
<table><tr><td>
Brent Yi<sup>1</sup>, Vickie Ye<sup>1</sup>, Maya Zheng<sup>1</sup>, Yunqi Li<sup>2</sup>, Lea Müller<sup>1</sup>, Georgios Pavlakos<sup>3</sup>, Yi Ma<sup>1</sup>, Jitendra Malik<sup>1</sup>, and Angjoo Kanazawa<sup>1</sup>.
<strong>Estimating Body and Hand Motion in an Ego-sensed World.</strong>
arXiV, 2024.
</td></tr>
</table>
<sup>1</sup><em>UC Berkeley</em>, <sup>2</sup><em>ShanghaiTech</em>, <sup>3</sup><em>UT Austin</em>
---
## Updates
- **Oct 7, 2024:** Initial release. (training code, core implementation details)
- **Oct 14, 2024:** Added model checkpoint, dataset preprocessing, inference, and visualization scripts.
- **May 6, 2025:** Updated scripts + instructions for dataset preprocessing, which is now self-contained in this repository.
## Overview
**TLDR;** We use egocentric SLAM poses and images to estimate 3D human body pose, height, and hands.
https://github.com/user-attachments/assets/7d28e07f-ab83-4749-ac6b-abe692d9ba20
This repository is structured as follows:
```
.
├── download_checkpoint_and_data.sh
│ - Download model checkpoint and sample data.
├── 0_preprocess_training_data.py
│ - Preprocessing script for training datasets.
├── 1_train_motion_prior.py
│ - Training script for motion diffusion model.
├── 2_run_hamer_on_vrs.py
│ - Run HaMeR on inference data (expects Aria VRS).
├── 3_aria_inference.py
│ - Run full pipeline on inference data.
├── 4_visualize_outputs.py
│ - Visualize outputs from inference.
├── 5_eval_body_metrics.py
│ - Compute and print body estimation accuracy metrics.
│
├── src/egoallo/
│ ├── data/ - Dataset utilities.
│ ├── transforms/ - SO(3) / SE(3) transformation helpers.
│ └── *.py - All core implementation.
│
└── pyproject.toml - Python dependencies/package metadata.
```
## Getting started
EgoAllo requires Python 3.12 or newer.
1. **Clone the repository.**
```bash
git clone https://github.com/brentyi/egoallo.git
```
2. **Install general dependencies.**
```bash
cd egoallo
pip install -e .
```
3. **Download+unzip model checkpoint and sample data.**
```bash
bash download_checkpoint_and_data.sh
```
You can also download the zip files manually: here are links to the [checkpoint](https://drive.google.com/file/d/14bDkWixFgo3U6dgyrCRmLoXSsXkrDA2w/view?usp=drive_link) and [example trajectories](https://drive.google.com/file/d/14zQ95NYxL4XIT7KIlFgAYTPCRITWxQqu/view?usp=drive_link).
4. **Download the SMPL-H model file.**
You can find the "Extended SMPL+H model" (16 shape parameters) from the [MANO project webpage](https://mano.is.tue.mpg.de/).
Our scripts assumes an npz file located at `./data/smplh/neutral/model.npz`, but this can be overridden at the command-line (`--smplh-npz-path {your path}`).
5. **Visualize model outputs.**
The example trajectories directory includes example outputs from our model. You can visualize them with:
```bash
python 4_visualize_outputs.py --search-root-dir ./egoallo_example_trajectories
```
## Running inference
1. **Installing inference dependencies.**
Our guidance optimization uses a Levenberg-Marquardt optimizer that's implemented in JAX. If you want to run this on an NVIDIA GPU, you'll need to install JAX with CUDA support:
```bash
# Also see: https://jax.readthedocs.io/en/latest/installation.html
pip install "jax[cuda12]==0.6.1"
```
You'll also need [jaxls](https://github.com/brentyi/jaxls):
```bash
pip install git+https://github.com/brentyi/jaxls.git
```
2. **Running inference on example data.**
Here's an example command for running EgoAllo on the "coffeemachine" sequence:
```bash
python 3_aria_inference.py --traj-root ./egoallo_example_trajectories/coffeemachine
```
You can run `python 3_aria_inference.py --help` to see the full list of options.
3. **Running inference on your own data.**
To run inference on your own data, you can copy the structure of the example trajectories. The key files are:
- A VRS file from Project Aria, which contains calibrations and images.
- SLAM outputs from Project Aria's MPS: `closed_loop_trajectory.csv` and `semidense_points.csv.gz`.
- (optional) HaMeR outputs, which we save to a `hamer_outputs.pkl`.
- (optional) Project Aria wrist and palm tracking outputs.
4. **Running HaMeR on your own data.**
To generate the `hamer_outputs.pkl` file, you'll need to install [hamer_helper](https://github.com/brentyi/hamer_helper).
Then, as an example for running on our coffeemachine sequence:
```bash
python 2_run_hamer_on_vrs.py --traj-root ./egoallo_example_trajectories/coffeemachine
```
## Preprocessing Training Data
To train the motion prior model, we use data from the [AMASS dataset](https://amass.is.tue.mpg.de/). Due to licensing constraints, we cannot redistribute the preprocessed data. Instead, we provide two sequential preprocessing scripts:
1. **Download the AMASS dataset.**
Download the AMASS dataset from the [official website](https://amass.is.tue.mpg.de/). We use the following splits:
- **Training**: ACCAD, BioMotionLab_NTroje, BMLhandball, BMLmovi, CMU, DanceDB, DFaust_67EKUT, Eyes_Japan_Dataset, KIT, MPI_Limits, TCD_handMocap, TotalCapture
- **Validation**: HumanEva, MPI_HDM05, SFU, MPI_mosh
- **Testing**: Transitions_mocap, SSM_synced
2. **Run the first preprocessing script.**
```bash
python 0a_preprocess_training_data.py --help
python 0a_preprocess_training_data.py --data-root /path/to/amass --smplh-root ./data/smplh
```
This script, adapted from HuMoR, processes raw AMASS data by:
- Converting to gender-neutral SMPL-H parameters
- Computing contact labels for feet, hands, and knees
- Filtering out problematic sequences (treadmill walking, sequences with foot skating)
- Downsampling to 30fps
3. **Run the second preprocessing script.**
```bash
python 0b_preprocess_training_data.py --help
python 0b_preprocess_training_data.py --data-npz-dir ./data/processed_30fps_no_skating/
```
This converts the processed NPZ files to a unified HDF5 format for more efficient training, with optimized chunk sizes for reading sequences.
## Status
This repository currently contains:
- `egoallo` package, which contains reference training and sampling implementation details.
- Training script.
- Model checkpoints.
- Dataset preprocessing script.
- Inference script.
- Visualization script.
- Setup instructions.
While we've put effort into cleaning up our code for release, this is research
code and there's room for improvement. If you have questions or comments,
please reach out!
================================================
FILE: download_checkpoint_and_data.sh
================================================
# Script for downloading model checkpoint and example inputs/outputs.
# egoallo_checkpoint_april13.zip (552 MB)
gdown https://drive.google.com/file/d/14bDkWixFgo3U6dgyrCRmLoXSsXkrDA2w/view?usp=drive_link --fuzzy
unzip egoallo_checkpoint_april13.zip
rm egoallo_checkpoint_april13.zip
# egoallo_example_trajectories.zip (8.17 GB)
gdown https://drive.google.com/file/d/14zQ95NYxL4XIT7KIlFgAYTPCRITWxQqu/view?usp=drive_link --fuzzy
unzip egoallo_example_trajectories.zip
rm egoallo_example_trajectories.zip
================================================
FILE: pyproject.toml
================================================
[build-system]
requires = ["setuptools>=61.0"]
build-backend = "setuptools.build_meta"
[project]
name = "egoallo"
version = "0.0.0"
description = "egoallo"
readme = "README.md"
license = { text="MIT" }
requires-python = ">=3.12"
classifiers = [
"Programming Language :: Python :: 3.12",
"License :: OSI Approved :: MIT License",
"Operating System :: OS Independent"
]
dependencies = [
"torch==2.7.1",
"viser>=0.2.11",
"plyfile==1.1.2",
"typeguard==4.4.3",
"jaxtyping==0.3.2",
"einops==0.8.1",
"rotary-embedding-torch==0.8.6",
"h5py==3.13.0",
"tensorboard==2.19.0",
"projectaria_tools==1.6.0",
"accelerate==1.7.0",
"tensorboardX==2.6.2.2",
"loguru==0.7.3",
"projectaria-tools[all]==1.6.0",
"opencv-python==4.11.0.86",
"gdown==5.2.0",
"scikit-learn==1.6.1", # Only needed for preprocessing
"smplx==0.1.28", # Only needed for preprocessing
]
[tool.setuptools.package-data]
egoallo = ["py.typed"]
[tool.pyright]
ignore = ["**/preprocessing/**", "./0a_preprocess_training_data.py"]
[tool.ruff.lint]
select = [
"E", # pycodestyle errors.
"F", # Pyflakes rules.
"PLC", # Pylint convention warnings.
"PLE", # Pylint errors.
"PLR", # Pylint refactor recommendations.
"PLW", # Pylint warnings.
]
ignore = [
"E731", # Do not assign a lambda expression, use a def.
"E741", # Ambiguous variable name. (l, O, or I)
"E501", # Line too long.
"E721", # Do not compare types, use `isinstance()`.
"F722", # Forward annotation false positive from jaxtyping. Should be caught by pyright.
"F821", # Forward annotation false positive from jaxtyping. Should be caught by pyright.
"PLR2004", # Magic value used in comparison.
"PLR0915", # Too many statements.
"PLR0913", # Too many arguments.
"PLC0414", # Import alias does not rename variable. (this is used for exporting names)
"PLC1901", # Use falsey strings.
"PLR5501", # Use `elif` instead of `else if`.
"PLR0911", # Too many return statements.
"PLR0912", # Too many branches.
"PLW0603", # Globa statement updates are discouraged.
"PLW2901", # For loop variable overwritten.
]
================================================
FILE: src/egoallo/__init__.py
================================================
================================================
FILE: src/egoallo/fncsmpl.py
================================================
"""Somewhat opinionated wrapper for the SMPL-H body model.
Very little of it is specific to SMPL-H. This could very easily be adapted for other models in SMPL family.
We break down the SMPL-H into four stages, each with a corresponding data structure:
- Loading the model itself:
`model = SmplhModel.load(path to npz)`
- Applying a body shape to the model:
`shaped = model.with_shape(betas)`
- Posing the body shape:
`posed = shaped.with_pose(root pose, local joint poses)`
- Recovering the mesh with LBS:
`mesh = posed.lbs()`
In contrast to other SMPL wrappers:
- Everything is stateless, so we can support arbitrary batch axes.
- The root is no longer ever called a joint.
- The `trans` and `root_orient` inputs are replaced by a single SE(3) root transformation.
- We're using (4,) wxyz quaternion vectors for all rotations, (7,) wxyz_xyz vectors for all
rigid transforms.
"""
from __future__ import annotations
from pathlib import Path
import numpy as np
import torch
from einops import einsum
from jaxtyping import Float, Int
from torch import Tensor
from .tensor_dataclass import TensorDataclass
from .transforms import SE3, SO3
class SmplhModel(TensorDataclass):
"""A human body model from the SMPL family."""
faces: Int[Tensor, "faces 3"]
"""Vertex indices for mesh faces."""
J_regressor: Float[Tensor, "joints+1 verts"]
"""Linear map from vertex to joint positions.
For SMPL-H, 1 root + 21 body joints + 2 * 15 hand joints."""
parent_indices: tuple[int, ...]
"""Defines kinematic tree. Index of -1 signifies that a joint is defined
relative to the root."""
weights: Float[Tensor, "verts joints+1"]
"""LBS weights."""
posedirs: Float[Tensor, "verts 3 joints*9"]
"""Pose blend shape bases."""
v_template: Float[Tensor, "verts 3"]
"""Canonical mesh verts."""
shapedirs: Float[Tensor, "verts 3 n_betas"]
"""Shape bases."""
@staticmethod
def load(model_path: Path) -> SmplhModel:
"""Load a body model from an NPZ file."""
params_numpy: dict[str, np.ndarray] = {
k: _normalize_dtype(v)
for k, v in np.load(model_path, allow_pickle=True).items()
}
assert (
"bs_style" not in params_numpy
or params_numpy.pop("bs_style").item() == b"lbs"
)
assert (
"bs_type" not in params_numpy
or params_numpy.pop("bs_type").item() == b"lrotmin"
)
parent_indices = tuple(
int(index) for index in params_numpy.pop("kintree_table")[0][1:] - 1
)
params = {
k: torch.from_numpy(v)
for k, v in params_numpy.items()
if v.dtype in (np.int32, np.float32)
}
return SmplhModel(
faces=params["f"],
J_regressor=params["J_regressor"],
parent_indices=parent_indices,
weights=params["weights"],
posedirs=params["posedirs"],
v_template=params["v_template"],
shapedirs=params["shapedirs"],
)
def get_num_joints(self) -> int:
"""Get the number of joints in this model."""
return len(self.parent_indices)
def with_shape(self, betas: Float[Tensor, "*#batch n_betas"]) -> SmplhShaped:
"""Compute a new body model, with betas applied."""
num_betas = betas.shape[-1]
assert num_betas <= self.shapedirs.shape[-1]
verts_with_shape = self.v_template + einsum(
self.shapedirs[:, :, :num_betas],
betas,
"verts xyz beta, ... beta -> ... verts xyz",
)
root_and_joints_pred = einsum(
self.J_regressor,
verts_with_shape,
"jointsp1 verts, ... verts xyz -> ... jointsp1 xyz",
)
root_offset = root_and_joints_pred[..., 0:1, :]
return SmplhShaped(
body_model=self,
root_offset=root_offset.unsqueeze(-2),
verts_zero=verts_with_shape - root_offset,
joints_zero=root_and_joints_pred[..., 1:, :] - root_offset,
t_parent_joint=root_and_joints_pred[..., 1:, :]
- root_and_joints_pred[..., np.array(self.parent_indices) + 1, :],
)
class SmplhShaped(TensorDataclass):
"""The SMPL-H body model with a body shape applied."""
body_model: SmplhModel
"""The underlying body model."""
root_offset: Float[Tensor, "*#batch 3"]
verts_zero: Float[Tensor, "*#batch verts 3"]
"""Vertices of shaped body _relative to the root joint_ at the zero
configuration."""
joints_zero: Float[Tensor, "*#batch joints 3"]
"""Joints of shaped body _relative to the root joint_ at the zero
configuration."""
t_parent_joint: Float[Tensor, "*#batch joints 3"]
"""Position of each shaped body joint relative to its parent. Does not
include root."""
def with_pose_decomposed(
self,
T_world_root: Float[Tensor, "*#batch 7"],
body_quats: Float[Tensor, "*#batch 21 4"],
left_hand_quats: Float[Tensor, "*#batch 15 4"] | None = None,
right_hand_quats: Float[Tensor, "*#batch 15 4"] | None = None,
) -> SmplhShapedAndPosed:
"""Pose our SMPL-H body model. Returns a set of joint and vertex outputs."""
num_joints = self.body_model.get_num_joints()
batch_axes = body_quats.shape[:-2]
if left_hand_quats is None:
left_hand_quats = body_quats.new_zeros((*batch_axes, 15, 4))
left_hand_quats[..., 0] = 1.0
if right_hand_quats is None:
right_hand_quats = body_quats.new_zeros((*batch_axes, 15, 4))
right_hand_quats[..., 0] = 1.0
local_quats = broadcasting_cat(
[body_quats, left_hand_quats, right_hand_quats], dim=-2
)
assert local_quats.shape[-2:] == (num_joints, 4)
return self.with_pose(T_world_root, local_quats)
def with_pose(
self,
T_world_root: Float[Tensor, "*#batch 7"],
local_quats: Float[Tensor, "*#batch joints 4"],
) -> SmplhShapedAndPosed:
"""Pose our SMPL-H body model. Returns a set of joint and vertex outputs."""
# Forward kinematics.
num_joints = self.body_model.get_num_joints()
assert local_quats.shape[-2:] == (num_joints, 4)
Ts_world_joint = forward_kinematics(
T_world_root=T_world_root,
Rs_parent_joint=local_quats,
t_parent_joint=self.t_parent_joint,
parent_indices=self.body_model.parent_indices,
)
assert Ts_world_joint.shape[-2:] == (num_joints, 7)
return SmplhShapedAndPosed(
shaped_model=self,
T_world_root=T_world_root,
local_quats=local_quats,
Ts_world_joint=Ts_world_joint,
)
class SmplhShapedAndPosed(TensorDataclass):
shaped_model: SmplhShaped
"""Underlying shaped body model."""
T_world_root: Float[Tensor, "*#batch 7"]
"""Root coordinate frame."""
local_quats: Float[Tensor, "*#batch joints 4"]
"""Local joint orientations."""
Ts_world_joint: Float[Tensor, "*#batch joints 7"]
"""Absolute transform for each joint. Does not include the root."""
def with_new_T_world_root(
self, T_world_root: Float[Tensor, "*#batch 7"]
) -> SmplhShapedAndPosed:
return SmplhShapedAndPosed(
shaped_model=self.shaped_model,
T_world_root=T_world_root,
local_quats=self.local_quats,
Ts_world_joint=(
SE3(T_world_root[..., None, :])
@ SE3(self.T_world_root[..., None, :]).inverse()
@ SE3(self.Ts_world_joint)
).parameters(),
)
def lbs(self) -> SmplMesh:
"""Compute a mesh with LBS."""
num_joints = self.local_quats.shape[-2]
verts_with_blend = self.shaped_model.verts_zero + einsum(
self.shaped_model.body_model.posedirs,
(
SO3(self.local_quats).as_matrix()
- torch.eye(
3, dtype=self.local_quats.dtype, device=self.local_quats.device
)
).reshape((*self.local_quats.shape[:-2], num_joints * 9)),
"... verts j joints_times_9, ... joints_times_9 -> ... verts j",
)
verts_transformed = einsum(
broadcasting_cat(
[
SE3(self.T_world_root).as_matrix()[..., None, :3, :],
SE3(self.Ts_world_joint).as_matrix()[..., :, :3, :],
],
dim=-3,
),
self.shaped_model.body_model.weights,
broadcasting_cat(
[
verts_with_blend[..., :, None, :]
- broadcasting_cat( # Prepend root to joints zeros.
[
self.shaped_model.joints_zero.new_zeros(3),
self.shaped_model.joints_zero[..., None, :, :],
],
dim=-2,
),
verts_with_blend.new_ones(
(
*verts_with_blend.shape[:-1],
1 + self.shaped_model.joints_zero.shape[-2],
1,
)
),
],
dim=-1,
),
"... joints_p1 i j, verts joints_p1, ... verts joints_p1 j -> ... verts i",
)
assert (
verts_transformed.shape[-2:]
== self.shaped_model.body_model.v_template.shape
)
return SmplMesh(
posed_model=self,
verts=verts_transformed,
faces=self.shaped_model.body_model.faces,
)
class SmplMesh(TensorDataclass):
"""Outputs from the SMPL-H model."""
posed_model: SmplhShapedAndPosed
"""Posed model that this mesh was computed for."""
verts: Float[Tensor, "*#batch verts 3"]
"""Vertices for mesh."""
faces: Int[Tensor, "verts 3"]
"""Faces for mesh."""
def forward_kinematics(
T_world_root: Float[Tensor, "*#batch 7"],
Rs_parent_joint: Float[Tensor, "*#batch joints 4"],
t_parent_joint: Float[Tensor, "*#batch joints 3"],
parent_indices: tuple[int, ...],
) -> Float[Tensor, "*#batch joints 7"]:
"""Run forward kinematics to compute absolute poses (T_world_joint) for
each joint. The output array containts pose parameters
(w, x, y, z, tx, ty, tz) for each joint. (this does not include the root!)
Args:
T_world_root: Transformation to world frame from root frame.
Rs_parent_joint: Local orientation of each joint.
t_parent_joint: Position of each joint with respect to its parent frame. (this does not
depend on local joint orientations)
parent_indices: Parent index for each joint. Index of -1 signifies that
a joint is defined relative to the root. We assume that this array is
sorted: parent joints should always precede child joints.
Returns:
Transformations to world frame from each joint frame.
"""
# Check shapes.
num_joints = len(parent_indices)
assert Rs_parent_joint.shape[-2:] == (num_joints, 4)
assert t_parent_joint.shape[-2:] == (num_joints, 3)
# Get relative transforms.
Ts_parent_child = broadcasting_cat([Rs_parent_joint, t_parent_joint], dim=-1)
assert Ts_parent_child.shape[-2:] == (num_joints, 7)
# Compute one joint at a time.
list_Ts_world_joint: list[Tensor] = []
for i in range(num_joints):
if parent_indices[i] == -1:
T_world_parent = T_world_root
else:
T_world_parent = list_Ts_world_joint[parent_indices[i]]
list_Ts_world_joint.append(
(SE3(T_world_parent) @ SE3(Ts_parent_child[..., i, :])).wxyz_xyz
)
Ts_world_joint = torch.stack(list_Ts_world_joint, dim=-2)
assert Ts_world_joint.shape[-2:] == (num_joints, 7)
return Ts_world_joint
def broadcasting_cat(tensors: list[Tensor], dim: int) -> Tensor:
"""Like torch.cat, but broadcasts."""
assert len(tensors) > 0
output_dims = max(map(lambda t: len(t.shape), tensors))
tensors = [
t.reshape((1,) * (output_dims - len(t.shape)) + t.shape) for t in tensors
]
max_sizes = [max(t.shape[i] for t in tensors) for i in range(output_dims)]
expanded_tensors = [
tensor.expand(
*(
tensor.shape[i] if i == dim % len(tensor.shape) else max_size
for i, max_size in enumerate(max_sizes)
)
)
for tensor in tensors
]
return torch.cat(expanded_tensors, dim=dim)
def _normalize_dtype(v: np.ndarray) -> np.ndarray:
"""Normalize datatypes; all arrays should be either int32 or float32."""
if "int" in str(v.dtype):
return v.astype(np.int32)
elif "float" in str(v.dtype):
return v.astype(np.float32)
else:
return v
================================================
FILE: src/egoallo/fncsmpl_extensions.py
================================================
"""EgoAllo-specific SMPL utilities."""
from __future__ import annotations
import numpy as np
import torch
from jaxtyping import Float
from torch import Tensor
from . import fncsmpl, transforms
def get_T_world_cpf(mesh: fncsmpl.SmplMesh) -> Float[Tensor, "*#batch 7"]:
"""Get the central pupil frame from a mesh. This assumes that we're using the SMPL-H model."""
assert mesh.verts.shape[-2:] == (6890, 3), "Not using SMPL-H model!"
right_eye = (mesh.verts[..., 6260, :] + mesh.verts[..., 6262, :]) / 2.0
left_eye = (mesh.verts[..., 2800, :] + mesh.verts[..., 2802, :]) / 2.0
# CPF is between the two eyes.
cpf_pos = (right_eye + left_eye) / 2.0
# Get orientation from head.
cpf_orientation = mesh.posed_model.Ts_world_joint[..., 14, :4]
return torch.cat([cpf_orientation, cpf_pos], dim=-1)
def get_T_head_cpf(shaped: fncsmpl.SmplhShaped) -> Float[Tensor, "*#batch 7"]:
"""Get the central pupil frame with respect to the head (joint 14). This
assumes that we're using the SMPL-H model."""
verts_zero = shaped.verts_zero
assert verts_zero.shape[-2:] == (6890, 3), "Not using SMPL-H model!"
right_eye = (verts_zero[..., 6260, :] + verts_zero[..., 6262, :]) / 2.0
left_eye = (verts_zero[..., 2800, :] + verts_zero[..., 2802, :]) / 2.0
# CPF is between the two eyes.
cpf_pos_wrt_head = (right_eye + left_eye) / 2.0 - shaped.joints_zero[..., 14, :]
return fncsmpl.broadcasting_cat(
[
transforms.SO3.identity(
device=cpf_pos_wrt_head.device, dtype=cpf_pos_wrt_head.dtype
).wxyz,
cpf_pos_wrt_head,
],
dim=-1,
)
def get_T_world_root_from_cpf_pose(
posed: fncsmpl.SmplhShapedAndPosed,
Ts_world_cpf: Float[Tensor | np.ndarray, "... 7"],
) -> Float[Tensor, "... 7"]:
"""Get the root transform that would align the CPF frame of `posed` to `Ts_world_cpf`."""
device = posed.Ts_world_joint.device
dtype = posed.Ts_world_joint.dtype
if isinstance(Ts_world_cpf, np.ndarray):
Ts_world_cpf = torch.from_numpy(Ts_world_cpf).to(device=device, dtype=dtype)
assert Ts_world_cpf.shape[-1] == 7
T_world_root = (
# T_world_cpf
transforms.SE3(Ts_world_cpf)
# T_cpf_head
@ transforms.SE3(get_T_head_cpf(posed.shaped_model)).inverse()
# T_head_world
@ transforms.SE3(posed.Ts_world_joint[..., 14, :]).inverse()
# T_world_root
@ transforms.SE3(posed.T_world_root)
)
return T_world_root.wxyz_xyz
================================================
FILE: src/egoallo/fncsmpl_jax.py
================================================
"""SMPL-H model, implemented in JAX.
Very little of it is specific to SMPL-H. This could very easily be adapted for other models in SMPL family.
"""
from __future__ import annotations
from pathlib import Path
from typing import Sequence, cast
import jax
import jax_dataclasses as jdc
import jaxlie
import numpy as onp
from einops import einsum
from jax import Array
from jax import numpy as jnp
from jaxtyping import Float, Int
@jdc.pytree_dataclass
class SmplhModel:
"""The SMPL-H human body model."""
faces: Int[Array, "faces 3"]
"""Vertex indices for mesh faces."""
J_regressor: Float[Array, "joints+1 verts"]
"""Linear map from vertex to joint positions.
22 body joints + 2 * 22 hand joints."""
parent_indices: Int[Array, "joints"]
"""Defines kinematic tree. Index of -1 signifies that a joint is defined
relative to the root."""
weights: Float[Array, "verts joints+1"]
"""LBS weights."""
posedirs: Float[Array, "verts 3 joints*9"]
"""Pose blend shape bases."""
v_template: Float[Array, "verts 3"]
"""Canonical mesh verts."""
shapedirs: Float[Array, "verts 3 n_betas"]
"""Shape bases."""
@staticmethod
def load(npz_path: Path) -> SmplhModel:
smplh_params: dict[str, onp.ndarray] = onp.load(npz_path, allow_pickle=True)
# assert smplh_params["bs_style"].item() == b"lbs"
# assert smplh_params["bs_type"].item() == b"lrotmin"
smplh_params = {k: _normalize_dtype(v) for k, v in smplh_params.items()}
return SmplhModel(
faces=jnp.array(smplh_params["f"]),
J_regressor=jnp.array(smplh_params["J_regressor"]),
parent_indices=jnp.array(smplh_params["kintree_table"][0][1:] - 1),
weights=jnp.array(smplh_params["weights"]),
posedirs=jnp.array(smplh_params["posedirs"]),
v_template=jnp.array(smplh_params["v_template"]),
shapedirs=jnp.array(smplh_params["shapedirs"]),
)
def with_shape(
self, betas: Float[Array | onp.ndarray, "... n_betas"]
) -> SmplhShaped:
"""Compute a new body model, with betas applied. betas vector should
have shape up to (16,)."""
num_betas = betas.shape[-1]
assert num_betas <= 16
verts_with_shape = self.v_template + einsum(
self.shapedirs[:, :, :num_betas],
betas,
"verts xyz beta, ... beta -> ... verts xyz",
)
root_and_joints_pred = einsum(
self.J_regressor,
verts_with_shape,
"joints verts, ... verts xyz -> ... joints xyz",
)
root_offset = root_and_joints_pred[..., 0:1, :]
return SmplhShaped(
body_model=self,
verts_zero=verts_with_shape - root_offset,
joints_zero=root_and_joints_pred[..., 1:, :] - root_offset,
t_parent_joint=root_and_joints_pred[..., 1:, :]
- root_and_joints_pred[..., self.parent_indices + 1, :],
)
@jdc.pytree_dataclass
class SmplhShaped:
"""The SMPL-H body model with a body shape applied."""
body_model: SmplhModel
verts_zero: Float[Array, "verts 3"]
"""Vertices of shaped body _relative to the root joint_ at the zero
configuration."""
joints_zero: Float[Array, "joints 3"]
"""Joints of shaped body _relative to the root joint_ at the zero
configuration."""
t_parent_joint: Float[Array, "joints 3"]
"""Position of each shaped body joint relative to its parent. Does not
include root."""
def with_pose_decomposed(
self,
T_world_root: Float[Array | onp.ndarray, "7"],
body_quats: Float[Array | onp.ndarray, "21 4"],
left_hand_quats: Float[Array | onp.ndarray, "15 4"] | None = None,
right_hand_quats: Float[Array | onp.ndarray, "15 4"] | None = None,
) -> SmplhShapedAndPosed:
"""Pose our SMPL-H body model. Returns a set of joint and vertex outputs."""
if left_hand_quats is None:
left_hand_quats = jnp.zeros((15, 4)).at[:, 0].set(1.0)
if right_hand_quats is None:
right_hand_quats = jnp.zeros((15, 4)).at[:, 0].set(1.0)
local_quats = broadcasting_cat(
cast(list[jax.Array], [body_quats, left_hand_quats, right_hand_quats]),
axis=0,
)
assert local_quats.shape[-2:] == (51, 4)
return self.with_pose(T_world_root, local_quats)
def with_pose(
self,
T_world_root: Float[Array | onp.ndarray, "... 7"],
local_quats: Float[Array | onp.ndarray, "... num_joints 4"],
) -> SmplhShapedAndPosed:
"""Pose our SMPL-H body model. Returns a set of joint and vertex outputs."""
# Forward kinematics.
# assert local_quats.shape == (51, 4), local_quats.shape
parent_indices = self.body_model.parent_indices
(num_joints,) = parent_indices.shape[-1:]
num_active_joints, _ = local_quats.shape[-2:]
assert local_quats.shape[-1] == 4
assert num_active_joints <= num_joints
assert self.t_parent_joint.shape[-2:] == (num_joints, 3)
# Get relative transforms.
Ts_parent_child = broadcasting_cat(
[local_quats, self.t_parent_joint[..., :num_active_joints, :]], axis=-1
)
assert Ts_parent_child.shape[-2:] == (num_active_joints, 7)
# Compute one joint at a time.
def compute_joint(i: int, Ts_world_joint: Array) -> Array:
T_world_parent = jnp.where(
parent_indices[i] == -1,
T_world_root,
Ts_world_joint[..., parent_indices[i], :],
)
return Ts_world_joint.at[..., i, :].set(
(
jaxlie.SE3(T_world_parent) @ jaxlie.SE3(Ts_parent_child[..., i, :])
).wxyz_xyz
)
Ts_world_joint = jax.lax.fori_loop(
lower=0,
upper=num_joints,
body_fun=compute_joint,
init_val=jnp.zeros_like(Ts_parent_child),
)
assert Ts_world_joint.shape[-2:] == (num_active_joints, 7)
return SmplhShapedAndPosed(
shaped_model=self,
T_world_root=T_world_root, # type: ignore
local_quats=local_quats, # type: ignore
Ts_world_joint=Ts_world_joint,
)
def get_T_head_cpf(self) -> Float[Array, "7"]:
"""Get the central pupil frame with respect to the head (joint 14). This
assumes that we're using the SMPL-H model."""
assert self.verts_zero.shape[-2:] == (6890, 3), "Not using SMPL-H model!"
right_eye = (
self.verts_zero[..., 6260, :] + self.verts_zero[..., 6262, :]
) / 2.0
left_eye = (self.verts_zero[..., 2800, :] + self.verts_zero[..., 2802, :]) / 2.0
# CPF is between the two eyes.
cpf_pos_wrt_head = (right_eye + left_eye) / 2.0 - self.joints_zero[..., 14, :]
return broadcasting_cat([jaxlie.SO3.identity().wxyz, cpf_pos_wrt_head], axis=-1)
@jdc.pytree_dataclass
class SmplhShapedAndPosed:
shaped_model: SmplhShaped
"""Underlying shaped body model."""
T_world_root: Float[Array, "*#batch 7"]
"""Root coordinate frame."""
local_quats: Float[Array, "*#batch joints 4"]
"""Local joint orientations."""
Ts_world_joint: Float[Array, "joints 7"]
"""Absolute transform for each joint. Does not include the root."""
def with_new_T_world_root(
self, T_world_root: Float[Array, "*#batch 7"]
) -> SmplhShapedAndPosed:
return SmplhShapedAndPosed(
shaped_model=self.shaped_model,
T_world_root=T_world_root,
local_quats=self.local_quats,
Ts_world_joint=(
jaxlie.SE3(T_world_root[..., None, :])
@ jaxlie.SE3(self.T_world_root[..., None, :]).inverse()
@ jaxlie.SE3(self.Ts_world_joint)
).parameters(),
)
def lbs(self) -> SmplhMesh:
assert (
self.local_quats.shape[0]
== self.shaped_model.body_model.parent_indices.shape[0]
), (
"It looks like only a partial set of joint rotations was passed into `with_pose()`. We need all of them for LBS."
)
# Linear blend skinning with a pose blend shape.
verts_with_blend = self.shaped_model.verts_zero + einsum(
self.shaped_model.body_model.posedirs,
(jaxlie.SO3(self.local_quats).as_matrix() - jnp.eye(3)).flatten(),
"verts j joints_times_9, ... joints_times_9 -> ... verts j",
)
verts_transformed = einsum(
broadcasting_cat(
[
# (*, 1, 3, 4)
jaxlie.SE3(self.T_world_root).as_matrix()[..., None, :3, :],
# (*, 51, 3, 4)
jaxlie.SE3(self.Ts_world_joint).as_matrix()[..., :3, :],
],
axis=0,
),
self.shaped_model.body_model.weights,
jnp.pad(
verts_with_blend[:, None, :]
- jnp.concatenate(
[
jnp.zeros((1, 1, 3)), # Root joint.
self.shaped_model.joints_zero[None, :, :],
],
axis=1,
),
((0, 0), (0, 0), (0, 1)),
constant_values=1.0,
),
"joints_p1 i j, ... verts joints_p1, ... verts joints_p1 j -> ... verts i",
)
return SmplhMesh(
posed_model=self,
verts=verts_transformed,
faces=self.shaped_model.body_model.faces,
)
@jdc.pytree_dataclass
class SmplhMesh:
posed_model: SmplhShapedAndPosed
verts: Float[Array, "verts 3"]
"""Vertices for mesh."""
faces: Int[Array, "13776 3"]
"""Faces for mesh."""
def broadcasting_cat(arrays: Sequence[jax.Array | onp.ndarray], axis: int) -> jax.Array:
"""Like jnp.concatenate, but broadcasts leading axes."""
assert len(arrays) > 0
output_dims = max(map(lambda t: len(t.shape), arrays))
arrays = [t.reshape((1,) * (output_dims - len(t.shape)) + t.shape) for t in arrays]
max_sizes = [max(t.shape[i] for t in arrays) for i in range(output_dims)]
expanded_arrays = [
jnp.broadcast_to(
array,
tuple(
array.shape[i] if i == axis % len(array.shape) else max_size
for i, max_size in enumerate(max_sizes)
),
)
for array in arrays
]
return jnp.concatenate(expanded_arrays, axis=axis)
def _normalize_dtype(v: onp.ndarray) -> onp.ndarray:
"""Normalize datatypes; all arrays should be either int32 or float32."""
if "int" in str(v.dtype):
return v.astype(onp.int32)
elif "float" in str(v.dtype):
return v.astype(onp.float32)
else:
return v
================================================
FILE: src/egoallo/guidance_optimizer_jax.py
================================================
"""Optimize constraints using Levenberg-Marquardt."""
from __future__ import annotations
import os
from .hand_detection_structs import (
CorrespondedAriaHandWristPoseDetections,
CorrespondedHamerDetections,
)
# Need to play nice with PyTorch!
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
import dataclasses
import time
from functools import partial
from typing import Callable, Literal, Unpack, assert_never, cast
import jax
import jax_dataclasses as jdc
import jaxlie
import jaxls
import numpy as onp
import torch
from jax import numpy as jnp
from jaxtyping import Float, Int
from torch import Tensor
from . import fncsmpl, fncsmpl_jax, network
from .transforms._so3 import SO3
def do_guidance_optimization(
Ts_world_cpf: Float[Tensor, "time 7"],
traj: network.EgoDenoiseTraj,
body_model: fncsmpl.SmplhModel,
guidance_mode: GuidanceMode,
phase: Literal["inner", "post"],
hamer_detections: None | CorrespondedHamerDetections,
aria_detections: None | CorrespondedAriaHandWristPoseDetections,
verbose: bool,
) -> tuple[network.EgoDenoiseTraj, dict]:
"""Run an optimizer to apply foot contact constraints."""
assert traj.hand_rotmats is not None
guidance_params = JaxGuidanceParams.defaults(guidance_mode, phase)
start_time = time.time()
quats, debug_info = _optimize_vmapped(
body=fncsmpl_jax.SmplhModel(
faces=cast(jax.Array, body_model.faces.numpy(force=True)),
J_regressor=cast(jax.Array, body_model.J_regressor.numpy(force=True)),
parent_indices=cast(jax.Array, onp.array(body_model.parent_indices)),
weights=cast(jax.Array, body_model.weights.numpy(force=True)),
posedirs=cast(jax.Array, body_model.posedirs.numpy(force=True)),
v_template=cast(jax.Array, body_model.v_template.numpy(force=True)),
shapedirs=cast(jax.Array, body_model.shapedirs.numpy(force=True)),
),
Ts_world_cpf=cast(jax.Array, Ts_world_cpf.numpy(force=True)),
betas=cast(jax.Array, traj.betas.numpy(force=True)),
body_rotmats=cast(jax.Array, traj.body_rotmats.numpy(force=True)),
hand_rotmats=cast(jax.Array, traj.hand_rotmats.numpy(force=True)),
contacts=cast(jax.Array, traj.contacts.numpy(force=True)),
guidance_params=guidance_params,
# The hand detections are a torch tensors in a TensorDataclass form. We
# use dictionaries to convert to pytrees.
hamer_detections=None
if hamer_detections is None
else hamer_detections.as_nested_dict(numpy=True),
aria_detections=None
if aria_detections is None
else aria_detections.as_nested_dict(numpy=True),
verbose=verbose,
)
rotmats = SO3(
torch.from_numpy(onp.array(quats))
.to(traj.body_rotmats.dtype)
.to(traj.body_rotmats.device)
).as_matrix()
print(f"Constraint optimization finished in {time.time() - start_time}sec")
return dataclasses.replace(
traj,
body_rotmats=rotmats[:, :, :21, :],
hand_rotmats=rotmats[:, :, 21:, :],
), debug_info
class _SmplhBodyPosesVar(
jaxls.Var[jax.Array],
default_factory=lambda: jnp.concatenate(
[jnp.ones((21, 1)), jnp.zeros((21, 3))], axis=-1
),
retract_fn=lambda val, delta: (
jaxlie.SO3(val) @ jaxlie.SO3.exp(delta.reshape(21, 3))
).wxyz,
tangent_dim=21 * 3,
):
"""Variable containing local joint poses for a SMPL-H human."""
class _SmplhSingleHandPosesVar(
jaxls.Var[jax.Array],
default_factory=lambda: jnp.concatenate(
[jnp.ones((15, 1)), jnp.zeros((15, 3))], axis=-1
),
retract_fn=lambda val, delta: (
jaxlie.SO3(val) @ jaxlie.SO3.exp(delta.reshape(15, 3))
).wxyz,
tangent_dim=15 * 3,
):
"""Variable containing local joint poses for one hand of a SMPL-H human."""
@jdc.jit
def _optimize_vmapped(
Ts_world_cpf: jax.Array,
body: fncsmpl_jax.SmplhModel,
betas: jax.Array,
body_rotmats: jax.Array,
hand_rotmats: jax.Array,
contacts: jax.Array,
guidance_params: JaxGuidanceParams,
hamer_detections: dict | None,
aria_detections: dict | None,
verbose: jdc.Static[bool],
) -> tuple[jax.Array, dict]:
return jax.vmap(
partial(
_optimize,
Ts_world_cpf=Ts_world_cpf,
body=body,
guidance_params=guidance_params,
hamer_detections=hamer_detections,
aria_detections=aria_detections,
verbose=verbose,
)
)(
betas=betas,
body_rotmats=body_rotmats,
hand_rotmats=hand_rotmats,
contacts=contacts,
)
# Modes for guidance.
GuidanceMode = Literal[
# Foot skating only.
"no_hands",
# Only use Aria wrist pose.
"aria_wrist_only",
# Use Aria wrist pose + HaMeR 3D estimates.
"aria_hamer",
# Use only HaMeR 3D estimates.
"hamer_wrist",
# Use HaMeR 3D estimates + reprojection.
"hamer_reproj2",
]
@jdc.pytree_dataclass
class JaxGuidanceParams:
prior_quat_weight: float = 1.0
prior_pos_weight: float = 5.0
body_quat_vel_smoothness_weight: float = 5.0
body_quat_smoothness_weight: float = 1.0
body_quat_delta_smoothness_weight: float = 10.0
skate_weight: float = 30.0
# Note: this should be quite high. If the hand quaternions aren't
# constrained enough the reprojecction loss can get wild.
hand_quats: jdc.Static[bool] = True
hand_quat_weight = 5.0
hand_quat_priors: jdc.Static[bool] = True
hand_quat_prior_weight = 0.1
hand_quat_smoothness_weight = 1.0
hamer_reproj: jdc.Static[bool] = True
hand_reproj_weight: float = 1.0
hamer_wrist_pose: jdc.Static[bool] = True
hamer_abspos_weight: float = 20.0
hamer_ori_weight: float = 5.0
aria_wrists: jdc.Static[bool] = True
aria_wrist_pos_weight: float = 50.0
aria_wrist_ori_weight: float = 10.0
# Optimization parameters.
lambda_initial: float = 0.1
max_iters: jdc.Static[int] = 20
@staticmethod
def defaults(
mode: GuidanceMode,
phase: Literal["inner", "post"],
) -> JaxGuidanceParams:
if mode == "no_hands":
return {
"inner": JaxGuidanceParams(
hand_quats=False,
hand_quat_priors=False,
hamer_reproj=False,
hamer_wrist_pose=False,
aria_wrists=False,
max_iters=5,
),
"post": JaxGuidanceParams(
hand_quats=False,
hand_quat_priors=False,
hamer_reproj=False,
hamer_wrist_pose=False,
aria_wrists=False,
max_iters=20,
),
}[phase]
elif mode == "aria_wrist_only":
return {
"inner": JaxGuidanceParams(
hand_quats=False,
hand_quat_priors=True,
hamer_reproj=False,
hamer_wrist_pose=False,
aria_wrists=True,
max_iters=5,
),
"post": JaxGuidanceParams(
hand_quats=False,
hand_quat_priors=True,
hamer_reproj=False,
hamer_wrist_pose=False,
aria_wrists=True,
max_iters=20,
),
}[phase]
elif mode == "aria_hamer":
return {
"inner": JaxGuidanceParams(
hand_quats=True,
hand_quat_priors=True,
hamer_reproj=False,
hamer_wrist_pose=False,
aria_wrists=True,
max_iters=5,
),
"post": JaxGuidanceParams(
hand_quats=True,
hand_quat_priors=True,
hamer_reproj=False,
hamer_wrist_pose=False,
aria_wrists=True,
max_iters=20,
),
}[phase]
elif mode == "hamer_wrist":
return {
"inner": JaxGuidanceParams(
hand_quats=True,
hand_quat_priors=True,
# NOTE: we turn off reprojection during the inner loop optimization.
hamer_reproj=False,
hamer_wrist_pose=True,
aria_wrists=False,
max_iters=5,
),
"post": JaxGuidanceParams(
hand_quats=True,
hand_quat_priors=True,
# Turn on reprojection.
hamer_reproj=False,
hamer_wrist_pose=True,
aria_wrists=False,
max_iters=20,
),
}[phase]
elif mode == "hamer_reproj2":
return {
"inner": JaxGuidanceParams(
hand_quats=True,
hand_quat_priors=True,
# NOTE: we turn off reprojection during the inner loop optimization.
hamer_reproj=False,
hamer_wrist_pose=True,
aria_wrists=False,
max_iters=5,
),
"post": JaxGuidanceParams(
hand_quats=True,
hand_quat_priors=True,
# Turn on reprojection.
hamer_reproj=True,
hamer_wrist_pose=True,
aria_wrists=False,
max_iters=20,
),
}[phase]
else:
assert_never(mode)
def _optimize(
Ts_world_cpf: jax.Array,
body: fncsmpl_jax.SmplhModel,
betas: jax.Array,
body_rotmats: jax.Array,
hand_rotmats: jax.Array,
contacts: jax.Array,
guidance_params: JaxGuidanceParams,
hamer_detections: dict | None,
aria_detections: dict | None,
verbose: bool,
) -> tuple[jax.Array, dict]:
"""Apply constraints using Levenberg-Marquardt optimizer. Returns updated
body_rotmats and hand_rotmats matrices."""
timesteps = body_rotmats.shape[0]
assert Ts_world_cpf.shape == (timesteps, 7)
assert body_rotmats.shape == (timesteps, 21, 3, 3)
assert hand_rotmats.shape == (timesteps, 30, 3, 3)
assert contacts.shape == (timesteps, 21)
assert betas.shape == (timesteps, 16)
init_quats = jaxlie.SO3.from_matrix(
# body_rotmats
jnp.concatenate([body_rotmats, hand_rotmats], axis=1)
).wxyz
assert init_quats.shape == (timesteps, 51, 4)
# Assume body shape is time-invariant.
shaped_body = body.with_shape(jnp.mean(betas, axis=0))
T_head_cpf = shaped_body.get_T_head_cpf()
T_cpf_head = jaxlie.SE3(T_head_cpf).inverse().parameters()
assert T_cpf_head.shape == (7,)
init_posed = shaped_body.with_pose(
jaxlie.SE3.identity(batch_axes=(timesteps,)).wxyz_xyz, init_quats
)
T_world_head = jaxlie.SE3(Ts_world_cpf) @ jaxlie.SE3(T_cpf_head)
T_root_head = jaxlie.SE3(init_posed.Ts_world_joint[:, 14])
init_posed = init_posed.with_new_T_world_root(
(T_world_head @ T_root_head.inverse()).wxyz_xyz
)
del T_world_head
del T_root_head
foot_joint_indices = jnp.array([6, 7, 9, 10])
num_foot_joints = foot_joint_indices.shape[0]
contacts = contacts[..., foot_joint_indices]
pairwise_contacts = (contacts[:-1, :] + contacts[1:, :]) / 2.0
assert pairwise_contacts.shape == (timesteps - 1, num_foot_joints)
del contacts
# We'll populate a list of factors (cost terms).
factors = list[jaxls.Cost]()
def cost_with_args[*CostArgs](
*args: Unpack[tuple[*CostArgs]],
) -> Callable[
[Callable[[jaxls.VarValues, *CostArgs], jax.Array]],
Callable[[jaxls.VarValues, *CostArgs], jax.Array],
]:
"""Decorator for appending to the factor list."""
def inner(
cost_func: Callable[[jaxls.VarValues, *CostArgs], jax.Array],
) -> Callable[[jaxls.VarValues, *CostArgs], jax.Array]:
factors.append(jaxls.Cost(cost_func, args))
return cost_func
return inner
def do_forward_kinematics(
vals: jaxls.VarValues,
var: _SmplhBodyPosesVar,
left_hand: _SmplhSingleHandPosesVar | None = None,
right_hand: _SmplhSingleHandPosesVar | None = None,
output_frame: Literal["world", "root"] = "world",
) -> fncsmpl_jax.SmplhShapedAndPosed:
"""Helper for computing forward kinematics from variables."""
assert (left_hand is None) == (right_hand is None)
if left_hand is None and right_hand is None:
posed = shaped_body.with_pose(
T_world_root=jaxlie.SE3.identity().wxyz_xyz,
local_quats=vals[var],
)
elif left_hand is not None and right_hand is None:
posed = shaped_body.with_pose(
T_world_root=jaxlie.SE3.identity().wxyz_xyz,
local_quats=jnp.concatenate([vals[var], vals[left_hand]], axis=-2),
)
elif left_hand is not None and right_hand is not None:
posed = shaped_body.with_pose(
T_world_root=jaxlie.SE3.identity().wxyz_xyz,
local_quats=jnp.concatenate(
[vals[var], vals[left_hand], vals[right_hand]], axis=-2
),
)
else:
assert False
if output_frame == "world":
T_world_root = (
# T_world_cpf
jaxlie.SE3(Ts_world_cpf[var.id, :])
# T_cpf_head
@ jaxlie.SE3(T_cpf_head)
# T_head_root
@ jaxlie.SE3(posed.Ts_world_joint[14]).inverse()
)
return posed.with_new_T_world_root(T_world_root.wxyz_xyz)
elif output_frame == "root":
return posed
# HaMeR pose cost.
if hamer_detections is not None and guidance_params.hand_quat_priors:
hamer_left = hamer_detections["detections_left_concat"]
hamer_right = hamer_detections["detections_right_concat"]
# HaMeR local quaternion smoothness.
@(
cost_with_args(
_SmplhSingleHandPosesVar(jnp.arange(timesteps * 2 - 2)),
_SmplhSingleHandPosesVar(jnp.arange(2, timesteps * 2)),
)
)
def hand_smoothness(
vals: jaxls.VarValues,
hand_pose: _SmplhSingleHandPosesVar,
hand_pose_next: _SmplhSingleHandPosesVar,
) -> jax.Array:
return (
guidance_params.hand_quat_smoothness_weight
* (
jaxlie.SO3(vals[hand_pose]).inverse()
@ jaxlie.SO3(vals[hand_pose_next])
)
.log()
.flatten()
)
# Hand prior loss.
@cost_with_args(
_SmplhSingleHandPosesVar(jnp.arange(timesteps * 2)),
init_quats[:, 21:51, :].reshape((timesteps * 2, 15, 4)),
)
def hand_prior(
vals: jaxls.VarValues,
hand_pose: _SmplhSingleHandPosesVar,
init_hand_quats: jax.Array,
) -> jax.Array:
return (
guidance_params.hand_quat_prior_weight
* (jaxlie.SO3(vals[hand_pose]).inverse() @ jaxlie.SO3(init_hand_quats))
.log()
.flatten()
)
if hamer_detections is not None and guidance_params.hand_quats:
hamer_left = hamer_detections["detections_left_concat"]
hamer_right = hamer_detections["detections_right_concat"]
# HaMeR local pose matching.
@(
cost_with_args(
_SmplhSingleHandPosesVar(hamer_left["indices"] * 2),
hamer_left["single_hand_quats"],
)
if hamer_left is not None
else lambda x: x
)
@(
cost_with_args(
_SmplhSingleHandPosesVar(hamer_right["indices"] * 2 + 1),
hamer_right["single_hand_quats"],
)
if hamer_right is not None
else lambda x: x
)
def hamer_local_pose_cost(
vals: jaxls.VarValues,
hand_pose: _SmplhSingleHandPosesVar,
estimated_hand_quats: jax.Array,
) -> jax.Array:
hand_quats = vals[hand_pose]
assert hand_quats.shape == estimated_hand_quats.shape
return guidance_params.hand_quat_weight * (
(jaxlie.SO3(hand_quats).inverse() @ jaxlie.SO3(estimated_hand_quats))
.log()
.flatten()
)
if hamer_detections is not None and (
guidance_params.hamer_reproj and guidance_params.hamer_wrist_pose
):
hamer_left = hamer_detections["detections_left_concat"]
hamer_right = hamer_detections["detections_right_concat"]
# HaMeR reprojection.
mano_from_openpose_indices = _get_mano_from_openpose_indices(include_tips=False)
@(
cost_with_args(
_SmplhBodyPosesVar(hamer_left["indices"]),
_SmplhSingleHandPosesVar(hamer_left["indices"] * 2),
_SmplhSingleHandPosesVar(hamer_left["indices"] * 2 + 1),
jnp.full_like(hamer_left["indices"], fill_value=0),
hamer_left["keypoints_3d"],
hamer_left["mano_hand_global_orient"],
)
if hamer_left is not None
else lambda x: x
)
@(
cost_with_args(
_SmplhBodyPosesVar(hamer_right["indices"]),
_SmplhSingleHandPosesVar(hamer_right["indices"] * 2),
_SmplhSingleHandPosesVar(hamer_right["indices"] * 2 + 1),
jnp.full_like(hamer_right["indices"], fill_value=1),
hamer_right["keypoints_3d"],
hamer_right["mano_hand_global_orient"],
)
if hamer_right is not None
else lambda x: x
)
def hamer_wrist_and_reproj(
vals: jaxls.VarValues,
body_pose: _SmplhBodyPosesVar,
left_hand_pose: _SmplhSingleHandPosesVar,
right_hand_pose: _SmplhSingleHandPosesVar,
left0_right1: jax.Array, # Set to 0 for left, 1 for right.
keypoints3d_wrt_cam: jax.Array, # These are in OpenPose order!!
Rmat_cam_wrist: jax.Array,
) -> jax.Array:
posed = do_forward_kinematics(
# The right hand comes _after_ the left hand, we can exclude it.
vals,
body_pose,
left_hand_pose,
right_hand_pose,
output_frame="root",
)
Ts_root_joint = posed.Ts_world_joint # Sorry for the naming...
del posed
# 19 for left wrist, 20 for right wrist.
wrist_index = 19 + left0_right1
hand_start_index = 21 + 15 * left0_right1
assert Ts_root_joint.shape == (51, 7)
joint_positions_wrt_root = Ts_root_joint[:, 4:7]
mano_joints_wrt_root = jnp.concatenate(
[
jax.lax.dynamic_slice_in_dim(
joint_positions_wrt_root,
start_index=wrist_index,
slice_size=1,
axis=-2,
),
jax.lax.dynamic_slice_in_dim(
joint_positions_wrt_root,
start_index=hand_start_index,
slice_size=15,
axis=-2,
),
],
axis=0,
)
assert mano_joints_wrt_root.shape == (16, 3)
assert keypoints3d_wrt_cam.shape == (21, 3) # In OpenPose.
T_cam_root = (
# T_cam_cpf (7,)
jaxlie.SE3(hamer_detections["T_cpf_cam"]).inverse()
# T_cpf_head (7,)
@ jaxlie.SE3(T_cpf_head)
# T_head_root (7,)
@ jaxlie.SE3(Ts_root_joint[14, :]).inverse()
)
assert T_cam_root.parameters().shape == (7,)
mano_joints_wrt_cam = T_cam_root @ mano_joints_wrt_root
obs_joints_wrt_cam = keypoints3d_wrt_cam[mano_from_openpose_indices, :]
mano_uv_wrt_cam = mano_joints_wrt_cam[:, :2] / mano_joints_wrt_cam[:, 2:3]
obs_uv_wrt_cam = obs_joints_wrt_cam[:, :2] / obs_joints_wrt_cam[:, 2:3]
T_cam_wrist = jaxlie.SE3.from_rotation_and_translation(
T_cam_root.rotation() @ jaxlie.SO3(Ts_root_joint[wrist_index, :4]),
mano_joints_wrt_cam[0, :],
)
obs_T_cam_wrist = jaxlie.SE3.from_rotation_and_translation(
jaxlie.SO3.from_matrix(Rmat_cam_wrist),
obs_joints_wrt_cam[0, :],
)
return jnp.concatenate(
[
(T_cam_wrist.inverse() @ obs_T_cam_wrist).log()
* jnp.array(
[guidance_params.hamer_abspos_weight] * 3
+ [guidance_params.hamer_ori_weight] * 3
),
guidance_params.hand_reproj_weight
* (mano_uv_wrt_cam - obs_uv_wrt_cam).flatten(),
]
)
elif (
hamer_detections is not None
and not guidance_params.hamer_reproj
and guidance_params.hamer_wrist_pose
):
hamer_left = hamer_detections["detections_left_concat"]
hamer_right = hamer_detections["detections_right_concat"]
@(
cost_with_args(
_SmplhBodyPosesVar(hamer_left["indices"]),
jnp.full_like(hamer_left["indices"], fill_value=0),
hamer_left["keypoints_3d"],
hamer_left["mano_hand_global_orient"],
)
if hamer_left is not None
else lambda x: x
)
@(
cost_with_args(
_SmplhBodyPosesVar(hamer_right["indices"]),
jnp.full_like(hamer_right["indices"], fill_value=1),
hamer_right["keypoints_3d"],
hamer_right["mano_hand_global_orient"],
)
if hamer_right is not None
else lambda x: x
)
def hamer_wrist_only(
vals: jaxls.VarValues,
body_pose: _SmplhBodyPosesVar,
left0_right1: jax.Array, # Set to 0 for left, 1 for right.
keypoints3d_wrt_cam: jax.Array, # These are in OpenPose order!!
Rmat_cam_wrist: jax.Array,
) -> jax.Array:
posed = do_forward_kinematics(vals, body_pose, output_frame="root")
Ts_root_joint = posed.Ts_world_joint # Sorry for the naming...
del posed
# 19 for left wrist, 20 for right wrist.
wrist_index = 19 + left0_right1
assert Ts_root_joint.shape == (21, 7)
wrist_position_wrt_root = Ts_root_joint[wrist_index, 4:7]
T_cam_root = (
# T_cam_cpf (7,)
jaxlie.SE3(hamer_detections["T_cpf_cam"]).inverse()
# T_cpf_head (7,)
@ jaxlie.SE3(T_cpf_head)
# T_head_root (7,)
@ jaxlie.SE3(Ts_root_joint[14, :]).inverse()
)
assert T_cam_root.parameters().shape == (7,)
wrist_position_wrt_cam = T_cam_root @ wrist_position_wrt_root
# Assumes OpenPose root is same as Mano root!!
wrist_pos_wrt_cam = keypoints3d_wrt_cam[0, :]
T_cam_wrist = jaxlie.SE3.from_rotation_and_translation(
T_cam_root.rotation() @ jaxlie.SO3(Ts_root_joint[wrist_index, :4]),
wrist_position_wrt_cam,
)
obs_T_cam_wrist = jaxlie.SE3.from_rotation_and_translation(
jaxlie.SO3.from_matrix(Rmat_cam_wrist),
wrist_pos_wrt_cam,
)
return (T_cam_wrist.inverse() @ obs_T_cam_wrist).log() * jnp.array(
[guidance_params.hamer_abspos_weight] * 3
+ [guidance_params.hamer_ori_weight] * 3
)
# Wrist pose cost.
if aria_detections is not None and guidance_params.aria_wrists:
aria_left = aria_detections["detections_left_concat"]
aria_right = aria_detections["detections_right_concat"]
@(
cost_with_args(
_SmplhBodyPosesVar(aria_left["indices"]),
aria_left["confidence"],
aria_left["wrist_position"],
aria_left["palm_position"],
aria_left["palm_normal"],
jnp.full_like(aria_left["indices"], fill_value=0),
)
if aria_left is not None
else lambda x: x
)
@(
cost_with_args(
_SmplhBodyPosesVar(aria_right["indices"]),
aria_right["confidence"],
aria_right["wrist_position"],
aria_right["palm_position"],
aria_right["palm_normal"],
jnp.full_like(aria_right["indices"], fill_value=1),
)
if aria_right is not None
else lambda x: x
)
def wrist_pose_cost(
vals: jaxls.VarValues,
pose: _SmplhBodyPosesVar,
confidence: jax.Array,
wrist_position: jax.Array,
palm_position: jax.Array,
palm_normal: jax.Array,
left0_right1: jax.Array, # Set to 0 for left, 1 for right.
) -> jax.Array:
assert wrist_position.shape == (3,)
assert left0_right1.shape == ()
posed = do_forward_kinematics(vals, pose)
T_world_wrist = posed.Ts_world_joint[19 + left0_right1]
pos_cost = (
# Left wrist is joint 19, right is joint 20.
T_world_wrist[4:7] - wrist_position
)
# Estimate wrist orientation from forward + normal directions.
palm_forward = palm_position - wrist_position
palm_forward = palm_forward / jnp.linalg.norm(palm_forward)
palm_normal = palm_normal / jnp.linalg.norm(palm_normal)
palm_forward = ( # Flip palm forward if right hand.
palm_forward * jnp.array([1, -1])[left0_right1]
)
palm_forward = ( # Gram-schmidt for forward direction.
palm_forward - jnp.dot(palm_forward, palm_normal) * palm_normal
)
estimatedR_world_wrist = jaxlie.SO3.from_matrix(
jnp.stack(
[
palm_forward,
-palm_normal,
jnp.cross(palm_normal, palm_forward),
],
axis=1,
)
)
R_world_wrist = jaxlie.SO3(T_world_wrist[:4])
ori_cost = (estimatedR_world_wrist.inverse() @ R_world_wrist).log()
return confidence * jnp.concatenate(
[
guidance_params.aria_wrist_pos_weight * pos_cost,
guidance_params.aria_wrist_ori_weight * ori_cost,
]
)
# Per-frame regularization cost.
@cost_with_args(
_SmplhBodyPosesVar(jnp.arange(timesteps)),
)
def reg_cost(
vals: jaxls.VarValues,
pose: _SmplhBodyPosesVar,
) -> jax.Array:
posed = do_forward_kinematics(vals, pose)
torso_indices = jnp.array([0, 1, 2, 5, 8])
return jnp.concatenate(
[
guidance_params.prior_quat_weight
* (
jaxlie.SO3(vals[pose]).inverse()
@ jaxlie.SO3(init_quats[pose.id, :21, :])
)
.log()
.flatten(),
# Only include some torso joints.
guidance_params.prior_pos_weight
* (
posed.Ts_world_joint[torso_indices, 4:7]
- init_posed.Ts_world_joint[pose.id, torso_indices, 4:7]
).flatten(),
]
)
@cost_with_args(
_SmplhBodyPosesVar(jnp.arange(timesteps - 1)),
_SmplhBodyPosesVar(jnp.arange(1, timesteps)),
)
def delta_smoothness_cost(
vals: jaxls.VarValues,
current: _SmplhBodyPosesVar,
next: _SmplhBodyPosesVar,
) -> jax.Array:
curdelt = jaxlie.SO3(vals[current]).inverse() @ jaxlie.SO3(
init_quats[current.id, :21, :]
)
nexdelt = jaxlie.SO3(vals[next]).inverse() @ jaxlie.SO3(
init_quats[next.id, :21, :]
)
return jnp.concatenate(
[
guidance_params.body_quat_delta_smoothness_weight
* (curdelt.inverse() @ nexdelt).log().flatten(),
guidance_params.body_quat_smoothness_weight
* (jaxlie.SO3(vals[current]).inverse() @ jaxlie.SO3(vals[next]))
.log()
.flatten(),
]
)
@cost_with_args(
_SmplhBodyPosesVar(jnp.arange(timesteps - 2)),
_SmplhBodyPosesVar(jnp.arange(1, timesteps - 1)),
_SmplhBodyPosesVar(jnp.arange(2, timesteps)),
)
def vel_smoothness_cost(
vals: jaxls.VarValues,
t0: _SmplhBodyPosesVar,
t1: _SmplhBodyPosesVar,
t2: _SmplhBodyPosesVar,
) -> jax.Array:
curdelt = jaxlie.SO3(vals[t0]).inverse() @ jaxlie.SO3(vals[t1])
nexdelt = jaxlie.SO3(vals[t1]).inverse() @ jaxlie.SO3(vals[t2])
return (
guidance_params.body_quat_vel_smoothness_weight
* (curdelt.inverse() @ nexdelt).log().flatten()
)
@cost_with_args(
_SmplhBodyPosesVar(jnp.arange(timesteps - 1)),
_SmplhBodyPosesVar(jnp.arange(1, timesteps)),
pairwise_contacts,
)
def skating_cost(
vals: jaxls.VarValues,
current: _SmplhBodyPosesVar,
next: _SmplhBodyPosesVar,
foot_contacts: jax.Array,
) -> jax.Array:
# Do forward kinematics.
posed_current = do_forward_kinematics(vals, current)
posed_next = do_forward_kinematics(vals, next)
footpos_current = posed_current.Ts_world_joint[foot_joint_indices, 4:7]
footpos_next = posed_next.Ts_world_joint[foot_joint_indices, 4:7]
assert footpos_current.shape == footpos_next.shape == (num_foot_joints, 3)
assert foot_contacts.shape == (num_foot_joints,)
return (
guidance_params.skate_weight
* (foot_contacts[:, None] * (footpos_current - footpos_next)).flatten()
)
vars_body_pose = _SmplhBodyPosesVar(jnp.arange(timesteps))
vars_hand_pose = _SmplhSingleHandPosesVar(jnp.arange(timesteps * 2))
graph = jaxls.LeastSquaresProblem(
costs=factors, variables=[vars_body_pose, vars_hand_pose]
).analyze()
solutions = graph.solve(
initial_vals=jaxls.VarValues.make(
[
vars_body_pose.with_value(init_quats[:, :21, :]),
vars_hand_pose.with_value(
init_quats[:, 21:51, :].reshape((timesteps * 2, 15, 4))
),
]
),
linear_solver="conjugate_gradient",
trust_region=jaxls.TrustRegionConfig(
lambda_initial=guidance_params.lambda_initial
),
termination=jaxls.TerminationConfig(max_iterations=guidance_params.max_iters),
verbose=verbose,
)
out_body_quats = solutions[_SmplhBodyPosesVar]
assert out_body_quats.shape == (timesteps, 21, 4)
out_hand_quats = solutions[_SmplhSingleHandPosesVar].reshape((timesteps, 30, 4))
assert out_hand_quats.shape == (timesteps, 30, 4)
return (
jnp.concatenate([out_body_quats, out_hand_quats], axis=-2),
{}, # Metadata dict that we use for debugging.
)
def _get_mano_from_openpose_indices(include_tips: bool) -> Int[onp.ndarray, "21"]:
# https://github.com/geopavlakos/hamer/blob/272d68f176e0ea8a506f761663dd3dca4a03ced0/hamer/models/mano_wrapper.py#L20
# fmt: off
mano_to_openpose = [0, 13, 14, 15, 16, 1, 2, 3, 17, 4, 5, 6, 18, 10, 11, 12, 19, 7, 8, 9, 20]
# fmt: on
openpose_from_mano_idx = {
mano_idx: openpose_idx for openpose_idx, mano_idx in enumerate(mano_to_openpose)
}
return onp.array(
[openpose_from_mano_idx[i] for i in range(21 if include_tips else 16)]
)
================================================
FILE: src/egoallo/hand_detection_structs.py
================================================
"""Data structure definition that we use for hand detections.
We'll run HaMeR, produce the dictionary defined by `SavedHamerOutputs`, then
pickle this dictionary.
"""
from __future__ import annotations
import pickle
from pathlib import Path
from typing import Protocol, TypedDict, cast
import numpy as np
import torch
from jaxtyping import Float, Int
from projectaria_tools.core import mps
from projectaria_tools.core.mps.utils import get_nearest_wrist_and_palm_pose
from torch import Tensor
from .tensor_dataclass import TensorDataclass
from .transforms import SE3, SO3
class SingleHandHamerOutputWrtCamera(TypedDict):
"""Hand outputs with respect to the camera frame. For use in pickle files."""
verts: np.ndarray
keypoints_3d: np.ndarray
mano_hand_pose: np.ndarray
mano_hand_betas: np.ndarray
mano_hand_global_orient: np.ndarray
class SavedHamerOutputs(TypedDict):
"""Outputs from the HAMeR hand detection algorithm. This is the structure
to pickle.
`detections_left_wrt_cam` and `detections_right_wrt_cam` use nanosecond
timestamps as keys.
"""
mano_faces_right: np.ndarray
mano_faces_left: np.ndarray
detections_left_wrt_cam: dict[int, SingleHandHamerOutputWrtCamera | None]
detections_right_wrt_cam: dict[int, SingleHandHamerOutputWrtCamera | None]
T_device_cam: np.ndarray # wxyz_xyz
T_cpf_cam: np.ndarray # wxyz_xyz
class AriaHandWristPoseWrtWorld(TensorDataclass):
confidence: Float[Tensor, "n_detections"]
wrist_position: Float[Tensor, "n_detections 3"]
wrist_normal: Float[Tensor, "n_detections 3"]
palm_position: Float[Tensor, "n_detections 3"]
palm_normal: Float[Tensor, "n_detections 3"]
indices: Int[Tensor, "n_detections"]
class CorrespondedAriaHandWristPoseDetections(TensorDataclass):
detections_left_concat: AriaHandWristPoseWrtWorld | None
detections_right_concat: AriaHandWristPoseWrtWorld | None
@staticmethod
def load(
wrist_and_palm_poses_csv_path: Path,
target_timestamps_sec: tuple[float, ...],
Ts_world_device: Float[np.ndarray, "timesteps 7"],
) -> CorrespondedAriaHandWristPoseDetections:
# API from runtime inspection of `projectaria_tools` outputs.
class WristAndPalmNormals(Protocol):
wrist_normal_device: np.ndarray
palm_normal_device: np.ndarray
class OneSide(Protocol):
confidence: float
wrist_position_device: np.ndarray
palm_position_device: np.ndarray
wrist_and_palm_normal_device: WristAndPalmNormals
wp_poses = mps.hand_tracking.read_wrist_and_palm_poses(
str(wrist_and_palm_poses_csv_path)
)
detections_left = list[OneSide]()
detections_right = list[OneSide]()
indices_left = list[int]()
indices_right = list[int]()
for i, time_sec in enumerate(target_timestamps_sec):
wp_pose = get_nearest_wrist_and_palm_pose(wp_poses, int(time_sec * 1e9))
if (
wp_pose is None
or abs(wp_pose.tracking_timestamp.total_seconds() - time_sec)
>= 1.0 / 30.0
):
continue
if wp_pose.left_hand is not None and wp_pose.left_hand.confidence > 0.7:
indices_left.append(i)
detections_left.append(wp_pose.left_hand)
if wp_pose.right_hand is not None and wp_pose.right_hand.confidence > 0.7:
indices_right.append(i)
detections_right.append(wp_pose.right_hand)
def form_detections_concat(
detections: list[OneSide], indices: list[int]
) -> AriaHandWristPoseWrtWorld | None:
assert len(detections) == len(indices)
if len(indices) == 0:
return None
Tslice_world_device = SE3(
torch.from_numpy(Ts_world_device[np.array(indices), :]).to(
dtype=torch.float32
)
)
Rslice_world_device = SO3(
torch.from_numpy(Ts_world_device[np.array(indices), :4]).to(
dtype=torch.float32
)
)
return AriaHandWristPoseWrtWorld(
confidence=torch.from_numpy(
np.array([d.confidence for d in detections])
),
wrist_position=Tslice_world_device
@ torch.from_numpy(
np.array(
[d.wrist_position_device for d in detections], dtype=np.float32
)
),
wrist_normal=Rslice_world_device
@ torch.from_numpy(
np.array(
[
d.wrist_and_palm_normal_device.wrist_normal_device
for d in detections
],
dtype=np.float32,
)
),
palm_position=Tslice_world_device
@ torch.from_numpy(
np.array(
[d.palm_position_device for d in detections], dtype=np.float32
)
),
palm_normal=Rslice_world_device
@ torch.from_numpy(
np.array(
[
d.wrist_and_palm_normal_device.palm_normal_device
for d in detections
],
dtype=np.float32,
)
),
indices=torch.from_numpy(np.array(indices, dtype=np.int64)),
)
return CorrespondedAriaHandWristPoseDetections(
detections_left_concat=form_detections_concat(
detections_left, indices_left
),
detections_right_concat=form_detections_concat(
detections_right, indices_right
),
)
class SingleHandHamerOutputWrtCameraConcatenated(TensorDataclass):
verts: Float[Tensor, "n_detections n_verts 3"]
keypoints_3d: Float[Tensor, "n_detections n_keypoints 3"]
mano_hand_global_orient: Float[Tensor, "n_detections 3 3"]
single_hand_quats: Float[Tensor, "n_detections 15 3"]
indices: Int[Tensor, "n_detections"]
class CorrespondedHamerDetections(TensorDataclass):
mano_faces_right: Tensor
mano_faces_left: Tensor
detections_left_tuple: tuple[None | SingleHandHamerOutputWrtCamera, ...]
detections_right_tuple: tuple[None | SingleHandHamerOutputWrtCamera, ...]
T_cpf_cam: Tensor
focal_length: float
# Concatenated detections will be None if there are no detections at all.
detections_left_concat: None | SingleHandHamerOutputWrtCameraConcatenated
detections_right_concat: None | SingleHandHamerOutputWrtCameraConcatenated
def get_length(self) -> int:
assert len(self.detections_left_tuple) == len(self.detections_right_tuple)
return len(self.detections_left_tuple)
def slice(self, start_index: int, end_index: int) -> CorrespondedHamerDetections:
"""Slice the hand detections. Removes unused hand detections, and
shifts indices as necessary."""
assert start_index < end_index
def _get_detections_in_window(
detections_side_concat: None | SingleHandHamerOutputWrtCameraConcatenated,
) -> None | SingleHandHamerOutputWrtCameraConcatenated:
if detections_side_concat is None:
return None
else:
indices = detections_side_concat.indices
indices_mask = (indices >= start_index) & (indices < end_index)
out = detections_side_concat.map(lambda x: x[indices_mask].clone())
out.indices -= start_index
return out
return CorrespondedHamerDetections(
self.mano_faces_right,
self.mano_faces_left,
self.detections_left_tuple[start_index:end_index],
self.detections_right_tuple[start_index:end_index],
T_cpf_cam=self.T_cpf_cam,
focal_length=self.focal_length,
detections_left_concat=_get_detections_in_window(
self.detections_left_concat
),
detections_right_concat=_get_detections_in_window(
self.detections_right_concat
),
)
@staticmethod
def load(
hand_pkl_path: Path,
target_timestamps_sec: tuple[float, ...],
) -> CorrespondedHamerDetections:
"""Helper which takes as input:
(1) A path to a pickle file containing hand detections through time.
See feb25_hamer_outputs_from_vrs.py for how this is generated.
(2) A set of target timestamps, sorted, in seconds.
We then output a data structure that has hand detections (or `None`) for each target timestamp.
"""
with open(hand_pkl_path, "rb") as f:
hamer_out = cast(SavedHamerOutputs, pickle.load(f))
def match_detections_to_targets(
detections_wrt_cam: dict[int, None | SingleHandHamerOutputWrtCamera],
) -> list[None | SingleHandHamerOutputWrtCamera]:
# Approximate the frame rate of the detections.
est_fps = len(detections_wrt_cam) / (
(max(detections_wrt_cam.keys()) - min(detections_wrt_cam.keys())) / 1e9
)
# Usually framerate is either 10 FPS or 30 FPS. We might want to
# run on 1 FPS video in the future, we can tweak this assert if we
# run into that...
assert 5 < est_fps < 40
# Get nanosecond timestamps within our target timestamp window.
# Note that input dictionary keys are nanosecond timestamps!
detect_ns = sorted(
[
time_ns
for time_ns in detections_wrt_cam.keys()
if time_ns / 1e9 >= target_timestamps_sec[0] - 1 / est_fps
and time_ns / 1e9 <= target_timestamps_sec[-1] + 1 / est_fps
]
)
delta_matrix = np.abs(
np.array(target_timestamps_sec)[:, None]
- np.array(detect_ns)[None, :] / 1e9
)
# For each target, which is the closest detection?
best_det_from_target = np.argmin(delta_matrix, axis=-1)
# For each detection, which is the closest target?
best_target_from_det = np.argmin(delta_matrix, axis=0)
# Get detection list; we do a cycle-consistency check to make sure
# we get a 1-to-1 mapping.
out: list[None | SingleHandHamerOutputWrtCamera] = []
for i in range(len(target_timestamps_sec)):
if best_target_from_det[best_det_from_target[i]] == i:
out.append(detections_wrt_cam[detect_ns[best_det_from_target[i]]])
else:
out.append(None)
return out
detections_left = match_detections_to_targets(
hamer_out["detections_left_wrt_cam"]
)
detections_right = match_detections_to_targets(
hamer_out["detections_right_wrt_cam"]
)
assert (
len(detections_left) == len(detections_right) == len(target_timestamps_sec)
)
def make_concat_detections(
detections_side: list[None | SingleHandHamerOutputWrtCamera],
detections_other_side: list[None | SingleHandHamerOutputWrtCamera],
) -> None | SingleHandHamerOutputWrtCameraConcatenated:
detections_side_concat = None
# Filter out HaMeR detections that are in the same location as each
# other.
# Sometimes we have a left detections and a right detection both in
# the same location. This filters both out.
detections_side_filtered: list[None | SingleHandHamerOutputWrtCamera] = []
for i, d in enumerate(detections_side):
if d is None:
detections_side_filtered.append(None)
continue
num_d = d["verts"].shape[0]
keep_mask = np.ones(d["keypoints_3d"].shape[0], dtype=bool)
for offset in range(-15, 15):
i_offset = i + offset
if i_offset < 0 or i_offset >= len(detections_other_side):
continue
d_other = detections_other_side[i_offset]
if d_other is None:
# detections_side_filtered.append(d)
continue
num_d_other = d_other["verts"].shape[0]
dist_matrix = np.linalg.norm(
d["keypoints_3d"][:, None, 0, :]
- d_other["keypoints_3d"][None, :, 0, :],
axis=-1,
)
assert dist_matrix.shape == (num_d, num_d_other)
keep_mask = np.logical_and(
keep_mask, np.all(dist_matrix > 0.1, axis=-1)
)
if keep_mask.sum() == 0:
detections_side_filtered.append(None)
else:
detections_side_filtered.append(
cast(
SingleHandHamerOutputWrtCamera,
{k: cast(np.ndarray, v)[keep_mask] for k, v in d.items()},
)
)
del detections_side
detections_side_not_none = [d is not None for d in detections_side_filtered]
if not any(detections_side_not_none):
return None
(valid_detection_indices,) = np.where(detections_side_not_none)
# We should be done with these.
del detections_side_not_none
del detections_other_side
detections_side_concat = SingleHandHamerOutputWrtCameraConcatenated(
verts=torch.from_numpy(
np.stack(
# Currently: we always just take the first hand detection.
[
d["verts"][0]
for d in detections_side_filtered
if d is not None
]
)
).to(torch.float32),
keypoints_3d=torch.from_numpy(
np.stack(
[
# Currently: we always just take the first hand detection.
d["keypoints_3d"][0]
for d in detections_side_filtered
if d is not None
]
)
).to(torch.float32),
mano_hand_global_orient=torch.from_numpy(
np.stack(
[
# Currently: we always just take the first hand detection.
d["mano_hand_global_orient"][0]
for d in detections_side_filtered
if d is not None
]
)
).to(torch.float32),
single_hand_quats=SO3.from_matrix(
torch.from_numpy(
np.stack(
[
# Currently: we always just take the first hand detection.
d["mano_hand_pose"][0]
for d in detections_side_filtered
if d is not None
]
)
).to(torch.float32)
).wxyz,
indices=torch.from_numpy(valid_detection_indices),
)
return detections_side_concat
return CorrespondedHamerDetections(
mano_faces_right=torch.from_numpy(
hamer_out["mano_faces_right"].astype(np.int64)
),
mano_faces_left=torch.from_numpy(
hamer_out["mano_faces_left"].astype(np.int64)
),
detections_left_tuple=tuple(detections_left),
detections_right_tuple=tuple(detections_right),
T_cpf_cam=torch.from_numpy(hamer_out["T_cpf_cam"]).to(torch.float32),
focal_length=450,
detections_left_concat=make_concat_detections(
detections_left, detections_right
),
detections_right_concat=make_concat_detections(
detections_right, detections_left
),
)
================================================
FILE: src/egoallo/inference_utils.py
================================================
"""Functions that are useful for inference scripts."""
from __future__ import annotations
from dataclasses import dataclass
from pathlib import Path
import numpy as np
import torch
import yaml
from jaxtyping import Float
from projectaria_tools.core import mps # type: ignore
from projectaria_tools.core.data_provider import create_vrs_data_provider
from safetensors import safe_open
from torch import Tensor
from .network import EgoDenoiser, EgoDenoiserConfig
from .tensor_dataclass import TensorDataclass
from .transforms import SE3
def load_denoiser(checkpoint_dir: Path) -> EgoDenoiser:
"""Load a denoiser model."""
checkpoint_dir = checkpoint_dir.absolute()
experiment_dir = checkpoint_dir.parent
config = yaml.load(
(experiment_dir / "model_config.yaml").read_text(), Loader=yaml.Loader
)
assert isinstance(config, EgoDenoiserConfig)
model = EgoDenoiser(config)
with safe_open(checkpoint_dir / "model.safetensors", framework="pt") as f: # type: ignore
state_dict = {k: f.get_tensor(k) for k in f.keys()}
model.load_state_dict(state_dict)
return model
@dataclass(frozen=True)
class InferenceTrajectoryPaths:
"""Paths for running EgoAllo on a single sequence from Project Aria.
Our basic assumptions here are:
1. VRS file for images: there is exactly one VRS file in the trajectory root directory.
2. Aria MPS point cloud: there is either one semidense_points.csv.gz file or one global_points.csv.gz file.
- Its parent directory should contain other Aria MPS artifacts. (like poses)
- This is optionally used for guidance.
3. HaMeR outputs: The hamer_outputs.pkl file may or may not exist in the trajectory root directory.
- This is optionally used for guidance.
4. Aria MPS wrist/palm poses: There may be zero or one wrist_and_palm_poses.csv file.
- This is optionally used for guidance.
5. Scene splat/ply file: There may be a splat.ply or scene.splat file.
- This is only used for visualization.
"""
vrs_file: Path
slam_root_dir: Path
points_path: Path
hamer_outputs: Path | None
wrist_and_palm_poses_csv: Path | None
splat_path: Path | None
@staticmethod
def find(traj_root: Path) -> InferenceTrajectoryPaths:
vrs_files = tuple(traj_root.glob("**/*.vrs"))
assert len(vrs_files) == 1, f"Found {len(vrs_files)} VRS files!"
points_paths = tuple(traj_root.glob("**/semidense_points.csv.gz"))
assert len(points_paths) <= 1, f"Found multiple points files! {points_paths}"
if len(points_paths) == 0:
points_paths = tuple(traj_root.glob("**/global_points.csv.gz"))
assert len(points_paths) == 1, f"Found {len(points_paths)} files!"
hamer_outputs = traj_root / "hamer_outputs.pkl"
if not hamer_outputs.exists():
hamer_outputs = None
wrist_and_palm_poses_csv = tuple(traj_root.glob("**/wrist_and_palm_poses.csv"))
if len(wrist_and_palm_poses_csv) == 0:
wrist_and_palm_poses_csv = None
else:
assert len(wrist_and_palm_poses_csv) == 1, (
"Found multiple wrist and palm poses files!"
)
splat_path = traj_root / "splat.ply"
if not splat_path.exists():
splat_path = traj_root / "scene.splat"
if not splat_path.exists():
print("No scene splat found.")
splat_path = None
else:
print("Found splat at", splat_path)
return InferenceTrajectoryPaths(
vrs_file=vrs_files[0],
slam_root_dir=points_paths[0].parent,
points_path=points_paths[0],
hamer_outputs=hamer_outputs,
wrist_and_palm_poses_csv=wrist_and_palm_poses_csv[0]
if wrist_and_palm_poses_csv
else None,
splat_path=splat_path,
)
class InferenceInputTransforms(TensorDataclass):
"""Some relevant transforms for inference."""
Ts_world_cpf: Float[Tensor, "timesteps 7"]
Ts_world_device: Float[Tensor, "timesteps 7"]
pose_timesteps: tuple[float, ...]
@staticmethod
def load(
vrs_path: Path,
slam_root_dir: Path,
fps: int = 30,
) -> InferenceInputTransforms:
"""Read some useful transforms via MPS + the VRS calibration."""
# Read device poses.
closed_loop_path = slam_root_dir / "closed_loop_trajectory.csv"
if not closed_loop_path.exists():
# Aria digital twins.
closed_loop_path = slam_root_dir / "aria_trajectory.csv"
closed_loop_traj = mps.read_closed_loop_trajectory(str(closed_loop_path)) # type: ignore
provider = create_vrs_data_provider(str(vrs_path))
device_calib = provider.get_device_calibration()
T_device_cpf = device_calib.get_transform_device_cpf().to_matrix()
# Get downsampled CPF frames.
aria_fps = len(closed_loop_traj) / (
closed_loop_traj[-1].tracking_timestamp.total_seconds()
- closed_loop_traj[0].tracking_timestamp.total_seconds()
)
num_poses = len(closed_loop_traj)
print(f"Loaded {num_poses=} with {aria_fps=}, visualizing at {fps=}")
Ts_world_device = []
Ts_world_cpf = []
out_timestamps_secs = []
for i in range(0, num_poses, int(aria_fps // fps)):
T_world_device = closed_loop_traj[i].transform_world_device.to_matrix()
assert T_world_device.shape == (4, 4)
Ts_world_device.append(T_world_device)
Ts_world_cpf.append(T_world_device @ T_device_cpf)
out_timestamps_secs.append(
closed_loop_traj[i].tracking_timestamp.total_seconds()
)
return InferenceInputTransforms(
Ts_world_device=SE3.from_matrix(torch.from_numpy(np.array(Ts_world_device)))
.parameters()
.to(torch.float32),
Ts_world_cpf=SE3.from_matrix(torch.from_numpy(np.array(Ts_world_cpf)))
.parameters()
.to(torch.float32),
pose_timesteps=tuple(out_timestamps_secs),
)
================================================
FILE: src/egoallo/metrics_helpers.py
================================================
from typing import Literal, overload
import numpy as np
import torch
from jaxtyping import Float
from torch import Tensor
from typing_extensions import assert_never
from .transforms import SO3
def compute_foot_skate(
pred_Ts_world_joint: Float[Tensor, "num_samples time 21 7"],
) -> np.ndarray:
(num_samples, time) = pred_Ts_world_joint.shape[:2]
# Drop the person to the floor.
# This is necessary for the foot skating metric to make sense for floating people...!
pred_Ts_world_joint = pred_Ts_world_joint.clone()
pred_Ts_world_joint[..., 6] -= torch.min(pred_Ts_world_joint[..., 6])
foot_indices = torch.tensor([6, 7, 9, 10], device=pred_Ts_world_joint.device)
foot_positions = pred_Ts_world_joint[:, :, foot_indices, 4:7]
foot_positions_diff = foot_positions[:, 1:, :, :2] - foot_positions[:, :-1, :, :2]
assert foot_positions_diff.shape == (num_samples, time - 1, 4, 2)
foot_positions_diff_norm = torch.sum(torch.abs(foot_positions_diff), dim=-1)
assert foot_positions_diff_norm.shape == (num_samples, time - 1, 4)
# From EgoEgo / kinpoly.
H_thresh = torch.tensor(
# To match indices above: (ankle, ankle, toe, toe)
[0.08, 0.08, 0.04, 0.04],
device=pred_Ts_world_joint.device,
dtype=torch.float32,
)
foot_positions_diff_norm = torch.sum(torch.abs(foot_positions_diff), dim=-1)
assert foot_positions_diff_norm.shape == (num_samples, time - 1, 4)
# Threshold.
foot_positions_diff_norm = foot_positions_diff_norm * (
foot_positions[..., 1:, :, 2] < H_thresh
)
fs_per_sample = torch.sum(
torch.sum(
foot_positions_diff_norm
* (2 - 2 ** (foot_positions[..., 1:, :, 2] / H_thresh)),
dim=-1,
),
dim=-1,
)
assert fs_per_sample.shape == (num_samples,)
return fs_per_sample.numpy(force=True)
def compute_foot_contact(
pred_Ts_world_joint: Float[Tensor, "num_samples time 21 7"],
) -> np.ndarray:
(num_samples, time) = pred_Ts_world_joint.shape[:2]
foot_indices = torch.tensor([6, 7, 9, 10], device=pred_Ts_world_joint.device)
# From EgoEgo / kinpoly.
H_thresh = torch.tensor(
# To match indices above: (ankle, ankle, toe, toe)
[0.08, 0.08, 0.04, 0.04],
device=pred_Ts_world_joint.device,
dtype=torch.float32,
)
foot_positions = pred_Ts_world_joint[:, :, foot_indices, 4:7]
any_contact = torch.any(
torch.any(foot_positions[..., 2] < H_thresh, dim=-1), dim=-1
).to(torch.float32)
assert any_contact.shape == (num_samples,)
return any_contact.numpy(force=True)
def compute_head_ori(
label_Ts_world_joint: Float[Tensor, "time 21 7"],
pred_Ts_world_joint: Float[Tensor, "num_samples time 21 7"],
) -> np.ndarray:
(num_samples, time) = pred_Ts_world_joint.shape[:2]
matrix_errors = (
SO3(pred_Ts_world_joint[:, :, 14, :4]).as_matrix()
@ SO3(label_Ts_world_joint[:, 14, :4]).inverse().as_matrix()
) - torch.eye(3, device=label_Ts_world_joint.device)
assert matrix_errors.shape == (num_samples, time, 3, 3)
return torch.mean(
torch.linalg.norm(matrix_errors.reshape((num_samples, time, 9)), dim=-1),
dim=-1,
).numpy(force=True)
def compute_head_trans(
label_Ts_world_joint: Float[Tensor, "time 21 7"],
pred_Ts_world_joint: Float[Tensor, "num_samples time 21 7"],
) -> np.ndarray:
(num_samples, time) = pred_Ts_world_joint.shape[:2]
errors = pred_Ts_world_joint[:, :, 14, 4:7] - label_Ts_world_joint[:, 14, 4:7]
assert errors.shape == (num_samples, time, 3)
return torch.mean(
torch.linalg.norm(errors, dim=-1),
dim=-1,
).numpy(force=True)
def compute_mpjpe(
label_T_world_root: Float[Tensor, "time 7"],
label_Ts_world_joint: Float[Tensor, "time 21 7"],
pred_T_world_root: Float[Tensor, "num_samples time 7"],
pred_Ts_world_joint: Float[Tensor, "num_samples time 21 7"],
per_frame_procrustes_align: bool,
) -> np.ndarray:
num_samples, time, _, _ = pred_Ts_world_joint.shape
# Concatenate the world root to the joints.
label_Ts_world_joint = torch.cat(
[label_T_world_root[..., None, :], label_Ts_world_joint], dim=-2
)
pred_Ts_world_joint = torch.cat(
[pred_T_world_root[..., None, :], pred_Ts_world_joint], dim=-2
)
del label_T_world_root, pred_T_world_root
pred_joint_positions = pred_Ts_world_joint[:, :, :, 4:7]
label_joint_positions = label_Ts_world_joint[None, :, :, 4:7].repeat(
num_samples, 1, 1, 1
)
if per_frame_procrustes_align:
pred_joint_positions = procrustes_align(
points_y=pred_joint_positions,
points_x=label_joint_positions,
output="aligned_x",
)
position_differences = pred_joint_positions - label_joint_positions
assert position_differences.shape == (num_samples, time, 22, 3)
# Per-joint position errors, in millimeters.
pjpe = torch.linalg.norm(position_differences, dim=-1) * 1000.0
assert pjpe.shape == (num_samples, time, 22)
# Mean per-joint position errors.
mpjpe = torch.mean(pjpe.reshape((num_samples, -1)), dim=-1)
assert mpjpe.shape == (num_samples,)
return mpjpe.cpu().numpy()
@overload
def procrustes_align(
points_y: Float[Tensor, "*#batch N 3"],
points_x: Float[Tensor, "*#batch N 3"],
output: Literal["transforms"],
fix_scale: bool = False,
) -> tuple[Tensor, Tensor, Tensor]: ...
@overload
def procrustes_align(
points_y: Float[Tensor, "*#batch N 3"],
points_x: Float[Tensor, "*#batch N 3"],
output: Literal["aligned_x"],
fix_scale: bool = False,
) -> Tensor: ...
def procrustes_align(
points_y: Float[Tensor, "*#batch N 3"],
points_x: Float[Tensor, "*#batch N 3"],
output: Literal["transforms", "aligned_x"],
fix_scale: bool = False,
) -> tuple[Tensor, Tensor, Tensor] | Tensor:
"""Similarity transform alignment using the Umeyama method. Adapted from
SLAHMR: https://github.com/vye16/slahmr/blob/main/slahmr/geometry/pcl.py
Minimizes:
mean( || Y - s * (R @ X) + t ||^2 )
with respect to s, R, and t.
Returns an (s, R, t) tuple.
"""
*dims, N, _ = points_y.shape
device = points_y.device
N = torch.ones((*dims, 1, 1), device=device) * N
# subtract mean
my = points_y.sum(dim=-2) / N[..., 0] # (*, 3)
mx = points_x.sum(dim=-2) / N[..., 0]
y0 = points_y - my[..., None, :] # (*, N, 3)
x0 = points_x - mx[..., None, :]
# correlation
C = torch.matmul(y0.transpose(-1, -2), x0) / N # (*, 3, 3)
U, D, Vh = torch.linalg.svd(C) # (*, 3, 3), (*, 3), (*, 3, 3)
S = (
torch.eye(3, device=device)
.reshape(*(1,) * (len(dims)), 3, 3)
.repeat(*dims, 1, 1)
)
neg = torch.det(U) * torch.det(Vh.transpose(-1, -2)) < 0
S = torch.where(
neg.reshape(*dims, 1, 1),
S * torch.diag(torch.tensor([1, 1, -1], device=device)),
S,
)
R = torch.matmul(U, torch.matmul(S, Vh)) # (*, 3, 3)
D = torch.diag_embed(D) # (*, 3, 3)
if fix_scale:
s = torch.ones(*dims, 1, device=device, dtype=torch.float32)
else:
var = torch.sum(torch.square(x0), dim=(-1, -2), keepdim=True) / N # (*, 1, 1)
s = (
torch.diagonal(torch.matmul(D, S), dim1=-2, dim2=-1).sum(
dim=-1, keepdim=True
)
/ var[..., 0]
) # (*, 1)
t = my - s * torch.matmul(R, mx[..., None])[..., 0] # (*, 3)
assert s.shape == (*dims, 1)
assert R.shape == (*dims, 3, 3)
assert t.shape == (*dims, 3)
if output == "transforms":
return s, R, t
elif output == "aligned_x":
aligned_x = (
s[..., None, :] * torch.einsum("...ij,...nj->...ni", R, points_x)
+ t[..., None, :]
)
assert aligned_x.shape == points_x.shape
return aligned_x
else:
assert_never(output)
================================================
FILE: src/egoallo/network.py
================================================
from __future__ import annotations
from dataclasses import dataclass
from functools import cache, cached_property
from typing import Literal, assert_never
import numpy as np
import torch
from einops import rearrange
from jaxtyping import Bool, Float
from loguru import logger
from rotary_embedding_torch import RotaryEmbedding
from torch import Tensor, nn
from .fncsmpl import SmplhModel, SmplhShapedAndPosed
from .tensor_dataclass import TensorDataclass
from .transforms import SE3, SO3
def project_rotmats_via_svd(
rotmats: Float[Tensor, "*batch 3 3"],
) -> Float[Tensor, "*batch 3 3"]:
u, s, vh = torch.linalg.svd(rotmats)
del s
return torch.einsum("...ij,...jk->...ik", u, vh)
class EgoDenoiseTraj(TensorDataclass):
"""Data structure for denoising. Contains tensors that we are denoising, as
well as utilities for packing + unpacking them."""
betas: Float[Tensor, "*#batch timesteps 16"]
"""Body shape parameters. We don't really need the timesteps axis here,
it's just for convenience."""
body_rotmats: Float[Tensor, "*#batch timesteps 21 3 3"]
"""Local orientations for each body joint."""
contacts: Float[Tensor, "*#batch timesteps 21"]
"""Contact boolean for each joint."""
hand_rotmats: Float[Tensor, "*#batch timesteps 30 3 3"] | None
"""Local orientations for each body joint."""
@staticmethod
def get_packed_dim(include_hands: bool) -> int:
packed_dim = 16 + 21 * 9 + 21
if include_hands:
packed_dim += 30 * 9
return packed_dim
def apply_to_body(self, body_model: SmplhModel) -> SmplhShapedAndPosed:
device = self.betas.device
dtype = self.betas.dtype
assert self.hand_rotmats is not None
shaped = body_model.with_shape(self.betas)
posed = shaped.with_pose(
T_world_root=SE3.identity(device=device, dtype=dtype).parameters(),
local_quats=SO3.from_matrix(
torch.cat([self.body_rotmats, self.hand_rotmats], dim=-3)
).wxyz,
)
return posed
def pack(self) -> Float[Tensor, "*#batch timesteps d_state"]:
"""Pack trajectory into a single flattened vector."""
(*batch, time, num_joints, _, _) = self.body_rotmats.shape
assert num_joints == 21
return torch.cat(
[
x.reshape((*batch, time, -1))
for x in vars(self).values()
if x is not None
],
dim=-1,
)
@classmethod
def unpack(
cls,
x: Float[Tensor, "*#batch timesteps d_state"],
include_hands: bool,
project_rotmats: bool = False,
) -> EgoDenoiseTraj:
"""Unpack trajectory from a single flattened vector.
Args:
x: Packed trajectory.
project_rotmats: If True, project the rotation matrices to SO(3) via SVD.
"""
(*batch, time, d_state) = x.shape
assert d_state == cls.get_packed_dim(include_hands)
if include_hands:
betas, body_rotmats_flat, contacts, hand_rotmats_flat = torch.split(
x, [16, 21 * 9, 21, 30 * 9], dim=-1
)
body_rotmats = body_rotmats_flat.reshape((*batch, time, 21, 3, 3))
hand_rotmats = hand_rotmats_flat.reshape((*batch, time, 30, 3, 3))
assert betas.shape == (*batch, time, 16)
else:
betas, body_rotmats_flat, contacts = torch.split(
x, [16, 21 * 9, 21], dim=-1
)
body_rotmats = body_rotmats_flat.reshape((*batch, time, 21, 3, 3))
hand_rotmats = None
assert betas.shape == (*batch, time, 16)
if project_rotmats:
# We might want to handle the -1 determinant case as well.
body_rotmats = project_rotmats_via_svd(body_rotmats)
return EgoDenoiseTraj(
betas=betas,
body_rotmats=body_rotmats,
contacts=contacts,
hand_rotmats=hand_rotmats,
)
@dataclass(frozen=True)
class EgoDenoiserConfig:
max_t: int = 1000
fourier_enc_freqs: int = 3
d_latent: int = 512
d_feedforward: int = 2048
d_noise_emb: int = 1024
num_heads: int = 4
encoder_layers: int = 6
decoder_layers: int = 6
dropout_p: float = 0.0
activation: Literal["gelu", "relu"] = "gelu"
positional_encoding: Literal["transformer", "rope"] = "rope"
noise_conditioning: Literal["token", "film"] = "token"
xattn_mode: Literal["kv_from_cond_q_from_x", "kv_from_x_q_from_cond"] = (
"kv_from_cond_q_from_x"
)
include_canonicalized_cpf_rotation_in_cond: bool = True
include_hands: bool = True
"""Whether to include hand joints (+15 per hand) in the denoised state."""
cond_param: Literal[
"ours", "canonicalized", "absolute", "absrel", "absrel_global_deltas"
] = "ours"
"""Which conditioning parameterization to use.
"ours" is the default, we try to be clever and design something with nice
equivariance properties.
"canonicalized" contains a transformation that's canonicalized to aligned
to the first frame.
"absolute" is the naive case, where we just pass in transformations
directly.
"""
include_hand_positions_cond: bool = False
"""Whether to include hand positions in the conditioning information."""
@cached_property
def d_cond(self) -> int:
"""Dimensionality of conditioning vector."""
if self.cond_param == "ours":
d_cond = 0
d_cond += 12 # Relative CPF pose, flattened 3x4 matrix.
d_cond += 1 # Floor height.
if self.include_canonicalized_cpf_rotation_in_cond:
d_cond += 9 # Canonicalized CPF rotation, flattened 3x3 matrix.
elif self.cond_param == "canonicalized":
d_cond = 12
elif self.cond_param == "absolute":
d_cond = 12
elif self.cond_param == "absrel":
# Both absolute and relative!
d_cond = 24
elif self.cond_param == "absrel_global_deltas":
# Both absolute and relative!
d_cond = 24
else:
assert_never(self.cond_param)
# Add two 3D positions to the conditioning dimension if we're including
# hand conditioning.
if self.include_hand_positions_cond:
d_cond = d_cond + 6
d_cond = d_cond + d_cond * self.fourier_enc_freqs * 2 # Fourier encoding.
return d_cond
def make_cond(
self,
T_cpf_tm1_cpf_t: Float[Tensor, "batch time 7"],
T_world_cpf: Float[Tensor, "batch time 7"],
hand_positions_wrt_cpf: Float[Tensor, "batch time 6"] | None,
) -> Float[Tensor, "batch time d_cond"]:
"""Construct conditioning information from CPF pose."""
(batch, time, _) = T_cpf_tm1_cpf_t.shape
# Construct device pose conditioning.
if self.cond_param == "ours":
# Compute conditioning terms. +Z is up in the world frame. We want
# the translation to be invariant to translations in the world X/Y
# directions.
height_from_floor = T_world_cpf[..., 6:7]
cond_parts = [
SE3(T_cpf_tm1_cpf_t).as_matrix()[..., :3, :].reshape((batch, time, 12)),
height_from_floor,
]
if self.include_canonicalized_cpf_rotation_in_cond:
# We want the rotation to be invariant to rotations around the
# world Z axis. Visualization of what's happening here:
#
# https://gist.github.com/brentyi/9226d082d2707132af39dea92b8609f6
#
# (The coordinate frame may differ by some axis-swapping
# compared to the exact equations in the paper. But to the
# network these will all look the same.)
R_world_cpf = SE3(T_world_cpf).rotation().wxyz
forward_cpf = R_world_cpf.new_tensor([0.0, 0.0, 1.0])
forward_world = SO3(R_world_cpf) @ forward_cpf
assert forward_world.shape == (batch, time, 3)
R_canonical_world = SO3.from_z_radians(
-torch.arctan2(forward_world[..., 1], forward_world[..., 0])
).wxyz
assert R_canonical_world.shape == (batch, time, 4)
cond_parts.append(
(SO3(R_canonical_world) @ SO3(R_world_cpf))
.as_matrix()
.reshape((batch, time, 9)),
)
cond = torch.cat(cond_parts, dim=-1)
elif self.cond_param == "canonicalized":
# Align the first timestep.
# Put poses so start is at origin, facing forward.
R_world_cpf = SE3(T_world_cpf[:, 0:1, :]).rotation().wxyz
forward_cpf = R_world_cpf.new_tensor([0.0, 0.0, 1.0])
forward_world = SO3(R_world_cpf) @ forward_cpf
assert forward_world.shape == (batch, 1, 3)
R_canonical_world = SO3.from_z_radians(
-torch.arctan2(forward_world[..., 1], forward_world[..., 0])
).wxyz
assert R_canonical_world.shape == (batch, 1, 4)
R_canonical_cpf = SO3(R_canonical_world) @ SE3(T_world_cpf).rotation()
t_canonical_cpf = SO3(R_canonical_world) @ SE3(T_world_cpf).translation()
t_canonical_cpf = t_canonical_cpf - t_canonical_cpf[:, 0:1, :]
cond = (
SE3.from_rotation_and_translation(R_canonical_cpf, t_canonical_cpf)
.as_matrix()[..., :3, :4]
.reshape((batch, time, 12))
)
elif self.cond_param == "absolute":
cond = SE3(T_world_cpf).as_matrix()[..., :3, :4].reshape((batch, time, 12))
elif self.cond_param == "absrel":
cond = torch.concatenate(
[
SE3(T_world_cpf)
.as_matrix()[..., :3, :4]
.reshape((batch, time, 12)),
SE3(T_cpf_tm1_cpf_t)
.as_matrix()[..., :3, :4]
.reshape((batch, time, 12)),
],
dim=-1,
)
elif self.cond_param == "absrel_global_deltas":
cond = torch.concatenate(
[
SE3(T_world_cpf)
.as_matrix()[..., :3, :4]
.reshape((batch, time, 12)),
SE3(T_cpf_tm1_cpf_t)
.rotation()
.as_matrix()
.reshape((batch, time, 9)),
(
SE3(T_world_cpf).rotation()
@ SE3(T_cpf_tm1_cpf_t).inverse().translation()
).reshape((batch, time, 3)),
],
dim=-1,
)
else:
assert_never(self.cond_param)
# Condition on hand poses as well.
# We didn't use this for the paper.
if self.include_hand_positions_cond:
if hand_positions_wrt_cpf is None:
logger.warning(
"Model is looking for hand conditioning but none was provided. Passing in zeros."
)
hand_positions_wrt_cpf = torch.zeros(
(batch, time, 6), device=T_world_cpf.device
)
assert hand_positions_wrt_cpf.shape == (batch, time, 6)
cond = torch.cat([cond, hand_positions_wrt_cpf], dim=-1)
cond = fourier_encode(cond, freqs=self.fourier_enc_freqs)
assert cond.shape == (batch, time, self.d_cond)
return cond
class EgoDenoiser(nn.Module):
"""Denoising network for human motion.
Inputs are noisy trajectory, conditioning information, and timestep.
Output is denoised trajectory.
"""
def __init__(self, config: EgoDenoiserConfig):
super().__init__()
self.config = config
Activation = {"gelu": nn.GELU, "relu": nn.ReLU}[config.activation]
# MLP encoders and decoders for each modality we want to denoise.
modality_dims: dict[str, int] = {
"betas": 16,
"body_rotmats": 21 * 9,
"contacts": 21,
}
if config.include_hands:
modality_dims["hand_rotmats"] = 30 * 9
assert sum(modality_dims.values()) == self.get_d_state()
self.encoders = nn.ModuleDict(
{
k: nn.Sequential(
nn.Linear(modality_dim, config.d_latent),
Activation(),
nn.Linear(config.d_latent, config.d_latent),
Activation(),
nn.Linear(config.d_latent, config.d_latent),
)
for k, modality_dim in modality_dims.items()
}
)
self.decoders = nn.ModuleDict(
{
k: nn.Sequential(
nn.Linear(config.d_latent, config.d_latent),
nn.LayerNorm(normalized_shape=config.d_latent),
Activation(),
nn.Linear(config.d_latent, config.d_latent),
Activation(),
nn.Linear(config.d_latent, modality_dim),
)
for k, modality_dim in modality_dims.items()
}
)
# Helpers for converting between input dimensionality and latent dimensionality.
self.latent_from_cond = nn.Linear(config.d_cond, config.d_latent)
# Noise embedder.
self.noise_emb = nn.Embedding(
# index 0 will be t=1
# index 999 will be t=1000
num_embeddings=config.max_t,
embedding_dim=config.d_noise_emb,
)
self.noise_emb_token_proj = (
nn.Linear(config.d_noise_emb, config.d_latent, bias=False)
if config.noise_conditioning == "token"
else None
)
# Encoder / decoder layers.
# Inputs are conditioning (current noise level, observations); output
# is encoded conditioning information.
self.encoder_layers = nn.ModuleList(
[
TransformerBlock(
TransformerBlockConfig(
d_latent=config.d_latent,
d_noise_emb=config.d_noise_emb,
d_feedforward=config.d_feedforward,
n_heads=config.num_heads,
dropout_p=config.dropout_p,
activation=config.activation,
include_xattn=False, # No conditioning for encoder.
use_rope_embedding=config.positional_encoding == "rope",
use_film_noise_conditioning=config.noise_conditioning == "film",
xattn_mode=config.xattn_mode,
)
)
for _ in range(config.encoder_layers)
]
)
self.decoder_layers = nn.ModuleList(
[
TransformerBlock(
TransformerBlockConfig(
d_latent=config.d_latent,
d_noise_emb=config.d_noise_emb,
d_feedforward=config.d_feedforward,
n_heads=config.num_heads,
dropout_p=config.dropout_p,
activation=config.activation,
include_xattn=True, # Include conditioning for the decoder.
use_rope_embedding=config.positional_encoding == "rope",
use_film_noise_conditioning=config.noise_conditioning == "film",
xattn_mode=config.xattn_mode,
)
)
for _ in range(config.decoder_layers)
]
)
def get_d_state(self) -> int:
return EgoDenoiseTraj.get_packed_dim(self.config.include_hands)
def forward(
self,
x_t_packed: Float[Tensor, "batch time state_dim"],
t: Float[Tensor, "batch"],
*,
T_world_cpf: Float[Tensor, "batch time 7"],
T_cpf_tm1_cpf_t: Float[Tensor, "batch time 7"],
project_output_rotmats: bool,
# Observed hand positions, relative to the CPF.
hand_positions_wrt_cpf: Float[Tensor, "batch time 6"] | None,
# Attention mask for using shorter sequences.
mask: Bool[Tensor, "batch time"] | None,
# Mask for when to drop out / keep conditioning information.
cond_dropout_keep_mask: Bool[Tensor, "batch"] | None = None,
) -> Float[Tensor, "batch time state_dim"]:
"""Predict a denoised trajectory. Note that `t` refers to a noise
level, not a timestep."""
config = self.config
x_t = EgoDenoiseTraj.unpack(x_t_packed, include_hands=self.config.include_hands)
(batch, time, num_body_joints, _, _) = x_t.body_rotmats.shape
assert num_body_joints == 21
# Encode the trajectory into a single vector per timestep.
x_t_encoded = (
self.encoders["betas"](x_t.betas.reshape((batch, time, -1)))
+ self.encoders["body_rotmats"](x_t.body_rotmats.reshape((batch, time, -1)))
+ self.encoders["contacts"](x_t.contacts)
)
if self.config.include_hands:
assert x_t.hand_rotmats is not None
x_t_encoded = x_t_encoded + self.encoders["hand_rotmats"](
x_t.hand_rotmats.reshape((batch, time, -1))
)
assert x_t_encoded.shape == (batch, time, config.d_latent)
# Embed the diffusion noise level.
assert t.shape == (batch,)
noise_emb = self.noise_emb(t - 1)
assert noise_emb.shape == (batch, config.d_noise_emb)
# Prepare conditioning information.
cond = config.make_cond(
T_cpf_tm1_cpf_t,
T_world_cpf=T_world_cpf,
hand_positions_wrt_cpf=hand_positions_wrt_cpf,
)
# Randomly drop out conditioning information; this serves as a
# regularizer that aims to improve sample diversity.
if cond_dropout_keep_mask is not None:
assert cond_dropout_keep_mask.shape == (batch,)
cond = cond * cond_dropout_keep_mask[:, None, None]
# Prepare encoder and decoder inputs.
if config.positional_encoding == "rope":
pos_enc = 0
elif config.positional_encoding == "transformer":
pos_enc = make_positional_encoding(
d_latent=config.d_latent,
length=time,
dtype=cond.dtype,
)[None, ...].to(x_t_encoded.device)
assert pos_enc.shape == (1, time, config.d_latent)
else:
assert_never(config.positional_encoding)
encoder_out = self.latent_from_cond(cond) + pos_enc
decoder_out = x_t_encoded + pos_enc
# Append the noise embedding to the encoder and decoder inputs.
# This is weird if we're using rotary embeddings!
if self.noise_emb_token_proj is not None:
noise_emb_token = self.noise_emb_token_proj(noise_emb)
assert noise_emb_token.shape == (batch, config.d_latent)
encoder_out = torch.cat([noise_emb_token[:, None, :], encoder_out], dim=1)
decoder_out = torch.cat([noise_emb_token[:, None, :], decoder_out], dim=1)
assert (
encoder_out.shape
== decoder_out.shape
== (batch, time + 1, config.d_latent)
)
num_tokens = time + 1
else:
num_tokens = time
# Compute attention mask. This needs to be a fl
if mask is None:
attn_mask = None
else:
assert mask.shape == (batch, time)
assert mask.dtype == torch.bool
if self.noise_emb_token_proj is not None: # Account for noise token.
mask = torch.cat([mask.new_ones((batch, 1)), mask], dim=1)
# Last two dimensions of mask are (query, key). We're masking out only keys;
# it's annoying for the softmax to mask out entire rows without getting NaNs.
attn_mask = mask[:, None, None, :].repeat(1, 1, num_tokens, 1)
assert attn_mask.shape == (batch, 1, num_tokens, num_tokens)
assert attn_mask.dtype == torch.bool
# Forward pass through transformer.
for layer in self.encoder_layers:
encoder_out = layer(encoder_out, attn_mask, noise_emb=noise_emb)
for layer in self.decoder_layers:
decoder_out = layer(
decoder_out, attn_mask, noise_emb=noise_emb, cond=encoder_out
)
# Remove the extra token corresponding to the noise embedding.
if self.noise_emb_token_proj is not None:
decoder_out = decoder_out[:, 1:, :]
assert isinstance(decoder_out, Tensor)
assert decoder_out.shape == (batch, time, config.d_latent)
packed_output = torch.cat(
[
# Project rotation matrices for body_rotmats via SVD,
(
project_rotmats_via_svd(
modality_decoder(decoder_out).reshape((-1, 3, 3))
).reshape(
(batch, time, {"body_rotmats": 21, "hand_rotmats": 30}[key] * 9)
)
# if enabled,
if project_output_rotmats
and key in ("body_rotmats", "hand_rotmats")
# otherwise, just decode normally.
else modality_decoder(decoder_out)
)
for key, modality_decoder in self.decoders.items()
],
dim=-1,
)
assert packed_output.shape == (batch, time, self.get_d_state())
# Return packed output.
return packed_output
@cache
def make_positional_encoding(
d_latent: int, length: int, dtype: torch.dtype
) -> Float[Tensor, "length d_latent"]:
"""Computes standard Transformer positional encoding."""
pe = torch.zeros(length, d_latent, dtype=dtype)
position = torch.arange(0, length, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(
torch.arange(0, d_latent, 2).float() * (-np.log(10000.0) / d_latent)
)
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
assert pe.shape == (length, d_latent)
return pe
def fourier_encode(
x: Float[Tensor, "*#batch channels"], freqs: int
) -> Float[Tensor, "*#batch channels+2*freqs*channels"]:
"""Apply Fourier encoding to a tensor."""
*batch_axes, x_dim = x.shape
coeffs = 2.0 ** torch.arange(freqs, device=x.device)
scaled = (x[..., None] * coeffs).reshape((*batch_axes, x_dim * freqs))
return torch.cat(
[
x,
torch.sin(torch.cat([scaled, scaled + torch.pi / 2.0], dim=-1)),
],
dim=-1,
)
@dataclass(frozen=True)
class TransformerBlockConfig:
d_latent: int
d_noise_emb: int
d_feedforward: int
n_heads: int
dropout_p: float
activation: Literal["gelu", "relu"]
include_xattn: bool
use_rope_embedding: bool
use_film_noise_conditioning: bool
xattn_mode: Literal["kv_from_cond_q_from_x", "kv_from_x_q_from_cond"]
class TransformerBlock(nn.Module):
"""An even-tempered Transformer block."""
def __init__(self, config: TransformerBlockConfig) -> None:
super().__init__()
self.sattn_qkv_proj = nn.Linear(
config.d_latent, config.d_latent * 3, bias=False
)
self.sattn_out_proj = nn.Linear(config.d_latent, config.d_latent, bias=False)
self.layernorm1 = nn.LayerNorm(config.d_latent)
self.layernorm2 = nn.LayerNorm(config.d_latent)
assert config.d_latent % config.n_heads == 0
self.rotary_emb = (
RotaryEmbedding(config.d_latent // config.n_heads)
if config.use_rope_embedding
else None
)
if config.include_xattn:
self.xattn_kv_proj = nn.Linear(
config.d_latent, config.d_latent * 2, bias=False
)
self.xattn_q_proj = nn.Linear(config.d_latent, config.d_latent, bias=False)
self.xattn_layernorm = nn.LayerNorm(config.d_latent)
self.xattn_out_proj = nn.Linear(
config.d_latent, config.d_latent, bias=False
)
self.norm_no_learnable = nn.LayerNorm(
config.d_feedforward, elementwise_affine=False, bias=False
)
self.activatio
gitextract_96b1_mr0/
├── .github/
│ └── workflows/
│ └── pyright.yml
├── .gitignore
├── 0a_preprocess_training_data.py
├── 0b_preprocess_training_data.py
├── 1_train_motion_prior.py
├── 2_run_hamer_on_vrs.py
├── 3_aria_inference.py
├── 4_visualize_outputs.py
├── 5_eval_body_metrics.py
├── LICENSE
├── README.md
├── download_checkpoint_and_data.sh
├── pyproject.toml
└── src/
└── egoallo/
├── __init__.py
├── fncsmpl.py
├── fncsmpl_extensions.py
├── fncsmpl_jax.py
├── guidance_optimizer_jax.py
├── hand_detection_structs.py
├── inference_utils.py
├── metrics_helpers.py
├── network.py
├── preprocessing/
│ ├── __init__.py
│ ├── body_model/
│ │ ├── __init__.py
│ │ ├── body_model.py
│ │ ├── skeleton.py
│ │ ├── specs.py
│ │ └── utils.py
│ ├── geometry/
│ │ ├── __init__.py
│ │ ├── camera.py
│ │ ├── helpers.py
│ │ ├── plane.py
│ │ ├── rotation.py
│ │ └── transforms/
│ │ ├── __init__.py
│ │ ├── _base.py
│ │ ├── _se2.py
│ │ ├── _se3.py
│ │ ├── _so2.py
│ │ ├── _so3.py
│ │ ├── hints/
│ │ │ └── __init__.py
│ │ └── utils/
│ │ ├── __init__.py
│ │ └── _utils.py
│ └── util/
│ ├── __init__.py
│ └── tensor.py
├── py.typed
├── sampling.py
├── tensor_dataclass.py
├── training_loss.py
├── training_utils.py
├── transforms/
│ ├── __init__.py
│ ├── _base.py
│ ├── _se3.py
│ ├── _so3.py
│ └── utils/
│ ├── __init__.py
│ └── _utils.py
└── vis_helpers.py
SYMBOL INDEX (426 symbols across 40 files)
FILE: 0a_preprocess_training_data.py
function load_neutral_beta_conversion (line 64) | def load_neutral_beta_conversion(gender: str) -> Tuple[np.ndarray, np.nd...
function convert_gender_neutral_beta (line 70) | def convert_gender_neutral_beta(
function determine_floor_height_and_contacts (line 85) | def determine_floor_height_and_contacts(
function detect_joint_contact (line 315) | def detect_joint_contact(
function compute_root_align_mats (line 336) | def compute_root_align_mats(root_orient):
function compute_joint_align_mats (line 356) | def compute_joint_align_mats(joint_seq):
function compute_align_from_body_right (line 375) | def compute_align_from_body_right(body_right):
function estimate_velocity (line 398) | def estimate_velocity(data_seq, h):
function estimate_angular_velocity (line 413) | def estimate_angular_velocity(rot_seq, h):
function load_seq_smpl_params (line 438) | def load_seq_smpl_params(input_path: str, num_betas: int = 16):
function run_batch_smpl (line 475) | def run_batch_smpl(
function process_seq (line 512) | def process_seq(
function process_seq_data (line 543) | def process_seq_data(
class Config (line 729) | class Config:
function check_skip (line 740) | def check_skip(path_name: str) -> bool:
function main (line 751) | def main(cfg: Config):
FILE: 0b_preprocess_training_data.py
function main (line 21) | def main(
FILE: 1_train_motion_prior.py
class EgoAlloTrainConfig (line 23) | class EgoAlloTrainConfig:
function get_experiment_dir (line 53) | def get_experiment_dir(experiment_name: str, version: int = 0) -> Path:
function run_training (line 68) | def run_training(
FILE: 2_run_hamer_on_vrs.py
function main (line 26) | def main(traj_root: Path, overwrite: bool = False) -> None:
function run_hamer_and_save (line 45) | def run_hamer_and_save(
function put_text (line 203) | def put_text(
FILE: 3_aria_inference.py
class Args (line 30) | class Args:
function main (line 66) | def main(args: Args) -> None:
FILE: 4_visualize_outputs.py
function main (line 32) | def main(
function load_and_visualize (line 100) | def load_and_visualize(
FILE: 5_eval_body_metrics.py
function main (line 47) | def main(
FILE: src/egoallo/fncsmpl.py
class SmplhModel (line 37) | class SmplhModel(TensorDataclass):
method load (line 58) | def load(model_path: Path) -> SmplhModel:
method get_num_joints (line 90) | def get_num_joints(self) -> int:
method with_shape (line 94) | def with_shape(self, betas: Float[Tensor, "*#batch n_betas"]) -> Smplh...
class SmplhShaped (line 119) | class SmplhShaped(TensorDataclass):
method with_pose_decomposed (line 135) | def with_pose_decomposed(
method with_pose (line 158) | def with_pose(
class SmplhShapedAndPosed (line 183) | class SmplhShapedAndPosed(TensorDataclass):
method with_new_T_world_root (line 196) | def with_new_T_world_root(
method lbs (line 210) | def lbs(self) -> SmplMesh:
class SmplMesh (line 265) | class SmplMesh(TensorDataclass):
function forward_kinematics (line 278) | def forward_kinematics(
function broadcasting_cat (line 326) | def broadcasting_cat(tensors: list[Tensor], dim: int) -> Tensor:
function _normalize_dtype (line 346) | def _normalize_dtype(v: np.ndarray) -> np.ndarray:
FILE: src/egoallo/fncsmpl_extensions.py
function get_T_world_cpf (line 14) | def get_T_world_cpf(mesh: fncsmpl.SmplMesh) -> Float[Tensor, "*#batch 7"]:
function get_T_head_cpf (line 29) | def get_T_head_cpf(shaped: fncsmpl.SmplhShaped) -> Float[Tensor, "*#batc...
function get_T_world_root_from_cpf_pose (line 53) | def get_T_world_root_from_cpf_pose(
FILE: src/egoallo/fncsmpl_jax.py
class SmplhModel (line 22) | class SmplhModel:
method load (line 43) | def load(npz_path: Path) -> SmplhModel:
method with_shape (line 58) | def with_shape(
class SmplhShaped (line 86) | class SmplhShaped:
method with_pose_decomposed (line 100) | def with_pose_decomposed(
method with_pose (line 120) | def with_pose(
method get_T_head_cpf (line 170) | def get_T_head_cpf(self) -> Float[Array, "7"]:
class SmplhShapedAndPosed (line 187) | class SmplhShapedAndPosed:
method with_new_T_world_root (line 200) | def with_new_T_world_root(
method lbs (line 214) | def lbs(self) -> SmplhMesh:
class SmplhMesh (line 262) | class SmplhMesh:
function broadcasting_cat (line 272) | def broadcasting_cat(arrays: Sequence[jax.Array | onp.ndarray], axis: in...
function _normalize_dtype (line 291) | def _normalize_dtype(v: onp.ndarray) -> onp.ndarray:
FILE: src/egoallo/guidance_optimizer_jax.py
function do_guidance_optimization (line 34) | def do_guidance_optimization(
class _SmplhBodyPosesVar (line 90) | class _SmplhBodyPosesVar(
class _SmplhSingleHandPosesVar (line 103) | class _SmplhSingleHandPosesVar(
function _optimize_vmapped (line 117) | def _optimize_vmapped(
class JaxGuidanceParams (line 163) | class JaxGuidanceParams:
method defaults (line 196) | def defaults(
function _optimize (line 303) | def _optimize(
function _get_mano_from_openpose_indices (line 887) | def _get_mano_from_openpose_indices(include_tips: bool) -> Int[onp.ndarr...
FILE: src/egoallo/hand_detection_structs.py
class SingleHandHamerOutputWrtCamera (line 24) | class SingleHandHamerOutputWrtCamera(TypedDict):
class SavedHamerOutputs (line 34) | class SavedHamerOutputs(TypedDict):
class AriaHandWristPoseWrtWorld (line 52) | class AriaHandWristPoseWrtWorld(TensorDataclass):
class CorrespondedAriaHandWristPoseDetections (line 63) | class CorrespondedAriaHandWristPoseDetections(TensorDataclass):
method load (line 68) | def load(
class SingleHandHamerOutputWrtCameraConcatenated (line 175) | class SingleHandHamerOutputWrtCameraConcatenated(TensorDataclass):
class CorrespondedHamerDetections (line 183) | class CorrespondedHamerDetections(TensorDataclass):
method get_length (line 195) | def get_length(self) -> int:
method slice (line 199) | def slice(self, start_index: int, end_index: int) -> CorrespondedHamer...
method load (line 233) | def load(
FILE: src/egoallo/inference_utils.py
function load_denoiser (line 22) | def load_denoiser(checkpoint_dir: Path) -> EgoDenoiser:
class InferenceTrajectoryPaths (line 41) | class InferenceTrajectoryPaths:
method find (line 65) | def find(traj_root: Path) -> InferenceTrajectoryPaths:
class InferenceInputTransforms (line 108) | class InferenceInputTransforms(TensorDataclass):
method load (line 116) | def load(
FILE: src/egoallo/metrics_helpers.py
function compute_foot_skate (line 12) | def compute_foot_skate(
function compute_foot_contact (line 59) | def compute_foot_contact(
function compute_head_ori (line 84) | def compute_head_ori(
function compute_head_trans (line 101) | def compute_head_trans(
function compute_mpjpe (line 115) | def compute_mpjpe(
function procrustes_align (line 160) | def procrustes_align(
function procrustes_align (line 169) | def procrustes_align(
function procrustes_align (line 177) | def procrustes_align(
FILE: src/egoallo/network.py
function project_rotmats_via_svd (line 20) | def project_rotmats_via_svd(
class EgoDenoiseTraj (line 28) | class EgoDenoiseTraj(TensorDataclass):
method get_packed_dim (line 46) | def get_packed_dim(include_hands: bool) -> int:
method apply_to_body (line 52) | def apply_to_body(self, body_model: SmplhModel) -> SmplhShapedAndPosed:
method pack (line 65) | def pack(self) -> Float[Tensor, "*#batch timesteps d_state"]:
method unpack (line 79) | def unpack(
class EgoDenoiserConfig (line 122) | class EgoDenoiserConfig:
method d_cond (line 162) | def d_cond(self) -> int:
method make_cond (line 192) | def make_cond(
class EgoDenoiser (line 309) | class EgoDenoiser(nn.Module):
method __init__ (line 316) | def __init__(self, config: EgoDenoiserConfig):
method get_d_state (line 416) | def get_d_state(self) -> int:
method forward (line 419) | def forward(
function make_positional_encoding (line 559) | def make_positional_encoding(
function fourier_encode (line 574) | def fourier_encode(
class TransformerBlockConfig (line 591) | class TransformerBlockConfig:
class TransformerBlock (line 604) | class TransformerBlock(nn.Module):
method __init__ (line 607) | def __init__(self, config: TransformerBlockConfig) -> None:
method forward (line 651) | def forward(
method _sattn (line 692) | def _sattn(self, x: Tensor, attn_mask: Tensor | None) -> Tensor:
method _xattn (line 717) | def _xattn(self, x: Tensor, attn_mask: Tensor | None, cond: Tensor) ->...
function zero_module (line 753) | def zero_module(module):
FILE: src/egoallo/preprocessing/body_model/body_model.py
class BodyModel (line 25) | class BodyModel(nn.Module):
method __init__ (line 30) | def __init__(
method _fill_default_vars (line 126) | def _fill_default_vars(self, model_args) -> Tuple[Dict, Dict]:
method _get_batch_size (line 140) | def _get_batch_size(self, **model_args) -> int:
method _get_default_model_var (line 151) | def _get_default_model_var(self, name: str, batch_size: int):
method get_full_pose_mats (line 158) | def get_full_pose_mats(self, model_args, add_mean: bool = True):
method get_hand_pose_mat (line 183) | def get_hand_pose_mat(
method forward_joints (line 216) | def forward_joints(self, **kwargs):
method get_extra_joints (line 242) | def get_extra_joints(self, betas, pose_mats, rel_transforms):
method inverse_joints (line 261) | def inverse_joints(self, joints: torch.Tensor, **kwargs):
method joints_to_beta (line 270) | def joints_to_beta(self, joint_unposed: torch.Tensor) -> torch.Tensor:
method forward (line 290) | def forward(self, **kwargs):
FILE: src/egoallo/preprocessing/body_model/skeleton.py
function smpl_kinematic_tree (line 18) | def smpl_kinematic_tree():
function joint_angles_rel_to_glob (line 39) | def joint_angles_rel_to_glob(rel_mats):
function joint_angles_glob_to_rel (line 64) | def joint_angles_glob_to_rel(glob_mats):
FILE: src/egoallo/preprocessing/body_model/specs.py
function smpl_to_openpose (line 219) | def smpl_to_openpose(
FILE: src/egoallo/preprocessing/body_model/utils.py
function run_smpl (line 37) | def run_smpl(
function reflect_pose_aa (line 68) | def reflect_pose_aa(root_orient: Tensor, pose_body: Tensor):
function reflect_root_trajectory (line 81) | def reflect_root_trajectory(
function forward_kinematics (line 109) | def forward_kinematics(
function get_pose_offsets (line 145) | def get_pose_offsets(
function select_vert_params (line 155) | def select_vert_params(
function get_verts_with_transforms (line 168) | def get_verts_with_transforms(
function inverse_kinematics (line 185) | def inverse_kinematics(rot_mats: Tensor, joints: Tensor, parents: Tensor...
function smpl_local_to_global (line 216) | def smpl_local_to_global(
function select_smpl_joints (line 228) | def select_smpl_joints(joints_full):
function get_openpose_from_smpl (line 236) | def get_openpose_from_smpl(joints_smpl, model_type="smplh"):
function convert_local_pose_to_aa (line 259) | def convert_local_pose_to_aa(pose_body: Tensor, rot_rep: str):
function convert_global_pose_to_aa (line 276) | def convert_global_pose_to_aa(pose_glob: Tensor, rot_rep: str):
function load_beta_conversion (line 291) | def load_beta_conversion(path: str) -> Tuple[Tensor, Tensor]:
function convert_model_betas (line 298) | def convert_model_betas(beta: Tensor, A: Tensor, b: Tensor) -> Tensor:
FILE: src/egoallo/preprocessing/geometry/camera.py
function project_from_world (line 10) | def project_from_world(X_w, R_cw, t_cw, intrins):
function proj_2d (line 20) | def proj_2d(xyz, intrins, eps=1e-4):
function proj_h (line 34) | def proj_h(xyzw):
function iproj_depth (line 42) | def iproj_depth(uv, z, intrins):
function iproj (line 55) | def iproj(uv, disp, intrins):
function normalize_coords (line 68) | def normalize_coords(uv, intrins):
function iproj_to_world (line 74) | def iproj_to_world(uv, disp, intrins, extrins, ret_3d=True):
function reproject (line 92) | def reproject(pose_params, intrins, disps, uv, ii, jj):
function proj_2d_jac (line 108) | def proj_2d_jac(X, intrins):
function actp_jac (line 124) | def actp_jac(X1):
function iproj_jac (line 136) | def iproj_jac(X):
function make_homogeneous (line 145) | def make_homogeneous(x):
function make_transform (line 153) | def make_transform(R, t):
function focal2fov (line 168) | def focal2fov(focal, R):
function fov2focal (line 176) | def fov2focal(fov, R):
function lookat_matrix (line 184) | def lookat_matrix(source_pos, target_pos, up):
function normalize (line 201) | def normalize(x):
function view_matrix (line 205) | def view_matrix(z, up, pos):
function average_pose (line 222) | def average_pose(poses):
function make_translation (line 233) | def make_translation(t):
function make_rotation (line 237) | def make_rotation(rx=0, ry=0, rz=0, order="xyz"):
function rotx (line 258) | def rotx(theta):
function roty (line 269) | def roty(theta):
function rotz (line 280) | def rotz(theta):
function identity (line 291) | def identity(shape: Tuple, d=4, **kwargs):
FILE: src/egoallo/preprocessing/geometry/helpers.py
function make_transform (line 7) | def make_transform(R, t):
function transform_points (line 22) | def transform_points(T, x):
function batch_apply_Rt (line 32) | def batch_apply_Rt(R, t, x):
function transform_global_to_rel (line 41) | def transform_global_to_rel(T_glob):
function transform_rel_to_global (line 53) | def transform_rel_to_global(T_rel):
function RT_global_to_rel (line 68) | def RT_global_to_rel(R_glob, t_glob):
function RT_rel_to_global (line 79) | def RT_rel_to_global(R_rel, t_rel):
function joints_local_to_global (line 90) | def joints_local_to_global(
function joints_global_to_local (line 112) | def joints_global_to_local(root_orient_mat, trans, joints_glob, joints_v...
function align_pcl (line 134) | def align_pcl(Y, X, weight=None, fixed_scale=False):
function get_translation_scale (line 187) | def get_translation_scale(fps=30):
function estimate_velocity (line 195) | def estimate_velocity(data_seq, h=1 / 30):
function estimate_angular_velocity (line 207) | def estimate_angular_velocity(rot_seq, h=1 / 30):
FILE: src/egoallo/preprocessing/geometry/plane.py
function transform_align_body_right (line 10) | def transform_align_body_right(root_orient_mat, trans, **kwargs):
function rotation_align_body_right (line 22) | def rotation_align_body_right(
function compute_world2aligned (line 47) | def compute_world2aligned(T_w0, **kwargs):
function rotation_align_vecs (line 60) | def rotation_align_vecs(src, target):
function compute_point_height (line 74) | def compute_point_height(point, floor_plane):
function compute_world2floor (line 87) | def compute_world2floor(
function compute_plane_transform (line 113) | def compute_plane_transform(
function fit_plane (line 137) | def fit_plane(
function parse_floor_plane (line 165) | def parse_floor_plane(floor_plane: Tensor, force_sign: int = -1) -> Tensor:
function force_plane_direction (line 187) | def force_plane_direction(
function compute_plane_intersection (line 205) | def compute_plane_intersection(point, direction, plane, eps=1e-5):
function project_vector (line 223) | def project_vector(x, d):
function bdot (line 232) | def bdot(A1, A2, keepdim=True, **kwargs):
FILE: src/egoallo/preprocessing/geometry/rotation.py
function get_rot_rep_shape (line 6) | def get_rot_rep_shape(rot_rep:str) -> Tuple:
function convert_rotation (line 17) | def convert_rotation(rot, src_rep, tgt_rep):
function rodrigues_vec_to_matrix (line 57) | def rodrigues_vec_to_matrix(rot_vecs, dtype=torch.float32):
function matrix_to_axis_angle (line 83) | def matrix_to_axis_angle(matrix):
function axis_angle_to_matrix (line 93) | def axis_angle_to_matrix(rot_vec):
function axis_angle_to_cont_6d (line 98) | def axis_angle_to_cont_6d(rot_vec):
function matrix_to_cont_6d (line 107) | def matrix_to_cont_6d(matrix):
function cont_6d_to_matrix (line 115) | def cont_6d_to_matrix(cont_6d):
function cont_6d_to_axis_angle (line 130) | def cont_6d_to_axis_angle(cont_6d):
function quaternion_to_axis_angle (line 135) | def quaternion_to_axis_angle(quaternion, eps=1e-5):
function quaternion_to_matrix (line 170) | def quaternion_to_matrix(quaternion):
function axis_angle_to_quaternion (line 218) | def axis_angle_to_quaternion(axis_angle, eps=1e-5):
function matrix_to_quaternion (line 240) | def matrix_to_quaternion(matrix, eps=1e-6):
function quaternion_mul (line 325) | def quaternion_mul(q0, q1):
function quaternion_inverse (line 338) | def quaternion_inverse(q, eps=1e-5):
function quaternion_slerp (line 348) | def quaternion_slerp(t, q0, q1, eps=1e-5):
FILE: src/egoallo/preprocessing/geometry/transforms/_base.py
class MatrixLieGroup (line 13) | class MatrixLieGroup(abc.ABC):
method __init__ (line 31) | def __init__(self, parameters: torch.Tensor):
method __mul__ (line 45) | def __mul__(self: GroupType, other: GroupType) -> GroupType:
method __mul__ (line 49) | def __mul__(self, other: hints.Array) -> torch.Tensor:
method __mul__ (line 52) | def __mul__(
method Identity (line 72) | def Identity(
method from_matrix (line 83) | def from_matrix(cls: Type[GroupType], matrix: hints.Array) -> GroupType:
method matrix (line 96) | def matrix(self) -> torch.Tensor:
method parameters (line 100) | def parameters(self) -> torch.Tensor:
method data (line 104) | def data(self) -> torch.Tensor:
method __getitem__ (line 107) | def __getitem__(self, index):
method shape (line 111) | def shape(self):
method act (line 117) | def act(self, target: hints.Array) -> torch.Tensor:
method mul (line 128) | def mul(self: GroupType, other: GroupType) -> GroupType:
method exp (line 137) | def exp(cls: Type[GroupType], tangent: hints.Array) -> GroupType:
method log (line 148) | def log(self) -> torch.Tensor:
method adjoint (line 156) | def adjoint(self, **kwargs) -> torch.Tensor:
method inv (line 172) | def inv(self: GroupType) -> GroupType:
method normalize (line 180) | def normalize(self: GroupType) -> GroupType:
class SOBase (line 188) | class SOBase(MatrixLieGroup):
class SEBase (line 195) | class SEBase(Generic[ContainedSOType], MatrixLieGroup):
method from_rotation_and_translation (line 206) | def from_rotation_and_translation(
method from_rotation (line 223) | def from_rotation(cls: Type[SEGroupType], rotation: ContainedSOType) -...
method from_translation (line 234) | def from_translation(
method rotation (line 240) | def rotation(self) -> ContainedSOType:
method translation (line 244) | def translation(self) -> torch.Tensor:
method act (line 251) | def act(self, target: hints.Array) -> torch.Tensor:
method mul (line 267) | def mul(self: SEGroupType, other: SEGroupType) -> SEGroupType:
method inv (line 275) | def inv(self: SEGroupType) -> SEGroupType:
method normalize (line 284) | def normalize(self: SEGroupType) -> SEGroupType:
FILE: src/egoallo/preprocessing/geometry/transforms/_se2.py
class SE2 (line 20) | class SE2(_base.SEBase[SO2]):
method __repr__ (line 35) | def __repr__(self) -> str:
method from_xy_theta (line 41) | def from_xy_theta(x: hints.Scalar, y: hints.Scalar, theta: hints.Scala...
method from_rotation_and_translation (line 55) | def from_rotation_and_translation(
method from_translation (line 66) | def from_translation(cls, translation: torch.Tensor) -> "SE2":
method rotation (line 77) | def rotation(self) -> SO2:
method translation (line 81) | def translation(self) -> torch.Tensor:
method Identity (line 88) | def Identity(shape: Optional[Tuple] = (), **kwargs) -> "SE2":
method from_matrix (line 98) | def from_matrix(matrix: hints.Array) -> "SE2":
method parameters (line 109) | def parameters(self) -> torch.Tensor:
method matrix (line 113) | def matrix(self) -> torch.Tensor:
method exp (line 125) | def exp(tangent: hints.Array) -> "SE2":
method log (line 171) | def log(self) -> torch.Tensor:
method adjoint (line 217) | def adjoint(self, **kwargs) -> torch.Tensor:
FILE: src/egoallo/preprocessing/geometry/transforms/_se3.py
function _skew (line 14) | def _skew(omega: torch.Tensor) -> torch.Tensor:
class SE3 (line 36) | class SE3(_base.SEBase[SO3]):
method __repr__ (line 51) | def __repr__(self) -> str:
method from_rotation_and_translation (line 60) | def from_rotation_and_translation(
method from_translation (line 69) | def from_translation(cls, translation: torch.Tensor) -> "SE3":
method rotation (line 80) | def rotation(self) -> SO3:
method translation (line 84) | def translation(self) -> torch.Tensor:
method Identity (line 91) | def Identity(shape: Optional[Tuple] = (), **kwargs) -> "SE3":
method from_matrix (line 101) | def from_matrix(matrix: torch.Tensor) -> "SE3":
method matrix (line 112) | def matrix(self) -> torch.Tensor:
method parameters (line 124) | def parameters(self) -> torch.Tensor:
method exp (line 131) | def exp(tangent: torch.Tensor) -> "SE3":
method log (line 177) | def log(self) -> torch.Tensor:
method adjoint (line 213) | def adjoint(self) -> torch.Tensor:
FILE: src/egoallo/preprocessing/geometry/transforms/_so2.py
class SO2 (line 20) | class SO2(_base.SOBase):
method __repr__ (line 34) | def __repr__(self) -> str:
method from_radians (line 39) | def from_radians(theta: hints.Scalar) -> SO2:
method as_radians (line 46) | def as_radians(self) -> torch.Tensor:
method Identity (line 55) | def Identity(shape: Optional[Tuple] = (), **kwargs) -> SO2:
method from_matrix (line 65) | def from_matrix(matrix: torch.Tensor) -> SO2:
method matrix (line 72) | def matrix(self) -> torch.Tensor:
method parameters (line 81) | def parameters(self) -> torch.Tensor:
method act (line 87) | def act(self, target: torch.Tensor) -> torch.Tensor:
method mul (line 92) | def mul(self, other: SO2) -> SO2:
method exp (line 101) | def exp(tangent: torch.Tensor) -> SO2:
method log (line 107) | def log(self) -> torch.Tensor:
method adjoint (line 113) | def adjoint(self, **kwargs) -> torch.Tensor:
method inv (line 117) | def inv(self) -> SO2:
method normalize (line 122) | def normalize(self) -> SO2:
FILE: src/egoallo/preprocessing/geometry/transforms/_so3.py
class SO3 (line 21) | class SO3(_base.SOBase):
method __repr__ (line 36) | def __repr__(self) -> str:
method from_x_radians (line 41) | def from_x_radians(theta: torch.Tensor) -> SO3:
method from_y_radians (line 51) | def from_y_radians(theta: torch.Tensor) -> SO3:
method from_z_radians (line 61) | def from_z_radians(theta: torch.Tensor) -> SO3:
method from_rpy_radians (line 71) | def from_rpy_radians(
method from_quaternion_xyzw (line 89) | def from_quaternion_xyzw(xyzw: torch.Tensor) -> SO3:
method as_quaternion_xyzw (line 100) | def as_quaternion_xyzw(self) -> torch.Tensor:
method as_rpy_radians (line 104) | def as_rpy_radians(self) -> hints.RollPitchYaw:
method compute_roll_radians (line 116) | def compute_roll_radians(self) -> torch.Tensor:
method compute_pitch_radians (line 125) | def compute_pitch_radians(self) -> torch.Tensor:
method compute_yaw_radians (line 134) | def compute_yaw_radians(self) -> torch.Tensor:
method Identity (line 147) | def Identity(shape: Optional[Tuple] = (), **kwargs) -> SO3:
method from_matrix (line 157) | def from_matrix(matrix: torch.Tensor) -> SO3:
method matrix (line 243) | def matrix(self) -> torch.Tensor:
method parameters (line 263) | def parameters(self) -> torch.Tensor:
method act (line 269) | def act(self, target: torch.Tensor) -> torch.Tensor:
method mul (line 277) | def mul(self, other: SO3) -> SO3:
method exp (line 294) | def exp(tangent: torch.Tensor) -> SO3:
method log (line 340) | def log(self) -> torch.Tensor:
method adjoint (line 377) | def adjoint(self) -> torch.Tensor:
method inv (line 381) | def inv(self) -> SO3:
method normalize (line 387) | def normalize(self) -> SO3:
FILE: src/egoallo/preprocessing/geometry/transforms/hints/__init__.py
class RollPitchYaw (line 14) | class RollPitchYaw(NamedTuple):
FILE: src/egoallo/preprocessing/geometry/transforms/utils/_utils.py
function get_epsilon (line 12) | def get_epsilon(dtype: torch.dtype) -> float:
function register_lie_group (line 27) | def register_lie_group(
FILE: src/egoallo/preprocessing/util/tensor.py
function batch_sum (line 10) | def batch_sum(x, nldims=1):
function batch_mean (line 21) | def batch_mean(x, nldims=1):
function pad_dim (line 32) | def pad_dim(x, max_len, dim=0, start=0, **kwargs):
function pad_back (line 55) | def pad_back(x, max_len, dim=0, **kwargs):
function pad_front (line 59) | def pad_front(x, max_len, dim=0, **kwargs):
function read_image (line 64) | def read_image(path, scale=1):
function move_to (line 76) | def move_to(obj: T, device) -> T:
function detach_all (line 86) | def detach_all(obj: T) -> T:
function to_torch (line 96) | def to_torch(obj):
function to_np (line 106) | def to_np(obj):
function load_npz_as_dict (line 116) | def load_npz_as_dict(path, **kwargs):
function get_device (line 121) | def get_device(i=0):
function invert_nested_dict (line 126) | def invert_nested_dict(d):
function batchify_dicts (line 137) | def batchify_dicts(dict_list: List[Dict]) -> Dict:
function batchify_recursive (line 148) | def batchify_recursive(dict_list: List[Dict], levels: int = -1):
function unbatch_dict (line 173) | def unbatch_dict(batched_dict, batch_size):
function get_batch_element (line 185) | def get_batch_element(batch, idx, batch_size):
function narrow_dict (line 201) | def narrow_dict(input_dict, tdim, start, length):
function narrow_list (line 213) | def narrow_list(input_list, tdim, start, length):
function narrow_obj (line 217) | def narrow_obj(v, tdim, start, length):
function scatter_intervals (line 229) | def scatter_intervals(tensor, start, end, T: int = -1):
function get_scatter_mask (line 261) | def get_scatter_mask(start, end, T):
function select_intervals (line 274) | def select_intervals(series, start, end, pad_len: int = -1):
function get_select_mask (line 291) | def get_select_mask(start, end):
function time_segment_idcs (line 302) | def time_segment_idcs(start, end, min_len: int = -1, clip: bool = True):
FILE: src/egoallo/sampling.py
function quadratic_ts (line 21) | def quadratic_ts() -> np.ndarray:
class CosineNoiseScheduleConstants (line 30) | class CosineNoiseScheduleConstants(TensorDataclass):
method compute (line 38) | def compute(timesteps: int, s: float = 0.008) -> CosineNoiseScheduleCo...
function run_sampling_with_stitching (line 59) | def run_sampling_with_stitching(
FILE: src/egoallo/tensor_dataclass.py
class TensorDataclass (line 8) | class TensorDataclass:
method __init_subclass__ (line 13) | def __init_subclass__(cls) -> None:
method to (line 16) | def to(self, device: torch.device | str) -> Self:
method as_nested_dict (line 27) | def as_nested_dict(self, numpy: bool) -> dict[str, Any]:
method map (line 47) | def map(self, fn: Callable[[torch.Tensor], torch.Tensor]) -> Self:
FILE: src/egoallo/training_loss.py
class TrainingLossConfig (line 19) | class TrainingLossConfig:
class TrainingLossComputer (line 35) | class TrainingLossComputer:
method __init__ (line 39) | def __init__(self, config: TrainingLossConfig, device: torch.device) -...
method compute_denoising_loss (line 59) | def compute_denoising_loss(
FILE: src/egoallo/training_utils.py
function flattened_hparam_dict_from_dataclass (line 25) | def flattened_hparam_dict_from_dataclass(
function pdb_safety_net (line 54) | def pdb_safety_net():
class SizedIterable (line 76) | class SizedIterable[ContainedType](Iterable[ContainedType], Sized, Proto...
class LoopMetrics (line 85) | class LoopMetrics:
function range_with_metrics (line 92) | def range_with_metrics(stop: int, /) -> SizedIterable[LoopMetrics]: ...
function range_with_metrics (line 96) | def range_with_metrics(start: int, stop: int, /) -> SizedIterable[LoopMe...
function range_with_metrics (line 100) | def range_with_metrics(
function range_with_metrics (line 105) | def range_with_metrics(*args: int) -> SizedIterable[LoopMetrics]:
class _RangeWithMetrics (line 112) | class _RangeWithMetrics:
method __iter__ (line 115) | def __iter__(self):
method __len__ (line 120) | def __len__(self) -> int:
function loop_metric_generator (line 124) | def loop_metric_generator(counter_init: int = 0) -> Generator[LoopMetric...
function get_git_commit_hash (line 161) | def get_git_commit_hash(cwd: Path | None = None) -> str:
function get_git_diff (line 172) | def get_git_diff(cwd: Path | None = None) -> str:
FILE: src/egoallo/transforms/_base.py
class MatrixLieGroup (line 23) | class MatrixLieGroup(abc.ABC):
method __init__ (line 41) | def __init__(
method __matmul__ (line 56) | def __matmul__(self: GroupType, other: GroupType) -> GroupType: ...
method __matmul__ (line 59) | def __matmul__(self, other: Tensor) -> Tensor: ...
method __matmul__ (line 61) | def __matmul__(
method identity (line 81) | def identity(
method from_matrix (line 92) | def from_matrix(cls: Type[GroupType], matrix: Tensor) -> GroupType:
method as_matrix (line 105) | def as_matrix(self) -> Tensor:
method parameters (line 109) | def parameters(self) -> Tensor:
method apply (line 115) | def apply(self, target: Tensor) -> Tensor:
method multiply (line 126) | def multiply(self: Self, other: Self) -> Self:
method exp (line 135) | def exp(cls: Type[GroupType], tangent: Tensor) -> GroupType:
method log (line 146) | def log(self) -> Tensor:
method adjoint (line 154) | def adjoint(self) -> Tensor:
method inverse (line 171) | def inverse(self: GroupType) -> GroupType:
method normalize (line 179) | def normalize(self: GroupType) -> GroupType:
method get_batch_axes (line 199) | def get_batch_axes(self) -> Tuple[int, ...]:
class SOBase (line 206) | class SOBase(MatrixLieGroup):
class SEBase (line 213) | class SEBase(Generic[ContainedSOType], MatrixLieGroup):
method from_rotation_and_translation (line 224) | def from_rotation_and_translation(
method from_rotation (line 241) | def from_rotation(cls: Type[SEGroupType], rotation: ContainedSOType) -...
method rotation (line 251) | def rotation(self) -> ContainedSOType:
method translation (line 255) | def translation(self) -> Tensor:
method apply (line 262) | def apply(self, target: Tensor) -> Tensor:
method multiply (line 267) | def multiply(self: SEGroupType, other: SEGroupType) -> SEGroupType: #...
method inverse (line 275) | def inverse(self: SEGroupType) -> SEGroupType:
method normalize (line 284) | def normalize(self: SEGroupType) -> SEGroupType:
FILE: src/egoallo/transforms/_se3.py
function _skew (line 15) | def _skew(omega: Tensor) -> Tensor:
class SE3 (line 37) | class SE3(_base.SEBase[SO3]):
method __repr__ (line 50) | def __repr__(self) -> str:
method from_rotation_and_translation (line 59) | def from_rotation_and_translation(
method rotation (line 68) | def rotation(self) -> SO3:
method translation (line 72) | def translation(self) -> Tensor:
method identity (line 79) | def identity(cls, device: Union[torch.device, str], dtype: torch.dtype...
method from_matrix (line 88) | def from_matrix(cls, matrix: Tensor) -> SE3:
method as_matrix (line 99) | def as_matrix(self) -> Tensor:
method parameters (line 111) | def parameters(self) -> Tensor:
method exp (line 118) | def exp(cls, tangent: Tensor) -> SE3:
method log (line 167) | def log(self) -> Tensor:
method adjoint (line 212) | def adjoint(self) -> Tensor:
FILE: src/egoallo/transforms/_so3.py
class SO3 (line 21) | class SO3(_base.SOBase):
method __repr__ (line 34) | def __repr__(self) -> str:
method from_x_radians (line 39) | def from_x_radians(theta: Tensor) -> SO3:
method from_y_radians (line 52) | def from_y_radians(theta: Tensor) -> SO3:
method from_z_radians (line 65) | def from_z_radians(theta: Tensor) -> SO3:
method from_rpy_radians (line 78) | def from_rpy_radians(
method from_quaternion_xyzw (line 101) | def from_quaternion_xyzw(xyzw: Tensor) -> SO3:
method as_quaternion_xyzw (line 116) | def as_quaternion_xyzw(self) -> Tensor:
method identity (line 124) | def identity(cls, device: Union[torch.device, str], dtype: torch.dtype...
method from_matrix (line 129) | def from_matrix(cls, matrix: Tensor) -> SO3:
method as_matrix (line 214) | def as_matrix(self) -> Tensor:
method parameters (line 234) | def parameters(self) -> Tensor:
method apply (line 240) | def apply(self, target: Tensor) -> Tensor:
method multiply (line 249) | def multiply(self, other: SO3) -> SO3: # type: ignore
method exp (line 266) | def exp(cls, tangent: Tensor) -> SO3:
method log (line 308) | def log(self) -> Tensor:
method adjoint (line 341) | def adjoint(self) -> Tensor:
method inverse (line 345) | def inverse(self) -> SO3:
method normalize (line 351) | def normalize(self) -> SO3:
FILE: src/egoallo/transforms/utils/_utils.py
function get_epsilon (line 12) | def get_epsilon(dtype: torch.dtype) -> float:
function register_lie_group (line 27) | def register_lie_group(
FILE: src/egoallo/vis_helpers.py
class SplatArgs (line 23) | class SplatArgs(TypedDict):
function load_splat_file (line 34) | def load_splat_file(splat_path: Path, center: bool = False) -> SplatArgs:
function load_ply_file (line 78) | def load_ply_file(ply_file_path: Path, center: bool = False) -> SplatArgs:
function add_splat_to_viser (line 111) | def add_splat_to_viser(
function visualize_traj_and_hand_detections (line 131) | def visualize_traj_and_hand_detections(
Condensed preview — 56 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (417K chars).
[
{
"path": ".github/workflows/pyright.yml",
"chars": 813,
"preview": "name: pyright\n\non:\n push:\n branches: [main]\n pull_request:\n branches: [main]\n\njobs:\n pyright:\n runs-on: ubun"
},
{
"path": ".gitignore",
"chars": 218,
"preview": "*.swp\n*.swo\n*.pyc\n*.egg-info\n*.ipynb_checkpoints\n__pycache__\n.coverage\nhtmlcov\n.mypy_cache\n.dmypy.json\n.hypothesis\n.envr"
},
{
"path": "0a_preprocess_training_data.py",
"chars": 28298,
"preview": "\"\"\"Convert raw AMASS data to HuMoR-style npz format.\n\nMostly taken from\nhttps://github.com/davrempe/humor/blob/main/humo"
},
{
"path": "0b_preprocess_training_data.py",
"chars": 3089,
"preview": "\"\"\"Translate data from HuMoR-style npz format to an hdf5-based one.\n\nDue to AMASS licensing, we unfortunately can't re-d"
},
{
"path": "1_train_motion_prior.py",
"chars": 8236,
"preview": "\"\"\"Training script for EgoAllo diffusion model using HuggingFace accelerate.\"\"\"\n\nimport dataclasses\nimport shutil\nfrom p"
},
{
"path": "2_run_hamer_on_vrs.py",
"chars": 7933,
"preview": "\"\"\"Script to run HaMeR on VRS data and save outputs to a pickle file.\"\"\"\n\nimport pickle\nimport shutil\nfrom pathlib impor"
},
{
"path": "3_aria_inference.py",
"chars": 7072,
"preview": "from __future__ import annotations\n\nimport dataclasses\nimport time\nfrom pathlib import Path\n\nimport numpy as np\nimport t"
},
{
"path": "4_visualize_outputs.py",
"chars": 9051,
"preview": "from __future__ import annotations\n\nimport io\nfrom pathlib import Path\nfrom typing import Callable\n\nimport cv2\nimport im"
},
{
"path": "5_eval_body_metrics.py",
"chars": 6336,
"preview": "\"\"\"Example script for computing body metrics on the test split of the AMASS dataset.\n\nThis is not the exact script we us"
},
{
"path": "LICENSE",
"chars": 1065,
"preview": "MIT License\n\nCopyright (c) 2024 Brent Yi\n\nPermission is hereby granted, free of charge, to any person obtaining a copy\no"
},
{
"path": "README.md",
"chars": 7021,
"preview": "# egoallo\n\n**[Project page](https://egoallo.github.io/) •\n[arXiv](https://arxiv.org/abs/2410.03665)**\n\nCode release"
},
{
"path": "download_checkpoint_and_data.sh",
"chars": 505,
"preview": "# Script for downloading model checkpoint and example inputs/outputs.\n\n# egoallo_checkpoint_april13.zip (552 MB)\ngdown h"
},
{
"path": "pyproject.toml",
"chars": 2204,
"preview": "[build-system]\nrequires = [\"setuptools>=61.0\"]\nbuild-backend = \"setuptools.build_meta\"\n\n[project]\nname = \"egoallo\"\nversi"
},
{
"path": "src/egoallo/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "src/egoallo/fncsmpl.py",
"chars": 13039,
"preview": "\"\"\"Somewhat opinionated wrapper for the SMPL-H body model.\n\nVery little of it is specific to SMPL-H. This could very eas"
},
{
"path": "src/egoallo/fncsmpl_extensions.py",
"chars": 2548,
"preview": "\"\"\"EgoAllo-specific SMPL utilities.\"\"\"\n\nfrom __future__ import annotations\n\n\nimport numpy as np\nimport torch\nfrom jaxtyp"
},
{
"path": "src/egoallo/fncsmpl_jax.py",
"chars": 10927,
"preview": "\"\"\"SMPL-H model, implemented in JAX.\n\nVery little of it is specific to SMPL-H. This could very easily be adapted for oth"
},
{
"path": "src/egoallo/guidance_optimizer_jax.py",
"chars": 32917,
"preview": "\"\"\"Optimize constraints using Levenberg-Marquardt.\"\"\"\n\nfrom __future__ import annotations\n\nimport os\n\nfrom .hand_detecti"
},
{
"path": "src/egoallo/hand_detection_structs.py",
"chars": 16894,
"preview": "\"\"\"Data structure definition that we use for hand detections.\n\nWe'll run HaMeR, produce the dictionary defined by `Saved"
},
{
"path": "src/egoallo/inference_utils.py",
"chars": 6166,
"preview": "\"\"\"Functions that are useful for inference scripts.\"\"\"\n\nfrom __future__ import annotations\n\nfrom dataclasses import data"
},
{
"path": "src/egoallo/metrics_helpers.py",
"chars": 8009,
"preview": "from typing import Literal, overload\n\nimport numpy as np\nimport torch\nfrom jaxtyping import Float\nfrom torch import Tens"
},
{
"path": "src/egoallo/network.py",
"chars": 29152,
"preview": "from __future__ import annotations\n\nfrom dataclasses import dataclass\nfrom functools import cache, cached_property\nfrom "
},
{
"path": "src/egoallo/preprocessing/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "src/egoallo/preprocessing/body_model/__init__.py",
"chars": 100,
"preview": "from .body_model import BodyModel\nfrom .skeleton import *\nfrom .specs import *\nfrom .utils import *\n"
},
{
"path": "src/egoallo/preprocessing/body_model/body_model.py",
"chars": 11328,
"preview": "from loguru import logger as guru\nimport os\nfrom einops import rearrange\nimport torch\nimport torch.nn as nn\nfrom typing "
},
{
"path": "src/egoallo/preprocessing/body_model/skeleton.py",
"chars": 2719,
"preview": "import numpy as np\nimport torch\n\nfrom .specs import SMPL_PARENTS, SMPL_JOINTS\n\n\n__all__ = [\n \"NUM_KINEMATIC_CHAINS\",\n"
},
{
"path": "src/egoallo/preprocessing/body_model/specs.py",
"chars": 14197,
"preview": "import numpy as np\n\nSMPL_JOINTS = {\n \"hips\": 0,\n \"leftUpLeg\": 1,\n \"rightUpLeg\": 2,\n \"spine\": 3,\n \"leftLeg"
},
{
"path": "src/egoallo/preprocessing/body_model/utils.py",
"chars": 10585,
"preview": "from jaxtyping import Float, Int\nfrom typing import Tuple, Optional\nimport numpy as np\nimport torch\nfrom torch import Te"
},
{
"path": "src/egoallo/preprocessing/geometry/__init__.py",
"chars": 113,
"preview": "from .rotation import *\nfrom .helpers import *\nfrom . import plane\nfrom . import camera\nfrom . import transforms\n"
},
{
"path": "src/egoallo/preprocessing/geometry/camera.py",
"chars": 7299,
"preview": "from typing import Tuple\nimport torch\nimport numpy as np\n\n# import lietorch as tf\nfrom . import transforms as tf\nfrom .h"
},
{
"path": "src/egoallo/preprocessing/geometry/helpers.py",
"chars": 7354,
"preview": "import torch\nimport numpy as np\n\nfrom .rotation import convert_rotation\n\n\ndef make_transform(R, t):\n \"\"\"\n :param R"
},
{
"path": "src/egoallo/preprocessing/geometry/plane.py",
"chars": 7968,
"preview": "from jaxtyping import Float\nfrom typing import Tuple, Optional\nimport torch\nfrom torch import Tensor\nimport torch.nn.fun"
},
{
"path": "src/egoallo/preprocessing/geometry/rotation.py",
"chars": 11070,
"preview": "from typing import Tuple\nimport torch\nfrom torch.nn import functional as F\n\n\ndef get_rot_rep_shape(rot_rep:str) -> Tuple"
},
{
"path": "src/egoallo/preprocessing/geometry/transforms/__init__.py",
"chars": 491,
"preview": "\"\"\"Lie group interface for rigid transforms, ported from\n[jaxlie](https://github.com/brentyi/jaxlie). Used by `viser` in"
},
{
"path": "src/egoallo/preprocessing/geometry/transforms/_base.py",
"chars": 8029,
"preview": "import abc\nfrom typing import ClassVar, Generic, Type, TypeVar, Union, overload, Optional, Tuple\n\nimport torch\nfrom typi"
},
{
"path": "src/egoallo/preprocessing/geometry/transforms/_se2.py",
"chars": 6772,
"preview": "import dataclasses\nfrom typing import Optional, Tuple\n\nimport torch\nimport numpy as onp\nfrom typing_extensions import ov"
},
{
"path": "src/egoallo/preprocessing/geometry/transforms/_se3.py",
"chars": 6795,
"preview": "from __future__ import annotations\n\nimport dataclasses\nfrom typing import Optional, Tuple\n\nimport torch\nfrom typing_exte"
},
{
"path": "src/egoallo/preprocessing/geometry/transforms/_so2.py",
"chars": 3368,
"preview": "from __future__ import annotations\n\nimport dataclasses\nfrom typing import Optional, Tuple\n\nimport torch\nfrom typing_exte"
},
{
"path": "src/egoallo/preprocessing/geometry/transforms/_so3.py",
"chars": 12436,
"preview": "from __future__ import annotations\n\nfrom typing import Optional, Tuple\nimport dataclasses\n\nimport math\nimport torch\nfrom"
},
{
"path": "src/egoallo/preprocessing/geometry/transforms/hints/__init__.py",
"chars": 415,
"preview": "from typing import NamedTuple, Union\n\nimport numpy as np\nimport torch\n\n\nArray = torch.Tensor\n\"\"\"Type alias for `torch.Te"
},
{
"path": "src/egoallo/preprocessing/geometry/transforms/utils/__init__.py",
"chars": 101,
"preview": "from ._utils import get_epsilon, register_lie_group\n\n__all__ = [\"get_epsilon\", \"register_lie_group\"]\n"
},
{
"path": "src/egoallo/preprocessing/geometry/transforms/utils/_utils.py",
"chars": 1065,
"preview": "from typing import TYPE_CHECKING, Callable, Type, TypeVar\n\nimport torch\n\nif TYPE_CHECKING:\n from .._base import Matri"
},
{
"path": "src/egoallo/preprocessing/util/__init__.py",
"chars": 22,
"preview": "from .tensor import *\n"
},
{
"path": "src/egoallo/preprocessing/util/tensor.py",
"chars": 9534,
"preview": "from loguru import logger as guru\nfrom typing import TypeVar, Dict, List\nimport torch\nfrom torch import Tensor\nimport to"
},
{
"path": "src/egoallo/py.typed",
"chars": 0,
"preview": ""
},
{
"path": "src/egoallo/sampling.py",
"chars": 8225,
"preview": "from __future__ import annotations\n\nimport time\n\nimport numpy as np\nimport torch\nfrom jaxtyping import Float\nfrom torch "
},
{
"path": "src/egoallo/tensor_dataclass.py",
"chars": 2429,
"preview": "import dataclasses\nfrom typing import Any, Callable, Self, dataclass_transform\n\nimport torch\n\n\n@dataclass_transform()\ncl"
},
{
"path": "src/egoallo/training_loss.py",
"chars": 9987,
"preview": "\"\"\"Training loss configuration.\"\"\"\n\nimport dataclasses\nfrom typing import Literal\n\nimport torch.utils.data\nfrom jaxtypin"
},
{
"path": "src/egoallo/training_utils.py",
"chars": 5044,
"preview": "\"\"\"Utilities for writing training scripts.\"\"\"\n\nimport dataclasses\nimport pdb\nimport signal\nimport subprocess\nimport sys\n"
},
{
"path": "src/egoallo/transforms/__init__.py",
"chars": 279,
"preview": "\"\"\"Rigid transforms implemented in PyTorch, ported from jaxlie.\"\"\"\n\nfrom . import utils as utils\nfrom ._base import Matr"
},
{
"path": "src/egoallo/transforms/_base.py",
"chars": 7963,
"preview": "import abc\nfrom typing import (\n ClassVar,\n Generic,\n Self,\n Tuple,\n Type,\n TypeVar,\n Union,\n fi"
},
{
"path": "src/egoallo/transforms/_se3.py",
"chars": 6707,
"preview": "from __future__ import annotations\n\nfrom dataclasses import dataclass\nfrom typing import Union, cast, override\n\nimport n"
},
{
"path": "src/egoallo/transforms/_so3.py",
"chars": 10511,
"preview": "from __future__ import annotations\n\nfrom dataclasses import dataclass\nfrom typing import Union, override\n\nimport numpy a"
},
{
"path": "src/egoallo/transforms/utils/__init__.py",
"chars": 101,
"preview": "from ._utils import get_epsilon, register_lie_group\n\n__all__ = [\"get_epsilon\", \"register_lie_group\"]\n"
},
{
"path": "src/egoallo/transforms/utils/_utils.py",
"chars": 992,
"preview": "from typing import TYPE_CHECKING, Callable, Type, TypeVar\n\nimport torch\n\nif TYPE_CHECKING:\n from .._base import Matri"
},
{
"path": "src/egoallo/vis_helpers.py",
"chars": 21525,
"preview": "import time\nfrom pathlib import Path\nfrom typing import Callable, TypedDict\n\nimport numpy as np\nimport numpy.typing as n"
}
]
About this extraction
This page contains the full source code of the brentyi/egoallo GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 56 files (387.7 KB), approximately 103.9k tokens, and a symbol index with 426 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.
Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.