Repository: brentyi/egoallo Branch: main Commit: 0c20ae8c8d48 Files: 56 Total size: 387.7 KB Directory structure: gitextract_96b1_mr0/ ├── .github/ │ └── workflows/ │ └── pyright.yml ├── .gitignore ├── 0a_preprocess_training_data.py ├── 0b_preprocess_training_data.py ├── 1_train_motion_prior.py ├── 2_run_hamer_on_vrs.py ├── 3_aria_inference.py ├── 4_visualize_outputs.py ├── 5_eval_body_metrics.py ├── LICENSE ├── README.md ├── download_checkpoint_and_data.sh ├── pyproject.toml └── src/ └── egoallo/ ├── __init__.py ├── fncsmpl.py ├── fncsmpl_extensions.py ├── fncsmpl_jax.py ├── guidance_optimizer_jax.py ├── hand_detection_structs.py ├── inference_utils.py ├── metrics_helpers.py ├── network.py ├── preprocessing/ │ ├── __init__.py │ ├── body_model/ │ │ ├── __init__.py │ │ ├── body_model.py │ │ ├── skeleton.py │ │ ├── specs.py │ │ └── utils.py │ ├── geometry/ │ │ ├── __init__.py │ │ ├── camera.py │ │ ├── helpers.py │ │ ├── plane.py │ │ ├── rotation.py │ │ └── transforms/ │ │ ├── __init__.py │ │ ├── _base.py │ │ ├── _se2.py │ │ ├── _se3.py │ │ ├── _so2.py │ │ ├── _so3.py │ │ ├── hints/ │ │ │ └── __init__.py │ │ └── utils/ │ │ ├── __init__.py │ │ └── _utils.py │ └── util/ │ ├── __init__.py │ └── tensor.py ├── py.typed ├── sampling.py ├── tensor_dataclass.py ├── training_loss.py ├── training_utils.py ├── transforms/ │ ├── __init__.py │ ├── _base.py │ ├── _se3.py │ ├── _so3.py │ └── utils/ │ ├── __init__.py │ └── _utils.py └── vis_helpers.py ================================================ FILE CONTENTS ================================================ ================================================ FILE: .github/workflows/pyright.yml ================================================ name: pyright on: push: branches: [main] pull_request: branches: [main] jobs: pyright: runs-on: ubuntu-latest strategy: matrix: python-version: ["3.12"] steps: - uses: actions/checkout@v2 - name: Set up Python ${{ matrix.python-version }} uses: actions/setup-python@v1 with: python-version: ${{ matrix.python-version }} - name: Install dependencies run: | pip install uv uv pip install --system -e . uv pip install --system jax uv pip install --system git+https://github.com/brentyi/jaxls.git uv pip install --system git+https://github.com/brentyi/hamer_helper.git uv pip install --system pyright - name: Run pyright run: | pyright . ================================================ FILE: .gitignore ================================================ *.swp *.swo *.pyc *.egg-info *.ipynb_checkpoints __pycache__ .coverage htmlcov .mypy_cache .dmypy.json .hypothesis .envrc .lvimrc .DS_Store .envrc lightning_logs/ outputs/ data/ egoallo_checkpoint_* egoallo_example_* ================================================ FILE: 0a_preprocess_training_data.py ================================================ """Convert raw AMASS data to HuMoR-style npz format. Mostly taken from https://github.com/davrempe/humor/blob/main/humor/scripts/process_amass_data.py, but added gender neutral beta conversion and other utilities. """ import dataclasses import os import time from concurrent.futures import ProcessPoolExecutor from pathlib import Path from typing import Dict, Tuple import matplotlib.pyplot as plt import numpy as np import torch import tyro from loguru import logger as guru from sklearn.cluster import DBSCAN from tqdm import tqdm from egoallo.preprocessing.body_model import ( KEYPT_VERTS, SMPL_JOINTS, BodyModel, reflect_pose_aa, reflect_root_trajectory, run_smpl, ) from egoallo.preprocessing.geometry import convert_rotation, joints_global_to_local from egoallo.preprocessing.util import move_to AMASS_SPLITS = { "train": [ "ACCAD", "BMLhandball", "BMLmovi", "BioMotionLab_NTroje", "CMU", "DFaust_67", "DanceDB", "EKUT", "Eyes_Japan_Dataset", "KIT", "MPI_Limits", "TCD_handMocap", "TotalCapture", ], "val": [ "HumanEva", "MPI_HDM05", "SFU", "MPI_mosh", ], "test": [ "Transitions_mocap", "SSM_synced", ], } AMASS_SPLITS["all"] = AMASS_SPLITS["train"] + AMASS_SPLITS["val"] + AMASS_SPLITS["test"] def load_neutral_beta_conversion(gender: str) -> Tuple[np.ndarray, np.ndarray]: assert gender in ["female", "male"] data = np.load(f"./data/smplh_gender_conversion/{gender}_to_neutral.npz") return data["A"], data["b"] def convert_gender_neutral_beta( beta: np.ndarray, A: np.ndarray, b: np.ndarray ) -> np.ndarray: """ :param beta (*, B) :param A (B, B) :param b (B) beta_neutral = A @ beta_gender + b """ *dims, B = beta.shape A = A.reshape((*(1,) * len(dims), B, B)) b = b.reshape((*(1,) * len(dims), B)) return np.einsum("...ij,...j->...i", A, beta) + b def determine_floor_height_and_contacts( body_joint_seq, fps, vis=False, floor_vel_thresh=0.005, floor_height_offset=0.01, contact_vel_thresh=0.005, # 0.015 contact_toe_height_thresh=0.04, # if static toe above this height contact_ankle_height_thresh=0.08, terrain_height_thresh=0.04, root_height_thresh=0.04, cluster_size_thresh=0.25, discard_terrain_seqs=False, # throw away person steps onto objects (determined by a heuristic) ): """ Taken from https://github.com/davrempe/humor/blob/main/humor/scripts/process_amass_data.py Input: body_joint_seq N x 21 x 3 numpy array Contacts are N x 4 where N is number of frames and each row is left heel/toe, right heel/toe """ num_frames = body_joint_seq.shape[0] # compute toe velocities root_seq = body_joint_seq[:, SMPL_JOINTS["hips"], :] left_toe_seq = body_joint_seq[:, SMPL_JOINTS["leftToeBase"], :] right_toe_seq = body_joint_seq[:, SMPL_JOINTS["rightToeBase"], :] left_toe_vel = np.linalg.norm(left_toe_seq[1:] - left_toe_seq[:-1], axis=1) left_toe_vel = np.append(left_toe_vel, left_toe_vel[-1]) right_toe_vel = np.linalg.norm(right_toe_seq[1:] - right_toe_seq[:-1], axis=1) right_toe_vel = np.append(right_toe_vel, right_toe_vel[-1]) if vis: plt.figure() steps = np.arange(num_frames) plt.plot(steps, left_toe_vel, "-r", label="left vel") plt.plot(steps, right_toe_vel, "-b", label="right vel") plt.legend() plt.show() plt.close() # now foot heights (z is up) left_toe_heights = left_toe_seq[:, 2] right_toe_heights = right_toe_seq[:, 2] root_heights = root_seq[:, 2] if vis: plt.figure() steps = np.arange(num_frames) plt.plot(steps, left_toe_heights, "-r", label="left toe height") plt.plot(steps, right_toe_heights, "-b", label="right toe height") plt.plot(steps, root_heights, "-g", label="root height") plt.legend() plt.show() plt.close() # filter out heights when velocity is greater than some threshold (not in contact) all_inds = np.arange(left_toe_heights.shape[0]) left_static_foot_heights = left_toe_heights[left_toe_vel < floor_vel_thresh] left_static_inds = all_inds[left_toe_vel < floor_vel_thresh] right_static_foot_heights = right_toe_heights[right_toe_vel < floor_vel_thresh] right_static_inds = all_inds[right_toe_vel < floor_vel_thresh] all_static_foot_heights = np.append( left_static_foot_heights, right_static_foot_heights ) all_static_inds = np.append(left_static_inds, right_static_inds) if vis: plt.figure() steps = np.arange(left_static_foot_heights.shape[0]) plt.plot(steps, left_static_foot_heights, "-r", label="left static height") plt.legend() plt.show() plt.close() discard_seq = False if all_static_foot_heights.shape[0] > 0: cluster_heights = [] cluster_root_heights = [] cluster_sizes = [] # cluster foot heights and find one with smallest median clustering = DBSCAN(eps=0.005, min_samples=3).fit( all_static_foot_heights.reshape(-1, 1) ) all_labels = np.unique(clustering.labels_) # print(all_labels) if vis: plt.figure() min_median = min_root_median = float("inf") for cur_label in all_labels: cur_clust = all_static_foot_heights[clustering.labels_ == cur_label] cur_clust_inds = np.unique( all_static_inds[clustering.labels_ == cur_label] ) # inds in the original sequence that correspond to this cluster if vis: plt.scatter( cur_clust, np.zeros_like(cur_clust), label="foot %d" % (cur_label) ) # get median foot height and use this as height cur_median = np.median(cur_clust) cluster_heights.append(cur_median) cluster_sizes.append(cur_clust.shape[0]) # get root information cur_root_clust = root_heights[cur_clust_inds] cur_root_median = np.median(cur_root_clust) cluster_root_heights.append(cur_root_median) if vis: plt.scatter( cur_root_clust, np.zeros_like(cur_root_clust), label="root %d" % (cur_label), ) # update min info if cur_median < min_median: min_median = cur_median min_root_median = cur_root_median # print(cluster_heights) # print(cluster_root_heights) # print(cluster_sizes) if vis: plt.show() plt.close() floor_height = min_median offset_floor_height = ( floor_height - floor_height_offset ) # toe joint is actually inside foot mesh a bit if discard_terrain_seqs: # print(min_median + TERRAIN_HEIGHT_THRESH) # print(min_root_median + ROOT_HEIGHT_THRESH) for cluster_root_height, cluster_height, cluster_size in zip( cluster_root_heights, cluster_heights, cluster_sizes ): root_above_thresh = cluster_root_height > ( min_root_median + root_height_thresh ) toe_above_thresh = cluster_height > (min_median + terrain_height_thresh) cluster_size_above_thresh = cluster_size > int( cluster_size_thresh * fps ) if root_above_thresh and toe_above_thresh and cluster_size_above_thresh: discard_seq = True print("DISCARDING sequence based on terrain interaction!") break else: floor_height = offset_floor_height = 0.0 # now find contacts (feet are below certain velocity and within certain range of floor) # compute heel velocities left_heel_seq = body_joint_seq[:, SMPL_JOINTS["leftFoot"], :] right_heel_seq = body_joint_seq[:, SMPL_JOINTS["rightFoot"], :] left_heel_vel = np.linalg.norm(left_heel_seq[1:] - left_heel_seq[:-1], axis=1) left_heel_vel = np.append(left_heel_vel, left_heel_vel[-1]) right_heel_vel = np.linalg.norm(right_heel_seq[1:] - right_heel_seq[:-1], axis=1) right_heel_vel = np.append(right_heel_vel, right_heel_vel[-1]) left_heel_contact = left_heel_vel < contact_vel_thresh right_heel_contact = right_heel_vel < contact_vel_thresh left_toe_contact = left_toe_vel < contact_vel_thresh right_toe_contact = right_toe_vel < contact_vel_thresh # compute heel heights left_heel_heights = left_heel_seq[:, 2] - floor_height right_heel_heights = right_heel_seq[:, 2] - floor_height left_toe_heights = left_toe_heights - floor_height right_toe_heights = right_toe_heights - floor_height left_heel_contact = np.logical_and( left_heel_contact, left_heel_heights < contact_ankle_height_thresh ) right_heel_contact = np.logical_and( right_heel_contact, right_heel_heights < contact_ankle_height_thresh ) left_toe_contact = np.logical_and( left_toe_contact, left_toe_heights < contact_toe_height_thresh ) right_toe_contact = np.logical_and( right_toe_contact, right_toe_heights < contact_toe_height_thresh ) contacts = np.zeros((num_frames, len(SMPL_JOINTS))) contacts[:, SMPL_JOINTS["leftFoot"]] = left_heel_contact contacts[:, SMPL_JOINTS["leftToeBase"]] = left_toe_contact contacts[:, SMPL_JOINTS["rightFoot"]] = right_heel_contact contacts[:, SMPL_JOINTS["rightToeBase"]] = right_toe_contact # hand contacts left_hand_contact = detect_joint_contact( body_joint_seq, "leftHand", floor_height, contact_vel_thresh, contact_ankle_height_thresh, ) right_hand_contact = detect_joint_contact( body_joint_seq, "rightHand", floor_height, contact_vel_thresh, contact_ankle_height_thresh, ) contacts[:, SMPL_JOINTS["leftHand"]] = left_hand_contact contacts[:, SMPL_JOINTS["rightHand"]] = right_hand_contact # knee contacts left_knee_contact = detect_joint_contact( body_joint_seq, "leftLeg", floor_height, contact_vel_thresh, contact_ankle_height_thresh, ) right_knee_contact = detect_joint_contact( body_joint_seq, "rightLeg", floor_height, contact_vel_thresh, contact_ankle_height_thresh, ) contacts[:, SMPL_JOINTS["leftLeg"]] = left_knee_contact contacts[:, SMPL_JOINTS["rightLeg"]] = right_knee_contact return offset_floor_height, contacts, discard_seq def detect_joint_contact( body_joint_seq, joint_name, floor_height, vel_thresh, height_thresh ): """ Taken from https://github.com/davrempe/humor/blob/main/humor/scripts/process_amass_data.py """ # calc velocity joint_seq = body_joint_seq[:, SMPL_JOINTS[joint_name], :] joint_vel = np.linalg.norm(joint_seq[1:] - joint_seq[:-1], axis=1) joint_vel = np.append(joint_vel, joint_vel[-1]) # determine contact by velocity joint_contact = joint_vel < vel_thresh # compute heights joint_heights = joint_seq[:, 2] - floor_height # compute contact by vel + height joint_contact = np.logical_and(joint_contact, joint_heights < height_thresh) return joint_contact def compute_root_align_mats(root_orient): """ Taken from https://github.com/davrempe/humor/blob/main/humor/scripts/process_amass_data.py compute world to canonical frame for each timestep (rotation around up axis) """ root_orient = torch.as_tensor(root_orient).reshape(-1, 3) # convert aa to matrices root_orient_mat = convert_rotation(root_orient, "aa", "mat").numpy() # rotate root so aligning local body right vector (-x) with world right vector (+x) # with a rotation around the up axis (+z) # in body coordinates body x-axis is to the left body_right = -root_orient_mat[:, :, 0] world2aligned_mat, world2aligned_aa = compute_align_from_body_right(body_right) return world2aligned_mat def compute_joint_align_mats(joint_seq): """ Taken from https://github.com/davrempe/humor/blob/main/humor/scripts/process_amass_data.py Compute world to canonical frame for each timestep (rotation around up axis) from the given joint seq (T x J x 3) """ left_idx = SMPL_JOINTS["leftUpLeg"] right_idx = SMPL_JOINTS["rightUpLeg"] body_right = joint_seq[:, right_idx] - joint_seq[:, left_idx] body_right = body_right / np.linalg.norm(body_right, axis=1)[:, np.newaxis] world2aligned_mat, world2aligned_aa = compute_align_from_body_right(body_right) return world2aligned_mat def compute_align_from_body_right(body_right): """ Taken from https://github.com/davrempe/humor/blob/main/humor/scripts/process_amass_data.py """ world2aligned_angle = np.arccos( body_right[:, 0] / (np.linalg.norm(body_right[:, :2], axis=1) + 1e-8) ) # project to world x axis, and compute angle body_right[:, 2] = 0.0 world2aligned_axis = np.cross(body_right, np.array([[1.0, 0.0, 0.0]])) world2aligned_aa = ( world2aligned_axis / (np.linalg.norm(world2aligned_axis, axis=1)[:, np.newaxis] + 1e-8) ) * world2aligned_angle[:, np.newaxis] world2aligned_mat = convert_rotation( torch.as_tensor(world2aligned_aa).reshape(-1, 3), "aa", "mat" ).numpy() return world2aligned_mat, world2aligned_aa def estimate_velocity(data_seq, h): """ Taken from https://github.com/davrempe/humor/blob/main/humor/scripts/process_amass_data.py Given some data sequence of T timesteps in the shape (T, ...), estimates the velocity for the middle T-2 steps using a second order central difference scheme. - h : step size """ data_tp1 = data_seq[2:] data_tm1 = data_seq[0:-2] data_vel_seq = (data_tp1 - data_tm1) / (2 * h) return data_vel_seq def estimate_angular_velocity(rot_seq, h): """ Taken from https://github.com/davrempe/humor/blob/main/humor/scripts/process_amass_data.py Given a sequence of T rotation matrices, estimates angular velocity at T-2 steps. Input sequence should be of shape (T, ..., 3, 3) """ # see https://en.wikipedia.org/wiki/Angular_velocity#Calculation_from_the_orientation_matrix dRdt = estimate_velocity(rot_seq, h) R = rot_seq[1:-1] RT = np.swapaxes(R, -1, -2) # compute skew-symmetric angular velocity tensor w_mat = np.matmul(dRdt, RT) # pull out angular velocity vector # average symmetric entries w_x = (-w_mat[..., 1, 2] + w_mat[..., 2, 1]) / 2.0 w_y = (w_mat[..., 0, 2] - w_mat[..., 2, 0]) / 2.0 w_z = (-w_mat[..., 0, 1] + w_mat[..., 1, 0]) / 2.0 w = np.stack([w_x, w_y, w_z], axis=-1) return w def load_seq_smpl_params(input_path: str, num_betas: int = 16): guru.info(f"Loading from {input_path}") # load in input data # we leave out "dmpls" and "marker_data"/"marker_label" which are not present in all datasets bdata = np.load(input_path) gender = np.array(bdata["gender"], ndmin=1)[0] gender = str(gender, "utf-8") if isinstance(gender, bytes) else str(gender) fps = bdata["mocap_framerate"] trans = bdata["trans"][:] # global translation num_frames = len(trans) root_orient = bdata["poses"][:, :3] # global root orientation (1 joint) pose_body = bdata["poses"][:, 3:66] # body joint rotations (21 joints) pose_hand = bdata["poses"][:, 66:] # finger articulation joint rotations betas = np.tile( bdata["betas"][None, :num_betas], [num_frames, 1] ) # body shape parameters # correct mislabeled data if input_path.find("BMLhandball") >= 0: fps = 240 if input_path.find("20160930_50032") >= 0 or input_path.find("20161014_50033") >= 0: fps = 59 model_vars = { "trans": trans, "root_orient": root_orient, "pose_body": pose_body, "pose_hand": pose_hand, "betas": betas, } meta = {"fps": fps, "gender": gender, "num_frames": num_frames} guru.info(f"meta {meta}") guru.info(f"model var shapes {str({k: v.shape for k, v in model_vars.items()})}") return model_vars, meta def run_batch_smpl( body_model: BodyModel, device: torch.device, num_total: int, batch_size: int, return_verts: bool = True, **kwargs, ): var_dims = body_model.var_dims var_names = [name for name in kwargs if name in var_dims] model_vars = { name: torch.as_tensor(kwargs[name], dtype=torch.float32).reshape( -1, var_dims[name] ) for name in var_names } fopts = {k: v for k, v in kwargs.items() if k not in var_names} batch_joints, batch_verts = [], [] for sidx in range(0, num_total, batch_size): eidx = min(sidx + batch_size, num_total) batch_model_vars = move_to( {name: x[sidx:eidx].contiguous() for name, x in model_vars.items()}, device ) with torch.no_grad(): joints, verts, _ = run_smpl( body_model, return_verts=return_verts, **batch_model_vars, **fopts ) batch_joints.append(joints.detach().cpu()) if return_verts and verts is not None: batch_verts.append(verts.detach().cpu()) joints_all = torch.cat(batch_joints, dim=0) verts_all = torch.cat(batch_verts, dim=0) if len(batch_verts) > 0 else None return joints_all, verts_all def process_seq( input_path: str, out_path: str, smplh_root: str, dev_id: int, beta_neutral: bool, reflect: bool = False, overwrite: bool = False, **kwargs, ): if not overwrite and os.path.isfile(out_path): guru.info(f"{out_path} already exists, skipping.") return guru.info(f"process {input_path} to {out_path}") model_vars, meta = load_seq_smpl_params(input_path) if beta_neutral: # get the gender neutral beta guru.info("converting betas to gender neutral") A_beta, b_beta = load_neutral_beta_conversion(meta["gender"]) model_vars["betas"] = convert_gender_neutral_beta( model_vars["betas"], A_beta, b_beta ) meta["gender"] = "neutral" process_seq_data( model_vars, meta, out_path, dev_id, smplh_root, reflect=reflect, **kwargs ) def process_seq_data( model_vars: Dict, meta: Dict, out_path: str, dev_id: int, smplh_root: str, reflect: bool = False, split_frame_limit: int = 2000, discard_shorter_than: float = 1.0, # seconds out_fps: int = 30, save_verts: bool = False, save_velocities: bool = True, # save all parameter velocities available ): guru.info(f"Processing seq with meta {meta}") start_t = time.time() gender = meta["gender"] src_fps = meta["fps"] num_frames = meta["num_frames"] # only keep middle 80% of sequences to avoid redundanct static poses sidx, eidx = int(0.1 * num_frames), int(0.9 * num_frames) num_frames = eidx - sidx for name, x in model_vars.items(): model_vars[name] = x[sidx:eidx] guru.info(str({k: v.shape for k, v in model_vars.items()})) # discard if shorter than threshold if num_frames < discard_shorter_than * src_fps: guru.info(f"Sequence shorter than {discard_shorter_than} s, discarding...") return # must do SMPL forward pass to get joints # split into manageable chunks to avoid running out of GPU memory for SMPL device = ( torch.device(f"cuda:{dev_id}") if torch.cuda.is_available() else torch.device("cpu") ) # # smplx tries to read shape properties, even when use_pca=False from smplx.utils import Struct Struct.hands_componentsl = np.zeros(100) # type: ignore Struct.hands_componentsr = np.zeros(100) # type: ignore Struct.hands_meanl = np.zeros(100) # type: ignore Struct.hands_meanr = np.zeros(100) # type: ignore # This defaults to 300, but we have 16 beta parameters. When # 16<300 the SMPL class will set num_betas to 10... from smplx import SMPLH assert SMPLH.SHAPE_SPACE_DIM in (300, 16) SMPLH.SHAPE_SPACE_DIM = 16 # body_model = BodyModel(f"{smplh_root}/{gender}/model.npz", use_pca=False).to(device) model_vars = {k: torch.as_tensor(v).float() for k, v in model_vars.items()} if reflect: rot_og = model_vars["root_orient"] rot_re, model_vars["pose_body"] = reflect_pose_aa( rot_og, model_vars["pose_body"] ) out = body_model.forward(betas=model_vars["betas"][:1].to(device)) root_loc = out.Jtr[:, 0].cpu() # type: ignore model_vars["root_orient"], model_vars["trans"] = reflect_root_trajectory( rot_og, model_vars["trans"], rot_re, root_loc ) body_joint_seq, body_vtx_seq = run_batch_smpl( body_model, device, num_frames, split_frame_limit, return_verts=save_verts, **model_vars, ) joints_glob = body_joint_seq[:, : len(SMPL_JOINTS), :] joint_seq = joints_glob.numpy() guru.info(f"Recovered joints and verts {joint_seq.shape}") out_dict = model_vars.copy() out_dict["joints"] = joint_seq out_dict["joints_loc"], _ = joints_global_to_local( convert_rotation(model_vars["root_orient"], "aa", "mat"), model_vars["trans"], joints_glob, ) if save_verts and body_vtx_seq is not None: out_dict["mojo_verts"] = body_vtx_seq[:, KEYPT_VERTS, :].numpy() # determine floor height and foot contacts floor_height, contacts, discard_seq = determine_floor_height_and_contacts( joint_seq, src_fps ) if discard_seq: guru.info("Terrain interaction detected, discarding...") return guru.info(f"Floor height: {floor_height}") # translate so floor is at z=0 for name in ["trans", "joints", "mojo_verts"]: if name not in out_dict: continue out_dict[name][..., 2] -= floor_height # compute rotation to canonical frame (forward facing +y) for every frame world2aligned_rot = compute_root_align_mats(model_vars["root_orient"]) out_dict.update( { "contacts": contacts, "floor_height": floor_height, "world2aligned_rot": world2aligned_rot, } ) # estimate various velocities based on full frame rate # with second order central differences before downsampling if save_velocities: h = 1.0 / src_fps lin_names = ["trans", "joints", "mojo_verts"] ang_names = ["root_orient", "pose_body"] cur_keys = lin_names + ang_names + ["contacts"] for name in lin_names: if name not in out_dict: continue out_dict[f"{name}_vel"] = estimate_velocity(out_dict[name], h) # root orient for name in ang_names: if name not in out_dict: continue rot_aa = ( torch.as_tensor(out_dict[name]).reshape(num_frames, -1, 3).squeeze() ) rot_mat = convert_rotation(rot_aa, "aa", "mat").numpy() out_dict[f"{name}_vel"] = estimate_angular_velocity(rot_mat, h) # joint up-axis angular velocity (need to compute joint frames first...) # need the joint transform at all steps to find the angular velocity joints_world2aligned_rot = compute_joint_align_mats(joint_seq) joint_orient_vel = -estimate_angular_velocity(joints_world2aligned_rot, h) # only need around z out_dict["joint_orient_vel"] = joint_orient_vel[:, 2] # throw out edge frames for other data so velocities are accurate for name in cur_keys: if name not in out_dict: continue out_dict[name] = out_dict[name][1:-1] num_frames = num_frames - 2 # downsample frames fps_ratio = float(out_fps) / src_fps guru.info(f"Downsamp ratio: {fps_ratio}") new_num_frames = int(fps_ratio * num_frames) guru.info(f"Downsamp num frames: {new_num_frames}") downsamp_inds = np.linspace(0, num_frames - 1, num=new_num_frames, dtype=int) for k, v in out_dict.items(): # print(k, type(v)) if not isinstance(v, (torch.Tensor, np.ndarray)): continue if v.ndim >= 1: # print("downsampling", k) out_dict[k] = v[downsamp_inds] meta = { "fps": out_fps, "num_frames": new_num_frames, "gender": str(gender), } guru.info(f"Seq process time: {time.time() - start_t} s") guru.info(f"Saving data to {out_path}") os.makedirs(os.path.dirname(out_path), exist_ok=True) np.savez(out_path, **meta, **out_dict) @dataclasses.dataclass class Config: data_root: str """Where the AMASS dataset is stored.""" smplh_root: str = "./data/smplh" out_root: str = "./data/processed_30fps_no_skating/" devices: tuple[int, ...] = (0,) """CUDA devices. We use CPU if not available.""" overwrite: bool = False def check_skip(path_name: str) -> bool: """Copied conditions from https://github.com/davrempe/humor/blob/main/humor/scripts/cleanup_amass_data.py""" if "BioMotionLab_NTroje" in path_name and ( "treadmill" in path_name or "normal_" in path_name ): return True if "MPI_HDM05" in path_name and "dg/HDM_dg_07-01" in path_name: return True return False def main(cfg: Config): dsets = AMASS_SPLITS["all"] paths_to_process = [] for dset in dsets: paths_to_process.extend( map(str, Path(f"{cfg.data_root}/{dset}").glob("**/*_poses.npz")) ) dev_ids = cfg.devices guru.info(f"devices {dev_ids}") if len(dev_ids) <= 1: guru.info("processing in sequence") for i, path in tqdm(enumerate(paths_to_process)): if check_skip(path): guru.info(f"skipping {path}") continue fname = path.split(cfg.data_root)[-1].rstrip("/") name, ext = os.path.splitext(fname) out_path = f"{cfg.out_root}/neutral/{name}{ext}" r_out_path = f"{cfg.out_root}/neutral/{name}_reflect{ext}" process_seq( path, out_path, cfg.smplh_root, dev_ids[i % len(dev_ids)], beta_neutral=True, reflect=False, overwrite=cfg.overwrite, ) process_seq( path, r_out_path, cfg.smplh_root, dev_ids[i % len(dev_ids)], beta_neutral=True, reflect=True, overwrite=cfg.overwrite, ) return with ProcessPoolExecutor(max_workers=len(dev_ids)) as exe: for i, path in tqdm(enumerate(paths_to_process)): if check_skip(path): guru.info(f"skipping {path}") continue fname = path.split(cfg.data_root)[-1].rstrip("/") name, ext = os.path.splitext(fname) out_path = f"{cfg.out_root}/neutral/{name}{ext}" r_out_path = f"{cfg.out_root}/neutral/{name}_reflect{ext}" exe.submit( process_seq, path, out_path, cfg.smplh_root, dev_ids[i % len(dev_ids)], beta_neutral=True, reflect=False, overwrite=cfg.overwrite, ) exe.submit( process_seq, path, r_out_path, cfg.smplh_root, dev_ids[i % len(dev_ids)], beta_neutral=True, reflect=True, overwrite=cfg.overwrite, ) if __name__ == "__main__": main(tyro.cli(Config)) ================================================ FILE: 0b_preprocess_training_data.py ================================================ """Translate data from HuMoR-style npz format to an hdf5-based one. Due to AMASS licensing, we unfortunately can't re-distribute our preprocessed dataset. If you have questions or run into issues, please reach out. """ import queue import threading import time from pathlib import Path import h5py import torch import torch.cuda import tyro from egoallo import fncsmpl from egoallo.data.amass import EgoTrainingData def main( smplh_npz_path: Path = Path("./data/smplh/neutral/model.npz"), data_npz_dir: Path = Path("./data/processed_30fps_no_skating/"), output_file: Path = Path("./data/egoalgo_no_skating_dataset.hdf5"), output_list_file: Path = Path("./data/egoalgo_no_skating_dataset_files.txt"), include_hands: bool = True, ) -> None: body_model = fncsmpl.SmplhModel.load(smplh_npz_path) assert torch.cuda.is_available() task_queue = queue.Queue[Path]() for path in list(data_npz_dir.glob("**/*.npz")): task_queue.put_nowait(path) total_count = task_queue.qsize() start_time = time.time() output_hdf5 = h5py.File(output_file, "w") file_list: list[str] = [] def worker(device_idx: int) -> None: device_body_model = body_model.to("cuda:" + str(device_idx)) while True: try: npz_path = task_queue.get_nowait() except queue.Empty: break print(f"Processing {npz_path} on device {device_idx}...") train_data = EgoTrainingData.load_from_npz( device_body_model, npz_path, include_hands=include_hands ) assert "neutral" in str(npz_path) group_name = str(npz_path).rpartition("neutral/")[2] print(f"Writing to group {group_name} on {device_idx}...") group = output_hdf5.create_group(group_name) file_list.append(group_name) for k, v in vars(train_data).items(): # No need to write the mask, which will always be ones when we # load from the npz file! if k == "mask": continue # Chunk into 32 timesteps at a time. assert v.dtype == torch.float32 if v.shape[0] == train_data.T_world_cpf.shape[0]: chunks = (min(32, v.shape[0]),) + v.shape[1:] else: assert v.shape[0] == 1 chunks = v.shape group.create_dataset(k, data=v.numpy(force=True), chunks=chunks) print( f"Finished ~{total_count - task_queue.qsize()}/{total_count},", f"{(total_count - task_queue.qsize()) / total_count * 100:.2f}% in", f"{time.time() - start_time} seconds", ) workers = [ threading.Thread(target=worker, args=(i,)) for i in range(torch.cuda.device_count()) ] for w in workers: w.start() for w in workers: w.join() output_list_file.write_text("\n".join(file_list)) if __name__ == "__main__": tyro.cli(main) ================================================ FILE: 1_train_motion_prior.py ================================================ """Training script for EgoAllo diffusion model using HuggingFace accelerate.""" import dataclasses import shutil from pathlib import Path from typing import Literal import tensorboardX import torch.optim.lr_scheduler import torch.utils.data import tyro import yaml from accelerate import Accelerator, DataLoaderConfiguration from accelerate.utils import ProjectConfiguration from loguru import logger from egoallo import network, training_loss, training_utils from egoallo.data.amass import EgoAmassHdf5Dataset from egoallo.data.dataclass import collate_dataclass @dataclasses.dataclass(frozen=True) class EgoAlloTrainConfig: experiment_name: str dataset_hdf5_path: Path dataset_files_path: Path model: network.EgoDenoiserConfig = network.EgoDenoiserConfig() loss: training_loss.TrainingLossConfig = training_loss.TrainingLossConfig() # Dataset arguments. batch_size: int = 256 """Effective batch size.""" num_workers: int = 2 subseq_len: int = 128 dataset_slice_strategy: Literal[ "deterministic", "random_uniform_len", "random_variable_len" ] = "random_uniform_len" dataset_slice_random_variable_len_proportion: float = 0.3 """Only used if dataset_slice_strategy == 'random_variable_len'.""" train_splits: tuple[Literal["train", "val", "test", "just_humaneva"], ...] = ( "train", "val", ) # Optimizer options. learning_rate: float = 1e-4 weight_decay: float = 1e-4 warmup_steps: int = 1000 max_grad_norm: float = 1.0 def get_experiment_dir(experiment_name: str, version: int = 0) -> Path: """Creates a directory to put experiment files in, suffixed with a version number. Similar to PyTorch lightning.""" experiment_dir = ( Path(__file__).absolute().parent / "experiments" / experiment_name / f"v{version}" ) if experiment_dir.exists(): return get_experiment_dir(experiment_name, version + 1) else: return experiment_dir def run_training( config: EgoAlloTrainConfig, restore_checkpoint_dir: Path | None = None, ) -> None: # Set up experiment directory + HF accelerate. # We're getting to manage logging, checkpoint directories, etc manually, # and just use `accelerate` for distibuted training. experiment_dir = get_experiment_dir(config.experiment_name) assert not experiment_dir.exists() accelerator = Accelerator( project_config=ProjectConfiguration(project_dir=str(experiment_dir)), dataloader_config=DataLoaderConfiguration(split_batches=True), ) writer = ( tensorboardX.SummaryWriter(logdir=str(experiment_dir), flush_secs=10) if accelerator.is_main_process else None ) device = accelerator.device # Initialize experiment. if accelerator.is_main_process: training_utils.pdb_safety_net() # Save various things that might be useful. experiment_dir.mkdir(exist_ok=True, parents=True) (experiment_dir / "git_commit.txt").write_text( training_utils.get_git_commit_hash() ) (experiment_dir / "git_diff.txt").write_text(training_utils.get_git_diff()) (experiment_dir / "run_config.yaml").write_text(yaml.dump(config)) (experiment_dir / "model_config.yaml").write_text(yaml.dump(config.model)) # Add hyperparameters to TensorBoard. assert writer is not None writer.add_hparams( hparam_dict=training_utils.flattened_hparam_dict_from_dataclass(config), metric_dict={}, name=".", # Hack to avoid timestamped subdirectory. ) # Write logs to file. logger.add(experiment_dir / "trainlog.log", rotation="100 MB") # Setup. model = network.EgoDenoiser(config.model) train_loader = torch.utils.data.DataLoader( dataset=EgoAmassHdf5Dataset( config.dataset_hdf5_path, config.dataset_files_path, splits=config.train_splits, subseq_len=config.subseq_len, cache_files=True, slice_strategy=config.dataset_slice_strategy, random_variable_len_proportion=config.dataset_slice_random_variable_len_proportion, ), batch_size=config.batch_size, shuffle=True, num_workers=config.num_workers, persistent_workers=config.num_workers > 0, pin_memory=True, collate_fn=collate_dataclass, drop_last=True, ) optim = torch.optim.AdamW( # type: ignore model.parameters(), lr=config.learning_rate, weight_decay=config.weight_decay, ) scheduler = torch.optim.lr_scheduler.LambdaLR( optim, lr_lambda=lambda step: min(1.0, step / config.warmup_steps) ) # HF accelerate setup. We use this for parallelism, etc! model, train_loader, optim, scheduler = accelerator.prepare( model, train_loader, optim, scheduler ) accelerator.register_for_checkpointing(scheduler) # Restore an existing model checkpoint. if restore_checkpoint_dir is not None: accelerator.load_state(str(restore_checkpoint_dir)) # Get the initial step count. if restore_checkpoint_dir is not None and restore_checkpoint_dir.name.startswith( "checkpoint_" ): step = int(restore_checkpoint_dir.name.partition("_")[2]) else: step = int(scheduler.state_dict()["last_epoch"]) assert step == 0 or restore_checkpoint_dir is not None, step # Save an initial checkpoint. Not a big deal but currently this has an # off-by-one error, in that `step` means something different in this # checkpoint vs the others. accelerator.save_state(str(experiment_dir / f"checkpoints_{step}")) # Run training loop! loss_helper = training_loss.TrainingLossComputer(config.loss, device=device) loop_metrics_gen = training_utils.loop_metric_generator(counter_init=step) prev_checkpoint_path: Path | None = None while True: for train_batch in train_loader: loop_metrics = next(loop_metrics_gen) step = loop_metrics.counter loss, log_outputs = loss_helper.compute_denoising_loss( model, unwrapped_model=accelerator.unwrap_model(model), train_batch=train_batch, ) log_outputs["learning_rate"] = scheduler.get_last_lr()[0] accelerator.log(log_outputs, step=step) accelerator.backward(loss) if accelerator.sync_gradients: accelerator.clip_grad_norm_(model.parameters(), config.max_grad_norm) optim.step() scheduler.step() optim.zero_grad(set_to_none=True) # The rest of the loop will only be executed by the main process. if not accelerator.is_main_process: continue # Logging. if step % 10 == 0: assert writer is not None for k, v in log_outputs.items(): writer.add_scalar(k, v, step) # Print status update to terminal. if step % 20 == 0: mem_free, mem_total = torch.cuda.mem_get_info() logger.info( f"step: {step} ({loop_metrics.iterations_per_sec:.2f} it/sec)" f" mem: {(mem_total - mem_free) / 1024**3:.2f}/{mem_total / 1024**3:.2f}G" f" lr: {scheduler.get_last_lr()[0]:.7f}" f" loss: {loss.item():.6f}" ) # Checkpointing. if step % 5000 == 0: # Save checkpoint. checkpoint_path = experiment_dir / f"checkpoints_{step}" accelerator.save_state(str(checkpoint_path)) logger.info(f"Saved checkpoint to {checkpoint_path}") # Keep checkpoints from only every 100k steps. if prev_checkpoint_path is not None: shutil.rmtree(prev_checkpoint_path) prev_checkpoint_path = None if step % 100_000 == 0 else checkpoint_path del checkpoint_path if __name__ == "__main__": tyro.cli(run_training) ================================================ FILE: 2_run_hamer_on_vrs.py ================================================ """Script to run HaMeR on VRS data and save outputs to a pickle file.""" import pickle import shutil from pathlib import Path import cv2 import imageio.v3 as iio import numpy as np import tyro from egoallo.hand_detection_structs import ( SavedHamerOutputs, SingleHandHamerOutputWrtCamera, ) from hamer_helper import HamerHelper from projectaria_tools.core import calibration from projectaria_tools.core.data_provider import ( VrsDataProvider, create_vrs_data_provider, ) from tqdm.auto import tqdm from egoallo.inference_utils import InferenceTrajectoryPaths def main(traj_root: Path, overwrite: bool = False) -> None: """Run HaMeR for on trajectory. We'll save outputs to `traj_root/hamer_outputs.pkl` and `traj_root/hamer_outputs_render". Arguments: traj_root: The root directory of the trajectory. We assume that there's a VRS file in this directory. overwrite: If True, overwrite any existing HaMeR outputs. """ paths = InferenceTrajectoryPaths.find(traj_root) vrs_path = paths.vrs_file assert vrs_path.exists() pickle_out = traj_root / "hamer_outputs.pkl" hamer_render_out = traj_root / "hamer_outputs_render" # This is just for debugging. run_hamer_and_save(vrs_path, pickle_out, hamer_render_out, overwrite) def run_hamer_and_save( vrs_path: Path, pickle_out: Path, hamer_render_out: Path, overwrite: bool ) -> None: if not overwrite: assert not pickle_out.exists() assert not hamer_render_out.exists() else: pickle_out.unlink(missing_ok=True) shutil.rmtree(hamer_render_out, ignore_errors=True) hamer_render_out.mkdir(exist_ok=True) hamer_helper = HamerHelper() # VRS data provider setup. provider = create_vrs_data_provider(str(vrs_path.absolute())) assert isinstance(provider, VrsDataProvider) rgb_stream_id = provider.get_stream_id_from_label("camera-rgb") assert rgb_stream_id is not None num_images = provider.get_num_data(rgb_stream_id) print(f"Found {num_images=}") # Get calibrations. device_calib = provider.get_device_calibration() assert device_calib is not None camera_calib = device_calib.get_camera_calib("camera-rgb") assert camera_calib is not None pinhole = calibration.get_linear_camera_calibration(1408, 1408, 450) # Compute camera extrinsics! sophus_T_device_camera = device_calib.get_transform_device_sensor("camera-rgb") sophus_T_cpf_camera = device_calib.get_transform_cpf_sensor("camera-rgb") assert sophus_T_device_camera is not None assert sophus_T_cpf_camera is not None T_device_cam = np.concatenate( [ sophus_T_device_camera.rotation().to_quat().squeeze(axis=0), sophus_T_device_camera.translation().squeeze(axis=0), ] ) T_cpf_cam = np.concatenate( [ sophus_T_cpf_camera.rotation().to_quat().squeeze(axis=0), sophus_T_cpf_camera.translation().squeeze(axis=0), ] ) assert T_device_cam.shape == T_cpf_cam.shape == (7,) # Dict from capture timestamp in nanoseconds to fields we care about. detections_left_wrt_cam: dict[int, SingleHandHamerOutputWrtCamera | None] = {} detections_right_wrt_cam: dict[int, SingleHandHamerOutputWrtCamera | None] = {} pbar = tqdm(range(num_images)) for i in pbar: image_data, image_data_record = provider.get_image_data_by_index( rgb_stream_id, i ) undistorted_image = calibration.distort_by_calibration( image_data.to_numpy_array(), pinhole, camera_calib ) hamer_out_left, hamer_out_right = hamer_helper.look_for_hands( undistorted_image, focal_length=450, ) timestamp_ns = image_data_record.capture_timestamp_ns if hamer_out_left is None: detections_left_wrt_cam[timestamp_ns] = None else: detections_left_wrt_cam[timestamp_ns] = { "verts": hamer_out_left["verts"], "keypoints_3d": hamer_out_left["keypoints_3d"], "mano_hand_pose": hamer_out_left["mano_hand_pose"], "mano_hand_betas": hamer_out_left["mano_hand_betas"], "mano_hand_global_orient": hamer_out_left["mano_hand_global_orient"], } if hamer_out_right is None: detections_right_wrt_cam[timestamp_ns] = None else: detections_right_wrt_cam[timestamp_ns] = { "verts": hamer_out_right["verts"], "keypoints_3d": hamer_out_right["keypoints_3d"], "mano_hand_pose": hamer_out_right["mano_hand_pose"], "mano_hand_betas": hamer_out_right["mano_hand_betas"], "mano_hand_global_orient": hamer_out_right["mano_hand_global_orient"], } composited = undistorted_image composited = hamer_helper.composite_detections( composited, hamer_out_left, border_color=(255, 100, 100), focal_length=450, ) composited = hamer_helper.composite_detections( composited, hamer_out_right, border_color=(100, 100, 255), focal_length=450, ) composited = put_text( composited, "L detections: " + ( "0" if hamer_out_left is None else str(hamer_out_left["verts"].shape[0]) ), 0, color=(255, 100, 100), font_scale=10.0 / 2880.0 * undistorted_image.shape[0], ) composited = put_text( composited, "R detections: " + ( "0" if hamer_out_right is None else str(hamer_out_right["verts"].shape[0]) ), 1, color=(100, 100, 255), font_scale=10.0 / 2880.0 * undistorted_image.shape[0], ) composited = put_text( composited, f"ns={timestamp_ns}", 2, color=(255, 255, 255), font_scale=10.0 / 2880.0 * undistorted_image.shape[0], ) print(f"Saving image {i:06d} to {hamer_render_out / f'{i:06d}.jpeg'}") iio.imwrite( str(hamer_render_out / f"{i:06d}.jpeg"), np.concatenate( [ # Darken input image, just for contrast... (undistorted_image * 0.6).astype(np.uint8), composited, ], axis=1, ), quality=90, ) outputs = SavedHamerOutputs( mano_faces_right=hamer_helper.get_mano_faces("right"), mano_faces_left=hamer_helper.get_mano_faces("left"), detections_right_wrt_cam=detections_right_wrt_cam, detections_left_wrt_cam=detections_left_wrt_cam, T_device_cam=T_device_cam, T_cpf_cam=T_cpf_cam, ) with open(pickle_out, "wb") as f: pickle.dump(outputs, f) def put_text( image: np.ndarray, text: str, line_number: int, color: tuple[int, int, int], font_scale: float, ) -> np.ndarray: """Put some text on the top-left corner of an image.""" image = image.copy() font = cv2.FONT_HERSHEY_PLAIN cv2.putText( image, text=text, org=(2, 1 + int(15 * font_scale * (line_number + 1))), fontFace=font, fontScale=font_scale, color=(0, 0, 0), thickness=max(int(font_scale), 1), lineType=cv2.LINE_AA, ) cv2.putText( image, text=text, org=(2, 1 + int(15 * font_scale * (line_number + 1))), fontFace=font, fontScale=font_scale, color=color, thickness=max(int(font_scale), 1), lineType=cv2.LINE_AA, ) return image if __name__ == "__main__": tyro.cli(main) ================================================ FILE: 3_aria_inference.py ================================================ from __future__ import annotations import dataclasses import time from pathlib import Path import numpy as np import torch import viser import yaml from egoallo import fncsmpl, fncsmpl_extensions from egoallo.data.aria_mps import load_point_cloud_and_find_ground from egoallo.guidance_optimizer_jax import GuidanceMode from egoallo.hand_detection_structs import ( CorrespondedAriaHandWristPoseDetections, CorrespondedHamerDetections, ) from egoallo.inference_utils import ( InferenceInputTransforms, InferenceTrajectoryPaths, load_denoiser, ) from egoallo.sampling import run_sampling_with_stitching from egoallo.transforms import SE3, SO3 from egoallo.vis_helpers import visualize_traj_and_hand_detections @dataclasses.dataclass class Args: traj_root: Path """Search directory for trajectories. This should generally be laid out as something like: traj_dir/ video.vrs egoallo_outputs/ {date}_{start_index}-{end_index}.npz ... ... """ checkpoint_dir: Path = Path("./egoallo_checkpoint_april13/checkpoints_3000000/") smplh_npz_path: Path = Path("./data/smplh/neutral/model.npz") glasses_x_angle_offset: float = 0.0 """Rotate the CPF poses by some X angle.""" start_index: int = 0 """Index within the downsampled trajectory to start inference at.""" traj_length: int = 128 """How many timesteps to estimate body motion for.""" num_samples: int = 1 """Number of samples to take.""" guidance_mode: GuidanceMode = "aria_hamer" """Which guidance mode to use.""" guidance_inner: bool = True """Whether to apply guidance optimizer between denoising steps. This is important if we're doing anything with hands. It can be turned off to speed up debugging/experiments, or if we only care about foot skating losses.""" guidance_post: bool = True """Whether to apply guidance optimizer after diffusion sampling.""" save_traj: bool = True """Whether to save the output trajectory, which will be placed under `traj_dir/egoallo_outputs/some_name.npz`.""" visualize_traj: bool = False """Whether to visualize the trajectory after sampling.""" def main(args: Args) -> None: device = torch.device("cuda") traj_paths = InferenceTrajectoryPaths.find(args.traj_root) if traj_paths.splat_path is not None: print("Found splat at", traj_paths.splat_path) else: print("No scene splat found.") # Get point cloud + floor. points_data, floor_z = load_point_cloud_and_find_ground(traj_paths.points_path) # Read transforms from VRS / MPS, downsampled. transforms = InferenceInputTransforms.load( traj_paths.vrs_file, traj_paths.slam_root_dir, fps=30 ).to(device=device) # Note the off-by-one for Ts_world_cpf, which we need for relative transform computation. Ts_world_cpf = ( SE3( transforms.Ts_world_cpf[ args.start_index : args.start_index + args.traj_length + 1 ] ) @ SE3.from_rotation( SO3.from_x_radians( transforms.Ts_world_cpf.new_tensor(args.glasses_x_angle_offset) ) ) ).parameters() pose_timestamps_sec = transforms.pose_timesteps[ args.start_index + 1 : args.start_index + args.traj_length + 1 ] Ts_world_device = transforms.Ts_world_device[ args.start_index + 1 : args.start_index + args.traj_length + 1 ] del transforms # Get temporally corresponded HaMeR detections. if traj_paths.hamer_outputs is not None: hamer_detections = CorrespondedHamerDetections.load( traj_paths.hamer_outputs, pose_timestamps_sec, ).to(device) else: print("No hand detections found.") hamer_detections = None # Get temporally corresponded Aria wrist and palm estimates. if traj_paths.wrist_and_palm_poses_csv is not None: aria_detections = CorrespondedAriaHandWristPoseDetections.load( traj_paths.wrist_and_palm_poses_csv, pose_timestamps_sec, Ts_world_device=Ts_world_device.numpy(force=True), ).to(device) else: print("No Aria hand detections found.") aria_detections = None print(f"{Ts_world_cpf.shape=}") server = None if args.visualize_traj: server = viser.ViserServer() server.gui.configure_theme(dark_mode=True) denoiser_network = load_denoiser(args.checkpoint_dir).to(device) body_model = fncsmpl.SmplhModel.load(args.smplh_npz_path).to(device) traj = run_sampling_with_stitching( denoiser_network, body_model=body_model, guidance_mode=args.guidance_mode, guidance_inner=args.guidance_inner, guidance_post=args.guidance_post, Ts_world_cpf=Ts_world_cpf, hamer_detections=hamer_detections, aria_detections=aria_detections, num_samples=args.num_samples, device=device, floor_z=floor_z, ) # Save outputs in case we want to visualize later. if args.save_traj: save_name = ( time.strftime("%Y%m%d-%H%M%S") + f"_{args.start_index}-{args.start_index + args.traj_length}" ) out_path = args.traj_root / "egoallo_outputs" / (save_name + ".npz") out_path.parent.mkdir(parents=True, exist_ok=True) assert not out_path.exists() (args.traj_root / "egoallo_outputs" / (save_name + "_args.yaml")).write_text( yaml.dump(dataclasses.asdict(args)) ) posed = traj.apply_to_body(body_model) Ts_world_root = fncsmpl_extensions.get_T_world_root_from_cpf_pose( posed, Ts_world_cpf[..., 1:, :] ) print(f"Saving to {out_path}...", end="") np.savez( out_path, Ts_world_cpf=Ts_world_cpf[1:, :].numpy(force=True), Ts_world_root=Ts_world_root.numpy(force=True), body_quats=posed.local_quats[..., :21, :].numpy(force=True), left_hand_quats=posed.local_quats[..., 21:36, :].numpy(force=True), right_hand_quats=posed.local_quats[..., 36:51, :].numpy(force=True), contacts=traj.contacts.numpy(force=True), # Sometimes we forgot this... betas=traj.betas.numpy(force=True), frame_nums=np.arange(args.start_index, args.start_index + args.traj_length), timestamps_ns=(np.array(pose_timestamps_sec) * 1e9).astype(np.int64), ) print("saved!") # Visualize. if args.visualize_traj: assert server is not None loop_cb = visualize_traj_and_hand_detections( server, Ts_world_cpf[1:], traj, body_model, hamer_detections, aria_detections, points_data=points_data, splat_path=traj_paths.splat_path, floor_z=floor_z, ) while True: loop_cb() if __name__ == "__main__": import tyro main(tyro.cli(Args)) ================================================ FILE: 4_visualize_outputs.py ================================================ from __future__ import annotations import io from pathlib import Path from typing import Callable import cv2 import imageio.v3 as iio import numpy as np import torch import tyro import viser from projectaria_tools.core.data_provider import ( VrsDataProvider, create_vrs_data_provider, ) from projectaria_tools.core.sensor_data import TimeDomain from tqdm import tqdm from egoallo import fncsmpl from egoallo.data.aria_mps import load_point_cloud_and_find_ground from egoallo.hand_detection_structs import ( CorrespondedAriaHandWristPoseDetections, CorrespondedHamerDetections, ) from egoallo.inference_utils import InferenceTrajectoryPaths from egoallo.network import EgoDenoiseTraj from egoallo.transforms import SE3, SO3 from egoallo.vis_helpers import visualize_traj_and_hand_detections def main( search_root_dir: Path, smplh_npz_path: Path = Path("./data/smplh/neutral/model.npz"), ) -> None: """Visualization script for outputs from EgoAllo. Arguments: search_root_dir: Root directory where inputs/outputs are stored. All NPZ files in this directory will be assumed to be outputs from EgoAllo. smplh_npz_path: Path to the SMPLH model NPZ file. """ device = torch.device("cuda") body_model = fncsmpl.SmplhModel.load(smplh_npz_path).to(device) server = viser.ViserServer() server.gui.configure_theme(dark_mode=True) def get_file_list(): return ["None"] + sorted( str(p.relative_to(search_root_dir)) for p in search_root_dir.glob("**/egoallo_outputs/*.npz") ) options = get_file_list() file_dropdown = server.gui.add_dropdown("File", options=options) refresh_file_list = server.gui.add_button("Refresh File List") @refresh_file_list.on_click def _(_) -> None: file_dropdown.options = get_file_list() trajectory_folder = server.gui.add_folder("Trajectory") current_file = "None" loop_cb = lambda: None while True: loop_cb() if current_file != file_dropdown.value: current_file = file_dropdown.value # Clear the scene. server.scene.reset() if current_file != "None": # Clear the folder by removing then re-adding it. # Perhaps we should expose some API for looping through children? trajectory_folder.remove() trajectory_folder = server.gui.add_folder("Trajectory") with trajectory_folder: npz_path = Path(search_root_dir / current_file).resolve() loop_cb = load_and_visualize( server, npz_path, body_model, device=device, ) args = npz_path.parent / (npz_path.stem + "_args.yaml") if args.exists(): with server.gui.add_folder("Args"): server.gui.add_markdown( "```\n" + args.read_text() + "\n```" ) def load_and_visualize( server: viser.ViserServer, npz_path: Path, body_model: fncsmpl.SmplhModel, device: torch.device, ) -> Callable[[], int]: # Here's how we saved: # # np.savez( # out_path, # Ts_world_cpf=Ts_world_cpf[1:, :].numpy(force=True), # Ts_world_root=Ts_world_root.numpy(force=True), # body_quats=posed.local_quats[..., :21, :].numpy(force=True), # left_hand_quats=posed.local_quats[..., 21:36, :].numpy(force=True), # right_hand_quats=posed.local_quats[..., 36:51, :].numpy(force=True), # betas=traj.betas.numpy(force=True), # frame_nums=np.arange(args.start_index, args.start_index + args.traj_length), # timestamps_ns=(np.array(pose_timestamps_sec) * 1e9).astype(np.int64), # ) outputs = np.load(npz_path) expected_keys = [ "Ts_world_cpf", "Ts_world_root", "body_quats", "left_hand_quats", "right_hand_quats", "betas", "frame_nums", "timestamps_ns", ] assert all(key in outputs for key in expected_keys), ( f"Missing keys in NPZ file. Expected: {expected_keys}, Found: {list(outputs.keys())}" ) (num_samples, timesteps, _, _) = outputs["body_quats"].shape # We assume the directory structure is: # - some trajectory root # - outputs # - the npz file traj_dir = npz_path.resolve().parent.parent paths = InferenceTrajectoryPaths.find(traj_dir) provider = create_vrs_data_provider(str(paths.vrs_file)) device_calib = provider.get_device_calibration() T_device_cpf = SE3( torch.from_numpy( device_calib.get_transform_device_cpf().to_quat_and_translation() ) ) assert T_device_cpf.wxyz_xyz.shape == (1, 7) pose_timestamps_sec = outputs["timestamps_ns"] / 1e9 Ts_world_device = ( SE3(torch.from_numpy(outputs["Ts_world_cpf"])) @ T_device_cpf.inverse() ).wxyz_xyz # Get temporally corresponded HaMeR detections. if paths.hamer_outputs is not None: hamer_detections = CorrespondedHamerDetections.load( paths.hamer_outputs, pose_timestamps_sec, ) else: print("No hand detections found.") hamer_detections = None # Get temporally corresponded Aria wrist and palm estimates. if paths.wrist_and_palm_poses_csv is not None: aria_detections = CorrespondedAriaHandWristPoseDetections.load( paths.wrist_and_palm_poses_csv, pose_timestamps_sec, Ts_world_device=Ts_world_device.numpy(force=True), ) else: aria_detections = None if paths.splat_path is not None: print("Found splat at", paths.splat_path) else: print("No scene splat found.") # Get point cloud + floor. points_data, floor_z = load_point_cloud_and_find_ground( paths.points_path, "filtered" ) traj = EgoDenoiseTraj( betas=torch.from_numpy(outputs["betas"]).to(device), body_rotmats=SO3( torch.from_numpy(outputs["body_quats"]), ) .as_matrix() .to(device), # We weren't saving contacts originally. We added it September 28th. contacts=torch.zeros((num_samples, timesteps, 21), device=device) if "contacts" not in outputs else torch.from_numpy(outputs["contacts"]).to(device), hand_rotmats=SO3( torch.from_numpy( np.concatenate( [ outputs["left_hand_quats"], outputs["right_hand_quats"], ], axis=-2, ) ).to(device) ).as_matrix(), ) Ts_world_cpf = torch.from_numpy(outputs["Ts_world_cpf"]).to(device) def get_ego_video( start_index: int, end_index: int, total_duration: float, ) -> bytes: """Helper function that returns the egocentric video corresponding to some start/end pose index.""" assert isinstance(provider, VrsDataProvider) rgb_stream_id = provider.get_stream_id_from_label("camera-rgb") assert rgb_stream_id is not None camera_fps = provider.get_configuration(rgb_stream_id).get_nominal_rate_hz() print(f"{camera_fps=}") start_ns = int(outputs["timestamps_ns"][start_index]) first_ns = provider.get_first_time_ns(rgb_stream_id, TimeDomain.RECORD_TIME) image_start_index = int((start_ns - first_ns) / 1e9 * camera_fps) image_end_index = min( int(image_start_index + (end_index - start_index) / 30.0 * camera_fps) + 5, provider.get_num_data(rgb_stream_id), ) frames = [] for i in tqdm(range(image_start_index, image_end_index)): image_data = provider.get_image_data_by_index(rgb_stream_id, i)[0] image_array = image_data.to_numpy_array().copy() image_array = cv2.resize( image_array, (800, 800), interpolation=cv2.INTER_AREA ) image_array = cv2.rotate(image_array, cv2.ROTATE_90_CLOCKWISE) frames.append(image_array) fps = len(frames) / total_duration output = io.BytesIO() iio.imwrite( output, frames, fps=fps, extension=".mp4", codec="libx264", pixelformat="yuv420p", quality=None, ffmpeg_params=["-crf", "23"], ) return output.getvalue() return visualize_traj_and_hand_detections( server, Ts_world_cpf, traj, body_model, hamer_detections, aria_detections, points_data, paths.splat_path, floor_z=floor_z, get_ego_video=get_ego_video, ) if __name__ == "__main__": tyro.cli(main) ================================================ FILE: 5_eval_body_metrics.py ================================================ """Example script for computing body metrics on the test split of the AMASS dataset. This is not the exact script we used for the paper metrics, but should have the details that matter matched. Below are some metrics from this script when our released checkpoint is passed in. For --subseq-len 128: mpjpe 118.340 +/- 1.350 (in paper: 119.7 +/- 1.3) pampjpe 100.026 +/- 1.349 (in paper: 101.1 +/- 1.3) T_head 0.006 +/- 0.000 (in paper: 0.0062 +/- 0.0001) foot_contact (GND) 1.000 +/- 0.000 (in paper: 1.0 +/- 0.0) foot_skate 0.417 +/- 0.017 (not reported in paper) For --subseq-len 32: mpjpe 129.193 +/- 1.108 (in paper: 129.8 +/- 1.1) pampjpe 109.489 +/- 1.147 (in paper: 109.8 +/- 1.1) T_head 0.006 +/- 0.000 (in paper: 0.0064 +/- 0.0001) foot_contact (GND) 0.985 +/- 0.003 (in paper: 0.98 +/- 0.00) foot_skate 0.185 +/- 0.005 (not reported in paper) """ from pathlib import Path import jax.tree import numpy as np import torch.optim.lr_scheduler import torch.utils.data import tyro from egoallo import fncsmpl from egoallo.data.amass import EgoAmassHdf5Dataset from egoallo.fncsmpl_extensions import get_T_world_root_from_cpf_pose from egoallo.inference_utils import load_denoiser from egoallo.metrics_helpers import ( compute_foot_contact, compute_foot_skate, compute_head_trans, compute_mpjpe, ) from egoallo.sampling import run_sampling_with_stitching from egoallo.transforms import SE3, SO3 def main( dataset_hdf5_path: Path, dataset_files_path: Path, subseq_len: int = 128, guidance_inner: bool = False, checkpoint_dir: Path = Path("./egoallo_checkpoint_april13/checkpoints_3000000/"), smplh_npz_path: Path = Path("./data/smplh/neutral/model.npz"), num_samples: int = 1, ) -> None: """Compute body metrics on the test split of the AMASS dataset.""" device = torch.device("cuda") # Setup. denoiser_network = load_denoiser(checkpoint_dir).to(device) dataset = EgoAmassHdf5Dataset( dataset_hdf5_path, dataset_files_path, splits=("test",), # We need an extra timestep in order to compute the relative CPF pose. (T_cpf_tm1_cpf_t) subseq_len=subseq_len + 1, cache_files=True, slice_strategy="deterministic", random_variable_len_proportion=0.0, ) body_model = fncsmpl.SmplhModel.load(smplh_npz_path).to(device) metrics = list[dict[str, np.ndarray]]() for i in range(len(dataset)): sequence = dataset[i].to(device) samples = run_sampling_with_stitching( denoiser_network, body_model=body_model, guidance_mode="no_hands", guidance_inner=guidance_inner, guidance_post=True, Ts_world_cpf=sequence.T_world_cpf, hamer_detections=None, aria_detections=None, num_samples=num_samples, floor_z=0.0, device=device, guidance_verbose=False, ) assert samples.hand_rotmats is not None assert samples.betas.shape == (num_samples, subseq_len, 16) assert samples.body_rotmats.shape == (num_samples, subseq_len, 21, 3, 3) assert samples.hand_rotmats.shape == (num_samples, subseq_len, 30, 3, 3) assert sequence.hand_quats is not None # We'll only use the body joint rotations. pred_posed = body_model.with_shape(samples.betas).with_pose( T_world_root=SE3.identity(device, torch.float32).wxyz_xyz, local_quats=SO3.from_matrix( torch.cat([samples.body_rotmats, samples.hand_rotmats], dim=2) ).wxyz, ) pred_posed = pred_posed.with_new_T_world_root( get_T_world_root_from_cpf_pose(pred_posed, sequence.T_world_cpf[1:, ...]) ) label_posed = body_model.with_shape(sequence.betas[1:, ...]).with_pose( sequence.T_world_root[1:, ...], torch.cat( [ sequence.body_quats[1:, ...], sequence.hand_quats[1:, ...], ], dim=1, ), ) metrics.append( { "mpjpe": compute_mpjpe( label_T_world_root=label_posed.T_world_root, label_Ts_world_joint=label_posed.Ts_world_joint[:, :21, :], pred_T_world_root=pred_posed.T_world_root, pred_Ts_world_joint=pred_posed.Ts_world_joint[:, :, :21, :], per_frame_procrustes_align=False, ), "pampjpe": compute_mpjpe( label_T_world_root=label_posed.T_world_root, label_Ts_world_joint=label_posed.Ts_world_joint[:, :21, :], pred_T_world_root=pred_posed.T_world_root, pred_Ts_world_joint=pred_posed.Ts_world_joint[:, :, :21, :], per_frame_procrustes_align=True, ), # We didn't report foot skating metrics in the paper. It's not # really meaningful: since we optimize foot skating in the # guidance optimizer, it's easy to "cheat" this metric. "foot_skate": compute_foot_skate( pred_Ts_world_joint=pred_posed.Ts_world_joint[:, :, :21, :], ), "foot_contact (GND)": compute_foot_contact( pred_Ts_world_joint=pred_posed.Ts_world_joint[:, :, :21, :], ), "T_head": compute_head_trans( label_Ts_world_joint=label_posed.Ts_world_joint[:, :21, :], pred_Ts_world_joint=pred_posed.Ts_world_joint[:, :, :21, :], ), } ) print("=" * 80) print("=" * 80) print("=" * 80) print(f"Metrics ({i}/{len(dataset)} processed)") for k, v in jax.tree.map( lambda *x: f"{np.mean(x):.3f} +/- {np.std(x) / np.sqrt(len(metrics) * num_samples):.3f}", *metrics, ).items(): print("\t", k, v) print("=" * 80) print("=" * 80) print("=" * 80) if __name__ == "__main__": tyro.cli(main) ================================================ FILE: LICENSE ================================================ MIT License Copyright (c) 2024 Brent Yi Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. ================================================ FILE: README.md ================================================ # egoallo **[Project page](https://egoallo.github.io/) • [arXiv](https://arxiv.org/abs/2410.03665)** Code release for our preprint:
Brent Yi1, Vickie Ye1, Maya Zheng1, Yunqi Li2, Lea Müller1, Georgios Pavlakos3, Yi Ma1, Jitendra Malik1, and Angjoo Kanazawa1. Estimating Body and Hand Motion in an Ego-sensed World. arXiV, 2024.
1UC Berkeley, 2ShanghaiTech, 3UT Austin --- ## Updates - **Oct 7, 2024:** Initial release. (training code, core implementation details) - **Oct 14, 2024:** Added model checkpoint, dataset preprocessing, inference, and visualization scripts. - **May 6, 2025:** Updated scripts + instructions for dataset preprocessing, which is now self-contained in this repository. ## Overview **TLDR;** We use egocentric SLAM poses and images to estimate 3D human body pose, height, and hands. https://github.com/user-attachments/assets/7d28e07f-ab83-4749-ac6b-abe692d9ba20 This repository is structured as follows: ``` . ├── download_checkpoint_and_data.sh │ - Download model checkpoint and sample data. ├── 0_preprocess_training_data.py │ - Preprocessing script for training datasets. ├── 1_train_motion_prior.py │ - Training script for motion diffusion model. ├── 2_run_hamer_on_vrs.py │ - Run HaMeR on inference data (expects Aria VRS). ├── 3_aria_inference.py │ - Run full pipeline on inference data. ├── 4_visualize_outputs.py │ - Visualize outputs from inference. ├── 5_eval_body_metrics.py │ - Compute and print body estimation accuracy metrics. │ ├── src/egoallo/ │ ├── data/ - Dataset utilities. │ ├── transforms/ - SO(3) / SE(3) transformation helpers. │ └── *.py - All core implementation. │ └── pyproject.toml - Python dependencies/package metadata. ``` ## Getting started EgoAllo requires Python 3.12 or newer. 1. **Clone the repository.** ```bash git clone https://github.com/brentyi/egoallo.git ``` 2. **Install general dependencies.** ```bash cd egoallo pip install -e . ``` 3. **Download+unzip model checkpoint and sample data.** ```bash bash download_checkpoint_and_data.sh ``` You can also download the zip files manually: here are links to the [checkpoint](https://drive.google.com/file/d/14bDkWixFgo3U6dgyrCRmLoXSsXkrDA2w/view?usp=drive_link) and [example trajectories](https://drive.google.com/file/d/14zQ95NYxL4XIT7KIlFgAYTPCRITWxQqu/view?usp=drive_link). 4. **Download the SMPL-H model file.** You can find the "Extended SMPL+H model" (16 shape parameters) from the [MANO project webpage](https://mano.is.tue.mpg.de/). Our scripts assumes an npz file located at `./data/smplh/neutral/model.npz`, but this can be overridden at the command-line (`--smplh-npz-path {your path}`). 5. **Visualize model outputs.** The example trajectories directory includes example outputs from our model. You can visualize them with: ```bash python 4_visualize_outputs.py --search-root-dir ./egoallo_example_trajectories ``` ## Running inference 1. **Installing inference dependencies.** Our guidance optimization uses a Levenberg-Marquardt optimizer that's implemented in JAX. If you want to run this on an NVIDIA GPU, you'll need to install JAX with CUDA support: ```bash # Also see: https://jax.readthedocs.io/en/latest/installation.html pip install "jax[cuda12]==0.6.1" ``` You'll also need [jaxls](https://github.com/brentyi/jaxls): ```bash pip install git+https://github.com/brentyi/jaxls.git ``` 2. **Running inference on example data.** Here's an example command for running EgoAllo on the "coffeemachine" sequence: ```bash python 3_aria_inference.py --traj-root ./egoallo_example_trajectories/coffeemachine ``` You can run `python 3_aria_inference.py --help` to see the full list of options. 3. **Running inference on your own data.** To run inference on your own data, you can copy the structure of the example trajectories. The key files are: - A VRS file from Project Aria, which contains calibrations and images. - SLAM outputs from Project Aria's MPS: `closed_loop_trajectory.csv` and `semidense_points.csv.gz`. - (optional) HaMeR outputs, which we save to a `hamer_outputs.pkl`. - (optional) Project Aria wrist and palm tracking outputs. 4. **Running HaMeR on your own data.** To generate the `hamer_outputs.pkl` file, you'll need to install [hamer_helper](https://github.com/brentyi/hamer_helper). Then, as an example for running on our coffeemachine sequence: ```bash python 2_run_hamer_on_vrs.py --traj-root ./egoallo_example_trajectories/coffeemachine ``` ## Preprocessing Training Data To train the motion prior model, we use data from the [AMASS dataset](https://amass.is.tue.mpg.de/). Due to licensing constraints, we cannot redistribute the preprocessed data. Instead, we provide two sequential preprocessing scripts: 1. **Download the AMASS dataset.** Download the AMASS dataset from the [official website](https://amass.is.tue.mpg.de/). We use the following splits: - **Training**: ACCAD, BioMotionLab_NTroje, BMLhandball, BMLmovi, CMU, DanceDB, DFaust_67EKUT, Eyes_Japan_Dataset, KIT, MPI_Limits, TCD_handMocap, TotalCapture - **Validation**: HumanEva, MPI_HDM05, SFU, MPI_mosh - **Testing**: Transitions_mocap, SSM_synced 2. **Run the first preprocessing script.** ```bash python 0a_preprocess_training_data.py --help python 0a_preprocess_training_data.py --data-root /path/to/amass --smplh-root ./data/smplh ``` This script, adapted from HuMoR, processes raw AMASS data by: - Converting to gender-neutral SMPL-H parameters - Computing contact labels for feet, hands, and knees - Filtering out problematic sequences (treadmill walking, sequences with foot skating) - Downsampling to 30fps 3. **Run the second preprocessing script.** ```bash python 0b_preprocess_training_data.py --help python 0b_preprocess_training_data.py --data-npz-dir ./data/processed_30fps_no_skating/ ``` This converts the processed NPZ files to a unified HDF5 format for more efficient training, with optimized chunk sizes for reading sequences. ## Status This repository currently contains: - `egoallo` package, which contains reference training and sampling implementation details. - Training script. - Model checkpoints. - Dataset preprocessing script. - Inference script. - Visualization script. - Setup instructions. While we've put effort into cleaning up our code for release, this is research code and there's room for improvement. If you have questions or comments, please reach out! ================================================ FILE: download_checkpoint_and_data.sh ================================================ # Script for downloading model checkpoint and example inputs/outputs. # egoallo_checkpoint_april13.zip (552 MB) gdown https://drive.google.com/file/d/14bDkWixFgo3U6dgyrCRmLoXSsXkrDA2w/view?usp=drive_link --fuzzy unzip egoallo_checkpoint_april13.zip rm egoallo_checkpoint_april13.zip # egoallo_example_trajectories.zip (8.17 GB) gdown https://drive.google.com/file/d/14zQ95NYxL4XIT7KIlFgAYTPCRITWxQqu/view?usp=drive_link --fuzzy unzip egoallo_example_trajectories.zip rm egoallo_example_trajectories.zip ================================================ FILE: pyproject.toml ================================================ [build-system] requires = ["setuptools>=61.0"] build-backend = "setuptools.build_meta" [project] name = "egoallo" version = "0.0.0" description = "egoallo" readme = "README.md" license = { text="MIT" } requires-python = ">=3.12" classifiers = [ "Programming Language :: Python :: 3.12", "License :: OSI Approved :: MIT License", "Operating System :: OS Independent" ] dependencies = [ "torch==2.7.1", "viser>=0.2.11", "plyfile==1.1.2", "typeguard==4.4.3", "jaxtyping==0.3.2", "einops==0.8.1", "rotary-embedding-torch==0.8.6", "h5py==3.13.0", "tensorboard==2.19.0", "projectaria_tools==1.6.0", "accelerate==1.7.0", "tensorboardX==2.6.2.2", "loguru==0.7.3", "projectaria-tools[all]==1.6.0", "opencv-python==4.11.0.86", "gdown==5.2.0", "scikit-learn==1.6.1", # Only needed for preprocessing "smplx==0.1.28", # Only needed for preprocessing ] [tool.setuptools.package-data] egoallo = ["py.typed"] [tool.pyright] ignore = ["**/preprocessing/**", "./0a_preprocess_training_data.py"] [tool.ruff.lint] select = [ "E", # pycodestyle errors. "F", # Pyflakes rules. "PLC", # Pylint convention warnings. "PLE", # Pylint errors. "PLR", # Pylint refactor recommendations. "PLW", # Pylint warnings. ] ignore = [ "E731", # Do not assign a lambda expression, use a def. "E741", # Ambiguous variable name. (l, O, or I) "E501", # Line too long. "E721", # Do not compare types, use `isinstance()`. "F722", # Forward annotation false positive from jaxtyping. Should be caught by pyright. "F821", # Forward annotation false positive from jaxtyping. Should be caught by pyright. "PLR2004", # Magic value used in comparison. "PLR0915", # Too many statements. "PLR0913", # Too many arguments. "PLC0414", # Import alias does not rename variable. (this is used for exporting names) "PLC1901", # Use falsey strings. "PLR5501", # Use `elif` instead of `else if`. "PLR0911", # Too many return statements. "PLR0912", # Too many branches. "PLW0603", # Globa statement updates are discouraged. "PLW2901", # For loop variable overwritten. ] ================================================ FILE: src/egoallo/__init__.py ================================================ ================================================ FILE: src/egoallo/fncsmpl.py ================================================ """Somewhat opinionated wrapper for the SMPL-H body model. Very little of it is specific to SMPL-H. This could very easily be adapted for other models in SMPL family. We break down the SMPL-H into four stages, each with a corresponding data structure: - Loading the model itself: `model = SmplhModel.load(path to npz)` - Applying a body shape to the model: `shaped = model.with_shape(betas)` - Posing the body shape: `posed = shaped.with_pose(root pose, local joint poses)` - Recovering the mesh with LBS: `mesh = posed.lbs()` In contrast to other SMPL wrappers: - Everything is stateless, so we can support arbitrary batch axes. - The root is no longer ever called a joint. - The `trans` and `root_orient` inputs are replaced by a single SE(3) root transformation. - We're using (4,) wxyz quaternion vectors for all rotations, (7,) wxyz_xyz vectors for all rigid transforms. """ from __future__ import annotations from pathlib import Path import numpy as np import torch from einops import einsum from jaxtyping import Float, Int from torch import Tensor from .tensor_dataclass import TensorDataclass from .transforms import SE3, SO3 class SmplhModel(TensorDataclass): """A human body model from the SMPL family.""" faces: Int[Tensor, "faces 3"] """Vertex indices for mesh faces.""" J_regressor: Float[Tensor, "joints+1 verts"] """Linear map from vertex to joint positions. For SMPL-H, 1 root + 21 body joints + 2 * 15 hand joints.""" parent_indices: tuple[int, ...] """Defines kinematic tree. Index of -1 signifies that a joint is defined relative to the root.""" weights: Float[Tensor, "verts joints+1"] """LBS weights.""" posedirs: Float[Tensor, "verts 3 joints*9"] """Pose blend shape bases.""" v_template: Float[Tensor, "verts 3"] """Canonical mesh verts.""" shapedirs: Float[Tensor, "verts 3 n_betas"] """Shape bases.""" @staticmethod def load(model_path: Path) -> SmplhModel: """Load a body model from an NPZ file.""" params_numpy: dict[str, np.ndarray] = { k: _normalize_dtype(v) for k, v in np.load(model_path, allow_pickle=True).items() } assert ( "bs_style" not in params_numpy or params_numpy.pop("bs_style").item() == b"lbs" ) assert ( "bs_type" not in params_numpy or params_numpy.pop("bs_type").item() == b"lrotmin" ) parent_indices = tuple( int(index) for index in params_numpy.pop("kintree_table")[0][1:] - 1 ) params = { k: torch.from_numpy(v) for k, v in params_numpy.items() if v.dtype in (np.int32, np.float32) } return SmplhModel( faces=params["f"], J_regressor=params["J_regressor"], parent_indices=parent_indices, weights=params["weights"], posedirs=params["posedirs"], v_template=params["v_template"], shapedirs=params["shapedirs"], ) def get_num_joints(self) -> int: """Get the number of joints in this model.""" return len(self.parent_indices) def with_shape(self, betas: Float[Tensor, "*#batch n_betas"]) -> SmplhShaped: """Compute a new body model, with betas applied.""" num_betas = betas.shape[-1] assert num_betas <= self.shapedirs.shape[-1] verts_with_shape = self.v_template + einsum( self.shapedirs[:, :, :num_betas], betas, "verts xyz beta, ... beta -> ... verts xyz", ) root_and_joints_pred = einsum( self.J_regressor, verts_with_shape, "jointsp1 verts, ... verts xyz -> ... jointsp1 xyz", ) root_offset = root_and_joints_pred[..., 0:1, :] return SmplhShaped( body_model=self, root_offset=root_offset.unsqueeze(-2), verts_zero=verts_with_shape - root_offset, joints_zero=root_and_joints_pred[..., 1:, :] - root_offset, t_parent_joint=root_and_joints_pred[..., 1:, :] - root_and_joints_pred[..., np.array(self.parent_indices) + 1, :], ) class SmplhShaped(TensorDataclass): """The SMPL-H body model with a body shape applied.""" body_model: SmplhModel """The underlying body model.""" root_offset: Float[Tensor, "*#batch 3"] verts_zero: Float[Tensor, "*#batch verts 3"] """Vertices of shaped body _relative to the root joint_ at the zero configuration.""" joints_zero: Float[Tensor, "*#batch joints 3"] """Joints of shaped body _relative to the root joint_ at the zero configuration.""" t_parent_joint: Float[Tensor, "*#batch joints 3"] """Position of each shaped body joint relative to its parent. Does not include root.""" def with_pose_decomposed( self, T_world_root: Float[Tensor, "*#batch 7"], body_quats: Float[Tensor, "*#batch 21 4"], left_hand_quats: Float[Tensor, "*#batch 15 4"] | None = None, right_hand_quats: Float[Tensor, "*#batch 15 4"] | None = None, ) -> SmplhShapedAndPosed: """Pose our SMPL-H body model. Returns a set of joint and vertex outputs.""" num_joints = self.body_model.get_num_joints() batch_axes = body_quats.shape[:-2] if left_hand_quats is None: left_hand_quats = body_quats.new_zeros((*batch_axes, 15, 4)) left_hand_quats[..., 0] = 1.0 if right_hand_quats is None: right_hand_quats = body_quats.new_zeros((*batch_axes, 15, 4)) right_hand_quats[..., 0] = 1.0 local_quats = broadcasting_cat( [body_quats, left_hand_quats, right_hand_quats], dim=-2 ) assert local_quats.shape[-2:] == (num_joints, 4) return self.with_pose(T_world_root, local_quats) def with_pose( self, T_world_root: Float[Tensor, "*#batch 7"], local_quats: Float[Tensor, "*#batch joints 4"], ) -> SmplhShapedAndPosed: """Pose our SMPL-H body model. Returns a set of joint and vertex outputs.""" # Forward kinematics. num_joints = self.body_model.get_num_joints() assert local_quats.shape[-2:] == (num_joints, 4) Ts_world_joint = forward_kinematics( T_world_root=T_world_root, Rs_parent_joint=local_quats, t_parent_joint=self.t_parent_joint, parent_indices=self.body_model.parent_indices, ) assert Ts_world_joint.shape[-2:] == (num_joints, 7) return SmplhShapedAndPosed( shaped_model=self, T_world_root=T_world_root, local_quats=local_quats, Ts_world_joint=Ts_world_joint, ) class SmplhShapedAndPosed(TensorDataclass): shaped_model: SmplhShaped """Underlying shaped body model.""" T_world_root: Float[Tensor, "*#batch 7"] """Root coordinate frame.""" local_quats: Float[Tensor, "*#batch joints 4"] """Local joint orientations.""" Ts_world_joint: Float[Tensor, "*#batch joints 7"] """Absolute transform for each joint. Does not include the root.""" def with_new_T_world_root( self, T_world_root: Float[Tensor, "*#batch 7"] ) -> SmplhShapedAndPosed: return SmplhShapedAndPosed( shaped_model=self.shaped_model, T_world_root=T_world_root, local_quats=self.local_quats, Ts_world_joint=( SE3(T_world_root[..., None, :]) @ SE3(self.T_world_root[..., None, :]).inverse() @ SE3(self.Ts_world_joint) ).parameters(), ) def lbs(self) -> SmplMesh: """Compute a mesh with LBS.""" num_joints = self.local_quats.shape[-2] verts_with_blend = self.shaped_model.verts_zero + einsum( self.shaped_model.body_model.posedirs, ( SO3(self.local_quats).as_matrix() - torch.eye( 3, dtype=self.local_quats.dtype, device=self.local_quats.device ) ).reshape((*self.local_quats.shape[:-2], num_joints * 9)), "... verts j joints_times_9, ... joints_times_9 -> ... verts j", ) verts_transformed = einsum( broadcasting_cat( [ SE3(self.T_world_root).as_matrix()[..., None, :3, :], SE3(self.Ts_world_joint).as_matrix()[..., :, :3, :], ], dim=-3, ), self.shaped_model.body_model.weights, broadcasting_cat( [ verts_with_blend[..., :, None, :] - broadcasting_cat( # Prepend root to joints zeros. [ self.shaped_model.joints_zero.new_zeros(3), self.shaped_model.joints_zero[..., None, :, :], ], dim=-2, ), verts_with_blend.new_ones( ( *verts_with_blend.shape[:-1], 1 + self.shaped_model.joints_zero.shape[-2], 1, ) ), ], dim=-1, ), "... joints_p1 i j, verts joints_p1, ... verts joints_p1 j -> ... verts i", ) assert ( verts_transformed.shape[-2:] == self.shaped_model.body_model.v_template.shape ) return SmplMesh( posed_model=self, verts=verts_transformed, faces=self.shaped_model.body_model.faces, ) class SmplMesh(TensorDataclass): """Outputs from the SMPL-H model.""" posed_model: SmplhShapedAndPosed """Posed model that this mesh was computed for.""" verts: Float[Tensor, "*#batch verts 3"] """Vertices for mesh.""" faces: Int[Tensor, "verts 3"] """Faces for mesh.""" def forward_kinematics( T_world_root: Float[Tensor, "*#batch 7"], Rs_parent_joint: Float[Tensor, "*#batch joints 4"], t_parent_joint: Float[Tensor, "*#batch joints 3"], parent_indices: tuple[int, ...], ) -> Float[Tensor, "*#batch joints 7"]: """Run forward kinematics to compute absolute poses (T_world_joint) for each joint. The output array containts pose parameters (w, x, y, z, tx, ty, tz) for each joint. (this does not include the root!) Args: T_world_root: Transformation to world frame from root frame. Rs_parent_joint: Local orientation of each joint. t_parent_joint: Position of each joint with respect to its parent frame. (this does not depend on local joint orientations) parent_indices: Parent index for each joint. Index of -1 signifies that a joint is defined relative to the root. We assume that this array is sorted: parent joints should always precede child joints. Returns: Transformations to world frame from each joint frame. """ # Check shapes. num_joints = len(parent_indices) assert Rs_parent_joint.shape[-2:] == (num_joints, 4) assert t_parent_joint.shape[-2:] == (num_joints, 3) # Get relative transforms. Ts_parent_child = broadcasting_cat([Rs_parent_joint, t_parent_joint], dim=-1) assert Ts_parent_child.shape[-2:] == (num_joints, 7) # Compute one joint at a time. list_Ts_world_joint: list[Tensor] = [] for i in range(num_joints): if parent_indices[i] == -1: T_world_parent = T_world_root else: T_world_parent = list_Ts_world_joint[parent_indices[i]] list_Ts_world_joint.append( (SE3(T_world_parent) @ SE3(Ts_parent_child[..., i, :])).wxyz_xyz ) Ts_world_joint = torch.stack(list_Ts_world_joint, dim=-2) assert Ts_world_joint.shape[-2:] == (num_joints, 7) return Ts_world_joint def broadcasting_cat(tensors: list[Tensor], dim: int) -> Tensor: """Like torch.cat, but broadcasts.""" assert len(tensors) > 0 output_dims = max(map(lambda t: len(t.shape), tensors)) tensors = [ t.reshape((1,) * (output_dims - len(t.shape)) + t.shape) for t in tensors ] max_sizes = [max(t.shape[i] for t in tensors) for i in range(output_dims)] expanded_tensors = [ tensor.expand( *( tensor.shape[i] if i == dim % len(tensor.shape) else max_size for i, max_size in enumerate(max_sizes) ) ) for tensor in tensors ] return torch.cat(expanded_tensors, dim=dim) def _normalize_dtype(v: np.ndarray) -> np.ndarray: """Normalize datatypes; all arrays should be either int32 or float32.""" if "int" in str(v.dtype): return v.astype(np.int32) elif "float" in str(v.dtype): return v.astype(np.float32) else: return v ================================================ FILE: src/egoallo/fncsmpl_extensions.py ================================================ """EgoAllo-specific SMPL utilities.""" from __future__ import annotations import numpy as np import torch from jaxtyping import Float from torch import Tensor from . import fncsmpl, transforms def get_T_world_cpf(mesh: fncsmpl.SmplMesh) -> Float[Tensor, "*#batch 7"]: """Get the central pupil frame from a mesh. This assumes that we're using the SMPL-H model.""" assert mesh.verts.shape[-2:] == (6890, 3), "Not using SMPL-H model!" right_eye = (mesh.verts[..., 6260, :] + mesh.verts[..., 6262, :]) / 2.0 left_eye = (mesh.verts[..., 2800, :] + mesh.verts[..., 2802, :]) / 2.0 # CPF is between the two eyes. cpf_pos = (right_eye + left_eye) / 2.0 # Get orientation from head. cpf_orientation = mesh.posed_model.Ts_world_joint[..., 14, :4] return torch.cat([cpf_orientation, cpf_pos], dim=-1) def get_T_head_cpf(shaped: fncsmpl.SmplhShaped) -> Float[Tensor, "*#batch 7"]: """Get the central pupil frame with respect to the head (joint 14). This assumes that we're using the SMPL-H model.""" verts_zero = shaped.verts_zero assert verts_zero.shape[-2:] == (6890, 3), "Not using SMPL-H model!" right_eye = (verts_zero[..., 6260, :] + verts_zero[..., 6262, :]) / 2.0 left_eye = (verts_zero[..., 2800, :] + verts_zero[..., 2802, :]) / 2.0 # CPF is between the two eyes. cpf_pos_wrt_head = (right_eye + left_eye) / 2.0 - shaped.joints_zero[..., 14, :] return fncsmpl.broadcasting_cat( [ transforms.SO3.identity( device=cpf_pos_wrt_head.device, dtype=cpf_pos_wrt_head.dtype ).wxyz, cpf_pos_wrt_head, ], dim=-1, ) def get_T_world_root_from_cpf_pose( posed: fncsmpl.SmplhShapedAndPosed, Ts_world_cpf: Float[Tensor | np.ndarray, "... 7"], ) -> Float[Tensor, "... 7"]: """Get the root transform that would align the CPF frame of `posed` to `Ts_world_cpf`.""" device = posed.Ts_world_joint.device dtype = posed.Ts_world_joint.dtype if isinstance(Ts_world_cpf, np.ndarray): Ts_world_cpf = torch.from_numpy(Ts_world_cpf).to(device=device, dtype=dtype) assert Ts_world_cpf.shape[-1] == 7 T_world_root = ( # T_world_cpf transforms.SE3(Ts_world_cpf) # T_cpf_head @ transforms.SE3(get_T_head_cpf(posed.shaped_model)).inverse() # T_head_world @ transforms.SE3(posed.Ts_world_joint[..., 14, :]).inverse() # T_world_root @ transforms.SE3(posed.T_world_root) ) return T_world_root.wxyz_xyz ================================================ FILE: src/egoallo/fncsmpl_jax.py ================================================ """SMPL-H model, implemented in JAX. Very little of it is specific to SMPL-H. This could very easily be adapted for other models in SMPL family. """ from __future__ import annotations from pathlib import Path from typing import Sequence, cast import jax import jax_dataclasses as jdc import jaxlie import numpy as onp from einops import einsum from jax import Array from jax import numpy as jnp from jaxtyping import Float, Int @jdc.pytree_dataclass class SmplhModel: """The SMPL-H human body model.""" faces: Int[Array, "faces 3"] """Vertex indices for mesh faces.""" J_regressor: Float[Array, "joints+1 verts"] """Linear map from vertex to joint positions. 22 body joints + 2 * 22 hand joints.""" parent_indices: Int[Array, "joints"] """Defines kinematic tree. Index of -1 signifies that a joint is defined relative to the root.""" weights: Float[Array, "verts joints+1"] """LBS weights.""" posedirs: Float[Array, "verts 3 joints*9"] """Pose blend shape bases.""" v_template: Float[Array, "verts 3"] """Canonical mesh verts.""" shapedirs: Float[Array, "verts 3 n_betas"] """Shape bases.""" @staticmethod def load(npz_path: Path) -> SmplhModel: smplh_params: dict[str, onp.ndarray] = onp.load(npz_path, allow_pickle=True) # assert smplh_params["bs_style"].item() == b"lbs" # assert smplh_params["bs_type"].item() == b"lrotmin" smplh_params = {k: _normalize_dtype(v) for k, v in smplh_params.items()} return SmplhModel( faces=jnp.array(smplh_params["f"]), J_regressor=jnp.array(smplh_params["J_regressor"]), parent_indices=jnp.array(smplh_params["kintree_table"][0][1:] - 1), weights=jnp.array(smplh_params["weights"]), posedirs=jnp.array(smplh_params["posedirs"]), v_template=jnp.array(smplh_params["v_template"]), shapedirs=jnp.array(smplh_params["shapedirs"]), ) def with_shape( self, betas: Float[Array | onp.ndarray, "... n_betas"] ) -> SmplhShaped: """Compute a new body model, with betas applied. betas vector should have shape up to (16,).""" num_betas = betas.shape[-1] assert num_betas <= 16 verts_with_shape = self.v_template + einsum( self.shapedirs[:, :, :num_betas], betas, "verts xyz beta, ... beta -> ... verts xyz", ) root_and_joints_pred = einsum( self.J_regressor, verts_with_shape, "joints verts, ... verts xyz -> ... joints xyz", ) root_offset = root_and_joints_pred[..., 0:1, :] return SmplhShaped( body_model=self, verts_zero=verts_with_shape - root_offset, joints_zero=root_and_joints_pred[..., 1:, :] - root_offset, t_parent_joint=root_and_joints_pred[..., 1:, :] - root_and_joints_pred[..., self.parent_indices + 1, :], ) @jdc.pytree_dataclass class SmplhShaped: """The SMPL-H body model with a body shape applied.""" body_model: SmplhModel verts_zero: Float[Array, "verts 3"] """Vertices of shaped body _relative to the root joint_ at the zero configuration.""" joints_zero: Float[Array, "joints 3"] """Joints of shaped body _relative to the root joint_ at the zero configuration.""" t_parent_joint: Float[Array, "joints 3"] """Position of each shaped body joint relative to its parent. Does not include root.""" def with_pose_decomposed( self, T_world_root: Float[Array | onp.ndarray, "7"], body_quats: Float[Array | onp.ndarray, "21 4"], left_hand_quats: Float[Array | onp.ndarray, "15 4"] | None = None, right_hand_quats: Float[Array | onp.ndarray, "15 4"] | None = None, ) -> SmplhShapedAndPosed: """Pose our SMPL-H body model. Returns a set of joint and vertex outputs.""" if left_hand_quats is None: left_hand_quats = jnp.zeros((15, 4)).at[:, 0].set(1.0) if right_hand_quats is None: right_hand_quats = jnp.zeros((15, 4)).at[:, 0].set(1.0) local_quats = broadcasting_cat( cast(list[jax.Array], [body_quats, left_hand_quats, right_hand_quats]), axis=0, ) assert local_quats.shape[-2:] == (51, 4) return self.with_pose(T_world_root, local_quats) def with_pose( self, T_world_root: Float[Array | onp.ndarray, "... 7"], local_quats: Float[Array | onp.ndarray, "... num_joints 4"], ) -> SmplhShapedAndPosed: """Pose our SMPL-H body model. Returns a set of joint and vertex outputs.""" # Forward kinematics. # assert local_quats.shape == (51, 4), local_quats.shape parent_indices = self.body_model.parent_indices (num_joints,) = parent_indices.shape[-1:] num_active_joints, _ = local_quats.shape[-2:] assert local_quats.shape[-1] == 4 assert num_active_joints <= num_joints assert self.t_parent_joint.shape[-2:] == (num_joints, 3) # Get relative transforms. Ts_parent_child = broadcasting_cat( [local_quats, self.t_parent_joint[..., :num_active_joints, :]], axis=-1 ) assert Ts_parent_child.shape[-2:] == (num_active_joints, 7) # Compute one joint at a time. def compute_joint(i: int, Ts_world_joint: Array) -> Array: T_world_parent = jnp.where( parent_indices[i] == -1, T_world_root, Ts_world_joint[..., parent_indices[i], :], ) return Ts_world_joint.at[..., i, :].set( ( jaxlie.SE3(T_world_parent) @ jaxlie.SE3(Ts_parent_child[..., i, :]) ).wxyz_xyz ) Ts_world_joint = jax.lax.fori_loop( lower=0, upper=num_joints, body_fun=compute_joint, init_val=jnp.zeros_like(Ts_parent_child), ) assert Ts_world_joint.shape[-2:] == (num_active_joints, 7) return SmplhShapedAndPosed( shaped_model=self, T_world_root=T_world_root, # type: ignore local_quats=local_quats, # type: ignore Ts_world_joint=Ts_world_joint, ) def get_T_head_cpf(self) -> Float[Array, "7"]: """Get the central pupil frame with respect to the head (joint 14). This assumes that we're using the SMPL-H model.""" assert self.verts_zero.shape[-2:] == (6890, 3), "Not using SMPL-H model!" right_eye = ( self.verts_zero[..., 6260, :] + self.verts_zero[..., 6262, :] ) / 2.0 left_eye = (self.verts_zero[..., 2800, :] + self.verts_zero[..., 2802, :]) / 2.0 # CPF is between the two eyes. cpf_pos_wrt_head = (right_eye + left_eye) / 2.0 - self.joints_zero[..., 14, :] return broadcasting_cat([jaxlie.SO3.identity().wxyz, cpf_pos_wrt_head], axis=-1) @jdc.pytree_dataclass class SmplhShapedAndPosed: shaped_model: SmplhShaped """Underlying shaped body model.""" T_world_root: Float[Array, "*#batch 7"] """Root coordinate frame.""" local_quats: Float[Array, "*#batch joints 4"] """Local joint orientations.""" Ts_world_joint: Float[Array, "joints 7"] """Absolute transform for each joint. Does not include the root.""" def with_new_T_world_root( self, T_world_root: Float[Array, "*#batch 7"] ) -> SmplhShapedAndPosed: return SmplhShapedAndPosed( shaped_model=self.shaped_model, T_world_root=T_world_root, local_quats=self.local_quats, Ts_world_joint=( jaxlie.SE3(T_world_root[..., None, :]) @ jaxlie.SE3(self.T_world_root[..., None, :]).inverse() @ jaxlie.SE3(self.Ts_world_joint) ).parameters(), ) def lbs(self) -> SmplhMesh: assert ( self.local_quats.shape[0] == self.shaped_model.body_model.parent_indices.shape[0] ), ( "It looks like only a partial set of joint rotations was passed into `with_pose()`. We need all of them for LBS." ) # Linear blend skinning with a pose blend shape. verts_with_blend = self.shaped_model.verts_zero + einsum( self.shaped_model.body_model.posedirs, (jaxlie.SO3(self.local_quats).as_matrix() - jnp.eye(3)).flatten(), "verts j joints_times_9, ... joints_times_9 -> ... verts j", ) verts_transformed = einsum( broadcasting_cat( [ # (*, 1, 3, 4) jaxlie.SE3(self.T_world_root).as_matrix()[..., None, :3, :], # (*, 51, 3, 4) jaxlie.SE3(self.Ts_world_joint).as_matrix()[..., :3, :], ], axis=0, ), self.shaped_model.body_model.weights, jnp.pad( verts_with_blend[:, None, :] - jnp.concatenate( [ jnp.zeros((1, 1, 3)), # Root joint. self.shaped_model.joints_zero[None, :, :], ], axis=1, ), ((0, 0), (0, 0), (0, 1)), constant_values=1.0, ), "joints_p1 i j, ... verts joints_p1, ... verts joints_p1 j -> ... verts i", ) return SmplhMesh( posed_model=self, verts=verts_transformed, faces=self.shaped_model.body_model.faces, ) @jdc.pytree_dataclass class SmplhMesh: posed_model: SmplhShapedAndPosed verts: Float[Array, "verts 3"] """Vertices for mesh.""" faces: Int[Array, "13776 3"] """Faces for mesh.""" def broadcasting_cat(arrays: Sequence[jax.Array | onp.ndarray], axis: int) -> jax.Array: """Like jnp.concatenate, but broadcasts leading axes.""" assert len(arrays) > 0 output_dims = max(map(lambda t: len(t.shape), arrays)) arrays = [t.reshape((1,) * (output_dims - len(t.shape)) + t.shape) for t in arrays] max_sizes = [max(t.shape[i] for t in arrays) for i in range(output_dims)] expanded_arrays = [ jnp.broadcast_to( array, tuple( array.shape[i] if i == axis % len(array.shape) else max_size for i, max_size in enumerate(max_sizes) ), ) for array in arrays ] return jnp.concatenate(expanded_arrays, axis=axis) def _normalize_dtype(v: onp.ndarray) -> onp.ndarray: """Normalize datatypes; all arrays should be either int32 or float32.""" if "int" in str(v.dtype): return v.astype(onp.int32) elif "float" in str(v.dtype): return v.astype(onp.float32) else: return v ================================================ FILE: src/egoallo/guidance_optimizer_jax.py ================================================ """Optimize constraints using Levenberg-Marquardt.""" from __future__ import annotations import os from .hand_detection_structs import ( CorrespondedAriaHandWristPoseDetections, CorrespondedHamerDetections, ) # Need to play nice with PyTorch! os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false" import dataclasses import time from functools import partial from typing import Callable, Literal, Unpack, assert_never, cast import jax import jax_dataclasses as jdc import jaxlie import jaxls import numpy as onp import torch from jax import numpy as jnp from jaxtyping import Float, Int from torch import Tensor from . import fncsmpl, fncsmpl_jax, network from .transforms._so3 import SO3 def do_guidance_optimization( Ts_world_cpf: Float[Tensor, "time 7"], traj: network.EgoDenoiseTraj, body_model: fncsmpl.SmplhModel, guidance_mode: GuidanceMode, phase: Literal["inner", "post"], hamer_detections: None | CorrespondedHamerDetections, aria_detections: None | CorrespondedAriaHandWristPoseDetections, verbose: bool, ) -> tuple[network.EgoDenoiseTraj, dict]: """Run an optimizer to apply foot contact constraints.""" assert traj.hand_rotmats is not None guidance_params = JaxGuidanceParams.defaults(guidance_mode, phase) start_time = time.time() quats, debug_info = _optimize_vmapped( body=fncsmpl_jax.SmplhModel( faces=cast(jax.Array, body_model.faces.numpy(force=True)), J_regressor=cast(jax.Array, body_model.J_regressor.numpy(force=True)), parent_indices=cast(jax.Array, onp.array(body_model.parent_indices)), weights=cast(jax.Array, body_model.weights.numpy(force=True)), posedirs=cast(jax.Array, body_model.posedirs.numpy(force=True)), v_template=cast(jax.Array, body_model.v_template.numpy(force=True)), shapedirs=cast(jax.Array, body_model.shapedirs.numpy(force=True)), ), Ts_world_cpf=cast(jax.Array, Ts_world_cpf.numpy(force=True)), betas=cast(jax.Array, traj.betas.numpy(force=True)), body_rotmats=cast(jax.Array, traj.body_rotmats.numpy(force=True)), hand_rotmats=cast(jax.Array, traj.hand_rotmats.numpy(force=True)), contacts=cast(jax.Array, traj.contacts.numpy(force=True)), guidance_params=guidance_params, # The hand detections are a torch tensors in a TensorDataclass form. We # use dictionaries to convert to pytrees. hamer_detections=None if hamer_detections is None else hamer_detections.as_nested_dict(numpy=True), aria_detections=None if aria_detections is None else aria_detections.as_nested_dict(numpy=True), verbose=verbose, ) rotmats = SO3( torch.from_numpy(onp.array(quats)) .to(traj.body_rotmats.dtype) .to(traj.body_rotmats.device) ).as_matrix() print(f"Constraint optimization finished in {time.time() - start_time}sec") return dataclasses.replace( traj, body_rotmats=rotmats[:, :, :21, :], hand_rotmats=rotmats[:, :, 21:, :], ), debug_info class _SmplhBodyPosesVar( jaxls.Var[jax.Array], default_factory=lambda: jnp.concatenate( [jnp.ones((21, 1)), jnp.zeros((21, 3))], axis=-1 ), retract_fn=lambda val, delta: ( jaxlie.SO3(val) @ jaxlie.SO3.exp(delta.reshape(21, 3)) ).wxyz, tangent_dim=21 * 3, ): """Variable containing local joint poses for a SMPL-H human.""" class _SmplhSingleHandPosesVar( jaxls.Var[jax.Array], default_factory=lambda: jnp.concatenate( [jnp.ones((15, 1)), jnp.zeros((15, 3))], axis=-1 ), retract_fn=lambda val, delta: ( jaxlie.SO3(val) @ jaxlie.SO3.exp(delta.reshape(15, 3)) ).wxyz, tangent_dim=15 * 3, ): """Variable containing local joint poses for one hand of a SMPL-H human.""" @jdc.jit def _optimize_vmapped( Ts_world_cpf: jax.Array, body: fncsmpl_jax.SmplhModel, betas: jax.Array, body_rotmats: jax.Array, hand_rotmats: jax.Array, contacts: jax.Array, guidance_params: JaxGuidanceParams, hamer_detections: dict | None, aria_detections: dict | None, verbose: jdc.Static[bool], ) -> tuple[jax.Array, dict]: return jax.vmap( partial( _optimize, Ts_world_cpf=Ts_world_cpf, body=body, guidance_params=guidance_params, hamer_detections=hamer_detections, aria_detections=aria_detections, verbose=verbose, ) )( betas=betas, body_rotmats=body_rotmats, hand_rotmats=hand_rotmats, contacts=contacts, ) # Modes for guidance. GuidanceMode = Literal[ # Foot skating only. "no_hands", # Only use Aria wrist pose. "aria_wrist_only", # Use Aria wrist pose + HaMeR 3D estimates. "aria_hamer", # Use only HaMeR 3D estimates. "hamer_wrist", # Use HaMeR 3D estimates + reprojection. "hamer_reproj2", ] @jdc.pytree_dataclass class JaxGuidanceParams: prior_quat_weight: float = 1.0 prior_pos_weight: float = 5.0 body_quat_vel_smoothness_weight: float = 5.0 body_quat_smoothness_weight: float = 1.0 body_quat_delta_smoothness_weight: float = 10.0 skate_weight: float = 30.0 # Note: this should be quite high. If the hand quaternions aren't # constrained enough the reprojecction loss can get wild. hand_quats: jdc.Static[bool] = True hand_quat_weight = 5.0 hand_quat_priors: jdc.Static[bool] = True hand_quat_prior_weight = 0.1 hand_quat_smoothness_weight = 1.0 hamer_reproj: jdc.Static[bool] = True hand_reproj_weight: float = 1.0 hamer_wrist_pose: jdc.Static[bool] = True hamer_abspos_weight: float = 20.0 hamer_ori_weight: float = 5.0 aria_wrists: jdc.Static[bool] = True aria_wrist_pos_weight: float = 50.0 aria_wrist_ori_weight: float = 10.0 # Optimization parameters. lambda_initial: float = 0.1 max_iters: jdc.Static[int] = 20 @staticmethod def defaults( mode: GuidanceMode, phase: Literal["inner", "post"], ) -> JaxGuidanceParams: if mode == "no_hands": return { "inner": JaxGuidanceParams( hand_quats=False, hand_quat_priors=False, hamer_reproj=False, hamer_wrist_pose=False, aria_wrists=False, max_iters=5, ), "post": JaxGuidanceParams( hand_quats=False, hand_quat_priors=False, hamer_reproj=False, hamer_wrist_pose=False, aria_wrists=False, max_iters=20, ), }[phase] elif mode == "aria_wrist_only": return { "inner": JaxGuidanceParams( hand_quats=False, hand_quat_priors=True, hamer_reproj=False, hamer_wrist_pose=False, aria_wrists=True, max_iters=5, ), "post": JaxGuidanceParams( hand_quats=False, hand_quat_priors=True, hamer_reproj=False, hamer_wrist_pose=False, aria_wrists=True, max_iters=20, ), }[phase] elif mode == "aria_hamer": return { "inner": JaxGuidanceParams( hand_quats=True, hand_quat_priors=True, hamer_reproj=False, hamer_wrist_pose=False, aria_wrists=True, max_iters=5, ), "post": JaxGuidanceParams( hand_quats=True, hand_quat_priors=True, hamer_reproj=False, hamer_wrist_pose=False, aria_wrists=True, max_iters=20, ), }[phase] elif mode == "hamer_wrist": return { "inner": JaxGuidanceParams( hand_quats=True, hand_quat_priors=True, # NOTE: we turn off reprojection during the inner loop optimization. hamer_reproj=False, hamer_wrist_pose=True, aria_wrists=False, max_iters=5, ), "post": JaxGuidanceParams( hand_quats=True, hand_quat_priors=True, # Turn on reprojection. hamer_reproj=False, hamer_wrist_pose=True, aria_wrists=False, max_iters=20, ), }[phase] elif mode == "hamer_reproj2": return { "inner": JaxGuidanceParams( hand_quats=True, hand_quat_priors=True, # NOTE: we turn off reprojection during the inner loop optimization. hamer_reproj=False, hamer_wrist_pose=True, aria_wrists=False, max_iters=5, ), "post": JaxGuidanceParams( hand_quats=True, hand_quat_priors=True, # Turn on reprojection. hamer_reproj=True, hamer_wrist_pose=True, aria_wrists=False, max_iters=20, ), }[phase] else: assert_never(mode) def _optimize( Ts_world_cpf: jax.Array, body: fncsmpl_jax.SmplhModel, betas: jax.Array, body_rotmats: jax.Array, hand_rotmats: jax.Array, contacts: jax.Array, guidance_params: JaxGuidanceParams, hamer_detections: dict | None, aria_detections: dict | None, verbose: bool, ) -> tuple[jax.Array, dict]: """Apply constraints using Levenberg-Marquardt optimizer. Returns updated body_rotmats and hand_rotmats matrices.""" timesteps = body_rotmats.shape[0] assert Ts_world_cpf.shape == (timesteps, 7) assert body_rotmats.shape == (timesteps, 21, 3, 3) assert hand_rotmats.shape == (timesteps, 30, 3, 3) assert contacts.shape == (timesteps, 21) assert betas.shape == (timesteps, 16) init_quats = jaxlie.SO3.from_matrix( # body_rotmats jnp.concatenate([body_rotmats, hand_rotmats], axis=1) ).wxyz assert init_quats.shape == (timesteps, 51, 4) # Assume body shape is time-invariant. shaped_body = body.with_shape(jnp.mean(betas, axis=0)) T_head_cpf = shaped_body.get_T_head_cpf() T_cpf_head = jaxlie.SE3(T_head_cpf).inverse().parameters() assert T_cpf_head.shape == (7,) init_posed = shaped_body.with_pose( jaxlie.SE3.identity(batch_axes=(timesteps,)).wxyz_xyz, init_quats ) T_world_head = jaxlie.SE3(Ts_world_cpf) @ jaxlie.SE3(T_cpf_head) T_root_head = jaxlie.SE3(init_posed.Ts_world_joint[:, 14]) init_posed = init_posed.with_new_T_world_root( (T_world_head @ T_root_head.inverse()).wxyz_xyz ) del T_world_head del T_root_head foot_joint_indices = jnp.array([6, 7, 9, 10]) num_foot_joints = foot_joint_indices.shape[0] contacts = contacts[..., foot_joint_indices] pairwise_contacts = (contacts[:-1, :] + contacts[1:, :]) / 2.0 assert pairwise_contacts.shape == (timesteps - 1, num_foot_joints) del contacts # We'll populate a list of factors (cost terms). factors = list[jaxls.Cost]() def cost_with_args[*CostArgs]( *args: Unpack[tuple[*CostArgs]], ) -> Callable[ [Callable[[jaxls.VarValues, *CostArgs], jax.Array]], Callable[[jaxls.VarValues, *CostArgs], jax.Array], ]: """Decorator for appending to the factor list.""" def inner( cost_func: Callable[[jaxls.VarValues, *CostArgs], jax.Array], ) -> Callable[[jaxls.VarValues, *CostArgs], jax.Array]: factors.append(jaxls.Cost(cost_func, args)) return cost_func return inner def do_forward_kinematics( vals: jaxls.VarValues, var: _SmplhBodyPosesVar, left_hand: _SmplhSingleHandPosesVar | None = None, right_hand: _SmplhSingleHandPosesVar | None = None, output_frame: Literal["world", "root"] = "world", ) -> fncsmpl_jax.SmplhShapedAndPosed: """Helper for computing forward kinematics from variables.""" assert (left_hand is None) == (right_hand is None) if left_hand is None and right_hand is None: posed = shaped_body.with_pose( T_world_root=jaxlie.SE3.identity().wxyz_xyz, local_quats=vals[var], ) elif left_hand is not None and right_hand is None: posed = shaped_body.with_pose( T_world_root=jaxlie.SE3.identity().wxyz_xyz, local_quats=jnp.concatenate([vals[var], vals[left_hand]], axis=-2), ) elif left_hand is not None and right_hand is not None: posed = shaped_body.with_pose( T_world_root=jaxlie.SE3.identity().wxyz_xyz, local_quats=jnp.concatenate( [vals[var], vals[left_hand], vals[right_hand]], axis=-2 ), ) else: assert False if output_frame == "world": T_world_root = ( # T_world_cpf jaxlie.SE3(Ts_world_cpf[var.id, :]) # T_cpf_head @ jaxlie.SE3(T_cpf_head) # T_head_root @ jaxlie.SE3(posed.Ts_world_joint[14]).inverse() ) return posed.with_new_T_world_root(T_world_root.wxyz_xyz) elif output_frame == "root": return posed # HaMeR pose cost. if hamer_detections is not None and guidance_params.hand_quat_priors: hamer_left = hamer_detections["detections_left_concat"] hamer_right = hamer_detections["detections_right_concat"] # HaMeR local quaternion smoothness. @( cost_with_args( _SmplhSingleHandPosesVar(jnp.arange(timesteps * 2 - 2)), _SmplhSingleHandPosesVar(jnp.arange(2, timesteps * 2)), ) ) def hand_smoothness( vals: jaxls.VarValues, hand_pose: _SmplhSingleHandPosesVar, hand_pose_next: _SmplhSingleHandPosesVar, ) -> jax.Array: return ( guidance_params.hand_quat_smoothness_weight * ( jaxlie.SO3(vals[hand_pose]).inverse() @ jaxlie.SO3(vals[hand_pose_next]) ) .log() .flatten() ) # Hand prior loss. @cost_with_args( _SmplhSingleHandPosesVar(jnp.arange(timesteps * 2)), init_quats[:, 21:51, :].reshape((timesteps * 2, 15, 4)), ) def hand_prior( vals: jaxls.VarValues, hand_pose: _SmplhSingleHandPosesVar, init_hand_quats: jax.Array, ) -> jax.Array: return ( guidance_params.hand_quat_prior_weight * (jaxlie.SO3(vals[hand_pose]).inverse() @ jaxlie.SO3(init_hand_quats)) .log() .flatten() ) if hamer_detections is not None and guidance_params.hand_quats: hamer_left = hamer_detections["detections_left_concat"] hamer_right = hamer_detections["detections_right_concat"] # HaMeR local pose matching. @( cost_with_args( _SmplhSingleHandPosesVar(hamer_left["indices"] * 2), hamer_left["single_hand_quats"], ) if hamer_left is not None else lambda x: x ) @( cost_with_args( _SmplhSingleHandPosesVar(hamer_right["indices"] * 2 + 1), hamer_right["single_hand_quats"], ) if hamer_right is not None else lambda x: x ) def hamer_local_pose_cost( vals: jaxls.VarValues, hand_pose: _SmplhSingleHandPosesVar, estimated_hand_quats: jax.Array, ) -> jax.Array: hand_quats = vals[hand_pose] assert hand_quats.shape == estimated_hand_quats.shape return guidance_params.hand_quat_weight * ( (jaxlie.SO3(hand_quats).inverse() @ jaxlie.SO3(estimated_hand_quats)) .log() .flatten() ) if hamer_detections is not None and ( guidance_params.hamer_reproj and guidance_params.hamer_wrist_pose ): hamer_left = hamer_detections["detections_left_concat"] hamer_right = hamer_detections["detections_right_concat"] # HaMeR reprojection. mano_from_openpose_indices = _get_mano_from_openpose_indices(include_tips=False) @( cost_with_args( _SmplhBodyPosesVar(hamer_left["indices"]), _SmplhSingleHandPosesVar(hamer_left["indices"] * 2), _SmplhSingleHandPosesVar(hamer_left["indices"] * 2 + 1), jnp.full_like(hamer_left["indices"], fill_value=0), hamer_left["keypoints_3d"], hamer_left["mano_hand_global_orient"], ) if hamer_left is not None else lambda x: x ) @( cost_with_args( _SmplhBodyPosesVar(hamer_right["indices"]), _SmplhSingleHandPosesVar(hamer_right["indices"] * 2), _SmplhSingleHandPosesVar(hamer_right["indices"] * 2 + 1), jnp.full_like(hamer_right["indices"], fill_value=1), hamer_right["keypoints_3d"], hamer_right["mano_hand_global_orient"], ) if hamer_right is not None else lambda x: x ) def hamer_wrist_and_reproj( vals: jaxls.VarValues, body_pose: _SmplhBodyPosesVar, left_hand_pose: _SmplhSingleHandPosesVar, right_hand_pose: _SmplhSingleHandPosesVar, left0_right1: jax.Array, # Set to 0 for left, 1 for right. keypoints3d_wrt_cam: jax.Array, # These are in OpenPose order!! Rmat_cam_wrist: jax.Array, ) -> jax.Array: posed = do_forward_kinematics( # The right hand comes _after_ the left hand, we can exclude it. vals, body_pose, left_hand_pose, right_hand_pose, output_frame="root", ) Ts_root_joint = posed.Ts_world_joint # Sorry for the naming... del posed # 19 for left wrist, 20 for right wrist. wrist_index = 19 + left0_right1 hand_start_index = 21 + 15 * left0_right1 assert Ts_root_joint.shape == (51, 7) joint_positions_wrt_root = Ts_root_joint[:, 4:7] mano_joints_wrt_root = jnp.concatenate( [ jax.lax.dynamic_slice_in_dim( joint_positions_wrt_root, start_index=wrist_index, slice_size=1, axis=-2, ), jax.lax.dynamic_slice_in_dim( joint_positions_wrt_root, start_index=hand_start_index, slice_size=15, axis=-2, ), ], axis=0, ) assert mano_joints_wrt_root.shape == (16, 3) assert keypoints3d_wrt_cam.shape == (21, 3) # In OpenPose. T_cam_root = ( # T_cam_cpf (7,) jaxlie.SE3(hamer_detections["T_cpf_cam"]).inverse() # T_cpf_head (7,) @ jaxlie.SE3(T_cpf_head) # T_head_root (7,) @ jaxlie.SE3(Ts_root_joint[14, :]).inverse() ) assert T_cam_root.parameters().shape == (7,) mano_joints_wrt_cam = T_cam_root @ mano_joints_wrt_root obs_joints_wrt_cam = keypoints3d_wrt_cam[mano_from_openpose_indices, :] mano_uv_wrt_cam = mano_joints_wrt_cam[:, :2] / mano_joints_wrt_cam[:, 2:3] obs_uv_wrt_cam = obs_joints_wrt_cam[:, :2] / obs_joints_wrt_cam[:, 2:3] T_cam_wrist = jaxlie.SE3.from_rotation_and_translation( T_cam_root.rotation() @ jaxlie.SO3(Ts_root_joint[wrist_index, :4]), mano_joints_wrt_cam[0, :], ) obs_T_cam_wrist = jaxlie.SE3.from_rotation_and_translation( jaxlie.SO3.from_matrix(Rmat_cam_wrist), obs_joints_wrt_cam[0, :], ) return jnp.concatenate( [ (T_cam_wrist.inverse() @ obs_T_cam_wrist).log() * jnp.array( [guidance_params.hamer_abspos_weight] * 3 + [guidance_params.hamer_ori_weight] * 3 ), guidance_params.hand_reproj_weight * (mano_uv_wrt_cam - obs_uv_wrt_cam).flatten(), ] ) elif ( hamer_detections is not None and not guidance_params.hamer_reproj and guidance_params.hamer_wrist_pose ): hamer_left = hamer_detections["detections_left_concat"] hamer_right = hamer_detections["detections_right_concat"] @( cost_with_args( _SmplhBodyPosesVar(hamer_left["indices"]), jnp.full_like(hamer_left["indices"], fill_value=0), hamer_left["keypoints_3d"], hamer_left["mano_hand_global_orient"], ) if hamer_left is not None else lambda x: x ) @( cost_with_args( _SmplhBodyPosesVar(hamer_right["indices"]), jnp.full_like(hamer_right["indices"], fill_value=1), hamer_right["keypoints_3d"], hamer_right["mano_hand_global_orient"], ) if hamer_right is not None else lambda x: x ) def hamer_wrist_only( vals: jaxls.VarValues, body_pose: _SmplhBodyPosesVar, left0_right1: jax.Array, # Set to 0 for left, 1 for right. keypoints3d_wrt_cam: jax.Array, # These are in OpenPose order!! Rmat_cam_wrist: jax.Array, ) -> jax.Array: posed = do_forward_kinematics(vals, body_pose, output_frame="root") Ts_root_joint = posed.Ts_world_joint # Sorry for the naming... del posed # 19 for left wrist, 20 for right wrist. wrist_index = 19 + left0_right1 assert Ts_root_joint.shape == (21, 7) wrist_position_wrt_root = Ts_root_joint[wrist_index, 4:7] T_cam_root = ( # T_cam_cpf (7,) jaxlie.SE3(hamer_detections["T_cpf_cam"]).inverse() # T_cpf_head (7,) @ jaxlie.SE3(T_cpf_head) # T_head_root (7,) @ jaxlie.SE3(Ts_root_joint[14, :]).inverse() ) assert T_cam_root.parameters().shape == (7,) wrist_position_wrt_cam = T_cam_root @ wrist_position_wrt_root # Assumes OpenPose root is same as Mano root!! wrist_pos_wrt_cam = keypoints3d_wrt_cam[0, :] T_cam_wrist = jaxlie.SE3.from_rotation_and_translation( T_cam_root.rotation() @ jaxlie.SO3(Ts_root_joint[wrist_index, :4]), wrist_position_wrt_cam, ) obs_T_cam_wrist = jaxlie.SE3.from_rotation_and_translation( jaxlie.SO3.from_matrix(Rmat_cam_wrist), wrist_pos_wrt_cam, ) return (T_cam_wrist.inverse() @ obs_T_cam_wrist).log() * jnp.array( [guidance_params.hamer_abspos_weight] * 3 + [guidance_params.hamer_ori_weight] * 3 ) # Wrist pose cost. if aria_detections is not None and guidance_params.aria_wrists: aria_left = aria_detections["detections_left_concat"] aria_right = aria_detections["detections_right_concat"] @( cost_with_args( _SmplhBodyPosesVar(aria_left["indices"]), aria_left["confidence"], aria_left["wrist_position"], aria_left["palm_position"], aria_left["palm_normal"], jnp.full_like(aria_left["indices"], fill_value=0), ) if aria_left is not None else lambda x: x ) @( cost_with_args( _SmplhBodyPosesVar(aria_right["indices"]), aria_right["confidence"], aria_right["wrist_position"], aria_right["palm_position"], aria_right["palm_normal"], jnp.full_like(aria_right["indices"], fill_value=1), ) if aria_right is not None else lambda x: x ) def wrist_pose_cost( vals: jaxls.VarValues, pose: _SmplhBodyPosesVar, confidence: jax.Array, wrist_position: jax.Array, palm_position: jax.Array, palm_normal: jax.Array, left0_right1: jax.Array, # Set to 0 for left, 1 for right. ) -> jax.Array: assert wrist_position.shape == (3,) assert left0_right1.shape == () posed = do_forward_kinematics(vals, pose) T_world_wrist = posed.Ts_world_joint[19 + left0_right1] pos_cost = ( # Left wrist is joint 19, right is joint 20. T_world_wrist[4:7] - wrist_position ) # Estimate wrist orientation from forward + normal directions. palm_forward = palm_position - wrist_position palm_forward = palm_forward / jnp.linalg.norm(palm_forward) palm_normal = palm_normal / jnp.linalg.norm(palm_normal) palm_forward = ( # Flip palm forward if right hand. palm_forward * jnp.array([1, -1])[left0_right1] ) palm_forward = ( # Gram-schmidt for forward direction. palm_forward - jnp.dot(palm_forward, palm_normal) * palm_normal ) estimatedR_world_wrist = jaxlie.SO3.from_matrix( jnp.stack( [ palm_forward, -palm_normal, jnp.cross(palm_normal, palm_forward), ], axis=1, ) ) R_world_wrist = jaxlie.SO3(T_world_wrist[:4]) ori_cost = (estimatedR_world_wrist.inverse() @ R_world_wrist).log() return confidence * jnp.concatenate( [ guidance_params.aria_wrist_pos_weight * pos_cost, guidance_params.aria_wrist_ori_weight * ori_cost, ] ) # Per-frame regularization cost. @cost_with_args( _SmplhBodyPosesVar(jnp.arange(timesteps)), ) def reg_cost( vals: jaxls.VarValues, pose: _SmplhBodyPosesVar, ) -> jax.Array: posed = do_forward_kinematics(vals, pose) torso_indices = jnp.array([0, 1, 2, 5, 8]) return jnp.concatenate( [ guidance_params.prior_quat_weight * ( jaxlie.SO3(vals[pose]).inverse() @ jaxlie.SO3(init_quats[pose.id, :21, :]) ) .log() .flatten(), # Only include some torso joints. guidance_params.prior_pos_weight * ( posed.Ts_world_joint[torso_indices, 4:7] - init_posed.Ts_world_joint[pose.id, torso_indices, 4:7] ).flatten(), ] ) @cost_with_args( _SmplhBodyPosesVar(jnp.arange(timesteps - 1)), _SmplhBodyPosesVar(jnp.arange(1, timesteps)), ) def delta_smoothness_cost( vals: jaxls.VarValues, current: _SmplhBodyPosesVar, next: _SmplhBodyPosesVar, ) -> jax.Array: curdelt = jaxlie.SO3(vals[current]).inverse() @ jaxlie.SO3( init_quats[current.id, :21, :] ) nexdelt = jaxlie.SO3(vals[next]).inverse() @ jaxlie.SO3( init_quats[next.id, :21, :] ) return jnp.concatenate( [ guidance_params.body_quat_delta_smoothness_weight * (curdelt.inverse() @ nexdelt).log().flatten(), guidance_params.body_quat_smoothness_weight * (jaxlie.SO3(vals[current]).inverse() @ jaxlie.SO3(vals[next])) .log() .flatten(), ] ) @cost_with_args( _SmplhBodyPosesVar(jnp.arange(timesteps - 2)), _SmplhBodyPosesVar(jnp.arange(1, timesteps - 1)), _SmplhBodyPosesVar(jnp.arange(2, timesteps)), ) def vel_smoothness_cost( vals: jaxls.VarValues, t0: _SmplhBodyPosesVar, t1: _SmplhBodyPosesVar, t2: _SmplhBodyPosesVar, ) -> jax.Array: curdelt = jaxlie.SO3(vals[t0]).inverse() @ jaxlie.SO3(vals[t1]) nexdelt = jaxlie.SO3(vals[t1]).inverse() @ jaxlie.SO3(vals[t2]) return ( guidance_params.body_quat_vel_smoothness_weight * (curdelt.inverse() @ nexdelt).log().flatten() ) @cost_with_args( _SmplhBodyPosesVar(jnp.arange(timesteps - 1)), _SmplhBodyPosesVar(jnp.arange(1, timesteps)), pairwise_contacts, ) def skating_cost( vals: jaxls.VarValues, current: _SmplhBodyPosesVar, next: _SmplhBodyPosesVar, foot_contacts: jax.Array, ) -> jax.Array: # Do forward kinematics. posed_current = do_forward_kinematics(vals, current) posed_next = do_forward_kinematics(vals, next) footpos_current = posed_current.Ts_world_joint[foot_joint_indices, 4:7] footpos_next = posed_next.Ts_world_joint[foot_joint_indices, 4:7] assert footpos_current.shape == footpos_next.shape == (num_foot_joints, 3) assert foot_contacts.shape == (num_foot_joints,) return ( guidance_params.skate_weight * (foot_contacts[:, None] * (footpos_current - footpos_next)).flatten() ) vars_body_pose = _SmplhBodyPosesVar(jnp.arange(timesteps)) vars_hand_pose = _SmplhSingleHandPosesVar(jnp.arange(timesteps * 2)) graph = jaxls.LeastSquaresProblem( costs=factors, variables=[vars_body_pose, vars_hand_pose] ).analyze() solutions = graph.solve( initial_vals=jaxls.VarValues.make( [ vars_body_pose.with_value(init_quats[:, :21, :]), vars_hand_pose.with_value( init_quats[:, 21:51, :].reshape((timesteps * 2, 15, 4)) ), ] ), linear_solver="conjugate_gradient", trust_region=jaxls.TrustRegionConfig( lambda_initial=guidance_params.lambda_initial ), termination=jaxls.TerminationConfig(max_iterations=guidance_params.max_iters), verbose=verbose, ) out_body_quats = solutions[_SmplhBodyPosesVar] assert out_body_quats.shape == (timesteps, 21, 4) out_hand_quats = solutions[_SmplhSingleHandPosesVar].reshape((timesteps, 30, 4)) assert out_hand_quats.shape == (timesteps, 30, 4) return ( jnp.concatenate([out_body_quats, out_hand_quats], axis=-2), {}, # Metadata dict that we use for debugging. ) def _get_mano_from_openpose_indices(include_tips: bool) -> Int[onp.ndarray, "21"]: # https://github.com/geopavlakos/hamer/blob/272d68f176e0ea8a506f761663dd3dca4a03ced0/hamer/models/mano_wrapper.py#L20 # fmt: off mano_to_openpose = [0, 13, 14, 15, 16, 1, 2, 3, 17, 4, 5, 6, 18, 10, 11, 12, 19, 7, 8, 9, 20] # fmt: on openpose_from_mano_idx = { mano_idx: openpose_idx for openpose_idx, mano_idx in enumerate(mano_to_openpose) } return onp.array( [openpose_from_mano_idx[i] for i in range(21 if include_tips else 16)] ) ================================================ FILE: src/egoallo/hand_detection_structs.py ================================================ """Data structure definition that we use for hand detections. We'll run HaMeR, produce the dictionary defined by `SavedHamerOutputs`, then pickle this dictionary. """ from __future__ import annotations import pickle from pathlib import Path from typing import Protocol, TypedDict, cast import numpy as np import torch from jaxtyping import Float, Int from projectaria_tools.core import mps from projectaria_tools.core.mps.utils import get_nearest_wrist_and_palm_pose from torch import Tensor from .tensor_dataclass import TensorDataclass from .transforms import SE3, SO3 class SingleHandHamerOutputWrtCamera(TypedDict): """Hand outputs with respect to the camera frame. For use in pickle files.""" verts: np.ndarray keypoints_3d: np.ndarray mano_hand_pose: np.ndarray mano_hand_betas: np.ndarray mano_hand_global_orient: np.ndarray class SavedHamerOutputs(TypedDict): """Outputs from the HAMeR hand detection algorithm. This is the structure to pickle. `detections_left_wrt_cam` and `detections_right_wrt_cam` use nanosecond timestamps as keys. """ mano_faces_right: np.ndarray mano_faces_left: np.ndarray detections_left_wrt_cam: dict[int, SingleHandHamerOutputWrtCamera | None] detections_right_wrt_cam: dict[int, SingleHandHamerOutputWrtCamera | None] T_device_cam: np.ndarray # wxyz_xyz T_cpf_cam: np.ndarray # wxyz_xyz class AriaHandWristPoseWrtWorld(TensorDataclass): confidence: Float[Tensor, "n_detections"] wrist_position: Float[Tensor, "n_detections 3"] wrist_normal: Float[Tensor, "n_detections 3"] palm_position: Float[Tensor, "n_detections 3"] palm_normal: Float[Tensor, "n_detections 3"] indices: Int[Tensor, "n_detections"] class CorrespondedAriaHandWristPoseDetections(TensorDataclass): detections_left_concat: AriaHandWristPoseWrtWorld | None detections_right_concat: AriaHandWristPoseWrtWorld | None @staticmethod def load( wrist_and_palm_poses_csv_path: Path, target_timestamps_sec: tuple[float, ...], Ts_world_device: Float[np.ndarray, "timesteps 7"], ) -> CorrespondedAriaHandWristPoseDetections: # API from runtime inspection of `projectaria_tools` outputs. class WristAndPalmNormals(Protocol): wrist_normal_device: np.ndarray palm_normal_device: np.ndarray class OneSide(Protocol): confidence: float wrist_position_device: np.ndarray palm_position_device: np.ndarray wrist_and_palm_normal_device: WristAndPalmNormals wp_poses = mps.hand_tracking.read_wrist_and_palm_poses( str(wrist_and_palm_poses_csv_path) ) detections_left = list[OneSide]() detections_right = list[OneSide]() indices_left = list[int]() indices_right = list[int]() for i, time_sec in enumerate(target_timestamps_sec): wp_pose = get_nearest_wrist_and_palm_pose(wp_poses, int(time_sec * 1e9)) if ( wp_pose is None or abs(wp_pose.tracking_timestamp.total_seconds() - time_sec) >= 1.0 / 30.0 ): continue if wp_pose.left_hand is not None and wp_pose.left_hand.confidence > 0.7: indices_left.append(i) detections_left.append(wp_pose.left_hand) if wp_pose.right_hand is not None and wp_pose.right_hand.confidence > 0.7: indices_right.append(i) detections_right.append(wp_pose.right_hand) def form_detections_concat( detections: list[OneSide], indices: list[int] ) -> AriaHandWristPoseWrtWorld | None: assert len(detections) == len(indices) if len(indices) == 0: return None Tslice_world_device = SE3( torch.from_numpy(Ts_world_device[np.array(indices), :]).to( dtype=torch.float32 ) ) Rslice_world_device = SO3( torch.from_numpy(Ts_world_device[np.array(indices), :4]).to( dtype=torch.float32 ) ) return AriaHandWristPoseWrtWorld( confidence=torch.from_numpy( np.array([d.confidence for d in detections]) ), wrist_position=Tslice_world_device @ torch.from_numpy( np.array( [d.wrist_position_device for d in detections], dtype=np.float32 ) ), wrist_normal=Rslice_world_device @ torch.from_numpy( np.array( [ d.wrist_and_palm_normal_device.wrist_normal_device for d in detections ], dtype=np.float32, ) ), palm_position=Tslice_world_device @ torch.from_numpy( np.array( [d.palm_position_device for d in detections], dtype=np.float32 ) ), palm_normal=Rslice_world_device @ torch.from_numpy( np.array( [ d.wrist_and_palm_normal_device.palm_normal_device for d in detections ], dtype=np.float32, ) ), indices=torch.from_numpy(np.array(indices, dtype=np.int64)), ) return CorrespondedAriaHandWristPoseDetections( detections_left_concat=form_detections_concat( detections_left, indices_left ), detections_right_concat=form_detections_concat( detections_right, indices_right ), ) class SingleHandHamerOutputWrtCameraConcatenated(TensorDataclass): verts: Float[Tensor, "n_detections n_verts 3"] keypoints_3d: Float[Tensor, "n_detections n_keypoints 3"] mano_hand_global_orient: Float[Tensor, "n_detections 3 3"] single_hand_quats: Float[Tensor, "n_detections 15 3"] indices: Int[Tensor, "n_detections"] class CorrespondedHamerDetections(TensorDataclass): mano_faces_right: Tensor mano_faces_left: Tensor detections_left_tuple: tuple[None | SingleHandHamerOutputWrtCamera, ...] detections_right_tuple: tuple[None | SingleHandHamerOutputWrtCamera, ...] T_cpf_cam: Tensor focal_length: float # Concatenated detections will be None if there are no detections at all. detections_left_concat: None | SingleHandHamerOutputWrtCameraConcatenated detections_right_concat: None | SingleHandHamerOutputWrtCameraConcatenated def get_length(self) -> int: assert len(self.detections_left_tuple) == len(self.detections_right_tuple) return len(self.detections_left_tuple) def slice(self, start_index: int, end_index: int) -> CorrespondedHamerDetections: """Slice the hand detections. Removes unused hand detections, and shifts indices as necessary.""" assert start_index < end_index def _get_detections_in_window( detections_side_concat: None | SingleHandHamerOutputWrtCameraConcatenated, ) -> None | SingleHandHamerOutputWrtCameraConcatenated: if detections_side_concat is None: return None else: indices = detections_side_concat.indices indices_mask = (indices >= start_index) & (indices < end_index) out = detections_side_concat.map(lambda x: x[indices_mask].clone()) out.indices -= start_index return out return CorrespondedHamerDetections( self.mano_faces_right, self.mano_faces_left, self.detections_left_tuple[start_index:end_index], self.detections_right_tuple[start_index:end_index], T_cpf_cam=self.T_cpf_cam, focal_length=self.focal_length, detections_left_concat=_get_detections_in_window( self.detections_left_concat ), detections_right_concat=_get_detections_in_window( self.detections_right_concat ), ) @staticmethod def load( hand_pkl_path: Path, target_timestamps_sec: tuple[float, ...], ) -> CorrespondedHamerDetections: """Helper which takes as input: (1) A path to a pickle file containing hand detections through time. See feb25_hamer_outputs_from_vrs.py for how this is generated. (2) A set of target timestamps, sorted, in seconds. We then output a data structure that has hand detections (or `None`) for each target timestamp. """ with open(hand_pkl_path, "rb") as f: hamer_out = cast(SavedHamerOutputs, pickle.load(f)) def match_detections_to_targets( detections_wrt_cam: dict[int, None | SingleHandHamerOutputWrtCamera], ) -> list[None | SingleHandHamerOutputWrtCamera]: # Approximate the frame rate of the detections. est_fps = len(detections_wrt_cam) / ( (max(detections_wrt_cam.keys()) - min(detections_wrt_cam.keys())) / 1e9 ) # Usually framerate is either 10 FPS or 30 FPS. We might want to # run on 1 FPS video in the future, we can tweak this assert if we # run into that... assert 5 < est_fps < 40 # Get nanosecond timestamps within our target timestamp window. # Note that input dictionary keys are nanosecond timestamps! detect_ns = sorted( [ time_ns for time_ns in detections_wrt_cam.keys() if time_ns / 1e9 >= target_timestamps_sec[0] - 1 / est_fps and time_ns / 1e9 <= target_timestamps_sec[-1] + 1 / est_fps ] ) delta_matrix = np.abs( np.array(target_timestamps_sec)[:, None] - np.array(detect_ns)[None, :] / 1e9 ) # For each target, which is the closest detection? best_det_from_target = np.argmin(delta_matrix, axis=-1) # For each detection, which is the closest target? best_target_from_det = np.argmin(delta_matrix, axis=0) # Get detection list; we do a cycle-consistency check to make sure # we get a 1-to-1 mapping. out: list[None | SingleHandHamerOutputWrtCamera] = [] for i in range(len(target_timestamps_sec)): if best_target_from_det[best_det_from_target[i]] == i: out.append(detections_wrt_cam[detect_ns[best_det_from_target[i]]]) else: out.append(None) return out detections_left = match_detections_to_targets( hamer_out["detections_left_wrt_cam"] ) detections_right = match_detections_to_targets( hamer_out["detections_right_wrt_cam"] ) assert ( len(detections_left) == len(detections_right) == len(target_timestamps_sec) ) def make_concat_detections( detections_side: list[None | SingleHandHamerOutputWrtCamera], detections_other_side: list[None | SingleHandHamerOutputWrtCamera], ) -> None | SingleHandHamerOutputWrtCameraConcatenated: detections_side_concat = None # Filter out HaMeR detections that are in the same location as each # other. # Sometimes we have a left detections and a right detection both in # the same location. This filters both out. detections_side_filtered: list[None | SingleHandHamerOutputWrtCamera] = [] for i, d in enumerate(detections_side): if d is None: detections_side_filtered.append(None) continue num_d = d["verts"].shape[0] keep_mask = np.ones(d["keypoints_3d"].shape[0], dtype=bool) for offset in range(-15, 15): i_offset = i + offset if i_offset < 0 or i_offset >= len(detections_other_side): continue d_other = detections_other_side[i_offset] if d_other is None: # detections_side_filtered.append(d) continue num_d_other = d_other["verts"].shape[0] dist_matrix = np.linalg.norm( d["keypoints_3d"][:, None, 0, :] - d_other["keypoints_3d"][None, :, 0, :], axis=-1, ) assert dist_matrix.shape == (num_d, num_d_other) keep_mask = np.logical_and( keep_mask, np.all(dist_matrix > 0.1, axis=-1) ) if keep_mask.sum() == 0: detections_side_filtered.append(None) else: detections_side_filtered.append( cast( SingleHandHamerOutputWrtCamera, {k: cast(np.ndarray, v)[keep_mask] for k, v in d.items()}, ) ) del detections_side detections_side_not_none = [d is not None for d in detections_side_filtered] if not any(detections_side_not_none): return None (valid_detection_indices,) = np.where(detections_side_not_none) # We should be done with these. del detections_side_not_none del detections_other_side detections_side_concat = SingleHandHamerOutputWrtCameraConcatenated( verts=torch.from_numpy( np.stack( # Currently: we always just take the first hand detection. [ d["verts"][0] for d in detections_side_filtered if d is not None ] ) ).to(torch.float32), keypoints_3d=torch.from_numpy( np.stack( [ # Currently: we always just take the first hand detection. d["keypoints_3d"][0] for d in detections_side_filtered if d is not None ] ) ).to(torch.float32), mano_hand_global_orient=torch.from_numpy( np.stack( [ # Currently: we always just take the first hand detection. d["mano_hand_global_orient"][0] for d in detections_side_filtered if d is not None ] ) ).to(torch.float32), single_hand_quats=SO3.from_matrix( torch.from_numpy( np.stack( [ # Currently: we always just take the first hand detection. d["mano_hand_pose"][0] for d in detections_side_filtered if d is not None ] ) ).to(torch.float32) ).wxyz, indices=torch.from_numpy(valid_detection_indices), ) return detections_side_concat return CorrespondedHamerDetections( mano_faces_right=torch.from_numpy( hamer_out["mano_faces_right"].astype(np.int64) ), mano_faces_left=torch.from_numpy( hamer_out["mano_faces_left"].astype(np.int64) ), detections_left_tuple=tuple(detections_left), detections_right_tuple=tuple(detections_right), T_cpf_cam=torch.from_numpy(hamer_out["T_cpf_cam"]).to(torch.float32), focal_length=450, detections_left_concat=make_concat_detections( detections_left, detections_right ), detections_right_concat=make_concat_detections( detections_right, detections_left ), ) ================================================ FILE: src/egoallo/inference_utils.py ================================================ """Functions that are useful for inference scripts.""" from __future__ import annotations from dataclasses import dataclass from pathlib import Path import numpy as np import torch import yaml from jaxtyping import Float from projectaria_tools.core import mps # type: ignore from projectaria_tools.core.data_provider import create_vrs_data_provider from safetensors import safe_open from torch import Tensor from .network import EgoDenoiser, EgoDenoiserConfig from .tensor_dataclass import TensorDataclass from .transforms import SE3 def load_denoiser(checkpoint_dir: Path) -> EgoDenoiser: """Load a denoiser model.""" checkpoint_dir = checkpoint_dir.absolute() experiment_dir = checkpoint_dir.parent config = yaml.load( (experiment_dir / "model_config.yaml").read_text(), Loader=yaml.Loader ) assert isinstance(config, EgoDenoiserConfig) model = EgoDenoiser(config) with safe_open(checkpoint_dir / "model.safetensors", framework="pt") as f: # type: ignore state_dict = {k: f.get_tensor(k) for k in f.keys()} model.load_state_dict(state_dict) return model @dataclass(frozen=True) class InferenceTrajectoryPaths: """Paths for running EgoAllo on a single sequence from Project Aria. Our basic assumptions here are: 1. VRS file for images: there is exactly one VRS file in the trajectory root directory. 2. Aria MPS point cloud: there is either one semidense_points.csv.gz file or one global_points.csv.gz file. - Its parent directory should contain other Aria MPS artifacts. (like poses) - This is optionally used for guidance. 3. HaMeR outputs: The hamer_outputs.pkl file may or may not exist in the trajectory root directory. - This is optionally used for guidance. 4. Aria MPS wrist/palm poses: There may be zero or one wrist_and_palm_poses.csv file. - This is optionally used for guidance. 5. Scene splat/ply file: There may be a splat.ply or scene.splat file. - This is only used for visualization. """ vrs_file: Path slam_root_dir: Path points_path: Path hamer_outputs: Path | None wrist_and_palm_poses_csv: Path | None splat_path: Path | None @staticmethod def find(traj_root: Path) -> InferenceTrajectoryPaths: vrs_files = tuple(traj_root.glob("**/*.vrs")) assert len(vrs_files) == 1, f"Found {len(vrs_files)} VRS files!" points_paths = tuple(traj_root.glob("**/semidense_points.csv.gz")) assert len(points_paths) <= 1, f"Found multiple points files! {points_paths}" if len(points_paths) == 0: points_paths = tuple(traj_root.glob("**/global_points.csv.gz")) assert len(points_paths) == 1, f"Found {len(points_paths)} files!" hamer_outputs = traj_root / "hamer_outputs.pkl" if not hamer_outputs.exists(): hamer_outputs = None wrist_and_palm_poses_csv = tuple(traj_root.glob("**/wrist_and_palm_poses.csv")) if len(wrist_and_palm_poses_csv) == 0: wrist_and_palm_poses_csv = None else: assert len(wrist_and_palm_poses_csv) == 1, ( "Found multiple wrist and palm poses files!" ) splat_path = traj_root / "splat.ply" if not splat_path.exists(): splat_path = traj_root / "scene.splat" if not splat_path.exists(): print("No scene splat found.") splat_path = None else: print("Found splat at", splat_path) return InferenceTrajectoryPaths( vrs_file=vrs_files[0], slam_root_dir=points_paths[0].parent, points_path=points_paths[0], hamer_outputs=hamer_outputs, wrist_and_palm_poses_csv=wrist_and_palm_poses_csv[0] if wrist_and_palm_poses_csv else None, splat_path=splat_path, ) class InferenceInputTransforms(TensorDataclass): """Some relevant transforms for inference.""" Ts_world_cpf: Float[Tensor, "timesteps 7"] Ts_world_device: Float[Tensor, "timesteps 7"] pose_timesteps: tuple[float, ...] @staticmethod def load( vrs_path: Path, slam_root_dir: Path, fps: int = 30, ) -> InferenceInputTransforms: """Read some useful transforms via MPS + the VRS calibration.""" # Read device poses. closed_loop_path = slam_root_dir / "closed_loop_trajectory.csv" if not closed_loop_path.exists(): # Aria digital twins. closed_loop_path = slam_root_dir / "aria_trajectory.csv" closed_loop_traj = mps.read_closed_loop_trajectory(str(closed_loop_path)) # type: ignore provider = create_vrs_data_provider(str(vrs_path)) device_calib = provider.get_device_calibration() T_device_cpf = device_calib.get_transform_device_cpf().to_matrix() # Get downsampled CPF frames. aria_fps = len(closed_loop_traj) / ( closed_loop_traj[-1].tracking_timestamp.total_seconds() - closed_loop_traj[0].tracking_timestamp.total_seconds() ) num_poses = len(closed_loop_traj) print(f"Loaded {num_poses=} with {aria_fps=}, visualizing at {fps=}") Ts_world_device = [] Ts_world_cpf = [] out_timestamps_secs = [] for i in range(0, num_poses, int(aria_fps // fps)): T_world_device = closed_loop_traj[i].transform_world_device.to_matrix() assert T_world_device.shape == (4, 4) Ts_world_device.append(T_world_device) Ts_world_cpf.append(T_world_device @ T_device_cpf) out_timestamps_secs.append( closed_loop_traj[i].tracking_timestamp.total_seconds() ) return InferenceInputTransforms( Ts_world_device=SE3.from_matrix(torch.from_numpy(np.array(Ts_world_device))) .parameters() .to(torch.float32), Ts_world_cpf=SE3.from_matrix(torch.from_numpy(np.array(Ts_world_cpf))) .parameters() .to(torch.float32), pose_timesteps=tuple(out_timestamps_secs), ) ================================================ FILE: src/egoallo/metrics_helpers.py ================================================ from typing import Literal, overload import numpy as np import torch from jaxtyping import Float from torch import Tensor from typing_extensions import assert_never from .transforms import SO3 def compute_foot_skate( pred_Ts_world_joint: Float[Tensor, "num_samples time 21 7"], ) -> np.ndarray: (num_samples, time) = pred_Ts_world_joint.shape[:2] # Drop the person to the floor. # This is necessary for the foot skating metric to make sense for floating people...! pred_Ts_world_joint = pred_Ts_world_joint.clone() pred_Ts_world_joint[..., 6] -= torch.min(pred_Ts_world_joint[..., 6]) foot_indices = torch.tensor([6, 7, 9, 10], device=pred_Ts_world_joint.device) foot_positions = pred_Ts_world_joint[:, :, foot_indices, 4:7] foot_positions_diff = foot_positions[:, 1:, :, :2] - foot_positions[:, :-1, :, :2] assert foot_positions_diff.shape == (num_samples, time - 1, 4, 2) foot_positions_diff_norm = torch.sum(torch.abs(foot_positions_diff), dim=-1) assert foot_positions_diff_norm.shape == (num_samples, time - 1, 4) # From EgoEgo / kinpoly. H_thresh = torch.tensor( # To match indices above: (ankle, ankle, toe, toe) [0.08, 0.08, 0.04, 0.04], device=pred_Ts_world_joint.device, dtype=torch.float32, ) foot_positions_diff_norm = torch.sum(torch.abs(foot_positions_diff), dim=-1) assert foot_positions_diff_norm.shape == (num_samples, time - 1, 4) # Threshold. foot_positions_diff_norm = foot_positions_diff_norm * ( foot_positions[..., 1:, :, 2] < H_thresh ) fs_per_sample = torch.sum( torch.sum( foot_positions_diff_norm * (2 - 2 ** (foot_positions[..., 1:, :, 2] / H_thresh)), dim=-1, ), dim=-1, ) assert fs_per_sample.shape == (num_samples,) return fs_per_sample.numpy(force=True) def compute_foot_contact( pred_Ts_world_joint: Float[Tensor, "num_samples time 21 7"], ) -> np.ndarray: (num_samples, time) = pred_Ts_world_joint.shape[:2] foot_indices = torch.tensor([6, 7, 9, 10], device=pred_Ts_world_joint.device) # From EgoEgo / kinpoly. H_thresh = torch.tensor( # To match indices above: (ankle, ankle, toe, toe) [0.08, 0.08, 0.04, 0.04], device=pred_Ts_world_joint.device, dtype=torch.float32, ) foot_positions = pred_Ts_world_joint[:, :, foot_indices, 4:7] any_contact = torch.any( torch.any(foot_positions[..., 2] < H_thresh, dim=-1), dim=-1 ).to(torch.float32) assert any_contact.shape == (num_samples,) return any_contact.numpy(force=True) def compute_head_ori( label_Ts_world_joint: Float[Tensor, "time 21 7"], pred_Ts_world_joint: Float[Tensor, "num_samples time 21 7"], ) -> np.ndarray: (num_samples, time) = pred_Ts_world_joint.shape[:2] matrix_errors = ( SO3(pred_Ts_world_joint[:, :, 14, :4]).as_matrix() @ SO3(label_Ts_world_joint[:, 14, :4]).inverse().as_matrix() ) - torch.eye(3, device=label_Ts_world_joint.device) assert matrix_errors.shape == (num_samples, time, 3, 3) return torch.mean( torch.linalg.norm(matrix_errors.reshape((num_samples, time, 9)), dim=-1), dim=-1, ).numpy(force=True) def compute_head_trans( label_Ts_world_joint: Float[Tensor, "time 21 7"], pred_Ts_world_joint: Float[Tensor, "num_samples time 21 7"], ) -> np.ndarray: (num_samples, time) = pred_Ts_world_joint.shape[:2] errors = pred_Ts_world_joint[:, :, 14, 4:7] - label_Ts_world_joint[:, 14, 4:7] assert errors.shape == (num_samples, time, 3) return torch.mean( torch.linalg.norm(errors, dim=-1), dim=-1, ).numpy(force=True) def compute_mpjpe( label_T_world_root: Float[Tensor, "time 7"], label_Ts_world_joint: Float[Tensor, "time 21 7"], pred_T_world_root: Float[Tensor, "num_samples time 7"], pred_Ts_world_joint: Float[Tensor, "num_samples time 21 7"], per_frame_procrustes_align: bool, ) -> np.ndarray: num_samples, time, _, _ = pred_Ts_world_joint.shape # Concatenate the world root to the joints. label_Ts_world_joint = torch.cat( [label_T_world_root[..., None, :], label_Ts_world_joint], dim=-2 ) pred_Ts_world_joint = torch.cat( [pred_T_world_root[..., None, :], pred_Ts_world_joint], dim=-2 ) del label_T_world_root, pred_T_world_root pred_joint_positions = pred_Ts_world_joint[:, :, :, 4:7] label_joint_positions = label_Ts_world_joint[None, :, :, 4:7].repeat( num_samples, 1, 1, 1 ) if per_frame_procrustes_align: pred_joint_positions = procrustes_align( points_y=pred_joint_positions, points_x=label_joint_positions, output="aligned_x", ) position_differences = pred_joint_positions - label_joint_positions assert position_differences.shape == (num_samples, time, 22, 3) # Per-joint position errors, in millimeters. pjpe = torch.linalg.norm(position_differences, dim=-1) * 1000.0 assert pjpe.shape == (num_samples, time, 22) # Mean per-joint position errors. mpjpe = torch.mean(pjpe.reshape((num_samples, -1)), dim=-1) assert mpjpe.shape == (num_samples,) return mpjpe.cpu().numpy() @overload def procrustes_align( points_y: Float[Tensor, "*#batch N 3"], points_x: Float[Tensor, "*#batch N 3"], output: Literal["transforms"], fix_scale: bool = False, ) -> tuple[Tensor, Tensor, Tensor]: ... @overload def procrustes_align( points_y: Float[Tensor, "*#batch N 3"], points_x: Float[Tensor, "*#batch N 3"], output: Literal["aligned_x"], fix_scale: bool = False, ) -> Tensor: ... def procrustes_align( points_y: Float[Tensor, "*#batch N 3"], points_x: Float[Tensor, "*#batch N 3"], output: Literal["transforms", "aligned_x"], fix_scale: bool = False, ) -> tuple[Tensor, Tensor, Tensor] | Tensor: """Similarity transform alignment using the Umeyama method. Adapted from SLAHMR: https://github.com/vye16/slahmr/blob/main/slahmr/geometry/pcl.py Minimizes: mean( || Y - s * (R @ X) + t ||^2 ) with respect to s, R, and t. Returns an (s, R, t) tuple. """ *dims, N, _ = points_y.shape device = points_y.device N = torch.ones((*dims, 1, 1), device=device) * N # subtract mean my = points_y.sum(dim=-2) / N[..., 0] # (*, 3) mx = points_x.sum(dim=-2) / N[..., 0] y0 = points_y - my[..., None, :] # (*, N, 3) x0 = points_x - mx[..., None, :] # correlation C = torch.matmul(y0.transpose(-1, -2), x0) / N # (*, 3, 3) U, D, Vh = torch.linalg.svd(C) # (*, 3, 3), (*, 3), (*, 3, 3) S = ( torch.eye(3, device=device) .reshape(*(1,) * (len(dims)), 3, 3) .repeat(*dims, 1, 1) ) neg = torch.det(U) * torch.det(Vh.transpose(-1, -2)) < 0 S = torch.where( neg.reshape(*dims, 1, 1), S * torch.diag(torch.tensor([1, 1, -1], device=device)), S, ) R = torch.matmul(U, torch.matmul(S, Vh)) # (*, 3, 3) D = torch.diag_embed(D) # (*, 3, 3) if fix_scale: s = torch.ones(*dims, 1, device=device, dtype=torch.float32) else: var = torch.sum(torch.square(x0), dim=(-1, -2), keepdim=True) / N # (*, 1, 1) s = ( torch.diagonal(torch.matmul(D, S), dim1=-2, dim2=-1).sum( dim=-1, keepdim=True ) / var[..., 0] ) # (*, 1) t = my - s * torch.matmul(R, mx[..., None])[..., 0] # (*, 3) assert s.shape == (*dims, 1) assert R.shape == (*dims, 3, 3) assert t.shape == (*dims, 3) if output == "transforms": return s, R, t elif output == "aligned_x": aligned_x = ( s[..., None, :] * torch.einsum("...ij,...nj->...ni", R, points_x) + t[..., None, :] ) assert aligned_x.shape == points_x.shape return aligned_x else: assert_never(output) ================================================ FILE: src/egoallo/network.py ================================================ from __future__ import annotations from dataclasses import dataclass from functools import cache, cached_property from typing import Literal, assert_never import numpy as np import torch from einops import rearrange from jaxtyping import Bool, Float from loguru import logger from rotary_embedding_torch import RotaryEmbedding from torch import Tensor, nn from .fncsmpl import SmplhModel, SmplhShapedAndPosed from .tensor_dataclass import TensorDataclass from .transforms import SE3, SO3 def project_rotmats_via_svd( rotmats: Float[Tensor, "*batch 3 3"], ) -> Float[Tensor, "*batch 3 3"]: u, s, vh = torch.linalg.svd(rotmats) del s return torch.einsum("...ij,...jk->...ik", u, vh) class EgoDenoiseTraj(TensorDataclass): """Data structure for denoising. Contains tensors that we are denoising, as well as utilities for packing + unpacking them.""" betas: Float[Tensor, "*#batch timesteps 16"] """Body shape parameters. We don't really need the timesteps axis here, it's just for convenience.""" body_rotmats: Float[Tensor, "*#batch timesteps 21 3 3"] """Local orientations for each body joint.""" contacts: Float[Tensor, "*#batch timesteps 21"] """Contact boolean for each joint.""" hand_rotmats: Float[Tensor, "*#batch timesteps 30 3 3"] | None """Local orientations for each body joint.""" @staticmethod def get_packed_dim(include_hands: bool) -> int: packed_dim = 16 + 21 * 9 + 21 if include_hands: packed_dim += 30 * 9 return packed_dim def apply_to_body(self, body_model: SmplhModel) -> SmplhShapedAndPosed: device = self.betas.device dtype = self.betas.dtype assert self.hand_rotmats is not None shaped = body_model.with_shape(self.betas) posed = shaped.with_pose( T_world_root=SE3.identity(device=device, dtype=dtype).parameters(), local_quats=SO3.from_matrix( torch.cat([self.body_rotmats, self.hand_rotmats], dim=-3) ).wxyz, ) return posed def pack(self) -> Float[Tensor, "*#batch timesteps d_state"]: """Pack trajectory into a single flattened vector.""" (*batch, time, num_joints, _, _) = self.body_rotmats.shape assert num_joints == 21 return torch.cat( [ x.reshape((*batch, time, -1)) for x in vars(self).values() if x is not None ], dim=-1, ) @classmethod def unpack( cls, x: Float[Tensor, "*#batch timesteps d_state"], include_hands: bool, project_rotmats: bool = False, ) -> EgoDenoiseTraj: """Unpack trajectory from a single flattened vector. Args: x: Packed trajectory. project_rotmats: If True, project the rotation matrices to SO(3) via SVD. """ (*batch, time, d_state) = x.shape assert d_state == cls.get_packed_dim(include_hands) if include_hands: betas, body_rotmats_flat, contacts, hand_rotmats_flat = torch.split( x, [16, 21 * 9, 21, 30 * 9], dim=-1 ) body_rotmats = body_rotmats_flat.reshape((*batch, time, 21, 3, 3)) hand_rotmats = hand_rotmats_flat.reshape((*batch, time, 30, 3, 3)) assert betas.shape == (*batch, time, 16) else: betas, body_rotmats_flat, contacts = torch.split( x, [16, 21 * 9, 21], dim=-1 ) body_rotmats = body_rotmats_flat.reshape((*batch, time, 21, 3, 3)) hand_rotmats = None assert betas.shape == (*batch, time, 16) if project_rotmats: # We might want to handle the -1 determinant case as well. body_rotmats = project_rotmats_via_svd(body_rotmats) return EgoDenoiseTraj( betas=betas, body_rotmats=body_rotmats, contacts=contacts, hand_rotmats=hand_rotmats, ) @dataclass(frozen=True) class EgoDenoiserConfig: max_t: int = 1000 fourier_enc_freqs: int = 3 d_latent: int = 512 d_feedforward: int = 2048 d_noise_emb: int = 1024 num_heads: int = 4 encoder_layers: int = 6 decoder_layers: int = 6 dropout_p: float = 0.0 activation: Literal["gelu", "relu"] = "gelu" positional_encoding: Literal["transformer", "rope"] = "rope" noise_conditioning: Literal["token", "film"] = "token" xattn_mode: Literal["kv_from_cond_q_from_x", "kv_from_x_q_from_cond"] = ( "kv_from_cond_q_from_x" ) include_canonicalized_cpf_rotation_in_cond: bool = True include_hands: bool = True """Whether to include hand joints (+15 per hand) in the denoised state.""" cond_param: Literal[ "ours", "canonicalized", "absolute", "absrel", "absrel_global_deltas" ] = "ours" """Which conditioning parameterization to use. "ours" is the default, we try to be clever and design something with nice equivariance properties. "canonicalized" contains a transformation that's canonicalized to aligned to the first frame. "absolute" is the naive case, where we just pass in transformations directly. """ include_hand_positions_cond: bool = False """Whether to include hand positions in the conditioning information.""" @cached_property def d_cond(self) -> int: """Dimensionality of conditioning vector.""" if self.cond_param == "ours": d_cond = 0 d_cond += 12 # Relative CPF pose, flattened 3x4 matrix. d_cond += 1 # Floor height. if self.include_canonicalized_cpf_rotation_in_cond: d_cond += 9 # Canonicalized CPF rotation, flattened 3x3 matrix. elif self.cond_param == "canonicalized": d_cond = 12 elif self.cond_param == "absolute": d_cond = 12 elif self.cond_param == "absrel": # Both absolute and relative! d_cond = 24 elif self.cond_param == "absrel_global_deltas": # Both absolute and relative! d_cond = 24 else: assert_never(self.cond_param) # Add two 3D positions to the conditioning dimension if we're including # hand conditioning. if self.include_hand_positions_cond: d_cond = d_cond + 6 d_cond = d_cond + d_cond * self.fourier_enc_freqs * 2 # Fourier encoding. return d_cond def make_cond( self, T_cpf_tm1_cpf_t: Float[Tensor, "batch time 7"], T_world_cpf: Float[Tensor, "batch time 7"], hand_positions_wrt_cpf: Float[Tensor, "batch time 6"] | None, ) -> Float[Tensor, "batch time d_cond"]: """Construct conditioning information from CPF pose.""" (batch, time, _) = T_cpf_tm1_cpf_t.shape # Construct device pose conditioning. if self.cond_param == "ours": # Compute conditioning terms. +Z is up in the world frame. We want # the translation to be invariant to translations in the world X/Y # directions. height_from_floor = T_world_cpf[..., 6:7] cond_parts = [ SE3(T_cpf_tm1_cpf_t).as_matrix()[..., :3, :].reshape((batch, time, 12)), height_from_floor, ] if self.include_canonicalized_cpf_rotation_in_cond: # We want the rotation to be invariant to rotations around the # world Z axis. Visualization of what's happening here: # # https://gist.github.com/brentyi/9226d082d2707132af39dea92b8609f6 # # (The coordinate frame may differ by some axis-swapping # compared to the exact equations in the paper. But to the # network these will all look the same.) R_world_cpf = SE3(T_world_cpf).rotation().wxyz forward_cpf = R_world_cpf.new_tensor([0.0, 0.0, 1.0]) forward_world = SO3(R_world_cpf) @ forward_cpf assert forward_world.shape == (batch, time, 3) R_canonical_world = SO3.from_z_radians( -torch.arctan2(forward_world[..., 1], forward_world[..., 0]) ).wxyz assert R_canonical_world.shape == (batch, time, 4) cond_parts.append( (SO3(R_canonical_world) @ SO3(R_world_cpf)) .as_matrix() .reshape((batch, time, 9)), ) cond = torch.cat(cond_parts, dim=-1) elif self.cond_param == "canonicalized": # Align the first timestep. # Put poses so start is at origin, facing forward. R_world_cpf = SE3(T_world_cpf[:, 0:1, :]).rotation().wxyz forward_cpf = R_world_cpf.new_tensor([0.0, 0.0, 1.0]) forward_world = SO3(R_world_cpf) @ forward_cpf assert forward_world.shape == (batch, 1, 3) R_canonical_world = SO3.from_z_radians( -torch.arctan2(forward_world[..., 1], forward_world[..., 0]) ).wxyz assert R_canonical_world.shape == (batch, 1, 4) R_canonical_cpf = SO3(R_canonical_world) @ SE3(T_world_cpf).rotation() t_canonical_cpf = SO3(R_canonical_world) @ SE3(T_world_cpf).translation() t_canonical_cpf = t_canonical_cpf - t_canonical_cpf[:, 0:1, :] cond = ( SE3.from_rotation_and_translation(R_canonical_cpf, t_canonical_cpf) .as_matrix()[..., :3, :4] .reshape((batch, time, 12)) ) elif self.cond_param == "absolute": cond = SE3(T_world_cpf).as_matrix()[..., :3, :4].reshape((batch, time, 12)) elif self.cond_param == "absrel": cond = torch.concatenate( [ SE3(T_world_cpf) .as_matrix()[..., :3, :4] .reshape((batch, time, 12)), SE3(T_cpf_tm1_cpf_t) .as_matrix()[..., :3, :4] .reshape((batch, time, 12)), ], dim=-1, ) elif self.cond_param == "absrel_global_deltas": cond = torch.concatenate( [ SE3(T_world_cpf) .as_matrix()[..., :3, :4] .reshape((batch, time, 12)), SE3(T_cpf_tm1_cpf_t) .rotation() .as_matrix() .reshape((batch, time, 9)), ( SE3(T_world_cpf).rotation() @ SE3(T_cpf_tm1_cpf_t).inverse().translation() ).reshape((batch, time, 3)), ], dim=-1, ) else: assert_never(self.cond_param) # Condition on hand poses as well. # We didn't use this for the paper. if self.include_hand_positions_cond: if hand_positions_wrt_cpf is None: logger.warning( "Model is looking for hand conditioning but none was provided. Passing in zeros." ) hand_positions_wrt_cpf = torch.zeros( (batch, time, 6), device=T_world_cpf.device ) assert hand_positions_wrt_cpf.shape == (batch, time, 6) cond = torch.cat([cond, hand_positions_wrt_cpf], dim=-1) cond = fourier_encode(cond, freqs=self.fourier_enc_freqs) assert cond.shape == (batch, time, self.d_cond) return cond class EgoDenoiser(nn.Module): """Denoising network for human motion. Inputs are noisy trajectory, conditioning information, and timestep. Output is denoised trajectory. """ def __init__(self, config: EgoDenoiserConfig): super().__init__() self.config = config Activation = {"gelu": nn.GELU, "relu": nn.ReLU}[config.activation] # MLP encoders and decoders for each modality we want to denoise. modality_dims: dict[str, int] = { "betas": 16, "body_rotmats": 21 * 9, "contacts": 21, } if config.include_hands: modality_dims["hand_rotmats"] = 30 * 9 assert sum(modality_dims.values()) == self.get_d_state() self.encoders = nn.ModuleDict( { k: nn.Sequential( nn.Linear(modality_dim, config.d_latent), Activation(), nn.Linear(config.d_latent, config.d_latent), Activation(), nn.Linear(config.d_latent, config.d_latent), ) for k, modality_dim in modality_dims.items() } ) self.decoders = nn.ModuleDict( { k: nn.Sequential( nn.Linear(config.d_latent, config.d_latent), nn.LayerNorm(normalized_shape=config.d_latent), Activation(), nn.Linear(config.d_latent, config.d_latent), Activation(), nn.Linear(config.d_latent, modality_dim), ) for k, modality_dim in modality_dims.items() } ) # Helpers for converting between input dimensionality and latent dimensionality. self.latent_from_cond = nn.Linear(config.d_cond, config.d_latent) # Noise embedder. self.noise_emb = nn.Embedding( # index 0 will be t=1 # index 999 will be t=1000 num_embeddings=config.max_t, embedding_dim=config.d_noise_emb, ) self.noise_emb_token_proj = ( nn.Linear(config.d_noise_emb, config.d_latent, bias=False) if config.noise_conditioning == "token" else None ) # Encoder / decoder layers. # Inputs are conditioning (current noise level, observations); output # is encoded conditioning information. self.encoder_layers = nn.ModuleList( [ TransformerBlock( TransformerBlockConfig( d_latent=config.d_latent, d_noise_emb=config.d_noise_emb, d_feedforward=config.d_feedforward, n_heads=config.num_heads, dropout_p=config.dropout_p, activation=config.activation, include_xattn=False, # No conditioning for encoder. use_rope_embedding=config.positional_encoding == "rope", use_film_noise_conditioning=config.noise_conditioning == "film", xattn_mode=config.xattn_mode, ) ) for _ in range(config.encoder_layers) ] ) self.decoder_layers = nn.ModuleList( [ TransformerBlock( TransformerBlockConfig( d_latent=config.d_latent, d_noise_emb=config.d_noise_emb, d_feedforward=config.d_feedforward, n_heads=config.num_heads, dropout_p=config.dropout_p, activation=config.activation, include_xattn=True, # Include conditioning for the decoder. use_rope_embedding=config.positional_encoding == "rope", use_film_noise_conditioning=config.noise_conditioning == "film", xattn_mode=config.xattn_mode, ) ) for _ in range(config.decoder_layers) ] ) def get_d_state(self) -> int: return EgoDenoiseTraj.get_packed_dim(self.config.include_hands) def forward( self, x_t_packed: Float[Tensor, "batch time state_dim"], t: Float[Tensor, "batch"], *, T_world_cpf: Float[Tensor, "batch time 7"], T_cpf_tm1_cpf_t: Float[Tensor, "batch time 7"], project_output_rotmats: bool, # Observed hand positions, relative to the CPF. hand_positions_wrt_cpf: Float[Tensor, "batch time 6"] | None, # Attention mask for using shorter sequences. mask: Bool[Tensor, "batch time"] | None, # Mask for when to drop out / keep conditioning information. cond_dropout_keep_mask: Bool[Tensor, "batch"] | None = None, ) -> Float[Tensor, "batch time state_dim"]: """Predict a denoised trajectory. Note that `t` refers to a noise level, not a timestep.""" config = self.config x_t = EgoDenoiseTraj.unpack(x_t_packed, include_hands=self.config.include_hands) (batch, time, num_body_joints, _, _) = x_t.body_rotmats.shape assert num_body_joints == 21 # Encode the trajectory into a single vector per timestep. x_t_encoded = ( self.encoders["betas"](x_t.betas.reshape((batch, time, -1))) + self.encoders["body_rotmats"](x_t.body_rotmats.reshape((batch, time, -1))) + self.encoders["contacts"](x_t.contacts) ) if self.config.include_hands: assert x_t.hand_rotmats is not None x_t_encoded = x_t_encoded + self.encoders["hand_rotmats"]( x_t.hand_rotmats.reshape((batch, time, -1)) ) assert x_t_encoded.shape == (batch, time, config.d_latent) # Embed the diffusion noise level. assert t.shape == (batch,) noise_emb = self.noise_emb(t - 1) assert noise_emb.shape == (batch, config.d_noise_emb) # Prepare conditioning information. cond = config.make_cond( T_cpf_tm1_cpf_t, T_world_cpf=T_world_cpf, hand_positions_wrt_cpf=hand_positions_wrt_cpf, ) # Randomly drop out conditioning information; this serves as a # regularizer that aims to improve sample diversity. if cond_dropout_keep_mask is not None: assert cond_dropout_keep_mask.shape == (batch,) cond = cond * cond_dropout_keep_mask[:, None, None] # Prepare encoder and decoder inputs. if config.positional_encoding == "rope": pos_enc = 0 elif config.positional_encoding == "transformer": pos_enc = make_positional_encoding( d_latent=config.d_latent, length=time, dtype=cond.dtype, )[None, ...].to(x_t_encoded.device) assert pos_enc.shape == (1, time, config.d_latent) else: assert_never(config.positional_encoding) encoder_out = self.latent_from_cond(cond) + pos_enc decoder_out = x_t_encoded + pos_enc # Append the noise embedding to the encoder and decoder inputs. # This is weird if we're using rotary embeddings! if self.noise_emb_token_proj is not None: noise_emb_token = self.noise_emb_token_proj(noise_emb) assert noise_emb_token.shape == (batch, config.d_latent) encoder_out = torch.cat([noise_emb_token[:, None, :], encoder_out], dim=1) decoder_out = torch.cat([noise_emb_token[:, None, :], decoder_out], dim=1) assert ( encoder_out.shape == decoder_out.shape == (batch, time + 1, config.d_latent) ) num_tokens = time + 1 else: num_tokens = time # Compute attention mask. This needs to be a fl if mask is None: attn_mask = None else: assert mask.shape == (batch, time) assert mask.dtype == torch.bool if self.noise_emb_token_proj is not None: # Account for noise token. mask = torch.cat([mask.new_ones((batch, 1)), mask], dim=1) # Last two dimensions of mask are (query, key). We're masking out only keys; # it's annoying for the softmax to mask out entire rows without getting NaNs. attn_mask = mask[:, None, None, :].repeat(1, 1, num_tokens, 1) assert attn_mask.shape == (batch, 1, num_tokens, num_tokens) assert attn_mask.dtype == torch.bool # Forward pass through transformer. for layer in self.encoder_layers: encoder_out = layer(encoder_out, attn_mask, noise_emb=noise_emb) for layer in self.decoder_layers: decoder_out = layer( decoder_out, attn_mask, noise_emb=noise_emb, cond=encoder_out ) # Remove the extra token corresponding to the noise embedding. if self.noise_emb_token_proj is not None: decoder_out = decoder_out[:, 1:, :] assert isinstance(decoder_out, Tensor) assert decoder_out.shape == (batch, time, config.d_latent) packed_output = torch.cat( [ # Project rotation matrices for body_rotmats via SVD, ( project_rotmats_via_svd( modality_decoder(decoder_out).reshape((-1, 3, 3)) ).reshape( (batch, time, {"body_rotmats": 21, "hand_rotmats": 30}[key] * 9) ) # if enabled, if project_output_rotmats and key in ("body_rotmats", "hand_rotmats") # otherwise, just decode normally. else modality_decoder(decoder_out) ) for key, modality_decoder in self.decoders.items() ], dim=-1, ) assert packed_output.shape == (batch, time, self.get_d_state()) # Return packed output. return packed_output @cache def make_positional_encoding( d_latent: int, length: int, dtype: torch.dtype ) -> Float[Tensor, "length d_latent"]: """Computes standard Transformer positional encoding.""" pe = torch.zeros(length, d_latent, dtype=dtype) position = torch.arange(0, length, dtype=torch.float).unsqueeze(1) div_term = torch.exp( torch.arange(0, d_latent, 2).float() * (-np.log(10000.0) / d_latent) ) pe[:, 0::2] = torch.sin(position * div_term) pe[:, 1::2] = torch.cos(position * div_term) assert pe.shape == (length, d_latent) return pe def fourier_encode( x: Float[Tensor, "*#batch channels"], freqs: int ) -> Float[Tensor, "*#batch channels+2*freqs*channels"]: """Apply Fourier encoding to a tensor.""" *batch_axes, x_dim = x.shape coeffs = 2.0 ** torch.arange(freqs, device=x.device) scaled = (x[..., None] * coeffs).reshape((*batch_axes, x_dim * freqs)) return torch.cat( [ x, torch.sin(torch.cat([scaled, scaled + torch.pi / 2.0], dim=-1)), ], dim=-1, ) @dataclass(frozen=True) class TransformerBlockConfig: d_latent: int d_noise_emb: int d_feedforward: int n_heads: int dropout_p: float activation: Literal["gelu", "relu"] include_xattn: bool use_rope_embedding: bool use_film_noise_conditioning: bool xattn_mode: Literal["kv_from_cond_q_from_x", "kv_from_x_q_from_cond"] class TransformerBlock(nn.Module): """An even-tempered Transformer block.""" def __init__(self, config: TransformerBlockConfig) -> None: super().__init__() self.sattn_qkv_proj = nn.Linear( config.d_latent, config.d_latent * 3, bias=False ) self.sattn_out_proj = nn.Linear(config.d_latent, config.d_latent, bias=False) self.layernorm1 = nn.LayerNorm(config.d_latent) self.layernorm2 = nn.LayerNorm(config.d_latent) assert config.d_latent % config.n_heads == 0 self.rotary_emb = ( RotaryEmbedding(config.d_latent // config.n_heads) if config.use_rope_embedding else None ) if config.include_xattn: self.xattn_kv_proj = nn.Linear( config.d_latent, config.d_latent * 2, bias=False ) self.xattn_q_proj = nn.Linear(config.d_latent, config.d_latent, bias=False) self.xattn_layernorm = nn.LayerNorm(config.d_latent) self.xattn_out_proj = nn.Linear( config.d_latent, config.d_latent, bias=False ) self.norm_no_learnable = nn.LayerNorm( config.d_feedforward, elementwise_affine=False, bias=False ) self.activation = {"gelu": nn.GELU, "relu": nn.ReLU}[config.activation]() self.dropout = nn.Dropout(config.dropout_p) self.mlp0 = nn.Linear(config.d_latent, config.d_feedforward) self.mlp_film_cond_proj = ( zero_module( nn.Linear(config.d_noise_emb, config.d_feedforward * 2, bias=False) ) if config.use_film_noise_conditioning else None ) self.mlp1 = nn.Linear(config.d_feedforward, config.d_latent) self.config = config def forward( self, x: Float[Tensor, "batch tokens d_latent"], attn_mask: Bool[Tensor, "batch 1 tokens tokens"] | None, noise_emb: Float[Tensor, "batch d_noise_emb"], cond: Float[Tensor, "batch tokens d_latent"] | None = None, ) -> Float[Tensor, "batch tokens d_latent"]: config = self.config (batch, time, d_latent) = x.shape # Self-attention. # We put layer normalization after the residual connection. x = self.layernorm1(x + self._sattn(x, attn_mask)) # Include conditioning. if config.include_xattn: assert cond is not None x = self.xattn_layernorm(x + self._xattn(x, attn_mask, cond=cond)) mlp_out = x mlp_out = self.mlp0(mlp_out) mlp_out = self.activation(mlp_out) # FiLM-style conditioning. if self.mlp_film_cond_proj is not None: scale, shift = torch.chunk( self.mlp_film_cond_proj(noise_emb), chunks=2, dim=-1 ) assert scale.shape == shift.shape == (batch, config.d_feedforward) mlp_out = ( self.norm_no_learnable(mlp_out) * (1.0 + scale[:, None, :]) + shift[:, None, :] ) mlp_out = self.dropout(mlp_out) mlp_out = self.mlp1(mlp_out) x = self.layernorm2(x + mlp_out) assert x.shape == (batch, time, d_latent) return x def _sattn(self, x: Tensor, attn_mask: Tensor | None) -> Tensor: """Multi-head self-attention.""" config = self.config q, k, v = rearrange( self.sattn_qkv_proj(x), "b t (qkv nh dh) -> qkv b nh t dh", qkv=3, nh=config.n_heads, ) if self.rotary_emb is not None: q = self.rotary_emb.rotate_queries_or_keys(q, seq_dim=-2) k = self.rotary_emb.rotate_queries_or_keys(k, seq_dim=-2) x = torch.nn.functional.scaled_dot_product_attention( q, k, v, dropout_p=config.dropout_p, attn_mask=attn_mask ) x = self.dropout(x) x = rearrange(x, "b nh t dh -> b t (nh dh)", nh=config.n_heads) x = torch.nn.functional.scaled_dot_product_attention( q, k, v, dropout_p=config.dropout_p ) x = self.dropout(x) x = rearrange(x, "b nh t dh -> b t (nh dh)", nh=config.n_heads) x = self.sattn_out_proj(x) return x def _xattn(self, x: Tensor, attn_mask: Tensor | None, cond: Tensor) -> Tensor: """Multi-head cross-attention.""" config = self.config k, v = rearrange( self.xattn_kv_proj( { "kv_from_cond_q_from_x": cond, "kv_from_x_q_from_cond": x, }[self.config.xattn_mode] ), "b t (qk nh dh) -> qk b nh t dh", qk=2, nh=config.n_heads, ) q = rearrange( self.xattn_q_proj( { "kv_from_cond_q_from_x": x, "kv_from_x_q_from_cond": cond, }[self.config.xattn_mode] ), "b t (nh dh) -> b nh t dh", nh=config.n_heads, ) if self.rotary_emb is not None: q = self.rotary_emb.rotate_queries_or_keys(q, seq_dim=-2) k = self.rotary_emb.rotate_queries_or_keys(k, seq_dim=-2) x = torch.nn.functional.scaled_dot_product_attention( q, k, v, dropout_p=config.dropout_p, attn_mask=attn_mask ) x = rearrange(x, "b nh t dh -> b t (nh dh)") x = self.xattn_out_proj(x) return x def zero_module(module): """Zero out the parameters of a module and return it.""" for p in module.parameters(): p.detach().zero_() return module ================================================ FILE: src/egoallo/preprocessing/__init__.py ================================================ ================================================ FILE: src/egoallo/preprocessing/body_model/__init__.py ================================================ from .body_model import BodyModel from .skeleton import * from .specs import * from .utils import * ================================================ FILE: src/egoallo/preprocessing/body_model/body_model.py ================================================ from loguru import logger as guru import os from einops import rearrange import torch import torch.nn as nn from typing import Tuple, Dict from smplx import SMPLLayer, SMPLHLayer from smplx.vertex_ids import vertex_ids from smplx.utils import Struct from smplx.lbs import lbs as smpl_lbs from ..geometry import convert_rotation from ..util.tensor import pad_dim from .specs import SMPL_JOINTS from .utils import ( forward_kinematics, inverse_kinematics, select_vert_params, get_verts_with_transforms, ) class BodyModel(nn.Module): """ Wrapper around SMPLX body model class. """ def __init__( self, bm_path, model_type: str = "smplh", use_pca: bool = True, num_pca_comps: int = 6, batch_size: int = 1, use_vtx_selector: bool = True, **kwargs, ): """ Creates the body model object at the given path. :param bm_path: path to the body model file :param model_type: one of [smpl, smplh] :param use_vtx_selector: if true, returns additional vertices as joints that correspond to OpenPose joints """ super().__init__() assert model_type in ["smpl", "smplh"] self.model_type = model_type self.use_pca = use_pca self.num_pca_comps = num_pca_comps self.use_vtx_selector = use_vtx_selector cur_vertex_ids = None if self.use_vtx_selector: cur_vertex_ids = vertex_ids[model_type] kwargs["vertex_ids"] = cur_vertex_ids ext = os.path.splitext(bm_path)[-1][1:] if model_type == "smpl": cls = SMPLLayer self.hand_dim = 0 self.num_betas = 10 else: cls = SMPLHLayer self.hand_dim = cls.NUM_HAND_JOINTS * 3 if not use_pca else num_pca_comps self.num_betas = 16 self.batch_size = batch_size self.num_joints = cls.NUM_JOINTS + 1 # include root self.num_body_joints = cls.NUM_BODY_JOINTS # create body model without default parameters self.bm = cls( bm_path, ext=ext, num_betas=self.num_betas, use_pca=use_pca, num_pca_comps=num_pca_comps, batch_size=batch_size, **kwargs, ) guru.info(f"loading body model from {bm_path}, batch size {batch_size}") # make our own default buffers self.var_dims = { "root_orient": 3, "pose_body": cls.NUM_BODY_JOINTS * 3, "pose_hand": self.hand_dim * 2, "betas": self.num_betas, "trans": 3, } guru.info(f"variable dims {self.var_dims}") for name, sh in self.var_dims.items(): self.register_buffer(name, torch.zeros(batch_size, sh)) # save the template joints v_template = self.bm.v_template # (V, 3) J_regressor = self.bm.J_regressor # (J, V) joint_template = torch.matmul(J_regressor, v_template)[None] # type: ignore # save the extra joints from the vertex template # (1, J, 3) self.register_buffer("joint_template", joint_template) self.parents = self.bm.parents shapedirs = self.bm.shapedirs # (V, 3, B) j_shapedirs = rearrange( torch.einsum("jv,v...->j...", J_regressor, shapedirs), "a b c -> (a b) c" ) # (J * 3, B) self.register_buffer("joint_shapedirs", j_shapedirs) # (B, J * 3) # because overparameterized, use the fewest smpl joints J = len(SMPL_JOINTS) self.register_buffer( "joint_shapedirs_pinv", torch.linalg.pinv(j_shapedirs[: J * 3]) ) self._recompute_inverse_beta_mat = False # self.register_buffer("joint_shapedirs_pinv", torch.linalg.pinv(j_shapedirs)) # self._recompute_inverse_beta_mat = True for p in self.parameters(): p.requires_grad_(False) def _fill_default_vars(self, model_args) -> Tuple[Dict, Dict]: """ fill in the missing variables with defaults padded to correct batch size """ B = self._get_batch_size(**model_args) model_vars = {} for name in self.var_dims: var = model_args.pop(name, None) if var is None: var = self._get_default_model_var(name, B) model_vars[name] = var return model_vars, model_args def _get_batch_size(self, **model_args) -> int: """ get the batch size of the input args """ B = self.batch_size for name in self.var_dims: if name in model_args and type(var := model_args[name]) == torch.Tensor: B = var.shape[0] break return B def _get_default_model_var(self, name: str, batch_size: int): """ if we have the desired variable, return it, otherwise return the default value get model var with desired batch size """ return pad_dim(getattr(self, name), batch_size) def get_full_pose_mats(self, model_args, add_mean: bool = True): """ get the full pose from provided model args """ B = self._get_batch_size(**model_args) names = ["root_orient", "pose_body", "pose_hand"] model_vars = { k: model_args.get(k, self._get_default_model_var(k, B)) for k in names } root_mat = model_vars["root_orient"] if root_mat.ndim == 2: root_mat = convert_rotation( root_mat.unsqueeze(-2), "aa", "mat" ) # (B, 1, 3, 3) body_mat = model_vars["pose_body"] if body_mat.ndim == 2: body_mat = convert_rotation( body_mat.reshape(B, -1, 3), "aa", "mat" ) # (B, J, 3, 3) hand_mat = self.get_hand_pose_mat( model_vars["pose_hand"], add_mean=add_mean ) # (B, H, 3, 3) full_pose = torch.cat([root_mat, body_mat, hand_mat], dim=-3) return full_pose def get_hand_pose_mat( self, pose_hand: torch.Tensor, add_mean: bool = True ) -> torch.Tensor: """ get the hand joint rotations if applicable :param pose_hand (*, D) """ if self.hand_dim == 0: return pose_hand B = pose_hand.shape[0] if self.use_pca: left_hand_pose = torch.einsum( "...i,ij->...j", pose_hand[..., : self.hand_dim], self.bm.left_hand_components, ) right_hand_pose = torch.einsum( "...i,ij->...j", pose_hand[..., self.hand_dim :], self.bm.right_hand_components, ) pose_hand = torch.cat([left_hand_pose, right_hand_pose], dim=-1) if add_mean: J = self.num_body_joints + 1 hand_mean = self.bm.pose_mean[..., 3 * J :] # type: ignore pose_hand += hand_mean if pose_hand.ndim == 2: pose_hand = convert_rotation(pose_hand.reshape(B, -1, 3), "aa", "mat") return pose_hand def forward_joints(self, **kwargs): """ forward on joints only returns (*, J, 3) joints """ model_vars, _ = self._fill_default_vars(kwargs) rot_mats = self.get_full_pose_mats(model_vars) B = rot_mats.shape[0] shape_diffs = torch.einsum( "ij,nj->ni", self.joint_shapedirs, model_vars["betas"] ) shape_diffs = shape_diffs.reshape(B, -1, 3) joints_shaped = self.joint_template + shape_diffs joints_local, rel_transforms = forward_kinematics( rot_mats, joints_shaped, self.parents # type: ignore ) if self.use_vtx_selector: extra_joints = self.get_extra_joints( model_vars["betas"], rot_mats, rel_transforms ) joints_local = torch.cat([joints_local, extra_joints], dim=-2) return joints_local + model_vars["trans"].unsqueeze(-2) def get_extra_joints(self, betas, pose_mats, rel_transforms): vtx_idcs = self.bm.vertex_joint_selector.extra_joints_idxs v_template, shapedirs, posedirs, lbs_weights = select_vert_params( vtx_idcs, # type: ignore self.bm.v_template, # type: ignore self.bm.shapedirs, # type: ignore self.bm.posedirs, # type: ignore self.bm.lbs_weights, # type: ignore ) return get_verts_with_transforms( betas, pose_mats, rel_transforms, v_template, shapedirs, posedirs, lbs_weights, ) def inverse_joints(self, joints: torch.Tensor, **kwargs): """ get the unposed joints (template pose) """ model_vars, _ = self._fill_default_vars(kwargs) rot_mats = self.get_full_pose_mats(model_vars) joints_local = joints - model_vars["trans"].unsqueeze(-2) return inverse_kinematics(rot_mats, joints_local, self.parents) # type: ignore def joints_to_beta(self, joint_unposed: torch.Tensor) -> torch.Tensor: """ get the nearest beta such that joint_unposed = joint_template + A @ beta :param (*, J, 3) joints """ # get the residual with the template J = len(SMPL_JOINTS) if self._recompute_inverse_beta_mat: self.joint_shapedirs_pinv = torch.linalg.pinv(self.joint_shapedirs[: J * 3]) # type: ignore self._recompute_inverse_beta_mat = False dims = joint_unposed.shape[:-2] joint_unposed = joint_unposed[..., :J, :] joint_template = self.joint_template[..., :J, :] # type: ignore joint_delta = (joint_unposed - joint_template).reshape(*dims, J * 3) betas = torch.einsum( "ij,...j->...i", self.joint_shapedirs_pinv, joint_delta ) # (*, B) return betas def forward(self, **kwargs): """ forward pass of smpl model expects kwargs in [root_orient, pose_body, pose_hand, betas, trans] to have same leading dimension if included, otherwise will pad itself """ model_vars, kwargs = self._fill_default_vars(kwargs) rot_mats = self.get_full_pose_mats(model_vars) verts, joints = smpl_lbs( model_vars["betas"], rot_mats, # type: ignore self.bm.v_template, # type: ignore self.bm.shapedirs, # type: ignore self.bm.posedirs, # type: ignore self.bm.J_regressor, # type: ignore self.bm.parents, # type: ignore self.bm.lbs_weights, # type: ignore pose2rot=False, ) joints = self.bm.vertex_joint_selector(verts, joints) trans = model_vars["trans"].unsqueeze(-2) joints += trans verts += trans out = { "v": verts, "f": self.bm.faces_tensor, "Jtr": joints, "full_pose": rot_mats, } if not self.use_vtx_selector: # don't need extra joints out["Jtr"] = out["Jtr"][:, : self.num_joints] return Struct(**out) ================================================ FILE: src/egoallo/preprocessing/body_model/skeleton.py ================================================ import numpy as np import torch from .specs import SMPL_PARENTS, SMPL_JOINTS __all__ = [ "NUM_KINEMATIC_CHAINS", "smpl_kinematic_tree", "joint_angles_rel_to_glob", "joint_angles_glob_to_rel", ] NUM_KINEMATIC_CHAINS = 5 def smpl_kinematic_tree(): """ get the SMPL kinematic tree as a list of chains of joint indices """ joint_idcs = list(range(len(SMPL_JOINTS))) tree = [] chains = {} # key: last vertex so far, value: chain(s)))) for joint in joint_idcs[::-1]: parent = SMPL_PARENTS[joint] if parent in chains or parent < 0: continue chains[parent] = [parent] + chains.pop(joint, [joint]) tree = [] for joint, chain in chains.items(): parent = SMPL_PARENTS[joint] if parent >= 0: chain = [parent] + chain tree.insert(0, np.array(chain)) return tree def joint_angles_rel_to_glob(rel_mats): """ convert joint angles from relative (wrt to previous branch on kinematic chain) to global (wrt to root of skeleton) :param rotation matrices (*, 21, 3, 3) return (*, 21, 3, 3) """ assert rel_mats.shape[-3] == len(SMPL_JOINTS) - 1 glob_mats = torch.zeros_like(rel_mats) # aggregate transforms from parent to children kin_tree = smpl_kinematic_tree() for chain in kin_tree: for pidx, cidx in zip(chain[:-1], chain[1:]): # R_c0 = R_cp * R_p0 if pidx == 0: glob_mats[..., cidx - 1, :, :] = rel_mats[..., cidx - 1, :, :] else: glob_mats[..., cidx - 1, :, :] = torch.matmul( rel_mats[..., cidx - 1, :, :], glob_mats[..., pidx - 1, :, :] ) return glob_mats def joint_angles_glob_to_rel(glob_mats): """ convert joint angles from global (wrt to root of skeleton) to relative (wrt to previous branch on kinematic chain) :param rotation matrices (*, 21, 3, 3) return (*, 21, 3, 3) """ rel_mats = torch.zeros_like(glob_mats) assert glob_mats.shape[-3] == len(SMPL_JOINTS) - 1 # add the root matrix to global rotations dims = glob_mats.shape[:-3] I = ( torch.eye(3, device=glob_mats.device) .reshape(*(1,) * len(dims), 1, 3, 3) .expand(*dims, 1, 3, 3) ) glob_mats = torch.cat([I, glob_mats], dim=-3) # invert transforms from parent to children kin_tree = smpl_kinematic_tree() for chain in kin_tree: pidx, cidx = chain[:-1], chain[1:] # R_cp = R_c0 * R_0p rel_mats[..., cidx - 1, :, :] = torch.matmul( glob_mats[..., cidx, :, :], glob_mats[..., pidx, :, :].transpose(-1, -2) ) return rel_mats ================================================ FILE: src/egoallo/preprocessing/body_model/specs.py ================================================ import numpy as np SMPL_JOINTS = { "hips": 0, "leftUpLeg": 1, "rightUpLeg": 2, "spine": 3, "leftLeg": 4, "rightLeg": 5, "spine1": 6, "leftFoot": 7, "rightFoot": 8, "spine2": 9, "leftToeBase": 10, "rightToeBase": 11, "neck": 12, "leftShoulder": 13, "rightShoulder": 14, "head": 15, "leftArm": 16, "rightArm": 17, "leftForeArm": 18, "rightForeArm": 19, "leftHand": 20, "rightHand": 21, } SMPL_PARENTS = np.array( [ -1, 0, 0, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 12, 12, 12, 13, 14, 16, 17, 18, 19, ] ) # reflect joints RIGHT_CHAIN = np.array([2, 5, 8, 11, 14, 17, 19, 21]) LEFT_CHAIN = np.array([1, 4, 7, 10, 13, 16, 18, 20]) REFLECT_PERM = np.array( [ 0, 2, 1, 3, 5, 4, 6, 8, 7, 9, 11, 10, 12, 14, 13, 15, 17, 16, 19, 18, 21, 20, ] ) POSE_REFLECT_PERM = np.concatenate([3 * i + np.arange(3) for i in REFLECT_PERM], axis=0) # root, left knee, right knee, left heel, right heel, # left toe, right toe, left hand, right hand CONTACT_JOINTS = [ "hips", "leftLeg", "rightLeg", "leftFoot", "rightFoot", "leftToeBase", "rightToeBase", "leftHand", "rightHand", ] CONTACT_INDS = [SMPL_JOINTS[joint] for joint in CONTACT_JOINTS] FEET_JOINTS = [ "leftToeBase", "rightToeBase", ] FEET_INDS = [SMPL_JOINTS[joint] for joint in FEET_JOINTS] # chosen virtual mocap markers that are "keypoints" to work with KEYPT_VERTS = [ 4404, 920, 3076, 3169, 823, 4310, 1010, 1085, 4495, 4569, 6615, 3217, 3313, 6713, 6785, 3383, 6607, 3207, 1241, 1508, 4797, 4122, 1618, 1569, 5135, 5040, 5691, 5636, 5404, 2230, 2173, 2108, 134, 3645, 6543, 3123, 3024, 4194, 1306, 182, 3694, 4294, 744, ] """ Openpose """ OP_NUM_JOINTS = 25 # OP_IGNORE_JOINTS = [1, 9, 12] # neck and left/right hip OP_IGNORE_JOINTS = [1] # neck OP_EDGE_LIST = [ [1, 8], [1, 2], [1, 5], [2, 3], [3, 4], [5, 6], [6, 7], [8, 9], [9, 10], [10, 11], [8, 12], [12, 13], [13, 14], [1, 0], [0, 15], [15, 17], [0, 16], [16, 18], [14, 19], [19, 20], [14, 21], [11, 22], [22, 23], [11, 24], ] # indices to map an openpose detection to its flipped version OP_FLIP_MAP = [ 0, 1, 5, 6, 7, 2, 3, 4, 8, 12, 13, 14, 9, 10, 11, 16, 15, 18, 17, 22, 23, 24, 19, 20, 21, ] # From https://github.com/vchoutas/smplify-x/blob/master/smplifyx/utils.py # Please see license for usage restrictions. def smpl_to_openpose( model_type="smplh", use_hands=False, use_face=False, use_face_contour=False, openpose_format="coco25", ): """Returns the indices of the permutation that maps SMPL to OpenPose Parameters ---------- model_type: str, optional The type of SMPL-like model that is used. The default mapping returned is for the SMPLX model use_hands: bool, optional Flag for adding to the returned permutation the mapping for the hand keypoints. Defaults to True use_face: bool, optional Flag for adding to the returned permutation the mapping for the face keypoints. Defaults to True use_face_contour: bool, optional Flag for appending the facial contour keypoints. Defaults to False openpose_format: bool, optional The output format of OpenPose. For now only COCO-25 and COCO-19 is supported. Defaults to 'coco25' """ if openpose_format.lower() == "coco25": if model_type == "smpl": return np.array( [ 24, 12, 17, 19, 21, 16, 18, 20, 0, 2, 5, 8, 1, 4, 7, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, ], dtype=np.int32, ) elif model_type == "smplh": body_mapping = np.array( [ 52, 12, 17, 19, 21, 16, 18, 20, 0, 2, 5, 8, 1, 4, 7, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, ], dtype=np.int32, ) mapping = [body_mapping] if use_hands: lhand_mapping = np.array( [ 20, 34, 35, 36, 63, 22, 23, 24, 64, 25, 26, 27, 65, 31, 32, 33, 66, 28, 29, 30, 67, ], dtype=np.int32, ) rhand_mapping = np.array( [ 21, 49, 50, 51, 68, 37, 38, 39, 69, 40, 41, 42, 70, 46, 47, 48, 71, 43, 44, 45, 72, ], dtype=np.int32, ) mapping += [lhand_mapping, rhand_mapping] return np.concatenate(mapping) # SMPLX elif model_type == "smplx": body_mapping = np.array( [ 55, 12, 17, 19, 21, 16, 18, 20, 0, 2, 5, 8, 1, 4, 7, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, ], dtype=np.int32, ) mapping = [body_mapping] if use_hands: lhand_mapping = np.array( [ 20, 37, 38, 39, 66, 25, 26, 27, 67, 28, 29, 30, 68, 34, 35, 36, 69, 31, 32, 33, 70, ], dtype=np.int32, ) rhand_mapping = np.array( [ 21, 52, 53, 54, 71, 40, 41, 42, 72, 43, 44, 45, 73, 49, 50, 51, 74, 46, 47, 48, 75, ], dtype=np.int32, ) mapping += [lhand_mapping, rhand_mapping] if use_face: # end_idx = 127 + 17 * use_face_contour face_mapping = np.arange( 76, 127 + 17 * use_face_contour, dtype=np.int32 ) mapping += [face_mapping] return np.concatenate(mapping) else: raise ValueError("Unknown model type: {}".format(model_type)) elif openpose_format == "coco19": if model_type == "smpl": return np.array( [24, 12, 17, 19, 21, 16, 18, 20, 0, 2, 5, 8, 1, 4, 7, 25, 26, 27, 28], dtype=np.int32, ) elif model_type == "smplh": body_mapping = np.array( [52, 12, 17, 19, 21, 16, 18, 20, 0, 2, 5, 8, 1, 4, 7, 53, 54, 55, 56], dtype=np.int32, ) mapping = [body_mapping] if use_hands: lhand_mapping = np.array( [ 20, 34, 35, 36, 57, 22, 23, 24, 58, 25, 26, 27, 59, 31, 32, 33, 60, 28, 29, 30, 61, ], dtype=np.int32, ) rhand_mapping = np.array( [ 21, 49, 50, 51, 62, 37, 38, 39, 63, 40, 41, 42, 64, 46, 47, 48, 65, 43, 44, 45, 66, ], dtype=np.int32, ) mapping += [lhand_mapping, rhand_mapping] return np.concatenate(mapping) # SMPLX elif model_type == "smplx": body_mapping = np.array( [55, 12, 17, 19, 21, 16, 18, 20, 0, 2, 5, 8, 1, 4, 7, 56, 57, 58, 59], dtype=np.int32, ) mapping = [body_mapping] if use_hands: lhand_mapping = np.array( [ 20, 37, 38, 39, 60, 25, 26, 27, 61, 28, 29, 30, 62, 34, 35, 36, 63, 31, 32, 33, 64, ], dtype=np.int32, ) rhand_mapping = np.array( [ 21, 52, 53, 54, 65, 40, 41, 42, 66, 43, 44, 45, 67, 49, 50, 51, 68, 46, 47, 48, 69, ], dtype=np.int32, ) mapping += [lhand_mapping, rhand_mapping] if use_face: face_mapping = np.arange( 70, 70 + 51 + 17 * use_face_contour, dtype=np.int32 ) mapping += [face_mapping] return np.concatenate(mapping) else: raise ValueError("Unknown model type: {}".format(model_type)) else: raise ValueError("Unknown joint format: {}".format(openpose_format)) ================================================ FILE: src/egoallo/preprocessing/body_model/utils.py ================================================ from jaxtyping import Float, Int from typing import Tuple, Optional import numpy as np import torch from torch import Tensor import torch.nn.functional as F from torch import Tensor from ..geometry import ( get_rot_rep_shape, convert_rotation, batch_apply_Rt, make_transform, transform_rel_to_global, transform_global_to_rel, ) from .specs import SMPL_JOINTS, smpl_to_openpose, POSE_REFLECT_PERM from .skeleton import joint_angles_glob_to_rel __all__ = [ "run_smpl", "reflect_pose_aa", "reflect_root_trajectory", "forward_kinematics", "inverse_kinematics", "smpl_local_to_global", "select_smpl_joints", "get_openpose_from_smpl", "convert_local_pose_to_aa", "convert_global_pose_to_aa", "load_beta_conversion", "convert_model_betas", ] def run_smpl( body_model, mats_in: bool = False, return_verts: bool = True, **kwargs ) -> Tuple[Tensor, Optional[Tensor], Optional[Tensor]]: """ helper function for running SMPL model with multiple leading dimensions return joints and optionally verts and faces :param body_model :param return_verts (optional bool=True) """ device = body_model.bm.shapedirs.device dims = (body_model.batch_size,) fields = ["pose_body", "root_orient", "trans", "betas"] dim_idcs = [-3 if mats_in else -1, -2 if mats_in else -1, -1, -1] for name, idx in zip(fields, dim_idcs): if name in kwargs: x = kwargs[name] if x is None: continue dims, sh = x.shape[:idx], x.shape[idx:] kwargs[name] = x.reshape(-1, *sh).to(device) if not return_verts: joints = body_model.forward_joints(**kwargs).reshape(*dims, -1, 3) return joints, None, None smpl = body_model(**kwargs) joints = smpl.Jtr.reshape(*dims, -1, 3) verts = smpl.v.reshape(*dims, -1, 3) return joints, verts, smpl.f def reflect_pose_aa(root_orient: Tensor, pose_body: Tensor): """ :param root_orient (*, 3) :param pose_body (*, (J-1)*3) return reflected root_orient and pose_body """ pose_full = torch.cat([root_orient, pose_body], dim=-1) # (*, J*3) pose_reflect = pose_full[..., POSE_REFLECT_PERM] pose_reflect[..., 1::3] = -pose_reflect[..., 1::3] pose_reflect[..., 2::3] = -pose_reflect[..., 2::3] return pose_reflect[..., :3], pose_reflect[..., 3:] def reflect_root_trajectory( rot_aa: Tensor, trans: Tensor, rot_aa_r: Tensor, root_loc: Tensor ) -> Tuple[Tensor, Tensor]: # rotation from t to world R_wt = convert_rotation(rot_aa, "aa", "mat") # get the transforms of the root in the world T_wt = make_transform(R_wt, trans + root_loc) # transform from t to previous T_pt = transform_global_to_rel(T_wt) # rotation from reflected t to world R_wtr = convert_rotation(rot_aa_r, "aa", "mat") # relative transforms R_prtr = transform_global_to_rel(R_wtr) # get the displacement between t and t-1 in t, the SOURCE frame # t_prt = R_prtr * t_tr, where t_tr = (-1, 1, 1) * t_t, and t_t = R_tp * t_pt t_tt = torch.einsum("tij,tj->ti", torch.linalg.inv(T_pt[:, :3, :3]), T_pt[:, :3, 3]) # reflect through x t_trtr = torch.cat([-t_tt[..., :1], t_tt[..., 1:]], dim=-1) # convert displacement into the t-1 TARGET frame t_prtr = torch.einsum("tij,tj->ti", R_prtr, t_trtr) # get back global trajectory T_prtr = make_transform(R_prtr, t_prtr) T_wtr = transform_rel_to_global(T_prtr) # get back the smpl translation trans_wtr = T_wtr[:, :3, 3] - root_loc return rot_aa_r, trans_wtr def forward_kinematics( rot_mats: Tensor, joints_in: Tensor, parents: Tensor, ) -> Tuple[Tensor, Tensor]: """ get the forward transformed joints very similar to smplx's batch_rigid_transform with more flexible batch dimensions :param rot_mats (*, J, 3, 3) joint rotations from joint i to parent :param joints_in (*, J, 3) :param parents (J) returns (*, J, 4, 4) tensor of transforms """ J = len(parents) joints_body_rel = ( joints_in[..., 1:J, :] - joints_in[..., parents[1:], :] ) # (*, J-1, 3) joints_rel = torch.cat( [joints_in[..., :1, :], joints_body_rel], dim=-2 ) # (*, J, 3) T_pi = make_transform(rot_mats, joints_rel) # (*, J, 4, 4) tforms_wp = [T_pi[..., 0, :, :]] for i in range(1, J): tforms_wp.append(torch.matmul(tforms_wp[parents[i]], T_pi[..., i, :, :])) transforms = torch.stack(tforms_wp, dim=-3) joints_posed = transforms[..., :3, 3] # (*, J, 3) rel_trans_h = F.pad( joints_posed - torch.einsum("...ij,...j->...i", transforms[..., :3, :3], joints_in), [0, 1], value=1.0, ).unsqueeze(-1) rel_transforms = torch.cat([transforms[..., :3], rel_trans_h], dim=-1) return joints_posed, rel_transforms def get_pose_offsets( pose_mats: Float[Tensor, "*batch J 3 3"], posedirs: Float[Tensor, "P N"], ) -> Float[Tensor, "*batch J 3"]: dims = pose_mats.shape[:-3] I = torch.eye(3, device=pose_mats.device).reshape(*(1,) * len(dims), 1, 3, 3) pose_feat = (pose_mats[..., 1:, :, :] - I).reshape(*dims, -1) # (*, P) return torch.einsum("...p,pn->...n", pose_feat, posedirs).reshape(*dims, -1, 3) def select_vert_params( idcs: Int[Tensor, "S"], v_template: Float[Tensor, "V 3"], shapedirs: Float[Tensor, "V 3 B"], posedirs: Float[Tensor, "P N"], lbs_weights: Float[Tensor, "V J"], ) -> Tuple[Tensor, Tensor, Tensor, Tensor]: pose_idcs = torch.repeat_interleave(3 * idcs, 3, -1) pose_idcs[1::3] += 1 pose_idcs[2::3] += 2 return v_template[idcs], shapedirs[idcs], posedirs[:, pose_idcs], lbs_weights[idcs] def get_verts_with_transforms( betas: Float[Tensor, "*batch B"], pose_mats: Float[Tensor, "*batch J 3 3"], rel_transforms: Float[Tensor, "*batch J 4 4"], v_template: Float[Tensor, "V 3"], shapedirs: Float[Tensor, "V 3 B"], posedirs: Float[Tensor, "P N"], lbs_weights: Float[Tensor, "V J"], ): # (*, V, 3) v_shaped = v_template + torch.einsum("...l,mkl->...mk", betas, shapedirs) v_posed = v_shaped + get_pose_offsets(pose_mats, posedirs) T = torch.einsum("ij,...jkl->...ikl", lbs_weights, rel_transforms) # (*, V, 4, 4) v_out = torch.einsum("...ij,...j->...i", T[..., :3, :3], v_posed) + T[..., :3, 3] return v_out def inverse_kinematics(rot_mats: Tensor, joints: Tensor, parents: Tensor) -> Tensor: """ given the joint rotations and locations of a posed skeleton, invert and get the template skeleton :param rot_mats (*, J, 3, 3) rotation from joint i to parent (R_pi) :param joints (*, J, 3) posed joint locations :param parents (J) returns (*, J, 3) template joints """ # J = len(parents) J = joints.shape[-2] # delta between joint and parent in the world delta_w = torch.cat( [joints[..., :1, :], joints[..., 1:J, :] - joints[..., parents[1:J], :]], dim=-2 ) # rot mats from parent to joint i rots_ip = rot_mats.transpose(-1, -2) # get the world to parent rotation matrices rots_pw = [rots_ip[..., 0, :, :]] trans_p = [joints[..., 0, :]] for i in range(1, J): # R_iw = R_ip R_pw = (R_pi.T) R_pw R_pw = rots_pw[parents[i]] delta_p = torch.einsum("...ij,...j->...i", R_pw, delta_w[..., i, :]) trans_p.append(trans_p[parents[i]] + delta_p) if i >= J - 1: break rots_pw.append(torch.matmul(rots_ip[..., i, :, :], R_pw)) return torch.stack(trans_p, dim=-2) def smpl_local_to_global( R_root: Tensor, t_root: Tensor, points_l: Tensor, root_l: Tensor ) -> Tensor: """ transform local smpl body to global :param T_root (*, 4, 4) root transform from local to world :param points_l (*, N, 3) points to transform :param root_l (*, 1, 3) root in local coordinates """ return batch_apply_Rt(R_root, t_root, points_l - root_l) + root_l def select_smpl_joints(joints_full): """ select the first 22 smpl joints from the full joints :param joints_full (*, J, 3) """ return joints_full[..., : len(SMPL_JOINTS), :] def get_openpose_from_smpl(joints_smpl, model_type="smplh"): smpl2op_map = smpl_to_openpose( model_type, use_hands=False, use_face=False, use_face_contour=False, openpose_format="coco25", ) joints3d_op = joints_smpl[..., smpl2op_map, :] # hacky way to get hip joints that align with ViTPose keypoints # this could be moved elsewhere in the future (and done properly) joints3d_op[..., [9, 12], :] = ( joints3d_op[..., [9, 12], :] + 0.25 * (joints3d_op[..., [9, 12], :] - joints3d_op[..., [12, 9], :]) + 0.5 * ( joints3d_op[..., [8], :] - 0.5 * (joints3d_op[..., [9, 12], :] + joints3d_op[..., [12, 9], :]) ) ) return joints3d_op def convert_local_pose_to_aa(pose_body: Tensor, rot_rep: str): """ convert local pose in rotation representation into flattened axis-angle :param pose_body (*, J*D) :param rot_rep (str) returns (*, J*3) flattened aa pose """ if rot_rep == "aa": return pose_body dims = pose_body.shape[:-1] rot_sh = get_rot_rep_shape(rot_rep) pose_aa = convert_rotation( pose_body.reshape(*dims, -1, *rot_sh), rot_rep, "aa" ) # (*, J, 3) return pose_aa.reshape(*dims, -1) def convert_global_pose_to_aa(pose_glob: Tensor, rot_rep: str): """ :param pose_glob (*, J*D) :param rot_rep (str) returnns (*, J*3) local pose flattened aa """ dims = pose_glob.shape[:-1] rot_sh = get_rot_rep_shape(rot_rep) pose_glob_mat = convert_rotation( pose_glob.reshape(*dims, -1, *rot_sh), rot_rep, "mat" ) pose_rel_mat = joint_angles_glob_to_rel(pose_glob_mat) # (*, J, 3, 3) return convert_rotation(pose_rel_mat, "mat", "aa").reshape(*dims, -1) def load_beta_conversion(path: str) -> Tuple[Tensor, Tensor]: data = np.load(path) return torch.from_numpy(data["A"].astype("float32")), torch.from_numpy( data["b"].astype("float32") ) def convert_model_betas(beta: Tensor, A: Tensor, b: Tensor) -> Tensor: """ :param beta (*, B) :param A (B, B) :param b (B) beta_neutral = A @ beta_gender + b """ *dims, B = beta.shape A = A.reshape((*(1,) * len(dims), *A.shape)) b = b.reshape((*(1,) * len(dims), *b.shape)) return torch.einsum("...ij,...j->...i", A, beta) + b ================================================ FILE: src/egoallo/preprocessing/geometry/__init__.py ================================================ from .rotation import * from .helpers import * from . import plane from . import camera from . import transforms ================================================ FILE: src/egoallo/preprocessing/geometry/camera.py ================================================ from typing import Tuple import torch import numpy as np # import lietorch as tf from . import transforms as tf from .helpers import batch_apply_Rt def project_from_world(X_w, R_cw, t_cw, intrins): """ :param X_w (*, N, 3) :param cam_R (*, 3, 3) :param cam_t (*, 3) :param intrins (*, 4) """ return proj_2d(batch_apply_Rt(R_cw, t_cw, X_w), intrins) def proj_2d(xyz, intrins, eps=1e-4): """ :param xyz (*, 3/4) 3d/4d point in camera coordinates :param intrins (*, 4) fx, fy, cx, cy return (*, 2) of reprojected points, (*) of points in front of camera """ z = xyz[..., 2:3] valid_mask = z > eps disp = torch.where(valid_mask, 1.0 / (z + eps), torch.ones_like(z)) focal = intrins[..., :2] center = intrins[..., 2:] return focal * disp * xyz[..., :2] + center, valid_mask[..., 0] def proj_h(xyzw): """ project homogeneous point """ w = xyzw[..., -1:] return xyzw[..., :-1] * torch.where(w > 0, 1.0 / w, w) def iproj_depth(uv, z, intrins): """ inverse project into 3d coords from depth :param uv (*, 2) :param z (*, 1) :param intrins (*, 4) :returns (*, 3) """ focal = intrins[..., :2] center = intrins[..., 2:] return z * torch.cat([(uv - center) / focal, torch.ones_like(z)], dim=-1) def iproj(uv, disp, intrins): """ inverse project from disparity. returns 4d homogeneous :param uv (*, 2) :param disp (*, 1) :param intrins (*, 4) :returns (*, 4) """ x = normalize_coords(uv, intrins) X = torch.cat([x, torch.ones_like(disp), disp], dim=-1) return X def normalize_coords(uv, intrins): focal = intrins[..., :2] center = intrins[..., 2:] return (uv - center) / focal def iproj_to_world(uv, disp, intrins, extrins, ret_3d=True): """ inverse project disparity into world coords. default returns 3d :param uv (*, 2) :param disp (*, 1) :param intrins (*, 4) :param extrins (*, 7) :param ret_3d (optional bool) return in 3d coords, default True :returns (*, 3) """ T_wc = tf.SE3(extrins).inv() X_c = iproj(uv, disp, intrins) X_w = T_wc.act(X_c) if ret_3d: return proj_h(X_w) return X_w def reproject(pose_params, intrins, disps, uv, ii, jj): """ :param pose_params (T, *, 7) pose parameters :param intrins (T, *, 4) fx, fy, cx, cy :param uv (T, *, 2) coordinate grid :param disps (T, *, 1) disparity :param ii (N) source index array into parameters :param jj (N) target index array into parameters returns (N, *, 2) points in ii reprojected into jj """ T_i, T_j = tf.SE3(pose_params[ii]), tf.SE3(pose_params[jj]) Xh_i = iproj(uv, disps[ii], intrins[ii]) Xh_j = T_j.mul(T_i.inv()).act(Xh_i) return proj_2d(Xh_j, intrins[jj]) def proj_2d_jac(X, intrins): """ :param X (*, 4) point in camera coordinates :param intrins (*, 4) fx, fy, cx, cy return (*, 2, 4) """ fx, fy, cx, cy = intrins.unbind(dim=-1) X, Y, Z, D = X.unbind(dim=-1) d = torch.where(Z > 0.1, 1.0 / Z, torch.ones_like(Z)) o = torch.zeros_like(d) return torch.stack( [fx * d, o, -fx * X * d * d, o, o, fy * d, -fy * Y * d * d, o], dim=-1, ).reshape(*d.shape, 2, 4) def actp_jac(X1): """ :param X1 (*, 4) point after transformation """ x, y, z, d = X1.unbind(dim=-1) o = torch.zeros_like(d) return torch.stack( [d, o, o, o, z, -y, o, d, o, -z, o, x, o, o, d, y, -x, o, o, o, o, o, o, o], dim=-1, ).reshape(*d.shape, 4, 6) def iproj_jac(X): """ jacobian for inverse projection to 4d """ J = torch.zeros_like(X) J[..., -1] = 1 return J def make_homogeneous(x): """ :param x (*, 3) returns x in homogeneous coordinates """ return torch.cat([x, torch.ones_like(x[..., :1])], dim=-1) def make_transform(R, t): """ :param R (*, 3, 3) :param t (*, 3) return (*, 4, 4) """ dims = R.shape[:-2] bottom = ( torch.tensor([0, 0, 0, 1], dtype=R.dtype, device=R.device) .reshape(*(1,) * len(dims), 1, 4) .repeat(*dims, 1, 1) ) return torch.cat([torch.cat([R, t.unsqueeze(-1)], dim=-1), bottom], dim=-2) def focal2fov(focal, R): """ :param focal, focal length :param R, either W / 2 or H / 2 """ return 2 * np.arctan(R / focal) def fov2focal(fov, R): """ :param fov, field of view in radians :param R, either W / 2 or H / 2 """ return R / np.tan(fov / 2) def lookat_matrix(source_pos, target_pos, up): """ uses x right y down z forward opencv convention :param source_pos (*, 3) :param target_pos (*, 3) :param up (3,) """ *dims, _ = source_pos.shape up = up.reshape(*(1,) * len(dims), 3) up = up / torch.linalg.norm(up, dim=-1, keepdim=True) back = normalize(source_pos - target_pos) right = normalize(torch.linalg.cross(up, back)) up = normalize(torch.linalg.cross(back, right)) R = torch.stack([right, -up, -back], dim=-1) return make_transform(R, source_pos) def normalize(x): return x / torch.linalg.norm(x, dim=-1, keepdim=True) def view_matrix(z, up, pos): """ :param z (*, 3) up (*, 3) pos (*, 3) returns (*, 4, 4) """ *dims, _ = z.shape x = normalize(torch.linalg.cross(up, z)) y = normalize(torch.linalg.cross(z, x)) bottom = ( torch.tensor([0, 0, 0, 1], dtype=torch.float32) .reshape(*(1,) * len(dims), 1, 4) .expand(*dims, 1, 4) ) return torch.cat([torch.stack([x, y, z, pos], dim=-1), bottom], dim=-2) def average_pose(poses): """ :param poses (N, 4, 4) returns average pose (4, 4) """ center = poses[:, :3, 3].mean(0) up = normalize(poses[:, :3, 1].sum(0)) z = normalize(poses[:, :3, 2].sum(0)) return view_matrix(z, up, center) def make_translation(t): return make_transform(torch.eye(3, device=t.device), t) def make_rotation(rx=0, ry=0, rz=0, order="xyz"): Rx = rotx(rx) Ry = roty(ry) Rz = rotz(rz) if order == "xyz": R = Rz @ Ry @ Rx elif order == "xzy": R = Ry @ Rz @ Rx elif order == "yxz": R = Rz @ Rx @ Ry elif order == "yzx": R = Rx @ Rz @ Ry elif order == "zyx": R = Rx @ Ry @ Rz elif order == "zxy": R = Ry @ Rx @ Rz else: raise NotImplementedError return make_transform(R, torch.zeros(3)) def rotx(theta): return torch.tensor( [ [1, 0, 0], [0, np.cos(theta), -np.sin(theta)], [0, np.sin(theta), np.cos(theta)], ], dtype=torch.float32, ) def roty(theta): return torch.tensor( [ [np.cos(theta), 0, np.sin(theta)], [0, 1, 0], [-np.sin(theta), 0, np.cos(theta)], ], dtype=torch.float32, ) def rotz(theta): return torch.tensor( [ [np.cos(theta), -np.sin(theta), 0], [np.sin(theta), np.cos(theta), 0], [0, 0, 1], ], dtype=torch.float32, ) def identity(shape: Tuple, d=4, **kwargs): I = torch.eye(d, **kwargs) return I.reshape(*(1,) * len(shape), d, d).repeat(*shape, 1, 1) ================================================ FILE: src/egoallo/preprocessing/geometry/helpers.py ================================================ import torch import numpy as np from .rotation import convert_rotation def make_transform(R, t): """ :param R (*, 3, 3) :param t (*, 3) """ dims = R.shape[:-2] pose_3x4 = torch.cat([R, t.view(*dims, 3, 1)], dim=-1) bottom = ( torch.tensor([0, 0, 0, 1], device=R.device) .reshape(*(1,) * len(dims), 1, 4) .expand(*dims, 1, 4) ) return torch.cat([pose_3x4, bottom], dim=-2) def transform_points(T, x): """ :param T (*, 4, 4) :param x (*, N, 3) """ R = T[..., :3, :3] t = T[..., :3, 3] return batch_apply_Rt(R, t, x) def batch_apply_Rt(R, t, x): """ :param R (*, 3, 3) :param t (*, 3) :param x (*, N, 3) """ return torch.einsum("...ij,...nj->...ni", R, x) + t.unsqueeze(-2) def transform_global_to_rel(T_glob): """ get the relative transforms (diffs) from global transform of trajectory :param T_glob (*, T, 4, 4) root to world transform return root t->t-1 transform (*, T, 4, 4) """ T_rel = torch.matmul( torch.linalg.inv(T_glob[..., :-1, :, :]), T_glob[..., 1:, :, :] ) # (*, T-1, 4, 4) return torch.cat([T_glob[..., :1, :, :], T_rel], dim=-3) def transform_rel_to_global(T_rel): """ convert relative transforms into global trajectory :param T_rel (*, T, 4, 4) root t -> t-1 transform return root t -> world transform (*, T, 4, 4) """ N = T_rel.shape[-3] T_rel_list = T_rel.unbind(dim=-3) T_glob_list = [T_rel_list[0]] for t in range(1, N): T_cur = torch.matmul(T_glob_list[t - 1], T_rel_list[t]) T_glob_list.append(T_cur) return torch.stack(T_glob_list, dim=-3) def RT_global_to_rel(R_glob, t_glob): """ :param R_glob (*, T, 3, 3) root to world rotation :param t_glob (*, T, 3) root to world translation returns root t -> t-1 rotation (*, T, 3, 3) and translation (*, T, 3) """ T_glob = make_transform(R_glob, t_glob) # (*, T, 4, 4) root to world T_rel = transform_global_to_rel(T_glob) return T_rel[..., :3, :3], T_rel[..., :3, 3] def RT_rel_to_global(R_rel, t_rel): """ :param R_rel (*, T, 3, 3) root at t -> root at t-1 rotation :param t_rel (*, T, 3) root at t -> root at t-1 translation return root to world rotation (*, T, 3, 3) and translation (*, T, 3) """ T_rel = make_transform(R_rel, t_rel) # (*, T, 4, 4) T_glob = transform_rel_to_global(T_rel) return T_glob[..., :3, :3], T_glob[..., :3, 3] def joints_local_to_global( root_orient, trans, joints_loc, use_rel: bool = True, rot_rep: str = "6d" ): """ convert joints in local coords to global coordinates (X_w - root) = T_wl * (X_l - root) :param trans (*, T, 3) :param root_orient (*, T, *rot_shape) :param joints_loc (*, T, J * 3) :param use_rel (optional bool) if true, root trajectory specified as relative transforms returns global joint locations (*, T, J, 3) """ root_orient_mat = convert_rotation(root_orient, rot_rep, "mat") # (B, T, 3, 3) T_wl = make_transform(root_orient_mat, trans) if use_rel: # global translation and orientation are in diffs T_wl = transform_rel_to_global(T_wl) joints_loc = joints_loc.reshape(*trans.shape[:-1], -1, 3) root_loc = joints_loc[..., :1, :] return transform_points(T_wl, joints_loc - root_loc) + root_loc def joints_global_to_local(root_orient_mat, trans, joints_glob, joints_vel_glob=None): """ convert joints in global coords to local coords i.e. smpl output with zero root_orient and trans (X_w - root) = T_wl * (X_l - root) :param trans (*, 3) :param root_orient_mat (*, 3, 3) :param joints_glob (*, J, 3) :param joints_vel_glob (optional) (*, J, 3) returns local joint locations (*, J, 3) """ T_lw = torch.linalg.inv(make_transform(root_orient_mat, trans)) # (*, 4, 4) root_loc = joints_glob[..., :1, :] - trans.unsqueeze(-2) # (*, 1, 3) joints_loc = transform_points(T_lw, joints_glob - root_loc) + root_loc joints_vel_loc = None if joints_vel_glob is not None: # no translation joints_vel_loc = torch.einsum( "...ij,...nj->...ni", T_lw[..., :3, :3], joints_vel_glob ) # (*, J, 3) return joints_loc, joints_vel_loc def align_pcl(Y, X, weight=None, fixed_scale=False): """ align similarity transform to align X with Y using umeyama method X' = s * R * X + t is aligned with Y :param Y (*, N, 3) first trajectory :param X (*, N, 3) second trajectory :param weight (*, N, 1) optional weight of valid correspondences :returns s (*, 1), R (*, 3, 3), t (*, 3) """ *dims, N, _ = Y.shape device = X.device N = torch.ones(*dims, 1, 1, device=device) * N if weight is not None: N = weight.sum(dim=-2, keepdim=True) # (*, 1, 1) # subtract mean my = Y.sum(dim=-2) / N[..., 0] # (*, 3) mx = X.sum(dim=-2) / N[..., 0] y0 = Y - my[..., None, :] # (*, N, 3) x0 = X - mx[..., None, :] if weight is not None: y0 = y0 * weight x0 = x0 * weight # correlation C = torch.matmul(y0.transpose(-1, -2), x0) / N # (*, 3, 3) U, D, Vh = torch.linalg.svd(C) # (*, 3, 3), (*, 3), (*, 3, 3) S = torch.eye(3, device=device).reshape(*(1,) * (len(dims)), 3, 3).repeat(*dims, 1, 1) neg = torch.det(U) * torch.det(Vh.transpose(-1, -2)) < 0 S[neg, 2, 2] = -1 R = torch.matmul(U, torch.matmul(S, Vh)) # (*, 3, 3) D = torch.diag_embed(D) # (*, 3, 3) if fixed_scale: s = torch.ones(*dims, 1, device=device, dtype=torch.float32) else: var = torch.sum(torch.square(x0), dim=(-1, -2), keepdim=True) / N # (*, 1, 1) s = ( torch.diagonal(torch.matmul(D, S), dim1=-2, dim2=-1).sum( dim=-1, keepdim=True ) / var[..., 0] ) # (*, 1) t = my - s * torch.matmul(R, mx[..., None])[..., 0] # (*, 3) return s, R, t def get_translation_scale(fps=30): """ scale relative translation into (m/s), over average walking speed ~1.5 m/s i.e. scale delta such that average walking speed -> 1 """ return 1.0 * fps def estimate_velocity(data_seq, h=1 / 30): """ Given some data sequence of T timesteps in the shape (T, ...), estimates the velocity for the middle T-2 steps using a second order central difference scheme. - h : step size """ data_tp1 = data_seq[2:] data_tm1 = data_seq[0:-2] data_vel_seq = (data_tp1 - data_tm1) / (2 * h) return data_vel_seq def estimate_angular_velocity(rot_seq, h=1 / 30): """ Given a sequence of T rotation matrices, estimates angular velocity at T-2 steps. Input sequence should be of shape (T, ..., 3, 3) """ # see https://en.wikipedia.org/wiki/Angular_velocity#Calculation_from_the_orientation_matrix dRdt = estimate_velocity(rot_seq, h) R = rot_seq[1:-1] RT = np.swapaxes(R, -1, -2) # compute skew-symmetric angular velocity tensor w_mat = np.matmul(dRdt, RT) # pull out angular velocity vector # average symmetric entries w_x = (-w_mat[..., 1, 2] + w_mat[..., 2, 1]) / 2.0 w_y = (w_mat[..., 0, 2] - w_mat[..., 2, 0]) / 2.0 w_z = (-w_mat[..., 0, 1] + w_mat[..., 1, 0]) / 2.0 w = np.stack([w_x, w_y, w_z], axis=-1) return w ================================================ FILE: src/egoallo/preprocessing/geometry/plane.py ================================================ from jaxtyping import Float from typing import Tuple, Optional import torch from torch import Tensor import torch.nn.functional as F from .rotation import axis_angle_to_matrix from .helpers import make_transform def transform_align_body_right(root_orient_mat, trans, **kwargs): """ make the transform aligns body_right (-x) with x axis and moves trans to origin :param root_orient_mat (3, 3) :param trans (3,) """ # move first frame to origin and transform root orient x to [1, 0, 0] R_align_x = rotation_align_body_right(root_orient_mat, **kwargs) # (3, 3) t_align_x = -R_align_x @ trans return make_transform(R_align_x, t_align_x) def rotation_align_body_right( root_orient_mat, up=[0.0, 0.0, 1.0], right=[-1.0, 0.0, 0.0], **kwargs ): """ compute the rotation that aligns local body right vector (-x) with [1, 0, 0] (+x) via rotation about up axis (+z) :param root_orient_mat (*, 3, 3) :param up vector (*, 3) default [0, 0, 1] :param right vector (*, 3) default [1, 0, 0] returns (*, 3, 3) rotation matrix """ root_x = -root_orient_mat[..., 0] nldims = root_x.ndim - 1 up = torch.as_tensor(up, device=root_x.device) right = torch.as_tensor(right, device=root_x.device) if up.ndim < root_x.ndim: up = up.reshape(*(1,) * nldims, 3) if right.ndim < root_x.ndim: right = right.reshape(*(1,) * nldims, 3) # project root_x to floor plane (perpendicular to up) root_x = root_x - project_vector(root_x, up) return rotation_align_vecs(root_x, right) def compute_world2aligned(T_w0, **kwargs): """ compute alignment transform to take T_w0 to aligned frame where body right is -x, and up is as specified (default +z) :param T_w0 (*, 4, 4) return (*, 4, 4) """ R_aw = rotation_align_body_right(T_w0[..., :3, :3], **kwargs) # (*..., 3, 3) t_aw = torch.einsum("...ij,...j->...i", -R_aw, T_w0[..., :3, 3]) T_aw = make_transform(R_aw, t_aw) # (*, 4, 4) return T_aw def rotation_align_vecs(src, target): """ compute rotation taking src to target through the shared plane :param src (*, 3) :param target (*, 3) return (*, 3, 3) rotation matrix """ axis = F.normalize(torch.linalg.cross(src, target), dim=-1) angle = torch.arccos( (src * target).sum(dim=-1) / (src.norm(dim=-1) * target.norm(dim=-1)) ) return axis_angle_to_matrix(axis * angle.unsqueeze(-1)) def compute_point_height(point, floor_plane): """ compute height of point from floor_plane :param point (*, 3) :param floor_plane (*, 3) """ floor_plane_4d = parse_floor_plane(floor_plane) floor_normal = floor_plane_4d[..., :3] # compute the distance from root to ground plane _, s_root = compute_plane_intersection(point, -floor_normal, floor_plane_4d) return s_root def compute_world2floor( floor_plane_4d, root_orient_mat, trans ) -> Tuple[Tensor, Tensor]: """ compute the transform from world frame (opencv +x right, +y down, +z forward), to floor frame (-x body right, +y up, with origin at trans) :param floor_plane (*, 4) floor plane in world coordinates :param root_orient_mat (*, 3, 3) root orientation in world :param trans (*, 3) root trans in world """ floor_normal = floor_plane_4d[..., :3] # compute prior frame axes in the camera frame # right is body +x direction projected to floor plane root_x = root_orient_mat[..., 0] x = F.normalize(root_x - project_vector(root_x, floor_normal), dim=-1) y = floor_normal z = F.normalize(torch.linalg.cross(x, y), dim=-1) # floor frame in world is body x right, floor normal up R_wf = torch.stack([x, y, z], dim=-1) R_fw = torch.linalg.inv(R_wf) t_fw = torch.einsum("...ij,...j->...i", -R_fw, trans) return R_fw, t_fw def compute_plane_transform( plane_4d: Float[Tensor, "*batch 4"], up: Float[Tensor, "*batch 3"], origin: Optional[Float[Tensor, "*batch 3"]] = None, ): """ compute the R and t transform from identity, where plane normal is up """ normal = plane_4d[..., :3] offset = plane_4d[..., 3:] normal = F.normalize(normal, dim=-1) up = F.normalize(up, dim=-1) v = torch.linalg.cross(up, normal) # (*, 3) vnorm = torch.linalg.norm(v, dim=-1, keepdim=True) # (*, 1) s = torch.arcsin(vnorm) / vnorm R = axis_angle_to_matrix(v * s) # (*, 3, 3) if origin is not None: t, _ = compute_plane_intersection(origin, -normal, plane_4d) else: # translate plane along normal vector t = normal * offset # (*, 3) return R, t def fit_plane( points: Float[Tensor, "*batch N 3"], weights: Optional[Float[Tensor, "*batch N 1"]] = None, force_sign: int = -1, ) -> Float[Tensor, "*batch 4"]: """ :param points (*, N, 3) returns (*, 4) plane parameters (returns in (normal, offset) format) """ *dims, _ = points.shape device = points.device if weights is None: weights = torch.ones(*dims, 1, device=device) mean = (weights * points).sum(dim=-2, keepdim=True) / weights.sum( dim=-2, keepdim=True ) # (*, N, 3), (*, 3), (*, 3, 3) _, _, Vh = torch.linalg.svd(weights * (points - mean)) normal = Vh[..., -1, :] # (*, 3) offset = torch.einsum("...ij,...j->...i", points, normal) # (*, N) w = weights[..., 0] # (*, N) offset = ((w * offset).sum(dim=-1) / w.sum(dim=-1)).unsqueeze(-1) # (*, 1) if force_sign != 0: normal, offset = force_plane_direction(normal, offset, sign=force_sign) return torch.cat([normal, offset], dim=-1) def parse_floor_plane(floor_plane: Tensor, force_sign: int = -1) -> Tensor: """ Takes floor plane in the optimization form (Bx3 with a,b,c * d) and parses into (a,b,c,d) from with (a,b,c) normal facing "up in the camera frame and d the offset. """ if floor_plane.shape[-1] == 4: return floor_plane floor_offset = torch.linalg.norm(floor_plane, dim=-1, keepdim=True) floor_normal = floor_plane / (floor_offset + 1e-5) # there's ambiguity in the signs of the normal and offset, # force the sign of the normal to be positive or negative depending # on convention if force_sign != 0: floor_normal, floor_offset = force_plane_direction( floor_normal, floor_offset, sign=force_sign ) return torch.cat([floor_normal, floor_offset], dim=-1) def force_plane_direction( floor_normal: Float[Tensor, "*batch 3"], floor_offset: Float[Tensor, "*batch 1"], sign: int = -1, ) -> Tuple[Float[Tensor, "*batch 3"], Float[Tensor, "*batch 1"]]: assert sign != 0 if sign > 0: mask = floor_normal[..., 1:2] < 0 else: mask = floor_normal[..., 1:2] > 0 floor_normal = torch.where( mask.expand_as(floor_normal), -floor_normal, floor_normal ) floor_offset = torch.where(mask, -floor_offset, floor_offset) return floor_normal, floor_offset def compute_plane_intersection(point, direction, plane, eps=1e-5): """ given a ray defined by a point in space and a direction, compute the intersection point with the given plane. :param point (*, 3) :param direction (*, 3) :param plane (*, 4) (normal, offset) returns: - itsct_pt (*, 3) - s (*, 1) s.t. itsct_pt = point + s * direction """ plane_normal = plane[..., :3] plane_off = plane[..., 3:] s = (plane_off - bdot(plane_normal, point)) / (bdot(plane_normal, direction) + eps) itsct_pt = point + s * direction return itsct_pt, s def project_vector(x, d): """ project x onto d :param x, d (*, 3) """ d = F.normalize(d, dim=-1) return bdot(x, d) * d def bdot(A1, A2, keepdim=True, **kwargs): """ batched dot product :param A1, A2 (*, D) returs (*, 1) """ return (A1 * A2).sum(dim=-1, keepdim=keepdim, **kwargs) ================================================ FILE: src/egoallo/preprocessing/geometry/rotation.py ================================================ from typing import Tuple import torch from torch.nn import functional as F def get_rot_rep_shape(rot_rep:str) -> Tuple: assert rot_rep in ["aa", "quat", "6d", "mat"] if rot_rep == "6d": return (6,) if rot_rep == "aa": return (3,) if rot_rep == "quat": return (4,) return (3, 3) def convert_rotation(rot, src_rep, tgt_rep): src_rep, tgt_rep = src_rep.lower(), tgt_rep.lower() if src_rep == tgt_rep: return rot if src_rep == "aa": if tgt_rep == "mat": return axis_angle_to_matrix(rot) if tgt_rep == "quat": return axis_angle_to_quaternion(rot) if tgt_rep == "6d": return axis_angle_to_cont_6d(rot) raise NotImplementedError if src_rep == "quat": if tgt_rep == "aa": return quaternion_to_axis_angle(rot) if tgt_rep == "mat": return quaternion_to_matrix(rot) if tgt_rep == "6d": return matrix_to_cont_6d(quaternion_to_matrix(rot)) raise NotImplementedError if src_rep == "mat": if tgt_rep == "6d": return matrix_to_cont_6d(rot) if tgt_rep == "aa": return matrix_to_axis_angle(rot) if tgt_rep == "quat": return matrix_to_quaternion(rot) raise NotImplementedError if src_rep == "6d": if tgt_rep == "mat": return cont_6d_to_matrix(rot) if tgt_rep == "aa": return cont_6d_to_axis_angle(rot) if tgt_rep == "quat": return matrix_to_quaternion(cont_6d_to_matrix(rot)) raise NotImplementedError raise NotImplementedError def rodrigues_vec_to_matrix(rot_vecs, dtype=torch.float32): """ Calculates the rotation matrices for a batch of rotation vectors referenced from https://github.com/mkocabas/VIBE/blob/master/lib/utils/geometry.py :param rot_vecs (*, 3) axis-angle vectors :returns rot_mats (*, 3, 3) """ dims = rot_vecs.shape[:-1] # leading dimensions device, dtype = rot_vecs.device, rot_vecs.dtype angle = torch.norm(rot_vecs + 1e-8, dim=-1, keepdim=True) # (*, 1) rot_dir = rot_vecs / angle # (*, 3) cos = torch.unsqueeze(torch.cos(angle), dim=-2) # (*, 1, 1) sin = torch.unsqueeze(torch.sin(angle), dim=-2) # (*, 1, 1) rx, ry, rz = torch.split(rot_dir, 1, dim=-1) # (*, 1) each zeros = torch.zeros(*dims, 1, dtype=dtype, device=device) K = torch.cat([zeros, -rz, ry, rz, zeros, -rx, -ry, rx, zeros], dim=-1).view( (*dims, 3, 3) ) I = torch.eye(3, dtype=dtype, device=device).reshape(*(1,) * len(dims), 3, 3) rot_mat = I + sin * K + (1 - cos) * torch.einsum("...ij,...jk->...ik", K, K) return rot_mat def matrix_to_axis_angle(matrix): """ Convert rotation matrix to Rodrigues vector """ quaternion = matrix_to_quaternion(matrix) aa = quaternion_to_axis_angle(quaternion) aa[torch.isnan(aa)] = 0.0 return aa def axis_angle_to_matrix(rot_vec): quaternion = axis_angle_to_quaternion(rot_vec) return quaternion_to_matrix(quaternion) def axis_angle_to_cont_6d(rot_vec): """ :param rot_vec (*, 3) :returns 6d vector (*, 6) """ rot_mat = axis_angle_to_matrix(rot_vec) return matrix_to_cont_6d(rot_mat) def matrix_to_cont_6d(matrix): """ :param matrix (*, 3, 3) :returns 6d vector (*, 6) """ return torch.cat([matrix[..., 0], matrix[..., 1]], dim=-1) def cont_6d_to_matrix(cont_6d): """ :param 6d vector (*, 6) :returns matrix (*, 3, 3) """ x1 = cont_6d[..., 0:3] y1 = cont_6d[..., 3:6] x = F.normalize(x1, dim=-1) y = F.normalize(y1 - (y1 * x).sum(dim=-1, keepdim=True) * x, dim=-1) z = torch.linalg.cross(x, y, dim=-1) return torch.stack([x, y, z], dim=-1) def cont_6d_to_axis_angle(cont_6d): rot_mat = cont_6d_to_matrix(cont_6d) return matrix_to_axis_angle(rot_mat) def quaternion_to_axis_angle(quaternion, eps=1e-5): """ This function is borrowed from https://github.com/kornia/kornia Convert quaternion vector to angle axis of rotation. Adapted from ceres C++ library: ceres-solver/include/ceres/rotation.h :param quaternion (*, 4) expects WXYZ :returns axis_angle (*, 3) """ # unpack input and compute conversion q1 = quaternion[..., 1] q2 = quaternion[..., 2] q3 = quaternion[..., 3] sin_squared_theta = q1 * q1 + q2 * q2 + q3 * q3 sin_theta = torch.sqrt(sin_squared_theta) cos_theta = quaternion[..., 0] two_theta = 2.0 * torch.where( cos_theta < -eps, torch.atan2(-sin_theta, -cos_theta), torch.atan2(sin_theta, cos_theta), ) k_pos = two_theta / sin_theta k_neg = 2.0 * torch.ones_like(sin_theta) k = torch.where(sin_squared_theta > eps, k_pos, k_neg) axis_angle = torch.zeros_like(quaternion)[..., :3] axis_angle[..., 0] += q1 * k axis_angle[..., 1] += q2 * k axis_angle[..., 2] += q3 * k return axis_angle def quaternion_to_matrix(quaternion): """ Convert a quaternion to a rotation matrix. Taken from https://github.com/kornia/kornia, based on https://github.com/matthew-brett/transforms3d/blob/8965c48401d9e8e66b6a8c37c65f2fc200a076fa/transforms3d/quaternions.py#L101 https://github.com/tensorflow/graphics/blob/master/tensorflow_graphics/geometry/transformation/rotation_matrix_3d.py#L247 :param quaternion (N, 4) expects WXYZ order returns rotation matrix (N, 3, 3) """ # normalize the input quaternion quaternion_norm = F.normalize(quaternion, p=2, dim=-1, eps=1e-8) *dims, _ = quaternion_norm.shape # unpack the normalized quaternion components w, x, y, z = torch.chunk(quaternion_norm, chunks=4, dim=-1) # compute the actual conversion tx = 2.0 * x ty = 2.0 * y tz = 2.0 * z twx = tx * w twy = ty * w twz = tz * w txx = tx * x txy = ty * x txz = tz * x tyy = ty * y tyz = tz * y tzz = tz * z one = torch.tensor(1.0) matrix = torch.stack( ( one - (tyy + tzz), txy - twz, txz + twy, txy + twz, one - (txx + tzz), tyz - twx, txz - twy, tyz + twx, one - (txx + tyy), ), dim=-1, ).view(*dims, 3, 3) return matrix def axis_angle_to_quaternion(axis_angle, eps=1e-5): """ This function is borrowed from https://github.com/kornia/kornia Convert angle axis to quaternion in WXYZ order :param axis_angle (*, 3) :returns quaternion (*, 4) WXYZ order """ theta = torch.linalg.norm(axis_angle, dim=-1, keepdim=True) theta_sq = torch.square(theta) # theta_sq = torch.sum(axis_angle ** 2, dim=-1, keepdim=True) # (*, 1) # theta = torch.sqrt(theta_sq + eps) # need to handle the zero rotation case valid = theta_sq > eps half_theta = 0.5 * theta ones = torch.ones_like(half_theta) # fill zero with the limit of sin ax / x -> a k = torch.where(valid, torch.sin(half_theta) / (theta + eps), 0.5 * ones) w = torch.where(valid, torch.cos(half_theta), ones) quat = torch.cat([w, k * axis_angle], dim=-1) return quat def matrix_to_quaternion(matrix, eps=1e-6): """ This function is borrowed from https://github.com/kornia/kornia Convert rotation matrix to 4d quaternion vector This algorithm is based on algorithm described in https://github.com/KieranWynn/pyquaternion/blob/master/pyquaternion/quaternion.py#L201 :param matrix (N, 3, 3) """ *dims, m, n = matrix.shape rmat_t = torch.transpose(matrix.reshape(-1, m, n), -1, -2) mask_d2 = rmat_t[:, 2, 2] < eps mask_d0_d1 = rmat_t[:, 0, 0] > rmat_t[:, 1, 1] mask_d0_nd1 = rmat_t[:, 0, 0] < -rmat_t[:, 1, 1] t0 = 1 + rmat_t[:, 0, 0] - rmat_t[:, 1, 1] - rmat_t[:, 2, 2] q0 = torch.stack( [ rmat_t[:, 1, 2] - rmat_t[:, 2, 1], t0, rmat_t[:, 0, 1] + rmat_t[:, 1, 0], rmat_t[:, 2, 0] + rmat_t[:, 0, 2], ], -1, ) t0_rep = t0.repeat(4, 1).t() t1 = 1 - rmat_t[:, 0, 0] + rmat_t[:, 1, 1] - rmat_t[:, 2, 2] q1 = torch.stack( [ rmat_t[:, 2, 0] - rmat_t[:, 0, 2], rmat_t[:, 0, 1] + rmat_t[:, 1, 0], t1, rmat_t[:, 1, 2] + rmat_t[:, 2, 1], ], -1, ) t1_rep = t1.repeat(4, 1).t() t2 = 1 - rmat_t[:, 0, 0] - rmat_t[:, 1, 1] + rmat_t[:, 2, 2] q2 = torch.stack( [ rmat_t[:, 0, 1] - rmat_t[:, 1, 0], rmat_t[:, 2, 0] + rmat_t[:, 0, 2], rmat_t[:, 1, 2] + rmat_t[:, 2, 1], t2, ], -1, ) t2_rep = t2.repeat(4, 1).t() t3 = 1 + rmat_t[:, 0, 0] + rmat_t[:, 1, 1] + rmat_t[:, 2, 2] q3 = torch.stack( [ t3, rmat_t[:, 1, 2] - rmat_t[:, 2, 1], rmat_t[:, 2, 0] - rmat_t[:, 0, 2], rmat_t[:, 0, 1] - rmat_t[:, 1, 0], ], -1, ) t3_rep = t3.repeat(4, 1).t() mask_c0 = mask_d2 * mask_d0_d1 mask_c1 = mask_d2 * ~mask_d0_d1 mask_c2 = ~mask_d2 * mask_d0_nd1 mask_c3 = ~mask_d2 * ~mask_d0_nd1 mask_c0 = mask_c0.view(-1, 1).type_as(q0) mask_c1 = mask_c1.view(-1, 1).type_as(q1) mask_c2 = mask_c2.view(-1, 1).type_as(q2) mask_c3 = mask_c3.view(-1, 1).type_as(q3) q = q0 * mask_c0 + q1 * mask_c1 + q2 * mask_c2 + q3 * mask_c3 q /= torch.sqrt( t0_rep * mask_c0 + t1_rep * mask_c1 + t2_rep * mask_c2 # noqa + t3_rep * mask_c3 ) # noqa q *= 0.5 return q.reshape(*dims, 4) def quaternion_mul(q0, q1): """ EXPECTS WXYZ :param q0 (*, 4) :param q1 (*, 4) """ r0, r1 = q0[..., :1], q1[..., :1] v0, v1 = q0[..., 1:], q1[..., 1:] r = r0 * r1 - (v0 * v1).sum(dim=-1, keepdim=True) v = r0 * v1 + r1 * v0 + torch.linalg.cross(v0, v1) return torch.cat([r, v], dim=-1) def quaternion_inverse(q, eps=1e-5): """ EXPECTS WXYZ :param q (*, 4) """ conj = torch.cat([q[..., :1], -q[..., 1:]], dim=-1) mag = torch.square(q).sum(dim=-1, keepdim=True) + eps return conj / mag def quaternion_slerp(t, q0, q1, eps=1e-5): """ :param t (*, 1) must be between 0 and 1 :param q0 (*, 4) :param q1 (*, 4) """ dims = q0.shape[:-1] t = t.view(*dims, 1) q0 = F.normalize(q0, p=2, dim=-1) q1 = F.normalize(q1, p=2, dim=-1) dot = (q0 * q1).sum(dim=-1, keepdim=True) # make sure we give the shortest rotation path (< 180d) neg = dot < -eps q1 = torch.where(neg, -q1, q1) dot = torch.where(neg, -dot, dot) angle = torch.acos(dot) # if angle is too small, just do linear interpolation collin = torch.abs(dot) > 1 - eps fac = 1 / torch.sin(angle) w0 = torch.where(collin, 1 - t, torch.sin((1 - t) * angle) * fac) w1 = torch.where(collin, t, torch.sin(t * angle) * fac) slerp = q0 * w0 + q1 * w1 return slerp ================================================ FILE: src/egoallo/preprocessing/geometry/transforms/__init__.py ================================================ """Lie group interface for rigid transforms, ported from [jaxlie](https://github.com/brentyi/jaxlie). Used by `viser` internally and in examples. Implements SO(2), SO(3), SE(2), and SE(3) Lie groups. Rotations are parameterized via S^1 and S^3. """ from ._base import MatrixLieGroup as MatrixLieGroup from ._base import SEBase as SEBase from ._base import SOBase as SOBase from ._se2 import SE2 as SE2 from ._se3 import SE3 as SE3 from ._so2 import SO2 as SO2 from ._so3 import SO3 as SO3 ================================================ FILE: src/egoallo/preprocessing/geometry/transforms/_base.py ================================================ import abc from typing import ClassVar, Generic, Type, TypeVar, Union, overload, Optional, Tuple import torch from typing_extensions import final, override from . import hints GroupType = TypeVar("GroupType", bound="MatrixLieGroup") SEGroupType = TypeVar("SEGroupType", bound="SEBase") class MatrixLieGroup(abc.ABC): """Interface definition for matrix Lie groups.""" # Class properties. # > These will be set in `_utils.register_lie_group()`. matrix_dim: ClassVar[int] """Dimension of square matrix output from `.matrix()`.""" parameters_dim: ClassVar[int] """Dimension of underlying parameters, `.parameters()`.""" tangent_dim: ClassVar[int] """Dimension of tangent space.""" space_dim: ClassVar[int] """Dimension of coordinates that can be transformed.""" def __init__(self, parameters: torch.Tensor): """ Construct a group object from its underlying parameters. Notes: - For the constructor signature to be consistent with subclasses, `parameters` should be marked as positional-only. But this isn't possible in Python 3.7. - This method is implicitly overriden by the dataclass decorator and should _not_ be marked abstract. """ raise NotImplementedError() # Shared implementations. @overload def __mul__(self: GroupType, other: GroupType) -> GroupType: ... @overload def __mul__(self, other: hints.Array) -> torch.Tensor: ... def __mul__( self: GroupType, other: Union[GroupType, hints.Array] ) -> Union[GroupType, torch.Tensor]: """Overload for the `@` operator. Switches between the group action (`.act()`) and multiplication (`.mul()`) based on the type of `other`. """ if isinstance(other, hints.Array): return self.act(target=other) elif isinstance(other, MatrixLieGroup): assert self.space_dim == other.space_dim return self.mul(other=other) else: assert False, f"Invalid argument type for `@` operator: {type(other)}" # Factory. @classmethod @abc.abstractmethod def Identity( cls: Type[GroupType], shape: Optional[Tuple] = (), **kwargs ) -> GroupType: """Returns identity element. Returns: Identity element. """ @classmethod @abc.abstractmethod def from_matrix(cls: Type[GroupType], matrix: hints.Array) -> GroupType: """Get group member from matrix representation. Args: matrix: Matrix representaiton. Returns: Group member. """ # Accessors. @abc.abstractmethod def matrix(self) -> torch.Tensor: """Get transformation as a matrix. Homogeneous for SE groups.""" @abc.abstractmethod def parameters(self) -> torch.Tensor: """Get underlying representation.""" @property def data(self) -> torch.Tensor: return self.parameters() def __getitem__(self, index): return self.__class__(self.data[index]) @property def shape(self): return self.data.shape # Operations. @abc.abstractmethod def act(self, target: hints.Array) -> torch.Tensor: """Applies group action to a point. Args: target: Point to transform. Returns: Transformed point. """ @abc.abstractmethod def mul(self: GroupType, other: GroupType) -> GroupType: """Composes this transformation with another. Returns: self @ other """ @classmethod @abc.abstractmethod def exp(cls: Type[GroupType], tangent: hints.Array) -> GroupType: """Computes `expm(wedge(tangent))`. Args: tangent: Tangent vector to take the exponential of. Returns: Output. """ @abc.abstractmethod def log(self) -> torch.Tensor: """Computes `vee(logm(transformation matrix))`. Returns: Output. Shape should be `(tangent_dim,)`. """ @abc.abstractmethod def adjoint(self, **kwargs) -> torch.Tensor: """Computes the adjoint, which transforms tangent vectors between tangent spaces. More precisely, for a transform `GroupType`: ``` GroupType @ exp(omega) = exp(Adj_T @ omega) @ GroupType ``` used for e.g. transforming twists, wrenches, and Jacobians across different reference frames. Returns: Output. Shape should be `(tangent_dim, tangent_dim)`. """ @abc.abstractmethod def inv(self: GroupType) -> GroupType: """Computes the inv of our transform. Returns: Output. """ @abc.abstractmethod def normalize(self: GroupType) -> GroupType: """Normalize/projects values and returns. Returns: GroupType: Normalized group member. """ class SOBase(MatrixLieGroup): """Base class for special orthogonal groups.""" ContainedSOType = TypeVar("ContainedSOType", bound=SOBase) class SEBase(Generic[ContainedSOType], MatrixLieGroup): """Base class for special Euclidean groups. Each SE(N) group member contains an SO(N) rotation, as well as an N-dimensional translation vector. """ # SE-specific interface. @classmethod @abc.abstractmethod def from_rotation_and_translation( cls: Type[SEGroupType], rotation: ContainedSOType, translation: hints.Array, ) -> SEGroupType: """Construct a rigid transform from a rotation and a translation. Args: rotation: Rotation term. translation: translation term. Returns: Constructed transformation. """ @final @classmethod def from_rotation(cls: Type[SEGroupType], rotation: ContainedSOType) -> SEGroupType: data = rotation.parameters() return cls.from_rotation_and_translation( rotation=rotation, translation=torch.zeros( *data.shape[:-1], cls.space_dim, dtype=data.dtype, device=data.devce ), ) @classmethod @abc.abstractmethod def from_translation( cls: Type[SEGroupType], translation: torch.Tensor ) -> SEGroupType: """Construct a transform from a translation term.""" @abc.abstractmethod def rotation(self) -> ContainedSOType: """Returns a transform's rotation term.""" @abc.abstractmethod def translation(self) -> torch.Tensor: """Returns a transform's translation term.""" # Overrides. @final @override def act(self, target: hints.Array) -> torch.Tensor: """ apply transform to point """ d = self.space_dim if target.shape[-1] == d: return self.rotation().act(target) + self.translation() # type: ignore # homogeneous point assert target.shape[-1] == d + 1 X, W = torch.split(target, [d, 1], dim=-1) # (*, d), (*, 1) Xp = self.rotation().act(X) + W * self.translation() return torch.cat([Xp, W], dim=-1) @final @override def mul(self: SEGroupType, other: SEGroupType) -> SEGroupType: return type(self).from_rotation_and_translation( rotation=self.rotation().mul(other.rotation()), translation=self.rotation().act(other.translation()) + self.translation(), ) @final @override def inv(self: SEGroupType) -> SEGroupType: R_inv = self.rotation().inv() return type(self).from_rotation_and_translation( rotation=R_inv, translation=-R_inv.act(self.translation()), ) @final @override def normalize(self: SEGroupType) -> SEGroupType: return type(self).from_rotation_and_translation( rotation=self.rotation().normalize(), translation=self.translation(), ) ================================================ FILE: src/egoallo/preprocessing/geometry/transforms/_se2.py ================================================ import dataclasses from typing import Optional, Tuple import torch import numpy as onp from typing_extensions import override from . import _base, hints from ._so2 import SO2 from .utils import get_epsilon, register_lie_group @register_lie_group( matrix_dim=3, parameters_dim=4, tangent_dim=3, space_dim=2, ) @dataclasses.dataclass class SE2(_base.SEBase[SO2]): """Special Euclidean group for proper rigid transforms in 2D. Ported to pytorch from `jaxlie.SE2`. Internal parameterization is `(cos, sin, x, y)`. Tangent parameterization is `(vx, vy, omega)`. """ # SE2-specific. unit_complex_xy: torch.Tensor """Internal parameters. `(cos, sin, x, y)`.""" @override def __repr__(self) -> str: unit_complex = torch.round(self.unit_complex_xy[..., :2], decimals=5) xy = torch.round(self.unit_complex_xy[..., 2:], decimals=5) return f"{self.__class__.__name__}(unit_complex={unit_complex}, xy={xy})" @staticmethod def from_xy_theta(x: hints.Scalar, y: hints.Scalar, theta: hints.Scalar) -> "SE2": """Construct a transformation from standard 2D pose parameters. Note that this is not the same as integrating over a length-3 twist. """ cos = torch.cos(torch.as_tensor(theta)) sin = torch.sin(torch.as_tensor(theta)) x, y = torch.as_tensor(x), torch.as_tensor(y) return SE2(unit_complex_xy=torch.stack([cos, sin, x, y], dim=-1)) # SE-specific. @staticmethod @override def from_rotation_and_translation( rotation: SO2, translation: hints.Array, ) -> "SE2": assert translation.shape[-1] == 2 return SE2( unit_complex_xy=torch.cat([rotation.unit_complex, translation], dim=-1) ) @override @classmethod def from_translation(cls, translation: torch.Tensor) -> "SE2": return SE2.from_rotation_and_translation( SO2.Identity( shape=translation.shape[:-1], dtype=translation.dtype, device=translation.device, ), translation, ) @override def rotation(self) -> SO2: return SO2(unit_complex=self.unit_complex_xy[..., :2]) @override def translation(self) -> torch.Tensor: return self.unit_complex_xy[..., 2:] # Factory. @staticmethod @override def Identity(shape: Optional[Tuple] = (), **kwargs) -> "SE2": id_elem = ( torch.tensor([1.0, 0.0, 0.0, 0.0], **kwargs) .reshape(*(1,) * len(shape), 4) .repeat(*shape, 1) ) return SE2(unit_complex_xy=id_elem) @staticmethod @override def from_matrix(matrix: hints.Array) -> "SE2": assert matrix.shape[-2:] == (3, 3) # Currently assumes bottom row is [0, 0, 1]. return SE2.from_rotation_and_translation( rotation=SO2.from_matrix(matrix[..., :2, :2]), translation=matrix[..., :2, 2], ) # Accessors. @override def parameters(self) -> torch.Tensor: return self.unit_complex_xy @override def matrix(self) -> torch.Tensor: cos, sin, x, y = self.unit_complex_xy.unbind(dim=-1) zero = torch.zeros_like(x) one = torch.ones_like(x) return torch.stack( [cos, -sin, x, sin, cos, y, zero, zero, one], dim=-1 ).reshape(*cos.shape, 3, 3) # Operations. @staticmethod @override def exp(tangent: hints.Array) -> "SE2": # Reference: # > https://github.com/strasdat/Sophus/blob/a0fe89a323e20c42d3cecb590937eb7a06b8343a/sophus/se2.hpp#L558 # Also see: # > http://ethaneade.com/lie.pdf assert tangent.shape[-1] == 3 theta = tangent[..., 2] # transform the translation vector use_taylor = torch.abs(theta) < get_epsilon(tangent.dtype) safe_theta = torch.where( use_taylor, torch.ones_like(theta), # Any non-zero value should do here. theta, ) theta_sq = theta ** 2 sin_over_theta = torch.where( use_taylor, 1.0 - theta_sq / 6.0, torch.sin(safe_theta) / safe_theta, ) one_minus_cos_over_theta = torch.where( use_taylor, 0.5 * theta - theta * theta_sq / 24.0, (1.0 - torch.cos(safe_theta)) / safe_theta, ) V = torch.stack( [ sin_over_theta, -one_minus_cos_over_theta, one_minus_cos_over_theta, sin_over_theta, ], dim=-1, ).reshape(*theta.shape, 2, 2) return SE2.from_rotation_and_translation( rotation=SO2.from_radians(theta), translation=torch.einsum("...ij,...j->...i", V, tangent[..., :2]), ) @override def log(self) -> torch.Tensor: # Reference: # > https://github.com/strasdat/Sophus/blob/a0fe89a323e20c42d3cecb590937eb7a06b8343a/sophus/se2.hpp#L160 # Also see: # > http://ethaneade.com/lie.pdf theta = self.rotation().log()[..., 0] cos = torch.cos(theta) cos_minus_one = cos - 1.0 half_theta = theta / 2.0 use_taylor = torch.abs(cos_minus_one) < get_epsilon(theta.dtype) safe_cos_minus_one = torch.where( use_taylor, torch.ones_like(cos_minus_one), # Any non-zero value should do here. cos_minus_one, ) half_theta_over_tan_half_theta = torch.where( use_taylor, # Taylor approximation. 1.0 - theta ** 2 / 12.0, # Default. -(half_theta * onp.sin(theta)) / safe_cos_minus_one, ) V_inv = torch.stack( [ half_theta_over_tan_half_theta, half_theta, -half_theta, half_theta_over_tan_half_theta, ], dim=-1, ).reshape(*theta.shape, 2, 2) tangent = torch.cat( [ torch.einsum("...ij,...j->...i", V_inv, self.translation()), theta[..., None], ] ) return tangent @override def adjoint(self, **kwargs) -> torch.Tensor: cos, sin, x, y = self.unit_complex_xy.unbind(dim=-1) zero = torch.zeros_like(x) one = torch.ones_like(x) return torch.stack( [ cos, -sin, y, sin, cos, -x, zero, zero, one, ], dim=-1, ).reshape(*x.shape, 3, 3) ================================================ FILE: src/egoallo/preprocessing/geometry/transforms/_se3.py ================================================ from __future__ import annotations import dataclasses from typing import Optional, Tuple import torch from typing_extensions import override from . import _base from ._so3 import SO3 from .utils import get_epsilon, register_lie_group def _skew(omega: torch.Tensor) -> torch.Tensor: """ Returns the skew-symmetric form of a length-3 vector. :param omega (*, 3) :returns (*, 3, 3) """ wx, wy, wz = omega.unbind(dim=-1) o = torch.zeros_like(wx) return torch.stack( [o, -wz, wy, wz, o, -wx, -wy, wx, o], dim=-1, ).reshape(*wx.shape, 3, 3) @register_lie_group( matrix_dim=4, parameters_dim=7, tangent_dim=6, space_dim=3, ) @dataclasses.dataclass class SE3(_base.SEBase[SO3]): """Special Euclidean group for proper rigid transforms in 3D. Ported to pytorch from `jaxlie.SE3`. Internal parameterization is `(qw, qx, qy, qz, x, y, z)`. Tangent parameterization is `(vx, vy, vz, omega_x, omega_y, omega_z)`. """ # SE3-specific. wxyz_xyz: torch.Tensor """Internal parameters. wxyz quaternion followed by xyz translation.""" @override def __repr__(self) -> str: quat = torch.round(self.wxyz_xyz[..., :4], decimals=5) trans = torch.round(self.wxyz_xyz[..., 4:], decimals=5) return f"{self.__class__.__name__}(wxyz={quat}, xyz={trans})" # SE-specific. @staticmethod @override def from_rotation_and_translation( rotation: SO3, translation: torch.Tensor, ) -> "SE3": assert translation.shape[-1] == 3 return SE3(wxyz_xyz=torch.cat([rotation.wxyz, translation], dim=-1)) @override @classmethod def from_translation(cls, translation: torch.Tensor) -> "SE3": return SE3.from_rotation_and_translation( SO3.Identity( shape=translation.shape[:-1], dtype=translation.dtype, device=translation.device, ), translation, ) @override def rotation(self) -> SO3: return SO3(wxyz=self.wxyz_xyz[..., :4]) @override def translation(self) -> torch.Tensor: return self.wxyz_xyz[..., 4:] # Factory. @staticmethod @override def Identity(shape: Optional[Tuple] = (), **kwargs) -> "SE3": id_elem = ( torch.tensor([1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], **kwargs) .reshape(*(1,) * len(shape), 7) .repeat(*shape, 1) ) return SE3(wxyz_xyz=id_elem) @staticmethod @override def from_matrix(matrix: torch.Tensor) -> "SE3": assert matrix.shape[-2:] == (4, 4) # Currently assumes bottom row is [0, 0, 0, 1]. return SE3.from_rotation_and_translation( rotation=SO3.from_matrix(matrix[..., :3, :3]), translation=matrix[..., :3, 3], ) # Accessors. @override def matrix(self) -> torch.Tensor: R = self.rotation().matrix() # (*, 3, 3) t = self.translation().unsqueeze(-1) # (*, 3, 1) dims = R.shape[:-2] bottom = ( torch.tensor([0, 0, 0, 1], dtype=R.dtype, device=R.device) .reshape(*(1,) * len(dims), 1, 4) .repeat(*dims, 1, 1) ) return torch.cat([torch.cat([R, t], dim=-1), bottom], dim=-2) @override def parameters(self) -> torch.Tensor: return self.wxyz_xyz # Operations. @staticmethod @override def exp(tangent: torch.Tensor) -> "SE3": """ :param tangent (*, 6) """ # Reference: # > https://github.com/strasdat/Sophus/blob/a0fe89a323e20c42d3cecb590937eb7a06b8343a/sophus/se3.hpp#L761 # (x, y, z, omega_x, omega_y, omega_z) *dims, d = tangent.shape assert d == 6 trans, omega = torch.split(tangent, [3, 3], dim=-1) # (*, 3), (*, 3) rotation = SO3.exp(omega) # (*, 3) theta_squared = torch.square(omega).sum(dim=-1) # (*) use_taylor = theta_squared < get_epsilon(theta_squared.dtype) theta_squared_safe = torch.where( use_taylor, torch.ones_like(theta_squared), # Any non-zero value should do here. theta_squared, ) del theta_squared theta_safe = torch.sqrt(theta_squared_safe) skew_omega = _skew(omega) # (*, 3, 3) I = ( torch.eye(3, device=omega.device) .reshape(*(1,) * len(dims), 3, 3) .expand(*dims, 3, 3) ) f1 = (1.0 - torch.cos(theta_safe)) / (theta_squared_safe) f2 = (theta_safe - torch.sin(theta_safe)) / (theta_squared_safe * theta_safe) V = torch.where( use_taylor[..., None, None], rotation.matrix(), I + f1[..., None, None] * skew_omega + f2[..., None, None] * torch.matmul(skew_omega, skew_omega), ) return SE3.from_rotation_and_translation( rotation=rotation, translation=torch.einsum("...ij,...j->...i", V, trans), ) @override def log(self) -> torch.Tensor: # Reference: # > https://github.com/strasdat/Sophus/blob/a0fe89a323e20c42d3cecb590937eb7a06b8343a/sophus/se3.hpp#L223 omega = self.rotation().log() # (*, 3) theta_squared = torch.square(omega).sum(dim=-1) # (*) use_taylor = theta_squared < get_epsilon(theta_squared.dtype) theta_squared_safe = torch.where( use_taylor, torch.ones_like(theta_squared), # Any non-zero value should do here. theta_squared, ) del theta_squared theta_safe = torch.sqrt(theta_squared_safe) half_theta_safe = theta_safe / 2.0 skew_omega = _skew(omega) # (*, 3, 3) skew_omega_sq = torch.matmul(skew_omega) I = torch.eye(3, device=omega.device).reshape(*(1,) * len(dims), 3, 3) f2 = ( 1.0 - theta_safe * torch.cos(half_theta_safe) / (2.0 * torch.sin(half_theta_safe)) ) / theta_squared_safe V_inv = torch.where( use_taylor, I - 0.5 * skew_omega + skew_omega_sq / 12.0, I - 0.5 * skew_omega + f2[..., None, None] * skew_omega_sq, ) return torch.cat( [torch.einsum("...ij,...j->...i", V_inv, self.translation()), omega], dim=-1 ) @override def adjoint(self) -> torch.Tensor: R = self.rotation().matrix() dims = R.shape[:-2] # (*, 6, 6) return torch.cat( [ torch.cat([R, torch.matmul(_skew(self.translation()), R)], dim=-1), torch.cat([torch.zeros((*dims, 3, 3)), R], dim=-1), ], dim=-2, ) ================================================ FILE: src/egoallo/preprocessing/geometry/transforms/_so2.py ================================================ from __future__ import annotations import dataclasses from typing import Optional, Tuple import torch from typing_extensions import override from . import _base, hints from .utils import register_lie_group @register_lie_group( matrix_dim=2, parameters_dim=2, tangent_dim=1, space_dim=2, ) @dataclasses.dataclass class SO2(_base.SOBase): """Special orthogonal group for 2D rotations. Ported to pytorch from `jaxlie.SO2`. Internal parameterization is `(cos, sin)`. Tangent parameterization is `(omega,)`. """ # SO2-specific. unit_complex: torch.Tensor """Internal parameters. `(cos, sin)`.""" @override def __repr__(self) -> str: unit_complex = torch.round(self.unit_complex, 5) return f"{self.__class__.__name__}(unit_complex={unit_complex})" @staticmethod def from_radians(theta: hints.Scalar) -> SO2: """Construct a rotation object from a scalar angle.""" theta = torch.as_tensor(theta) cos = torch.cos(theta) sin = torch.sin(theta) return SO2(unit_complex=torch.stack([cos, sin], dim=-1)) def as_radians(self) -> torch.Tensor: """Compute a scalar angle from a rotation object.""" radians = self.log()[..., 0] return radians # Factory. @staticmethod @override def Identity(shape: Optional[Tuple] = (), **kwargs) -> SO2: id_elem = ( torch.tensor([1.0, 0.0], **kwargs) .reshape(*(1,) * len(shape), 2) .repeat(*shape, 1) ) return SO2(unit_complex=id_elem) @staticmethod @override def from_matrix(matrix: torch.Tensor) -> SO2: assert matrix.shape[-2:] == (2, 2) return SO2(unit_complex=matrix[..., 0]) # Accessors. @override def matrix(self) -> torch.Tensor: """ [[cos, -sin], [sin, cos]] :returns (*, 2, 2) tensor """ cos, sin = self.unit_complex.unbind(dim=-1) return torch.stack([cos, -sin, sin, cos], dim=-1).reshape(*cos.shape, 2, 2) @override def parameters(self) -> torch.Tensor: return self.unit_complex # Operations. @override def act(self, target: torch.Tensor) -> torch.Tensor: assert target.shape[-1] == 2 return torch.einsum("...ij,...j->...i", self.matrix(), target) @override def mul(self, other: SO2) -> SO2: return SO2( unit_complex=torch.einsum( "...ij,...j->...i", self.matrix(), other.unit_complex ) ) @staticmethod @override def exp(tangent: torch.Tensor) -> SO2: return SO2( unit_complex=torch.stack([torch.cos(tangent), torch.sin(tangent)], dim=-1) ) @override def log(self) -> torch.Tensor: return torch.atan2( self.unit_complex[..., 1, None], self.unit_complex[..., 0, None] ) @override def adjoint(self, **kwargs) -> torch.Tensor: return torch.eye(1, **kwargs) @override def inv(self) -> SO2: cos, sin = self.unit_complex.unbind(dim=-1) return SO2(unit_complex=torch.stack([cos, -sin], dim=-1)) @override def normalize(self) -> SO2: return SO2( unit_complex=self.unit_complex / torch.linalg.norm(self.unit_complex, dim=-1, keepdim=True) ) ================================================ FILE: src/egoallo/preprocessing/geometry/transforms/_so3.py ================================================ from __future__ import annotations from typing import Optional, Tuple import dataclasses import math import torch from typing_extensions import override from . import _base, hints from .utils import get_epsilon, register_lie_group @register_lie_group( matrix_dim=3, parameters_dim=4, tangent_dim=3, space_dim=3, ) @dataclasses.dataclass class SO3(_base.SOBase): """Special orthogonal group for 3D rotations. Ported to pytorch from `jaxlie.SO3`. Internal parameterization is `(qw, qx, qy, qz)`. Tangent parameterization is `(omega_x, omega_y, omega_z)`. """ # SO3-specific. wxyz: torch.Tensor """Internal parameters. `(w, x, y, z)` quaternion.""" @override def __repr__(self) -> str: wxyz = torch.round(self.wxyz, decimals=5) return f"{self.__class__.__name__}(wxyz={wxyz})" @staticmethod def from_x_radians(theta: torch.Tensor) -> SO3: """ Generates a x-axis rotation. :param theta (tensor) x rotation :returns SO3 object """ zero = torch.zeros_like(theta) return SO3.exp(torch.stack([theta, zero, zero], dim=-1)) @staticmethod def from_y_radians(theta: torch.Tensor) -> SO3: """ Generates a y-axis rotation. :param theta (tensor) y rotation :returns SO3 object """ zero = torch.zeros_like(theta) return SO3.exp(torch.stack([zero, theta, zero], dim=-1)) @staticmethod def from_z_radians(theta: torch.Tensor) -> SO3: """ Generates a z-axis rotation. :param theta (tensor) z rotation :returns SO3 object """ zero = torch.zeros_like(theta) return SO3.exp(torch.stack([zero, zero, theta], dim=-1)) @staticmethod def from_rpy_radians( roll: torch.Tensor, pitch: torch.Tensor, yaw: torch.Tensor, ) -> SO3: """ Generates a transform from a set of Euler angles. Uses the ZYX convention. Args: roll: X rotation, in radians. Applied first. pitch: Y rotation, in radians. Applied second. yaw: Z rotation, in radians. Applied last. """ Rz = SO3.from_z_radians(yaw) Ry = SO3.from_y_radians(pitch) Rx = SO3.from_x_radians(roll) return Rz.mul(Ry.mul(Rx)) @staticmethod def from_quaternion_xyzw(xyzw: torch.Tensor) -> SO3: """ Construct a rotation from an `xyzw` quaternion. Note that `wxyz` quaternions can be constructed using the default dataclass constructor. :param xyzw (*, 4) quat in xyzw convention :returns SO3 object """ assert xyzw.shape[-1] == 4 return SO3(torch.roll(xyzw, shift=1, dims=-1)) def as_quaternion_xyzw(self) -> torch.Tensor: """Grab parameters as xyzw quaternion.""" return torch.roll(self.wxyz, shift=-1, dims=-1) def as_rpy_radians(self) -> hints.RollPitchYaw: """ Computes roll, pitch, and yaw angles. Uses the ZYX convention. Returns: Named tuple containing Euler angles in radians. """ return hints.RollPitchYaw( roll=self.compute_roll_radians(), pitch=self.compute_pitch_radians(), yaw=self.compute_yaw_radians(), ) def compute_roll_radians(self) -> torch.Tensor: """ Compute roll angle. Uses the ZYX convention. :returns angle (*) if wxyz is (*, 4) """ # https://en.wikipedia.org/wiki/Conversion_between_quaternions_and_Euler_angles#Quaternion_to_Euler_angles_conversion q0, q1, q2, q3 = self.wxyz.unbind(dim=-1) return torch.atan2(2 * (q0 * q1 + q2 * q3), 1 - 2 * (q1 ** 2 + q2 ** 2)) def compute_pitch_radians(self) -> torch.Tensor: """ Compute pitch angle. Uses the ZYX convention. :returns angle (*) if wxyz is (*, 4) """ # https://en.wikipedia.org/wiki/Conversion_between_quaternions_and_Euler_angles#Quaternion_to_Euler_angles_conversion q0, q1, q2, q3 = self.wxyz.unbind(dim=-1) return torch.asin(2 * (q0 * q2 - q3 * q1)) def compute_yaw_radians(self) -> torch.Tensor: """ Compute yaw angle. Uses the ZYX convention. :returns angle (*) if wxyz is (*, 4) """ # https://en.wikipedia.org/wiki/Conversion_between_quaternions_and_Euler_angles#Quaternion_to_Euler_angles_conversion q0, q1, q2, q3 = self.wxyz.unbind(dim=-1) return torch.atan2(2 * (q0 * q3 + q1 * q2), 1 - 2 * (q2 ** 2 + q3 ** 2)) # Factory. @staticmethod @override def Identity(shape: Optional[Tuple] = (), **kwargs) -> SO3: id_elem = ( torch.tensor([1.0, 0.0, 0.0, 0.0], **kwargs) .reshape(*(1,) * len(shape), 4) .repeat(*shape, 1) ) return SO3(wxyz=id_elem) @staticmethod @override def from_matrix(matrix: torch.Tensor) -> SO3: assert matrix.shape[-2:] == (3, 3) # Modified from: # > "Converting a Rotation Matrix to a Quaternion" from Mike Day # > https://d3cw3dd2w32x2b.cloudfront.net/wp-content/uploads/2015/01/matrix-to-quat.pdf def case0(m): t = 1 + m[..., 0, 0] - m[..., 1, 1] - m[..., 2, 2] q = torch.stack( [ m[..., 2, 1] - m[..., 1, 2], t, m[..., 1, 0] + m[..., 0, 1], m[..., 0, 2] + m[..., 2, 0], ], dim=-1, ) return t, q def case1(m): t = 1 - m[..., 0, 0] + m[..., 1, 1] - m[..., 2, 2] q = torch.stack( [ m[..., 0, 2] - m[..., 2, 0], m[..., 1, 0] + m[..., 0, 1], t, m[..., 2, 1] + m[..., 1, 2], ], dim=-1, ) return t, q def case2(m): t = 1 - m[..., 0, 0] - m[..., 1, 1] + m[..., 2, 2] q = torch.stack( [ m[..., 1, 0] - m[..., 0, 1], m[..., 0, 2] + m[..., 2, 0], m[..., 2, 1] + m[..., 1, 2], t, ], dim=-1, ) return t, q def case3(m): t = 1 + m[..., 0, 0] + m[..., 1, 1] + m[..., 2, 2] q = torch.stack( [ t, m[..., 2, 1] - m[..., 1, 2], m[..., 0, 2] - m[..., 2, 0], m[..., 1, 0] - m[..., 0, 1], ], dim=-1, ) return t, q # Compute four cases, then pick the most precise one. # Probably worth revisiting this! case0_t, case0_q = case0(matrix) case1_t, case1_q = case1(matrix) case2_t, case2_q = case2(matrix) case3_t, case3_q = case3(matrix) cond0 = matrix[..., 2, 2] < 0 cond1 = matrix[..., 0, 0] > matrix[..., 1, 1] cond2 = matrix[..., 0, 0] < -matrix[..., 1, 1] t = torch.where( cond0, torch.where(cond1, case0_t, case1_t), torch.where(cond2, case2_t, case3_t), ).unsqueeze(-1) q = torch.where( cond0.unsqueeze(-1), torch.where(cond1.unsqueeze(-1), case0_q, case1_q), torch.where(cond2.unsqueeze(-1), case2_q, case3_q), ) return SO3(wxyz=q * 0.5 / torch.sqrt(t)) # Accessors. @override def matrix(self) -> torch.Tensor: norm_sq = torch.square(self.wxyz).sum(dim=-1, keepdim=True) qvec = self.wxyz * torch.sqrt(2.0 / norm_sq) # (*, 4) Q = torch.einsum("...i,...j->...ij", qvec, qvec) # (*, 4, 4) return torch.stack( [ 1.0 - Q[..., 2, 2] - Q[..., 3, 3], Q[..., 1, 2] - Q[..., 3, 0], Q[..., 1, 3] + Q[..., 2, 0], Q[..., 1, 2] + Q[..., 3, 0], 1.0 - Q[..., 1, 1] - Q[..., 3, 3], Q[..., 2, 3] - Q[..., 1, 0], Q[..., 1, 3] - Q[..., 2, 0], Q[..., 2, 3] + Q[..., 1, 0], 1.0 - Q[..., 1, 1] - Q[..., 2, 2], ], dim=-1, ).reshape(*qvec.shape[:-1], 3, 3) @override def parameters(self) -> torch.Tensor: return self.wxyz # Operations. @override def act(self, target: torch.Tensor) -> torch.Tensor: assert target.shape[-1] == 3 # Compute using quaternion muls. padded_target = torch.cat([torch.ones_like(target[..., :1]), target], dim=-1) out = self.mul(SO3(wxyz=padded_target).mul(self.inv())) return out.wxyz[..., 1:] @override def mul(self, other: SO3) -> SO3: w0, x0, y0, z0 = self.wxyz.unbind(dim=-1) w1, x1, y1, z1 = other.wxyz.unbind(dim=-1) wxyz2 = torch.stack( [ -x0 * x1 - y0 * y1 - z0 * z1 + w0 * w1, x0 * w1 + y0 * z1 - z0 * y1 + w0 * x1, -x0 * z1 + y0 * w1 + z0 * x1 + w0 * y1, x0 * y1 - y0 * x1 + z0 * w1 + w0 * z1, ], dim=-1, ) return SO3(wxyz=wxyz2) @staticmethod @override def exp(tangent: torch.Tensor) -> SO3: """ create SO3 object from axis angle tangent vector :param tangent (*, 3) """ # Reference: # > https://github.com/strasdat/Sophus/blob/a0fe89a323e20c42d3cecb590937eb7a06b8343a/sophus/so3.hpp#L583 assert tangent.shape[-1] == 3 theta_squared = torch.square(tangent).sum(dim=-1) # (*) theta_pow_4 = theta_squared * theta_squared use_taylor = theta_squared < get_epsilon(tangent.dtype) safe_theta = torch.sqrt( torch.where( use_taylor, torch.ones_like(theta_squared), # Any constant value should do here. theta_squared, ) ) safe_half_theta = 0.5 * safe_theta real_factor = torch.where( use_taylor, 1.0 - theta_squared / 8.0 + theta_pow_4 / 384.0, torch.cos(safe_half_theta), ) imaginary_factor = torch.where( use_taylor, 0.5 - theta_squared / 48.0 + theta_pow_4 / 3840.0, torch.sin(safe_half_theta) / safe_theta, ) return SO3( wxyz=torch.cat( [ real_factor[..., None], imaginary_factor[..., None] * tangent, ], dim=-1, ) ) @override def log(self) -> torch.Tensor: """ log map to tangent space :return (*, 3) tangent vector """ # Reference: # > https://github.com/strasdat/Sophus/blob/a0fe89a323e20c42d3cecb590937eb7a06b8343a/sophus/so3.hpp#L247 w, xyz = torch.split(self.wxyz, [1, 3], dim=-1) # (*, 1), (*, 3) norm_sq = torch.square(xyz).sum(dim=-1, keepdim=True) # (*, 1) use_taylor = norm_sq < get_epsilon(norm_sq.dtype) norm_safe = torch.sqrt( torch.where( use_taylor, torch.ones_like(norm_sq), # Any non-zero value should do here. norm_sq, ) ) w_safe = torch.where(use_taylor, w, torch.ones_like(w)) atan_n_over_w = torch.atan2( torch.where(w < 0, -norm_safe, norm_safe), torch.abs(w), ) atan_factor = torch.where( use_taylor, 2.0 / w_safe - 2.0 / 3.0 * norm_sq / w_safe ** 3, torch.where( torch.abs(w) < get_epsilon(w.dtype), torch.where(w > 0, 1.0, -1.0) * math.pi / norm_safe, 2.0 * atan_n_over_w / norm_safe, ), ) return atan_factor * xyz @override def adjoint(self) -> torch.Tensor: return self.matrix() @override def inv(self) -> SO3: # Negate complex terms. w, xyz = torch.split(self.wxyz, [1, 3], dim=-1) return SO3(wxyz=torch.cat([w, -xyz], dim=-1)) @override def normalize(self) -> SO3: return SO3(wxyz=self.wxyz / torch.linalg.norm(self.wxyz, dim=-1, keepdim=True)) ================================================ FILE: src/egoallo/preprocessing/geometry/transforms/hints/__init__.py ================================================ from typing import NamedTuple, Union import numpy as np import torch Array = torch.Tensor """Type alias for `torch.Tensor`.""" Scalar = Union[float, Array] """Type alias for `Union[float, Array]`.""" class RollPitchYaw(NamedTuple): """Tuple containing roll, pitch, and yaw Euler angles.""" roll: Scalar pitch: Scalar yaw: Scalar __all__ = [ "Array", "Scalar", "RollPitchYaw", ] ================================================ FILE: src/egoallo/preprocessing/geometry/transforms/utils/__init__.py ================================================ from ._utils import get_epsilon, register_lie_group __all__ = ["get_epsilon", "register_lie_group"] ================================================ FILE: src/egoallo/preprocessing/geometry/transforms/utils/_utils.py ================================================ from typing import TYPE_CHECKING, Callable, Type, TypeVar import torch if TYPE_CHECKING: from .._base import MatrixLieGroup T = TypeVar("T", bound="MatrixLieGroup") def get_epsilon(dtype: torch.dtype) -> float: """Helper for grabbing type-specific precision constants. Args: dtype: Datatype. Returns: Output float. """ return { torch.float32: 1e-5, torch.float64: 1e-10, }[dtype] def register_lie_group( *, matrix_dim: int, parameters_dim: int, tangent_dim: int, space_dim: int, ) -> Callable[[Type[T]], Type[T]]: """Decorator for registering Lie group dataclasses. Sets dimensionality class variables, and (formerly in the JAX version) marks all methods for JIT compilation. """ def _wrap(cls: Type[T]) -> Type[T]: # Register dimensions as class attributes. cls.matrix_dim = matrix_dim cls.parameters_dim = parameters_dim cls.tangent_dim = tangent_dim cls.space_dim = space_dim return cls return _wrap ================================================ FILE: src/egoallo/preprocessing/util/__init__.py ================================================ from .tensor import * ================================================ FILE: src/egoallo/preprocessing/util/tensor.py ================================================ from loguru import logger as guru from typing import TypeVar, Dict, List import torch from torch import Tensor import torch.nn.functional as F import numpy as np from PIL import Image def batch_sum(x, nldims=1): """ Sum across all but batch dimension(s) :param x (B, *) :param nldims (optional int=1) number of leading dims to keep """ if x.ndim > nldims: return x.sum(dim=tuple(range(nldims, x.ndim))) return x def batch_mean(x, nldims=1): """ Mean across all but batch dimension(s) :param x (B, *) :param nldims (optional int=1) number of leading dims to keep """ if x.ndim > nldims: return x.mean(dim=tuple(range(nldims, x.ndim))) return x def pad_dim(x, max_len, dim=0, start=0, **kwargs): """ pads x to max_len in specified dim :param x (tensor) :param max_len (int) :param start (int default 0) :param dim (optional int default 0) """ N = x.shape[dim] if max_len == N: return x if max_len < N: return torch.narrow(x, dim, start, max_len) if dim < 0: dim = x.ndim + dim pad = [0, 0] * x.ndim pad[2 * dim + 1] = start pad[2 * dim] = max_len - (N + start) return F.pad(x, pad[::-1], **kwargs) def pad_back(x, max_len, dim=0, **kwargs): return pad_dim(x, max_len, dim, 0, **kwargs) def pad_front(x, max_len, dim=0, **kwargs): N = x.shape[dim] return pad_dim(x, max_len, dim=dim, start=max_len - N, **kwargs) def read_image(path, scale=1): im = Image.open(path) if scale == 1: return np.array(im) W, H = im.size w, h = int(scale * W), int(scale * H) return np.array(im.resize((w, h), Image.ANTIALIAS)) T = TypeVar("T") def move_to(obj: T, device) -> T: if isinstance(obj, torch.Tensor): return obj.to(device) if isinstance(obj, dict): return {k: move_to(v, device) for k, v in obj.items()} # type: ignore if isinstance(obj, (list, tuple)): return [move_to(x, device) for x in obj] # type: ignore return obj # otherwise do nothing def detach_all(obj: T) -> T: if isinstance(obj, torch.Tensor): return obj.detach() if isinstance(obj, dict): return {k: detach_all(v) for k, v in obj.items()} # type: ignore if isinstance(obj, (list, tuple)): return [detach_all(x) for x in obj] # type: ignore return obj # otherwise do nothing def to_torch(obj): if isinstance(obj, np.ndarray): return torch.from_numpy(obj).float() if isinstance(obj, dict): return {k: to_torch(v) for k, v in obj.items()} if isinstance(obj, (list, tuple)): return [to_torch(x) for x in obj] return obj def to_np(obj): if isinstance(obj, torch.Tensor): return obj.numpy() if isinstance(obj, dict): return {k: to_np(v) for k, v in obj.items()} if isinstance(obj, (list, tuple)): return [to_np(x) for x in obj] return obj def load_npz_as_dict(path, **kwargs): npz = np.load(path, **kwargs) return {key: npz[key] for key in npz.files} def get_device(i=0): device = f"cuda:{i}" if torch.cuda.is_available() else "cpu" return torch.device(device) def invert_nested_dict(d): """ invert nesting of dict of dicts """ outer_keys = d.keys() # outer nested keys inner_keys = next(iter(d.values())).keys() # inner nested keys return { inner: {outer: d[outer][inner] for outer in outer_keys} for inner in inner_keys } def batchify_dicts(dict_list: List[Dict]) -> Dict: """ given a list of dicts with shared keys, return a dict of those keys stacked into lists """ keys = dict_list[0].keys() if not all(d.keys() == keys for d in dict_list): guru.warning("found dicts with not same keys! using first element's keys") return {k: [d[k] for d in dict_list] for k in keys} def batchify_recursive(dict_list: List[Dict], levels: int = -1): x = dict_list[0] keys = x.keys() out = {} for k in keys: if isinstance(x[k], dict) and levels != 0: # aggregate the values with this key into # a list of dicts and batchify recursively vals = batchify_recursive( [d[k] for d in dict_list], levels=levels - 1 ) elif isinstance(x[k], (list, tuple)) and levels != 0: # aggregate the values with this key into # a flattened list and batchify recursively vals = [x for d in dict_list for x in d[k]] # perhaps another list of dicts if isinstance(vals[0], dict) and levels != 0: vals = batchify_recursive( vals, levels=levels - 1) else: # aggregate the values with this key into a list as is vals = [d[k] for d in dict_list] out[k] = vals return out def unbatch_dict(batched_dict, batch_size): """ :param d (dict) of batched tensors return len B list of dicts of unbatched tensors """ out_list = [{} for _ in range(batch_size)] for k, v in batched_dict.items(): for b in range(batch_size): out_list[b][k] = get_batch_element(v, b, batch_size) return out_list def get_batch_element(batch, idx, batch_size): if isinstance(batch, torch.Tensor): return batch[idx] if idx < batch.shape[0] else batch if isinstance(batch, dict): return {k: get_batch_element(v, idx, batch_size) for k, v in batch.items()} if isinstance(batch, list): if len(batch) == batch_size: return batch[idx] return [get_batch_element(v, idx, batch_size) for v in batch] if isinstance(batch, tuple): if len(batch) == batch_size: return batch[idx] return tuple(get_batch_element(v, idx, batch_size) for v in batch) return batch def narrow_dict(input_dict, tdim, start, length): """ slice dict of tensors :param d (dict) :param idcs (tensor or list) """ input_batch = {} for k, v in input_dict.items(): input_batch[k] = narrow_obj(v, tdim, start, length) return input_batch def narrow_list(input_list, tdim, start, length): return [narrow_obj(x, tdim, start, length) for x in input_list] def narrow_obj(v, tdim, start, length): if isinstance(v, dict): return narrow_dict(v, tdim, start, length) if isinstance(v, (tuple, list)): return narrow_list(v, tdim, start, length) if not isinstance(v, Tensor): return v if v.ndim <= tdim or v.shape[tdim] < start + length: return v return v.narrow(tdim, start, length) def scatter_intervals(tensor, start, end, T: int = -1): """ Scatter the tensor contents into intervals from start to end output tensor indexed from 0 to end.max() :param tensor (B, S, *) :param start (B) start indices :param end (B) end indices :param T (int, optional) max length returns (B, T, *) scattered tensor """ assert isinstance(tensor, torch.Tensor) and tensor.ndim >= 2 if T < 0: T = end.max() assert torch.all(end <= T) B, S, *dims = tensor.shape start, end = start.long(), end.long() # get idcs that go past the last time step so we don't have repeat indices in scatter idcs = time_segment_idcs(start, end, min_len=T, clip=False) # (B, T) # mask out the extra padding mask = idcs >= end[:, None] tensor[mask] = 0 idcs = idcs.reshape(B, S, *(1,) * len(dims)).repeat(1, 1, *dims) output = torch.zeros( B, idcs.max() + 1, *dims, device=tensor.device, dtype=tensor.dtype ) output.scatter_(1, idcs, tensor) # slice out the extra segments return output[:, :T] def get_scatter_mask(start, end, T): """ get the mask of selected intervals """ B = start.shape[0] start, end = start.long(), end.long() assert torch.all(end <= T) idcs = time_segment_idcs(start, end, clip=True) mask = torch.zeros(B, T, device=start.device, dtype=torch.bool) mask.scatter_(1, idcs, 1) return mask def select_intervals(series, start, end, pad_len: int = -1): """ Select slices of a tensor from start to end will pad uneven sequences to all the max segment length :param series (B, T, *) :param start (B) :param end (B) returns (B, S, *) selected segments, S = max(end - start) """ B, T, *dims = series.shape assert torch.all(end <= T) sel = time_segment_idcs(start, end, min_len=pad_len, clip=True) S = sel.shape[1] sel = sel.reshape(B, S, *(1,) * len(dims)).repeat(1, 1, *dims) return torch.gather(series, 1, sel) def get_select_mask(start, end): """ get the mask of unpadded elementes for the selected time segments e.g. sel[mask] are the unpadded elements :param start (B) :param end (B) """ idcs = time_segment_idcs(start, end, clip=False) return idcs < end[:, None] # (B, S) def time_segment_idcs(start, end, min_len: int = -1, clip: bool = True): """ :param start (B) :param end (B) returns (B, S) long tensor of indices, where S = max(end - start) """ start, end = start.long(), end.long() S = max(int((end - start).max()), min_len) seg = torch.arange(S, dtype=torch.int64, device=start.device) idcs = start[:, None] + seg[None, :] # (B, S) if clip: # clip at the lengths of each track imax = torch.maximum(end - 1, start)[:, None] idcs = idcs.clamp(max=imax) return idcs ================================================ FILE: src/egoallo/py.typed ================================================ ================================================ FILE: src/egoallo/sampling.py ================================================ from __future__ import annotations import time import numpy as np import torch from jaxtyping import Float from torch import Tensor from tqdm.auto import tqdm from . import fncsmpl, network from .guidance_optimizer_jax import GuidanceMode, do_guidance_optimization from .hand_detection_structs import ( CorrespondedAriaHandWristPoseDetections, CorrespondedHamerDetections, ) from .tensor_dataclass import TensorDataclass from .transforms import SE3 def quadratic_ts() -> np.ndarray: """DDIM sampling schedule.""" end_step = 0 start_step = 1000 x = np.arange(end_step, int(np.sqrt(start_step))) ** 2 x[-1] = start_step return x[::-1] class CosineNoiseScheduleConstants(TensorDataclass): alpha_t: Float[Tensor, "T"] r"""$1 - \beta_t$""" alpha_bar_t: Float[Tensor, "T+1"] r"""$\Prod_{j=1}^t (1 - \beta_j)$""" @staticmethod def compute(timesteps: int, s: float = 0.008) -> CosineNoiseScheduleConstants: steps = timesteps + 1 x = torch.linspace(0, 1, steps, dtype=torch.float64) def get_betas(): alphas_cumprod = torch.cos((x + s) / (1 + s) * torch.pi * 0.5) ** 2 alphas_cumprod = alphas_cumprod / alphas_cumprod[0] betas = 1.0 - (alphas_cumprod[1:] / alphas_cumprod[:-1]) return torch.clip(betas, 0, 0.999) alpha_t = 1.0 - get_betas() assert len(alpha_t.shape) == 1 alpha_cumprod_t = torch.cat( [torch.ones((1,)), torch.cumprod(alpha_t, dim=0)], dim=0, ) return CosineNoiseScheduleConstants( alpha_t=alpha_t, alpha_bar_t=alpha_cumprod_t ) def run_sampling_with_stitching( denoiser_network: network.EgoDenoiser, body_model: fncsmpl.SmplhModel, guidance_mode: GuidanceMode, guidance_post: bool, guidance_inner: bool, Ts_world_cpf: Float[Tensor, "time 7"], floor_z: float, hamer_detections: None | CorrespondedHamerDetections, aria_detections: None | CorrespondedAriaHandWristPoseDetections, num_samples: int, device: torch.device, guidance_verbose: bool = True, ) -> network.EgoDenoiseTraj: # Offset the T_world_cpf transform to place the floor at z=0 for the # denoiser network. All of the network outputs are local, so we don't need to # unoffset when returning. Ts_world_cpf_shifted = Ts_world_cpf.clone() Ts_world_cpf_shifted[..., 6] -= floor_z noise_constants = CosineNoiseScheduleConstants.compute(timesteps=1000).to( device=device ) alpha_bar_t = noise_constants.alpha_bar_t alpha_t = noise_constants.alpha_t T_cpf_tm1_cpf_t = ( SE3(Ts_world_cpf[..., :-1, :]).inverse() @ SE3(Ts_world_cpf[..., 1:, :]) ).wxyz_xyz x_t_packed = torch.randn( (num_samples, Ts_world_cpf.shape[0] - 1, denoiser_network.get_d_state()), device=device, ) x_t_list = [ network.EgoDenoiseTraj.unpack( x_t_packed, include_hands=denoiser_network.config.include_hands ) ] ts = quadratic_ts() seq_len = x_t_packed.shape[1] start_time = None window_size = 128 overlap_size = 32 canonical_overlap_weights = ( torch.from_numpy( np.minimum( # Make this shape /```\ overlap_size, np.minimum( # Make this shape: / np.arange(1, seq_len + 1), # Make this shape: \ np.arange(1, seq_len + 1)[::-1], ), ) / overlap_size, ) .to(device) .to(torch.float32) ) for i in tqdm(range(len(ts) - 1)): print(f"Sampling {i}/{len(ts) - 1}") t = ts[i] t_next = ts[i + 1] with torch.inference_mode(): # Chop everything into windows. x_0_packed_pred = torch.zeros_like(x_t_packed) overlap_weights = torch.zeros((1, seq_len, 1), device=x_t_packed.device) # Denoise each window. for start_t in range(0, seq_len, window_size - overlap_size): end_t = min(start_t + window_size, seq_len) assert end_t - start_t > 0 overlap_weights_slice = canonical_overlap_weights[ None, : end_t - start_t, None ] overlap_weights[:, start_t:end_t, :] += overlap_weights_slice x_0_packed_pred[:, start_t:end_t, :] += ( denoiser_network.forward( x_t_packed[:, start_t:end_t, :], torch.tensor([t], device=device).expand((num_samples,)), T_cpf_tm1_cpf_t=T_cpf_tm1_cpf_t[None, start_t:end_t, :].repeat( (num_samples, 1, 1) ), T_world_cpf=Ts_world_cpf_shifted[ None, start_t + 1 : end_t + 1, : ].repeat((num_samples, 1, 1)), project_output_rotmats=False, hand_positions_wrt_cpf=None, # TODO: this should be filled in!! mask=None, ) * overlap_weights_slice ) # Take the mean for overlapping regions. x_0_packed_pred /= overlap_weights x_0_packed_pred = network.EgoDenoiseTraj.unpack( x_0_packed_pred, include_hands=denoiser_network.config.include_hands, project_rotmats=True, ).pack() if torch.any(torch.isnan(x_0_packed_pred)): print("found nan", i) sigma_t = torch.cat( [ torch.zeros((1,), device=device), torch.sqrt( (1.0 - alpha_bar_t[:-1]) / (1 - alpha_bar_t[1:]) * (1 - alpha_t) ) * 0.8, ] ) if guidance_mode != "off" and guidance_inner: x_0_pred, _ = do_guidance_optimization( # It's important that we _don't_ use the shifted transforms here. Ts_world_cpf=Ts_world_cpf[1:, :], traj=network.EgoDenoiseTraj.unpack( x_0_packed_pred, include_hands=denoiser_network.config.include_hands ), body_model=body_model, guidance_mode=guidance_mode, phase="inner", hamer_detections=hamer_detections, aria_detections=aria_detections, verbose=guidance_verbose, ) x_0_packed_pred = x_0_pred.pack() del x_0_pred if start_time is None: start_time = time.time() # print(sigma_t) x_t_packed = ( torch.sqrt(alpha_bar_t[t_next]) * x_0_packed_pred + ( torch.sqrt(1 - alpha_bar_t[t_next] - sigma_t[t] ** 2) * (x_t_packed - torch.sqrt(alpha_bar_t[t]) * x_0_packed_pred) / torch.sqrt(1 - alpha_bar_t[t] + 1e-1) ) + sigma_t[t] * torch.randn(x_0_packed_pred.shape, device=device) ) x_t_list.append( network.EgoDenoiseTraj.unpack( x_t_packed, include_hands=denoiser_network.config.include_hands ) ) if guidance_mode != "off" and guidance_post: constrained_traj = x_t_list[-1] constrained_traj, _ = do_guidance_optimization( # It's important that we _don't_ use the shifted transforms here. Ts_world_cpf=Ts_world_cpf[1:, :], traj=constrained_traj, body_model=body_model, guidance_mode=guidance_mode, phase="post", hamer_detections=hamer_detections, aria_detections=aria_detections, verbose=guidance_verbose, ) assert start_time is not None print("RUNTIME (exclude first optimization)", time.time() - start_time) return constrained_traj else: assert start_time is not None print("RUNTIME (exclude first optimization)", time.time() - start_time) return x_t_list[-1] ================================================ FILE: src/egoallo/tensor_dataclass.py ================================================ import dataclasses from typing import Any, Callable, Self, dataclass_transform import torch @dataclass_transform() class TensorDataclass: """A lighter version of nerfstudio's TensorDataclass: https://github.com/nerfstudio-project/nerfstudio/blob/main/nerfstudio/utils/tensor_dataclass.py """ def __init_subclass__(cls) -> None: dataclasses.dataclass(cls) def to(self, device: torch.device | str) -> Self: """Move the tensors in the dataclass to the given device. Args: device: The device to move to. Returns: A new dataclass. """ return self.map(lambda x: x.to(device)) def as_nested_dict(self, numpy: bool) -> dict[str, Any]: """Convert the dataclass to a nested dictionary. Recurses into lists, tuples, and dictionaries. """ def _to_dict(obj: Any) -> Any: if isinstance(obj, TensorDataclass): return {k: _to_dict(v) for k, v in vars(obj).items()} elif isinstance(obj, (list, tuple)): return type(obj)(_to_dict(v) for v in obj) elif isinstance(obj, dict): return {k: _to_dict(v) for k, v in obj.items()} elif isinstance(obj, torch.Tensor) and numpy: return obj.numpy(force=True) else: return obj return _to_dict(self) def map(self, fn: Callable[[torch.Tensor], torch.Tensor]) -> Self: """Apply a function to all tensors in the dataclass. Also recurses into lists, tuples, and dictionaries. Args: fn: The function to apply to each tensor. Returns: A new dataclass. """ def _map_impl[MapT]( fn: Callable[[torch.Tensor], torch.Tensor], val: MapT, ) -> MapT: if isinstance(val, torch.Tensor): return fn(val) elif isinstance(val, TensorDataclass): return type(val)(**_map_impl(fn, vars(val))) elif isinstance(val, (list, tuple)): return type(val)(_map_impl(fn, v) for v in val) elif isinstance(val, dict): assert type(val) is dict # No subclass support. return {k: _map_impl(fn, v) for k, v in val.items()} # type: ignore else: return val return _map_impl(fn, self) ================================================ FILE: src/egoallo/training_loss.py ================================================ """Training loss configuration.""" import dataclasses from typing import Literal import torch.utils.data from jaxtyping import Bool, Float, Int from torch import Tensor from torch._dynamo import OptimizedModule from torch.nn.parallel import DistributedDataParallel from . import network from .data.amass import EgoTrainingData from .sampling import CosineNoiseScheduleConstants from .transforms import SO3 @dataclasses.dataclass(frozen=True) class TrainingLossConfig: cond_dropout_prob: float = 0.0 beta_coeff_weights: tuple[float, ...] = tuple(1 / (i + 1) for i in range(16)) loss_weights: dict[str, float] = dataclasses.field( default_factory={ "betas": 0.1, "body_rotmats": 1.0, "contacts": 0.1, # We don't have many hands in the AMASS dataset... "hand_rotmats": 0.01, }.copy ) weight_loss_by_t: Literal["emulate_eps_pred"] = "emulate_eps_pred" """Weights to apply to the loss at each noise level.""" class TrainingLossComputer: """Helper class for computing the training loss. Contains a single method for computing a training loss.""" def __init__(self, config: TrainingLossConfig, device: torch.device) -> None: self.config = config self.noise_constants = ( CosineNoiseScheduleConstants.compute(timesteps=1000) .to(device) .map(lambda tensor: tensor.to(torch.float32)) ) # Emulate loss weight that would be ~equivalent to epsilon prediction. # # This will penalize later errors (close to the end of sampling) much # more than earlier ones (at the start of sampling). assert self.config.weight_loss_by_t == "emulate_eps_pred" weight_t = self.noise_constants.alpha_bar_t / ( 1 - self.noise_constants.alpha_bar_t ) # Pad for numerical stability, and scale between [padding, 1.0]. padding = 0.01 self.weight_t = weight_t / weight_t[1] * (1.0 - padding) + padding def compute_denoising_loss( self, model: network.EgoDenoiser | DistributedDataParallel | OptimizedModule, unwrapped_model: network.EgoDenoiser, train_batch: EgoTrainingData, ) -> tuple[Tensor, dict[str, Tensor | float]]: """Compute a training loss for the EgoDenoiser model. Returns: A tuple (loss tensor, dictionary of things to log). """ log_outputs: dict[str, Tensor | float] = {} batch, time, num_joints, _ = train_batch.body_quats.shape assert num_joints == 21 if unwrapped_model.config.include_hands: assert train_batch.hand_quats is not None x_0 = network.EgoDenoiseTraj( betas=train_batch.betas.expand((batch, time, 16)), body_rotmats=SO3(train_batch.body_quats).as_matrix(), contacts=train_batch.contacts, hand_rotmats=SO3(train_batch.hand_quats).as_matrix(), ) else: x_0 = network.EgoDenoiseTraj( betas=train_batch.betas.expand((batch, time, 16)), body_rotmats=SO3(train_batch.body_quats).as_matrix(), contacts=train_batch.contacts, hand_rotmats=None, ) x_0_packed = x_0.pack() device = x_0_packed.device assert x_0_packed.shape == (batch, time, unwrapped_model.get_d_state()) # Diffuse. t = torch.randint( low=1, high=unwrapped_model.config.max_t + 1, size=(batch,), device=device, ) eps = torch.randn(x_0_packed.shape, dtype=x_0_packed.dtype, device=device) assert self.noise_constants.alpha_bar_t.shape == ( unwrapped_model.config.max_t + 1, ) alpha_bar_t = self.noise_constants.alpha_bar_t[t, None, None] assert alpha_bar_t.shape == (batch, 1, 1) x_t_packed = ( torch.sqrt(alpha_bar_t) * x_0_packed + torch.sqrt(1.0 - alpha_bar_t) * eps ) hand_positions_wrt_cpf: Tensor | None = None if unwrapped_model.config.include_hand_positions_cond: # Joints 19 and 20 are the hand positions. hand_positions_wrt_cpf = train_batch.joints_wrt_cpf[:, :, 19:21, :].reshape( (batch, time, 6) ) # Exclude hand positions for some items in the batch. We'll just do # this by passing in zeros. hand_positions_wrt_cpf = torch.where( # Uniformly drop out with some uniformly sampled probability. # :) ( torch.rand((batch, time, 1), device=device) < torch.rand((batch, 1, 1), device=device) ), hand_positions_wrt_cpf, 0.0, ) # Denoise. x_0_packed_pred = model.forward( x_t_packed=x_t_packed, t=t, T_world_cpf=train_batch.T_world_cpf, T_cpf_tm1_cpf_t=train_batch.T_cpf_tm1_cpf_t, hand_positions_wrt_cpf=hand_positions_wrt_cpf, project_output_rotmats=False, mask=train_batch.mask, cond_dropout_keep_mask=torch.rand((batch,), device=device) > self.config.cond_dropout_prob if self.config.cond_dropout_prob > 0.0 else None, ) assert isinstance(x_0_packed_pred, torch.Tensor) x_0_pred = network.EgoDenoiseTraj.unpack( x_0_packed_pred, include_hands=unwrapped_model.config.include_hands ) weight_t = self.weight_t[t].to(device) assert weight_t.shape == (batch,) def weight_and_mask_loss( loss_per_step: Float[Tensor, "b t d"], # bt stands for "batch time" bt_mask: Bool[Tensor, "b t"] = train_batch.mask, bt_mask_sum: Int[Tensor, ""] = torch.sum(train_batch.mask), ) -> Float[Tensor, ""]: """Weight and mask per-timestep losses (squared errors).""" _, _, d = loss_per_step.shape assert loss_per_step.shape == (batch, time, d) assert bt_mask.shape == (batch, time) assert weight_t.shape == (batch,) return ( # Sum across b axis. torch.sum( # Sum across t axis. torch.sum( # Mean across d axis. torch.mean(loss_per_step, dim=-1) * bt_mask, dim=-1, ) * weight_t ) / bt_mask_sum ) loss_terms: dict[str, Tensor | float] = { "betas": weight_and_mask_loss( # (b, t, 16) (x_0_pred.betas - x_0.betas) ** 2 # (16,) * x_0.betas.new_tensor(self.config.beta_coeff_weights), ), "body_rotmats": weight_and_mask_loss( # (b, t, 21 * 3 * 3) (x_0_pred.body_rotmats - x_0.body_rotmats).reshape( (batch, time, 21 * 3 * 3) ) ** 2, ), "contacts": weight_and_mask_loss((x_0_pred.contacts - x_0.contacts) ** 2), } # Include hand objective. # We didn't use this in the paper. if unwrapped_model.config.include_hands: assert x_0_pred.hand_rotmats is not None assert x_0.hand_rotmats is not None assert x_0.hand_rotmats.shape == (batch, time, 30, 3, 3) # Detect whether or not hands move in a sequence. # We should only supervise sequences where the hands are actully tracked / move; # we mask out hands in AMASS sequences where they are not tracked. gt_hand_flatmat = x_0.hand_rotmats.reshape((batch, time, -1)) hand_motion = ( torch.sum( # (b,) from (b, t) torch.sum( # (b, t) from (b, t, d) torch.abs(gt_hand_flatmat - gt_hand_flatmat[:, 0:1, :]), dim=-1 ) # Zero out changes in masked frames. * train_batch.mask, dim=-1, ) > 1e-5 ) assert hand_motion.shape == (batch,) hand_bt_mask = torch.logical_and(hand_motion[:, None], train_batch.mask) loss_terms["hand_rotmats"] = torch.sum( weight_and_mask_loss( (x_0_pred.hand_rotmats - x_0.hand_rotmats).reshape( batch, time, 30 * 3 * 3 ) ** 2, bt_mask=hand_bt_mask, # We want to weight the loss by the number of frames where # the hands actually move, but gradients here can be too # noisy and put NaNs into mixed-precision training when we # inevitably sample too few frames. So we clip the # denominator. bt_mask_sum=torch.maximum( torch.sum(hand_bt_mask), torch.tensor(256, device=device) ), ) ) # self.log( # "train/hand_motion_proportion", # torch.sum(hand_motion) / batch, # ) else: loss_terms["hand_rotmats"] = 0.0 assert loss_terms.keys() == self.config.loss_weights.keys() # Log loss terms. for name, term in loss_terms.items(): log_outputs[f"loss_term/{name}"] = term # Return loss. loss = sum([loss_terms[k] * self.config.loss_weights[k] for k in loss_terms]) assert isinstance(loss, Tensor) assert loss.shape == () log_outputs["train_loss"] = loss return loss, log_outputs ================================================ FILE: src/egoallo/training_utils.py ================================================ """Utilities for writing training scripts.""" import dataclasses import pdb import signal import subprocess import sys import time import traceback as tb from pathlib import Path from typing import ( Any, Dict, Generator, Iterable, Protocol, Sized, get_type_hints, overload, ) import torch def flattened_hparam_dict_from_dataclass( dataclass: Any, prefix: str | None = None ) -> Dict[str, Any]: """Convert a config object in the form of a nested dataclass into a flattened dictionary, for use with Tensorboard hparams.""" assert dataclasses.is_dataclass(dataclass) cls = type(dataclass) hints = get_type_hints(cls) output = {} for field in dataclasses.fields(dataclass): field_type = hints[field.name] value = getattr(dataclass, field.name) if dataclasses.is_dataclass(field_type): inner = flattened_hparam_dict_from_dataclass(value, prefix=None) inner = {".".join([field.name, k]): v for k, v in inner.items()} output.update(inner) # Cast to type supported by tensorboard hparams. elif isinstance(value, (int, float, str, bool, torch.Tensor)): output[field.name] = value else: output[field.name] = str(value) if prefix is None: return output else: return {f"{prefix}.{k}": v for k, v in output.items()} def pdb_safety_net(): """Attaches a "safety net" for unexpected errors in a Python script. When called, PDB will be automatically opened when either (a) the user hits Ctrl+C or (b) we encounter an uncaught exception. Helpful for bypassing minor errors, diagnosing problems, and rescuing unsaved models. """ # Open PDB on Ctrl+C def handler(sig, frame): pdb.set_trace() signal.signal(signal.SIGINT, handler) # Open PDB when we encounter an uncaught exception def excepthook(type_, value, traceback): # pragma: no cover (impossible to test) tb.print_exception(type_, value, traceback, limit=100) pdb.post_mortem(traceback) sys.excepthook = excepthook class SizedIterable[ContainedType](Iterable[ContainedType], Sized, Protocol): """Protocol for objects that define both `__iter__()` and `__len__()` methods. This is particularly useful for managing minibatches, which can be iterated over but only in order due to multiprocessing/prefetching optimizations, and for which length evaluation is useful for tools like `tqdm`.""" @dataclasses.dataclass class LoopMetrics: counter: int iterations_per_sec: float time_elapsed: float @overload def range_with_metrics(stop: int, /) -> SizedIterable[LoopMetrics]: ... @overload def range_with_metrics(start: int, stop: int, /) -> SizedIterable[LoopMetrics]: ... @overload def range_with_metrics( start: int, stop: int, step: int, / ) -> SizedIterable[LoopMetrics]: ... def range_with_metrics(*args: int) -> SizedIterable[LoopMetrics]: """Light wrapper for `fifteen.utils.loop_metric_generator()`, for use in place of `range()`. Yields a LoopMetrics object instead of an integer.""" return _RangeWithMetrics(args=args) @dataclasses.dataclass class _RangeWithMetrics: args: tuple[int, ...] def __iter__(self): loop_metrics = loop_metric_generator() for counter in range(*self.args): yield dataclasses.replace(next(loop_metrics), counter=counter) def __len__(self) -> int: return len(range(*self.args)) def loop_metric_generator(counter_init: int = 0) -> Generator[LoopMetrics, None, None]: """Generator for computing loop metrics. Note that the first `iteration_per_sec` metric will be 0.0. Example usage: ``` # Note that this is an infinite loop. for metric in loop_metric_generator(): time.sleep(1.0) print(metric) ``` or: ``` loop_metrics = loop_metric_generator() while True: time.sleep(1.0) print(next(loop_metrics).iterations_per_sec) ``` """ counter = counter_init del counter_init time_start = time.time() time_prev = time_start while True: time_now = time.time() yield LoopMetrics( counter=counter, iterations_per_sec=1.0 / (time_now - time_prev) if counter > 0 else 0.0, time_elapsed=time_now - time_start, ) time_prev = time_now counter += 1 def get_git_commit_hash(cwd: Path | None = None) -> str: """Returns the current Git commit hash.""" if cwd is None: cwd = Path.cwd() return ( subprocess.check_output(["git", "rev-parse", "HEAD"], cwd=cwd.as_posix()) .decode("ascii") .strip() ) def get_git_diff(cwd: Path | None = None) -> str: """Returns the output of `git diff HEAD`.""" if cwd is None: cwd = Path.cwd() return ( subprocess.check_output(["git", "diff", "HEAD"], cwd=cwd.as_posix()) .decode("ascii") .strip() ) ================================================ FILE: src/egoallo/transforms/__init__.py ================================================ """Rigid transforms implemented in PyTorch, ported from jaxlie.""" from . import utils as utils from ._base import MatrixLieGroup as MatrixLieGroup from ._base import SEBase as SEBase from ._base import SOBase as SOBase from ._se3 import SE3 as SE3 from ._so3 import SO3 as SO3 ================================================ FILE: src/egoallo/transforms/_base.py ================================================ import abc from typing import ( ClassVar, Generic, Self, Tuple, Type, TypeVar, Union, final, overload, override, ) import numpy as onp import torch from torch import Tensor GroupType = TypeVar("GroupType", bound="MatrixLieGroup") SEGroupType = TypeVar("SEGroupType", bound="SEBase") class MatrixLieGroup(abc.ABC): """Interface definition for matrix Lie groups.""" # Class properties. # > These will be set in `_utils.register_lie_group()`. matrix_dim: ClassVar[int] """Dimension of square matrix output from `.as_matrix()`.""" parameters_dim: ClassVar[int] """Dimension of underlying parameters, `.parameters()`.""" tangent_dim: ClassVar[int] """Dimension of tangent space.""" space_dim: ClassVar[int] """Dimension of coordinates that can be transformed.""" def __init__( # Notes: # - For the constructor signature to be consistent with subclasses, `parameters` # should be marked as positional-only. But this isn't possible in Python 3.7. # - This method is implicitly overriden by the dataclass decorator and # should _not_ be marked abstract. self, parameters: Tensor, ): """Construct a group object from its underlying parameters.""" raise NotImplementedError() # Shared implementations. @overload def __matmul__(self: GroupType, other: GroupType) -> GroupType: ... @overload def __matmul__(self, other: Tensor) -> Tensor: ... def __matmul__( self: GroupType, other: Union[GroupType, Tensor] ) -> Union[GroupType, Tensor]: """Overload for the `@` operator. Switches between the group action (`.apply()`) and multiplication (`.multiply()`) based on the type of `other`. """ if isinstance(other, (onp.ndarray, Tensor)): return self.apply(target=other) elif isinstance(other, MatrixLieGroup): assert self.space_dim == other.space_dim return self.multiply(other=other) else: assert False, f"Invalid argument type for `@` operator: {type(other)}" # Factory. @classmethod @abc.abstractmethod def identity( cls: Type[GroupType], device: Union[torch.device, str], dtype: torch.dtype ) -> GroupType: """Returns identity element. Returns: Identity element. """ @classmethod @abc.abstractmethod def from_matrix(cls: Type[GroupType], matrix: Tensor) -> GroupType: """Get group member from matrix representation. Args: matrix: Matrix representaiton. Returns: Group member. """ # Accessors. @abc.abstractmethod def as_matrix(self) -> Tensor: """Get transformation as a matrix. Homogeneous for SE groups.""" @abc.abstractmethod def parameters(self) -> Tensor: """Get underlying representation.""" # Operations. @abc.abstractmethod def apply(self, target: Tensor) -> Tensor: """Applies group action to a point. Args: target: Point to transform. Returns: Transformed point. """ @abc.abstractmethod def multiply(self: Self, other: Self) -> Self: """Composes this transformation with another. Returns: self @ other """ @classmethod @abc.abstractmethod def exp(cls: Type[GroupType], tangent: Tensor) -> GroupType: """Computes `expm(wedge(tangent))`. Args: tangent: Tangent vector to take the exponential of. Returns: Output. """ @abc.abstractmethod def log(self) -> Tensor: """Computes `vee(logm(transformation matrix))`. Returns: Output. Shape should be `(tangent_dim,)`. """ @abc.abstractmethod def adjoint(self) -> Tensor: """Computes the adjoint, which transforms tangent vectors between tangent spaces. More precisely, for a transform `GroupType`: ``` GroupType @ exp(omega) = exp(Adj_T @ omega) @ GroupType ``` In robotics, typically used for transforming twists, wrenches, and Jacobians across different reference frames. Returns: Output. Shape should be `(tangent_dim, tangent_dim)`. """ @abc.abstractmethod def inverse(self: GroupType) -> GroupType: """Computes the inverse of our transform. Returns: Output. """ @abc.abstractmethod def normalize(self: GroupType) -> GroupType: """Normalize/projects values and returns. Returns: GroupType: Normalized group member. """ # @classmethod # @abc.abstractmethod # def sample_uniform(cls: Type[GroupType], key: Tensor) -> GroupType: # """Draw a uniform sample from the group. Translations (if applicable) are in the # range [-1, 1]. # # Args: # key: PRNG key, as returned by `jax.random.PRNGKey()`. # # Returns: # Sampled group member. # """ def get_batch_axes(self) -> Tuple[int, ...]: """Return any leading batch axes in contained parameters. If an array of shape `(100, 4)` is placed in the wxyz field of an SO3 object, for example, this will return `(100,)`.""" return self.parameters().shape[:-1] class SOBase(MatrixLieGroup): """Base class for special orthogonal groups.""" ContainedSOType = TypeVar("ContainedSOType", bound=SOBase) class SEBase(Generic[ContainedSOType], MatrixLieGroup): """Base class for special Euclidean groups. Each SE(N) group member contains an SO(N) rotation, as well as an N-dimensional translation vector. """ # SE-specific interface. @classmethod @abc.abstractmethod def from_rotation_and_translation( cls: Type[SEGroupType], rotation: ContainedSOType, translation: Tensor, ) -> SEGroupType: """Construct a rigid transform from a rotation and a translation. Args: rotation: Rotation term. translation: Translation term. Returns: Constructed transformation. """ @final @classmethod def from_rotation(cls: Type[SEGroupType], rotation: ContainedSOType) -> SEGroupType: return cls.from_rotation_and_translation( rotation=rotation, translation=rotation.parameters().new_zeros( (*rotation.parameters().shape[:-1], cls.space_dim), dtype=rotation.parameters().dtype, ), ) @abc.abstractmethod def rotation(self) -> ContainedSOType: """Returns a transform's rotation term.""" @abc.abstractmethod def translation(self) -> Tensor: """Returns a transform's translation term.""" # Overrides. @final @override def apply(self, target: Tensor) -> Tensor: return self.rotation() @ target + self.translation() # type: ignore @final @override def multiply(self: SEGroupType, other: SEGroupType) -> SEGroupType: # type: ignore return type(self).from_rotation_and_translation( rotation=self.rotation() @ other.rotation(), translation=(self.rotation() @ other.translation()) + self.translation(), ) @final @override def inverse(self: SEGroupType) -> SEGroupType: R_inv = self.rotation().inverse() return type(self).from_rotation_and_translation( rotation=R_inv, translation=-(R_inv @ self.translation()), ) @final @override def normalize(self: SEGroupType) -> SEGroupType: return type(self).from_rotation_and_translation( rotation=self.rotation().normalize(), translation=self.translation(), ) ================================================ FILE: src/egoallo/transforms/_se3.py ================================================ from __future__ import annotations from dataclasses import dataclass from typing import Union, cast, override import numpy as np import torch from torch import Tensor from . import _base from ._so3 import SO3 from .utils import get_epsilon, register_lie_group def _skew(omega: Tensor) -> Tensor: """ Returns the skew-symmetric form of a length-3 vector. :param omega (*, 3) :returns (*, 3, 3) """ wx, wy, wz = omega.unbind(dim=-1) o = torch.zeros_like(wx) return torch.stack( [o, -wz, wy, wz, o, -wx, -wy, wx, o], dim=-1, ).reshape(*wx.shape, 3, 3) @register_lie_group( matrix_dim=4, parameters_dim=7, tangent_dim=6, space_dim=3, ) @dataclass(frozen=True) class SE3(_base.SEBase[SO3]): """Special Euclidean group for proper rigid transforms in 3D. Internal parameterization is `(qw, qx, qy, qz, x, y, z)`. Tangent parameterization is `(vx, vy, vz, omega_x, omega_y, omega_z)`. """ # SE3-specific. wxyz_xyz: Tensor """Internal parameters. wxyz quaternion followed by xyz translation.""" @override def __repr__(self) -> str: quat = np.round(self.wxyz_xyz[..., :4].numpy(force=True), 5) trans = np.round(self.wxyz_xyz[..., 4:].numpy(force=True), 5) return f"{self.__class__.__name__}(wxyz={quat}, xyz={trans})" # SE-specific. @classmethod @override def from_rotation_and_translation( cls, rotation: SO3, translation: Tensor, ) -> SE3: assert translation.shape[-1] == 3 return SE3(wxyz_xyz=torch.cat([rotation.wxyz, translation], dim=-1)) @override def rotation(self) -> SO3: return SO3(wxyz=self.wxyz_xyz[..., :4]) @override def translation(self) -> Tensor: return self.wxyz_xyz[..., 4:] # Factory. @classmethod @override def identity(cls, device: Union[torch.device, str], dtype: torch.dtype) -> SE3: return SE3( wxyz_xyz=torch.tensor( [1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], device=device, dtype=dtype ) ) @classmethod @override def from_matrix(cls, matrix: Tensor) -> SE3: assert matrix.shape[-2:] == (4, 4) or matrix.shape[-2:] == (3, 4) # Currently assumes bottom row is [0, 0, 0, 1]. return SE3.from_rotation_and_translation( rotation=SO3.from_matrix(matrix[..., :3, :3]), translation=matrix[..., :3, 3], ) # Accessors. @override def as_matrix(self) -> Tensor: R = self.rotation().as_matrix() # (*, 3, 3) t = self.translation().unsqueeze(-1) # (*, 3, 1) dims = R.shape[:-2] bottom = ( torch.tensor([0, 0, 0, 1], dtype=R.dtype, device=R.device) .reshape(*(1,) * len(dims), 1, 4) .repeat(*dims, 1, 1) ) return torch.cat([torch.cat([R, t], dim=-1), bottom], dim=-2) @override def parameters(self) -> Tensor: return self.wxyz_xyz # Operations. @classmethod @override def exp(cls, tangent: Tensor) -> SE3: # Reference: # > https://github.com/strasdat/Sophus/blob/a0fe89a323e20c42d3cecb590937eb7a06b8343a/sophus/se3.hpp#L761 # (x, y, z, omega_x, omega_y, omega_z) assert tangent.shape[-1] == 6 rotation = SO3.exp(tangent[..., 3:]) theta_squared = torch.square(tangent[..., 3:]).sum(dim=-1) # (*) use_taylor = theta_squared < get_epsilon(theta_squared.dtype) theta_squared_safe = cast( Tensor, torch.where( use_taylor, 1.0, # Any non-zero value should do here. theta_squared, ), ) del theta_squared theta_safe = torch.sqrt(theta_squared_safe) skew_omega = _skew(tangent[..., 3:]) dtype = skew_omega.dtype device = skew_omega.device V = torch.where( use_taylor[..., None, None], rotation.as_matrix(), ( torch.eye(3, device=device, dtype=dtype) + ((1.0 - torch.cos(theta_safe)) / (theta_squared_safe))[ ..., None, None ] * skew_omega + ( (theta_safe - torch.sin(theta_safe)) / (theta_squared_safe * theta_safe) )[..., None, None] * torch.einsum("...ij,...jk->...ik", skew_omega, skew_omega) ), ) return SE3.from_rotation_and_translation( rotation=rotation, translation=torch.einsum("...ij,...j->...i", V, tangent[..., :3]), ) @override def log(self) -> Tensor: # Reference: # > https://github.com/strasdat/Sophus/blob/a0fe89a323e20c42d3cecb590937eb7a06b8343a/sophus/se3.hpp#L223 omega = self.rotation().log() theta_squared = torch.square(omega).sum(dim=-1) # (*) use_taylor = theta_squared < get_epsilon(theta_squared.dtype) skew_omega = _skew(omega) # Shim to avoid NaNs in jnp.where branches, which cause failures for # reverse-mode AD. theta_squared_safe = torch.where( use_taylor, 1.0, # Any non-zero value should do here. theta_squared, ) del theta_squared theta_safe = torch.sqrt(theta_squared_safe) half_theta_safe = theta_safe / 2.0 dtype = omega.dtype device = omega.device V_inv = torch.where( use_taylor[..., None, None], torch.eye(3, device=device, dtype=dtype) - 0.5 * skew_omega + torch.matmul(skew_omega, skew_omega) / 12.0, ( torch.eye(3, device=device, dtype=dtype) - 0.5 * skew_omega + ( 1.0 - theta_safe * torch.cos(half_theta_safe) / (2.0 * torch.sin(half_theta_safe)) )[..., None, None] / theta_squared_safe[..., None, None] * torch.matmul(skew_omega, skew_omega) ), ) return torch.cat( [torch.einsum("...ij,...j->...i", V_inv, self.translation()), omega], dim=-1 ) @override def adjoint(self) -> Tensor: R = self.rotation().as_matrix() dims = R.shape[:-2] # (*, 6, 6) return torch.cat( [ torch.cat([R, torch.matmul(_skew(self.translation()), R)], dim=-1), torch.cat([torch.zeros((*dims, 3, 3)), R], dim=-1), ], dim=-2, ) ================================================ FILE: src/egoallo/transforms/_so3.py ================================================ from __future__ import annotations from dataclasses import dataclass from typing import Union, override import numpy as np import torch from torch import Tensor from . import _base from .utils import get_epsilon, register_lie_group @register_lie_group( matrix_dim=3, parameters_dim=4, tangent_dim=3, space_dim=3, ) @dataclass(frozen=True) class SO3(_base.SOBase): """Special orthogonal group for 3D rotations. Internal parameterization is `(qw, qx, qy, qz)`. Tangent parameterization is `(omega_x, omega_y, omega_z)`. """ # SO3-specific. wxyz: Tensor """Internal parameters. `(w, x, y, z)` quaternion.""" @override def __repr__(self) -> str: wxyz = np.round(self.wxyz.numpy(force=True), 5) return f"{self.__class__.__name__}(wxyz={wxyz})" @staticmethod def from_x_radians(theta: Tensor) -> SO3: """Generates a x-axis rotation. Args: angle: X rotation, in radians. Returns: Output. """ zeros = torch.zeros_like(theta) return SO3.exp(torch.stack([theta, zeros, zeros], dim=-1)) @staticmethod def from_y_radians(theta: Tensor) -> SO3: """Generates a y-axis rotation. Args: angle: Y rotation, in radians. Returns: Output. """ zeros = torch.zeros_like(theta) return SO3.exp(torch.stack([zeros, theta, zeros], dim=-1)) @staticmethod def from_z_radians(theta: Tensor) -> SO3: """Generates a z-axis rotation. Args: angle: Z rotation, in radians. Returns: Output. """ zeros = torch.zeros_like(theta) return SO3.exp(torch.stack([zeros, zeros, theta], dim=-1)) @staticmethod def from_rpy_radians( roll: Tensor, pitch: Tensor, yaw: Tensor, ) -> SO3: """Generates a transform from a set of Euler angles. Uses the ZYX mobile robot convention. Args: roll: X rotation, in radians. Applied first. pitch: Y rotation, in radians. Applied second. yaw: Z rotation, in radians. Applied last. Returns: Output. """ return ( SO3.from_z_radians(yaw) @ SO3.from_y_radians(pitch) @ SO3.from_x_radians(roll) ) @staticmethod def from_quaternion_xyzw(xyzw: Tensor) -> SO3: """Construct a rotation from an `xyzw` quaternion. Note that `wxyz` quaternions can be constructed using the default dataclass constructor. Args: xyzw: xyzw quaternion. Shape should be (4,). Returns: Output. """ assert xyzw.shape == (4,) return SO3(torch.roll(xyzw, shifts=1, dims=-1)) def as_quaternion_xyzw(self) -> Tensor: """Grab parameters as xyzw quaternion.""" return torch.roll(self.wxyz, shifts=-1, dims=-1) # Factory. @classmethod @override def identity(cls, device: Union[torch.device, str], dtype: torch.dtype) -> SO3: return SO3(wxyz=torch.tensor([1.0, 0.0, 0.0, 0.0], device=device, dtype=dtype)) @classmethod @override def from_matrix(cls, matrix: Tensor) -> SO3: assert matrix.shape[-2:] == (3, 3) # Modified from: # > "Converting a Rotation Matrix to a Quaternion" from Mike Day # > https://d3cw3dd2w32x2b.cloudfront.net/wp-content/uploads/2015/01/matrix-to-quat.pdf def case0(m): t = 1 + m[..., 0, 0] - m[..., 1, 1] - m[..., 2, 2] q = torch.stack( [ m[..., 2, 1] - m[..., 1, 2], t, m[..., 1, 0] + m[..., 0, 1], m[..., 0, 2] + m[..., 2, 0], ], dim=-1, ) return t, q def case1(m): t = 1 - m[..., 0, 0] + m[..., 1, 1] - m[..., 2, 2] q = torch.stack( [ m[..., 0, 2] - m[..., 2, 0], m[..., 1, 0] + m[..., 0, 1], t, m[..., 2, 1] + m[..., 1, 2], ], dim=-1, ) return t, q def case2(m): t = 1 - m[..., 0, 0] - m[..., 1, 1] + m[..., 2, 2] q = torch.stack( [ m[..., 1, 0] - m[..., 0, 1], m[..., 0, 2] + m[..., 2, 0], m[..., 2, 1] + m[..., 1, 2], t, ], dim=-1, ) return t, q def case3(m): t = 1 + m[..., 0, 0] + m[..., 1, 1] + m[..., 2, 2] q = torch.stack( [ t, m[..., 2, 1] - m[..., 1, 2], m[..., 0, 2] - m[..., 2, 0], m[..., 1, 0] - m[..., 0, 1], ], dim=-1, ) return t, q # Compute four cases, then pick the most precise one. # Probably worth revisiting this! case0_t, case0_q = case0(matrix) case1_t, case1_q = case1(matrix) case2_t, case2_q = case2(matrix) case3_t, case3_q = case3(matrix) cond0 = matrix[..., 2, 2] < 0 cond1 = matrix[..., 0, 0] > matrix[..., 1, 1] cond2 = matrix[..., 0, 0] < -matrix[..., 1, 1] t = torch.where( cond0, torch.where(cond1, case0_t, case1_t), torch.where(cond2, case2_t, case3_t), ) q = torch.where( cond0[..., None], torch.where(cond1[..., None], case0_q, case1_q), torch.where(cond2[..., None], case2_q, case3_q), ) return SO3(wxyz=q * 0.5 / torch.sqrt(t[..., None])) # Accessors. @override def as_matrix(self) -> Tensor: norm_sq = torch.square(self.wxyz).sum(dim=-1, keepdim=True) qvec = self.wxyz * torch.sqrt(2.0 / norm_sq) # (*, 4) Q = torch.einsum("...i,...j->...ij", qvec, qvec) # (*, 4, 4) return torch.stack( [ 1.0 - Q[..., 2, 2] - Q[..., 3, 3], Q[..., 1, 2] - Q[..., 3, 0], Q[..., 1, 3] + Q[..., 2, 0], Q[..., 1, 2] + Q[..., 3, 0], 1.0 - Q[..., 1, 1] - Q[..., 3, 3], Q[..., 2, 3] - Q[..., 1, 0], Q[..., 1, 3] - Q[..., 2, 0], Q[..., 2, 3] + Q[..., 1, 0], 1.0 - Q[..., 1, 1] - Q[..., 2, 2], ], dim=-1, ).reshape(*qvec.shape[:-1], 3, 3) @override def parameters(self) -> Tensor: return self.wxyz # Operations. @override def apply(self, target: Tensor) -> Tensor: assert target.shape[-1] == 3 # Compute using quaternion multiplys. padded_target = torch.cat([torch.ones_like(target[..., :1]), target], dim=-1) out = self.multiply(SO3(wxyz=padded_target).multiply(self.inverse())) return out.wxyz[..., 1:] @override def multiply(self, other: SO3) -> SO3: # type: ignore w0, x0, y0, z0 = self.wxyz.unbind(dim=-1) w1, x1, y1, z1 = other.wxyz.unbind(dim=-1) wxyz2 = torch.stack( [ -x0 * x1 - y0 * y1 - z0 * z1 + w0 * w1, x0 * w1 + y0 * z1 - z0 * y1 + w0 * x1, -x0 * z1 + y0 * w1 + z0 * x1 + w0 * y1, x0 * y1 - y0 * x1 + z0 * w1 + w0 * z1, ], dim=-1, ) return SO3(wxyz=wxyz2) @classmethod @override def exp(cls, tangent: Tensor) -> SO3: # Reference: # > https://github.com/strasdat/Sophus/blob/a0fe89a323e20c42d3cecb590937eb7a06b8343a/sophus/so3.hpp#L583 assert tangent.shape[-1] == 3 theta_squared = torch.square(tangent).sum(dim=-1) # (*) theta_pow_4 = theta_squared * theta_squared use_taylor = theta_squared < get_epsilon(tangent.dtype) safe_theta = torch.sqrt( torch.where( use_taylor, torch.ones_like(theta_squared), # Any constant value should do here. theta_squared, ) ) safe_half_theta = 0.5 * safe_theta real_factor = torch.where( use_taylor, 1.0 - theta_squared / 8.0 + theta_pow_4 / 384.0, torch.cos(safe_half_theta), ) imaginary_factor = torch.where( use_taylor, 0.5 - theta_squared / 48.0 + theta_pow_4 / 3840.0, torch.sin(safe_half_theta) / safe_theta, ) return SO3( wxyz=torch.cat( [ real_factor[..., None], imaginary_factor[..., None] * tangent, ], dim=-1, ) ) @override def log(self) -> Tensor: # Reference: # > https://github.com/strasdat/Sophus/blob/a0fe89a323e20c42d3cecb590937eb7a06b8343a/sophus/so3.hpp#L247 w, xyz = torch.split(self.wxyz, [1, 3], dim=-1) # (*, 1), (*, 3) norm_sq = torch.square(xyz).sum(dim=-1, keepdim=True) # (*, 1) use_taylor = norm_sq < get_epsilon(norm_sq.dtype) norm_safe = torch.sqrt( torch.where( use_taylor, torch.ones_like(norm_sq), # Any non-zero value should do here. norm_sq, ) ) w_safe = torch.where(use_taylor, w, torch.ones_like(w)) atan_n_over_w = torch.atan2( torch.where(w < 0, -norm_safe, norm_safe), torch.abs(w), ) atan_factor = torch.where( use_taylor, 2.0 / w_safe - 2.0 / 3.0 * norm_sq / w_safe**3, torch.where( torch.abs(w) < get_epsilon(w.dtype), torch.where(w > 0, 1.0, -1.0) * torch.pi / norm_safe, 2.0 * atan_n_over_w / norm_safe, ), ) return atan_factor * xyz @override def adjoint(self) -> Tensor: return self.as_matrix() @override def inverse(self) -> SO3: # Negate complex terms. w, xyz = torch.split(self.wxyz, [1, 3], dim=-1) return SO3(wxyz=torch.cat([w, -xyz], dim=-1)) @override def normalize(self) -> SO3: return SO3(wxyz=self.wxyz / torch.linalg.norm(self.wxyz, dim=-1, keepdim=True)) ================================================ FILE: src/egoallo/transforms/utils/__init__.py ================================================ from ._utils import get_epsilon, register_lie_group __all__ = ["get_epsilon", "register_lie_group"] ================================================ FILE: src/egoallo/transforms/utils/_utils.py ================================================ from typing import TYPE_CHECKING, Callable, Type, TypeVar import torch if TYPE_CHECKING: from .._base import MatrixLieGroup T = TypeVar("T", bound="MatrixLieGroup") def get_epsilon(dtype: torch.dtype) -> float: """Helper for grabbing type-specific precision constants. Args: dtype: Datatype. Returns: Output float. """ return { torch.float32: 1e-5, torch.float64: 1e-10, }[dtype] def register_lie_group( *, matrix_dim: int, parameters_dim: int, tangent_dim: int, space_dim: int, ) -> Callable[[Type[T]], Type[T]]: """Decorator for registering Lie group dataclasses. Sets dimensionality class variables. """ def _wrap(cls: Type[T]) -> Type[T]: # Register dimensions as class attributes. cls.matrix_dim = matrix_dim cls.parameters_dim = parameters_dim cls.tangent_dim = tangent_dim cls.space_dim = space_dim return cls return _wrap ================================================ FILE: src/egoallo/vis_helpers.py ================================================ import time from pathlib import Path from typing import Callable, TypedDict import numpy as np import numpy.typing as npt import torch import trimesh import viser import viser.transforms as vtf from jaxtyping import Float from plyfile import PlyData from torch import Tensor from . import fncsmpl, fncsmpl_extensions, network from .hand_detection_structs import ( CorrespondedAriaHandWristPoseDetections, CorrespondedHamerDetections, ) from .transforms import SE3, SO3 class SplatArgs(TypedDict): centers: npt.NDArray[np.floating] """(N, 3).""" rgbs: npt.NDArray[np.floating] """(N, 3). Range [0, 1].""" opacities: npt.NDArray[np.floating] """(N, 1). Range [0, 1].""" covariances: npt.NDArray[np.floating] """(N, 3, 3).""" def load_splat_file(splat_path: Path, center: bool = False) -> SplatArgs: """Load an antimatter15-style splat file.""" start_time = time.time() splat_buffer = splat_path.read_bytes() bytes_per_gaussian = ( # Each Gaussian is serialized as: # - position (vec3, float32) 3 * 4 # - xyz (vec3, float32) + 3 * 4 # - rgba (vec4, uint8) + 4 # - ijkl (vec4, uint8), where 0 => -1, 255 => 1. + 4 ) assert len(splat_buffer) % bytes_per_gaussian == 0 num_gaussians = len(splat_buffer) // bytes_per_gaussian # Reinterpret cast to dtypes that we want to extract. splat_uint8 = np.frombuffer(splat_buffer, dtype=np.uint8).reshape( (num_gaussians, bytes_per_gaussian) ) scales = splat_uint8[:, 12:24].copy().view(np.float32) wxyzs = splat_uint8[:, 28:32] / 255.0 * 2.0 - 1.0 Rs = vtf.SO3(wxyzs).as_matrix() covariances = np.einsum( "nij,njk,nlk->nil", Rs, np.eye(3)[None, :, :] * scales[:, None, :] ** 2, Rs ) centers = splat_uint8[:, 0:12].copy().view(np.float32) if center: centers -= np.mean(centers, axis=0, keepdims=True) print( f"Splat file with {num_gaussians=} loaded in {time.time() - start_time} seconds" ) return { "centers": centers, # Colors should have shape (N, 3). "rgbs": splat_uint8[:, 24:27] / 255.0, "opacities": splat_uint8[:, 27:28] / 255.0, # Covariances should have shape (N, 3, 3). "covariances": covariances, } def load_ply_file(ply_file_path: Path, center: bool = False) -> SplatArgs: """Load Gaussians stored in a PLY file.""" start_time = time.time() SH_C0 = 0.28209479177387814 plydata = PlyData.read(ply_file_path) v = plydata["vertex"] positions = np.stack([v["x"], v["y"], v["z"]], axis=-1) scales = np.exp(np.stack([v["scale_0"], v["scale_1"], v["scale_2"]], axis=-1)) wxyzs = np.stack([v["rot_0"], v["rot_1"], v["rot_2"], v["rot_3"]], axis=1) colors = 0.5 + SH_C0 * np.stack([v["f_dc_0"], v["f_dc_1"], v["f_dc_2"]], axis=1) opacities = 1.0 / (1.0 + np.exp(-v["opacity"][:, None])) Rs = vtf.SO3(wxyzs).as_matrix() covariances = np.einsum( "nij,njk,nlk->nil", Rs, np.eye(3)[None, :, :] * scales[:, None, :] ** 2, Rs ) if center: positions -= np.mean(positions, axis=0, keepdims=True) num_gaussians = len(v) print( f"PLY file with {num_gaussians=} loaded in {time.time() - start_time} seconds" ) return { "centers": positions, "rgbs": colors, "opacities": opacities, "covariances": covariances, } def add_splat_to_viser( splat_or_ply_path: Path, server: viser.ViserServer, z_offset: float = 0.0 ) -> None: """Add some Gaussian splats to the Viser server.""" if splat_or_ply_path.suffix.lower() == ".ply": splat_args = load_ply_file(splat_or_ply_path) elif splat_or_ply_path.suffix.lower() == ".splat": splat_args = load_splat_file(splat_or_ply_path) else: assert False server.scene.add_gaussian_splats( "/gaussian_splats", centers=splat_args["centers"], rgbs=splat_args["rgbs"], opacities=splat_args["opacities"], covariances=splat_args["covariances"], position=(0.0, 0.0, z_offset), ) def visualize_traj_and_hand_detections( server: viser.ViserServer, Ts_world_cpf: Float[Tensor, "timesteps 7"], traj: network.EgoDenoiseTraj | None, body_model: fncsmpl.SmplhModel, hamer_detections: CorrespondedHamerDetections | None = None, aria_detections: CorrespondedAriaHandWristPoseDetections | None = None, points_data: np.ndarray | None = None, splat_path: Path | None = None, floor_z: float = 0.0, show_joints: bool = False, get_ego_video: Callable[[int, int, float], bytes] | None = None, ) -> Callable[[], int]: """Chaotic mega-function for visualization. Returns a callback that should be called repeatedly in a loop.""" timesteps = Ts_world_cpf.shape[0] server.scene.add_grid( "/ground", plane="xy", cell_color=(80, 80, 80), section_color=(50, 50, 50), position=(0.0, 0.0, floor_z), ) if points_data is not None: point_cloud = server.scene.add_point_cloud( "/aria_points", points=points_data, colors=np.cos(points_data + np.arange(3)) / 3.0 + 0.7, # Make points colorful :) point_size=0.01, # point_size=0.1, point_shape="sparkle", ) size_slider = server.gui.add_slider( "Point cloud size", min=0.001, max=0.05, step=0.001, initial_value=0.005 ) @size_slider.on_update def _(_) -> None: if point_cloud is not None: point_cloud.point_size = size_slider.value if splat_path is not None: add_splat_to_viser(splat_path, server) # , z_offset=-floor_z) if traj is not None: betas = traj.betas timesteps = betas.shape[1] sample_count = betas.shape[0] assert betas.shape == (sample_count, timesteps, 16) body_quats = SO3.from_matrix(traj.body_rotmats).wxyz assert body_quats.shape == (sample_count, timesteps, 21, 4) device = body_quats.device if traj.hand_rotmats is not None: hand_quats = SO3.from_matrix(traj.hand_rotmats).wxyz left_hand_quats = hand_quats[..., :15, :] right_hand_quats = hand_quats[..., 15:30, :] else: left_hand_quats = None right_hand_quats = None shaped = body_model.with_shape(torch.mean(betas, dim=1, keepdim=True)) fk_outputs = shaped.with_pose_decomposed( T_world_root=SE3.identity( device=device, dtype=body_quats.dtype ).parameters(), body_quats=body_quats, left_hand_quats=left_hand_quats, right_hand_quats=right_hand_quats, ) assert Ts_world_cpf.shape == (timesteps, 7) T_world_root = fncsmpl_extensions.get_T_world_root_from_cpf_pose( # Batch axes of fk_outputs are (num_samples, time). # Batch axes of Ts_world_cpf are (time,). fk_outputs, Ts_world_cpf[None, ...], ) fk_outputs = fk_outputs.with_new_T_world_root(T_world_root) else: shaped = None fk_outputs = None sample_count = 0 glasses_mesh = trimesh.load("./data/glasses.stl") assert isinstance(glasses_mesh, trimesh.Trimesh) glasses_mesh.visual.face_colors = [10, 20, 20, 255] # type: ignore cpf_handle = server.scene.add_frame( "/cpf", show_axes=True, axes_length=0.05, axes_radius=0.004, ) server.scene.add_mesh_trimesh("/cpf/glasses", glasses_mesh, scale=0.001 * 1.05) # TODO: remove # hamer_detections = None # aria_detections = None joint_position_handles: list[viser.SceneNodeHandle] = [] timestep_handles: list[viser.FrameHandle] = [] hamer_handles: list[viser.MeshHandle | viser.PointCloudHandle] = [] aria_handles: list[viser.SceneNodeHandle] = [] for t in range(Ts_world_cpf.shape[0]): timestep_handles.append( server.scene.add_frame(f"/timesteps/{t}", show_axes=False) ) # Joints. if show_joints and fk_outputs is not None: assert traj is not None for j in range(sample_count): joints_colors = np.zeros((21, 3)) joints_colors[:, 0] = traj.contacts[j, t, :].numpy(force=True) joints_colors[:, 2] = 1.0 - traj.contacts[j, t, :].numpy(force=True) joint_position_handles.append( server.scene.add_point_cloud( f"/timesteps/{t}/joints", points=fk_outputs.Ts_world_joint[j, t, :21, 4:7].numpy( force=True ), colors=joints_colors, point_shape="circle", point_size=0.02, ) ) # Visualize HaMeR outputs. if hamer_detections is not None: T_world_cam = SE3(Ts_world_cpf[t]) @ SE3(hamer_detections.T_cpf_cam) server.scene.add_frame( f"/timesteps/{t}/cpf/cam", show_axes=True, axes_length=0.025, axes_radius=0.003, wxyz=T_world_cam.wxyz_xyz[..., :4].numpy(force=True), position=T_world_cam.wxyz_xyz[..., 4:7].numpy(force=True), ) hands_l = hamer_detections.detections_left_tuple[t] hands_r = hamer_detections.detections_right_tuple[t] if hands_l is not None: for j in range(hands_l["verts"].shape[0]): hamer_handles.append( server.scene.add_mesh_simple( f"/timesteps/{t}/cpf/cam/left_hand{j}", vertices=hands_l["verts"][j], faces=hamer_detections.mano_faces_left.numpy(force=True), visible=False, ) ) hamer_handles.append( server.scene.add_point_cloud( f"/timesteps/{t}/cpf/cam/lefft_keypoints3d", points=hands_l["keypoints_3d"][j], colors=(255, 127, 0), point_size=0.008, point_shape="square", visible=False, ) ) if hands_r is not None: for j in range(hands_r["verts"].shape[0]): hamer_handles.append( server.scene.add_mesh_simple( f"/timesteps/{t}/cpf/cam/right_hand{j}", vertices=hands_r["verts"][j], faces=hamer_detections.mano_faces_right.numpy(force=True), visible=False, ) ) hamer_handles.append( server.scene.add_point_cloud( f"/timesteps/{t}/cpf/cam/right_keypoints3d", points=hands_r["keypoints_3d"][j], colors=(0, 127, 255), point_size=0.008, point_shape="square", visible=False, ) ) # Visualize Aria detections. if aria_detections is not None: for side in ("left", "right"): detections = { "left": aria_detections.detections_left_concat, "right": aria_detections.detections_right_concat, }[side] if detections is None: continue indices = detections.indices index = torch.searchsorted(indices, t) if index < len(indices) and indices[index] == t: # found? aria_handles.append( server.scene.add_spline_catmull_rom( f"/timesteps/{t}/aria_detections/{side}", np.array( [ detections.wrist_position[index].numpy(force=True), detections.palm_position[index].numpy(force=True), ] ), line_width=3.0, color=(255, 0, 0) if side == "left" else (0, 255, 0), visible=False, ) ) body_handles = ( [ server.scene.add_mesh_skinned( f"/persons/{i}", vertices=shaped.verts_zero[i, 0, :, :].numpy(force=True), faces=body_model.faces.numpy(force=True), bone_wxyzs=vtf.SO3.identity( batch_axes=(body_model.get_num_joints() + 1,) ).wxyz, bone_positions=np.concatenate( [ np.zeros((1, 3)), # Indices are (batch, time, joint, positions). shaped.joints_zero[i, :, :, :] .numpy(force=True) .squeeze(axis=0), ], axis=0, ), color=(152, 93, 229), skin_weights=body_model.weights.numpy(force=True), ) for i in range(sample_count) ] if shaped is not None else [] ) gui_attach = server.gui.add_checkbox("Attach camera to CPF", initial_value=False) gui_attach_dist = server.gui.add_number("Attach distance", initial_value=0.3) gui_show_body = server.gui.add_checkbox("Show body", initial_value=True) gui_show_glasses = server.gui.add_checkbox("Show glasses", initial_value=True) gui_show_cpf_axes = server.gui.add_checkbox("Show CPF axes", initial_value=False) gui_wireframe = server.gui.add_checkbox("Wireframe", initial_value=False) gui_smpl_opacity = server.gui.add_slider( "SMPL Opacity", initial_value=1.0, min=0.0, max=1.0, step=0.01 ) gui_hamer_opacity = server.gui.add_slider( "HaMeR Opacity", initial_value=1.0, min=0.0, max=1.0, step=0.01 ) @gui_smpl_opacity.on_update def _(_) -> None: for handle in body_handles: handle.opacity = gui_smpl_opacity.value @gui_hamer_opacity.on_update def _(_) -> None: for handle in hamer_handles: if isinstance(handle, viser.MeshHandle): handle.opacity = gui_hamer_opacity.value gui_show_hamer_hands = server.gui.add_checkbox( "Show HaMeR hands", initial_value=False ) gui_show_aria_hands = server.gui.add_checkbox( "Show wrist detections", initial_value=False ) gui_body_color = server.gui.add_rgb("Body color", initial_value=(152, 93, 229)) if show_joints: gui_show_joints = server.gui.add_checkbox("Show joints", initial_value=True) @gui_show_joints.on_update def _(_) -> None: for handle in joint_position_handles: handle.visible = gui_show_joints.value @gui_show_body.on_update def _(_) -> None: for handle in body_handles: handle.visible = gui_show_body.value @gui_show_glasses.on_update def _(_) -> None: # The glasses are a child of the CPF frame. cpf_handle.visible = gui_show_glasses.value @gui_show_cpf_axes.on_update def _(_) -> None: cpf_handle.show_axes = gui_show_cpf_axes.value @gui_wireframe.on_update def _(_) -> None: for handle in body_handles: handle.wireframe = gui_wireframe.value @gui_show_hamer_hands.on_update def _(_) -> None: for handle in hamer_handles: handle.visible = gui_show_hamer_hands.value @gui_show_aria_hands.on_update def _(_) -> None: for handle in aria_handles: handle.visible = gui_show_aria_hands.value @gui_body_color.on_update def _(_) -> None: for handle in body_handles: handle.color = gui_body_color.value # Add playback UI. with server.gui.add_folder("Playback"): gui_timestep = server.gui.add_slider( "Timestep", min=0, max=timesteps - 1, step=1, initial_value=0, disabled=True, ) gui_start_end = server.gui.add_multi_slider( "Start/end", min=0, max=timesteps - 1, initial_value=(0, timesteps - 1), step=1, ) gui_next_frame = server.gui.add_button("Next Frame", disabled=True) gui_prev_frame = server.gui.add_button("Prev Frame", disabled=True) gui_playing = server.gui.add_checkbox("Playing", True) gui_framerate = server.gui.add_slider( "FPS", min=1, max=60, step=0.1, initial_value=15 ) gui_framerate_options = server.gui.add_button_group( "FPS options", ("10", "20", "30", "60") ) # Frame step buttons. @gui_next_frame.on_click def _(_) -> None: gui_timestep.value = (gui_timestep.value + 1) % timesteps @gui_prev_frame.on_click def _(_) -> None: gui_timestep.value = (gui_timestep.value - 1) % timesteps # Disable frame controls when we're playing. @gui_playing.on_update def _(_) -> None: gui_timestep.disabled = gui_playing.value gui_next_frame.disabled = gui_playing.value gui_prev_frame.disabled = gui_playing.value # Set the framerate when we click one of the options. @gui_framerate_options.on_click def _(_) -> None: gui_framerate.value = int(gui_framerate_options.value) Ts_world_cpf_numpy = Ts_world_cpf.numpy(force=True) def do_update() -> None: t = gui_timestep.value cpf_handle.wxyz = Ts_world_cpf_numpy[t, :4] cpf_handle.position = Ts_world_cpf_numpy[t, 4:7] if gui_attach.value: for client in server.get_clients().values(): client.camera.wxyz = ( vtf.SO3(cpf_handle.wxyz) @ vtf.SO3.from_z_radians(np.pi) ).wxyz client.camera.position = cpf_handle.position - vtf.SO3( cpf_handle.wxyz ) @ np.array([0.0, 0.0, gui_attach_dist.value]) if fk_outputs is not None: for i in range(sample_count): for b, bone_handle in enumerate(body_handles[i].bones): if b == 0: bone_transform = fk_outputs.T_world_root[i, t].numpy(force=True) else: bone_transform = fk_outputs.Ts_world_joint[i, t, b - 1].numpy( force=True ) bone_handle.wxyz = bone_transform[:4] bone_handle.position = bone_transform[4:7] for ii, timestep_frame in enumerate(timestep_handles): timestep_frame.visible = t == ii get_viser_file = server.gui.add_button("Get .viser file") if get_ego_video is not None: ego_video = server.gui.add_button("Get Ego Video") @ego_video.on_click def _(event: viser.GuiEvent) -> None: assert event.client is not None notif = event.client.add_notification( "Getting video...", body="", loading=True, with_close_button=False ) ego_video_bytes = get_ego_video( gui_start_end.value[0], gui_start_end.value[1], (gui_start_end.value[1] - gui_start_end.value[0]) / gui_framerate.value, ) notif.remove() event.client.send_file_download("ego_video.mp4", ego_video_bytes) prev_time = time.time() handle = None def loop_cb() -> int: start, end = gui_start_end.value duration = end - start if get_viser_file.value is False: nonlocal prev_time now = time.time() sleepdur = 1.0 / gui_framerate.value - (now - prev_time) if sleepdur > 0.0: time.sleep(sleepdur) prev_time = now if gui_playing.value: gui_timestep.value = (gui_timestep.value + 1 - start) % duration + start do_update() return gui_timestep.value else: # Save trajectory. nonlocal handle if handle is None: handle = server._start_scene_recording() handle.set_loop_start() gui_timestep.value = start assert handle is not None handle.insert_sleep(1.0 / gui_framerate.value) gui_timestep.value = (gui_timestep.value + 1 - start) % duration + start if gui_timestep.value == start: get_viser_file.value = False server.send_file_download( "recording.viser", content=handle.end_and_serialize() ) handle = None do_update() return gui_timestep.value return loop_cb