[
  {
    "path": ".github/workflows/pyright.yml",
    "content": "name: pyright\n\non:\n  push:\n    branches: [main]\n  pull_request:\n    branches: [main]\n\njobs:\n  pyright:\n    runs-on: ubuntu-latest\n    strategy:\n      matrix:\n        python-version: [\"3.12\"]\n\n    steps:\n      - uses: actions/checkout@v2\n      - name: Set up Python ${{ matrix.python-version }}\n        uses: actions/setup-python@v1\n        with:\n          python-version: ${{ matrix.python-version }}\n      - name: Install dependencies\n        run: |\n          pip install uv\n          uv pip install --system -e .\n          uv pip install --system jax\n          uv pip install --system git+https://github.com/brentyi/jaxls.git\n          uv pip install --system git+https://github.com/brentyi/hamer_helper.git\n          uv pip install --system pyright\n      - name: Run pyright\n        run: |\n          pyright .\n"
  },
  {
    "path": ".gitignore",
    "content": "*.swp\n*.swo\n*.pyc\n*.egg-info\n*.ipynb_checkpoints\n__pycache__\n.coverage\nhtmlcov\n.mypy_cache\n.dmypy.json\n.hypothesis\n.envrc\n.lvimrc\n.DS_Store\n.envrc\nlightning_logs/\noutputs/\n\ndata/\negoallo_checkpoint_*\negoallo_example_*\n"
  },
  {
    "path": "0a_preprocess_training_data.py",
    "content": "\"\"\"Convert raw AMASS data to HuMoR-style npz format.\n\nMostly taken from\nhttps://github.com/davrempe/humor/blob/main/humor/scripts/process_amass_data.py,\nbut added gender neutral beta conversion and other utilities.\n\"\"\"\n\nimport dataclasses\nimport os\nimport time\nfrom concurrent.futures import ProcessPoolExecutor\nfrom pathlib import Path\nfrom typing import Dict, Tuple\n\nimport matplotlib.pyplot as plt\nimport numpy as np\nimport torch\nimport tyro\nfrom loguru import logger as guru\nfrom sklearn.cluster import DBSCAN\nfrom tqdm import tqdm\n\nfrom egoallo.preprocessing.body_model import (\n    KEYPT_VERTS,\n    SMPL_JOINTS,\n    BodyModel,\n    reflect_pose_aa,\n    reflect_root_trajectory,\n    run_smpl,\n)\nfrom egoallo.preprocessing.geometry import convert_rotation, joints_global_to_local\nfrom egoallo.preprocessing.util import move_to\n\nAMASS_SPLITS = {\n    \"train\": [\n        \"ACCAD\",\n        \"BMLhandball\",\n        \"BMLmovi\",\n        \"BioMotionLab_NTroje\",\n        \"CMU\",\n        \"DFaust_67\",\n        \"DanceDB\",\n        \"EKUT\",\n        \"Eyes_Japan_Dataset\",\n        \"KIT\",\n        \"MPI_Limits\",\n        \"TCD_handMocap\",\n        \"TotalCapture\",\n    ],\n    \"val\": [\n        \"HumanEva\",\n        \"MPI_HDM05\",\n        \"SFU\",\n        \"MPI_mosh\",\n    ],\n    \"test\": [\n        \"Transitions_mocap\",\n        \"SSM_synced\",\n    ],\n}\nAMASS_SPLITS[\"all\"] = AMASS_SPLITS[\"train\"] + AMASS_SPLITS[\"val\"] + AMASS_SPLITS[\"test\"]\n\n\ndef load_neutral_beta_conversion(gender: str) -> Tuple[np.ndarray, np.ndarray]:\n    assert gender in [\"female\", \"male\"]\n    data = np.load(f\"./data/smplh_gender_conversion/{gender}_to_neutral.npz\")\n    return data[\"A\"], data[\"b\"]\n\n\ndef convert_gender_neutral_beta(\n    beta: np.ndarray, A: np.ndarray, b: np.ndarray\n) -> np.ndarray:\n    \"\"\"\n    :param beta (*, B)\n    :param A (B, B)\n    :param b (B)\n    beta_neutral = A @ beta_gender + b\n    \"\"\"\n    *dims, B = beta.shape\n    A = A.reshape((*(1,) * len(dims), B, B))\n    b = b.reshape((*(1,) * len(dims), B))\n    return np.einsum(\"...ij,...j->...i\", A, beta) + b\n\n\ndef determine_floor_height_and_contacts(\n    body_joint_seq,\n    fps,\n    vis=False,\n    floor_vel_thresh=0.005,\n    floor_height_offset=0.01,\n    contact_vel_thresh=0.005,  # 0.015\n    contact_toe_height_thresh=0.04,  # if static toe above this height\n    contact_ankle_height_thresh=0.08,\n    terrain_height_thresh=0.04,\n    root_height_thresh=0.04,\n    cluster_size_thresh=0.25,\n    discard_terrain_seqs=False,  # throw away person steps onto objects (determined by a heuristic)\n):\n    \"\"\"\n    Taken from\n    https://github.com/davrempe/humor/blob/main/humor/scripts/process_amass_data.py\n\n    Input: body_joint_seq N x 21 x 3 numpy array\n    Contacts are N x 4 where N is number of frames and each row is left heel/toe, right heel/toe\n    \"\"\"\n    num_frames = body_joint_seq.shape[0]\n\n    # compute toe velocities\n    root_seq = body_joint_seq[:, SMPL_JOINTS[\"hips\"], :]\n    left_toe_seq = body_joint_seq[:, SMPL_JOINTS[\"leftToeBase\"], :]\n    right_toe_seq = body_joint_seq[:, SMPL_JOINTS[\"rightToeBase\"], :]\n    left_toe_vel = np.linalg.norm(left_toe_seq[1:] - left_toe_seq[:-1], axis=1)\n    left_toe_vel = np.append(left_toe_vel, left_toe_vel[-1])\n    right_toe_vel = np.linalg.norm(right_toe_seq[1:] - right_toe_seq[:-1], axis=1)\n    right_toe_vel = np.append(right_toe_vel, right_toe_vel[-1])\n\n    if vis:\n        plt.figure()\n        steps = np.arange(num_frames)\n        plt.plot(steps, left_toe_vel, \"-r\", label=\"left vel\")\n        plt.plot(steps, right_toe_vel, \"-b\", label=\"right vel\")\n        plt.legend()\n        plt.show()\n        plt.close()\n\n    # now foot heights (z is up)\n    left_toe_heights = left_toe_seq[:, 2]\n    right_toe_heights = right_toe_seq[:, 2]\n    root_heights = root_seq[:, 2]\n\n    if vis:\n        plt.figure()\n        steps = np.arange(num_frames)\n        plt.plot(steps, left_toe_heights, \"-r\", label=\"left toe height\")\n        plt.plot(steps, right_toe_heights, \"-b\", label=\"right toe height\")\n        plt.plot(steps, root_heights, \"-g\", label=\"root height\")\n        plt.legend()\n        plt.show()\n        plt.close()\n\n    # filter out heights when velocity is greater than some threshold (not in contact)\n    all_inds = np.arange(left_toe_heights.shape[0])\n    left_static_foot_heights = left_toe_heights[left_toe_vel < floor_vel_thresh]\n    left_static_inds = all_inds[left_toe_vel < floor_vel_thresh]\n    right_static_foot_heights = right_toe_heights[right_toe_vel < floor_vel_thresh]\n    right_static_inds = all_inds[right_toe_vel < floor_vel_thresh]\n\n    all_static_foot_heights = np.append(\n        left_static_foot_heights, right_static_foot_heights\n    )\n    all_static_inds = np.append(left_static_inds, right_static_inds)\n\n    if vis:\n        plt.figure()\n        steps = np.arange(left_static_foot_heights.shape[0])\n        plt.plot(steps, left_static_foot_heights, \"-r\", label=\"left static height\")\n        plt.legend()\n        plt.show()\n        plt.close()\n\n    discard_seq = False\n    if all_static_foot_heights.shape[0] > 0:\n        cluster_heights = []\n        cluster_root_heights = []\n        cluster_sizes = []\n        # cluster foot heights and find one with smallest median\n        clustering = DBSCAN(eps=0.005, min_samples=3).fit(\n            all_static_foot_heights.reshape(-1, 1)\n        )\n        all_labels = np.unique(clustering.labels_)\n        # print(all_labels)\n        if vis:\n            plt.figure()\n        min_median = min_root_median = float(\"inf\")\n        for cur_label in all_labels:\n            cur_clust = all_static_foot_heights[clustering.labels_ == cur_label]\n            cur_clust_inds = np.unique(\n                all_static_inds[clustering.labels_ == cur_label]\n            )  # inds in the original sequence that correspond to this cluster\n            if vis:\n                plt.scatter(\n                    cur_clust, np.zeros_like(cur_clust), label=\"foot %d\" % (cur_label)\n                )\n            # get median foot height and use this as height\n            cur_median = np.median(cur_clust)\n            cluster_heights.append(cur_median)\n            cluster_sizes.append(cur_clust.shape[0])\n\n            # get root information\n            cur_root_clust = root_heights[cur_clust_inds]\n            cur_root_median = np.median(cur_root_clust)\n            cluster_root_heights.append(cur_root_median)\n            if vis:\n                plt.scatter(\n                    cur_root_clust,\n                    np.zeros_like(cur_root_clust),\n                    label=\"root %d\" % (cur_label),\n                )\n\n            # update min info\n            if cur_median < min_median:\n                min_median = cur_median\n                min_root_median = cur_root_median\n\n        # print(cluster_heights)\n        # print(cluster_root_heights)\n        # print(cluster_sizes)\n        if vis:\n            plt.show()\n            plt.close()\n\n        floor_height = min_median\n        offset_floor_height = (\n            floor_height - floor_height_offset\n        )  # toe joint is actually inside foot mesh a bit\n\n        if discard_terrain_seqs:\n            # print(min_median + TERRAIN_HEIGHT_THRESH)\n            # print(min_root_median + ROOT_HEIGHT_THRESH)\n            for cluster_root_height, cluster_height, cluster_size in zip(\n                cluster_root_heights, cluster_heights, cluster_sizes\n            ):\n                root_above_thresh = cluster_root_height > (\n                    min_root_median + root_height_thresh\n                )\n                toe_above_thresh = cluster_height > (min_median + terrain_height_thresh)\n                cluster_size_above_thresh = cluster_size > int(\n                    cluster_size_thresh * fps\n                )\n                if root_above_thresh and toe_above_thresh and cluster_size_above_thresh:\n                    discard_seq = True\n                    print(\"DISCARDING sequence based on terrain interaction!\")\n                    break\n    else:\n        floor_height = offset_floor_height = 0.0\n\n    # now find contacts (feet are below certain velocity and within certain range of floor)\n    # compute heel velocities\n    left_heel_seq = body_joint_seq[:, SMPL_JOINTS[\"leftFoot\"], :]\n    right_heel_seq = body_joint_seq[:, SMPL_JOINTS[\"rightFoot\"], :]\n    left_heel_vel = np.linalg.norm(left_heel_seq[1:] - left_heel_seq[:-1], axis=1)\n    left_heel_vel = np.append(left_heel_vel, left_heel_vel[-1])\n    right_heel_vel = np.linalg.norm(right_heel_seq[1:] - right_heel_seq[:-1], axis=1)\n    right_heel_vel = np.append(right_heel_vel, right_heel_vel[-1])\n\n    left_heel_contact = left_heel_vel < contact_vel_thresh\n    right_heel_contact = right_heel_vel < contact_vel_thresh\n    left_toe_contact = left_toe_vel < contact_vel_thresh\n    right_toe_contact = right_toe_vel < contact_vel_thresh\n\n    # compute heel heights\n    left_heel_heights = left_heel_seq[:, 2] - floor_height\n    right_heel_heights = right_heel_seq[:, 2] - floor_height\n    left_toe_heights = left_toe_heights - floor_height\n    right_toe_heights = right_toe_heights - floor_height\n\n    left_heel_contact = np.logical_and(\n        left_heel_contact, left_heel_heights < contact_ankle_height_thresh\n    )\n    right_heel_contact = np.logical_and(\n        right_heel_contact, right_heel_heights < contact_ankle_height_thresh\n    )\n    left_toe_contact = np.logical_and(\n        left_toe_contact, left_toe_heights < contact_toe_height_thresh\n    )\n    right_toe_contact = np.logical_and(\n        right_toe_contact, right_toe_heights < contact_toe_height_thresh\n    )\n\n    contacts = np.zeros((num_frames, len(SMPL_JOINTS)))\n    contacts[:, SMPL_JOINTS[\"leftFoot\"]] = left_heel_contact\n    contacts[:, SMPL_JOINTS[\"leftToeBase\"]] = left_toe_contact\n    contacts[:, SMPL_JOINTS[\"rightFoot\"]] = right_heel_contact\n    contacts[:, SMPL_JOINTS[\"rightToeBase\"]] = right_toe_contact\n\n    # hand contacts\n    left_hand_contact = detect_joint_contact(\n        body_joint_seq,\n        \"leftHand\",\n        floor_height,\n        contact_vel_thresh,\n        contact_ankle_height_thresh,\n    )\n    right_hand_contact = detect_joint_contact(\n        body_joint_seq,\n        \"rightHand\",\n        floor_height,\n        contact_vel_thresh,\n        contact_ankle_height_thresh,\n    )\n    contacts[:, SMPL_JOINTS[\"leftHand\"]] = left_hand_contact\n    contacts[:, SMPL_JOINTS[\"rightHand\"]] = right_hand_contact\n\n    # knee contacts\n    left_knee_contact = detect_joint_contact(\n        body_joint_seq,\n        \"leftLeg\",\n        floor_height,\n        contact_vel_thresh,\n        contact_ankle_height_thresh,\n    )\n    right_knee_contact = detect_joint_contact(\n        body_joint_seq,\n        \"rightLeg\",\n        floor_height,\n        contact_vel_thresh,\n        contact_ankle_height_thresh,\n    )\n    contacts[:, SMPL_JOINTS[\"leftLeg\"]] = left_knee_contact\n    contacts[:, SMPL_JOINTS[\"rightLeg\"]] = right_knee_contact\n\n    return offset_floor_height, contacts, discard_seq\n\n\ndef detect_joint_contact(\n    body_joint_seq, joint_name, floor_height, vel_thresh, height_thresh\n):\n    \"\"\"\n    Taken from\n    https://github.com/davrempe/humor/blob/main/humor/scripts/process_amass_data.py\n    \"\"\"\n    # calc velocity\n    joint_seq = body_joint_seq[:, SMPL_JOINTS[joint_name], :]\n    joint_vel = np.linalg.norm(joint_seq[1:] - joint_seq[:-1], axis=1)\n    joint_vel = np.append(joint_vel, joint_vel[-1])\n    # determine contact by velocity\n    joint_contact = joint_vel < vel_thresh\n    # compute heights\n    joint_heights = joint_seq[:, 2] - floor_height\n    # compute contact by vel + height\n    joint_contact = np.logical_and(joint_contact, joint_heights < height_thresh)\n\n    return joint_contact\n\n\ndef compute_root_align_mats(root_orient):\n    \"\"\"\n    Taken from\n    https://github.com/davrempe/humor/blob/main/humor/scripts/process_amass_data.py\n\n    compute world to canonical frame for each timestep (rotation around up axis)\n    \"\"\"\n    root_orient = torch.as_tensor(root_orient).reshape(-1, 3)\n    # convert aa to matrices\n    root_orient_mat = convert_rotation(root_orient, \"aa\", \"mat\").numpy()\n\n    # rotate root so aligning local body right vector (-x) with world right vector (+x)\n    #       with a rotation around the up axis (+z)\n    # in body coordinates body x-axis is to the left\n    body_right = -root_orient_mat[:, :, 0]\n    world2aligned_mat, world2aligned_aa = compute_align_from_body_right(body_right)\n\n    return world2aligned_mat\n\n\ndef compute_joint_align_mats(joint_seq):\n    \"\"\"\n    Taken from\n    https://github.com/davrempe/humor/blob/main/humor/scripts/process_amass_data.py\n\n    Compute world to canonical frame for each timestep (rotation around up axis)\n    from the given joint seq (T x J x 3)\n    \"\"\"\n    left_idx = SMPL_JOINTS[\"leftUpLeg\"]\n    right_idx = SMPL_JOINTS[\"rightUpLeg\"]\n\n    body_right = joint_seq[:, right_idx] - joint_seq[:, left_idx]\n    body_right = body_right / np.linalg.norm(body_right, axis=1)[:, np.newaxis]\n\n    world2aligned_mat, world2aligned_aa = compute_align_from_body_right(body_right)\n\n    return world2aligned_mat\n\n\ndef compute_align_from_body_right(body_right):\n    \"\"\"\n    Taken from\n    https://github.com/davrempe/humor/blob/main/humor/scripts/process_amass_data.py\n    \"\"\"\n    world2aligned_angle = np.arccos(\n        body_right[:, 0] / (np.linalg.norm(body_right[:, :2], axis=1) + 1e-8)\n    )  # project to world x axis, and compute angle\n    body_right[:, 2] = 0.0\n    world2aligned_axis = np.cross(body_right, np.array([[1.0, 0.0, 0.0]]))\n\n    world2aligned_aa = (\n        world2aligned_axis\n        / (np.linalg.norm(world2aligned_axis, axis=1)[:, np.newaxis] + 1e-8)\n    ) * world2aligned_angle[:, np.newaxis]\n\n    world2aligned_mat = convert_rotation(\n        torch.as_tensor(world2aligned_aa).reshape(-1, 3), \"aa\", \"mat\"\n    ).numpy()\n\n    return world2aligned_mat, world2aligned_aa\n\n\ndef estimate_velocity(data_seq, h):\n    \"\"\"\n    Taken from\n    https://github.com/davrempe/humor/blob/main/humor/scripts/process_amass_data.py\n\n    Given some data sequence of T timesteps in the shape (T, ...), estimates\n    the velocity for the middle T-2 steps using a second order central difference scheme.\n    - h : step size\n    \"\"\"\n    data_tp1 = data_seq[2:]\n    data_tm1 = data_seq[0:-2]\n    data_vel_seq = (data_tp1 - data_tm1) / (2 * h)\n    return data_vel_seq\n\n\ndef estimate_angular_velocity(rot_seq, h):\n    \"\"\"\n    Taken from\n    https://github.com/davrempe/humor/blob/main/humor/scripts/process_amass_data.py\n\n    Given a sequence of T rotation matrices, estimates angular velocity at T-2 steps.\n    Input sequence should be of shape (T, ..., 3, 3)\n    \"\"\"\n    # see https://en.wikipedia.org/wiki/Angular_velocity#Calculation_from_the_orientation_matrix\n    dRdt = estimate_velocity(rot_seq, h)\n    R = rot_seq[1:-1]\n    RT = np.swapaxes(R, -1, -2)\n    # compute skew-symmetric angular velocity tensor\n    w_mat = np.matmul(dRdt, RT)\n\n    # pull out angular velocity vector\n    # average symmetric entries\n    w_x = (-w_mat[..., 1, 2] + w_mat[..., 2, 1]) / 2.0\n    w_y = (w_mat[..., 0, 2] - w_mat[..., 2, 0]) / 2.0\n    w_z = (-w_mat[..., 0, 1] + w_mat[..., 1, 0]) / 2.0\n    w = np.stack([w_x, w_y, w_z], axis=-1)\n\n    return w\n\n\ndef load_seq_smpl_params(input_path: str, num_betas: int = 16):\n    guru.info(f\"Loading from {input_path}\")\n\n    # load in input data\n    # we leave out \"dmpls\" and \"marker_data\"/\"marker_label\" which are not present in all datasets\n    bdata = np.load(input_path)\n    gender = np.array(bdata[\"gender\"], ndmin=1)[0]\n    gender = str(gender, \"utf-8\") if isinstance(gender, bytes) else str(gender)\n    fps = bdata[\"mocap_framerate\"]\n    trans = bdata[\"trans\"][:]  # global translation\n    num_frames = len(trans)\n    root_orient = bdata[\"poses\"][:, :3]  # global root orientation (1 joint)\n    pose_body = bdata[\"poses\"][:, 3:66]  # body joint rotations (21 joints)\n    pose_hand = bdata[\"poses\"][:, 66:]  # finger articulation joint rotations\n    betas = np.tile(\n        bdata[\"betas\"][None, :num_betas], [num_frames, 1]\n    )  # body shape parameters\n\n    # correct mislabeled data\n    if input_path.find(\"BMLhandball\") >= 0:\n        fps = 240\n    if input_path.find(\"20160930_50032\") >= 0 or input_path.find(\"20161014_50033\") >= 0:\n        fps = 59\n\n    model_vars = {\n        \"trans\": trans,\n        \"root_orient\": root_orient,\n        \"pose_body\": pose_body,\n        \"pose_hand\": pose_hand,\n        \"betas\": betas,\n    }\n    meta = {\"fps\": fps, \"gender\": gender, \"num_frames\": num_frames}\n    guru.info(f\"meta {meta}\")\n    guru.info(f\"model var shapes {str({k: v.shape for k, v in model_vars.items()})}\")\n    return model_vars, meta\n\n\ndef run_batch_smpl(\n    body_model: BodyModel,\n    device: torch.device,\n    num_total: int,\n    batch_size: int,\n    return_verts: bool = True,\n    **kwargs,\n):\n    var_dims = body_model.var_dims\n    var_names = [name for name in kwargs if name in var_dims]\n    model_vars = {\n        name: torch.as_tensor(kwargs[name], dtype=torch.float32).reshape(\n            -1, var_dims[name]\n        )\n        for name in var_names\n    }\n    fopts = {k: v for k, v in kwargs.items() if k not in var_names}\n\n    batch_joints, batch_verts = [], []\n    for sidx in range(0, num_total, batch_size):\n        eidx = min(sidx + batch_size, num_total)\n        batch_model_vars = move_to(\n            {name: x[sidx:eidx].contiguous() for name, x in model_vars.items()}, device\n        )\n        with torch.no_grad():\n            joints, verts, _ = run_smpl(\n                body_model, return_verts=return_verts, **batch_model_vars, **fopts\n            )\n        batch_joints.append(joints.detach().cpu())\n        if return_verts and verts is not None:\n            batch_verts.append(verts.detach().cpu())\n\n    joints_all = torch.cat(batch_joints, dim=0)\n    verts_all = torch.cat(batch_verts, dim=0) if len(batch_verts) > 0 else None\n    return joints_all, verts_all\n\n\ndef process_seq(\n    input_path: str,\n    out_path: str,\n    smplh_root: str,\n    dev_id: int,\n    beta_neutral: bool,\n    reflect: bool = False,\n    overwrite: bool = False,\n    **kwargs,\n):\n    if not overwrite and os.path.isfile(out_path):\n        guru.info(f\"{out_path} already exists, skipping.\")\n        return\n\n    guru.info(f\"process {input_path} to {out_path}\")\n\n    model_vars, meta = load_seq_smpl_params(input_path)\n\n    if beta_neutral:  # get the gender neutral beta\n        guru.info(\"converting betas to gender neutral\")\n        A_beta, b_beta = load_neutral_beta_conversion(meta[\"gender\"])\n        model_vars[\"betas\"] = convert_gender_neutral_beta(\n            model_vars[\"betas\"], A_beta, b_beta\n        )\n        meta[\"gender\"] = \"neutral\"\n\n    process_seq_data(\n        model_vars, meta, out_path, dev_id, smplh_root, reflect=reflect, **kwargs\n    )\n\n\ndef process_seq_data(\n    model_vars: Dict,\n    meta: Dict,\n    out_path: str,\n    dev_id: int,\n    smplh_root: str,\n    reflect: bool = False,\n    split_frame_limit: int = 2000,\n    discard_shorter_than: float = 1.0,  # seconds\n    out_fps: int = 30,\n    save_verts: bool = False,\n    save_velocities: bool = True,  # save all parameter velocities available\n):\n    guru.info(f\"Processing seq with meta {meta}\")\n    start_t = time.time()\n\n    gender = meta[\"gender\"]\n    src_fps = meta[\"fps\"]\n    num_frames = meta[\"num_frames\"]\n\n    # only keep middle 80% of sequences to avoid redundanct static poses\n    sidx, eidx = int(0.1 * num_frames), int(0.9 * num_frames)\n    num_frames = eidx - sidx\n    for name, x in model_vars.items():\n        model_vars[name] = x[sidx:eidx]\n    guru.info(str({k: v.shape for k, v in model_vars.items()}))\n\n    # discard if shorter than threshold\n    if num_frames < discard_shorter_than * src_fps:\n        guru.info(f\"Sequence shorter than {discard_shorter_than} s, discarding...\")\n        return\n\n    # must do SMPL forward pass to get joints\n    # split into manageable chunks to avoid running out of GPU memory for SMPL\n    device = (\n        torch.device(f\"cuda:{dev_id}\")\n        if torch.cuda.is_available()\n        else torch.device(\"cpu\")\n    )\n\n    # <HACKS>\n    # smplx tries to read shape properties, even when use_pca=False\n    from smplx.utils import Struct\n\n    Struct.hands_componentsl = np.zeros(100)  # type: ignore\n    Struct.hands_componentsr = np.zeros(100)  # type: ignore\n    Struct.hands_meanl = np.zeros(100)  # type: ignore\n    Struct.hands_meanr = np.zeros(100)  # type: ignore\n\n    # This defaults to 300, but we have 16 beta parameters. When\n    # 16<300 the SMPL class will set num_betas to 10...\n    from smplx import SMPLH\n\n    assert SMPLH.SHAPE_SPACE_DIM in (300, 16)\n    SMPLH.SHAPE_SPACE_DIM = 16\n    # <HACKS>\n\n    body_model = BodyModel(f\"{smplh_root}/{gender}/model.npz\", use_pca=False).to(device)\n    model_vars = {k: torch.as_tensor(v).float() for k, v in model_vars.items()}\n    if reflect:\n        rot_og = model_vars[\"root_orient\"]\n        rot_re, model_vars[\"pose_body\"] = reflect_pose_aa(\n            rot_og, model_vars[\"pose_body\"]\n        )\n        out = body_model.forward(betas=model_vars[\"betas\"][:1].to(device))\n        root_loc = out.Jtr[:, 0].cpu()  # type: ignore\n        model_vars[\"root_orient\"], model_vars[\"trans\"] = reflect_root_trajectory(\n            rot_og, model_vars[\"trans\"], rot_re, root_loc\n        )\n\n    body_joint_seq, body_vtx_seq = run_batch_smpl(\n        body_model,\n        device,\n        num_frames,\n        split_frame_limit,\n        return_verts=save_verts,\n        **model_vars,\n    )\n    joints_glob = body_joint_seq[:, : len(SMPL_JOINTS), :]\n    joint_seq = joints_glob.numpy()\n\n    guru.info(f\"Recovered joints and verts {joint_seq.shape}\")\n\n    out_dict = model_vars.copy()\n    out_dict[\"joints\"] = joint_seq\n    out_dict[\"joints_loc\"], _ = joints_global_to_local(\n        convert_rotation(model_vars[\"root_orient\"], \"aa\", \"mat\"),\n        model_vars[\"trans\"],\n        joints_glob,\n    )\n\n    if save_verts and body_vtx_seq is not None:\n        out_dict[\"mojo_verts\"] = body_vtx_seq[:, KEYPT_VERTS, :].numpy()\n\n    # determine floor height and foot contacts\n    floor_height, contacts, discard_seq = determine_floor_height_and_contacts(\n        joint_seq, src_fps\n    )\n\n    if discard_seq:\n        guru.info(\"Terrain interaction detected, discarding...\")\n        return\n\n    guru.info(f\"Floor height: {floor_height}\")\n    # translate so floor is at z=0\n    for name in [\"trans\", \"joints\", \"mojo_verts\"]:\n        if name not in out_dict:\n            continue\n        out_dict[name][..., 2] -= floor_height\n\n    # compute rotation to canonical frame (forward facing +y) for every frame\n    world2aligned_rot = compute_root_align_mats(model_vars[\"root_orient\"])\n\n    out_dict.update(\n        {\n            \"contacts\": contacts,\n            \"floor_height\": floor_height,\n            \"world2aligned_rot\": world2aligned_rot,\n        }\n    )\n\n    # estimate various velocities based on full frame rate\n    #       with second order central differences before downsampling\n    if save_velocities:\n        h = 1.0 / src_fps\n        lin_names = [\"trans\", \"joints\", \"mojo_verts\"]\n        ang_names = [\"root_orient\", \"pose_body\"]\n        cur_keys = lin_names + ang_names + [\"contacts\"]\n\n        for name in lin_names:\n            if name not in out_dict:\n                continue\n            out_dict[f\"{name}_vel\"] = estimate_velocity(out_dict[name], h)\n\n        # root orient\n        for name in ang_names:\n            if name not in out_dict:\n                continue\n            rot_aa = (\n                torch.as_tensor(out_dict[name]).reshape(num_frames, -1, 3).squeeze()\n            )\n            rot_mat = convert_rotation(rot_aa, \"aa\", \"mat\").numpy()\n            out_dict[f\"{name}_vel\"] = estimate_angular_velocity(rot_mat, h)\n\n        # joint up-axis angular velocity (need to compute joint frames first...)\n        # need the joint transform at all steps to find the angular velocity\n        joints_world2aligned_rot = compute_joint_align_mats(joint_seq)\n        joint_orient_vel = -estimate_angular_velocity(joints_world2aligned_rot, h)\n        # only need around z\n        out_dict[\"joint_orient_vel\"] = joint_orient_vel[:, 2]\n\n        # throw out edge frames for other data so velocities are accurate\n        for name in cur_keys:\n            if name not in out_dict:\n                continue\n            out_dict[name] = out_dict[name][1:-1]\n        num_frames = num_frames - 2\n\n    # downsample frames\n    fps_ratio = float(out_fps) / src_fps\n    guru.info(f\"Downsamp ratio: {fps_ratio}\")\n    new_num_frames = int(fps_ratio * num_frames)\n    guru.info(f\"Downsamp num frames: {new_num_frames}\")\n    downsamp_inds = np.linspace(0, num_frames - 1, num=new_num_frames, dtype=int)\n\n    for k, v in out_dict.items():\n        # print(k, type(v))\n        if not isinstance(v, (torch.Tensor, np.ndarray)):\n            continue\n        if v.ndim >= 1:\n            # print(\"downsampling\", k)\n            out_dict[k] = v[downsamp_inds]\n\n    meta = {\n        \"fps\": out_fps,\n        \"num_frames\": new_num_frames,\n        \"gender\": str(gender),\n    }\n\n    guru.info(f\"Seq process time: {time.time() - start_t} s\")\n    guru.info(f\"Saving data to {out_path}\")\n    os.makedirs(os.path.dirname(out_path), exist_ok=True)\n    np.savez(out_path, **meta, **out_dict)\n\n\n@dataclasses.dataclass\nclass Config:\n    data_root: str\n    \"\"\"Where the AMASS dataset is stored.\"\"\"\n\n    smplh_root: str = \"./data/smplh\"\n    out_root: str = \"./data/processed_30fps_no_skating/\"\n    devices: tuple[int, ...] = (0,)\n    \"\"\"CUDA devices. We use CPU if not available.\"\"\"\n    overwrite: bool = False\n\n\ndef check_skip(path_name: str) -> bool:\n    \"\"\"Copied conditions from https://github.com/davrempe/humor/blob/main/humor/scripts/cleanup_amass_data.py\"\"\"\n    if \"BioMotionLab_NTroje\" in path_name and (\n        \"treadmill\" in path_name or \"normal_\" in path_name\n    ):\n        return True\n    if \"MPI_HDM05\" in path_name and \"dg/HDM_dg_07-01\" in path_name:\n        return True\n    return False\n\n\ndef main(cfg: Config):\n    dsets = AMASS_SPLITS[\"all\"]\n    paths_to_process = []\n    for dset in dsets:\n        paths_to_process.extend(\n            map(str, Path(f\"{cfg.data_root}/{dset}\").glob(\"**/*_poses.npz\"))\n        )\n\n    dev_ids = cfg.devices\n    guru.info(f\"devices {dev_ids}\")\n\n    if len(dev_ids) <= 1:\n        guru.info(\"processing in sequence\")\n        for i, path in tqdm(enumerate(paths_to_process)):\n            if check_skip(path):\n                guru.info(f\"skipping {path}\")\n                continue\n            fname = path.split(cfg.data_root)[-1].rstrip(\"/\")\n            name, ext = os.path.splitext(fname)\n            out_path = f\"{cfg.out_root}/neutral/{name}{ext}\"\n            r_out_path = f\"{cfg.out_root}/neutral/{name}_reflect{ext}\"\n            process_seq(\n                path,\n                out_path,\n                cfg.smplh_root,\n                dev_ids[i % len(dev_ids)],\n                beta_neutral=True,\n                reflect=False,\n                overwrite=cfg.overwrite,\n            )\n            process_seq(\n                path,\n                r_out_path,\n                cfg.smplh_root,\n                dev_ids[i % len(dev_ids)],\n                beta_neutral=True,\n                reflect=True,\n                overwrite=cfg.overwrite,\n            )\n        return\n\n    with ProcessPoolExecutor(max_workers=len(dev_ids)) as exe:\n        for i, path in tqdm(enumerate(paths_to_process)):\n            if check_skip(path):\n                guru.info(f\"skipping {path}\")\n                continue\n            fname = path.split(cfg.data_root)[-1].rstrip(\"/\")\n            name, ext = os.path.splitext(fname)\n            out_path = f\"{cfg.out_root}/neutral/{name}{ext}\"\n            r_out_path = f\"{cfg.out_root}/neutral/{name}_reflect{ext}\"\n            exe.submit(\n                process_seq,\n                path,\n                out_path,\n                cfg.smplh_root,\n                dev_ids[i % len(dev_ids)],\n                beta_neutral=True,\n                reflect=False,\n                overwrite=cfg.overwrite,\n            )\n            exe.submit(\n                process_seq,\n                path,\n                r_out_path,\n                cfg.smplh_root,\n                dev_ids[i % len(dev_ids)],\n                beta_neutral=True,\n                reflect=True,\n                overwrite=cfg.overwrite,\n            )\n\n\nif __name__ == \"__main__\":\n    main(tyro.cli(Config))\n"
  },
  {
    "path": "0b_preprocess_training_data.py",
    "content": "\"\"\"Translate data from HuMoR-style npz format to an hdf5-based one.\n\nDue to AMASS licensing, we unfortunately can't re-distribute our preprocessed dataset. If you have questions\nor run into issues, please reach out.\n\"\"\"\n\nimport queue\nimport threading\nimport time\nfrom pathlib import Path\n\nimport h5py\nimport torch\nimport torch.cuda\nimport tyro\n\nfrom egoallo import fncsmpl\nfrom egoallo.data.amass import EgoTrainingData\n\n\ndef main(\n    smplh_npz_path: Path = Path(\"./data/smplh/neutral/model.npz\"),\n    data_npz_dir: Path = Path(\"./data/processed_30fps_no_skating/\"),\n    output_file: Path = Path(\"./data/egoalgo_no_skating_dataset.hdf5\"),\n    output_list_file: Path = Path(\"./data/egoalgo_no_skating_dataset_files.txt\"),\n    include_hands: bool = True,\n) -> None:\n    body_model = fncsmpl.SmplhModel.load(smplh_npz_path)\n\n    assert torch.cuda.is_available()\n\n    task_queue = queue.Queue[Path]()\n    for path in list(data_npz_dir.glob(\"**/*.npz\")):\n        task_queue.put_nowait(path)\n\n    total_count = task_queue.qsize()\n    start_time = time.time()\n\n    output_hdf5 = h5py.File(output_file, \"w\")\n    file_list: list[str] = []\n\n    def worker(device_idx: int) -> None:\n        device_body_model = body_model.to(\"cuda:\" + str(device_idx))\n\n        while True:\n            try:\n                npz_path = task_queue.get_nowait()\n            except queue.Empty:\n                break\n\n            print(f\"Processing {npz_path} on device {device_idx}...\")\n            train_data = EgoTrainingData.load_from_npz(\n                device_body_model, npz_path, include_hands=include_hands\n            )\n\n            assert \"neutral\" in str(npz_path)\n            group_name = str(npz_path).rpartition(\"neutral/\")[2]\n            print(f\"Writing to group {group_name} on {device_idx}...\")\n            group = output_hdf5.create_group(group_name)\n            file_list.append(group_name)\n\n            for k, v in vars(train_data).items():\n                # No need to write the mask, which will always be ones when we\n                # load from the npz file!\n                if k == \"mask\":\n                    continue\n\n                # Chunk into 32 timesteps at a time.\n                assert v.dtype == torch.float32\n                if v.shape[0] == train_data.T_world_cpf.shape[0]:\n                    chunks = (min(32, v.shape[0]),) + v.shape[1:]\n                else:\n                    assert v.shape[0] == 1\n                    chunks = v.shape\n                group.create_dataset(k, data=v.numpy(force=True), chunks=chunks)\n\n            print(\n                f\"Finished ~{total_count - task_queue.qsize()}/{total_count},\",\n                f\"{(total_count - task_queue.qsize()) / total_count * 100:.2f}% in\",\n                f\"{time.time() - start_time} seconds\",\n            )\n\n    workers = [\n        threading.Thread(target=worker, args=(i,))\n        for i in range(torch.cuda.device_count())\n    ]\n    for w in workers:\n        w.start()\n    for w in workers:\n        w.join()\n    output_list_file.write_text(\"\\n\".join(file_list))\n\n\nif __name__ == \"__main__\":\n    tyro.cli(main)\n"
  },
  {
    "path": "1_train_motion_prior.py",
    "content": "\"\"\"Training script for EgoAllo diffusion model using HuggingFace accelerate.\"\"\"\n\nimport dataclasses\nimport shutil\nfrom pathlib import Path\nfrom typing import Literal\n\nimport tensorboardX\nimport torch.optim.lr_scheduler\nimport torch.utils.data\nimport tyro\nimport yaml\nfrom accelerate import Accelerator, DataLoaderConfiguration\nfrom accelerate.utils import ProjectConfiguration\nfrom loguru import logger\n\nfrom egoallo import network, training_loss, training_utils\nfrom egoallo.data.amass import EgoAmassHdf5Dataset\nfrom egoallo.data.dataclass import collate_dataclass\n\n\n@dataclasses.dataclass(frozen=True)\nclass EgoAlloTrainConfig:\n    experiment_name: str\n    dataset_hdf5_path: Path\n    dataset_files_path: Path\n\n    model: network.EgoDenoiserConfig = network.EgoDenoiserConfig()\n    loss: training_loss.TrainingLossConfig = training_loss.TrainingLossConfig()\n\n    # Dataset arguments.\n    batch_size: int = 256\n    \"\"\"Effective batch size.\"\"\"\n    num_workers: int = 2\n    subseq_len: int = 128\n    dataset_slice_strategy: Literal[\n        \"deterministic\", \"random_uniform_len\", \"random_variable_len\"\n    ] = \"random_uniform_len\"\n    dataset_slice_random_variable_len_proportion: float = 0.3\n    \"\"\"Only used if dataset_slice_strategy == 'random_variable_len'.\"\"\"\n    train_splits: tuple[Literal[\"train\", \"val\", \"test\", \"just_humaneva\"], ...] = (\n        \"train\",\n        \"val\",\n    )\n\n    # Optimizer options.\n    learning_rate: float = 1e-4\n    weight_decay: float = 1e-4\n    warmup_steps: int = 1000\n    max_grad_norm: float = 1.0\n\n\ndef get_experiment_dir(experiment_name: str, version: int = 0) -> Path:\n    \"\"\"Creates a directory to put experiment files in, suffixed with a version\n    number. Similar to PyTorch lightning.\"\"\"\n    experiment_dir = (\n        Path(__file__).absolute().parent\n        / \"experiments\"\n        / experiment_name\n        / f\"v{version}\"\n    )\n    if experiment_dir.exists():\n        return get_experiment_dir(experiment_name, version + 1)\n    else:\n        return experiment_dir\n\n\ndef run_training(\n    config: EgoAlloTrainConfig,\n    restore_checkpoint_dir: Path | None = None,\n) -> None:\n    # Set up experiment directory + HF accelerate.\n    # We're getting to manage logging, checkpoint directories, etc manually,\n    # and just use `accelerate` for distibuted training.\n    experiment_dir = get_experiment_dir(config.experiment_name)\n    assert not experiment_dir.exists()\n    accelerator = Accelerator(\n        project_config=ProjectConfiguration(project_dir=str(experiment_dir)),\n        dataloader_config=DataLoaderConfiguration(split_batches=True),\n    )\n    writer = (\n        tensorboardX.SummaryWriter(logdir=str(experiment_dir), flush_secs=10)\n        if accelerator.is_main_process\n        else None\n    )\n    device = accelerator.device\n\n    # Initialize experiment.\n    if accelerator.is_main_process:\n        training_utils.pdb_safety_net()\n\n        # Save various things that might be useful.\n        experiment_dir.mkdir(exist_ok=True, parents=True)\n        (experiment_dir / \"git_commit.txt\").write_text(\n            training_utils.get_git_commit_hash()\n        )\n        (experiment_dir / \"git_diff.txt\").write_text(training_utils.get_git_diff())\n        (experiment_dir / \"run_config.yaml\").write_text(yaml.dump(config))\n        (experiment_dir / \"model_config.yaml\").write_text(yaml.dump(config.model))\n\n        # Add hyperparameters to TensorBoard.\n        assert writer is not None\n        writer.add_hparams(\n            hparam_dict=training_utils.flattened_hparam_dict_from_dataclass(config),\n            metric_dict={},\n            name=\".\",  # Hack to avoid timestamped subdirectory.\n        )\n\n        # Write logs to file.\n        logger.add(experiment_dir / \"trainlog.log\", rotation=\"100 MB\")\n\n    # Setup.\n    model = network.EgoDenoiser(config.model)\n    train_loader = torch.utils.data.DataLoader(\n        dataset=EgoAmassHdf5Dataset(\n            config.dataset_hdf5_path,\n            config.dataset_files_path,\n            splits=config.train_splits,\n            subseq_len=config.subseq_len,\n            cache_files=True,\n            slice_strategy=config.dataset_slice_strategy,\n            random_variable_len_proportion=config.dataset_slice_random_variable_len_proportion,\n        ),\n        batch_size=config.batch_size,\n        shuffle=True,\n        num_workers=config.num_workers,\n        persistent_workers=config.num_workers > 0,\n        pin_memory=True,\n        collate_fn=collate_dataclass,\n        drop_last=True,\n    )\n    optim = torch.optim.AdamW(  # type: ignore\n        model.parameters(),\n        lr=config.learning_rate,\n        weight_decay=config.weight_decay,\n    )\n    scheduler = torch.optim.lr_scheduler.LambdaLR(\n        optim, lr_lambda=lambda step: min(1.0, step / config.warmup_steps)\n    )\n\n    # HF accelerate setup. We use this for parallelism, etc!\n    model, train_loader, optim, scheduler = accelerator.prepare(\n        model, train_loader, optim, scheduler\n    )\n    accelerator.register_for_checkpointing(scheduler)\n\n    # Restore an existing model checkpoint.\n    if restore_checkpoint_dir is not None:\n        accelerator.load_state(str(restore_checkpoint_dir))\n\n    # Get the initial step count.\n    if restore_checkpoint_dir is not None and restore_checkpoint_dir.name.startswith(\n        \"checkpoint_\"\n    ):\n        step = int(restore_checkpoint_dir.name.partition(\"_\")[2])\n    else:\n        step = int(scheduler.state_dict()[\"last_epoch\"])\n        assert step == 0 or restore_checkpoint_dir is not None, step\n\n    # Save an initial checkpoint. Not a big deal but currently this has an\n    # off-by-one error, in that `step` means something different in this\n    # checkpoint vs the others.\n    accelerator.save_state(str(experiment_dir / f\"checkpoints_{step}\"))\n\n    # Run training loop!\n    loss_helper = training_loss.TrainingLossComputer(config.loss, device=device)\n    loop_metrics_gen = training_utils.loop_metric_generator(counter_init=step)\n    prev_checkpoint_path: Path | None = None\n    while True:\n        for train_batch in train_loader:\n            loop_metrics = next(loop_metrics_gen)\n            step = loop_metrics.counter\n\n            loss, log_outputs = loss_helper.compute_denoising_loss(\n                model,\n                unwrapped_model=accelerator.unwrap_model(model),\n                train_batch=train_batch,\n            )\n            log_outputs[\"learning_rate\"] = scheduler.get_last_lr()[0]\n            accelerator.log(log_outputs, step=step)\n            accelerator.backward(loss)\n            if accelerator.sync_gradients:\n                accelerator.clip_grad_norm_(model.parameters(), config.max_grad_norm)\n            optim.step()\n            scheduler.step()\n            optim.zero_grad(set_to_none=True)\n\n            # The rest of the loop will only be executed by the main process.\n            if not accelerator.is_main_process:\n                continue\n\n            # Logging.\n            if step % 10 == 0:\n                assert writer is not None\n                for k, v in log_outputs.items():\n                    writer.add_scalar(k, v, step)\n\n            # Print status update to terminal.\n            if step % 20 == 0:\n                mem_free, mem_total = torch.cuda.mem_get_info()\n                logger.info(\n                    f\"step: {step} ({loop_metrics.iterations_per_sec:.2f} it/sec)\"\n                    f\" mem: {(mem_total - mem_free) / 1024**3:.2f}/{mem_total / 1024**3:.2f}G\"\n                    f\" lr: {scheduler.get_last_lr()[0]:.7f}\"\n                    f\" loss: {loss.item():.6f}\"\n                )\n\n            # Checkpointing.\n            if step % 5000 == 0:\n                # Save checkpoint.\n                checkpoint_path = experiment_dir / f\"checkpoints_{step}\"\n                accelerator.save_state(str(checkpoint_path))\n                logger.info(f\"Saved checkpoint to {checkpoint_path}\")\n\n                # Keep checkpoints from only every 100k steps.\n                if prev_checkpoint_path is not None:\n                    shutil.rmtree(prev_checkpoint_path)\n                prev_checkpoint_path = None if step % 100_000 == 0 else checkpoint_path\n                del checkpoint_path\n\n\nif __name__ == \"__main__\":\n    tyro.cli(run_training)\n"
  },
  {
    "path": "2_run_hamer_on_vrs.py",
    "content": "\"\"\"Script to run HaMeR on VRS data and save outputs to a pickle file.\"\"\"\n\nimport pickle\nimport shutil\nfrom pathlib import Path\n\nimport cv2\nimport imageio.v3 as iio\nimport numpy as np\nimport tyro\nfrom egoallo.hand_detection_structs import (\n    SavedHamerOutputs,\n    SingleHandHamerOutputWrtCamera,\n)\nfrom hamer_helper import HamerHelper\nfrom projectaria_tools.core import calibration\nfrom projectaria_tools.core.data_provider import (\n    VrsDataProvider,\n    create_vrs_data_provider,\n)\nfrom tqdm.auto import tqdm\n\nfrom egoallo.inference_utils import InferenceTrajectoryPaths\n\n\ndef main(traj_root: Path, overwrite: bool = False) -> None:\n    \"\"\"Run HaMeR for on trajectory. We'll save outputs to\n    `traj_root/hamer_outputs.pkl` and `traj_root/hamer_outputs_render\".\n\n    Arguments:\n        traj_root: The root directory of the trajectory. We assume that there's\n            a VRS file in this directory.\n        overwrite: If True, overwrite any existing HaMeR outputs.\n    \"\"\"\n\n    paths = InferenceTrajectoryPaths.find(traj_root)\n\n    vrs_path = paths.vrs_file\n    assert vrs_path.exists()\n    pickle_out = traj_root / \"hamer_outputs.pkl\"\n    hamer_render_out = traj_root / \"hamer_outputs_render\"  # This is just for debugging.\n    run_hamer_and_save(vrs_path, pickle_out, hamer_render_out, overwrite)\n\n\ndef run_hamer_and_save(\n    vrs_path: Path, pickle_out: Path, hamer_render_out: Path, overwrite: bool\n) -> None:\n    if not overwrite:\n        assert not pickle_out.exists()\n        assert not hamer_render_out.exists()\n    else:\n        pickle_out.unlink(missing_ok=True)\n        shutil.rmtree(hamer_render_out, ignore_errors=True)\n\n    hamer_render_out.mkdir(exist_ok=True)\n    hamer_helper = HamerHelper()\n\n    # VRS data provider setup.\n    provider = create_vrs_data_provider(str(vrs_path.absolute()))\n    assert isinstance(provider, VrsDataProvider)\n    rgb_stream_id = provider.get_stream_id_from_label(\"camera-rgb\")\n    assert rgb_stream_id is not None\n\n    num_images = provider.get_num_data(rgb_stream_id)\n    print(f\"Found {num_images=}\")\n\n    # Get calibrations.\n    device_calib = provider.get_device_calibration()\n    assert device_calib is not None\n    camera_calib = device_calib.get_camera_calib(\"camera-rgb\")\n    assert camera_calib is not None\n    pinhole = calibration.get_linear_camera_calibration(1408, 1408, 450)\n\n    # Compute camera extrinsics!\n    sophus_T_device_camera = device_calib.get_transform_device_sensor(\"camera-rgb\")\n    sophus_T_cpf_camera = device_calib.get_transform_cpf_sensor(\"camera-rgb\")\n    assert sophus_T_device_camera is not None\n    assert sophus_T_cpf_camera is not None\n    T_device_cam = np.concatenate(\n        [\n            sophus_T_device_camera.rotation().to_quat().squeeze(axis=0),\n            sophus_T_device_camera.translation().squeeze(axis=0),\n        ]\n    )\n    T_cpf_cam = np.concatenate(\n        [\n            sophus_T_cpf_camera.rotation().to_quat().squeeze(axis=0),\n            sophus_T_cpf_camera.translation().squeeze(axis=0),\n        ]\n    )\n    assert T_device_cam.shape == T_cpf_cam.shape == (7,)\n\n    # Dict from capture timestamp in nanoseconds to fields we care about.\n    detections_left_wrt_cam: dict[int, SingleHandHamerOutputWrtCamera | None] = {}\n    detections_right_wrt_cam: dict[int, SingleHandHamerOutputWrtCamera | None] = {}\n\n    pbar = tqdm(range(num_images))\n    for i in pbar:\n        image_data, image_data_record = provider.get_image_data_by_index(\n            rgb_stream_id, i\n        )\n        undistorted_image = calibration.distort_by_calibration(\n            image_data.to_numpy_array(), pinhole, camera_calib\n        )\n\n        hamer_out_left, hamer_out_right = hamer_helper.look_for_hands(\n            undistorted_image,\n            focal_length=450,\n        )\n        timestamp_ns = image_data_record.capture_timestamp_ns\n\n        if hamer_out_left is None:\n            detections_left_wrt_cam[timestamp_ns] = None\n        else:\n            detections_left_wrt_cam[timestamp_ns] = {\n                \"verts\": hamer_out_left[\"verts\"],\n                \"keypoints_3d\": hamer_out_left[\"keypoints_3d\"],\n                \"mano_hand_pose\": hamer_out_left[\"mano_hand_pose\"],\n                \"mano_hand_betas\": hamer_out_left[\"mano_hand_betas\"],\n                \"mano_hand_global_orient\": hamer_out_left[\"mano_hand_global_orient\"],\n            }\n\n        if hamer_out_right is None:\n            detections_right_wrt_cam[timestamp_ns] = None\n        else:\n            detections_right_wrt_cam[timestamp_ns] = {\n                \"verts\": hamer_out_right[\"verts\"],\n                \"keypoints_3d\": hamer_out_right[\"keypoints_3d\"],\n                \"mano_hand_pose\": hamer_out_right[\"mano_hand_pose\"],\n                \"mano_hand_betas\": hamer_out_right[\"mano_hand_betas\"],\n                \"mano_hand_global_orient\": hamer_out_right[\"mano_hand_global_orient\"],\n            }\n\n        composited = undistorted_image\n        composited = hamer_helper.composite_detections(\n            composited,\n            hamer_out_left,\n            border_color=(255, 100, 100),\n            focal_length=450,\n        )\n        composited = hamer_helper.composite_detections(\n            composited,\n            hamer_out_right,\n            border_color=(100, 100, 255),\n            focal_length=450,\n        )\n        composited = put_text(\n            composited,\n            \"L detections: \"\n            + (\n                \"0\" if hamer_out_left is None else str(hamer_out_left[\"verts\"].shape[0])\n            ),\n            0,\n            color=(255, 100, 100),\n            font_scale=10.0 / 2880.0 * undistorted_image.shape[0],\n        )\n        composited = put_text(\n            composited,\n            \"R detections: \"\n            + (\n                \"0\"\n                if hamer_out_right is None\n                else str(hamer_out_right[\"verts\"].shape[0])\n            ),\n            1,\n            color=(100, 100, 255),\n            font_scale=10.0 / 2880.0 * undistorted_image.shape[0],\n        )\n        composited = put_text(\n            composited,\n            f\"ns={timestamp_ns}\",\n            2,\n            color=(255, 255, 255),\n            font_scale=10.0 / 2880.0 * undistorted_image.shape[0],\n        )\n\n        print(f\"Saving image {i:06d} to {hamer_render_out / f'{i:06d}.jpeg'}\")\n        iio.imwrite(\n            str(hamer_render_out / f\"{i:06d}.jpeg\"),\n            np.concatenate(\n                [\n                    # Darken input image, just for contrast...\n                    (undistorted_image * 0.6).astype(np.uint8),\n                    composited,\n                ],\n                axis=1,\n            ),\n            quality=90,\n        )\n\n    outputs = SavedHamerOutputs(\n        mano_faces_right=hamer_helper.get_mano_faces(\"right\"),\n        mano_faces_left=hamer_helper.get_mano_faces(\"left\"),\n        detections_right_wrt_cam=detections_right_wrt_cam,\n        detections_left_wrt_cam=detections_left_wrt_cam,\n        T_device_cam=T_device_cam,\n        T_cpf_cam=T_cpf_cam,\n    )\n    with open(pickle_out, \"wb\") as f:\n        pickle.dump(outputs, f)\n\n\ndef put_text(\n    image: np.ndarray,\n    text: str,\n    line_number: int,\n    color: tuple[int, int, int],\n    font_scale: float,\n) -> np.ndarray:\n    \"\"\"Put some text on the top-left corner of an image.\"\"\"\n    image = image.copy()\n    font = cv2.FONT_HERSHEY_PLAIN\n    cv2.putText(\n        image,\n        text=text,\n        org=(2, 1 + int(15 * font_scale * (line_number + 1))),\n        fontFace=font,\n        fontScale=font_scale,\n        color=(0, 0, 0),\n        thickness=max(int(font_scale), 1),\n        lineType=cv2.LINE_AA,\n    )\n    cv2.putText(\n        image,\n        text=text,\n        org=(2, 1 + int(15 * font_scale * (line_number + 1))),\n        fontFace=font,\n        fontScale=font_scale,\n        color=color,\n        thickness=max(int(font_scale), 1),\n        lineType=cv2.LINE_AA,\n    )\n    return image\n\n\nif __name__ == \"__main__\":\n    tyro.cli(main)\n"
  },
  {
    "path": "3_aria_inference.py",
    "content": "from __future__ import annotations\n\nimport dataclasses\nimport time\nfrom pathlib import Path\n\nimport numpy as np\nimport torch\nimport viser\nimport yaml\n\nfrom egoallo import fncsmpl, fncsmpl_extensions\nfrom egoallo.data.aria_mps import load_point_cloud_and_find_ground\nfrom egoallo.guidance_optimizer_jax import GuidanceMode\nfrom egoallo.hand_detection_structs import (\n    CorrespondedAriaHandWristPoseDetections,\n    CorrespondedHamerDetections,\n)\nfrom egoallo.inference_utils import (\n    InferenceInputTransforms,\n    InferenceTrajectoryPaths,\n    load_denoiser,\n)\nfrom egoallo.sampling import run_sampling_with_stitching\nfrom egoallo.transforms import SE3, SO3\nfrom egoallo.vis_helpers import visualize_traj_and_hand_detections\n\n\n@dataclasses.dataclass\nclass Args:\n    traj_root: Path\n    \"\"\"Search directory for trajectories. This should generally be laid out as something like:\n\n    traj_dir/\n        video.vrs\n        egoallo_outputs/\n            {date}_{start_index}-{end_index}.npz\n            ...\n        ...\n    \"\"\"\n    checkpoint_dir: Path = Path(\"./egoallo_checkpoint_april13/checkpoints_3000000/\")\n    smplh_npz_path: Path = Path(\"./data/smplh/neutral/model.npz\")\n\n    glasses_x_angle_offset: float = 0.0\n    \"\"\"Rotate the CPF poses by some X angle.\"\"\"\n    start_index: int = 0\n    \"\"\"Index within the downsampled trajectory to start inference at.\"\"\"\n    traj_length: int = 128\n    \"\"\"How many timesteps to estimate body motion for.\"\"\"\n    num_samples: int = 1\n    \"\"\"Number of samples to take.\"\"\"\n    guidance_mode: GuidanceMode = \"aria_hamer\"\n    \"\"\"Which guidance mode to use.\"\"\"\n    guidance_inner: bool = True\n    \"\"\"Whether to apply guidance optimizer between denoising steps. This is\n    important if we're doing anything with hands. It can be turned off to speed\n    up debugging/experiments, or if we only care about foot skating losses.\"\"\"\n    guidance_post: bool = True\n    \"\"\"Whether to apply guidance optimizer after diffusion sampling.\"\"\"\n    save_traj: bool = True\n    \"\"\"Whether to save the output trajectory, which will be placed under `traj_dir/egoallo_outputs/some_name.npz`.\"\"\"\n    visualize_traj: bool = False\n    \"\"\"Whether to visualize the trajectory after sampling.\"\"\"\n\n\ndef main(args: Args) -> None:\n    device = torch.device(\"cuda\")\n\n    traj_paths = InferenceTrajectoryPaths.find(args.traj_root)\n    if traj_paths.splat_path is not None:\n        print(\"Found splat at\", traj_paths.splat_path)\n    else:\n        print(\"No scene splat found.\")\n    # Get point cloud + floor.\n    points_data, floor_z = load_point_cloud_and_find_ground(traj_paths.points_path)\n\n    # Read transforms from VRS / MPS, downsampled.\n    transforms = InferenceInputTransforms.load(\n        traj_paths.vrs_file, traj_paths.slam_root_dir, fps=30\n    ).to(device=device)\n\n    # Note the off-by-one for Ts_world_cpf, which we need for relative transform computation.\n    Ts_world_cpf = (\n        SE3(\n            transforms.Ts_world_cpf[\n                args.start_index : args.start_index + args.traj_length + 1\n            ]\n        )\n        @ SE3.from_rotation(\n            SO3.from_x_radians(\n                transforms.Ts_world_cpf.new_tensor(args.glasses_x_angle_offset)\n            )\n        )\n    ).parameters()\n    pose_timestamps_sec = transforms.pose_timesteps[\n        args.start_index + 1 : args.start_index + args.traj_length + 1\n    ]\n    Ts_world_device = transforms.Ts_world_device[\n        args.start_index + 1 : args.start_index + args.traj_length + 1\n    ]\n    del transforms\n\n    # Get temporally corresponded HaMeR detections.\n    if traj_paths.hamer_outputs is not None:\n        hamer_detections = CorrespondedHamerDetections.load(\n            traj_paths.hamer_outputs,\n            pose_timestamps_sec,\n        ).to(device)\n    else:\n        print(\"No hand detections found.\")\n        hamer_detections = None\n\n    # Get temporally corresponded Aria wrist and palm estimates.\n    if traj_paths.wrist_and_palm_poses_csv is not None:\n        aria_detections = CorrespondedAriaHandWristPoseDetections.load(\n            traj_paths.wrist_and_palm_poses_csv,\n            pose_timestamps_sec,\n            Ts_world_device=Ts_world_device.numpy(force=True),\n        ).to(device)\n    else:\n        print(\"No Aria hand detections found.\")\n        aria_detections = None\n\n    print(f\"{Ts_world_cpf.shape=}\")\n\n    server = None\n    if args.visualize_traj:\n        server = viser.ViserServer()\n        server.gui.configure_theme(dark_mode=True)\n\n    denoiser_network = load_denoiser(args.checkpoint_dir).to(device)\n    body_model = fncsmpl.SmplhModel.load(args.smplh_npz_path).to(device)\n\n    traj = run_sampling_with_stitching(\n        denoiser_network,\n        body_model=body_model,\n        guidance_mode=args.guidance_mode,\n        guidance_inner=args.guidance_inner,\n        guidance_post=args.guidance_post,\n        Ts_world_cpf=Ts_world_cpf,\n        hamer_detections=hamer_detections,\n        aria_detections=aria_detections,\n        num_samples=args.num_samples,\n        device=device,\n        floor_z=floor_z,\n    )\n\n    # Save outputs in case we want to visualize later.\n    if args.save_traj:\n        save_name = (\n            time.strftime(\"%Y%m%d-%H%M%S\")\n            + f\"_{args.start_index}-{args.start_index + args.traj_length}\"\n        )\n        out_path = args.traj_root / \"egoallo_outputs\" / (save_name + \".npz\")\n        out_path.parent.mkdir(parents=True, exist_ok=True)\n        assert not out_path.exists()\n        (args.traj_root / \"egoallo_outputs\" / (save_name + \"_args.yaml\")).write_text(\n            yaml.dump(dataclasses.asdict(args))\n        )\n\n        posed = traj.apply_to_body(body_model)\n        Ts_world_root = fncsmpl_extensions.get_T_world_root_from_cpf_pose(\n            posed, Ts_world_cpf[..., 1:, :]\n        )\n        print(f\"Saving to {out_path}...\", end=\"\")\n        np.savez(\n            out_path,\n            Ts_world_cpf=Ts_world_cpf[1:, :].numpy(force=True),\n            Ts_world_root=Ts_world_root.numpy(force=True),\n            body_quats=posed.local_quats[..., :21, :].numpy(force=True),\n            left_hand_quats=posed.local_quats[..., 21:36, :].numpy(force=True),\n            right_hand_quats=posed.local_quats[..., 36:51, :].numpy(force=True),\n            contacts=traj.contacts.numpy(force=True),  # Sometimes we forgot this...\n            betas=traj.betas.numpy(force=True),\n            frame_nums=np.arange(args.start_index, args.start_index + args.traj_length),\n            timestamps_ns=(np.array(pose_timestamps_sec) * 1e9).astype(np.int64),\n        )\n        print(\"saved!\")\n\n    # Visualize.\n    if args.visualize_traj:\n        assert server is not None\n        loop_cb = visualize_traj_and_hand_detections(\n            server,\n            Ts_world_cpf[1:],\n            traj,\n            body_model,\n            hamer_detections,\n            aria_detections,\n            points_data=points_data,\n            splat_path=traj_paths.splat_path,\n            floor_z=floor_z,\n        )\n        while True:\n            loop_cb()\n\n\nif __name__ == \"__main__\":\n    import tyro\n\n    main(tyro.cli(Args))\n"
  },
  {
    "path": "4_visualize_outputs.py",
    "content": "from __future__ import annotations\n\nimport io\nfrom pathlib import Path\nfrom typing import Callable\n\nimport cv2\nimport imageio.v3 as iio\nimport numpy as np\nimport torch\nimport tyro\nimport viser\nfrom projectaria_tools.core.data_provider import (\n    VrsDataProvider,\n    create_vrs_data_provider,\n)\nfrom projectaria_tools.core.sensor_data import TimeDomain\nfrom tqdm import tqdm\n\nfrom egoallo import fncsmpl\nfrom egoallo.data.aria_mps import load_point_cloud_and_find_ground\nfrom egoallo.hand_detection_structs import (\n    CorrespondedAriaHandWristPoseDetections,\n    CorrespondedHamerDetections,\n)\nfrom egoallo.inference_utils import InferenceTrajectoryPaths\nfrom egoallo.network import EgoDenoiseTraj\nfrom egoallo.transforms import SE3, SO3\nfrom egoallo.vis_helpers import visualize_traj_and_hand_detections\n\n\ndef main(\n    search_root_dir: Path,\n    smplh_npz_path: Path = Path(\"./data/smplh/neutral/model.npz\"),\n) -> None:\n    \"\"\"Visualization script for outputs from EgoAllo.\n\n    Arguments:\n        search_root_dir: Root directory where inputs/outputs are stored. All\n            NPZ files in this directory will be assumed to be outputs from EgoAllo.\n        smplh_npz_path: Path to the SMPLH model NPZ file.\n    \"\"\"\n    device = torch.device(\"cuda\")\n\n    body_model = fncsmpl.SmplhModel.load(smplh_npz_path).to(device)\n\n    server = viser.ViserServer()\n    server.gui.configure_theme(dark_mode=True)\n\n    def get_file_list():\n        return [\"None\"] + sorted(\n            str(p.relative_to(search_root_dir))\n            for p in search_root_dir.glob(\"**/egoallo_outputs/*.npz\")\n        )\n\n    options = get_file_list()\n    file_dropdown = server.gui.add_dropdown(\"File\", options=options)\n\n    refresh_file_list = server.gui.add_button(\"Refresh File List\")\n\n    @refresh_file_list.on_click\n    def _(_) -> None:\n        file_dropdown.options = get_file_list()\n\n    trajectory_folder = server.gui.add_folder(\"Trajectory\")\n\n    current_file = \"None\"\n    loop_cb = lambda: None\n\n    while True:\n        loop_cb()\n        if current_file != file_dropdown.value:\n            current_file = file_dropdown.value\n\n            # Clear the scene.\n            server.scene.reset()\n\n            if current_file != \"None\":\n                # Clear the folder by removing then re-adding it.\n                # Perhaps we should expose some API for looping through children?\n                trajectory_folder.remove()\n                trajectory_folder = server.gui.add_folder(\"Trajectory\")\n\n                with trajectory_folder:\n                    npz_path = Path(search_root_dir / current_file).resolve()\n                    loop_cb = load_and_visualize(\n                        server,\n                        npz_path,\n                        body_model,\n                        device=device,\n                    )\n                    args = npz_path.parent / (npz_path.stem + \"_args.yaml\")\n                    if args.exists():\n                        with server.gui.add_folder(\"Args\"):\n                            server.gui.add_markdown(\n                                \"```\\n\" + args.read_text() + \"\\n```\"\n                            )\n\n\ndef load_and_visualize(\n    server: viser.ViserServer,\n    npz_path: Path,\n    body_model: fncsmpl.SmplhModel,\n    device: torch.device,\n) -> Callable[[], int]:\n    # Here's how we saved:\n    #\n    # np.savez(\n    #     out_path,\n    #     Ts_world_cpf=Ts_world_cpf[1:, :].numpy(force=True),\n    #     Ts_world_root=Ts_world_root.numpy(force=True),\n    #     body_quats=posed.local_quats[..., :21, :].numpy(force=True),\n    #     left_hand_quats=posed.local_quats[..., 21:36, :].numpy(force=True),\n    #     right_hand_quats=posed.local_quats[..., 36:51, :].numpy(force=True),\n    #     betas=traj.betas.numpy(force=True),\n    #     frame_nums=np.arange(args.start_index, args.start_index + args.traj_length),\n    #     timestamps_ns=(np.array(pose_timestamps_sec) * 1e9).astype(np.int64),\n    # )\n    outputs = np.load(npz_path)\n    expected_keys = [\n        \"Ts_world_cpf\",\n        \"Ts_world_root\",\n        \"body_quats\",\n        \"left_hand_quats\",\n        \"right_hand_quats\",\n        \"betas\",\n        \"frame_nums\",\n        \"timestamps_ns\",\n    ]\n    assert all(key in outputs for key in expected_keys), (\n        f\"Missing keys in NPZ file. Expected: {expected_keys}, Found: {list(outputs.keys())}\"\n    )\n    (num_samples, timesteps, _, _) = outputs[\"body_quats\"].shape\n\n    # We assume the directory structure is:\n    # - some trajectory root\n    #     - outputs\n    #         -  the npz file\n    traj_dir = npz_path.resolve().parent.parent\n    paths = InferenceTrajectoryPaths.find(traj_dir)\n\n    provider = create_vrs_data_provider(str(paths.vrs_file))\n    device_calib = provider.get_device_calibration()\n    T_device_cpf = SE3(\n        torch.from_numpy(\n            device_calib.get_transform_device_cpf().to_quat_and_translation()\n        )\n    )\n    assert T_device_cpf.wxyz_xyz.shape == (1, 7)\n    pose_timestamps_sec = outputs[\"timestamps_ns\"] / 1e9\n\n    Ts_world_device = (\n        SE3(torch.from_numpy(outputs[\"Ts_world_cpf\"])) @ T_device_cpf.inverse()\n    ).wxyz_xyz\n\n    # Get temporally corresponded HaMeR detections.\n    if paths.hamer_outputs is not None:\n        hamer_detections = CorrespondedHamerDetections.load(\n            paths.hamer_outputs,\n            pose_timestamps_sec,\n        )\n    else:\n        print(\"No hand detections found.\")\n        hamer_detections = None\n\n    # Get temporally corresponded Aria wrist and palm estimates.\n    if paths.wrist_and_palm_poses_csv is not None:\n        aria_detections = CorrespondedAriaHandWristPoseDetections.load(\n            paths.wrist_and_palm_poses_csv,\n            pose_timestamps_sec,\n            Ts_world_device=Ts_world_device.numpy(force=True),\n        )\n    else:\n        aria_detections = None\n\n    if paths.splat_path is not None:\n        print(\"Found splat at\", paths.splat_path)\n    else:\n        print(\"No scene splat found.\")\n\n    # Get point cloud + floor.\n    points_data, floor_z = load_point_cloud_and_find_ground(\n        paths.points_path, \"filtered\"\n    )\n\n    traj = EgoDenoiseTraj(\n        betas=torch.from_numpy(outputs[\"betas\"]).to(device),\n        body_rotmats=SO3(\n            torch.from_numpy(outputs[\"body_quats\"]),\n        )\n        .as_matrix()\n        .to(device),\n        # We weren't saving contacts originally. We added it September 28th.\n        contacts=torch.zeros((num_samples, timesteps, 21), device=device)\n        if \"contacts\" not in outputs\n        else torch.from_numpy(outputs[\"contacts\"]).to(device),\n        hand_rotmats=SO3(\n            torch.from_numpy(\n                np.concatenate(\n                    [\n                        outputs[\"left_hand_quats\"],\n                        outputs[\"right_hand_quats\"],\n                    ],\n                    axis=-2,\n                )\n            ).to(device)\n        ).as_matrix(),\n    )\n    Ts_world_cpf = torch.from_numpy(outputs[\"Ts_world_cpf\"]).to(device)\n\n    def get_ego_video(\n        start_index: int,\n        end_index: int,\n        total_duration: float,\n    ) -> bytes:\n        \"\"\"Helper function that returns the egocentric video corresponding to\n        some start/end pose index.\"\"\"\n        assert isinstance(provider, VrsDataProvider)\n        rgb_stream_id = provider.get_stream_id_from_label(\"camera-rgb\")\n        assert rgb_stream_id is not None\n        camera_fps = provider.get_configuration(rgb_stream_id).get_nominal_rate_hz()\n        print(f\"{camera_fps=}\")\n\n        start_ns = int(outputs[\"timestamps_ns\"][start_index])\n        first_ns = provider.get_first_time_ns(rgb_stream_id, TimeDomain.RECORD_TIME)\n\n        image_start_index = int((start_ns - first_ns) / 1e9 * camera_fps)\n        image_end_index = min(\n            int(image_start_index + (end_index - start_index) / 30.0 * camera_fps) + 5,\n            provider.get_num_data(rgb_stream_id),\n        )\n\n        frames = []\n        for i in tqdm(range(image_start_index, image_end_index)):\n            image_data = provider.get_image_data_by_index(rgb_stream_id, i)[0]\n            image_array = image_data.to_numpy_array().copy()\n            image_array = cv2.resize(\n                image_array, (800, 800), interpolation=cv2.INTER_AREA\n            )\n            image_array = cv2.rotate(image_array, cv2.ROTATE_90_CLOCKWISE)\n            frames.append(image_array)\n\n        fps = len(frames) / total_duration\n        output = io.BytesIO()\n        iio.imwrite(\n            output,\n            frames,\n            fps=fps,\n            extension=\".mp4\",\n            codec=\"libx264\",\n            pixelformat=\"yuv420p\",\n            quality=None,\n            ffmpeg_params=[\"-crf\", \"23\"],\n        )\n        return output.getvalue()\n\n    return visualize_traj_and_hand_detections(\n        server,\n        Ts_world_cpf,\n        traj,\n        body_model,\n        hamer_detections,\n        aria_detections,\n        points_data,\n        paths.splat_path,\n        floor_z=floor_z,\n        get_ego_video=get_ego_video,\n    )\n\n\nif __name__ == \"__main__\":\n    tyro.cli(main)\n"
  },
  {
    "path": "5_eval_body_metrics.py",
    "content": "\"\"\"Example script for computing body metrics on the test split of the AMASS dataset.\n\nThis is not the exact script we used for the paper metrics, but should have the\ndetails that matter matched. Below are some metrics from this script when our\nreleased checkpoint is passed in.\n\nFor --subseq-len 128:\n\n     mpjpe 118.340 +/- 1.350             (in paper: 119.7 +/- 1.3)\n     pampjpe 100.026 +/- 1.349           (in paper: 101.1 +/- 1.3)\n     T_head 0.006 +/- 0.000              (in paper: 0.0062 +/- 0.0001)\n     foot_contact (GND) 1.000 +/- 0.000  (in paper: 1.0 +/- 0.0)\n     foot_skate 0.417 +/- 0.017          (not reported in paper)\n\n\nFor --subseq-len 32:\n\n     mpjpe 129.193 +/- 1.108             (in paper: 129.8 +/- 1.1)\n     pampjpe 109.489 +/- 1.147           (in paper: 109.8 +/- 1.1)\n     T_head 0.006 +/- 0.000              (in paper: 0.0064 +/- 0.0001)\n     foot_contact (GND) 0.985 +/- 0.003  (in paper: 0.98 +/- 0.00)\n     foot_skate 0.185 +/- 0.005          (not reported in paper)\n\"\"\"\n\nfrom pathlib import Path\n\nimport jax.tree\nimport numpy as np\nimport torch.optim.lr_scheduler\nimport torch.utils.data\nimport tyro\n\nfrom egoallo import fncsmpl\nfrom egoallo.data.amass import EgoAmassHdf5Dataset\nfrom egoallo.fncsmpl_extensions import get_T_world_root_from_cpf_pose\nfrom egoallo.inference_utils import load_denoiser\nfrom egoallo.metrics_helpers import (\n    compute_foot_contact,\n    compute_foot_skate,\n    compute_head_trans,\n    compute_mpjpe,\n)\nfrom egoallo.sampling import run_sampling_with_stitching\nfrom egoallo.transforms import SE3, SO3\n\n\ndef main(\n    dataset_hdf5_path: Path,\n    dataset_files_path: Path,\n    subseq_len: int = 128,\n    guidance_inner: bool = False,\n    checkpoint_dir: Path = Path(\"./egoallo_checkpoint_april13/checkpoints_3000000/\"),\n    smplh_npz_path: Path = Path(\"./data/smplh/neutral/model.npz\"),\n    num_samples: int = 1,\n) -> None:\n    \"\"\"Compute body metrics on the test split of the AMASS dataset.\"\"\"\n    device = torch.device(\"cuda\")\n\n    # Setup.\n    denoiser_network = load_denoiser(checkpoint_dir).to(device)\n    dataset = EgoAmassHdf5Dataset(\n        dataset_hdf5_path,\n        dataset_files_path,\n        splits=(\"test\",),\n        # We need an extra timestep in order to compute the relative CPF pose. (T_cpf_tm1_cpf_t)\n        subseq_len=subseq_len + 1,\n        cache_files=True,\n        slice_strategy=\"deterministic\",\n        random_variable_len_proportion=0.0,\n    )\n    body_model = fncsmpl.SmplhModel.load(smplh_npz_path).to(device)\n\n    metrics = list[dict[str, np.ndarray]]()\n\n    for i in range(len(dataset)):\n        sequence = dataset[i].to(device)\n\n        samples = run_sampling_with_stitching(\n            denoiser_network,\n            body_model=body_model,\n            guidance_mode=\"no_hands\",\n            guidance_inner=guidance_inner,\n            guidance_post=True,\n            Ts_world_cpf=sequence.T_world_cpf,\n            hamer_detections=None,\n            aria_detections=None,\n            num_samples=num_samples,\n            floor_z=0.0,\n            device=device,\n            guidance_verbose=False,\n        )\n\n        assert samples.hand_rotmats is not None\n        assert samples.betas.shape == (num_samples, subseq_len, 16)\n        assert samples.body_rotmats.shape == (num_samples, subseq_len, 21, 3, 3)\n        assert samples.hand_rotmats.shape == (num_samples, subseq_len, 30, 3, 3)\n        assert sequence.hand_quats is not None\n\n        # We'll only use the body joint rotations.\n        pred_posed = body_model.with_shape(samples.betas).with_pose(\n            T_world_root=SE3.identity(device, torch.float32).wxyz_xyz,\n            local_quats=SO3.from_matrix(\n                torch.cat([samples.body_rotmats, samples.hand_rotmats], dim=2)\n            ).wxyz,\n        )\n        pred_posed = pred_posed.with_new_T_world_root(\n            get_T_world_root_from_cpf_pose(pred_posed, sequence.T_world_cpf[1:, ...])\n        )\n\n        label_posed = body_model.with_shape(sequence.betas[1:, ...]).with_pose(\n            sequence.T_world_root[1:, ...],\n            torch.cat(\n                [\n                    sequence.body_quats[1:, ...],\n                    sequence.hand_quats[1:, ...],\n                ],\n                dim=1,\n            ),\n        )\n\n        metrics.append(\n            {\n                \"mpjpe\": compute_mpjpe(\n                    label_T_world_root=label_posed.T_world_root,\n                    label_Ts_world_joint=label_posed.Ts_world_joint[:, :21, :],\n                    pred_T_world_root=pred_posed.T_world_root,\n                    pred_Ts_world_joint=pred_posed.Ts_world_joint[:, :, :21, :],\n                    per_frame_procrustes_align=False,\n                ),\n                \"pampjpe\": compute_mpjpe(\n                    label_T_world_root=label_posed.T_world_root,\n                    label_Ts_world_joint=label_posed.Ts_world_joint[:, :21, :],\n                    pred_T_world_root=pred_posed.T_world_root,\n                    pred_Ts_world_joint=pred_posed.Ts_world_joint[:, :, :21, :],\n                    per_frame_procrustes_align=True,\n                ),\n                # We didn't report foot skating metrics in the paper. It's not\n                # really meaningful: since we optimize foot skating in the\n                # guidance optimizer, it's easy to \"cheat\" this metric.\n                \"foot_skate\": compute_foot_skate(\n                    pred_Ts_world_joint=pred_posed.Ts_world_joint[:, :, :21, :],\n                ),\n                \"foot_contact (GND)\": compute_foot_contact(\n                    pred_Ts_world_joint=pred_posed.Ts_world_joint[:, :, :21, :],\n                ),\n                \"T_head\": compute_head_trans(\n                    label_Ts_world_joint=label_posed.Ts_world_joint[:, :21, :],\n                    pred_Ts_world_joint=pred_posed.Ts_world_joint[:, :, :21, :],\n                ),\n            }\n        )\n\n        print(\"=\" * 80)\n        print(\"=\" * 80)\n        print(\"=\" * 80)\n        print(f\"Metrics ({i}/{len(dataset)} processed)\")\n        for k, v in jax.tree.map(\n            lambda *x: f\"{np.mean(x):.3f} +/- {np.std(x) / np.sqrt(len(metrics) * num_samples):.3f}\",\n            *metrics,\n        ).items():\n            print(\"\\t\", k, v)\n        print(\"=\" * 80)\n        print(\"=\" * 80)\n        print(\"=\" * 80)\n\n\nif __name__ == \"__main__\":\n    tyro.cli(main)\n"
  },
  {
    "path": "LICENSE",
    "content": "MIT License\n\nCopyright (c) 2024 Brent Yi\n\nPermission is hereby granted, free of charge, to any person obtaining a copy\nof this software and associated documentation files (the \"Software\"), to deal\nin the Software without restriction, including without limitation the rights\nto use, copy, modify, merge, publish, distribute, sublicense, and/or sell\ncopies of the Software, and to permit persons to whom the Software is\nfurnished to do so, subject to the following conditions:\n\nThe above copyright notice and this permission notice shall be included in all\ncopies or substantial portions of the Software.\n\nTHE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\nIMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\nFITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\nAUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\nLIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\nOUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE\nSOFTWARE.\n"
  },
  {
    "path": "README.md",
    "content": "# egoallo\n\n**[Project page](https://egoallo.github.io/) &bull;\n[arXiv](https://arxiv.org/abs/2410.03665)**\n\nCode release for our preprint:\n\n<table><tr><td>\n    Brent Yi<sup>1</sup>, Vickie Ye<sup>1</sup>, Maya Zheng<sup>1</sup>, Yunqi Li<sup>2</sup>, Lea M&uuml;ller<sup>1</sup>, Georgios Pavlakos<sup>3</sup>, Yi Ma<sup>1</sup>, Jitendra Malik<sup>1</sup>, and Angjoo Kanazawa<sup>1</sup>.\n    <strong>Estimating Body and Hand Motion in an Ego-sensed World.</strong>\n    arXiV, 2024.\n</td></tr>\n</table>\n<sup>1</sup><em>UC Berkeley</em>, <sup>2</sup><em>ShanghaiTech</em>, <sup>3</sup><em>UT Austin</em>\n\n---\n\n## Updates\n\n- **Oct 7, 2024:** Initial release. (training code, core implementation details)\n- **Oct 14, 2024:** Added model checkpoint, dataset preprocessing, inference, and visualization scripts.\n- **May 6, 2025:** Updated scripts + instructions for dataset preprocessing, which is now self-contained in this repository.\n\n## Overview\n\n**TLDR;** We use egocentric SLAM poses and images to estimate 3D human body pose, height, and hands.\n\nhttps://github.com/user-attachments/assets/7d28e07f-ab83-4749-ac6b-abe692d9ba20\n\nThis repository is structured as follows:\n\n```\n.\n├── download_checkpoint_and_data.sh\n│                            - Download model checkpoint and sample data.\n├── 0_preprocess_training_data.py\n│                            - Preprocessing script for training datasets.\n├── 1_train_motion_prior.py\n│                            - Training script for motion diffusion model.\n├── 2_run_hamer_on_vrs.py\n│                            - Run HaMeR on inference data (expects Aria VRS).\n├── 3_aria_inference.py\n│                            - Run full pipeline on inference data.\n├── 4_visualize_outputs.py\n│                            - Visualize outputs from inference.\n├── 5_eval_body_metrics.py\n│                            - Compute and print body estimation accuracy metrics.\n│\n├── src/egoallo/\n│   ├── data/                - Dataset utilities.\n│   ├── transforms/          - SO(3) / SE(3) transformation helpers.\n│   └── *.py                 - All core implementation.\n│\n└── pyproject.toml          - Python dependencies/package metadata.\n```\n\n## Getting started\n\nEgoAllo requires Python 3.12 or newer.\n\n1. **Clone the repository.**\n   ```bash\n   git clone https://github.com/brentyi/egoallo.git\n   ```\n2. **Install general dependencies.**\n   ```bash\n   cd egoallo\n   pip install -e .\n   ```\n3. **Download+unzip model checkpoint and sample data.**\n\n   ```bash\n   bash download_checkpoint_and_data.sh\n   ```\n\n   You can also download the zip files manually: here are links to the [checkpoint](https://drive.google.com/file/d/14bDkWixFgo3U6dgyrCRmLoXSsXkrDA2w/view?usp=drive_link) and [example trajectories](https://drive.google.com/file/d/14zQ95NYxL4XIT7KIlFgAYTPCRITWxQqu/view?usp=drive_link).\n\n4. **Download the SMPL-H model file.**\n\n   You can find the \"Extended SMPL+H model\" (16 shape parameters) from the [MANO project webpage](https://mano.is.tue.mpg.de/).\n   Our scripts assumes an npz file located at `./data/smplh/neutral/model.npz`, but this can be overridden at the command-line (`--smplh-npz-path {your path}`).\n\n5. **Visualize model outputs.**\n\n   The example trajectories directory includes example outputs from our model. You can visualize them with:\n\n   ```bash\n   python 4_visualize_outputs.py --search-root-dir ./egoallo_example_trajectories\n   ```\n\n## Running inference\n\n1. **Installing inference dependencies.**\n\n   Our guidance optimization uses a Levenberg-Marquardt optimizer that's implemented in JAX. If you want to run this on an NVIDIA GPU, you'll need to install JAX with CUDA support:\n\n   ```bash\n   # Also see: https://jax.readthedocs.io/en/latest/installation.html\n   pip install \"jax[cuda12]==0.6.1\"\n   ```\n\n   You'll also need [jaxls](https://github.com/brentyi/jaxls):\n\n   ```bash\n   pip install git+https://github.com/brentyi/jaxls.git\n   ```\n\n2. **Running inference on example data.**\n\n   Here's an example command for running EgoAllo on the \"coffeemachine\" sequence:\n\n   ```bash\n   python 3_aria_inference.py --traj-root ./egoallo_example_trajectories/coffeemachine\n   ```\n\n   You can run `python 3_aria_inference.py --help` to see the full list of options.\n\n3. **Running inference on your own data.**\n\n   To run inference on your own data, you can copy the structure of the example trajectories. The key files are:\n\n   - A VRS file from Project Aria, which contains calibrations and images.\n   - SLAM outputs from Project Aria's MPS: `closed_loop_trajectory.csv` and `semidense_points.csv.gz`.\n   - (optional) HaMeR outputs, which we save to a `hamer_outputs.pkl`.\n   - (optional) Project Aria wrist and palm tracking outputs.\n\n4. **Running HaMeR on your own data.**\n\n   To generate the `hamer_outputs.pkl` file, you'll need to install [hamer_helper](https://github.com/brentyi/hamer_helper).\n\n   Then, as an example for running on our coffeemachine sequence:\n\n   ```bash\n   python 2_run_hamer_on_vrs.py --traj-root ./egoallo_example_trajectories/coffeemachine\n   ```\n\n## Preprocessing Training Data\n\nTo train the motion prior model, we use data from the [AMASS dataset](https://amass.is.tue.mpg.de/). Due to licensing constraints, we cannot redistribute the preprocessed data. Instead, we provide two sequential preprocessing scripts:\n\n1. **Download the AMASS dataset.**\n\n   Download the AMASS dataset from the [official website](https://amass.is.tue.mpg.de/). We use the following splits:\n\n   - **Training**: ACCAD, BioMotionLab_NTroje, BMLhandball, BMLmovi, CMU, DanceDB, DFaust_67EKUT, Eyes_Japan_Dataset, KIT, MPI_Limits, TCD_handMocap, TotalCapture\n   - **Validation**: HumanEva, MPI_HDM05, SFU, MPI_mosh\n   - **Testing**: Transitions_mocap, SSM_synced\n\n2. **Run the first preprocessing script.**\n\n   ```bash\n   python 0a_preprocess_training_data.py --help\n   python 0a_preprocess_training_data.py --data-root /path/to/amass --smplh-root ./data/smplh\n   ```\n\n   This script, adapted from HuMoR, processes raw AMASS data by:\n\n   - Converting to gender-neutral SMPL-H parameters\n   - Computing contact labels for feet, hands, and knees\n   - Filtering out problematic sequences (treadmill walking, sequences with foot skating)\n   - Downsampling to 30fps\n\n3. **Run the second preprocessing script.**\n\n   ```bash\n   python 0b_preprocess_training_data.py --help\n   python 0b_preprocess_training_data.py --data-npz-dir ./data/processed_30fps_no_skating/\n   ```\n\n   This converts the processed NPZ files to a unified HDF5 format for more efficient training, with optimized chunk sizes for reading sequences.\n\n## Status\n\nThis repository currently contains:\n\n- `egoallo` package, which contains reference training and sampling implementation details.\n- Training script.\n- Model checkpoints.\n- Dataset preprocessing script.\n- Inference script.\n- Visualization script.\n- Setup instructions.\n\nWhile we've put effort into cleaning up our code for release, this is research\ncode and there's room for improvement. If you have questions or comments,\nplease reach out!\n"
  },
  {
    "path": "download_checkpoint_and_data.sh",
    "content": "# Script for downloading model checkpoint and example inputs/outputs.\n\n# egoallo_checkpoint_april13.zip (552 MB)\ngdown https://drive.google.com/file/d/14bDkWixFgo3U6dgyrCRmLoXSsXkrDA2w/view?usp=drive_link --fuzzy\nunzip egoallo_checkpoint_april13.zip\nrm egoallo_checkpoint_april13.zip\n\n# egoallo_example_trajectories.zip (8.17 GB)\ngdown https://drive.google.com/file/d/14zQ95NYxL4XIT7KIlFgAYTPCRITWxQqu/view?usp=drive_link --fuzzy\nunzip egoallo_example_trajectories.zip\nrm egoallo_example_trajectories.zip\n"
  },
  {
    "path": "pyproject.toml",
    "content": "[build-system]\nrequires = [\"setuptools>=61.0\"]\nbuild-backend = \"setuptools.build_meta\"\n\n[project]\nname = \"egoallo\"\nversion = \"0.0.0\"\ndescription = \"egoallo\"\nreadme = \"README.md\"\nlicense = { text=\"MIT\" }\nrequires-python = \">=3.12\"\nclassifiers = [\n    \"Programming Language :: Python :: 3.12\",\n    \"License :: OSI Approved :: MIT License\",\n    \"Operating System :: OS Independent\"\n]\ndependencies = [\n    \"torch==2.7.1\",\n    \"viser>=0.2.11\",\n    \"plyfile==1.1.2\",\n    \"typeguard==4.4.3\",\n    \"jaxtyping==0.3.2\",\n    \"einops==0.8.1\",\n    \"rotary-embedding-torch==0.8.6\",\n    \"h5py==3.13.0\",\n    \"tensorboard==2.19.0\",\n    \"projectaria_tools==1.6.0\",\n    \"accelerate==1.7.0\",\n    \"tensorboardX==2.6.2.2\",\n    \"loguru==0.7.3\",\n    \"projectaria-tools[all]==1.6.0\",\n    \"opencv-python==4.11.0.86\",\n    \"gdown==5.2.0\",\n    \"scikit-learn==1.6.1\", # Only needed for preprocessing\n    \"smplx==0.1.28\", # Only needed for preprocessing\n]\n\n[tool.setuptools.package-data]\negoallo = [\"py.typed\"]\n\n[tool.pyright]\nignore = [\"**/preprocessing/**\", \"./0a_preprocess_training_data.py\"]\n\n[tool.ruff.lint]\nselect = [\n    \"E\",  # pycodestyle errors.\n    \"F\",  # Pyflakes rules.\n    \"PLC\",  # Pylint convention warnings.\n    \"PLE\",  # Pylint errors.\n    \"PLR\",  # Pylint refactor recommendations.\n    \"PLW\",  # Pylint warnings.\n]\nignore = [\n    \"E731\",  # Do not assign a lambda expression, use a def.\n    \"E741\", # Ambiguous variable name. (l, O, or I)\n    \"E501\",  # Line too long.\n    \"E721\",  # Do not compare types, use `isinstance()`.\n    \"F722\",  # Forward annotation false positive from jaxtyping. Should be caught by pyright.\n    \"F821\",  # Forward annotation false positive from jaxtyping. Should be caught by pyright.\n    \"PLR2004\",  # Magic value used in comparison.\n    \"PLR0915\",  # Too many statements.\n    \"PLR0913\",  # Too many arguments.\n    \"PLC0414\",  # Import alias does not rename variable. (this is used for exporting names)\n    \"PLC1901\",  # Use falsey strings.\n    \"PLR5501\",  # Use `elif` instead of `else if`.\n    \"PLR0911\",  # Too many return statements.\n    \"PLR0912\",  # Too many branches.\n    \"PLW0603\",  # Globa statement updates are discouraged.\n    \"PLW2901\",  # For loop variable overwritten.\n]\n"
  },
  {
    "path": "src/egoallo/__init__.py",
    "content": ""
  },
  {
    "path": "src/egoallo/fncsmpl.py",
    "content": "\"\"\"Somewhat opinionated wrapper for the SMPL-H body model.\n\nVery little of it is specific to SMPL-H. This could very easily be adapted for other models in SMPL family.\n\nWe break down the SMPL-H into four stages, each with a corresponding data structure:\n- Loading the model itself:\n    `model = SmplhModel.load(path to npz)`\n- Applying a body shape to the model:\n    `shaped = model.with_shape(betas)`\n- Posing the body shape:\n    `posed = shaped.with_pose(root pose, local joint poses)`\n- Recovering the mesh with LBS:\n    `mesh = posed.lbs()`\n\nIn contrast to other SMPL wrappers:\n- Everything is stateless, so we can support arbitrary batch axes.\n- The root is no longer ever called a joint.\n- The `trans` and `root_orient` inputs are replaced by a single SE(3) root transformation.\n- We're using (4,) wxyz quaternion vectors for all rotations, (7,) wxyz_xyz vectors for all\n  rigid transforms.\n\"\"\"\n\nfrom __future__ import annotations\n\nfrom pathlib import Path\n\nimport numpy as np\nimport torch\nfrom einops import einsum\nfrom jaxtyping import Float, Int\nfrom torch import Tensor\n\nfrom .tensor_dataclass import TensorDataclass\nfrom .transforms import SE3, SO3\n\n\nclass SmplhModel(TensorDataclass):\n    \"\"\"A human body model from the SMPL family.\"\"\"\n\n    faces: Int[Tensor, \"faces 3\"]\n    \"\"\"Vertex indices for mesh faces.\"\"\"\n    J_regressor: Float[Tensor, \"joints+1 verts\"]\n    \"\"\"Linear map from vertex to joint positions.\n    For SMPL-H, 1 root + 21 body joints + 2 * 15 hand joints.\"\"\"\n    parent_indices: tuple[int, ...]\n    \"\"\"Defines kinematic tree. Index of -1 signifies that a joint is defined\n    relative to the root.\"\"\"\n    weights: Float[Tensor, \"verts joints+1\"]\n    \"\"\"LBS weights.\"\"\"\n    posedirs: Float[Tensor, \"verts 3 joints*9\"]\n    \"\"\"Pose blend shape bases.\"\"\"\n    v_template: Float[Tensor, \"verts 3\"]\n    \"\"\"Canonical mesh verts.\"\"\"\n    shapedirs: Float[Tensor, \"verts 3 n_betas\"]\n    \"\"\"Shape bases.\"\"\"\n\n    @staticmethod\n    def load(model_path: Path) -> SmplhModel:\n        \"\"\"Load a body model from an NPZ file.\"\"\"\n        params_numpy: dict[str, np.ndarray] = {\n            k: _normalize_dtype(v)\n            for k, v in np.load(model_path, allow_pickle=True).items()\n        }\n        assert (\n            \"bs_style\" not in params_numpy\n            or params_numpy.pop(\"bs_style\").item() == b\"lbs\"\n        )\n        assert (\n            \"bs_type\" not in params_numpy\n            or params_numpy.pop(\"bs_type\").item() == b\"lrotmin\"\n        )\n        parent_indices = tuple(\n            int(index) for index in params_numpy.pop(\"kintree_table\")[0][1:] - 1\n        )\n        params = {\n            k: torch.from_numpy(v)\n            for k, v in params_numpy.items()\n            if v.dtype in (np.int32, np.float32)\n        }\n        return SmplhModel(\n            faces=params[\"f\"],\n            J_regressor=params[\"J_regressor\"],\n            parent_indices=parent_indices,\n            weights=params[\"weights\"],\n            posedirs=params[\"posedirs\"],\n            v_template=params[\"v_template\"],\n            shapedirs=params[\"shapedirs\"],\n        )\n\n    def get_num_joints(self) -> int:\n        \"\"\"Get the number of joints in this model.\"\"\"\n        return len(self.parent_indices)\n\n    def with_shape(self, betas: Float[Tensor, \"*#batch n_betas\"]) -> SmplhShaped:\n        \"\"\"Compute a new body model, with betas applied.\"\"\"\n        num_betas = betas.shape[-1]\n        assert num_betas <= self.shapedirs.shape[-1]\n        verts_with_shape = self.v_template + einsum(\n            self.shapedirs[:, :, :num_betas],\n            betas,\n            \"verts xyz beta, ... beta -> ... verts xyz\",\n        )\n        root_and_joints_pred = einsum(\n            self.J_regressor,\n            verts_with_shape,\n            \"jointsp1 verts, ... verts xyz -> ... jointsp1 xyz\",\n        )\n        root_offset = root_and_joints_pred[..., 0:1, :]\n        return SmplhShaped(\n            body_model=self,\n            root_offset=root_offset.unsqueeze(-2),\n            verts_zero=verts_with_shape - root_offset,\n            joints_zero=root_and_joints_pred[..., 1:, :] - root_offset,\n            t_parent_joint=root_and_joints_pred[..., 1:, :]\n            - root_and_joints_pred[..., np.array(self.parent_indices) + 1, :],\n        )\n\n\nclass SmplhShaped(TensorDataclass):\n    \"\"\"The SMPL-H body model with a body shape applied.\"\"\"\n\n    body_model: SmplhModel\n    \"\"\"The underlying body model.\"\"\"\n    root_offset: Float[Tensor, \"*#batch 3\"]\n    verts_zero: Float[Tensor, \"*#batch verts 3\"]\n    \"\"\"Vertices of shaped body _relative to the root joint_ at the zero\n    configuration.\"\"\"\n    joints_zero: Float[Tensor, \"*#batch joints 3\"]\n    \"\"\"Joints of shaped body _relative to the root joint_ at the zero\n    configuration.\"\"\"\n    t_parent_joint: Float[Tensor, \"*#batch joints 3\"]\n    \"\"\"Position of each shaped body joint relative to its parent. Does not\n    include root.\"\"\"\n\n    def with_pose_decomposed(\n        self,\n        T_world_root: Float[Tensor, \"*#batch 7\"],\n        body_quats: Float[Tensor, \"*#batch 21 4\"],\n        left_hand_quats: Float[Tensor, \"*#batch 15 4\"] | None = None,\n        right_hand_quats: Float[Tensor, \"*#batch 15 4\"] | None = None,\n    ) -> SmplhShapedAndPosed:\n        \"\"\"Pose our SMPL-H body model. Returns a set of joint and vertex outputs.\"\"\"\n\n        num_joints = self.body_model.get_num_joints()\n        batch_axes = body_quats.shape[:-2]\n        if left_hand_quats is None:\n            left_hand_quats = body_quats.new_zeros((*batch_axes, 15, 4))\n            left_hand_quats[..., 0] = 1.0\n        if right_hand_quats is None:\n            right_hand_quats = body_quats.new_zeros((*batch_axes, 15, 4))\n            right_hand_quats[..., 0] = 1.0\n        local_quats = broadcasting_cat(\n            [body_quats, left_hand_quats, right_hand_quats], dim=-2\n        )\n        assert local_quats.shape[-2:] == (num_joints, 4)\n        return self.with_pose(T_world_root, local_quats)\n\n    def with_pose(\n        self,\n        T_world_root: Float[Tensor, \"*#batch 7\"],\n        local_quats: Float[Tensor, \"*#batch joints 4\"],\n    ) -> SmplhShapedAndPosed:\n        \"\"\"Pose our SMPL-H body model. Returns a set of joint and vertex outputs.\"\"\"\n\n        # Forward kinematics.\n        num_joints = self.body_model.get_num_joints()\n        assert local_quats.shape[-2:] == (num_joints, 4)\n        Ts_world_joint = forward_kinematics(\n            T_world_root=T_world_root,\n            Rs_parent_joint=local_quats,\n            t_parent_joint=self.t_parent_joint,\n            parent_indices=self.body_model.parent_indices,\n        )\n        assert Ts_world_joint.shape[-2:] == (num_joints, 7)\n        return SmplhShapedAndPosed(\n            shaped_model=self,\n            T_world_root=T_world_root,\n            local_quats=local_quats,\n            Ts_world_joint=Ts_world_joint,\n        )\n\n\nclass SmplhShapedAndPosed(TensorDataclass):\n    shaped_model: SmplhShaped\n    \"\"\"Underlying shaped body model.\"\"\"\n\n    T_world_root: Float[Tensor, \"*#batch 7\"]\n    \"\"\"Root coordinate frame.\"\"\"\n\n    local_quats: Float[Tensor, \"*#batch joints 4\"]\n    \"\"\"Local joint orientations.\"\"\"\n\n    Ts_world_joint: Float[Tensor, \"*#batch joints 7\"]\n    \"\"\"Absolute transform for each joint. Does not include the root.\"\"\"\n\n    def with_new_T_world_root(\n        self, T_world_root: Float[Tensor, \"*#batch 7\"]\n    ) -> SmplhShapedAndPosed:\n        return SmplhShapedAndPosed(\n            shaped_model=self.shaped_model,\n            T_world_root=T_world_root,\n            local_quats=self.local_quats,\n            Ts_world_joint=(\n                SE3(T_world_root[..., None, :])\n                @ SE3(self.T_world_root[..., None, :]).inverse()\n                @ SE3(self.Ts_world_joint)\n            ).parameters(),\n        )\n\n    def lbs(self) -> SmplMesh:\n        \"\"\"Compute a mesh with LBS.\"\"\"\n        num_joints = self.local_quats.shape[-2]\n        verts_with_blend = self.shaped_model.verts_zero + einsum(\n            self.shaped_model.body_model.posedirs,\n            (\n                SO3(self.local_quats).as_matrix()\n                - torch.eye(\n                    3, dtype=self.local_quats.dtype, device=self.local_quats.device\n                )\n            ).reshape((*self.local_quats.shape[:-2], num_joints * 9)),\n            \"... verts j joints_times_9, ... joints_times_9 -> ... verts j\",\n        )\n        verts_transformed = einsum(\n            broadcasting_cat(\n                [\n                    SE3(self.T_world_root).as_matrix()[..., None, :3, :],\n                    SE3(self.Ts_world_joint).as_matrix()[..., :, :3, :],\n                ],\n                dim=-3,\n            ),\n            self.shaped_model.body_model.weights,\n            broadcasting_cat(\n                [\n                    verts_with_blend[..., :, None, :]\n                    - broadcasting_cat(  # Prepend root to joints zeros.\n                        [\n                            self.shaped_model.joints_zero.new_zeros(3),\n                            self.shaped_model.joints_zero[..., None, :, :],\n                        ],\n                        dim=-2,\n                    ),\n                    verts_with_blend.new_ones(\n                        (\n                            *verts_with_blend.shape[:-1],\n                            1 + self.shaped_model.joints_zero.shape[-2],\n                            1,\n                        )\n                    ),\n                ],\n                dim=-1,\n            ),\n            \"... joints_p1 i j, verts joints_p1, ... verts joints_p1 j -> ... verts i\",\n        )\n        assert (\n            verts_transformed.shape[-2:]\n            == self.shaped_model.body_model.v_template.shape\n        )\n        return SmplMesh(\n            posed_model=self,\n            verts=verts_transformed,\n            faces=self.shaped_model.body_model.faces,\n        )\n\n\nclass SmplMesh(TensorDataclass):\n    \"\"\"Outputs from the SMPL-H model.\"\"\"\n\n    posed_model: SmplhShapedAndPosed\n    \"\"\"Posed model that this mesh was computed for.\"\"\"\n\n    verts: Float[Tensor, \"*#batch verts 3\"]\n    \"\"\"Vertices for mesh.\"\"\"\n\n    faces: Int[Tensor, \"verts 3\"]\n    \"\"\"Faces for mesh.\"\"\"\n\n\ndef forward_kinematics(\n    T_world_root: Float[Tensor, \"*#batch 7\"],\n    Rs_parent_joint: Float[Tensor, \"*#batch joints 4\"],\n    t_parent_joint: Float[Tensor, \"*#batch joints 3\"],\n    parent_indices: tuple[int, ...],\n) -> Float[Tensor, \"*#batch joints 7\"]:\n    \"\"\"Run forward kinematics to compute absolute poses (T_world_joint) for\n    each joint. The output array containts pose parameters\n    (w, x, y, z, tx, ty, tz) for each joint. (this does not include the root!)\n\n    Args:\n        T_world_root: Transformation to world frame from root frame.\n        Rs_parent_joint: Local orientation of each joint.\n        t_parent_joint: Position of each joint with respect to its parent frame. (this does not\n            depend on local joint orientations)\n        parent_indices: Parent index for each joint. Index of -1 signifies that\n            a joint is defined relative to the root. We assume that this array is\n            sorted: parent joints should always precede child joints.\n\n    Returns:\n        Transformations to world frame from each joint frame.\n    \"\"\"\n\n    # Check shapes.\n    num_joints = len(parent_indices)\n    assert Rs_parent_joint.shape[-2:] == (num_joints, 4)\n    assert t_parent_joint.shape[-2:] == (num_joints, 3)\n\n    # Get relative transforms.\n    Ts_parent_child = broadcasting_cat([Rs_parent_joint, t_parent_joint], dim=-1)\n    assert Ts_parent_child.shape[-2:] == (num_joints, 7)\n\n    # Compute one joint at a time.\n    list_Ts_world_joint: list[Tensor] = []\n    for i in range(num_joints):\n        if parent_indices[i] == -1:\n            T_world_parent = T_world_root\n        else:\n            T_world_parent = list_Ts_world_joint[parent_indices[i]]\n        list_Ts_world_joint.append(\n            (SE3(T_world_parent) @ SE3(Ts_parent_child[..., i, :])).wxyz_xyz\n        )\n\n    Ts_world_joint = torch.stack(list_Ts_world_joint, dim=-2)\n    assert Ts_world_joint.shape[-2:] == (num_joints, 7)\n    return Ts_world_joint\n\n\ndef broadcasting_cat(tensors: list[Tensor], dim: int) -> Tensor:\n    \"\"\"Like torch.cat, but broadcasts.\"\"\"\n    assert len(tensors) > 0\n    output_dims = max(map(lambda t: len(t.shape), tensors))\n    tensors = [\n        t.reshape((1,) * (output_dims - len(t.shape)) + t.shape) for t in tensors\n    ]\n    max_sizes = [max(t.shape[i] for t in tensors) for i in range(output_dims)]\n    expanded_tensors = [\n        tensor.expand(\n            *(\n                tensor.shape[i] if i == dim % len(tensor.shape) else max_size\n                for i, max_size in enumerate(max_sizes)\n            )\n        )\n        for tensor in tensors\n    ]\n    return torch.cat(expanded_tensors, dim=dim)\n\n\ndef _normalize_dtype(v: np.ndarray) -> np.ndarray:\n    \"\"\"Normalize datatypes; all arrays should be either int32 or float32.\"\"\"\n    if \"int\" in str(v.dtype):\n        return v.astype(np.int32)\n    elif \"float\" in str(v.dtype):\n        return v.astype(np.float32)\n    else:\n        return v\n"
  },
  {
    "path": "src/egoallo/fncsmpl_extensions.py",
    "content": "\"\"\"EgoAllo-specific SMPL utilities.\"\"\"\n\nfrom __future__ import annotations\n\n\nimport numpy as np\nimport torch\nfrom jaxtyping import Float\nfrom torch import Tensor\n\nfrom . import fncsmpl, transforms\n\n\ndef get_T_world_cpf(mesh: fncsmpl.SmplMesh) -> Float[Tensor, \"*#batch 7\"]:\n    \"\"\"Get the central pupil frame from a mesh. This assumes that we're using the SMPL-H model.\"\"\"\n\n    assert mesh.verts.shape[-2:] == (6890, 3), \"Not using SMPL-H model!\"\n    right_eye = (mesh.verts[..., 6260, :] + mesh.verts[..., 6262, :]) / 2.0\n    left_eye = (mesh.verts[..., 2800, :] + mesh.verts[..., 2802, :]) / 2.0\n\n    # CPF is between the two eyes.\n    cpf_pos = (right_eye + left_eye) / 2.0\n    # Get orientation from head.\n    cpf_orientation = mesh.posed_model.Ts_world_joint[..., 14, :4]\n\n    return torch.cat([cpf_orientation, cpf_pos], dim=-1)\n\n\ndef get_T_head_cpf(shaped: fncsmpl.SmplhShaped) -> Float[Tensor, \"*#batch 7\"]:\n    \"\"\"Get the central pupil frame with respect to the head (joint 14). This\n    assumes that we're using the SMPL-H model.\"\"\"\n\n    verts_zero = shaped.verts_zero\n\n    assert verts_zero.shape[-2:] == (6890, 3), \"Not using SMPL-H model!\"\n    right_eye = (verts_zero[..., 6260, :] + verts_zero[..., 6262, :]) / 2.0\n    left_eye = (verts_zero[..., 2800, :] + verts_zero[..., 2802, :]) / 2.0\n\n    # CPF is between the two eyes.\n    cpf_pos_wrt_head = (right_eye + left_eye) / 2.0 - shaped.joints_zero[..., 14, :]\n\n    return fncsmpl.broadcasting_cat(\n        [\n            transforms.SO3.identity(\n                device=cpf_pos_wrt_head.device, dtype=cpf_pos_wrt_head.dtype\n            ).wxyz,\n            cpf_pos_wrt_head,\n        ],\n        dim=-1,\n    )\n\n\ndef get_T_world_root_from_cpf_pose(\n    posed: fncsmpl.SmplhShapedAndPosed,\n    Ts_world_cpf: Float[Tensor | np.ndarray, \"... 7\"],\n) -> Float[Tensor, \"... 7\"]:\n    \"\"\"Get the root transform that would align the CPF frame of `posed` to `Ts_world_cpf`.\"\"\"\n    device = posed.Ts_world_joint.device\n    dtype = posed.Ts_world_joint.dtype\n\n    if isinstance(Ts_world_cpf, np.ndarray):\n        Ts_world_cpf = torch.from_numpy(Ts_world_cpf).to(device=device, dtype=dtype)\n\n    assert Ts_world_cpf.shape[-1] == 7\n    T_world_root = (\n        # T_world_cpf\n        transforms.SE3(Ts_world_cpf)\n        # T_cpf_head\n        @ transforms.SE3(get_T_head_cpf(posed.shaped_model)).inverse()\n        # T_head_world\n        @ transforms.SE3(posed.Ts_world_joint[..., 14, :]).inverse()\n        # T_world_root\n        @ transforms.SE3(posed.T_world_root)\n    )\n    return T_world_root.wxyz_xyz\n"
  },
  {
    "path": "src/egoallo/fncsmpl_jax.py",
    "content": "\"\"\"SMPL-H model, implemented in JAX.\n\nVery little of it is specific to SMPL-H. This could very easily be adapted for other models in SMPL family.\n\"\"\"\n\nfrom __future__ import annotations\n\nfrom pathlib import Path\nfrom typing import Sequence, cast\n\nimport jax\nimport jax_dataclasses as jdc\nimport jaxlie\nimport numpy as onp\nfrom einops import einsum\nfrom jax import Array\nfrom jax import numpy as jnp\nfrom jaxtyping import Float, Int\n\n\n@jdc.pytree_dataclass\nclass SmplhModel:\n    \"\"\"The SMPL-H human body model.\"\"\"\n\n    faces: Int[Array, \"faces 3\"]\n    \"\"\"Vertex indices for mesh faces.\"\"\"\n    J_regressor: Float[Array, \"joints+1 verts\"]\n    \"\"\"Linear map from vertex to joint positions.\n    22 body joints + 2 * 22 hand joints.\"\"\"\n    parent_indices: Int[Array, \"joints\"]\n    \"\"\"Defines kinematic tree. Index of -1 signifies that a joint is defined\n    relative to the root.\"\"\"\n    weights: Float[Array, \"verts joints+1\"]\n    \"\"\"LBS weights.\"\"\"\n    posedirs: Float[Array, \"verts 3 joints*9\"]\n    \"\"\"Pose blend shape bases.\"\"\"\n    v_template: Float[Array, \"verts 3\"]\n    \"\"\"Canonical mesh verts.\"\"\"\n    shapedirs: Float[Array, \"verts 3 n_betas\"]\n    \"\"\"Shape bases.\"\"\"\n\n    @staticmethod\n    def load(npz_path: Path) -> SmplhModel:\n        smplh_params: dict[str, onp.ndarray] = onp.load(npz_path, allow_pickle=True)\n        # assert smplh_params[\"bs_style\"].item() == b\"lbs\"\n        # assert smplh_params[\"bs_type\"].item() == b\"lrotmin\"\n        smplh_params = {k: _normalize_dtype(v) for k, v in smplh_params.items()}\n        return SmplhModel(\n            faces=jnp.array(smplh_params[\"f\"]),\n            J_regressor=jnp.array(smplh_params[\"J_regressor\"]),\n            parent_indices=jnp.array(smplh_params[\"kintree_table\"][0][1:] - 1),\n            weights=jnp.array(smplh_params[\"weights\"]),\n            posedirs=jnp.array(smplh_params[\"posedirs\"]),\n            v_template=jnp.array(smplh_params[\"v_template\"]),\n            shapedirs=jnp.array(smplh_params[\"shapedirs\"]),\n        )\n\n    def with_shape(\n        self, betas: Float[Array | onp.ndarray, \"... n_betas\"]\n    ) -> SmplhShaped:\n        \"\"\"Compute a new body model, with betas applied. betas vector should\n        have shape up to (16,).\"\"\"\n        num_betas = betas.shape[-1]\n        assert num_betas <= 16\n        verts_with_shape = self.v_template + einsum(\n            self.shapedirs[:, :, :num_betas],\n            betas,\n            \"verts xyz beta, ... beta -> ... verts xyz\",\n        )\n        root_and_joints_pred = einsum(\n            self.J_regressor,\n            verts_with_shape,\n            \"joints verts, ... verts xyz -> ... joints xyz\",\n        )\n        root_offset = root_and_joints_pred[..., 0:1, :]\n        return SmplhShaped(\n            body_model=self,\n            verts_zero=verts_with_shape - root_offset,\n            joints_zero=root_and_joints_pred[..., 1:, :] - root_offset,\n            t_parent_joint=root_and_joints_pred[..., 1:, :]\n            - root_and_joints_pred[..., self.parent_indices + 1, :],\n        )\n\n\n@jdc.pytree_dataclass\nclass SmplhShaped:\n    \"\"\"The SMPL-H body model with a body shape applied.\"\"\"\n\n    body_model: SmplhModel\n    verts_zero: Float[Array, \"verts 3\"]\n    \"\"\"Vertices of shaped body _relative to the root joint_ at the zero\n    configuration.\"\"\"\n    joints_zero: Float[Array, \"joints 3\"]\n    \"\"\"Joints of shaped body _relative to the root joint_ at the zero\n    configuration.\"\"\"\n    t_parent_joint: Float[Array, \"joints 3\"]\n    \"\"\"Position of each shaped body joint relative to its parent. Does not\n    include root.\"\"\"\n\n    def with_pose_decomposed(\n        self,\n        T_world_root: Float[Array | onp.ndarray, \"7\"],\n        body_quats: Float[Array | onp.ndarray, \"21 4\"],\n        left_hand_quats: Float[Array | onp.ndarray, \"15 4\"] | None = None,\n        right_hand_quats: Float[Array | onp.ndarray, \"15 4\"] | None = None,\n    ) -> SmplhShapedAndPosed:\n        \"\"\"Pose our SMPL-H body model. Returns a set of joint and vertex outputs.\"\"\"\n\n        if left_hand_quats is None:\n            left_hand_quats = jnp.zeros((15, 4)).at[:, 0].set(1.0)\n        if right_hand_quats is None:\n            right_hand_quats = jnp.zeros((15, 4)).at[:, 0].set(1.0)\n        local_quats = broadcasting_cat(\n            cast(list[jax.Array], [body_quats, left_hand_quats, right_hand_quats]),\n            axis=0,\n        )\n        assert local_quats.shape[-2:] == (51, 4)\n        return self.with_pose(T_world_root, local_quats)\n\n    def with_pose(\n        self,\n        T_world_root: Float[Array | onp.ndarray, \"... 7\"],\n        local_quats: Float[Array | onp.ndarray, \"... num_joints 4\"],\n    ) -> SmplhShapedAndPosed:\n        \"\"\"Pose our SMPL-H body model. Returns a set of joint and vertex outputs.\"\"\"\n\n        # Forward kinematics.\n        # assert local_quats.shape == (51, 4), local_quats.shape\n        parent_indices = self.body_model.parent_indices\n        (num_joints,) = parent_indices.shape[-1:]\n        num_active_joints, _ = local_quats.shape[-2:]\n        assert local_quats.shape[-1] == 4\n        assert num_active_joints <= num_joints\n        assert self.t_parent_joint.shape[-2:] == (num_joints, 3)\n\n        # Get relative transforms.\n        Ts_parent_child = broadcasting_cat(\n            [local_quats, self.t_parent_joint[..., :num_active_joints, :]], axis=-1\n        )\n        assert Ts_parent_child.shape[-2:] == (num_active_joints, 7)\n\n        # Compute one joint at a time.\n        def compute_joint(i: int, Ts_world_joint: Array) -> Array:\n            T_world_parent = jnp.where(\n                parent_indices[i] == -1,\n                T_world_root,\n                Ts_world_joint[..., parent_indices[i], :],\n            )\n            return Ts_world_joint.at[..., i, :].set(\n                (\n                    jaxlie.SE3(T_world_parent) @ jaxlie.SE3(Ts_parent_child[..., i, :])\n                ).wxyz_xyz\n            )\n\n        Ts_world_joint = jax.lax.fori_loop(\n            lower=0,\n            upper=num_joints,\n            body_fun=compute_joint,\n            init_val=jnp.zeros_like(Ts_parent_child),\n        )\n        assert Ts_world_joint.shape[-2:] == (num_active_joints, 7)\n\n        return SmplhShapedAndPosed(\n            shaped_model=self,\n            T_world_root=T_world_root,  # type: ignore\n            local_quats=local_quats,  # type: ignore\n            Ts_world_joint=Ts_world_joint,\n        )\n\n    def get_T_head_cpf(self) -> Float[Array, \"7\"]:\n        \"\"\"Get the central pupil frame with respect to the head (joint 14). This\n        assumes that we're using the SMPL-H model.\"\"\"\n\n        assert self.verts_zero.shape[-2:] == (6890, 3), \"Not using SMPL-H model!\"\n        right_eye = (\n            self.verts_zero[..., 6260, :] + self.verts_zero[..., 6262, :]\n        ) / 2.0\n        left_eye = (self.verts_zero[..., 2800, :] + self.verts_zero[..., 2802, :]) / 2.0\n\n        # CPF is between the two eyes.\n        cpf_pos_wrt_head = (right_eye + left_eye) / 2.0 - self.joints_zero[..., 14, :]\n\n        return broadcasting_cat([jaxlie.SO3.identity().wxyz, cpf_pos_wrt_head], axis=-1)\n\n\n@jdc.pytree_dataclass\nclass SmplhShapedAndPosed:\n    shaped_model: SmplhShaped\n    \"\"\"Underlying shaped body model.\"\"\"\n\n    T_world_root: Float[Array, \"*#batch 7\"]\n    \"\"\"Root coordinate frame.\"\"\"\n\n    local_quats: Float[Array, \"*#batch joints 4\"]\n    \"\"\"Local joint orientations.\"\"\"\n\n    Ts_world_joint: Float[Array, \"joints 7\"]\n    \"\"\"Absolute transform for each joint. Does not include the root.\"\"\"\n\n    def with_new_T_world_root(\n        self, T_world_root: Float[Array, \"*#batch 7\"]\n    ) -> SmplhShapedAndPosed:\n        return SmplhShapedAndPosed(\n            shaped_model=self.shaped_model,\n            T_world_root=T_world_root,\n            local_quats=self.local_quats,\n            Ts_world_joint=(\n                jaxlie.SE3(T_world_root[..., None, :])\n                @ jaxlie.SE3(self.T_world_root[..., None, :]).inverse()\n                @ jaxlie.SE3(self.Ts_world_joint)\n            ).parameters(),\n        )\n\n    def lbs(self) -> SmplhMesh:\n        assert (\n            self.local_quats.shape[0]\n            == self.shaped_model.body_model.parent_indices.shape[0]\n        ), (\n            \"It looks like only a partial set of joint rotations was passed into `with_pose()`. We need all of them for LBS.\"\n        )\n\n        # Linear blend skinning with a pose blend shape.\n        verts_with_blend = self.shaped_model.verts_zero + einsum(\n            self.shaped_model.body_model.posedirs,\n            (jaxlie.SO3(self.local_quats).as_matrix() - jnp.eye(3)).flatten(),\n            \"verts j joints_times_9, ... joints_times_9 -> ... verts j\",\n        )\n        verts_transformed = einsum(\n            broadcasting_cat(\n                [\n                    # (*, 1, 3, 4)\n                    jaxlie.SE3(self.T_world_root).as_matrix()[..., None, :3, :],\n                    # (*, 51, 3, 4)\n                    jaxlie.SE3(self.Ts_world_joint).as_matrix()[..., :3, :],\n                ],\n                axis=0,\n            ),\n            self.shaped_model.body_model.weights,\n            jnp.pad(\n                verts_with_blend[:, None, :]\n                - jnp.concatenate(\n                    [\n                        jnp.zeros((1, 1, 3)),  # Root joint.\n                        self.shaped_model.joints_zero[None, :, :],\n                    ],\n                    axis=1,\n                ),\n                ((0, 0), (0, 0), (0, 1)),\n                constant_values=1.0,\n            ),\n            \"joints_p1 i j, ... verts joints_p1, ... verts joints_p1 j -> ... verts i\",\n        )\n\n        return SmplhMesh(\n            posed_model=self,\n            verts=verts_transformed,\n            faces=self.shaped_model.body_model.faces,\n        )\n\n\n@jdc.pytree_dataclass\nclass SmplhMesh:\n    posed_model: SmplhShapedAndPosed\n\n    verts: Float[Array, \"verts 3\"]\n    \"\"\"Vertices for mesh.\"\"\"\n\n    faces: Int[Array, \"13776 3\"]\n    \"\"\"Faces for mesh.\"\"\"\n\n\ndef broadcasting_cat(arrays: Sequence[jax.Array | onp.ndarray], axis: int) -> jax.Array:\n    \"\"\"Like jnp.concatenate, but broadcasts leading axes.\"\"\"\n    assert len(arrays) > 0\n    output_dims = max(map(lambda t: len(t.shape), arrays))\n    arrays = [t.reshape((1,) * (output_dims - len(t.shape)) + t.shape) for t in arrays]\n    max_sizes = [max(t.shape[i] for t in arrays) for i in range(output_dims)]\n    expanded_arrays = [\n        jnp.broadcast_to(\n            array,\n            tuple(\n                array.shape[i] if i == axis % len(array.shape) else max_size\n                for i, max_size in enumerate(max_sizes)\n            ),\n        )\n        for array in arrays\n    ]\n    return jnp.concatenate(expanded_arrays, axis=axis)\n\n\ndef _normalize_dtype(v: onp.ndarray) -> onp.ndarray:\n    \"\"\"Normalize datatypes; all arrays should be either int32 or float32.\"\"\"\n    if \"int\" in str(v.dtype):\n        return v.astype(onp.int32)\n    elif \"float\" in str(v.dtype):\n        return v.astype(onp.float32)\n    else:\n        return v\n"
  },
  {
    "path": "src/egoallo/guidance_optimizer_jax.py",
    "content": "\"\"\"Optimize constraints using Levenberg-Marquardt.\"\"\"\n\nfrom __future__ import annotations\n\nimport os\n\nfrom .hand_detection_structs import (\n    CorrespondedAriaHandWristPoseDetections,\n    CorrespondedHamerDetections,\n)\n\n# Need to play nice with PyTorch!\nos.environ[\"XLA_PYTHON_CLIENT_PREALLOCATE\"] = \"false\"\n\nimport dataclasses\nimport time\nfrom functools import partial\nfrom typing import Callable, Literal, Unpack, assert_never, cast\n\nimport jax\nimport jax_dataclasses as jdc\nimport jaxlie\nimport jaxls\nimport numpy as onp\nimport torch\nfrom jax import numpy as jnp\nfrom jaxtyping import Float, Int\nfrom torch import Tensor\n\nfrom . import fncsmpl, fncsmpl_jax, network\nfrom .transforms._so3 import SO3\n\n\ndef do_guidance_optimization(\n    Ts_world_cpf: Float[Tensor, \"time 7\"],\n    traj: network.EgoDenoiseTraj,\n    body_model: fncsmpl.SmplhModel,\n    guidance_mode: GuidanceMode,\n    phase: Literal[\"inner\", \"post\"],\n    hamer_detections: None | CorrespondedHamerDetections,\n    aria_detections: None | CorrespondedAriaHandWristPoseDetections,\n    verbose: bool,\n) -> tuple[network.EgoDenoiseTraj, dict]:\n    \"\"\"Run an optimizer to apply foot contact constraints.\"\"\"\n\n    assert traj.hand_rotmats is not None\n    guidance_params = JaxGuidanceParams.defaults(guidance_mode, phase)\n\n    start_time = time.time()\n    quats, debug_info = _optimize_vmapped(\n        body=fncsmpl_jax.SmplhModel(\n            faces=cast(jax.Array, body_model.faces.numpy(force=True)),\n            J_regressor=cast(jax.Array, body_model.J_regressor.numpy(force=True)),\n            parent_indices=cast(jax.Array, onp.array(body_model.parent_indices)),\n            weights=cast(jax.Array, body_model.weights.numpy(force=True)),\n            posedirs=cast(jax.Array, body_model.posedirs.numpy(force=True)),\n            v_template=cast(jax.Array, body_model.v_template.numpy(force=True)),\n            shapedirs=cast(jax.Array, body_model.shapedirs.numpy(force=True)),\n        ),\n        Ts_world_cpf=cast(jax.Array, Ts_world_cpf.numpy(force=True)),\n        betas=cast(jax.Array, traj.betas.numpy(force=True)),\n        body_rotmats=cast(jax.Array, traj.body_rotmats.numpy(force=True)),\n        hand_rotmats=cast(jax.Array, traj.hand_rotmats.numpy(force=True)),\n        contacts=cast(jax.Array, traj.contacts.numpy(force=True)),\n        guidance_params=guidance_params,\n        # The hand detections are a torch tensors in a TensorDataclass form. We\n        # use dictionaries to convert to pytrees.\n        hamer_detections=None\n        if hamer_detections is None\n        else hamer_detections.as_nested_dict(numpy=True),\n        aria_detections=None\n        if aria_detections is None\n        else aria_detections.as_nested_dict(numpy=True),\n        verbose=verbose,\n    )\n    rotmats = SO3(\n        torch.from_numpy(onp.array(quats))\n        .to(traj.body_rotmats.dtype)\n        .to(traj.body_rotmats.device)\n    ).as_matrix()\n\n    print(f\"Constraint optimization finished in {time.time() - start_time}sec\")\n    return dataclasses.replace(\n        traj,\n        body_rotmats=rotmats[:, :, :21, :],\n        hand_rotmats=rotmats[:, :, 21:, :],\n    ), debug_info\n\n\nclass _SmplhBodyPosesVar(\n    jaxls.Var[jax.Array],\n    default_factory=lambda: jnp.concatenate(\n        [jnp.ones((21, 1)), jnp.zeros((21, 3))], axis=-1\n    ),\n    retract_fn=lambda val, delta: (\n        jaxlie.SO3(val) @ jaxlie.SO3.exp(delta.reshape(21, 3))\n    ).wxyz,\n    tangent_dim=21 * 3,\n):\n    \"\"\"Variable containing local joint poses for a SMPL-H human.\"\"\"\n\n\nclass _SmplhSingleHandPosesVar(\n    jaxls.Var[jax.Array],\n    default_factory=lambda: jnp.concatenate(\n        [jnp.ones((15, 1)), jnp.zeros((15, 3))], axis=-1\n    ),\n    retract_fn=lambda val, delta: (\n        jaxlie.SO3(val) @ jaxlie.SO3.exp(delta.reshape(15, 3))\n    ).wxyz,\n    tangent_dim=15 * 3,\n):\n    \"\"\"Variable containing local joint poses for one hand of a SMPL-H human.\"\"\"\n\n\n@jdc.jit\ndef _optimize_vmapped(\n    Ts_world_cpf: jax.Array,\n    body: fncsmpl_jax.SmplhModel,\n    betas: jax.Array,\n    body_rotmats: jax.Array,\n    hand_rotmats: jax.Array,\n    contacts: jax.Array,\n    guidance_params: JaxGuidanceParams,\n    hamer_detections: dict | None,\n    aria_detections: dict | None,\n    verbose: jdc.Static[bool],\n) -> tuple[jax.Array, dict]:\n    return jax.vmap(\n        partial(\n            _optimize,\n            Ts_world_cpf=Ts_world_cpf,\n            body=body,\n            guidance_params=guidance_params,\n            hamer_detections=hamer_detections,\n            aria_detections=aria_detections,\n            verbose=verbose,\n        )\n    )(\n        betas=betas,\n        body_rotmats=body_rotmats,\n        hand_rotmats=hand_rotmats,\n        contacts=contacts,\n    )\n\n\n# Modes for guidance.\nGuidanceMode = Literal[\n    # Foot skating only.\n    \"no_hands\",\n    # Only use Aria wrist pose.\n    \"aria_wrist_only\",\n    # Use Aria wrist pose + HaMeR 3D estimates.\n    \"aria_hamer\",\n    # Use only HaMeR 3D estimates.\n    \"hamer_wrist\",\n    # Use HaMeR 3D estimates + reprojection.\n    \"hamer_reproj2\",\n]\n\n\n@jdc.pytree_dataclass\nclass JaxGuidanceParams:\n    prior_quat_weight: float = 1.0\n    prior_pos_weight: float = 5.0\n    body_quat_vel_smoothness_weight: float = 5.0\n    body_quat_smoothness_weight: float = 1.0\n    body_quat_delta_smoothness_weight: float = 10.0\n    skate_weight: float = 30.0\n\n    # Note: this should be quite high. If the hand quaternions aren't\n    # constrained enough the reprojecction loss can get wild.\n    hand_quats: jdc.Static[bool] = True\n    hand_quat_weight = 5.0\n\n    hand_quat_priors: jdc.Static[bool] = True\n    hand_quat_prior_weight = 0.1\n    hand_quat_smoothness_weight = 1.0\n\n    hamer_reproj: jdc.Static[bool] = True\n    hand_reproj_weight: float = 1.0\n\n    hamer_wrist_pose: jdc.Static[bool] = True\n    hamer_abspos_weight: float = 20.0\n    hamer_ori_weight: float = 5.0\n\n    aria_wrists: jdc.Static[bool] = True\n    aria_wrist_pos_weight: float = 50.0\n    aria_wrist_ori_weight: float = 10.0\n\n    # Optimization parameters.\n    lambda_initial: float = 0.1\n    max_iters: jdc.Static[int] = 20\n\n    @staticmethod\n    def defaults(\n        mode: GuidanceMode,\n        phase: Literal[\"inner\", \"post\"],\n    ) -> JaxGuidanceParams:\n        if mode == \"no_hands\":\n            return {\n                \"inner\": JaxGuidanceParams(\n                    hand_quats=False,\n                    hand_quat_priors=False,\n                    hamer_reproj=False,\n                    hamer_wrist_pose=False,\n                    aria_wrists=False,\n                    max_iters=5,\n                ),\n                \"post\": JaxGuidanceParams(\n                    hand_quats=False,\n                    hand_quat_priors=False,\n                    hamer_reproj=False,\n                    hamer_wrist_pose=False,\n                    aria_wrists=False,\n                    max_iters=20,\n                ),\n            }[phase]\n        elif mode == \"aria_wrist_only\":\n            return {\n                \"inner\": JaxGuidanceParams(\n                    hand_quats=False,\n                    hand_quat_priors=True,\n                    hamer_reproj=False,\n                    hamer_wrist_pose=False,\n                    aria_wrists=True,\n                    max_iters=5,\n                ),\n                \"post\": JaxGuidanceParams(\n                    hand_quats=False,\n                    hand_quat_priors=True,\n                    hamer_reproj=False,\n                    hamer_wrist_pose=False,\n                    aria_wrists=True,\n                    max_iters=20,\n                ),\n            }[phase]\n        elif mode == \"aria_hamer\":\n            return {\n                \"inner\": JaxGuidanceParams(\n                    hand_quats=True,\n                    hand_quat_priors=True,\n                    hamer_reproj=False,\n                    hamer_wrist_pose=False,\n                    aria_wrists=True,\n                    max_iters=5,\n                ),\n                \"post\": JaxGuidanceParams(\n                    hand_quats=True,\n                    hand_quat_priors=True,\n                    hamer_reproj=False,\n                    hamer_wrist_pose=False,\n                    aria_wrists=True,\n                    max_iters=20,\n                ),\n            }[phase]\n        elif mode == \"hamer_wrist\":\n            return {\n                \"inner\": JaxGuidanceParams(\n                    hand_quats=True,\n                    hand_quat_priors=True,\n                    # NOTE: we turn off reprojection during the inner loop optimization.\n                    hamer_reproj=False,\n                    hamer_wrist_pose=True,\n                    aria_wrists=False,\n                    max_iters=5,\n                ),\n                \"post\": JaxGuidanceParams(\n                    hand_quats=True,\n                    hand_quat_priors=True,\n                    # Turn on reprojection.\n                    hamer_reproj=False,\n                    hamer_wrist_pose=True,\n                    aria_wrists=False,\n                    max_iters=20,\n                ),\n            }[phase]\n        elif mode == \"hamer_reproj2\":\n            return {\n                \"inner\": JaxGuidanceParams(\n                    hand_quats=True,\n                    hand_quat_priors=True,\n                    # NOTE: we turn off reprojection during the inner loop optimization.\n                    hamer_reproj=False,\n                    hamer_wrist_pose=True,\n                    aria_wrists=False,\n                    max_iters=5,\n                ),\n                \"post\": JaxGuidanceParams(\n                    hand_quats=True,\n                    hand_quat_priors=True,\n                    # Turn on reprojection.\n                    hamer_reproj=True,\n                    hamer_wrist_pose=True,\n                    aria_wrists=False,\n                    max_iters=20,\n                ),\n            }[phase]\n        else:\n            assert_never(mode)\n\n\ndef _optimize(\n    Ts_world_cpf: jax.Array,\n    body: fncsmpl_jax.SmplhModel,\n    betas: jax.Array,\n    body_rotmats: jax.Array,\n    hand_rotmats: jax.Array,\n    contacts: jax.Array,\n    guidance_params: JaxGuidanceParams,\n    hamer_detections: dict | None,\n    aria_detections: dict | None,\n    verbose: bool,\n) -> tuple[jax.Array, dict]:\n    \"\"\"Apply constraints using Levenberg-Marquardt optimizer. Returns updated\n    body_rotmats and hand_rotmats matrices.\"\"\"\n    timesteps = body_rotmats.shape[0]\n    assert Ts_world_cpf.shape == (timesteps, 7)\n    assert body_rotmats.shape == (timesteps, 21, 3, 3)\n    assert hand_rotmats.shape == (timesteps, 30, 3, 3)\n    assert contacts.shape == (timesteps, 21)\n    assert betas.shape == (timesteps, 16)\n\n    init_quats = jaxlie.SO3.from_matrix(\n        # body_rotmats\n        jnp.concatenate([body_rotmats, hand_rotmats], axis=1)\n    ).wxyz\n    assert init_quats.shape == (timesteps, 51, 4)\n\n    # Assume body shape is time-invariant.\n    shaped_body = body.with_shape(jnp.mean(betas, axis=0))\n    T_head_cpf = shaped_body.get_T_head_cpf()\n    T_cpf_head = jaxlie.SE3(T_head_cpf).inverse().parameters()\n    assert T_cpf_head.shape == (7,)\n\n    init_posed = shaped_body.with_pose(\n        jaxlie.SE3.identity(batch_axes=(timesteps,)).wxyz_xyz, init_quats\n    )\n    T_world_head = jaxlie.SE3(Ts_world_cpf) @ jaxlie.SE3(T_cpf_head)\n    T_root_head = jaxlie.SE3(init_posed.Ts_world_joint[:, 14])\n    init_posed = init_posed.with_new_T_world_root(\n        (T_world_head @ T_root_head.inverse()).wxyz_xyz\n    )\n    del T_world_head\n    del T_root_head\n\n    foot_joint_indices = jnp.array([6, 7, 9, 10])\n    num_foot_joints = foot_joint_indices.shape[0]\n\n    contacts = contacts[..., foot_joint_indices]\n    pairwise_contacts = (contacts[:-1, :] + contacts[1:, :]) / 2.0\n    assert pairwise_contacts.shape == (timesteps - 1, num_foot_joints)\n    del contacts\n\n    # We'll populate a list of factors (cost terms).\n    factors = list[jaxls.Cost]()\n\n    def cost_with_args[*CostArgs](\n        *args: Unpack[tuple[*CostArgs]],\n    ) -> Callable[\n        [Callable[[jaxls.VarValues, *CostArgs], jax.Array]],\n        Callable[[jaxls.VarValues, *CostArgs], jax.Array],\n    ]:\n        \"\"\"Decorator for appending to the factor list.\"\"\"\n\n        def inner(\n            cost_func: Callable[[jaxls.VarValues, *CostArgs], jax.Array],\n        ) -> Callable[[jaxls.VarValues, *CostArgs], jax.Array]:\n            factors.append(jaxls.Cost(cost_func, args))\n            return cost_func\n\n        return inner\n\n    def do_forward_kinematics(\n        vals: jaxls.VarValues,\n        var: _SmplhBodyPosesVar,\n        left_hand: _SmplhSingleHandPosesVar | None = None,\n        right_hand: _SmplhSingleHandPosesVar | None = None,\n        output_frame: Literal[\"world\", \"root\"] = \"world\",\n    ) -> fncsmpl_jax.SmplhShapedAndPosed:\n        \"\"\"Helper for computing forward kinematics from variables.\"\"\"\n        assert (left_hand is None) == (right_hand is None)\n        if left_hand is None and right_hand is None:\n            posed = shaped_body.with_pose(\n                T_world_root=jaxlie.SE3.identity().wxyz_xyz,\n                local_quats=vals[var],\n            )\n        elif left_hand is not None and right_hand is None:\n            posed = shaped_body.with_pose(\n                T_world_root=jaxlie.SE3.identity().wxyz_xyz,\n                local_quats=jnp.concatenate([vals[var], vals[left_hand]], axis=-2),\n            )\n        elif left_hand is not None and right_hand is not None:\n            posed = shaped_body.with_pose(\n                T_world_root=jaxlie.SE3.identity().wxyz_xyz,\n                local_quats=jnp.concatenate(\n                    [vals[var], vals[left_hand], vals[right_hand]], axis=-2\n                ),\n            )\n        else:\n            assert False\n\n        if output_frame == \"world\":\n            T_world_root = (\n                # T_world_cpf\n                jaxlie.SE3(Ts_world_cpf[var.id, :])\n                # T_cpf_head\n                @ jaxlie.SE3(T_cpf_head)\n                # T_head_root\n                @ jaxlie.SE3(posed.Ts_world_joint[14]).inverse()\n            )\n            return posed.with_new_T_world_root(T_world_root.wxyz_xyz)\n        elif output_frame == \"root\":\n            return posed\n\n    # HaMeR pose cost.\n    if hamer_detections is not None and guidance_params.hand_quat_priors:\n        hamer_left = hamer_detections[\"detections_left_concat\"]\n        hamer_right = hamer_detections[\"detections_right_concat\"]\n\n        # HaMeR local quaternion smoothness.\n        @(\n            cost_with_args(\n                _SmplhSingleHandPosesVar(jnp.arange(timesteps * 2 - 2)),\n                _SmplhSingleHandPosesVar(jnp.arange(2, timesteps * 2)),\n            )\n        )\n        def hand_smoothness(\n            vals: jaxls.VarValues,\n            hand_pose: _SmplhSingleHandPosesVar,\n            hand_pose_next: _SmplhSingleHandPosesVar,\n        ) -> jax.Array:\n            return (\n                guidance_params.hand_quat_smoothness_weight\n                * (\n                    jaxlie.SO3(vals[hand_pose]).inverse()\n                    @ jaxlie.SO3(vals[hand_pose_next])\n                )\n                .log()\n                .flatten()\n            )\n\n        # Hand prior loss.\n        @cost_with_args(\n            _SmplhSingleHandPosesVar(jnp.arange(timesteps * 2)),\n            init_quats[:, 21:51, :].reshape((timesteps * 2, 15, 4)),\n        )\n        def hand_prior(\n            vals: jaxls.VarValues,\n            hand_pose: _SmplhSingleHandPosesVar,\n            init_hand_quats: jax.Array,\n        ) -> jax.Array:\n            return (\n                guidance_params.hand_quat_prior_weight\n                * (jaxlie.SO3(vals[hand_pose]).inverse() @ jaxlie.SO3(init_hand_quats))\n                .log()\n                .flatten()\n            )\n\n    if hamer_detections is not None and guidance_params.hand_quats:\n        hamer_left = hamer_detections[\"detections_left_concat\"]\n        hamer_right = hamer_detections[\"detections_right_concat\"]\n\n        # HaMeR local pose matching.\n        @(\n            cost_with_args(\n                _SmplhSingleHandPosesVar(hamer_left[\"indices\"] * 2),\n                hamer_left[\"single_hand_quats\"],\n            )\n            if hamer_left is not None\n            else lambda x: x\n        )\n        @(\n            cost_with_args(\n                _SmplhSingleHandPosesVar(hamer_right[\"indices\"] * 2 + 1),\n                hamer_right[\"single_hand_quats\"],\n            )\n            if hamer_right is not None\n            else lambda x: x\n        )\n        def hamer_local_pose_cost(\n            vals: jaxls.VarValues,\n            hand_pose: _SmplhSingleHandPosesVar,\n            estimated_hand_quats: jax.Array,\n        ) -> jax.Array:\n            hand_quats = vals[hand_pose]\n            assert hand_quats.shape == estimated_hand_quats.shape\n            return guidance_params.hand_quat_weight * (\n                (jaxlie.SO3(hand_quats).inverse() @ jaxlie.SO3(estimated_hand_quats))\n                .log()\n                .flatten()\n            )\n\n    if hamer_detections is not None and (\n        guidance_params.hamer_reproj and guidance_params.hamer_wrist_pose\n    ):\n        hamer_left = hamer_detections[\"detections_left_concat\"]\n        hamer_right = hamer_detections[\"detections_right_concat\"]\n\n        # HaMeR reprojection.\n        mano_from_openpose_indices = _get_mano_from_openpose_indices(include_tips=False)\n\n        @(\n            cost_with_args(\n                _SmplhBodyPosesVar(hamer_left[\"indices\"]),\n                _SmplhSingleHandPosesVar(hamer_left[\"indices\"] * 2),\n                _SmplhSingleHandPosesVar(hamer_left[\"indices\"] * 2 + 1),\n                jnp.full_like(hamer_left[\"indices\"], fill_value=0),\n                hamer_left[\"keypoints_3d\"],\n                hamer_left[\"mano_hand_global_orient\"],\n            )\n            if hamer_left is not None\n            else lambda x: x\n        )\n        @(\n            cost_with_args(\n                _SmplhBodyPosesVar(hamer_right[\"indices\"]),\n                _SmplhSingleHandPosesVar(hamer_right[\"indices\"] * 2),\n                _SmplhSingleHandPosesVar(hamer_right[\"indices\"] * 2 + 1),\n                jnp.full_like(hamer_right[\"indices\"], fill_value=1),\n                hamer_right[\"keypoints_3d\"],\n                hamer_right[\"mano_hand_global_orient\"],\n            )\n            if hamer_right is not None\n            else lambda x: x\n        )\n        def hamer_wrist_and_reproj(\n            vals: jaxls.VarValues,\n            body_pose: _SmplhBodyPosesVar,\n            left_hand_pose: _SmplhSingleHandPosesVar,\n            right_hand_pose: _SmplhSingleHandPosesVar,\n            left0_right1: jax.Array,  # Set to 0 for left, 1 for right.\n            keypoints3d_wrt_cam: jax.Array,  # These are in OpenPose order!!\n            Rmat_cam_wrist: jax.Array,\n        ) -> jax.Array:\n            posed = do_forward_kinematics(\n                # The right hand comes _after_ the left hand, we can exclude it.\n                vals,\n                body_pose,\n                left_hand_pose,\n                right_hand_pose,\n                output_frame=\"root\",\n            )\n            Ts_root_joint = posed.Ts_world_joint  # Sorry for the naming...\n            del posed\n\n            # 19 for left wrist, 20 for right wrist.\n            wrist_index = 19 + left0_right1\n            hand_start_index = 21 + 15 * left0_right1\n\n            assert Ts_root_joint.shape == (51, 7)\n            joint_positions_wrt_root = Ts_root_joint[:, 4:7]\n            mano_joints_wrt_root = jnp.concatenate(\n                [\n                    jax.lax.dynamic_slice_in_dim(\n                        joint_positions_wrt_root,\n                        start_index=wrist_index,\n                        slice_size=1,\n                        axis=-2,\n                    ),\n                    jax.lax.dynamic_slice_in_dim(\n                        joint_positions_wrt_root,\n                        start_index=hand_start_index,\n                        slice_size=15,\n                        axis=-2,\n                    ),\n                ],\n                axis=0,\n            )\n            assert mano_joints_wrt_root.shape == (16, 3)\n            assert keypoints3d_wrt_cam.shape == (21, 3)  # In OpenPose.\n\n            T_cam_root = (\n                # T_cam_cpf (7,)\n                jaxlie.SE3(hamer_detections[\"T_cpf_cam\"]).inverse()\n                # T_cpf_head (7,)\n                @ jaxlie.SE3(T_cpf_head)\n                # T_head_root (7,)\n                @ jaxlie.SE3(Ts_root_joint[14, :]).inverse()\n            )\n            assert T_cam_root.parameters().shape == (7,)\n            mano_joints_wrt_cam = T_cam_root @ mano_joints_wrt_root\n            obs_joints_wrt_cam = keypoints3d_wrt_cam[mano_from_openpose_indices, :]\n\n            mano_uv_wrt_cam = mano_joints_wrt_cam[:, :2] / mano_joints_wrt_cam[:, 2:3]\n            obs_uv_wrt_cam = obs_joints_wrt_cam[:, :2] / obs_joints_wrt_cam[:, 2:3]\n\n            T_cam_wrist = jaxlie.SE3.from_rotation_and_translation(\n                T_cam_root.rotation() @ jaxlie.SO3(Ts_root_joint[wrist_index, :4]),\n                mano_joints_wrt_cam[0, :],\n            )\n            obs_T_cam_wrist = jaxlie.SE3.from_rotation_and_translation(\n                jaxlie.SO3.from_matrix(Rmat_cam_wrist),\n                obs_joints_wrt_cam[0, :],\n            )\n\n            return jnp.concatenate(\n                [\n                    (T_cam_wrist.inverse() @ obs_T_cam_wrist).log()\n                    * jnp.array(\n                        [guidance_params.hamer_abspos_weight] * 3\n                        + [guidance_params.hamer_ori_weight] * 3\n                    ),\n                    guidance_params.hand_reproj_weight\n                    * (mano_uv_wrt_cam - obs_uv_wrt_cam).flatten(),\n                ]\n            )\n    elif (\n        hamer_detections is not None\n        and not guidance_params.hamer_reproj\n        and guidance_params.hamer_wrist_pose\n    ):\n        hamer_left = hamer_detections[\"detections_left_concat\"]\n        hamer_right = hamer_detections[\"detections_right_concat\"]\n\n        @(\n            cost_with_args(\n                _SmplhBodyPosesVar(hamer_left[\"indices\"]),\n                jnp.full_like(hamer_left[\"indices\"], fill_value=0),\n                hamer_left[\"keypoints_3d\"],\n                hamer_left[\"mano_hand_global_orient\"],\n            )\n            if hamer_left is not None\n            else lambda x: x\n        )\n        @(\n            cost_with_args(\n                _SmplhBodyPosesVar(hamer_right[\"indices\"]),\n                jnp.full_like(hamer_right[\"indices\"], fill_value=1),\n                hamer_right[\"keypoints_3d\"],\n                hamer_right[\"mano_hand_global_orient\"],\n            )\n            if hamer_right is not None\n            else lambda x: x\n        )\n        def hamer_wrist_only(\n            vals: jaxls.VarValues,\n            body_pose: _SmplhBodyPosesVar,\n            left0_right1: jax.Array,  # Set to 0 for left, 1 for right.\n            keypoints3d_wrt_cam: jax.Array,  # These are in OpenPose order!!\n            Rmat_cam_wrist: jax.Array,\n        ) -> jax.Array:\n            posed = do_forward_kinematics(vals, body_pose, output_frame=\"root\")\n            Ts_root_joint = posed.Ts_world_joint  # Sorry for the naming...\n            del posed\n\n            # 19 for left wrist, 20 for right wrist.\n            wrist_index = 19 + left0_right1\n\n            assert Ts_root_joint.shape == (21, 7)\n            wrist_position_wrt_root = Ts_root_joint[wrist_index, 4:7]\n\n            T_cam_root = (\n                # T_cam_cpf (7,)\n                jaxlie.SE3(hamer_detections[\"T_cpf_cam\"]).inverse()\n                # T_cpf_head (7,)\n                @ jaxlie.SE3(T_cpf_head)\n                # T_head_root (7,)\n                @ jaxlie.SE3(Ts_root_joint[14, :]).inverse()\n            )\n            assert T_cam_root.parameters().shape == (7,)\n            wrist_position_wrt_cam = T_cam_root @ wrist_position_wrt_root\n\n            # Assumes OpenPose root is same as Mano root!!\n            wrist_pos_wrt_cam = keypoints3d_wrt_cam[0, :]\n\n            T_cam_wrist = jaxlie.SE3.from_rotation_and_translation(\n                T_cam_root.rotation() @ jaxlie.SO3(Ts_root_joint[wrist_index, :4]),\n                wrist_position_wrt_cam,\n            )\n            obs_T_cam_wrist = jaxlie.SE3.from_rotation_and_translation(\n                jaxlie.SO3.from_matrix(Rmat_cam_wrist),\n                wrist_pos_wrt_cam,\n            )\n            return (T_cam_wrist.inverse() @ obs_T_cam_wrist).log() * jnp.array(\n                [guidance_params.hamer_abspos_weight] * 3\n                + [guidance_params.hamer_ori_weight] * 3\n            )\n\n    # Wrist pose cost.\n    if aria_detections is not None and guidance_params.aria_wrists:\n        aria_left = aria_detections[\"detections_left_concat\"]\n        aria_right = aria_detections[\"detections_right_concat\"]\n\n        @(\n            cost_with_args(\n                _SmplhBodyPosesVar(aria_left[\"indices\"]),\n                aria_left[\"confidence\"],\n                aria_left[\"wrist_position\"],\n                aria_left[\"palm_position\"],\n                aria_left[\"palm_normal\"],\n                jnp.full_like(aria_left[\"indices\"], fill_value=0),\n            )\n            if aria_left is not None\n            else lambda x: x\n        )\n        @(\n            cost_with_args(\n                _SmplhBodyPosesVar(aria_right[\"indices\"]),\n                aria_right[\"confidence\"],\n                aria_right[\"wrist_position\"],\n                aria_right[\"palm_position\"],\n                aria_right[\"palm_normal\"],\n                jnp.full_like(aria_right[\"indices\"], fill_value=1),\n            )\n            if aria_right is not None\n            else lambda x: x\n        )\n        def wrist_pose_cost(\n            vals: jaxls.VarValues,\n            pose: _SmplhBodyPosesVar,\n            confidence: jax.Array,\n            wrist_position: jax.Array,\n            palm_position: jax.Array,\n            palm_normal: jax.Array,\n            left0_right1: jax.Array,  # Set to 0 for left, 1 for right.\n        ) -> jax.Array:\n            assert wrist_position.shape == (3,)\n            assert left0_right1.shape == ()\n            posed = do_forward_kinematics(vals, pose)\n\n            T_world_wrist = posed.Ts_world_joint[19 + left0_right1]\n\n            pos_cost = (\n                # Left wrist is joint 19, right is joint 20.\n                T_world_wrist[4:7] - wrist_position\n            )\n\n            # Estimate wrist orientation from forward + normal directions.\n            palm_forward = palm_position - wrist_position\n            palm_forward = palm_forward / jnp.linalg.norm(palm_forward)\n            palm_normal = palm_normal / jnp.linalg.norm(palm_normal)\n            palm_forward = (  # Flip palm forward if right hand.\n                palm_forward * jnp.array([1, -1])[left0_right1]\n            )\n            palm_forward = (  # Gram-schmidt for forward direction.\n                palm_forward - jnp.dot(palm_forward, palm_normal) * palm_normal\n            )\n            estimatedR_world_wrist = jaxlie.SO3.from_matrix(\n                jnp.stack(\n                    [\n                        palm_forward,\n                        -palm_normal,\n                        jnp.cross(palm_normal, palm_forward),\n                    ],\n                    axis=1,\n                )\n            )\n            R_world_wrist = jaxlie.SO3(T_world_wrist[:4])\n            ori_cost = (estimatedR_world_wrist.inverse() @ R_world_wrist).log()\n\n            return confidence * jnp.concatenate(\n                [\n                    guidance_params.aria_wrist_pos_weight * pos_cost,\n                    guidance_params.aria_wrist_ori_weight * ori_cost,\n                ]\n            )\n\n    # Per-frame regularization cost.\n    @cost_with_args(\n        _SmplhBodyPosesVar(jnp.arange(timesteps)),\n    )\n    def reg_cost(\n        vals: jaxls.VarValues,\n        pose: _SmplhBodyPosesVar,\n    ) -> jax.Array:\n        posed = do_forward_kinematics(vals, pose)\n\n        torso_indices = jnp.array([0, 1, 2, 5, 8])\n        return jnp.concatenate(\n            [\n                guidance_params.prior_quat_weight\n                * (\n                    jaxlie.SO3(vals[pose]).inverse()\n                    @ jaxlie.SO3(init_quats[pose.id, :21, :])\n                )\n                .log()\n                .flatten(),\n                # Only include some torso joints.\n                guidance_params.prior_pos_weight\n                * (\n                    posed.Ts_world_joint[torso_indices, 4:7]\n                    - init_posed.Ts_world_joint[pose.id, torso_indices, 4:7]\n                ).flatten(),\n            ]\n        )\n\n    @cost_with_args(\n        _SmplhBodyPosesVar(jnp.arange(timesteps - 1)),\n        _SmplhBodyPosesVar(jnp.arange(1, timesteps)),\n    )\n    def delta_smoothness_cost(\n        vals: jaxls.VarValues,\n        current: _SmplhBodyPosesVar,\n        next: _SmplhBodyPosesVar,\n    ) -> jax.Array:\n        curdelt = jaxlie.SO3(vals[current]).inverse() @ jaxlie.SO3(\n            init_quats[current.id, :21, :]\n        )\n        nexdelt = jaxlie.SO3(vals[next]).inverse() @ jaxlie.SO3(\n            init_quats[next.id, :21, :]\n        )\n        return jnp.concatenate(\n            [\n                guidance_params.body_quat_delta_smoothness_weight\n                * (curdelt.inverse() @ nexdelt).log().flatten(),\n                guidance_params.body_quat_smoothness_weight\n                * (jaxlie.SO3(vals[current]).inverse() @ jaxlie.SO3(vals[next]))\n                .log()\n                .flatten(),\n            ]\n        )\n\n    @cost_with_args(\n        _SmplhBodyPosesVar(jnp.arange(timesteps - 2)),\n        _SmplhBodyPosesVar(jnp.arange(1, timesteps - 1)),\n        _SmplhBodyPosesVar(jnp.arange(2, timesteps)),\n    )\n    def vel_smoothness_cost(\n        vals: jaxls.VarValues,\n        t0: _SmplhBodyPosesVar,\n        t1: _SmplhBodyPosesVar,\n        t2: _SmplhBodyPosesVar,\n    ) -> jax.Array:\n        curdelt = jaxlie.SO3(vals[t0]).inverse() @ jaxlie.SO3(vals[t1])\n        nexdelt = jaxlie.SO3(vals[t1]).inverse() @ jaxlie.SO3(vals[t2])\n        return (\n            guidance_params.body_quat_vel_smoothness_weight\n            * (curdelt.inverse() @ nexdelt).log().flatten()\n        )\n\n    @cost_with_args(\n        _SmplhBodyPosesVar(jnp.arange(timesteps - 1)),\n        _SmplhBodyPosesVar(jnp.arange(1, timesteps)),\n        pairwise_contacts,\n    )\n    def skating_cost(\n        vals: jaxls.VarValues,\n        current: _SmplhBodyPosesVar,\n        next: _SmplhBodyPosesVar,\n        foot_contacts: jax.Array,\n    ) -> jax.Array:\n        # Do forward kinematics.\n        posed_current = do_forward_kinematics(vals, current)\n        posed_next = do_forward_kinematics(vals, next)\n        footpos_current = posed_current.Ts_world_joint[foot_joint_indices, 4:7]\n        footpos_next = posed_next.Ts_world_joint[foot_joint_indices, 4:7]\n        assert footpos_current.shape == footpos_next.shape == (num_foot_joints, 3)\n        assert foot_contacts.shape == (num_foot_joints,)\n\n        return (\n            guidance_params.skate_weight\n            * (foot_contacts[:, None] * (footpos_current - footpos_next)).flatten()\n        )\n\n    vars_body_pose = _SmplhBodyPosesVar(jnp.arange(timesteps))\n    vars_hand_pose = _SmplhSingleHandPosesVar(jnp.arange(timesteps * 2))\n    graph = jaxls.LeastSquaresProblem(\n        costs=factors, variables=[vars_body_pose, vars_hand_pose]\n    ).analyze()\n    solutions = graph.solve(\n        initial_vals=jaxls.VarValues.make(\n            [\n                vars_body_pose.with_value(init_quats[:, :21, :]),\n                vars_hand_pose.with_value(\n                    init_quats[:, 21:51, :].reshape((timesteps * 2, 15, 4))\n                ),\n            ]\n        ),\n        linear_solver=\"conjugate_gradient\",\n        trust_region=jaxls.TrustRegionConfig(\n            lambda_initial=guidance_params.lambda_initial\n        ),\n        termination=jaxls.TerminationConfig(max_iterations=guidance_params.max_iters),\n        verbose=verbose,\n    )\n    out_body_quats = solutions[_SmplhBodyPosesVar]\n    assert out_body_quats.shape == (timesteps, 21, 4)\n    out_hand_quats = solutions[_SmplhSingleHandPosesVar].reshape((timesteps, 30, 4))\n    assert out_hand_quats.shape == (timesteps, 30, 4)\n    return (\n        jnp.concatenate([out_body_quats, out_hand_quats], axis=-2),\n        {},  # Metadata dict that we use for debugging.\n    )\n\n\ndef _get_mano_from_openpose_indices(include_tips: bool) -> Int[onp.ndarray, \"21\"]:\n    # https://github.com/geopavlakos/hamer/blob/272d68f176e0ea8a506f761663dd3dca4a03ced0/hamer/models/mano_wrapper.py#L20\n    # fmt: off\n    mano_to_openpose = [0, 13, 14, 15, 16, 1, 2, 3, 17, 4, 5, 6, 18, 10, 11, 12, 19, 7, 8, 9, 20]\n    # fmt: on\n    openpose_from_mano_idx = {\n        mano_idx: openpose_idx for openpose_idx, mano_idx in enumerate(mano_to_openpose)\n    }\n    return onp.array(\n        [openpose_from_mano_idx[i] for i in range(21 if include_tips else 16)]\n    )\n"
  },
  {
    "path": "src/egoallo/hand_detection_structs.py",
    "content": "\"\"\"Data structure definition that we use for hand detections.\n\nWe'll run HaMeR, produce the dictionary defined by `SavedHamerOutputs`, then\npickle this dictionary.\n\"\"\"\n\nfrom __future__ import annotations\n\nimport pickle\nfrom pathlib import Path\nfrom typing import Protocol, TypedDict, cast\n\nimport numpy as np\nimport torch\nfrom jaxtyping import Float, Int\nfrom projectaria_tools.core import mps\nfrom projectaria_tools.core.mps.utils import get_nearest_wrist_and_palm_pose\nfrom torch import Tensor\n\nfrom .tensor_dataclass import TensorDataclass\nfrom .transforms import SE3, SO3\n\n\nclass SingleHandHamerOutputWrtCamera(TypedDict):\n    \"\"\"Hand outputs with respect to the camera frame. For use in pickle files.\"\"\"\n\n    verts: np.ndarray\n    keypoints_3d: np.ndarray\n    mano_hand_pose: np.ndarray\n    mano_hand_betas: np.ndarray\n    mano_hand_global_orient: np.ndarray\n\n\nclass SavedHamerOutputs(TypedDict):\n    \"\"\"Outputs from the HAMeR hand detection algorithm. This is the structure\n    to pickle.\n\n    `detections_left_wrt_cam` and `detections_right_wrt_cam` use nanosecond\n    timestamps as keys.\n    \"\"\"\n\n    mano_faces_right: np.ndarray\n    mano_faces_left: np.ndarray\n\n    detections_left_wrt_cam: dict[int, SingleHandHamerOutputWrtCamera | None]\n    detections_right_wrt_cam: dict[int, SingleHandHamerOutputWrtCamera | None]\n\n    T_device_cam: np.ndarray  # wxyz_xyz\n    T_cpf_cam: np.ndarray  # wxyz_xyz\n\n\nclass AriaHandWristPoseWrtWorld(TensorDataclass):\n    confidence: Float[Tensor, \"n_detections\"]\n    wrist_position: Float[Tensor, \"n_detections 3\"]\n    wrist_normal: Float[Tensor, \"n_detections 3\"]\n\n    palm_position: Float[Tensor, \"n_detections 3\"]\n    palm_normal: Float[Tensor, \"n_detections 3\"]\n\n    indices: Int[Tensor, \"n_detections\"]\n\n\nclass CorrespondedAriaHandWristPoseDetections(TensorDataclass):\n    detections_left_concat: AriaHandWristPoseWrtWorld | None\n    detections_right_concat: AriaHandWristPoseWrtWorld | None\n\n    @staticmethod\n    def load(\n        wrist_and_palm_poses_csv_path: Path,\n        target_timestamps_sec: tuple[float, ...],\n        Ts_world_device: Float[np.ndarray, \"timesteps 7\"],\n    ) -> CorrespondedAriaHandWristPoseDetections:\n        # API from runtime inspection of `projectaria_tools` outputs.\n        class WristAndPalmNormals(Protocol):\n            wrist_normal_device: np.ndarray\n            palm_normal_device: np.ndarray\n\n        class OneSide(Protocol):\n            confidence: float\n            wrist_position_device: np.ndarray\n            palm_position_device: np.ndarray\n            wrist_and_palm_normal_device: WristAndPalmNormals\n\n        wp_poses = mps.hand_tracking.read_wrist_and_palm_poses(\n            str(wrist_and_palm_poses_csv_path)\n        )\n        detections_left = list[OneSide]()\n        detections_right = list[OneSide]()\n        indices_left = list[int]()\n        indices_right = list[int]()\n        for i, time_sec in enumerate(target_timestamps_sec):\n            wp_pose = get_nearest_wrist_and_palm_pose(wp_poses, int(time_sec * 1e9))\n            if (\n                wp_pose is None\n                or abs(wp_pose.tracking_timestamp.total_seconds() - time_sec)\n                >= 1.0 / 30.0\n            ):\n                continue\n\n            if wp_pose.left_hand is not None and wp_pose.left_hand.confidence > 0.7:\n                indices_left.append(i)\n                detections_left.append(wp_pose.left_hand)\n\n            if wp_pose.right_hand is not None and wp_pose.right_hand.confidence > 0.7:\n                indices_right.append(i)\n                detections_right.append(wp_pose.right_hand)\n\n        def form_detections_concat(\n            detections: list[OneSide], indices: list[int]\n        ) -> AriaHandWristPoseWrtWorld | None:\n            assert len(detections) == len(indices)\n            if len(indices) == 0:\n                return None\n\n            Tslice_world_device = SE3(\n                torch.from_numpy(Ts_world_device[np.array(indices), :]).to(\n                    dtype=torch.float32\n                )\n            )\n            Rslice_world_device = SO3(\n                torch.from_numpy(Ts_world_device[np.array(indices), :4]).to(\n                    dtype=torch.float32\n                )\n            )\n\n            return AriaHandWristPoseWrtWorld(\n                confidence=torch.from_numpy(\n                    np.array([d.confidence for d in detections])\n                ),\n                wrist_position=Tslice_world_device\n                @ torch.from_numpy(\n                    np.array(\n                        [d.wrist_position_device for d in detections], dtype=np.float32\n                    )\n                ),\n                wrist_normal=Rslice_world_device\n                @ torch.from_numpy(\n                    np.array(\n                        [\n                            d.wrist_and_palm_normal_device.wrist_normal_device\n                            for d in detections\n                        ],\n                        dtype=np.float32,\n                    )\n                ),\n                palm_position=Tslice_world_device\n                @ torch.from_numpy(\n                    np.array(\n                        [d.palm_position_device for d in detections], dtype=np.float32\n                    )\n                ),\n                palm_normal=Rslice_world_device\n                @ torch.from_numpy(\n                    np.array(\n                        [\n                            d.wrist_and_palm_normal_device.palm_normal_device\n                            for d in detections\n                        ],\n                        dtype=np.float32,\n                    )\n                ),\n                indices=torch.from_numpy(np.array(indices, dtype=np.int64)),\n            )\n\n        return CorrespondedAriaHandWristPoseDetections(\n            detections_left_concat=form_detections_concat(\n                detections_left, indices_left\n            ),\n            detections_right_concat=form_detections_concat(\n                detections_right, indices_right\n            ),\n        )\n\n\nclass SingleHandHamerOutputWrtCameraConcatenated(TensorDataclass):\n    verts: Float[Tensor, \"n_detections n_verts 3\"]\n    keypoints_3d: Float[Tensor, \"n_detections n_keypoints 3\"]\n    mano_hand_global_orient: Float[Tensor, \"n_detections 3 3\"]\n    single_hand_quats: Float[Tensor, \"n_detections 15 3\"]\n    indices: Int[Tensor, \"n_detections\"]\n\n\nclass CorrespondedHamerDetections(TensorDataclass):\n    mano_faces_right: Tensor\n    mano_faces_left: Tensor\n    detections_left_tuple: tuple[None | SingleHandHamerOutputWrtCamera, ...]\n    detections_right_tuple: tuple[None | SingleHandHamerOutputWrtCamera, ...]\n    T_cpf_cam: Tensor\n    focal_length: float\n\n    # Concatenated detections will be None if there are no detections at all.\n    detections_left_concat: None | SingleHandHamerOutputWrtCameraConcatenated\n    detections_right_concat: None | SingleHandHamerOutputWrtCameraConcatenated\n\n    def get_length(self) -> int:\n        assert len(self.detections_left_tuple) == len(self.detections_right_tuple)\n        return len(self.detections_left_tuple)\n\n    def slice(self, start_index: int, end_index: int) -> CorrespondedHamerDetections:\n        \"\"\"Slice the hand detections. Removes unused hand detections, and\n        shifts indices as necessary.\"\"\"\n\n        assert start_index < end_index\n\n        def _get_detections_in_window(\n            detections_side_concat: None | SingleHandHamerOutputWrtCameraConcatenated,\n        ) -> None | SingleHandHamerOutputWrtCameraConcatenated:\n            if detections_side_concat is None:\n                return None\n            else:\n                indices = detections_side_concat.indices\n                indices_mask = (indices >= start_index) & (indices < end_index)\n                out = detections_side_concat.map(lambda x: x[indices_mask].clone())\n                out.indices -= start_index\n                return out\n\n        return CorrespondedHamerDetections(\n            self.mano_faces_right,\n            self.mano_faces_left,\n            self.detections_left_tuple[start_index:end_index],\n            self.detections_right_tuple[start_index:end_index],\n            T_cpf_cam=self.T_cpf_cam,\n            focal_length=self.focal_length,\n            detections_left_concat=_get_detections_in_window(\n                self.detections_left_concat\n            ),\n            detections_right_concat=_get_detections_in_window(\n                self.detections_right_concat\n            ),\n        )\n\n    @staticmethod\n    def load(\n        hand_pkl_path: Path,\n        target_timestamps_sec: tuple[float, ...],\n    ) -> CorrespondedHamerDetections:\n        \"\"\"Helper which takes as input:\n\n        (1) A path to a pickle file containing hand detections through time.\n\n            See feb25_hamer_outputs_from_vrs.py for how this is generated.\n\n        (2) A set of target timestamps, sorted, in seconds.\n\n        We then output a data structure that has hand detections (or `None`) for each target timestamp.\n        \"\"\"\n\n        with open(hand_pkl_path, \"rb\") as f:\n            hamer_out = cast(SavedHamerOutputs, pickle.load(f))\n\n        def match_detections_to_targets(\n            detections_wrt_cam: dict[int, None | SingleHandHamerOutputWrtCamera],\n        ) -> list[None | SingleHandHamerOutputWrtCamera]:\n            # Approximate the frame rate of the detections.\n            est_fps = len(detections_wrt_cam) / (\n                (max(detections_wrt_cam.keys()) - min(detections_wrt_cam.keys())) / 1e9\n            )\n            # Usually framerate is either 10 FPS or 30 FPS. We might want to\n            # run on 1 FPS video in the future, we can tweak this assert if we\n            # run into that...\n            assert 5 < est_fps < 40\n\n            # Get nanosecond timestamps within our target timestamp window.\n            # Note that input dictionary keys are nanosecond timestamps!\n            detect_ns = sorted(\n                [\n                    time_ns\n                    for time_ns in detections_wrt_cam.keys()\n                    if time_ns / 1e9 >= target_timestamps_sec[0] - 1 / est_fps\n                    and time_ns / 1e9 <= target_timestamps_sec[-1] + 1 / est_fps\n                ]\n            )\n            delta_matrix = np.abs(\n                np.array(target_timestamps_sec)[:, None]\n                - np.array(detect_ns)[None, :] / 1e9\n            )\n\n            # For each target, which is the closest detection?\n            best_det_from_target = np.argmin(delta_matrix, axis=-1)\n\n            # For each detection, which is the closest target?\n            best_target_from_det = np.argmin(delta_matrix, axis=0)\n\n            # Get detection list; we do a cycle-consistency check to make sure\n            # we get a 1-to-1 mapping.\n            out: list[None | SingleHandHamerOutputWrtCamera] = []\n            for i in range(len(target_timestamps_sec)):\n                if best_target_from_det[best_det_from_target[i]] == i:\n                    out.append(detections_wrt_cam[detect_ns[best_det_from_target[i]]])\n                else:\n                    out.append(None)\n            return out\n\n        detections_left = match_detections_to_targets(\n            hamer_out[\"detections_left_wrt_cam\"]\n        )\n        detections_right = match_detections_to_targets(\n            hamer_out[\"detections_right_wrt_cam\"]\n        )\n        assert (\n            len(detections_left) == len(detections_right) == len(target_timestamps_sec)\n        )\n\n        def make_concat_detections(\n            detections_side: list[None | SingleHandHamerOutputWrtCamera],\n            detections_other_side: list[None | SingleHandHamerOutputWrtCamera],\n        ) -> None | SingleHandHamerOutputWrtCameraConcatenated:\n            detections_side_concat = None\n\n            # Filter out HaMeR detections that are in the same location as each\n            # other.\n            # Sometimes we have a left detections and a right detection both in\n            # the same location. This filters both out.\n            detections_side_filtered: list[None | SingleHandHamerOutputWrtCamera] = []\n            for i, d in enumerate(detections_side):\n                if d is None:\n                    detections_side_filtered.append(None)\n                    continue\n                num_d = d[\"verts\"].shape[0]\n\n                keep_mask = np.ones(d[\"keypoints_3d\"].shape[0], dtype=bool)\n                for offset in range(-15, 15):\n                    i_offset = i + offset\n                    if i_offset < 0 or i_offset >= len(detections_other_side):\n                        continue\n\n                    d_other = detections_other_side[i_offset]\n                    if d_other is None:\n                        # detections_side_filtered.append(d)\n                        continue\n\n                    num_d_other = d_other[\"verts\"].shape[0]\n\n                    dist_matrix = np.linalg.norm(\n                        d[\"keypoints_3d\"][:, None, 0, :]\n                        - d_other[\"keypoints_3d\"][None, :, 0, :],\n                        axis=-1,\n                    )\n                    assert dist_matrix.shape == (num_d, num_d_other)\n                    keep_mask = np.logical_and(\n                        keep_mask, np.all(dist_matrix > 0.1, axis=-1)\n                    )\n\n                if keep_mask.sum() == 0:\n                    detections_side_filtered.append(None)\n                else:\n                    detections_side_filtered.append(\n                        cast(\n                            SingleHandHamerOutputWrtCamera,\n                            {k: cast(np.ndarray, v)[keep_mask] for k, v in d.items()},\n                        )\n                    )\n            del detections_side\n\n            detections_side_not_none = [d is not None for d in detections_side_filtered]\n            if not any(detections_side_not_none):\n                return None\n            (valid_detection_indices,) = np.where(detections_side_not_none)\n\n            # We should be done with these.\n            del detections_side_not_none\n            del detections_other_side\n\n            detections_side_concat = SingleHandHamerOutputWrtCameraConcatenated(\n                verts=torch.from_numpy(\n                    np.stack(\n                        # Currently: we always just take the first hand detection.\n                        [\n                            d[\"verts\"][0]\n                            for d in detections_side_filtered\n                            if d is not None\n                        ]\n                    )\n                ).to(torch.float32),\n                keypoints_3d=torch.from_numpy(\n                    np.stack(\n                        [\n                            # Currently: we always just take the first hand detection.\n                            d[\"keypoints_3d\"][0]\n                            for d in detections_side_filtered\n                            if d is not None\n                        ]\n                    )\n                ).to(torch.float32),\n                mano_hand_global_orient=torch.from_numpy(\n                    np.stack(\n                        [\n                            # Currently: we always just take the first hand detection.\n                            d[\"mano_hand_global_orient\"][0]\n                            for d in detections_side_filtered\n                            if d is not None\n                        ]\n                    )\n                ).to(torch.float32),\n                single_hand_quats=SO3.from_matrix(\n                    torch.from_numpy(\n                        np.stack(\n                            [\n                                # Currently: we always just take the first hand detection.\n                                d[\"mano_hand_pose\"][0]\n                                for d in detections_side_filtered\n                                if d is not None\n                            ]\n                        )\n                    ).to(torch.float32)\n                ).wxyz,\n                indices=torch.from_numpy(valid_detection_indices),\n            )\n            return detections_side_concat\n\n        return CorrespondedHamerDetections(\n            mano_faces_right=torch.from_numpy(\n                hamer_out[\"mano_faces_right\"].astype(np.int64)\n            ),\n            mano_faces_left=torch.from_numpy(\n                hamer_out[\"mano_faces_left\"].astype(np.int64)\n            ),\n            detections_left_tuple=tuple(detections_left),\n            detections_right_tuple=tuple(detections_right),\n            T_cpf_cam=torch.from_numpy(hamer_out[\"T_cpf_cam\"]).to(torch.float32),\n            focal_length=450,\n            detections_left_concat=make_concat_detections(\n                detections_left, detections_right\n            ),\n            detections_right_concat=make_concat_detections(\n                detections_right, detections_left\n            ),\n        )\n"
  },
  {
    "path": "src/egoallo/inference_utils.py",
    "content": "\"\"\"Functions that are useful for inference scripts.\"\"\"\n\nfrom __future__ import annotations\n\nfrom dataclasses import dataclass\nfrom pathlib import Path\n\nimport numpy as np\nimport torch\nimport yaml\nfrom jaxtyping import Float\nfrom projectaria_tools.core import mps  # type: ignore\nfrom projectaria_tools.core.data_provider import create_vrs_data_provider\nfrom safetensors import safe_open\nfrom torch import Tensor\n\nfrom .network import EgoDenoiser, EgoDenoiserConfig\nfrom .tensor_dataclass import TensorDataclass\nfrom .transforms import SE3\n\n\ndef load_denoiser(checkpoint_dir: Path) -> EgoDenoiser:\n    \"\"\"Load a denoiser model.\"\"\"\n    checkpoint_dir = checkpoint_dir.absolute()\n    experiment_dir = checkpoint_dir.parent\n\n    config = yaml.load(\n        (experiment_dir / \"model_config.yaml\").read_text(), Loader=yaml.Loader\n    )\n    assert isinstance(config, EgoDenoiserConfig)\n\n    model = EgoDenoiser(config)\n    with safe_open(checkpoint_dir / \"model.safetensors\", framework=\"pt\") as f:  # type: ignore\n        state_dict = {k: f.get_tensor(k) for k in f.keys()}\n    model.load_state_dict(state_dict)\n\n    return model\n\n\n@dataclass(frozen=True)\nclass InferenceTrajectoryPaths:\n    \"\"\"Paths for running EgoAllo on a single sequence from Project Aria.\n\n    Our basic assumptions here are:\n    1. VRS file for images: there is exactly one VRS file in the trajectory root directory.\n    2. Aria MPS point cloud: there is either one semidense_points.csv.gz file or one global_points.csv.gz file.\n        - Its parent directory should contain other Aria MPS artifacts. (like poses)\n        - This is optionally used for guidance.\n    3. HaMeR outputs: The hamer_outputs.pkl file may or may not exist in the trajectory root directory.\n        - This is optionally used for guidance.\n    4. Aria MPS wrist/palm poses: There may be zero or one wrist_and_palm_poses.csv file.\n        - This is optionally used for guidance.\n    5. Scene splat/ply file: There may be a splat.ply or scene.splat file.\n        - This is only used for visualization.\n    \"\"\"\n\n    vrs_file: Path\n    slam_root_dir: Path\n    points_path: Path\n    hamer_outputs: Path | None\n    wrist_and_palm_poses_csv: Path | None\n    splat_path: Path | None\n\n    @staticmethod\n    def find(traj_root: Path) -> InferenceTrajectoryPaths:\n        vrs_files = tuple(traj_root.glob(\"**/*.vrs\"))\n        assert len(vrs_files) == 1, f\"Found {len(vrs_files)} VRS files!\"\n\n        points_paths = tuple(traj_root.glob(\"**/semidense_points.csv.gz\"))\n        assert len(points_paths) <= 1, f\"Found multiple points files! {points_paths}\"\n        if len(points_paths) == 0:\n            points_paths = tuple(traj_root.glob(\"**/global_points.csv.gz\"))\n        assert len(points_paths) == 1, f\"Found {len(points_paths)} files!\"\n\n        hamer_outputs = traj_root / \"hamer_outputs.pkl\"\n        if not hamer_outputs.exists():\n            hamer_outputs = None\n\n        wrist_and_palm_poses_csv = tuple(traj_root.glob(\"**/wrist_and_palm_poses.csv\"))\n        if len(wrist_and_palm_poses_csv) == 0:\n            wrist_and_palm_poses_csv = None\n        else:\n            assert len(wrist_and_palm_poses_csv) == 1, (\n                \"Found multiple wrist and palm poses files!\"\n            )\n\n        splat_path = traj_root / \"splat.ply\"\n        if not splat_path.exists():\n            splat_path = traj_root / \"scene.splat\"\n        if not splat_path.exists():\n            print(\"No scene splat found.\")\n            splat_path = None\n        else:\n            print(\"Found splat at\", splat_path)\n\n        return InferenceTrajectoryPaths(\n            vrs_file=vrs_files[0],\n            slam_root_dir=points_paths[0].parent,\n            points_path=points_paths[0],\n            hamer_outputs=hamer_outputs,\n            wrist_and_palm_poses_csv=wrist_and_palm_poses_csv[0]\n            if wrist_and_palm_poses_csv\n            else None,\n            splat_path=splat_path,\n        )\n\n\nclass InferenceInputTransforms(TensorDataclass):\n    \"\"\"Some relevant transforms for inference.\"\"\"\n\n    Ts_world_cpf: Float[Tensor, \"timesteps 7\"]\n    Ts_world_device: Float[Tensor, \"timesteps 7\"]\n    pose_timesteps: tuple[float, ...]\n\n    @staticmethod\n    def load(\n        vrs_path: Path,\n        slam_root_dir: Path,\n        fps: int = 30,\n    ) -> InferenceInputTransforms:\n        \"\"\"Read some useful transforms via MPS + the VRS calibration.\"\"\"\n        # Read device poses.\n        closed_loop_path = slam_root_dir / \"closed_loop_trajectory.csv\"\n        if not closed_loop_path.exists():\n            # Aria digital twins.\n            closed_loop_path = slam_root_dir / \"aria_trajectory.csv\"\n        closed_loop_traj = mps.read_closed_loop_trajectory(str(closed_loop_path))  # type: ignore\n\n        provider = create_vrs_data_provider(str(vrs_path))\n        device_calib = provider.get_device_calibration()\n        T_device_cpf = device_calib.get_transform_device_cpf().to_matrix()\n\n        # Get downsampled CPF frames.\n        aria_fps = len(closed_loop_traj) / (\n            closed_loop_traj[-1].tracking_timestamp.total_seconds()\n            - closed_loop_traj[0].tracking_timestamp.total_seconds()\n        )\n        num_poses = len(closed_loop_traj)\n        print(f\"Loaded {num_poses=} with {aria_fps=}, visualizing at {fps=}\")\n        Ts_world_device = []\n        Ts_world_cpf = []\n        out_timestamps_secs = []\n        for i in range(0, num_poses, int(aria_fps // fps)):\n            T_world_device = closed_loop_traj[i].transform_world_device.to_matrix()\n            assert T_world_device.shape == (4, 4)\n            Ts_world_device.append(T_world_device)\n            Ts_world_cpf.append(T_world_device @ T_device_cpf)\n            out_timestamps_secs.append(\n                closed_loop_traj[i].tracking_timestamp.total_seconds()\n            )\n\n        return InferenceInputTransforms(\n            Ts_world_device=SE3.from_matrix(torch.from_numpy(np.array(Ts_world_device)))\n            .parameters()\n            .to(torch.float32),\n            Ts_world_cpf=SE3.from_matrix(torch.from_numpy(np.array(Ts_world_cpf)))\n            .parameters()\n            .to(torch.float32),\n            pose_timesteps=tuple(out_timestamps_secs),\n        )\n"
  },
  {
    "path": "src/egoallo/metrics_helpers.py",
    "content": "from typing import Literal, overload\n\nimport numpy as np\nimport torch\nfrom jaxtyping import Float\nfrom torch import Tensor\nfrom typing_extensions import assert_never\n\nfrom .transforms import SO3\n\n\ndef compute_foot_skate(\n    pred_Ts_world_joint: Float[Tensor, \"num_samples time 21 7\"],\n) -> np.ndarray:\n    (num_samples, time) = pred_Ts_world_joint.shape[:2]\n\n    # Drop the person to the floor.\n    # This is necessary for the foot skating metric to make sense for floating people...!\n    pred_Ts_world_joint = pred_Ts_world_joint.clone()\n    pred_Ts_world_joint[..., 6] -= torch.min(pred_Ts_world_joint[..., 6])\n\n    foot_indices = torch.tensor([6, 7, 9, 10], device=pred_Ts_world_joint.device)\n\n    foot_positions = pred_Ts_world_joint[:, :, foot_indices, 4:7]\n    foot_positions_diff = foot_positions[:, 1:, :, :2] - foot_positions[:, :-1, :, :2]\n    assert foot_positions_diff.shape == (num_samples, time - 1, 4, 2)\n\n    foot_positions_diff_norm = torch.sum(torch.abs(foot_positions_diff), dim=-1)\n    assert foot_positions_diff_norm.shape == (num_samples, time - 1, 4)\n\n    # From EgoEgo / kinpoly.\n    H_thresh = torch.tensor(\n        # To match indices above: (ankle, ankle, toe, toe)\n        [0.08, 0.08, 0.04, 0.04],\n        device=pred_Ts_world_joint.device,\n        dtype=torch.float32,\n    )\n\n    foot_positions_diff_norm = torch.sum(torch.abs(foot_positions_diff), dim=-1)\n    assert foot_positions_diff_norm.shape == (num_samples, time - 1, 4)\n\n    # Threshold.\n    foot_positions_diff_norm = foot_positions_diff_norm * (\n        foot_positions[..., 1:, :, 2] < H_thresh\n    )\n    fs_per_sample = torch.sum(\n        torch.sum(\n            foot_positions_diff_norm\n            * (2 - 2 ** (foot_positions[..., 1:, :, 2] / H_thresh)),\n            dim=-1,\n        ),\n        dim=-1,\n    )\n    assert fs_per_sample.shape == (num_samples,)\n\n    return fs_per_sample.numpy(force=True)\n\n\ndef compute_foot_contact(\n    pred_Ts_world_joint: Float[Tensor, \"num_samples time 21 7\"],\n) -> np.ndarray:\n    (num_samples, time) = pred_Ts_world_joint.shape[:2]\n\n    foot_indices = torch.tensor([6, 7, 9, 10], device=pred_Ts_world_joint.device)\n\n    # From EgoEgo / kinpoly.\n    H_thresh = torch.tensor(\n        # To match indices above: (ankle, ankle, toe, toe)\n        [0.08, 0.08, 0.04, 0.04],\n        device=pred_Ts_world_joint.device,\n        dtype=torch.float32,\n    )\n\n    foot_positions = pred_Ts_world_joint[:, :, foot_indices, 4:7]\n\n    any_contact = torch.any(\n        torch.any(foot_positions[..., 2] < H_thresh, dim=-1), dim=-1\n    ).to(torch.float32)\n    assert any_contact.shape == (num_samples,)\n\n    return any_contact.numpy(force=True)\n\n\ndef compute_head_ori(\n    label_Ts_world_joint: Float[Tensor, \"time 21 7\"],\n    pred_Ts_world_joint: Float[Tensor, \"num_samples time 21 7\"],\n) -> np.ndarray:\n    (num_samples, time) = pred_Ts_world_joint.shape[:2]\n    matrix_errors = (\n        SO3(pred_Ts_world_joint[:, :, 14, :4]).as_matrix()\n        @ SO3(label_Ts_world_joint[:, 14, :4]).inverse().as_matrix()\n    ) - torch.eye(3, device=label_Ts_world_joint.device)\n    assert matrix_errors.shape == (num_samples, time, 3, 3)\n\n    return torch.mean(\n        torch.linalg.norm(matrix_errors.reshape((num_samples, time, 9)), dim=-1),\n        dim=-1,\n    ).numpy(force=True)\n\n\ndef compute_head_trans(\n    label_Ts_world_joint: Float[Tensor, \"time 21 7\"],\n    pred_Ts_world_joint: Float[Tensor, \"num_samples time 21 7\"],\n) -> np.ndarray:\n    (num_samples, time) = pred_Ts_world_joint.shape[:2]\n    errors = pred_Ts_world_joint[:, :, 14, 4:7] - label_Ts_world_joint[:, 14, 4:7]\n    assert errors.shape == (num_samples, time, 3)\n\n    return torch.mean(\n        torch.linalg.norm(errors, dim=-1),\n        dim=-1,\n    ).numpy(force=True)\n\n\ndef compute_mpjpe(\n    label_T_world_root: Float[Tensor, \"time 7\"],\n    label_Ts_world_joint: Float[Tensor, \"time 21 7\"],\n    pred_T_world_root: Float[Tensor, \"num_samples time 7\"],\n    pred_Ts_world_joint: Float[Tensor, \"num_samples time 21 7\"],\n    per_frame_procrustes_align: bool,\n) -> np.ndarray:\n    num_samples, time, _, _ = pred_Ts_world_joint.shape\n\n    # Concatenate the world root to the joints.\n    label_Ts_world_joint = torch.cat(\n        [label_T_world_root[..., None, :], label_Ts_world_joint], dim=-2\n    )\n    pred_Ts_world_joint = torch.cat(\n        [pred_T_world_root[..., None, :], pred_Ts_world_joint], dim=-2\n    )\n    del label_T_world_root, pred_T_world_root\n\n    pred_joint_positions = pred_Ts_world_joint[:, :, :, 4:7]\n    label_joint_positions = label_Ts_world_joint[None, :, :, 4:7].repeat(\n        num_samples, 1, 1, 1\n    )\n\n    if per_frame_procrustes_align:\n        pred_joint_positions = procrustes_align(\n            points_y=pred_joint_positions,\n            points_x=label_joint_positions,\n            output=\"aligned_x\",\n        )\n\n    position_differences = pred_joint_positions - label_joint_positions\n    assert position_differences.shape == (num_samples, time, 22, 3)\n\n    # Per-joint position errors, in millimeters.\n    pjpe = torch.linalg.norm(position_differences, dim=-1) * 1000.0\n    assert pjpe.shape == (num_samples, time, 22)\n\n    # Mean per-joint position errors.\n    mpjpe = torch.mean(pjpe.reshape((num_samples, -1)), dim=-1)\n    assert mpjpe.shape == (num_samples,)\n\n    return mpjpe.cpu().numpy()\n\n\n@overload\ndef procrustes_align(\n    points_y: Float[Tensor, \"*#batch N 3\"],\n    points_x: Float[Tensor, \"*#batch N 3\"],\n    output: Literal[\"transforms\"],\n    fix_scale: bool = False,\n) -> tuple[Tensor, Tensor, Tensor]: ...\n\n\n@overload\ndef procrustes_align(\n    points_y: Float[Tensor, \"*#batch N 3\"],\n    points_x: Float[Tensor, \"*#batch N 3\"],\n    output: Literal[\"aligned_x\"],\n    fix_scale: bool = False,\n) -> Tensor: ...\n\n\ndef procrustes_align(\n    points_y: Float[Tensor, \"*#batch N 3\"],\n    points_x: Float[Tensor, \"*#batch N 3\"],\n    output: Literal[\"transforms\", \"aligned_x\"],\n    fix_scale: bool = False,\n) -> tuple[Tensor, Tensor, Tensor] | Tensor:\n    \"\"\"Similarity transform alignment using the Umeyama method. Adapted from\n    SLAHMR: https://github.com/vye16/slahmr/blob/main/slahmr/geometry/pcl.py\n    Minimizes:\n        mean( || Y - s * (R @ X) + t ||^2 )\n    with respect to s, R, and t.\n    Returns an (s, R, t) tuple.\n    \"\"\"\n    *dims, N, _ = points_y.shape\n    device = points_y.device\n    N = torch.ones((*dims, 1, 1), device=device) * N\n\n    # subtract mean\n    my = points_y.sum(dim=-2) / N[..., 0]  # (*, 3)\n    mx = points_x.sum(dim=-2) / N[..., 0]\n    y0 = points_y - my[..., None, :]  # (*, N, 3)\n    x0 = points_x - mx[..., None, :]\n\n    # correlation\n    C = torch.matmul(y0.transpose(-1, -2), x0) / N  # (*, 3, 3)\n    U, D, Vh = torch.linalg.svd(C)  # (*, 3, 3), (*, 3), (*, 3, 3)\n\n    S = (\n        torch.eye(3, device=device)\n        .reshape(*(1,) * (len(dims)), 3, 3)\n        .repeat(*dims, 1, 1)\n    )\n    neg = torch.det(U) * torch.det(Vh.transpose(-1, -2)) < 0\n    S = torch.where(\n        neg.reshape(*dims, 1, 1),\n        S * torch.diag(torch.tensor([1, 1, -1], device=device)),\n        S,\n    )\n\n    R = torch.matmul(U, torch.matmul(S, Vh))  # (*, 3, 3)\n\n    D = torch.diag_embed(D)  # (*, 3, 3)\n    if fix_scale:\n        s = torch.ones(*dims, 1, device=device, dtype=torch.float32)\n    else:\n        var = torch.sum(torch.square(x0), dim=(-1, -2), keepdim=True) / N  # (*, 1, 1)\n        s = (\n            torch.diagonal(torch.matmul(D, S), dim1=-2, dim2=-1).sum(\n                dim=-1, keepdim=True\n            )\n            / var[..., 0]\n        )  # (*, 1)\n\n    t = my - s * torch.matmul(R, mx[..., None])[..., 0]  # (*, 3)\n\n    assert s.shape == (*dims, 1)\n    assert R.shape == (*dims, 3, 3)\n    assert t.shape == (*dims, 3)\n\n    if output == \"transforms\":\n        return s, R, t\n    elif output == \"aligned_x\":\n        aligned_x = (\n            s[..., None, :] * torch.einsum(\"...ij,...nj->...ni\", R, points_x)\n            + t[..., None, :]\n        )\n        assert aligned_x.shape == points_x.shape\n        return aligned_x\n    else:\n        assert_never(output)\n"
  },
  {
    "path": "src/egoallo/network.py",
    "content": "from __future__ import annotations\n\nfrom dataclasses import dataclass\nfrom functools import cache, cached_property\nfrom typing import Literal, assert_never\n\nimport numpy as np\nimport torch\nfrom einops import rearrange\nfrom jaxtyping import Bool, Float\nfrom loguru import logger\nfrom rotary_embedding_torch import RotaryEmbedding\nfrom torch import Tensor, nn\n\nfrom .fncsmpl import SmplhModel, SmplhShapedAndPosed\nfrom .tensor_dataclass import TensorDataclass\nfrom .transforms import SE3, SO3\n\n\ndef project_rotmats_via_svd(\n    rotmats: Float[Tensor, \"*batch 3 3\"],\n) -> Float[Tensor, \"*batch 3 3\"]:\n    u, s, vh = torch.linalg.svd(rotmats)\n    del s\n    return torch.einsum(\"...ij,...jk->...ik\", u, vh)\n\n\nclass EgoDenoiseTraj(TensorDataclass):\n    \"\"\"Data structure for denoising. Contains tensors that we are denoising, as\n    well as utilities for packing + unpacking them.\"\"\"\n\n    betas: Float[Tensor, \"*#batch timesteps 16\"]\n    \"\"\"Body shape parameters. We don't really need the timesteps axis here,\n    it's just for convenience.\"\"\"\n\n    body_rotmats: Float[Tensor, \"*#batch timesteps 21 3 3\"]\n    \"\"\"Local orientations for each body joint.\"\"\"\n\n    contacts: Float[Tensor, \"*#batch timesteps 21\"]\n    \"\"\"Contact boolean for each joint.\"\"\"\n\n    hand_rotmats: Float[Tensor, \"*#batch timesteps 30 3 3\"] | None\n    \"\"\"Local orientations for each body joint.\"\"\"\n\n    @staticmethod\n    def get_packed_dim(include_hands: bool) -> int:\n        packed_dim = 16 + 21 * 9 + 21\n        if include_hands:\n            packed_dim += 30 * 9\n        return packed_dim\n\n    def apply_to_body(self, body_model: SmplhModel) -> SmplhShapedAndPosed:\n        device = self.betas.device\n        dtype = self.betas.dtype\n        assert self.hand_rotmats is not None\n        shaped = body_model.with_shape(self.betas)\n        posed = shaped.with_pose(\n            T_world_root=SE3.identity(device=device, dtype=dtype).parameters(),\n            local_quats=SO3.from_matrix(\n                torch.cat([self.body_rotmats, self.hand_rotmats], dim=-3)\n            ).wxyz,\n        )\n        return posed\n\n    def pack(self) -> Float[Tensor, \"*#batch timesteps d_state\"]:\n        \"\"\"Pack trajectory into a single flattened vector.\"\"\"\n        (*batch, time, num_joints, _, _) = self.body_rotmats.shape\n        assert num_joints == 21\n        return torch.cat(\n            [\n                x.reshape((*batch, time, -1))\n                for x in vars(self).values()\n                if x is not None\n            ],\n            dim=-1,\n        )\n\n    @classmethod\n    def unpack(\n        cls,\n        x: Float[Tensor, \"*#batch timesteps d_state\"],\n        include_hands: bool,\n        project_rotmats: bool = False,\n    ) -> EgoDenoiseTraj:\n        \"\"\"Unpack trajectory from a single flattened vector.\n\n        Args:\n            x: Packed trajectory.\n            project_rotmats: If True, project the rotation matrices to SO(3) via SVD.\n        \"\"\"\n        (*batch, time, d_state) = x.shape\n        assert d_state == cls.get_packed_dim(include_hands)\n\n        if include_hands:\n            betas, body_rotmats_flat, contacts, hand_rotmats_flat = torch.split(\n                x, [16, 21 * 9, 21, 30 * 9], dim=-1\n            )\n            body_rotmats = body_rotmats_flat.reshape((*batch, time, 21, 3, 3))\n            hand_rotmats = hand_rotmats_flat.reshape((*batch, time, 30, 3, 3))\n            assert betas.shape == (*batch, time, 16)\n        else:\n            betas, body_rotmats_flat, contacts = torch.split(\n                x, [16, 21 * 9, 21], dim=-1\n            )\n            body_rotmats = body_rotmats_flat.reshape((*batch, time, 21, 3, 3))\n            hand_rotmats = None\n            assert betas.shape == (*batch, time, 16)\n\n        if project_rotmats:\n            # We might want to handle the -1 determinant case as well.\n            body_rotmats = project_rotmats_via_svd(body_rotmats)\n\n        return EgoDenoiseTraj(\n            betas=betas,\n            body_rotmats=body_rotmats,\n            contacts=contacts,\n            hand_rotmats=hand_rotmats,\n        )\n\n\n@dataclass(frozen=True)\nclass EgoDenoiserConfig:\n    max_t: int = 1000\n    fourier_enc_freqs: int = 3\n    d_latent: int = 512\n    d_feedforward: int = 2048\n    d_noise_emb: int = 1024\n    num_heads: int = 4\n    encoder_layers: int = 6\n    decoder_layers: int = 6\n    dropout_p: float = 0.0\n    activation: Literal[\"gelu\", \"relu\"] = \"gelu\"\n\n    positional_encoding: Literal[\"transformer\", \"rope\"] = \"rope\"\n    noise_conditioning: Literal[\"token\", \"film\"] = \"token\"\n\n    xattn_mode: Literal[\"kv_from_cond_q_from_x\", \"kv_from_x_q_from_cond\"] = (\n        \"kv_from_cond_q_from_x\"\n    )\n\n    include_canonicalized_cpf_rotation_in_cond: bool = True\n    include_hands: bool = True\n    \"\"\"Whether to include hand joints (+15 per hand) in the denoised state.\"\"\"\n\n    cond_param: Literal[\n        \"ours\", \"canonicalized\", \"absolute\", \"absrel\", \"absrel_global_deltas\"\n    ] = \"ours\"\n    \"\"\"Which conditioning parameterization to use.\n\n    \"ours\" is the default, we try to be clever and design something with nice\n        equivariance properties.\n    \"canonicalized\" contains a transformation that's canonicalized to aligned\n        to the first frame.\n    \"absolute\" is the naive case, where we just pass in transformations\n        directly.\n    \"\"\"\n\n    include_hand_positions_cond: bool = False\n    \"\"\"Whether to include hand positions in the conditioning information.\"\"\"\n\n    @cached_property\n    def d_cond(self) -> int:\n        \"\"\"Dimensionality of conditioning vector.\"\"\"\n\n        if self.cond_param == \"ours\":\n            d_cond = 0\n            d_cond += 12  # Relative CPF pose, flattened 3x4 matrix.\n            d_cond += 1  # Floor height.\n            if self.include_canonicalized_cpf_rotation_in_cond:\n                d_cond += 9  # Canonicalized CPF rotation, flattened 3x3 matrix.\n        elif self.cond_param == \"canonicalized\":\n            d_cond = 12\n        elif self.cond_param == \"absolute\":\n            d_cond = 12\n        elif self.cond_param == \"absrel\":\n            # Both absolute and relative!\n            d_cond = 24\n        elif self.cond_param == \"absrel_global_deltas\":\n            # Both absolute and relative!\n            d_cond = 24\n        else:\n            assert_never(self.cond_param)\n\n        # Add two 3D positions to the conditioning dimension if we're including\n        # hand conditioning.\n        if self.include_hand_positions_cond:\n            d_cond = d_cond + 6\n\n        d_cond = d_cond + d_cond * self.fourier_enc_freqs * 2  # Fourier encoding.\n        return d_cond\n\n    def make_cond(\n        self,\n        T_cpf_tm1_cpf_t: Float[Tensor, \"batch time 7\"],\n        T_world_cpf: Float[Tensor, \"batch time 7\"],\n        hand_positions_wrt_cpf: Float[Tensor, \"batch time 6\"] | None,\n    ) -> Float[Tensor, \"batch time d_cond\"]:\n        \"\"\"Construct conditioning information from CPF pose.\"\"\"\n\n        (batch, time, _) = T_cpf_tm1_cpf_t.shape\n\n        # Construct device pose conditioning.\n        if self.cond_param == \"ours\":\n            # Compute conditioning terms. +Z is up in the world frame. We want\n            # the translation to be invariant to translations in the world X/Y\n            # directions.\n            height_from_floor = T_world_cpf[..., 6:7]\n\n            cond_parts = [\n                SE3(T_cpf_tm1_cpf_t).as_matrix()[..., :3, :].reshape((batch, time, 12)),\n                height_from_floor,\n            ]\n            if self.include_canonicalized_cpf_rotation_in_cond:\n                # We want the rotation to be invariant to rotations around the\n                # world Z axis. Visualization of what's happening here:\n                #\n                # https://gist.github.com/brentyi/9226d082d2707132af39dea92b8609f6\n                #\n                # (The coordinate frame may differ by some axis-swapping\n                # compared to the exact equations in the paper. But to the\n                # network these will all look the same.)\n                R_world_cpf = SE3(T_world_cpf).rotation().wxyz\n                forward_cpf = R_world_cpf.new_tensor([0.0, 0.0, 1.0])\n                forward_world = SO3(R_world_cpf) @ forward_cpf\n                assert forward_world.shape == (batch, time, 3)\n                R_canonical_world = SO3.from_z_radians(\n                    -torch.arctan2(forward_world[..., 1], forward_world[..., 0])\n                ).wxyz\n                assert R_canonical_world.shape == (batch, time, 4)\n                cond_parts.append(\n                    (SO3(R_canonical_world) @ SO3(R_world_cpf))\n                    .as_matrix()\n                    .reshape((batch, time, 9)),\n                )\n            cond = torch.cat(cond_parts, dim=-1)\n        elif self.cond_param == \"canonicalized\":\n            # Align the first timestep.\n            # Put poses so start is at origin, facing forward.\n            R_world_cpf = SE3(T_world_cpf[:, 0:1, :]).rotation().wxyz\n            forward_cpf = R_world_cpf.new_tensor([0.0, 0.0, 1.0])\n            forward_world = SO3(R_world_cpf) @ forward_cpf\n            assert forward_world.shape == (batch, 1, 3)\n            R_canonical_world = SO3.from_z_radians(\n                -torch.arctan2(forward_world[..., 1], forward_world[..., 0])\n            ).wxyz\n            assert R_canonical_world.shape == (batch, 1, 4)\n\n            R_canonical_cpf = SO3(R_canonical_world) @ SE3(T_world_cpf).rotation()\n            t_canonical_cpf = SO3(R_canonical_world) @ SE3(T_world_cpf).translation()\n            t_canonical_cpf = t_canonical_cpf - t_canonical_cpf[:, 0:1, :]\n\n            cond = (\n                SE3.from_rotation_and_translation(R_canonical_cpf, t_canonical_cpf)\n                .as_matrix()[..., :3, :4]\n                .reshape((batch, time, 12))\n            )\n        elif self.cond_param == \"absolute\":\n            cond = SE3(T_world_cpf).as_matrix()[..., :3, :4].reshape((batch, time, 12))\n        elif self.cond_param == \"absrel\":\n            cond = torch.concatenate(\n                [\n                    SE3(T_world_cpf)\n                    .as_matrix()[..., :3, :4]\n                    .reshape((batch, time, 12)),\n                    SE3(T_cpf_tm1_cpf_t)\n                    .as_matrix()[..., :3, :4]\n                    .reshape((batch, time, 12)),\n                ],\n                dim=-1,\n            )\n        elif self.cond_param == \"absrel_global_deltas\":\n            cond = torch.concatenate(\n                [\n                    SE3(T_world_cpf)\n                    .as_matrix()[..., :3, :4]\n                    .reshape((batch, time, 12)),\n                    SE3(T_cpf_tm1_cpf_t)\n                    .rotation()\n                    .as_matrix()\n                    .reshape((batch, time, 9)),\n                    (\n                        SE3(T_world_cpf).rotation()\n                        @ SE3(T_cpf_tm1_cpf_t).inverse().translation()\n                    ).reshape((batch, time, 3)),\n                ],\n                dim=-1,\n            )\n        else:\n            assert_never(self.cond_param)\n\n        # Condition on hand poses as well.\n        # We didn't use this for the paper.\n        if self.include_hand_positions_cond:\n            if hand_positions_wrt_cpf is None:\n                logger.warning(\n                    \"Model is looking for hand conditioning but none was provided. Passing in zeros.\"\n                )\n                hand_positions_wrt_cpf = torch.zeros(\n                    (batch, time, 6), device=T_world_cpf.device\n                )\n            assert hand_positions_wrt_cpf.shape == (batch, time, 6)\n            cond = torch.cat([cond, hand_positions_wrt_cpf], dim=-1)\n\n        cond = fourier_encode(cond, freqs=self.fourier_enc_freqs)\n        assert cond.shape == (batch, time, self.d_cond)\n        return cond\n\n\nclass EgoDenoiser(nn.Module):\n    \"\"\"Denoising network for human motion.\n\n    Inputs are noisy trajectory, conditioning information, and timestep.\n    Output is denoised trajectory.\n    \"\"\"\n\n    def __init__(self, config: EgoDenoiserConfig):\n        super().__init__()\n\n        self.config = config\n        Activation = {\"gelu\": nn.GELU, \"relu\": nn.ReLU}[config.activation]\n\n        # MLP encoders and decoders for each modality we want to denoise.\n        modality_dims: dict[str, int] = {\n            \"betas\": 16,\n            \"body_rotmats\": 21 * 9,\n            \"contacts\": 21,\n        }\n        if config.include_hands:\n            modality_dims[\"hand_rotmats\"] = 30 * 9\n\n        assert sum(modality_dims.values()) == self.get_d_state()\n        self.encoders = nn.ModuleDict(\n            {\n                k: nn.Sequential(\n                    nn.Linear(modality_dim, config.d_latent),\n                    Activation(),\n                    nn.Linear(config.d_latent, config.d_latent),\n                    Activation(),\n                    nn.Linear(config.d_latent, config.d_latent),\n                )\n                for k, modality_dim in modality_dims.items()\n            }\n        )\n        self.decoders = nn.ModuleDict(\n            {\n                k: nn.Sequential(\n                    nn.Linear(config.d_latent, config.d_latent),\n                    nn.LayerNorm(normalized_shape=config.d_latent),\n                    Activation(),\n                    nn.Linear(config.d_latent, config.d_latent),\n                    Activation(),\n                    nn.Linear(config.d_latent, modality_dim),\n                )\n                for k, modality_dim in modality_dims.items()\n            }\n        )\n\n        # Helpers for converting between input dimensionality and latent dimensionality.\n        self.latent_from_cond = nn.Linear(config.d_cond, config.d_latent)\n\n        # Noise embedder.\n        self.noise_emb = nn.Embedding(\n            # index 0 will be t=1\n            # index 999 will be t=1000\n            num_embeddings=config.max_t,\n            embedding_dim=config.d_noise_emb,\n        )\n        self.noise_emb_token_proj = (\n            nn.Linear(config.d_noise_emb, config.d_latent, bias=False)\n            if config.noise_conditioning == \"token\"\n            else None\n        )\n\n        # Encoder / decoder layers.\n        # Inputs are conditioning (current noise level, observations); output\n        # is encoded conditioning information.\n        self.encoder_layers = nn.ModuleList(\n            [\n                TransformerBlock(\n                    TransformerBlockConfig(\n                        d_latent=config.d_latent,\n                        d_noise_emb=config.d_noise_emb,\n                        d_feedforward=config.d_feedforward,\n                        n_heads=config.num_heads,\n                        dropout_p=config.dropout_p,\n                        activation=config.activation,\n                        include_xattn=False,  # No conditioning for encoder.\n                        use_rope_embedding=config.positional_encoding == \"rope\",\n                        use_film_noise_conditioning=config.noise_conditioning == \"film\",\n                        xattn_mode=config.xattn_mode,\n                    )\n                )\n                for _ in range(config.encoder_layers)\n            ]\n        )\n        self.decoder_layers = nn.ModuleList(\n            [\n                TransformerBlock(\n                    TransformerBlockConfig(\n                        d_latent=config.d_latent,\n                        d_noise_emb=config.d_noise_emb,\n                        d_feedforward=config.d_feedforward,\n                        n_heads=config.num_heads,\n                        dropout_p=config.dropout_p,\n                        activation=config.activation,\n                        include_xattn=True,  # Include conditioning for the decoder.\n                        use_rope_embedding=config.positional_encoding == \"rope\",\n                        use_film_noise_conditioning=config.noise_conditioning == \"film\",\n                        xattn_mode=config.xattn_mode,\n                    )\n                )\n                for _ in range(config.decoder_layers)\n            ]\n        )\n\n    def get_d_state(self) -> int:\n        return EgoDenoiseTraj.get_packed_dim(self.config.include_hands)\n\n    def forward(\n        self,\n        x_t_packed: Float[Tensor, \"batch time state_dim\"],\n        t: Float[Tensor, \"batch\"],\n        *,\n        T_world_cpf: Float[Tensor, \"batch time 7\"],\n        T_cpf_tm1_cpf_t: Float[Tensor, \"batch time 7\"],\n        project_output_rotmats: bool,\n        # Observed hand positions, relative to the CPF.\n        hand_positions_wrt_cpf: Float[Tensor, \"batch time 6\"] | None,\n        # Attention mask for using shorter sequences.\n        mask: Bool[Tensor, \"batch time\"] | None,\n        # Mask for when to drop out / keep conditioning information.\n        cond_dropout_keep_mask: Bool[Tensor, \"batch\"] | None = None,\n    ) -> Float[Tensor, \"batch time state_dim\"]:\n        \"\"\"Predict a denoised trajectory. Note that `t` refers to a noise\n        level, not a timestep.\"\"\"\n        config = self.config\n\n        x_t = EgoDenoiseTraj.unpack(x_t_packed, include_hands=self.config.include_hands)\n        (batch, time, num_body_joints, _, _) = x_t.body_rotmats.shape\n        assert num_body_joints == 21\n\n        # Encode the trajectory into a single vector per timestep.\n        x_t_encoded = (\n            self.encoders[\"betas\"](x_t.betas.reshape((batch, time, -1)))\n            + self.encoders[\"body_rotmats\"](x_t.body_rotmats.reshape((batch, time, -1)))\n            + self.encoders[\"contacts\"](x_t.contacts)\n        )\n        if self.config.include_hands:\n            assert x_t.hand_rotmats is not None\n            x_t_encoded = x_t_encoded + self.encoders[\"hand_rotmats\"](\n                x_t.hand_rotmats.reshape((batch, time, -1))\n            )\n        assert x_t_encoded.shape == (batch, time, config.d_latent)\n\n        # Embed the diffusion noise level.\n        assert t.shape == (batch,)\n        noise_emb = self.noise_emb(t - 1)\n        assert noise_emb.shape == (batch, config.d_noise_emb)\n\n        # Prepare conditioning information.\n        cond = config.make_cond(\n            T_cpf_tm1_cpf_t,\n            T_world_cpf=T_world_cpf,\n            hand_positions_wrt_cpf=hand_positions_wrt_cpf,\n        )\n\n        # Randomly drop out conditioning information; this serves as a\n        # regularizer that aims to improve sample diversity.\n        if cond_dropout_keep_mask is not None:\n            assert cond_dropout_keep_mask.shape == (batch,)\n            cond = cond * cond_dropout_keep_mask[:, None, None]\n\n        # Prepare encoder and decoder inputs.\n        if config.positional_encoding == \"rope\":\n            pos_enc = 0\n        elif config.positional_encoding == \"transformer\":\n            pos_enc = make_positional_encoding(\n                d_latent=config.d_latent,\n                length=time,\n                dtype=cond.dtype,\n            )[None, ...].to(x_t_encoded.device)\n            assert pos_enc.shape == (1, time, config.d_latent)\n        else:\n            assert_never(config.positional_encoding)\n\n        encoder_out = self.latent_from_cond(cond) + pos_enc\n        decoder_out = x_t_encoded + pos_enc\n\n        # Append the noise embedding to the encoder and decoder inputs.\n        # This is weird if we're using rotary embeddings!\n        if self.noise_emb_token_proj is not None:\n            noise_emb_token = self.noise_emb_token_proj(noise_emb)\n            assert noise_emb_token.shape == (batch, config.d_latent)\n            encoder_out = torch.cat([noise_emb_token[:, None, :], encoder_out], dim=1)\n            decoder_out = torch.cat([noise_emb_token[:, None, :], decoder_out], dim=1)\n            assert (\n                encoder_out.shape\n                == decoder_out.shape\n                == (batch, time + 1, config.d_latent)\n            )\n            num_tokens = time + 1\n        else:\n            num_tokens = time\n\n        # Compute attention mask. This needs to be a fl\n        if mask is None:\n            attn_mask = None\n        else:\n            assert mask.shape == (batch, time)\n            assert mask.dtype == torch.bool\n            if self.noise_emb_token_proj is not None:  # Account for noise token.\n                mask = torch.cat([mask.new_ones((batch, 1)), mask], dim=1)\n            # Last two dimensions of mask are (query, key). We're masking out only keys;\n            # it's annoying for the softmax to mask out entire rows without getting NaNs.\n            attn_mask = mask[:, None, None, :].repeat(1, 1, num_tokens, 1)\n            assert attn_mask.shape == (batch, 1, num_tokens, num_tokens)\n            assert attn_mask.dtype == torch.bool\n\n        # Forward pass through transformer.\n        for layer in self.encoder_layers:\n            encoder_out = layer(encoder_out, attn_mask, noise_emb=noise_emb)\n        for layer in self.decoder_layers:\n            decoder_out = layer(\n                decoder_out, attn_mask, noise_emb=noise_emb, cond=encoder_out\n            )\n\n        # Remove the extra token corresponding to the noise embedding.\n        if self.noise_emb_token_proj is not None:\n            decoder_out = decoder_out[:, 1:, :]\n        assert isinstance(decoder_out, Tensor)\n        assert decoder_out.shape == (batch, time, config.d_latent)\n\n        packed_output = torch.cat(\n            [\n                # Project rotation matrices for body_rotmats via SVD,\n                (\n                    project_rotmats_via_svd(\n                        modality_decoder(decoder_out).reshape((-1, 3, 3))\n                    ).reshape(\n                        (batch, time, {\"body_rotmats\": 21, \"hand_rotmats\": 30}[key] * 9)\n                    )\n                    # if enabled,\n                    if project_output_rotmats\n                    and key in (\"body_rotmats\", \"hand_rotmats\")\n                    # otherwise, just decode normally.\n                    else modality_decoder(decoder_out)\n                )\n                for key, modality_decoder in self.decoders.items()\n            ],\n            dim=-1,\n        )\n        assert packed_output.shape == (batch, time, self.get_d_state())\n\n        # Return packed output.\n        return packed_output\n\n\n@cache\ndef make_positional_encoding(\n    d_latent: int, length: int, dtype: torch.dtype\n) -> Float[Tensor, \"length d_latent\"]:\n    \"\"\"Computes standard Transformer positional encoding.\"\"\"\n    pe = torch.zeros(length, d_latent, dtype=dtype)\n    position = torch.arange(0, length, dtype=torch.float).unsqueeze(1)\n    div_term = torch.exp(\n        torch.arange(0, d_latent, 2).float() * (-np.log(10000.0) / d_latent)\n    )\n    pe[:, 0::2] = torch.sin(position * div_term)\n    pe[:, 1::2] = torch.cos(position * div_term)\n    assert pe.shape == (length, d_latent)\n    return pe\n\n\ndef fourier_encode(\n    x: Float[Tensor, \"*#batch channels\"], freqs: int\n) -> Float[Tensor, \"*#batch channels+2*freqs*channels\"]:\n    \"\"\"Apply Fourier encoding to a tensor.\"\"\"\n    *batch_axes, x_dim = x.shape\n    coeffs = 2.0 ** torch.arange(freqs, device=x.device)\n    scaled = (x[..., None] * coeffs).reshape((*batch_axes, x_dim * freqs))\n    return torch.cat(\n        [\n            x,\n            torch.sin(torch.cat([scaled, scaled + torch.pi / 2.0], dim=-1)),\n        ],\n        dim=-1,\n    )\n\n\n@dataclass(frozen=True)\nclass TransformerBlockConfig:\n    d_latent: int\n    d_noise_emb: int\n    d_feedforward: int\n    n_heads: int\n    dropout_p: float\n    activation: Literal[\"gelu\", \"relu\"]\n    include_xattn: bool\n    use_rope_embedding: bool\n    use_film_noise_conditioning: bool\n    xattn_mode: Literal[\"kv_from_cond_q_from_x\", \"kv_from_x_q_from_cond\"]\n\n\nclass TransformerBlock(nn.Module):\n    \"\"\"An even-tempered Transformer block.\"\"\"\n\n    def __init__(self, config: TransformerBlockConfig) -> None:\n        super().__init__()\n        self.sattn_qkv_proj = nn.Linear(\n            config.d_latent, config.d_latent * 3, bias=False\n        )\n        self.sattn_out_proj = nn.Linear(config.d_latent, config.d_latent, bias=False)\n\n        self.layernorm1 = nn.LayerNorm(config.d_latent)\n        self.layernorm2 = nn.LayerNorm(config.d_latent)\n\n        assert config.d_latent % config.n_heads == 0\n        self.rotary_emb = (\n            RotaryEmbedding(config.d_latent // config.n_heads)\n            if config.use_rope_embedding\n            else None\n        )\n\n        if config.include_xattn:\n            self.xattn_kv_proj = nn.Linear(\n                config.d_latent, config.d_latent * 2, bias=False\n            )\n            self.xattn_q_proj = nn.Linear(config.d_latent, config.d_latent, bias=False)\n            self.xattn_layernorm = nn.LayerNorm(config.d_latent)\n            self.xattn_out_proj = nn.Linear(\n                config.d_latent, config.d_latent, bias=False\n            )\n\n        self.norm_no_learnable = nn.LayerNorm(\n            config.d_feedforward, elementwise_affine=False, bias=False\n        )\n        self.activation = {\"gelu\": nn.GELU, \"relu\": nn.ReLU}[config.activation]()\n        self.dropout = nn.Dropout(config.dropout_p)\n\n        self.mlp0 = nn.Linear(config.d_latent, config.d_feedforward)\n        self.mlp_film_cond_proj = (\n            zero_module(\n                nn.Linear(config.d_noise_emb, config.d_feedforward * 2, bias=False)\n            )\n            if config.use_film_noise_conditioning\n            else None\n        )\n        self.mlp1 = nn.Linear(config.d_feedforward, config.d_latent)\n        self.config = config\n\n    def forward(\n        self,\n        x: Float[Tensor, \"batch tokens d_latent\"],\n        attn_mask: Bool[Tensor, \"batch 1 tokens tokens\"] | None,\n        noise_emb: Float[Tensor, \"batch d_noise_emb\"],\n        cond: Float[Tensor, \"batch tokens d_latent\"] | None = None,\n    ) -> Float[Tensor, \"batch tokens d_latent\"]:\n        config = self.config\n        (batch, time, d_latent) = x.shape\n\n        # Self-attention.\n        # We put layer normalization after the residual connection.\n        x = self.layernorm1(x + self._sattn(x, attn_mask))\n\n        # Include conditioning.\n        if config.include_xattn:\n            assert cond is not None\n            x = self.xattn_layernorm(x + self._xattn(x, attn_mask, cond=cond))\n\n        mlp_out = x\n        mlp_out = self.mlp0(mlp_out)\n        mlp_out = self.activation(mlp_out)\n\n        # FiLM-style conditioning.\n        if self.mlp_film_cond_proj is not None:\n            scale, shift = torch.chunk(\n                self.mlp_film_cond_proj(noise_emb), chunks=2, dim=-1\n            )\n            assert scale.shape == shift.shape == (batch, config.d_feedforward)\n            mlp_out = (\n                self.norm_no_learnable(mlp_out) * (1.0 + scale[:, None, :])\n                + shift[:, None, :]\n            )\n\n        mlp_out = self.dropout(mlp_out)\n        mlp_out = self.mlp1(mlp_out)\n\n        x = self.layernorm2(x + mlp_out)\n        assert x.shape == (batch, time, d_latent)\n        return x\n\n    def _sattn(self, x: Tensor, attn_mask: Tensor | None) -> Tensor:\n        \"\"\"Multi-head self-attention.\"\"\"\n        config = self.config\n        q, k, v = rearrange(\n            self.sattn_qkv_proj(x),\n            \"b t (qkv nh dh) -> qkv b nh t dh\",\n            qkv=3,\n            nh=config.n_heads,\n        )\n        if self.rotary_emb is not None:\n            q = self.rotary_emb.rotate_queries_or_keys(q, seq_dim=-2)\n            k = self.rotary_emb.rotate_queries_or_keys(k, seq_dim=-2)\n        x = torch.nn.functional.scaled_dot_product_attention(\n            q, k, v, dropout_p=config.dropout_p, attn_mask=attn_mask\n        )\n        x = self.dropout(x)\n        x = rearrange(x, \"b nh t dh -> b t (nh dh)\", nh=config.n_heads)\n        x = torch.nn.functional.scaled_dot_product_attention(\n            q, k, v, dropout_p=config.dropout_p\n        )\n        x = self.dropout(x)\n        x = rearrange(x, \"b nh t dh -> b t (nh dh)\", nh=config.n_heads)\n        x = self.sattn_out_proj(x)\n        return x\n\n    def _xattn(self, x: Tensor, attn_mask: Tensor | None, cond: Tensor) -> Tensor:\n        \"\"\"Multi-head cross-attention.\"\"\"\n        config = self.config\n        k, v = rearrange(\n            self.xattn_kv_proj(\n                {\n                    \"kv_from_cond_q_from_x\": cond,\n                    \"kv_from_x_q_from_cond\": x,\n                }[self.config.xattn_mode]\n            ),\n            \"b t (qk nh dh) -> qk b nh t dh\",\n            qk=2,\n            nh=config.n_heads,\n        )\n        q = rearrange(\n            self.xattn_q_proj(\n                {\n                    \"kv_from_cond_q_from_x\": x,\n                    \"kv_from_x_q_from_cond\": cond,\n                }[self.config.xattn_mode]\n            ),\n            \"b t (nh dh) -> b nh t dh\",\n            nh=config.n_heads,\n        )\n        if self.rotary_emb is not None:\n            q = self.rotary_emb.rotate_queries_or_keys(q, seq_dim=-2)\n            k = self.rotary_emb.rotate_queries_or_keys(k, seq_dim=-2)\n        x = torch.nn.functional.scaled_dot_product_attention(\n            q, k, v, dropout_p=config.dropout_p, attn_mask=attn_mask\n        )\n        x = rearrange(x, \"b nh t dh -> b t (nh dh)\")\n        x = self.xattn_out_proj(x)\n\n        return x\n\n\ndef zero_module(module):\n    \"\"\"Zero out the parameters of a module and return it.\"\"\"\n    for p in module.parameters():\n        p.detach().zero_()\n    return module\n"
  },
  {
    "path": "src/egoallo/preprocessing/__init__.py",
    "content": ""
  },
  {
    "path": "src/egoallo/preprocessing/body_model/__init__.py",
    "content": "from .body_model import BodyModel\nfrom .skeleton import *\nfrom .specs import *\nfrom .utils import *\n"
  },
  {
    "path": "src/egoallo/preprocessing/body_model/body_model.py",
    "content": "from loguru import logger as guru\nimport os\nfrom einops import rearrange\nimport torch\nimport torch.nn as nn\nfrom typing import Tuple, Dict\n\nfrom smplx import SMPLLayer, SMPLHLayer\nfrom smplx.vertex_ids import vertex_ids\nfrom smplx.utils import Struct\nfrom smplx.lbs import lbs as smpl_lbs\n\nfrom ..geometry import convert_rotation\nfrom ..util.tensor import pad_dim\n\nfrom .specs import SMPL_JOINTS\nfrom .utils import (\n    forward_kinematics,\n    inverse_kinematics,\n    select_vert_params,\n    get_verts_with_transforms,\n)\n\n\nclass BodyModel(nn.Module):\n    \"\"\"\n    Wrapper around SMPLX body model class.\n    \"\"\"\n\n    def __init__(\n        self,\n        bm_path,\n        model_type: str = \"smplh\",\n        use_pca: bool = True,\n        num_pca_comps: int = 6,\n        batch_size: int = 1,\n        use_vtx_selector: bool = True,\n        **kwargs,\n    ):\n        \"\"\"\n        Creates the body model object at the given path.\n        :param bm_path: path to the body model file\n        :param model_type: one of [smpl, smplh]\n        :param use_vtx_selector:\n            if true, returns additional vertices as joints that correspond to OpenPose joints\n        \"\"\"\n        super().__init__()\n        assert model_type in [\"smpl\", \"smplh\"]\n        self.model_type = model_type\n\n        self.use_pca = use_pca\n        self.num_pca_comps = num_pca_comps\n        self.use_vtx_selector = use_vtx_selector\n        cur_vertex_ids = None\n        if self.use_vtx_selector:\n            cur_vertex_ids = vertex_ids[model_type]\n        kwargs[\"vertex_ids\"] = cur_vertex_ids\n\n        ext = os.path.splitext(bm_path)[-1][1:]\n        if model_type == \"smpl\":\n            cls = SMPLLayer\n            self.hand_dim = 0\n            self.num_betas = 10\n        else:\n            cls = SMPLHLayer\n            self.hand_dim = cls.NUM_HAND_JOINTS * 3 if not use_pca else num_pca_comps\n            self.num_betas = 16\n\n        self.batch_size = batch_size\n        self.num_joints = cls.NUM_JOINTS + 1  # include root\n        self.num_body_joints = cls.NUM_BODY_JOINTS\n        # create body model without default parameters\n        self.bm = cls(\n            bm_path,\n            ext=ext,\n            num_betas=self.num_betas,\n            use_pca=use_pca,\n            num_pca_comps=num_pca_comps,\n            batch_size=batch_size,\n            **kwargs,\n        )\n        guru.info(f\"loading body model from {bm_path}, batch size {batch_size}\")\n\n        # make our own default buffers\n        self.var_dims = {\n            \"root_orient\": 3,\n            \"pose_body\": cls.NUM_BODY_JOINTS * 3,\n            \"pose_hand\": self.hand_dim * 2,\n            \"betas\": self.num_betas,\n            \"trans\": 3,\n        }\n        guru.info(f\"variable dims {self.var_dims}\")\n        for name, sh in self.var_dims.items():\n            self.register_buffer(name, torch.zeros(batch_size, sh))\n\n        # save the template joints\n        v_template = self.bm.v_template  # (V, 3)\n        J_regressor = self.bm.J_regressor  # (J, V)\n        joint_template = torch.matmul(J_regressor, v_template)[None]  # type: ignore\n\n        # save the extra joints from the vertex template\n\n        # (1, J, 3)\n        self.register_buffer(\"joint_template\", joint_template)\n        self.parents = self.bm.parents\n\n        shapedirs = self.bm.shapedirs  # (V, 3, B)\n        j_shapedirs = rearrange(\n            torch.einsum(\"jv,v...->j...\", J_regressor, shapedirs), \"a b c -> (a b) c\"\n        )\n        # (J * 3, B)\n        self.register_buffer(\"joint_shapedirs\", j_shapedirs)\n        # (B, J * 3)\n        # because overparameterized, use the fewest smpl joints\n        J = len(SMPL_JOINTS)\n        self.register_buffer(\n            \"joint_shapedirs_pinv\", torch.linalg.pinv(j_shapedirs[: J * 3])\n        )\n        self._recompute_inverse_beta_mat = False\n        # self.register_buffer(\"joint_shapedirs_pinv\", torch.linalg.pinv(j_shapedirs))\n        # self._recompute_inverse_beta_mat = True\n\n        for p in self.parameters():\n            p.requires_grad_(False)\n\n    def _fill_default_vars(self, model_args) -> Tuple[Dict, Dict]:\n        \"\"\"\n        fill in the missing variables with defaults padded to correct batch size\n        \"\"\"\n        B = self._get_batch_size(**model_args)\n\n        model_vars = {}\n        for name in self.var_dims:\n            var = model_args.pop(name, None)\n            if var is None:\n                var = self._get_default_model_var(name, B)\n            model_vars[name] = var\n        return model_vars, model_args\n\n    def _get_batch_size(self, **model_args) -> int:\n        \"\"\"\n        get the batch size of the input args\n        \"\"\"\n        B = self.batch_size\n        for name in self.var_dims:\n            if name in model_args and type(var := model_args[name]) == torch.Tensor:\n                B = var.shape[0]\n                break\n        return B\n\n    def _get_default_model_var(self, name: str, batch_size: int):\n        \"\"\"\n        if we have the desired variable, return it, otherwise return the default value\n        get model var with desired batch size\n        \"\"\"\n        return pad_dim(getattr(self, name), batch_size)\n\n    def get_full_pose_mats(self, model_args, add_mean: bool = True):\n        \"\"\"\n        get the full pose from provided model args\n        \"\"\"\n        B = self._get_batch_size(**model_args)\n        names = [\"root_orient\", \"pose_body\", \"pose_hand\"]\n        model_vars = {\n            k: model_args.get(k, self._get_default_model_var(k, B)) for k in names\n        }\n        root_mat = model_vars[\"root_orient\"]\n        if root_mat.ndim == 2:\n            root_mat = convert_rotation(\n                root_mat.unsqueeze(-2), \"aa\", \"mat\"\n            )  # (B, 1, 3, 3)\n        body_mat = model_vars[\"pose_body\"]\n        if body_mat.ndim == 2:\n            body_mat = convert_rotation(\n                body_mat.reshape(B, -1, 3), \"aa\", \"mat\"\n            )  # (B, J, 3, 3)\n        hand_mat = self.get_hand_pose_mat(\n            model_vars[\"pose_hand\"], add_mean=add_mean\n        )  # (B, H, 3, 3)\n        full_pose = torch.cat([root_mat, body_mat, hand_mat], dim=-3)\n        return full_pose\n\n    def get_hand_pose_mat(\n        self, pose_hand: torch.Tensor, add_mean: bool = True\n    ) -> torch.Tensor:\n        \"\"\"\n        get the hand joint rotations if applicable\n        :param pose_hand (*, D)\n        \"\"\"\n        if self.hand_dim == 0:\n            return pose_hand\n\n        B = pose_hand.shape[0]\n        if self.use_pca:\n            left_hand_pose = torch.einsum(\n                \"...i,ij->...j\",\n                pose_hand[..., : self.hand_dim],\n                self.bm.left_hand_components,\n            )\n            right_hand_pose = torch.einsum(\n                \"...i,ij->...j\",\n                pose_hand[..., self.hand_dim :],\n                self.bm.right_hand_components,\n            )\n            pose_hand = torch.cat([left_hand_pose, right_hand_pose], dim=-1)\n            if add_mean:\n                J = self.num_body_joints + 1\n                hand_mean = self.bm.pose_mean[..., 3 * J :]  # type: ignore\n                pose_hand += hand_mean\n\n        if pose_hand.ndim == 2:\n            pose_hand = convert_rotation(pose_hand.reshape(B, -1, 3), \"aa\", \"mat\")\n\n        return pose_hand\n\n    def forward_joints(self, **kwargs):\n        \"\"\"\n        forward on joints only\n        returns (*, J, 3) joints\n        \"\"\"\n        model_vars, _ = self._fill_default_vars(kwargs)\n\n        rot_mats = self.get_full_pose_mats(model_vars)\n        B = rot_mats.shape[0]\n        shape_diffs = torch.einsum(\n            \"ij,nj->ni\", self.joint_shapedirs, model_vars[\"betas\"]\n        )\n        shape_diffs = shape_diffs.reshape(B, -1, 3)\n        joints_shaped = self.joint_template + shape_diffs\n\n        joints_local, rel_transforms = forward_kinematics(\n            rot_mats, joints_shaped, self.parents  # type: ignore\n        )\n        if self.use_vtx_selector:\n            extra_joints = self.get_extra_joints(\n                model_vars[\"betas\"], rot_mats, rel_transforms\n            )\n            joints_local = torch.cat([joints_local, extra_joints], dim=-2)\n\n        return joints_local + model_vars[\"trans\"].unsqueeze(-2)\n\n    def get_extra_joints(self, betas, pose_mats, rel_transforms):\n        vtx_idcs = self.bm.vertex_joint_selector.extra_joints_idxs\n        v_template, shapedirs, posedirs, lbs_weights = select_vert_params(\n            vtx_idcs,  # type: ignore\n            self.bm.v_template,  # type: ignore\n            self.bm.shapedirs,  # type: ignore\n            self.bm.posedirs,  # type: ignore\n            self.bm.lbs_weights,  # type: ignore\n        )\n        return get_verts_with_transforms(\n            betas,\n            pose_mats,\n            rel_transforms,\n            v_template,\n            shapedirs,\n            posedirs,\n            lbs_weights,\n        )\n\n    def inverse_joints(self, joints: torch.Tensor, **kwargs):\n        \"\"\"\n        get the unposed joints (template pose)\n        \"\"\"\n        model_vars, _ = self._fill_default_vars(kwargs)\n        rot_mats = self.get_full_pose_mats(model_vars)\n        joints_local = joints - model_vars[\"trans\"].unsqueeze(-2)\n        return inverse_kinematics(rot_mats, joints_local, self.parents)  # type: ignore\n\n    def joints_to_beta(self, joint_unposed: torch.Tensor) -> torch.Tensor:\n        \"\"\"\n        get the nearest beta such that\n        joint_unposed = joint_template + A @ beta\n        :param (*, J, 3) joints\n        \"\"\"\n        # get the residual with the template\n        J = len(SMPL_JOINTS)\n        if self._recompute_inverse_beta_mat:\n            self.joint_shapedirs_pinv = torch.linalg.pinv(self.joint_shapedirs[: J * 3])  # type: ignore\n            self._recompute_inverse_beta_mat = False\n        dims = joint_unposed.shape[:-2]\n        joint_unposed = joint_unposed[..., :J, :]\n        joint_template = self.joint_template[..., :J, :]  # type: ignore\n        joint_delta = (joint_unposed - joint_template).reshape(*dims, J * 3)\n        betas = torch.einsum(\n            \"ij,...j->...i\", self.joint_shapedirs_pinv, joint_delta\n        )  # (*, B)\n        return betas\n\n    def forward(self, **kwargs):\n        \"\"\"\n        forward pass of smpl model\n        expects kwargs in [root_orient, pose_body, pose_hand, betas, trans]\n        to have same leading dimension if included, otherwise will pad itself\n        \"\"\"\n        model_vars, kwargs = self._fill_default_vars(kwargs)\n\n        rot_mats = self.get_full_pose_mats(model_vars)\n        verts, joints = smpl_lbs(\n            model_vars[\"betas\"],\n            rot_mats,  # type: ignore\n            self.bm.v_template,  # type: ignore\n            self.bm.shapedirs,  # type: ignore\n            self.bm.posedirs,  # type: ignore\n            self.bm.J_regressor,  # type: ignore\n            self.bm.parents,  # type: ignore\n            self.bm.lbs_weights,  # type: ignore\n            pose2rot=False,\n        )\n        joints = self.bm.vertex_joint_selector(verts, joints)\n        trans = model_vars[\"trans\"].unsqueeze(-2)\n        joints += trans\n        verts += trans\n\n        out = {\n            \"v\": verts,\n            \"f\": self.bm.faces_tensor,\n            \"Jtr\": joints,\n            \"full_pose\": rot_mats,\n        }\n\n        if not self.use_vtx_selector:  # don't need extra joints\n            out[\"Jtr\"] = out[\"Jtr\"][:, : self.num_joints]\n\n        return Struct(**out)\n"
  },
  {
    "path": "src/egoallo/preprocessing/body_model/skeleton.py",
    "content": "import numpy as np\nimport torch\n\nfrom .specs import SMPL_PARENTS, SMPL_JOINTS\n\n\n__all__ = [\n    \"NUM_KINEMATIC_CHAINS\",\n    \"smpl_kinematic_tree\",\n    \"joint_angles_rel_to_glob\",\n    \"joint_angles_glob_to_rel\",\n]\n\n\nNUM_KINEMATIC_CHAINS = 5\n\n\ndef smpl_kinematic_tree():\n    \"\"\"\n    get the SMPL kinematic tree as a list of chains of joint indices\n    \"\"\"\n    joint_idcs = list(range(len(SMPL_JOINTS)))\n    tree = []\n    chains = {}  # key: last vertex so far, value: chain(s))))\n    for joint in joint_idcs[::-1]:\n        parent = SMPL_PARENTS[joint]\n        if parent in chains or parent < 0:\n            continue\n        chains[parent] = [parent] + chains.pop(joint, [joint])\n    tree = []\n    for joint, chain in chains.items():\n        parent = SMPL_PARENTS[joint]\n        if parent >= 0:\n            chain = [parent] + chain\n        tree.insert(0, np.array(chain))\n    return tree\n\n\ndef joint_angles_rel_to_glob(rel_mats):\n    \"\"\"\n    convert joint angles\n    from relative (wrt to previous branch on kinematic chain)\n    to global (wrt to root of skeleton)\n    :param rotation matrices (*, 21, 3, 3)\n    return (*, 21, 3, 3)\n    \"\"\"\n    assert rel_mats.shape[-3] == len(SMPL_JOINTS) - 1\n    glob_mats = torch.zeros_like(rel_mats)\n\n    # aggregate transforms from parent to children\n    kin_tree = smpl_kinematic_tree()\n    for chain in kin_tree:\n        for pidx, cidx in zip(chain[:-1], chain[1:]):\n            # R_c0 = R_cp * R_p0\n            if pidx == 0:\n                glob_mats[..., cidx - 1, :, :] = rel_mats[..., cidx - 1, :, :]\n            else:\n                glob_mats[..., cidx - 1, :, :] = torch.matmul(\n                    rel_mats[..., cidx - 1, :, :], glob_mats[..., pidx - 1, :, :]\n                )\n    return glob_mats\n\n\ndef joint_angles_glob_to_rel(glob_mats):\n    \"\"\"\n    convert joint angles\n    from global (wrt to root of skeleton)\n    to relative (wrt to previous branch on kinematic chain)\n    :param rotation matrices (*, 21, 3, 3)\n    return (*, 21, 3, 3)\n    \"\"\"\n    rel_mats = torch.zeros_like(glob_mats)\n    assert glob_mats.shape[-3] == len(SMPL_JOINTS) - 1\n\n    # add the root matrix to global rotations\n    dims = glob_mats.shape[:-3]\n    I = (\n        torch.eye(3, device=glob_mats.device)\n        .reshape(*(1,) * len(dims), 1, 3, 3)\n        .expand(*dims, 1, 3, 3)\n    )\n    glob_mats = torch.cat([I, glob_mats], dim=-3)\n\n    # invert transforms from parent to children\n    kin_tree = smpl_kinematic_tree()\n    for chain in kin_tree:\n        pidx, cidx = chain[:-1], chain[1:]\n        # R_cp = R_c0 * R_0p\n        rel_mats[..., cidx - 1, :, :] = torch.matmul(\n            glob_mats[..., cidx, :, :], glob_mats[..., pidx, :, :].transpose(-1, -2)\n        )\n    return rel_mats\n"
  },
  {
    "path": "src/egoallo/preprocessing/body_model/specs.py",
    "content": "import numpy as np\n\nSMPL_JOINTS = {\n    \"hips\": 0,\n    \"leftUpLeg\": 1,\n    \"rightUpLeg\": 2,\n    \"spine\": 3,\n    \"leftLeg\": 4,\n    \"rightLeg\": 5,\n    \"spine1\": 6,\n    \"leftFoot\": 7,\n    \"rightFoot\": 8,\n    \"spine2\": 9,\n    \"leftToeBase\": 10,\n    \"rightToeBase\": 11,\n    \"neck\": 12,\n    \"leftShoulder\": 13,\n    \"rightShoulder\": 14,\n    \"head\": 15,\n    \"leftArm\": 16,\n    \"rightArm\": 17,\n    \"leftForeArm\": 18,\n    \"rightForeArm\": 19,\n    \"leftHand\": 20,\n    \"rightHand\": 21,\n}\n\nSMPL_PARENTS = np.array(\n    [\n        -1,\n        0,\n        0,\n        0,\n        1,\n        2,\n        3,\n        4,\n        5,\n        6,\n        7,\n        8,\n        9,\n        12,\n        12,\n        12,\n        13,\n        14,\n        16,\n        17,\n        18,\n        19,\n    ]\n)\n\n# reflect joints\nRIGHT_CHAIN = np.array([2, 5, 8, 11, 14, 17, 19, 21])\nLEFT_CHAIN = np.array([1, 4, 7, 10, 13, 16, 18, 20])\nREFLECT_PERM = np.array(\n    [\n        0,\n        2,\n        1,\n        3,\n        5,\n        4,\n        6,\n        8,\n        7,\n        9,\n        11,\n        10,\n        12,\n        14,\n        13,\n        15,\n        17,\n        16,\n        19,\n        18,\n        21,\n        20,\n    ]\n)\nPOSE_REFLECT_PERM = np.concatenate([3 * i + np.arange(3) for i in REFLECT_PERM], axis=0)\n\n# root, left knee, right knee, left heel, right heel,\n# left toe, right toe, left hand, right hand\nCONTACT_JOINTS = [\n    \"hips\",\n    \"leftLeg\",\n    \"rightLeg\",\n    \"leftFoot\",\n    \"rightFoot\",\n    \"leftToeBase\",\n    \"rightToeBase\",\n    \"leftHand\",\n    \"rightHand\",\n]\nCONTACT_INDS = [SMPL_JOINTS[joint] for joint in CONTACT_JOINTS]\n\nFEET_JOINTS = [\n    \"leftToeBase\",\n    \"rightToeBase\",\n]\nFEET_INDS = [SMPL_JOINTS[joint] for joint in FEET_JOINTS]\n\n# chosen virtual mocap markers that are \"keypoints\" to work with\nKEYPT_VERTS = [\n    4404,\n    920,\n    3076,\n    3169,\n    823,\n    4310,\n    1010,\n    1085,\n    4495,\n    4569,\n    6615,\n    3217,\n    3313,\n    6713,\n    6785,\n    3383,\n    6607,\n    3207,\n    1241,\n    1508,\n    4797,\n    4122,\n    1618,\n    1569,\n    5135,\n    5040,\n    5691,\n    5636,\n    5404,\n    2230,\n    2173,\n    2108,\n    134,\n    3645,\n    6543,\n    3123,\n    3024,\n    4194,\n    1306,\n    182,\n    3694,\n    4294,\n    744,\n]\n\n\n\"\"\"\nOpenpose\n\"\"\"\nOP_NUM_JOINTS = 25\n# OP_IGNORE_JOINTS = [1, 9, 12]  # neck and left/right hip\nOP_IGNORE_JOINTS = [1]  # neck\nOP_EDGE_LIST = [\n    [1, 8],\n    [1, 2],\n    [1, 5],\n    [2, 3],\n    [3, 4],\n    [5, 6],\n    [6, 7],\n    [8, 9],\n    [9, 10],\n    [10, 11],\n    [8, 12],\n    [12, 13],\n    [13, 14],\n    [1, 0],\n    [0, 15],\n    [15, 17],\n    [0, 16],\n    [16, 18],\n    [14, 19],\n    [19, 20],\n    [14, 21],\n    [11, 22],\n    [22, 23],\n    [11, 24],\n]\n# indices to map an openpose detection to its flipped version\nOP_FLIP_MAP = [\n    0,\n    1,\n    5,\n    6,\n    7,\n    2,\n    3,\n    4,\n    8,\n    12,\n    13,\n    14,\n    9,\n    10,\n    11,\n    16,\n    15,\n    18,\n    17,\n    22,\n    23,\n    24,\n    19,\n    20,\n    21,\n]\n\n\n# From https://github.com/vchoutas/smplify-x/blob/master/smplifyx/utils.py\n# Please see license for usage restrictions.\ndef smpl_to_openpose(\n    model_type=\"smplh\",\n    use_hands=False,\n    use_face=False,\n    use_face_contour=False,\n    openpose_format=\"coco25\",\n):\n    \"\"\"Returns the indices of the permutation that maps SMPL to OpenPose\n\n    Parameters\n    ----------\n    model_type: str, optional\n        The type of SMPL-like model that is used. The default mapping\n        returned is for the SMPLX model\n    use_hands: bool, optional\n        Flag for adding to the returned permutation the mapping for the\n        hand keypoints. Defaults to True\n    use_face: bool, optional\n        Flag for adding to the returned permutation the mapping for the\n        face keypoints. Defaults to True\n    use_face_contour: bool, optional\n        Flag for appending the facial contour keypoints. Defaults to False\n    openpose_format: bool, optional\n        The output format of OpenPose. For now only COCO-25 and COCO-19 is\n        supported. Defaults to 'coco25'\n\n    \"\"\"\n    if openpose_format.lower() == \"coco25\":\n        if model_type == \"smpl\":\n            return np.array(\n                [\n                    24,\n                    12,\n                    17,\n                    19,\n                    21,\n                    16,\n                    18,\n                    20,\n                    0,\n                    2,\n                    5,\n                    8,\n                    1,\n                    4,\n                    7,\n                    25,\n                    26,\n                    27,\n                    28,\n                    29,\n                    30,\n                    31,\n                    32,\n                    33,\n                    34,\n                ],\n                dtype=np.int32,\n            )\n        elif model_type == \"smplh\":\n            body_mapping = np.array(\n                [\n                    52,\n                    12,\n                    17,\n                    19,\n                    21,\n                    16,\n                    18,\n                    20,\n                    0,\n                    2,\n                    5,\n                    8,\n                    1,\n                    4,\n                    7,\n                    53,\n                    54,\n                    55,\n                    56,\n                    57,\n                    58,\n                    59,\n                    60,\n                    61,\n                    62,\n                ],\n                dtype=np.int32,\n            )\n            mapping = [body_mapping]\n            if use_hands:\n                lhand_mapping = np.array(\n                    [\n                        20,\n                        34,\n                        35,\n                        36,\n                        63,\n                        22,\n                        23,\n                        24,\n                        64,\n                        25,\n                        26,\n                        27,\n                        65,\n                        31,\n                        32,\n                        33,\n                        66,\n                        28,\n                        29,\n                        30,\n                        67,\n                    ],\n                    dtype=np.int32,\n                )\n                rhand_mapping = np.array(\n                    [\n                        21,\n                        49,\n                        50,\n                        51,\n                        68,\n                        37,\n                        38,\n                        39,\n                        69,\n                        40,\n                        41,\n                        42,\n                        70,\n                        46,\n                        47,\n                        48,\n                        71,\n                        43,\n                        44,\n                        45,\n                        72,\n                    ],\n                    dtype=np.int32,\n                )\n                mapping += [lhand_mapping, rhand_mapping]\n            return np.concatenate(mapping)\n        # SMPLX\n        elif model_type == \"smplx\":\n            body_mapping = np.array(\n                [\n                    55,\n                    12,\n                    17,\n                    19,\n                    21,\n                    16,\n                    18,\n                    20,\n                    0,\n                    2,\n                    5,\n                    8,\n                    1,\n                    4,\n                    7,\n                    56,\n                    57,\n                    58,\n                    59,\n                    60,\n                    61,\n                    62,\n                    63,\n                    64,\n                    65,\n                ],\n                dtype=np.int32,\n            )\n            mapping = [body_mapping]\n            if use_hands:\n                lhand_mapping = np.array(\n                    [\n                        20,\n                        37,\n                        38,\n                        39,\n                        66,\n                        25,\n                        26,\n                        27,\n                        67,\n                        28,\n                        29,\n                        30,\n                        68,\n                        34,\n                        35,\n                        36,\n                        69,\n                        31,\n                        32,\n                        33,\n                        70,\n                    ],\n                    dtype=np.int32,\n                )\n                rhand_mapping = np.array(\n                    [\n                        21,\n                        52,\n                        53,\n                        54,\n                        71,\n                        40,\n                        41,\n                        42,\n                        72,\n                        43,\n                        44,\n                        45,\n                        73,\n                        49,\n                        50,\n                        51,\n                        74,\n                        46,\n                        47,\n                        48,\n                        75,\n                    ],\n                    dtype=np.int32,\n                )\n\n                mapping += [lhand_mapping, rhand_mapping]\n            if use_face:\n                #  end_idx = 127 + 17 * use_face_contour\n                face_mapping = np.arange(\n                    76, 127 + 17 * use_face_contour, dtype=np.int32\n                )\n                mapping += [face_mapping]\n\n            return np.concatenate(mapping)\n        else:\n            raise ValueError(\"Unknown model type: {}\".format(model_type))\n    elif openpose_format == \"coco19\":\n        if model_type == \"smpl\":\n            return np.array(\n                [24, 12, 17, 19, 21, 16, 18, 20, 0, 2, 5, 8, 1, 4, 7, 25, 26, 27, 28],\n                dtype=np.int32,\n            )\n        elif model_type == \"smplh\":\n            body_mapping = np.array(\n                [52, 12, 17, 19, 21, 16, 18, 20, 0, 2, 5, 8, 1, 4, 7, 53, 54, 55, 56],\n                dtype=np.int32,\n            )\n            mapping = [body_mapping]\n            if use_hands:\n                lhand_mapping = np.array(\n                    [\n                        20,\n                        34,\n                        35,\n                        36,\n                        57,\n                        22,\n                        23,\n                        24,\n                        58,\n                        25,\n                        26,\n                        27,\n                        59,\n                        31,\n                        32,\n                        33,\n                        60,\n                        28,\n                        29,\n                        30,\n                        61,\n                    ],\n                    dtype=np.int32,\n                )\n                rhand_mapping = np.array(\n                    [\n                        21,\n                        49,\n                        50,\n                        51,\n                        62,\n                        37,\n                        38,\n                        39,\n                        63,\n                        40,\n                        41,\n                        42,\n                        64,\n                        46,\n                        47,\n                        48,\n                        65,\n                        43,\n                        44,\n                        45,\n                        66,\n                    ],\n                    dtype=np.int32,\n                )\n                mapping += [lhand_mapping, rhand_mapping]\n            return np.concatenate(mapping)\n        # SMPLX\n        elif model_type == \"smplx\":\n            body_mapping = np.array(\n                [55, 12, 17, 19, 21, 16, 18, 20, 0, 2, 5, 8, 1, 4, 7, 56, 57, 58, 59],\n                dtype=np.int32,\n            )\n            mapping = [body_mapping]\n            if use_hands:\n                lhand_mapping = np.array(\n                    [\n                        20,\n                        37,\n                        38,\n                        39,\n                        60,\n                        25,\n                        26,\n                        27,\n                        61,\n                        28,\n                        29,\n                        30,\n                        62,\n                        34,\n                        35,\n                        36,\n                        63,\n                        31,\n                        32,\n                        33,\n                        64,\n                    ],\n                    dtype=np.int32,\n                )\n                rhand_mapping = np.array(\n                    [\n                        21,\n                        52,\n                        53,\n                        54,\n                        65,\n                        40,\n                        41,\n                        42,\n                        66,\n                        43,\n                        44,\n                        45,\n                        67,\n                        49,\n                        50,\n                        51,\n                        68,\n                        46,\n                        47,\n                        48,\n                        69,\n                    ],\n                    dtype=np.int32,\n                )\n\n                mapping += [lhand_mapping, rhand_mapping]\n            if use_face:\n                face_mapping = np.arange(\n                    70, 70 + 51 + 17 * use_face_contour, dtype=np.int32\n                )\n                mapping += [face_mapping]\n\n            return np.concatenate(mapping)\n        else:\n            raise ValueError(\"Unknown model type: {}\".format(model_type))\n    else:\n        raise ValueError(\"Unknown joint format: {}\".format(openpose_format))\n"
  },
  {
    "path": "src/egoallo/preprocessing/body_model/utils.py",
    "content": "from jaxtyping import Float, Int\nfrom typing import Tuple, Optional\nimport numpy as np\nimport torch\nfrom torch import Tensor\nimport torch.nn.functional as F\nfrom torch import Tensor\n\nfrom ..geometry import (\n    get_rot_rep_shape,\n    convert_rotation,\n    batch_apply_Rt,\n    make_transform,\n    transform_rel_to_global,\n    transform_global_to_rel,\n)\nfrom .specs import SMPL_JOINTS, smpl_to_openpose, POSE_REFLECT_PERM\nfrom .skeleton import joint_angles_glob_to_rel\n\n\n__all__ = [\n    \"run_smpl\",\n    \"reflect_pose_aa\",\n    \"reflect_root_trajectory\",\n    \"forward_kinematics\",\n    \"inverse_kinematics\",\n    \"smpl_local_to_global\",\n    \"select_smpl_joints\",\n    \"get_openpose_from_smpl\",\n    \"convert_local_pose_to_aa\",\n    \"convert_global_pose_to_aa\",\n    \"load_beta_conversion\",\n    \"convert_model_betas\",\n]\n\n\ndef run_smpl(\n    body_model, mats_in: bool = False, return_verts: bool = True, **kwargs\n) -> Tuple[Tensor, Optional[Tensor], Optional[Tensor]]:\n    \"\"\"\n    helper function for running SMPL model with multiple leading dimensions\n    return joints and optionally verts and faces\n    :param body_model\n    :param return_verts (optional bool=True)\n    \"\"\"\n    device = body_model.bm.shapedirs.device\n    dims = (body_model.batch_size,)\n    fields = [\"pose_body\", \"root_orient\", \"trans\", \"betas\"]\n    dim_idcs = [-3 if mats_in else -1, -2 if mats_in else -1, -1, -1]\n    for name, idx in zip(fields, dim_idcs):\n        if name in kwargs:\n            x = kwargs[name]\n            if x is None:\n                continue\n            dims, sh = x.shape[:idx], x.shape[idx:]\n            kwargs[name] = x.reshape(-1, *sh).to(device)\n\n    if not return_verts:\n        joints = body_model.forward_joints(**kwargs).reshape(*dims, -1, 3)\n        return joints, None, None\n\n    smpl = body_model(**kwargs)\n    joints = smpl.Jtr.reshape(*dims, -1, 3)\n    verts = smpl.v.reshape(*dims, -1, 3)\n    return joints, verts, smpl.f\n\n\ndef reflect_pose_aa(root_orient: Tensor, pose_body: Tensor):\n    \"\"\"\n    :param root_orient (*, 3)\n    :param pose_body (*, (J-1)*3)\n    return reflected root_orient and pose_body\n    \"\"\"\n    pose_full = torch.cat([root_orient, pose_body], dim=-1)  # (*, J*3)\n    pose_reflect = pose_full[..., POSE_REFLECT_PERM]\n    pose_reflect[..., 1::3] = -pose_reflect[..., 1::3]\n    pose_reflect[..., 2::3] = -pose_reflect[..., 2::3]\n    return pose_reflect[..., :3], pose_reflect[..., 3:]\n\n\ndef reflect_root_trajectory(\n    rot_aa: Tensor, trans: Tensor, rot_aa_r: Tensor, root_loc: Tensor\n) -> Tuple[Tensor, Tensor]:\n    # rotation from t to world\n    R_wt = convert_rotation(rot_aa, \"aa\", \"mat\")\n    # get the transforms of the root in the world\n    T_wt = make_transform(R_wt, trans + root_loc)\n    # transform from t to previous\n    T_pt = transform_global_to_rel(T_wt)\n    # rotation from reflected t to world\n    R_wtr = convert_rotation(rot_aa_r, \"aa\", \"mat\")\n    # relative transforms\n    R_prtr = transform_global_to_rel(R_wtr)\n    # get the displacement between t and t-1 in t, the SOURCE frame\n    # t_prt = R_prtr * t_tr, where t_tr = (-1, 1, 1) * t_t, and t_t = R_tp * t_pt\n    t_tt = torch.einsum(\"tij,tj->ti\", torch.linalg.inv(T_pt[:, :3, :3]), T_pt[:, :3, 3])\n    # reflect through x\n    t_trtr = torch.cat([-t_tt[..., :1], t_tt[..., 1:]], dim=-1)\n    # convert displacement into the t-1 TARGET frame\n    t_prtr = torch.einsum(\"tij,tj->ti\", R_prtr, t_trtr)\n    # get back global trajectory\n    T_prtr = make_transform(R_prtr, t_prtr)\n    T_wtr = transform_rel_to_global(T_prtr)\n    # get back the smpl translation\n    trans_wtr = T_wtr[:, :3, 3] - root_loc\n    return rot_aa_r, trans_wtr\n\n\ndef forward_kinematics(\n    rot_mats: Tensor,\n    joints_in: Tensor,\n    parents: Tensor,\n) -> Tuple[Tensor, Tensor]:\n    \"\"\"\n    get the forward transformed joints\n    very similar to smplx's batch_rigid_transform with more flexible batch dimensions\n    :param rot_mats (*, J, 3, 3) joint rotations from joint i to parent\n    :param joints_in (*, J, 3)\n    :param parents (J)\n    returns (*, J, 4, 4) tensor of transforms\n    \"\"\"\n    J = len(parents)\n    joints_body_rel = (\n        joints_in[..., 1:J, :] - joints_in[..., parents[1:], :]\n    )  # (*, J-1, 3)\n    joints_rel = torch.cat(\n        [joints_in[..., :1, :], joints_body_rel], dim=-2\n    )  # (*, J, 3)\n    T_pi = make_transform(rot_mats, joints_rel)  # (*, J, 4, 4)\n    tforms_wp = [T_pi[..., 0, :, :]]\n    for i in range(1, J):\n        tforms_wp.append(torch.matmul(tforms_wp[parents[i]], T_pi[..., i, :, :]))\n    transforms = torch.stack(tforms_wp, dim=-3)\n    joints_posed = transforms[..., :3, 3]  # (*, J, 3)\n    rel_trans_h = F.pad(\n        joints_posed\n        - torch.einsum(\"...ij,...j->...i\", transforms[..., :3, :3], joints_in),\n        [0, 1],\n        value=1.0,\n    ).unsqueeze(-1)\n    rel_transforms = torch.cat([transforms[..., :3], rel_trans_h], dim=-1)\n    return joints_posed, rel_transforms\n\n\ndef get_pose_offsets(\n    pose_mats: Float[Tensor, \"*batch J 3 3\"],\n    posedirs: Float[Tensor, \"P N\"],\n) -> Float[Tensor, \"*batch J 3\"]:\n    dims = pose_mats.shape[:-3]\n    I = torch.eye(3, device=pose_mats.device).reshape(*(1,) * len(dims), 1, 3, 3)\n    pose_feat = (pose_mats[..., 1:, :, :] - I).reshape(*dims, -1)  # (*, P)\n    return torch.einsum(\"...p,pn->...n\", pose_feat, posedirs).reshape(*dims, -1, 3)\n\n\ndef select_vert_params(\n    idcs: Int[Tensor, \"S\"],\n    v_template: Float[Tensor, \"V 3\"],\n    shapedirs: Float[Tensor, \"V 3 B\"],\n    posedirs: Float[Tensor, \"P N\"],\n    lbs_weights: Float[Tensor, \"V J\"],\n) -> Tuple[Tensor, Tensor, Tensor, Tensor]:\n    pose_idcs = torch.repeat_interleave(3 * idcs, 3, -1)\n    pose_idcs[1::3] += 1\n    pose_idcs[2::3] += 2\n    return v_template[idcs], shapedirs[idcs], posedirs[:, pose_idcs], lbs_weights[idcs]\n\n\ndef get_verts_with_transforms(\n    betas: Float[Tensor, \"*batch B\"],\n    pose_mats: Float[Tensor, \"*batch J 3 3\"],\n    rel_transforms: Float[Tensor, \"*batch J 4 4\"],\n    v_template: Float[Tensor, \"V 3\"],\n    shapedirs: Float[Tensor, \"V 3 B\"],\n    posedirs: Float[Tensor, \"P N\"],\n    lbs_weights: Float[Tensor, \"V J\"],\n):\n    # (*, V, 3)\n    v_shaped = v_template + torch.einsum(\"...l,mkl->...mk\", betas, shapedirs)\n    v_posed = v_shaped + get_pose_offsets(pose_mats, posedirs)\n    T = torch.einsum(\"ij,...jkl->...ikl\", lbs_weights, rel_transforms)  # (*, V, 4, 4)\n    v_out = torch.einsum(\"...ij,...j->...i\", T[..., :3, :3], v_posed) + T[..., :3, 3]\n    return v_out\n\n\ndef inverse_kinematics(rot_mats: Tensor, joints: Tensor, parents: Tensor) -> Tensor:\n    \"\"\"\n    given the joint rotations and locations of a posed skeleton,\n    invert and get the template skeleton\n    :param rot_mats (*, J, 3, 3) rotation from joint i to parent (R_pi)\n    :param joints (*, J, 3) posed joint locations\n    :param parents (J)\n    returns (*, J, 3) template joints\n    \"\"\"\n    # J = len(parents)\n    J = joints.shape[-2]\n    # delta between joint and parent in the world\n    delta_w = torch.cat(\n        [joints[..., :1, :], joints[..., 1:J, :] - joints[..., parents[1:J], :]], dim=-2\n    )\n    # rot mats from parent to joint i\n    rots_ip = rot_mats.transpose(-1, -2)\n    # get the world to parent rotation matrices\n    rots_pw = [rots_ip[..., 0, :, :]]\n    trans_p = [joints[..., 0, :]]\n    for i in range(1, J):\n        # R_iw = R_ip R_pw = (R_pi.T) R_pw\n        R_pw = rots_pw[parents[i]]\n        delta_p = torch.einsum(\"...ij,...j->...i\", R_pw, delta_w[..., i, :])\n        trans_p.append(trans_p[parents[i]] + delta_p)\n        if i >= J - 1:\n            break\n        rots_pw.append(torch.matmul(rots_ip[..., i, :, :], R_pw))\n    return torch.stack(trans_p, dim=-2)\n\n\ndef smpl_local_to_global(\n    R_root: Tensor, t_root: Tensor, points_l: Tensor, root_l: Tensor\n) -> Tensor:\n    \"\"\"\n    transform local smpl body to global\n    :param T_root (*, 4, 4) root transform from local to world\n    :param points_l (*, N, 3) points to transform\n    :param root_l (*, 1, 3) root in local coordinates\n    \"\"\"\n    return batch_apply_Rt(R_root, t_root, points_l - root_l) + root_l\n\n\ndef select_smpl_joints(joints_full):\n    \"\"\"\n    select the first 22 smpl joints from the full joints\n    :param joints_full (*, J, 3)\n    \"\"\"\n    return joints_full[..., : len(SMPL_JOINTS), :]\n\n\ndef get_openpose_from_smpl(joints_smpl, model_type=\"smplh\"):\n    smpl2op_map = smpl_to_openpose(\n        model_type,\n        use_hands=False,\n        use_face=False,\n        use_face_contour=False,\n        openpose_format=\"coco25\",\n    )\n    joints3d_op = joints_smpl[..., smpl2op_map, :]\n    # hacky way to get hip joints that align with ViTPose keypoints\n    # this could be moved elsewhere in the future (and done properly)\n    joints3d_op[..., [9, 12], :] = (\n        joints3d_op[..., [9, 12], :]\n        + 0.25 * (joints3d_op[..., [9, 12], :] - joints3d_op[..., [12, 9], :])\n        + 0.5\n        * (\n            joints3d_op[..., [8], :]\n            - 0.5 * (joints3d_op[..., [9, 12], :] + joints3d_op[..., [12, 9], :])\n        )\n    )\n    return joints3d_op\n\n\ndef convert_local_pose_to_aa(pose_body: Tensor, rot_rep: str):\n    \"\"\"\n    convert local pose in rotation representation into flattened axis-angle\n    :param pose_body (*, J*D)\n    :param rot_rep (str)\n    returns (*, J*3) flattened aa pose\n    \"\"\"\n    if rot_rep == \"aa\":\n        return pose_body\n    dims = pose_body.shape[:-1]\n    rot_sh = get_rot_rep_shape(rot_rep)\n    pose_aa = convert_rotation(\n        pose_body.reshape(*dims, -1, *rot_sh), rot_rep, \"aa\"\n    )  # (*, J, 3)\n    return pose_aa.reshape(*dims, -1)\n\n\ndef convert_global_pose_to_aa(pose_glob: Tensor, rot_rep: str):\n    \"\"\"\n    :param pose_glob (*, J*D)\n    :param rot_rep (str)\n    returnns (*, J*3) local pose flattened aa\n    \"\"\"\n    dims = pose_glob.shape[:-1]\n    rot_sh = get_rot_rep_shape(rot_rep)\n    pose_glob_mat = convert_rotation(\n        pose_glob.reshape(*dims, -1, *rot_sh), rot_rep, \"mat\"\n    )\n    pose_rel_mat = joint_angles_glob_to_rel(pose_glob_mat)  # (*, J, 3, 3)\n    return convert_rotation(pose_rel_mat, \"mat\", \"aa\").reshape(*dims, -1)\n\n\ndef load_beta_conversion(path: str) -> Tuple[Tensor, Tensor]:\n    data = np.load(path)\n    return torch.from_numpy(data[\"A\"].astype(\"float32\")), torch.from_numpy(\n        data[\"b\"].astype(\"float32\")\n    )\n\n\ndef convert_model_betas(beta: Tensor, A: Tensor, b: Tensor) -> Tensor:\n    \"\"\"\n    :param beta (*, B)\n    :param A (B, B)\n    :param b (B)\n    beta_neutral = A @ beta_gender + b\n    \"\"\"\n    *dims, B = beta.shape\n    A = A.reshape((*(1,) * len(dims), *A.shape))\n    b = b.reshape((*(1,) * len(dims), *b.shape))\n    return torch.einsum(\"...ij,...j->...i\", A, beta) + b\n"
  },
  {
    "path": "src/egoallo/preprocessing/geometry/__init__.py",
    "content": "from .rotation import *\nfrom .helpers import *\nfrom . import plane\nfrom . import camera\nfrom . import transforms\n"
  },
  {
    "path": "src/egoallo/preprocessing/geometry/camera.py",
    "content": "from typing import Tuple\nimport torch\nimport numpy as np\n\n# import lietorch as tf\nfrom . import transforms as tf\nfrom .helpers import batch_apply_Rt\n\n\ndef project_from_world(X_w, R_cw, t_cw, intrins):\n    \"\"\"\n    :param X_w (*, N, 3)\n    :param cam_R (*, 3, 3)\n    :param cam_t (*, 3)\n    :param intrins (*, 4)\n    \"\"\"\n    return proj_2d(batch_apply_Rt(R_cw, t_cw, X_w), intrins)\n\n\ndef proj_2d(xyz, intrins, eps=1e-4):\n    \"\"\"\n    :param xyz (*, 3/4) 3d/4d point in camera coordinates\n    :param intrins (*, 4) fx, fy, cx, cy\n    return (*, 2) of reprojected points, (*) of points in front of camera\n    \"\"\"\n    z = xyz[..., 2:3]\n    valid_mask = z > eps\n    disp = torch.where(valid_mask, 1.0 / (z + eps), torch.ones_like(z))\n    focal = intrins[..., :2]\n    center = intrins[..., 2:]\n    return focal * disp * xyz[..., :2] + center, valid_mask[..., 0]\n\n\ndef proj_h(xyzw):\n    \"\"\"\n    project homogeneous point\n    \"\"\"\n    w = xyzw[..., -1:]\n    return xyzw[..., :-1] * torch.where(w > 0, 1.0 / w, w)\n\n\ndef iproj_depth(uv, z, intrins):\n    \"\"\"\n    inverse project into 3d coords from depth\n    :param uv (*, 2)\n    :param z (*, 1)\n    :param intrins (*, 4)\n    :returns (*, 3)\n    \"\"\"\n    focal = intrins[..., :2]\n    center = intrins[..., 2:]\n    return z * torch.cat([(uv - center) / focal, torch.ones_like(z)], dim=-1)\n\n\ndef iproj(uv, disp, intrins):\n    \"\"\"\n    inverse project from disparity. returns 4d homogeneous\n    :param uv (*, 2)\n    :param disp (*, 1)\n    :param intrins (*, 4)\n    :returns (*, 4)\n    \"\"\"\n    x = normalize_coords(uv, intrins)\n    X = torch.cat([x, torch.ones_like(disp), disp], dim=-1)\n    return X\n\n\ndef normalize_coords(uv, intrins):\n    focal = intrins[..., :2]\n    center = intrins[..., 2:]\n    return (uv - center) / focal\n\n\ndef iproj_to_world(uv, disp, intrins, extrins, ret_3d=True):\n    \"\"\"\n    inverse project disparity into world coords. default returns 3d\n    :param uv (*, 2)\n    :param disp (*, 1)\n    :param intrins (*, 4)\n    :param extrins (*, 7)\n    :param ret_3d (optional bool) return in 3d coords, default True\n    :returns (*, 3)\n    \"\"\"\n    T_wc = tf.SE3(extrins).inv()\n    X_c = iproj(uv, disp, intrins)\n    X_w = T_wc.act(X_c)\n    if ret_3d:\n        return proj_h(X_w)\n    return X_w\n\n\ndef reproject(pose_params, intrins, disps, uv, ii, jj):\n    \"\"\"\n    :param pose_params (T, *, 7) pose parameters\n    :param intrins (T, *, 4) fx, fy, cx, cy\n    :param uv (T, *, 2) coordinate grid\n    :param disps (T, *, 1) disparity\n    :param ii (N) source index array into parameters\n    :param jj (N) target index array into parameters\n    returns (N, *, 2) points in ii reprojected into jj\n    \"\"\"\n    T_i, T_j = tf.SE3(pose_params[ii]), tf.SE3(pose_params[jj])\n    Xh_i = iproj(uv, disps[ii], intrins[ii])\n    Xh_j = T_j.mul(T_i.inv()).act(Xh_i)\n    return proj_2d(Xh_j, intrins[jj])\n\n\ndef proj_2d_jac(X, intrins):\n    \"\"\"\n    :param X (*, 4) point in camera coordinates\n    :param intrins (*, 4) fx, fy, cx, cy\n    return (*, 2, 4)\n    \"\"\"\n    fx, fy, cx, cy = intrins.unbind(dim=-1)\n    X, Y, Z, D = X.unbind(dim=-1)\n    d = torch.where(Z > 0.1, 1.0 / Z, torch.ones_like(Z))\n    o = torch.zeros_like(d)\n    return torch.stack(\n        [fx * d, o, -fx * X * d * d, o, o, fy * d, -fy * Y * d * d, o],\n        dim=-1,\n    ).reshape(*d.shape, 2, 4)\n\n\ndef actp_jac(X1):\n    \"\"\"\n    :param X1 (*, 4) point after transformation\n    \"\"\"\n    x, y, z, d = X1.unbind(dim=-1)\n    o = torch.zeros_like(d)\n    return torch.stack(\n        [d, o, o, o, z, -y, o, d, o, -z, o, x, o, o, d, y, -x, o, o, o, o, o, o, o],\n        dim=-1,\n    ).reshape(*d.shape, 4, 6)\n\n\ndef iproj_jac(X):\n    \"\"\"\n    jacobian for inverse projection to 4d\n    \"\"\"\n    J = torch.zeros_like(X)\n    J[..., -1] = 1\n    return J\n\n\ndef make_homogeneous(x):\n    \"\"\"\n    :param x (*, 3)\n    returns x in homogeneous coordinates\n    \"\"\"\n    return torch.cat([x, torch.ones_like(x[..., :1])], dim=-1)\n\n\ndef make_transform(R, t):\n    \"\"\"\n    :param R (*, 3, 3)\n    :param t (*, 3)\n    return (*, 4, 4)\n    \"\"\"\n    dims = R.shape[:-2]\n    bottom = (\n        torch.tensor([0, 0, 0, 1], dtype=R.dtype, device=R.device)\n        .reshape(*(1,) * len(dims), 1, 4)\n        .repeat(*dims, 1, 1)\n    )\n    return torch.cat([torch.cat([R, t.unsqueeze(-1)], dim=-1), bottom], dim=-2)\n\n\ndef focal2fov(focal, R):\n    \"\"\"\n    :param focal, focal length\n    :param R, either W / 2 or H / 2\n    \"\"\"\n    return 2 * np.arctan(R / focal)\n\n\ndef fov2focal(fov, R):\n    \"\"\"\n    :param fov, field of view in radians\n    :param R, either W / 2 or H / 2\n    \"\"\"\n    return R / np.tan(fov / 2)\n\n\ndef lookat_matrix(source_pos, target_pos, up):\n    \"\"\"\n    uses x right y down z forward opencv convention\n    :param source_pos (*, 3)\n    :param target_pos (*, 3)\n    :param up (3,)\n    \"\"\"\n    *dims, _ = source_pos.shape\n    up = up.reshape(*(1,) * len(dims), 3)\n    up = up / torch.linalg.norm(up, dim=-1, keepdim=True)\n    back = normalize(source_pos - target_pos)\n    right = normalize(torch.linalg.cross(up, back))\n    up = normalize(torch.linalg.cross(back, right))\n    R = torch.stack([right, -up, -back], dim=-1)\n    return make_transform(R, source_pos)\n\n\ndef normalize(x):\n    return x / torch.linalg.norm(x, dim=-1, keepdim=True)\n\n\ndef view_matrix(z, up, pos):\n    \"\"\"\n    :param z (*, 3) up (*, 3) pos (*, 3)\n    returns (*, 4, 4)\n    \"\"\"\n    *dims, _ = z.shape\n    x = normalize(torch.linalg.cross(up, z))\n    y = normalize(torch.linalg.cross(z, x))\n    bottom = (\n        torch.tensor([0, 0, 0, 1], dtype=torch.float32)\n        .reshape(*(1,) * len(dims), 1, 4)\n        .expand(*dims, 1, 4)\n    )\n\n    return torch.cat([torch.stack([x, y, z, pos], dim=-1), bottom], dim=-2)\n\n\ndef average_pose(poses):\n    \"\"\"\n    :param poses (N, 4, 4)\n    returns average pose (4, 4)\n    \"\"\"\n    center = poses[:, :3, 3].mean(0)\n    up = normalize(poses[:, :3, 1].sum(0))\n    z = normalize(poses[:, :3, 2].sum(0))\n    return view_matrix(z, up, center)\n\n\ndef make_translation(t):\n    return make_transform(torch.eye(3, device=t.device), t)\n\n\ndef make_rotation(rx=0, ry=0, rz=0, order=\"xyz\"):\n    Rx = rotx(rx)\n    Ry = roty(ry)\n    Rz = rotz(rz)\n    if order == \"xyz\":\n        R = Rz @ Ry @ Rx\n    elif order == \"xzy\":\n        R = Ry @ Rz @ Rx\n    elif order == \"yxz\":\n        R = Rz @ Rx @ Ry\n    elif order == \"yzx\":\n        R = Rx @ Rz @ Ry\n    elif order == \"zyx\":\n        R = Rx @ Ry @ Rz\n    elif order == \"zxy\":\n        R = Ry @ Rx @ Rz\n    else:\n        raise NotImplementedError\n    return make_transform(R, torch.zeros(3))\n\n\ndef rotx(theta):\n    return torch.tensor(\n        [\n            [1, 0, 0],\n            [0, np.cos(theta), -np.sin(theta)],\n            [0, np.sin(theta), np.cos(theta)],\n        ],\n        dtype=torch.float32,\n    )\n\n\ndef roty(theta):\n    return torch.tensor(\n        [\n            [np.cos(theta), 0, np.sin(theta)],\n            [0, 1, 0],\n            [-np.sin(theta), 0, np.cos(theta)],\n        ],\n        dtype=torch.float32,\n    )\n\n\ndef rotz(theta):\n    return torch.tensor(\n        [\n            [np.cos(theta), -np.sin(theta), 0],\n            [np.sin(theta), np.cos(theta), 0],\n            [0, 0, 1],\n        ],\n        dtype=torch.float32,\n    )\n\n\ndef identity(shape: Tuple, d=4, **kwargs):\n    I = torch.eye(d, **kwargs)\n    return I.reshape(*(1,) * len(shape), d, d).repeat(*shape, 1, 1)\n"
  },
  {
    "path": "src/egoallo/preprocessing/geometry/helpers.py",
    "content": "import torch\nimport numpy as np\n\nfrom .rotation import convert_rotation\n\n\ndef make_transform(R, t):\n    \"\"\"\n    :param R (*, 3, 3)\n    :param t (*, 3)\n    \"\"\"\n    dims = R.shape[:-2]\n    pose_3x4 = torch.cat([R, t.view(*dims, 3, 1)], dim=-1)\n    bottom = (\n        torch.tensor([0, 0, 0, 1], device=R.device)\n        .reshape(*(1,) * len(dims), 1, 4)\n        .expand(*dims, 1, 4)\n    )\n    return torch.cat([pose_3x4, bottom], dim=-2)\n\n\ndef transform_points(T, x):\n    \"\"\"\n    :param T (*, 4, 4)\n    :param x (*, N, 3)\n    \"\"\"\n    R = T[..., :3, :3]\n    t = T[..., :3, 3]\n    return batch_apply_Rt(R, t, x)\n\n\ndef batch_apply_Rt(R, t, x):\n    \"\"\"\n    :param R (*, 3, 3)\n    :param t (*, 3)\n    :param x (*, N, 3)\n    \"\"\"\n    return torch.einsum(\"...ij,...nj->...ni\", R, x) + t.unsqueeze(-2)\n\n\ndef transform_global_to_rel(T_glob):\n    \"\"\"\n    get the relative transforms (diffs) from global transform of trajectory\n    :param T_glob (*, T, 4, 4) root to world transform\n    return root t->t-1 transform (*, T, 4, 4)\n    \"\"\"\n    T_rel = torch.matmul(\n        torch.linalg.inv(T_glob[..., :-1, :, :]), T_glob[..., 1:, :, :]\n    )  # (*, T-1, 4, 4)\n    return torch.cat([T_glob[..., :1, :, :], T_rel], dim=-3)\n\n\ndef transform_rel_to_global(T_rel):\n    \"\"\"\n    convert relative transforms into global trajectory\n    :param T_rel (*, T, 4, 4) root t -> t-1 transform\n    return root t -> world transform (*, T, 4, 4)\n    \"\"\"\n    N = T_rel.shape[-3]\n    T_rel_list = T_rel.unbind(dim=-3)\n    T_glob_list = [T_rel_list[0]]\n    for t in range(1, N):\n        T_cur = torch.matmul(T_glob_list[t - 1], T_rel_list[t])\n        T_glob_list.append(T_cur)\n    return torch.stack(T_glob_list, dim=-3)\n\n\ndef RT_global_to_rel(R_glob, t_glob):\n    \"\"\"\n    :param R_glob (*, T, 3, 3) root to world rotation\n    :param t_glob (*, T, 3) root to world translation\n    returns root t -> t-1 rotation (*, T, 3, 3) and translation (*, T, 3)\n    \"\"\"\n    T_glob = make_transform(R_glob, t_glob)  # (*, T, 4, 4) root to world\n    T_rel = transform_global_to_rel(T_glob)\n    return T_rel[..., :3, :3], T_rel[..., :3, 3]\n\n\ndef RT_rel_to_global(R_rel, t_rel):\n    \"\"\"\n    :param R_rel (*, T, 3, 3) root at t -> root at t-1 rotation\n    :param t_rel (*, T, 3) root at t -> root at t-1 translation\n    return root to world rotation (*, T, 3, 3) and translation (*, T, 3)\n    \"\"\"\n    T_rel = make_transform(R_rel, t_rel)  # (*, T, 4, 4)\n    T_glob = transform_rel_to_global(T_rel)\n    return T_glob[..., :3, :3], T_glob[..., :3, 3]\n\n\ndef joints_local_to_global(\n    root_orient, trans, joints_loc, use_rel: bool = True, rot_rep: str = \"6d\"\n):\n    \"\"\"\n    convert joints in local coords to global coordinates\n    (X_w - root) = T_wl * (X_l - root)\n    :param trans (*, T, 3)\n    :param root_orient (*, T, *rot_shape)\n    :param joints_loc (*, T, J * 3)\n    :param use_rel (optional bool) if true, root trajectory specified as relative transforms\n    returns global joint locations (*, T, J, 3)\n    \"\"\"\n    root_orient_mat = convert_rotation(root_orient, rot_rep, \"mat\")  # (B, T, 3, 3)\n    T_wl = make_transform(root_orient_mat, trans)\n    if use_rel:  # global translation and orientation are in diffs\n        T_wl = transform_rel_to_global(T_wl)\n\n    joints_loc = joints_loc.reshape(*trans.shape[:-1], -1, 3)\n    root_loc = joints_loc[..., :1, :]\n    return transform_points(T_wl, joints_loc - root_loc) + root_loc\n\n\ndef joints_global_to_local(root_orient_mat, trans, joints_glob, joints_vel_glob=None):\n    \"\"\"\n    convert joints in global coords to local coords\n    i.e. smpl output with zero root_orient and trans\n    (X_w - root) = T_wl * (X_l - root)\n    :param trans (*, 3)\n    :param root_orient_mat (*, 3, 3)\n    :param joints_glob (*, J, 3)\n    :param joints_vel_glob (optional) (*, J, 3)\n    returns local joint locations (*, J, 3)\n    \"\"\"\n    T_lw = torch.linalg.inv(make_transform(root_orient_mat, trans))  # (*, 4, 4)\n    root_loc = joints_glob[..., :1, :] - trans.unsqueeze(-2)  # (*, 1, 3)\n    joints_loc = transform_points(T_lw, joints_glob - root_loc) + root_loc\n    joints_vel_loc = None\n    if joints_vel_glob is not None:  # no translation\n        joints_vel_loc = torch.einsum(\n            \"...ij,...nj->...ni\", T_lw[..., :3, :3], joints_vel_glob\n        )  # (*, J, 3)\n    return joints_loc, joints_vel_loc\n\n\ndef align_pcl(Y, X, weight=None, fixed_scale=False):\n    \"\"\"\n    align similarity transform to align X with Y using umeyama method\n    X' = s * R * X + t is aligned with Y\n    :param Y (*, N, 3) first trajectory\n    :param X (*, N, 3) second trajectory\n    :param weight (*, N, 1) optional weight of valid correspondences\n    :returns s (*, 1), R (*, 3, 3), t (*, 3)\n    \"\"\"\n    *dims, N, _ = Y.shape\n    device = X.device\n    N = torch.ones(*dims, 1, 1, device=device) * N\n\n    if weight is not None:\n        N = weight.sum(dim=-2, keepdim=True)  # (*, 1, 1)\n\n    # subtract mean\n    my = Y.sum(dim=-2) / N[..., 0]  # (*, 3)\n    mx = X.sum(dim=-2) / N[..., 0]\n    y0 = Y - my[..., None, :]  # (*, N, 3)\n    x0 = X - mx[..., None, :]\n\n    if weight is not None:\n        y0 = y0 * weight\n        x0 = x0 * weight\n\n    # correlation\n    C = torch.matmul(y0.transpose(-1, -2), x0) / N  # (*, 3, 3)\n    U, D, Vh = torch.linalg.svd(C)  # (*, 3, 3), (*, 3), (*, 3, 3)\n\n    S = torch.eye(3, device=device).reshape(*(1,) * (len(dims)), 3, 3).repeat(*dims, 1, 1)\n    neg = torch.det(U) * torch.det(Vh.transpose(-1, -2)) < 0\n    S[neg, 2, 2] = -1\n\n    R = torch.matmul(U, torch.matmul(S, Vh))  # (*, 3, 3)\n\n    D = torch.diag_embed(D)  # (*, 3, 3)\n    if fixed_scale:\n        s = torch.ones(*dims, 1, device=device, dtype=torch.float32)\n    else:\n        var = torch.sum(torch.square(x0), dim=(-1, -2), keepdim=True) / N  # (*, 1, 1)\n        s = (\n            torch.diagonal(torch.matmul(D, S), dim1=-2, dim2=-1).sum(\n                dim=-1, keepdim=True\n            )\n            / var[..., 0]\n        )  # (*, 1)\n\n    t = my - s * torch.matmul(R, mx[..., None])[..., 0]  # (*, 3)\n\n    return s, R, t\n\n\ndef get_translation_scale(fps=30):\n    \"\"\"\n    scale relative translation into (m/s), over average walking speed ~1.5 m/s\n    i.e. scale delta such that average walking speed -> 1\n    \"\"\"\n    return 1.0 * fps\n\n\ndef estimate_velocity(data_seq, h=1 / 30):\n    \"\"\"\n    Given some data sequence of T timesteps in the shape (T, ...), estimates\n    the velocity for the middle T-2 steps using a second order central difference scheme.\n    - h : step size\n    \"\"\"\n    data_tp1 = data_seq[2:]\n    data_tm1 = data_seq[0:-2]\n    data_vel_seq = (data_tp1 - data_tm1) / (2 * h)\n    return data_vel_seq\n\n\ndef estimate_angular_velocity(rot_seq, h=1 / 30):\n    \"\"\"\n    Given a sequence of T rotation matrices, estimates angular velocity at T-2 steps.\n    Input sequence should be of shape (T, ..., 3, 3)\n    \"\"\"\n    # see https://en.wikipedia.org/wiki/Angular_velocity#Calculation_from_the_orientation_matrix\n    dRdt = estimate_velocity(rot_seq, h)\n    R = rot_seq[1:-1]\n    RT = np.swapaxes(R, -1, -2)\n    # compute skew-symmetric angular velocity tensor\n    w_mat = np.matmul(dRdt, RT)\n\n    # pull out angular velocity vector\n    # average symmetric entries\n    w_x = (-w_mat[..., 1, 2] + w_mat[..., 2, 1]) / 2.0\n    w_y = (w_mat[..., 0, 2] - w_mat[..., 2, 0]) / 2.0\n    w_z = (-w_mat[..., 0, 1] + w_mat[..., 1, 0]) / 2.0\n    w = np.stack([w_x, w_y, w_z], axis=-1)\n    return w\n"
  },
  {
    "path": "src/egoallo/preprocessing/geometry/plane.py",
    "content": "from jaxtyping import Float\nfrom typing import Tuple, Optional\nimport torch\nfrom torch import Tensor\nimport torch.nn.functional as F\nfrom .rotation import axis_angle_to_matrix\nfrom .helpers import make_transform\n\n\ndef transform_align_body_right(root_orient_mat, trans, **kwargs):\n    \"\"\"\n    make the transform aligns body_right (-x) with x axis and moves trans to origin\n    :param root_orient_mat (3, 3)\n    :param trans (3,)\n    \"\"\"\n    # move first frame to origin and transform root orient x to [1, 0, 0]\n    R_align_x = rotation_align_body_right(root_orient_mat, **kwargs)  # (3, 3)\n    t_align_x = -R_align_x @ trans\n    return make_transform(R_align_x, t_align_x)\n\n\ndef rotation_align_body_right(\n    root_orient_mat, up=[0.0, 0.0, 1.0], right=[-1.0, 0.0, 0.0], **kwargs\n):\n    \"\"\"\n    compute the rotation that aligns local body right vector (-x)\n    with [1, 0, 0] (+x) via rotation about up axis (+z)\n    :param root_orient_mat (*, 3, 3)\n    :param up vector (*, 3) default [0, 0, 1]\n    :param right vector (*, 3) default [1, 0, 0]\n    returns (*, 3, 3) rotation matrix\n    \"\"\"\n    root_x = -root_orient_mat[..., 0]\n    nldims = root_x.ndim - 1\n    up = torch.as_tensor(up, device=root_x.device)\n    right = torch.as_tensor(right, device=root_x.device)\n    if up.ndim < root_x.ndim:\n        up = up.reshape(*(1,) * nldims, 3)\n    if right.ndim < root_x.ndim:\n        right = right.reshape(*(1,) * nldims, 3)\n\n    # project root_x to floor plane (perpendicular to up)\n    root_x = root_x - project_vector(root_x, up)\n    return rotation_align_vecs(root_x, right)\n\n\ndef compute_world2aligned(T_w0, **kwargs):\n    \"\"\"\n    compute alignment transform to take T_w0 to aligned frame\n    where body right is -x, and up is as specified (default +z)\n    :param T_w0 (*, 4, 4)\n    return (*, 4, 4)\n    \"\"\"\n    R_aw = rotation_align_body_right(T_w0[..., :3, :3], **kwargs)  # (*..., 3, 3)\n    t_aw = torch.einsum(\"...ij,...j->...i\", -R_aw, T_w0[..., :3, 3])\n    T_aw = make_transform(R_aw, t_aw)  # (*, 4, 4)\n    return T_aw\n\n\ndef rotation_align_vecs(src, target):\n    \"\"\"\n    compute rotation taking src to target through the shared plane\n    :param src (*, 3)\n    :param target (*, 3)\n    return (*, 3, 3) rotation matrix\n    \"\"\"\n    axis = F.normalize(torch.linalg.cross(src, target), dim=-1)\n    angle = torch.arccos(\n        (src * target).sum(dim=-1) / (src.norm(dim=-1) * target.norm(dim=-1))\n    )\n    return axis_angle_to_matrix(axis * angle.unsqueeze(-1))\n\n\ndef compute_point_height(point, floor_plane):\n    \"\"\"\n    compute height of point from floor_plane\n    :param point (*, 3)\n    :param floor_plane (*, 3)\n    \"\"\"\n    floor_plane_4d = parse_floor_plane(floor_plane)\n    floor_normal = floor_plane_4d[..., :3]\n    # compute the distance from root to ground plane\n    _, s_root = compute_plane_intersection(point, -floor_normal, floor_plane_4d)\n    return s_root\n\n\ndef compute_world2floor(\n    floor_plane_4d, root_orient_mat, trans\n) -> Tuple[Tensor, Tensor]:\n    \"\"\"\n    compute the transform from world frame (opencv +x right, +y down, +z forward),\n    to floor frame (-x body right, +y up, with origin at trans)\n    :param floor_plane (*, 4) floor plane in world coordinates\n    :param root_orient_mat (*, 3, 3) root orientation in world\n    :param trans (*, 3) root trans in world\n    \"\"\"\n    floor_normal = floor_plane_4d[..., :3]\n\n    # compute prior frame axes in the camera frame\n    # right is body +x direction projected to floor plane\n    root_x = root_orient_mat[..., 0]\n    x = F.normalize(root_x - project_vector(root_x, floor_normal), dim=-1)\n    y = floor_normal\n    z = F.normalize(torch.linalg.cross(x, y), dim=-1)\n\n    # floor frame in world is body x right, floor normal up\n    R_wf = torch.stack([x, y, z], dim=-1)\n    R_fw = torch.linalg.inv(R_wf)\n    t_fw = torch.einsum(\"...ij,...j->...i\", -R_fw, trans)\n    return R_fw, t_fw\n\n\ndef compute_plane_transform(\n    plane_4d: Float[Tensor, \"*batch 4\"],\n    up: Float[Tensor, \"*batch 3\"],\n    origin: Optional[Float[Tensor, \"*batch 3\"]] = None,\n):\n    \"\"\"\n    compute the R and t transform from identity, where plane normal is up\n    \"\"\"\n    normal = plane_4d[..., :3]\n    offset = plane_4d[..., 3:]\n    normal = F.normalize(normal, dim=-1)\n    up = F.normalize(up, dim=-1)\n    v = torch.linalg.cross(up, normal)  # (*, 3)\n    vnorm = torch.linalg.norm(v, dim=-1, keepdim=True)  # (*, 1)\n    s = torch.arcsin(vnorm) / vnorm\n    R = axis_angle_to_matrix(v * s)  # (*, 3, 3)\n    if origin is not None:\n        t, _ = compute_plane_intersection(origin, -normal, plane_4d)\n    else:\n        # translate plane along normal vector\n        t = normal * offset  # (*, 3)\n    return R, t\n\n\ndef fit_plane(\n    points: Float[Tensor, \"*batch N 3\"],\n    weights: Optional[Float[Tensor, \"*batch N 1\"]] = None,\n    force_sign: int = -1,\n) -> Float[Tensor, \"*batch 4\"]:\n    \"\"\"\n    :param points (*, N, 3)\n    returns (*, 4) plane parameters (returns in (normal, offset) format)\n    \"\"\"\n    *dims, _ = points.shape\n    device = points.device\n    if weights is None:\n        weights = torch.ones(*dims, 1, device=device)\n\n    mean = (weights * points).sum(dim=-2, keepdim=True) / weights.sum(\n        dim=-2, keepdim=True\n    )\n    # (*, N, 3), (*, 3), (*, 3, 3)\n    _, _, Vh = torch.linalg.svd(weights * (points - mean))\n    normal = Vh[..., -1, :]  # (*, 3)\n    offset = torch.einsum(\"...ij,...j->...i\", points, normal)  # (*, N)\n    w = weights[..., 0]  # (*, N)\n    offset = ((w * offset).sum(dim=-1) / w.sum(dim=-1)).unsqueeze(-1)  # (*, 1)\n    if force_sign != 0:\n        normal, offset = force_plane_direction(normal, offset, sign=force_sign)\n    return torch.cat([normal, offset], dim=-1)\n\n\ndef parse_floor_plane(floor_plane: Tensor, force_sign: int = -1) -> Tensor:\n    \"\"\"\n    Takes floor plane in the optimization form (Bx3 with a,b,c * d) and parses into\n    (a,b,c,d) from with (a,b,c) normal facing \"up in the camera frame and d the offset.\n    \"\"\"\n    if floor_plane.shape[-1] == 4:\n        return floor_plane\n\n    floor_offset = torch.linalg.norm(floor_plane, dim=-1, keepdim=True)\n    floor_normal = floor_plane / (floor_offset + 1e-5)\n\n    # there's ambiguity in the signs of the normal and offset,\n    # force the sign of the normal to be positive or negative depending\n    # on convention\n    if force_sign != 0:\n        floor_normal, floor_offset = force_plane_direction(\n            floor_normal, floor_offset, sign=force_sign\n        )\n\n    return torch.cat([floor_normal, floor_offset], dim=-1)\n\n\ndef force_plane_direction(\n    floor_normal: Float[Tensor, \"*batch 3\"],\n    floor_offset: Float[Tensor, \"*batch 1\"],\n    sign: int = -1,\n) -> Tuple[Float[Tensor, \"*batch 3\"], Float[Tensor, \"*batch 1\"]]:\n    assert sign != 0\n    if sign > 0:\n        mask = floor_normal[..., 1:2] < 0\n    else:\n        mask = floor_normal[..., 1:2] > 0\n\n    floor_normal = torch.where(\n        mask.expand_as(floor_normal), -floor_normal, floor_normal\n    )\n    floor_offset = torch.where(mask, -floor_offset, floor_offset)\n    return floor_normal, floor_offset\n\n\ndef compute_plane_intersection(point, direction, plane, eps=1e-5):\n    \"\"\"\n    given a ray defined by a point in space and a direction,\n    compute the intersection point with the given plane.\n    :param point (*, 3)\n    :param direction (*, 3)\n    :param plane (*, 4) (normal, offset)\n    returns:\n        - itsct_pt (*, 3)\n        - s (*, 1) s.t. itsct_pt = point + s * direction\n    \"\"\"\n    plane_normal = plane[..., :3]\n    plane_off = plane[..., 3:]\n    s = (plane_off - bdot(plane_normal, point)) / (bdot(plane_normal, direction) + eps)\n    itsct_pt = point + s * direction\n    return itsct_pt, s\n\n\ndef project_vector(x, d):\n    \"\"\"\n    project x onto d\n    :param x, d (*, 3)\n    \"\"\"\n    d = F.normalize(d, dim=-1)\n    return bdot(x, d) * d\n\n\ndef bdot(A1, A2, keepdim=True, **kwargs):\n    \"\"\"\n    batched dot product\n    :param A1, A2 (*, D)\n    returs (*, 1)\n    \"\"\"\n    return (A1 * A2).sum(dim=-1, keepdim=keepdim, **kwargs)\n"
  },
  {
    "path": "src/egoallo/preprocessing/geometry/rotation.py",
    "content": "from typing import Tuple\nimport torch\nfrom torch.nn import functional as F\n\n\ndef get_rot_rep_shape(rot_rep:str) -> Tuple:\n    assert rot_rep in [\"aa\", \"quat\", \"6d\", \"mat\"]\n    if rot_rep == \"6d\":\n        return (6,)\n    if rot_rep == \"aa\":\n        return (3,)\n    if rot_rep == \"quat\":\n        return (4,)\n    return (3, 3)\n\n\ndef convert_rotation(rot, src_rep, tgt_rep):\n    src_rep, tgt_rep = src_rep.lower(), tgt_rep.lower()\n    if src_rep == tgt_rep:\n        return rot\n\n    if src_rep == \"aa\":\n        if tgt_rep == \"mat\":\n            return axis_angle_to_matrix(rot)\n        if tgt_rep == \"quat\":\n            return axis_angle_to_quaternion(rot)\n        if tgt_rep == \"6d\":\n            return axis_angle_to_cont_6d(rot)\n        raise NotImplementedError\n    if src_rep == \"quat\":\n        if tgt_rep == \"aa\":\n            return quaternion_to_axis_angle(rot)\n        if tgt_rep == \"mat\":\n            return quaternion_to_matrix(rot)\n        if tgt_rep == \"6d\":\n            return matrix_to_cont_6d(quaternion_to_matrix(rot))\n        raise NotImplementedError\n    if src_rep == \"mat\":\n        if tgt_rep == \"6d\":\n            return matrix_to_cont_6d(rot)\n        if tgt_rep == \"aa\":\n            return matrix_to_axis_angle(rot)\n        if tgt_rep == \"quat\":\n            return matrix_to_quaternion(rot)\n        raise NotImplementedError\n    if src_rep == \"6d\":\n        if tgt_rep == \"mat\":\n            return cont_6d_to_matrix(rot)\n        if tgt_rep == \"aa\":\n            return cont_6d_to_axis_angle(rot)\n        if tgt_rep == \"quat\":\n            return matrix_to_quaternion(cont_6d_to_matrix(rot))\n        raise NotImplementedError\n    raise NotImplementedError\n\n\ndef rodrigues_vec_to_matrix(rot_vecs, dtype=torch.float32):\n    \"\"\"\n    Calculates the rotation matrices for a batch of rotation vectors\n    referenced from https://github.com/mkocabas/VIBE/blob/master/lib/utils/geometry.py\n    :param rot_vecs (*, 3) axis-angle vectors\n    :returns rot_mats (*, 3, 3)\n    \"\"\"\n    dims = rot_vecs.shape[:-1]  # leading dimensions\n    device, dtype = rot_vecs.device, rot_vecs.dtype\n\n    angle = torch.norm(rot_vecs + 1e-8, dim=-1, keepdim=True)  # (*, 1)\n    rot_dir = rot_vecs / angle  # (*, 3)\n\n    cos = torch.unsqueeze(torch.cos(angle), dim=-2)  # (*, 1, 1)\n    sin = torch.unsqueeze(torch.sin(angle), dim=-2)  # (*, 1, 1)\n\n    rx, ry, rz = torch.split(rot_dir, 1, dim=-1)  # (*, 1) each\n    zeros = torch.zeros(*dims, 1, dtype=dtype, device=device)\n    K = torch.cat([zeros, -rz, ry, rz, zeros, -rx, -ry, rx, zeros], dim=-1).view(\n        (*dims, 3, 3)\n    )\n    I = torch.eye(3, dtype=dtype, device=device).reshape(*(1,) * len(dims), 3, 3)\n    rot_mat = I + sin * K + (1 - cos) * torch.einsum(\"...ij,...jk->...ik\", K, K)\n    return rot_mat\n\n\ndef matrix_to_axis_angle(matrix):\n    \"\"\"\n    Convert rotation matrix to Rodrigues vector\n    \"\"\"\n    quaternion = matrix_to_quaternion(matrix)\n    aa = quaternion_to_axis_angle(quaternion)\n    aa[torch.isnan(aa)] = 0.0\n    return aa\n\n\ndef axis_angle_to_matrix(rot_vec):\n    quaternion = axis_angle_to_quaternion(rot_vec)\n    return quaternion_to_matrix(quaternion)\n\n\ndef axis_angle_to_cont_6d(rot_vec):\n    \"\"\"\n    :param rot_vec (*, 3)\n    :returns 6d vector (*, 6)\n    \"\"\"\n    rot_mat = axis_angle_to_matrix(rot_vec)\n    return matrix_to_cont_6d(rot_mat)\n\n\ndef matrix_to_cont_6d(matrix):\n    \"\"\"\n    :param matrix (*, 3, 3)\n    :returns 6d vector (*, 6)\n    \"\"\"\n    return torch.cat([matrix[..., 0], matrix[..., 1]], dim=-1)\n\n\ndef cont_6d_to_matrix(cont_6d):\n    \"\"\"\n    :param 6d vector (*, 6)\n    :returns matrix (*, 3, 3)\n    \"\"\"\n    x1 = cont_6d[..., 0:3]\n    y1 = cont_6d[..., 3:6]\n\n    x = F.normalize(x1, dim=-1)\n    y = F.normalize(y1 - (y1 * x).sum(dim=-1, keepdim=True) * x, dim=-1)\n    z = torch.linalg.cross(x, y, dim=-1)\n\n    return torch.stack([x, y, z], dim=-1)\n\n\ndef cont_6d_to_axis_angle(cont_6d):\n    rot_mat = cont_6d_to_matrix(cont_6d)\n    return matrix_to_axis_angle(rot_mat)\n\n\ndef quaternion_to_axis_angle(quaternion, eps=1e-5):\n    \"\"\"\n    This function is borrowed from https://github.com/kornia/kornia\n\n    Convert quaternion vector to angle axis of rotation.\n    Adapted from ceres C++ library: ceres-solver/include/ceres/rotation.h\n\n    :param quaternion (*, 4) expects WXYZ\n    :returns axis_angle (*, 3)\n    \"\"\"\n    # unpack input and compute conversion\n    q1 = quaternion[..., 1]\n    q2 = quaternion[..., 2]\n    q3 = quaternion[..., 3]\n    sin_squared_theta = q1 * q1 + q2 * q2 + q3 * q3\n\n    sin_theta = torch.sqrt(sin_squared_theta)\n    cos_theta = quaternion[..., 0]\n    two_theta = 2.0 * torch.where(\n        cos_theta < -eps,\n        torch.atan2(-sin_theta, -cos_theta),\n        torch.atan2(sin_theta, cos_theta),\n    )\n\n    k_pos = two_theta / sin_theta\n    k_neg = 2.0 * torch.ones_like(sin_theta)\n    k = torch.where(sin_squared_theta > eps, k_pos, k_neg)\n\n    axis_angle = torch.zeros_like(quaternion)[..., :3]\n    axis_angle[..., 0] += q1 * k\n    axis_angle[..., 1] += q2 * k\n    axis_angle[..., 2] += q3 * k\n    return axis_angle\n\n\ndef quaternion_to_matrix(quaternion):\n    \"\"\"\n    Convert a quaternion to a rotation matrix.\n    Taken from https://github.com/kornia/kornia, based on\n    https://github.com/matthew-brett/transforms3d/blob/8965c48401d9e8e66b6a8c37c65f2fc200a076fa/transforms3d/quaternions.py#L101\n    https://github.com/tensorflow/graphics/blob/master/tensorflow_graphics/geometry/transformation/rotation_matrix_3d.py#L247\n    :param quaternion (N, 4) expects WXYZ order\n    returns rotation matrix (N, 3, 3)\n    \"\"\"\n    # normalize the input quaternion\n    quaternion_norm = F.normalize(quaternion, p=2, dim=-1, eps=1e-8)\n    *dims, _ = quaternion_norm.shape\n\n    # unpack the normalized quaternion components\n    w, x, y, z = torch.chunk(quaternion_norm, chunks=4, dim=-1)\n\n    # compute the actual conversion\n    tx = 2.0 * x\n    ty = 2.0 * y\n    tz = 2.0 * z\n    twx = tx * w\n    twy = ty * w\n    twz = tz * w\n    txx = tx * x\n    txy = ty * x\n    txz = tz * x\n    tyy = ty * y\n    tyz = tz * y\n    tzz = tz * z\n    one = torch.tensor(1.0)\n\n    matrix = torch.stack(\n        (\n            one - (tyy + tzz),\n            txy - twz,\n            txz + twy,\n            txy + twz,\n            one - (txx + tzz),\n            tyz - twx,\n            txz - twy,\n            tyz + twx,\n            one - (txx + tyy),\n        ),\n        dim=-1,\n    ).view(*dims, 3, 3)\n    return matrix\n\n\ndef axis_angle_to_quaternion(axis_angle, eps=1e-5):\n    \"\"\"\n    This function is borrowed from https://github.com/kornia/kornia\n    Convert angle axis to quaternion in WXYZ order\n    :param axis_angle (*, 3)\n    :returns quaternion (*, 4) WXYZ order\n    \"\"\"\n    theta = torch.linalg.norm(axis_angle, dim=-1, keepdim=True)\n    theta_sq = torch.square(theta)\n    # theta_sq = torch.sum(axis_angle ** 2, dim=-1, keepdim=True)  # (*, 1)\n    # theta = torch.sqrt(theta_sq + eps)\n    # need to handle the zero rotation case\n    valid = theta_sq > eps\n    half_theta = 0.5 * theta\n    ones = torch.ones_like(half_theta)\n    # fill zero with the limit of sin ax / x -> a\n    k = torch.where(valid, torch.sin(half_theta) / (theta + eps), 0.5 * ones)\n    w = torch.where(valid, torch.cos(half_theta), ones)\n    quat = torch.cat([w, k * axis_angle], dim=-1)\n    return quat\n\n\ndef matrix_to_quaternion(matrix, eps=1e-6):\n    \"\"\"\n    This function is borrowed from https://github.com/kornia/kornia\n    Convert rotation matrix to 4d quaternion vector\n    This algorithm is based on algorithm described in\n    https://github.com/KieranWynn/pyquaternion/blob/master/pyquaternion/quaternion.py#L201\n\n    :param matrix (N, 3, 3)\n    \"\"\"\n    *dims, m, n = matrix.shape\n    rmat_t = torch.transpose(matrix.reshape(-1, m, n), -1, -2)\n\n    mask_d2 = rmat_t[:, 2, 2] < eps\n\n    mask_d0_d1 = rmat_t[:, 0, 0] > rmat_t[:, 1, 1]\n    mask_d0_nd1 = rmat_t[:, 0, 0] < -rmat_t[:, 1, 1]\n\n    t0 = 1 + rmat_t[:, 0, 0] - rmat_t[:, 1, 1] - rmat_t[:, 2, 2]\n    q0 = torch.stack(\n        [\n            rmat_t[:, 1, 2] - rmat_t[:, 2, 1],\n            t0,\n            rmat_t[:, 0, 1] + rmat_t[:, 1, 0],\n            rmat_t[:, 2, 0] + rmat_t[:, 0, 2],\n        ],\n        -1,\n    )\n    t0_rep = t0.repeat(4, 1).t()\n\n    t1 = 1 - rmat_t[:, 0, 0] + rmat_t[:, 1, 1] - rmat_t[:, 2, 2]\n    q1 = torch.stack(\n        [\n            rmat_t[:, 2, 0] - rmat_t[:, 0, 2],\n            rmat_t[:, 0, 1] + rmat_t[:, 1, 0],\n            t1,\n            rmat_t[:, 1, 2] + rmat_t[:, 2, 1],\n        ],\n        -1,\n    )\n    t1_rep = t1.repeat(4, 1).t()\n\n    t2 = 1 - rmat_t[:, 0, 0] - rmat_t[:, 1, 1] + rmat_t[:, 2, 2]\n    q2 = torch.stack(\n        [\n            rmat_t[:, 0, 1] - rmat_t[:, 1, 0],\n            rmat_t[:, 2, 0] + rmat_t[:, 0, 2],\n            rmat_t[:, 1, 2] + rmat_t[:, 2, 1],\n            t2,\n        ],\n        -1,\n    )\n    t2_rep = t2.repeat(4, 1).t()\n\n    t3 = 1 + rmat_t[:, 0, 0] + rmat_t[:, 1, 1] + rmat_t[:, 2, 2]\n    q3 = torch.stack(\n        [\n            t3,\n            rmat_t[:, 1, 2] - rmat_t[:, 2, 1],\n            rmat_t[:, 2, 0] - rmat_t[:, 0, 2],\n            rmat_t[:, 0, 1] - rmat_t[:, 1, 0],\n        ],\n        -1,\n    )\n    t3_rep = t3.repeat(4, 1).t()\n\n    mask_c0 = mask_d2 * mask_d0_d1\n    mask_c1 = mask_d2 * ~mask_d0_d1\n    mask_c2 = ~mask_d2 * mask_d0_nd1\n    mask_c3 = ~mask_d2 * ~mask_d0_nd1\n    mask_c0 = mask_c0.view(-1, 1).type_as(q0)\n    mask_c1 = mask_c1.view(-1, 1).type_as(q1)\n    mask_c2 = mask_c2.view(-1, 1).type_as(q2)\n    mask_c3 = mask_c3.view(-1, 1).type_as(q3)\n\n    q = q0 * mask_c0 + q1 * mask_c1 + q2 * mask_c2 + q3 * mask_c3\n    q /= torch.sqrt(\n        t0_rep * mask_c0\n        + t1_rep * mask_c1\n        + t2_rep * mask_c2  # noqa\n        + t3_rep * mask_c3\n    )  # noqa\n    q *= 0.5\n    return q.reshape(*dims, 4)\n\n\ndef quaternion_mul(q0, q1):\n    \"\"\"\n    EXPECTS WXYZ\n    :param q0 (*, 4)\n    :param q1 (*, 4)\n    \"\"\"\n    r0, r1 = q0[..., :1], q1[..., :1]\n    v0, v1 = q0[..., 1:], q1[..., 1:]\n    r = r0 * r1 - (v0 * v1).sum(dim=-1, keepdim=True)\n    v = r0 * v1 + r1 * v0 + torch.linalg.cross(v0, v1)\n    return torch.cat([r, v], dim=-1)\n\n\ndef quaternion_inverse(q, eps=1e-5):\n    \"\"\"\n    EXPECTS WXYZ\n    :param q (*, 4)\n    \"\"\"\n    conj = torch.cat([q[..., :1], -q[..., 1:]], dim=-1)\n    mag = torch.square(q).sum(dim=-1, keepdim=True) + eps\n    return conj / mag\n\n\ndef quaternion_slerp(t, q0, q1, eps=1e-5):\n    \"\"\"\n    :param t (*, 1)  must be between 0 and 1\n    :param q0 (*, 4)\n    :param q1 (*, 4)\n    \"\"\"\n    dims = q0.shape[:-1]\n    t = t.view(*dims, 1)\n\n    q0 = F.normalize(q0, p=2, dim=-1)\n    q1 = F.normalize(q1, p=2, dim=-1)\n    dot = (q0 * q1).sum(dim=-1, keepdim=True)\n\n    # make sure we give the shortest rotation path (< 180d)\n    neg = dot < -eps\n    q1 = torch.where(neg, -q1, q1)\n    dot = torch.where(neg, -dot, dot)\n    angle = torch.acos(dot)\n\n    # if angle is too small, just do linear interpolation\n    collin = torch.abs(dot) > 1 - eps\n    fac = 1 / torch.sin(angle)\n    w0 = torch.where(collin, 1 - t, torch.sin((1 - t) * angle) * fac)\n    w1 = torch.where(collin, t, torch.sin(t * angle) * fac)\n    slerp = q0 * w0 + q1 * w1\n    return slerp\n"
  },
  {
    "path": "src/egoallo/preprocessing/geometry/transforms/__init__.py",
    "content": "\"\"\"Lie group interface for rigid transforms, ported from\n[jaxlie](https://github.com/brentyi/jaxlie). Used by `viser` internally and\nin examples.\n\nImplements SO(2), SO(3), SE(2), and SE(3) Lie groups. Rotations are parameterized\nvia S^1 and S^3.\n\"\"\"\n\nfrom ._base import MatrixLieGroup as MatrixLieGroup\nfrom ._base import SEBase as SEBase\nfrom ._base import SOBase as SOBase\nfrom ._se2 import SE2 as SE2\nfrom ._se3 import SE3 as SE3\nfrom ._so2 import SO2 as SO2\nfrom ._so3 import SO3 as SO3\n"
  },
  {
    "path": "src/egoallo/preprocessing/geometry/transforms/_base.py",
    "content": "import abc\nfrom typing import ClassVar, Generic, Type, TypeVar, Union, overload, Optional, Tuple\n\nimport torch\nfrom typing_extensions import final, override\n\nfrom . import hints\n\nGroupType = TypeVar(\"GroupType\", bound=\"MatrixLieGroup\")\nSEGroupType = TypeVar(\"SEGroupType\", bound=\"SEBase\")\n\n\nclass MatrixLieGroup(abc.ABC):\n    \"\"\"Interface definition for matrix Lie groups.\"\"\"\n\n    # Class properties.\n    # > These will be set in `_utils.register_lie_group()`.\n\n    matrix_dim: ClassVar[int]\n    \"\"\"Dimension of square matrix output from `.matrix()`.\"\"\"\n\n    parameters_dim: ClassVar[int]\n    \"\"\"Dimension of underlying parameters, `.parameters()`.\"\"\"\n\n    tangent_dim: ClassVar[int]\n    \"\"\"Dimension of tangent space.\"\"\"\n\n    space_dim: ClassVar[int]\n    \"\"\"Dimension of coordinates that can be transformed.\"\"\"\n\n    def __init__(self, parameters: torch.Tensor):\n        \"\"\"\n        Construct a group object from its underlying parameters.\n        Notes:\n        - For the constructor signature to be consistent with subclasses, `parameters`\n          should be marked as positional-only. But this isn't possible in Python 3.7.\n        - This method is implicitly overriden by the dataclass decorator and\n          should _not_ be marked abstract.\n        \"\"\"\n        raise NotImplementedError()\n\n    # Shared implementations.\n\n    @overload\n    def __mul__(self: GroupType, other: GroupType) -> GroupType:\n        ...\n\n    @overload\n    def __mul__(self, other: hints.Array) -> torch.Tensor:\n        ...\n\n    def __mul__(\n        self: GroupType, other: Union[GroupType, hints.Array]\n    ) -> Union[GroupType, torch.Tensor]:\n        \"\"\"Overload for the `@` operator.\n\n        Switches between the group action (`.act()`) and multiplication\n        (`.mul()`) based on the type of `other`.\n        \"\"\"\n        if isinstance(other, hints.Array):\n            return self.act(target=other)\n        elif isinstance(other, MatrixLieGroup):\n            assert self.space_dim == other.space_dim\n            return self.mul(other=other)\n        else:\n            assert False, f\"Invalid argument type for `@` operator: {type(other)}\"\n\n    # Factory.\n\n    @classmethod\n    @abc.abstractmethod\n    def Identity(\n        cls: Type[GroupType], shape: Optional[Tuple] = (), **kwargs\n    ) -> GroupType:\n        \"\"\"Returns identity element.\n\n        Returns:\n            Identity element.\n        \"\"\"\n\n    @classmethod\n    @abc.abstractmethod\n    def from_matrix(cls: Type[GroupType], matrix: hints.Array) -> GroupType:\n        \"\"\"Get group member from matrix representation.\n\n        Args:\n            matrix: Matrix representaiton.\n\n        Returns:\n            Group member.\n        \"\"\"\n\n    # Accessors.\n\n    @abc.abstractmethod\n    def matrix(self) -> torch.Tensor:\n        \"\"\"Get transformation as a matrix. Homogeneous for SE groups.\"\"\"\n\n    @abc.abstractmethod\n    def parameters(self) -> torch.Tensor:\n        \"\"\"Get underlying representation.\"\"\"\n\n    @property\n    def data(self) -> torch.Tensor:\n        return self.parameters()\n\n    def __getitem__(self, index):\n        return self.__class__(self.data[index])\n\n    @property\n    def shape(self):\n        return self.data.shape\n\n    # Operations.\n\n    @abc.abstractmethod\n    def act(self, target: hints.Array) -> torch.Tensor:\n        \"\"\"Applies group action to a point.\n\n        Args:\n            target: Point to transform.\n\n        Returns:\n            Transformed point.\n        \"\"\"\n\n    @abc.abstractmethod\n    def mul(self: GroupType, other: GroupType) -> GroupType:\n        \"\"\"Composes this transformation with another.\n\n        Returns:\n            self @ other\n        \"\"\"\n\n    @classmethod\n    @abc.abstractmethod\n    def exp(cls: Type[GroupType], tangent: hints.Array) -> GroupType:\n        \"\"\"Computes `expm(wedge(tangent))`.\n\n        Args:\n            tangent: Tangent vector to take the exponential of.\n\n        Returns:\n            Output.\n        \"\"\"\n\n    @abc.abstractmethod\n    def log(self) -> torch.Tensor:\n        \"\"\"Computes `vee(logm(transformation matrix))`.\n\n        Returns:\n            Output. Shape should be `(tangent_dim,)`.\n        \"\"\"\n\n    @abc.abstractmethod\n    def adjoint(self, **kwargs) -> torch.Tensor:\n        \"\"\"Computes the adjoint, which transforms tangent vectors between tangent\n        spaces.\n\n        More precisely, for a transform `GroupType`:\n        ```\n        GroupType @ exp(omega) = exp(Adj_T @ omega) @ GroupType\n        ```\n        used for e.g. transforming twists, wrenches, and Jacobians\n        across different reference frames.\n\n        Returns:\n            Output. Shape should be `(tangent_dim, tangent_dim)`.\n        \"\"\"\n\n    @abc.abstractmethod\n    def inv(self: GroupType) -> GroupType:\n        \"\"\"Computes the inv of our transform.\n\n        Returns:\n            Output.\n        \"\"\"\n\n    @abc.abstractmethod\n    def normalize(self: GroupType) -> GroupType:\n        \"\"\"Normalize/projects values and returns.\n\n        Returns:\n            GroupType: Normalized group member.\n        \"\"\"\n\n\nclass SOBase(MatrixLieGroup):\n    \"\"\"Base class for special orthogonal groups.\"\"\"\n\n\nContainedSOType = TypeVar(\"ContainedSOType\", bound=SOBase)\n\n\nclass SEBase(Generic[ContainedSOType], MatrixLieGroup):\n    \"\"\"Base class for special Euclidean groups.\n\n    Each SE(N) group member contains an SO(N) rotation, as well as an N-dimensional\n    translation vector.\n    \"\"\"\n\n    # SE-specific interface.\n\n    @classmethod\n    @abc.abstractmethod\n    def from_rotation_and_translation(\n        cls: Type[SEGroupType],\n        rotation: ContainedSOType,\n        translation: hints.Array,\n    ) -> SEGroupType:\n        \"\"\"Construct a rigid transform from a rotation and a translation.\n\n        Args:\n            rotation: Rotation term.\n            translation: translation term.\n\n        Returns:\n            Constructed transformation.\n        \"\"\"\n\n    @final\n    @classmethod\n    def from_rotation(cls: Type[SEGroupType], rotation: ContainedSOType) -> SEGroupType:\n        data = rotation.parameters()\n        return cls.from_rotation_and_translation(\n            rotation=rotation,\n            translation=torch.zeros(\n                *data.shape[:-1], cls.space_dim, dtype=data.dtype, device=data.devce\n            ),\n        )\n\n    @classmethod\n    @abc.abstractmethod\n    def from_translation(\n        cls: Type[SEGroupType], translation: torch.Tensor\n    ) -> SEGroupType:\n        \"\"\"Construct a transform from a translation term.\"\"\"\n\n    @abc.abstractmethod\n    def rotation(self) -> ContainedSOType:\n        \"\"\"Returns a transform's rotation term.\"\"\"\n\n    @abc.abstractmethod\n    def translation(self) -> torch.Tensor:\n        \"\"\"Returns a transform's translation term.\"\"\"\n\n    # Overrides.\n\n    @final\n    @override\n    def act(self, target: hints.Array) -> torch.Tensor:\n        \"\"\"\n        apply transform to point\n        \"\"\"\n        d = self.space_dim\n        if target.shape[-1] == d:\n            return self.rotation().act(target) + self.translation()  # type: ignore\n\n        # homogeneous point\n        assert target.shape[-1] == d + 1\n        X, W = torch.split(target, [d, 1], dim=-1)  # (*, d), (*, 1)\n        Xp = self.rotation().act(X) + W * self.translation()\n        return torch.cat([Xp, W], dim=-1)\n\n    @final\n    @override\n    def mul(self: SEGroupType, other: SEGroupType) -> SEGroupType:\n        return type(self).from_rotation_and_translation(\n            rotation=self.rotation().mul(other.rotation()),\n            translation=self.rotation().act(other.translation()) + self.translation(),\n        )\n\n    @final\n    @override\n    def inv(self: SEGroupType) -> SEGroupType:\n        R_inv = self.rotation().inv()\n        return type(self).from_rotation_and_translation(\n            rotation=R_inv,\n            translation=-R_inv.act(self.translation()),\n        )\n\n    @final\n    @override\n    def normalize(self: SEGroupType) -> SEGroupType:\n        return type(self).from_rotation_and_translation(\n            rotation=self.rotation().normalize(),\n            translation=self.translation(),\n        )\n"
  },
  {
    "path": "src/egoallo/preprocessing/geometry/transforms/_se2.py",
    "content": "import dataclasses\nfrom typing import Optional, Tuple\n\nimport torch\nimport numpy as onp\nfrom typing_extensions import override\n\nfrom . import _base, hints\nfrom ._so2 import SO2\nfrom .utils import get_epsilon, register_lie_group\n\n\n@register_lie_group(\n    matrix_dim=3,\n    parameters_dim=4,\n    tangent_dim=3,\n    space_dim=2,\n)\n@dataclasses.dataclass\nclass SE2(_base.SEBase[SO2]):\n    \"\"\"Special Euclidean group for proper rigid transforms in 2D.\n\n    Ported to pytorch from `jaxlie.SE2`.\n\n    Internal parameterization is `(cos, sin, x, y)`. Tangent parameterization is `(vx,\n    vy, omega)`.\n    \"\"\"\n\n    # SE2-specific.\n\n    unit_complex_xy: torch.Tensor\n    \"\"\"Internal parameters. `(cos, sin, x, y)`.\"\"\"\n\n    @override\n    def __repr__(self) -> str:\n        unit_complex = torch.round(self.unit_complex_xy[..., :2], decimals=5)\n        xy = torch.round(self.unit_complex_xy[..., 2:], decimals=5)\n        return f\"{self.__class__.__name__}(unit_complex={unit_complex}, xy={xy})\"\n\n    @staticmethod\n    def from_xy_theta(x: hints.Scalar, y: hints.Scalar, theta: hints.Scalar) -> \"SE2\":\n        \"\"\"Construct a transformation from standard 2D pose parameters.\n\n        Note that this is not the same as integrating over a length-3 twist.\n        \"\"\"\n        cos = torch.cos(torch.as_tensor(theta))\n        sin = torch.sin(torch.as_tensor(theta))\n        x, y = torch.as_tensor(x), torch.as_tensor(y)\n        return SE2(unit_complex_xy=torch.stack([cos, sin, x, y], dim=-1))\n\n    # SE-specific.\n\n    @staticmethod\n    @override\n    def from_rotation_and_translation(\n        rotation: SO2,\n        translation: hints.Array,\n    ) -> \"SE2\":\n        assert translation.shape[-1] == 2\n        return SE2(\n            unit_complex_xy=torch.cat([rotation.unit_complex, translation], dim=-1)\n        )\n\n    @override\n    @classmethod\n    def from_translation(cls, translation: torch.Tensor) -> \"SE2\":\n        return SE2.from_rotation_and_translation(\n            SO2.Identity(\n                shape=translation.shape[:-1],\n                dtype=translation.dtype,\n                device=translation.device,\n            ),\n            translation,\n        )\n\n    @override\n    def rotation(self) -> SO2:\n        return SO2(unit_complex=self.unit_complex_xy[..., :2])\n\n    @override\n    def translation(self) -> torch.Tensor:\n        return self.unit_complex_xy[..., 2:]\n\n    # Factory.\n\n    @staticmethod\n    @override\n    def Identity(shape: Optional[Tuple] = (), **kwargs) -> \"SE2\":\n        id_elem = (\n            torch.tensor([1.0, 0.0, 0.0, 0.0], **kwargs)\n            .reshape(*(1,) * len(shape), 4)\n            .repeat(*shape, 1)\n        )\n        return SE2(unit_complex_xy=id_elem)\n\n    @staticmethod\n    @override\n    def from_matrix(matrix: hints.Array) -> \"SE2\":\n        assert matrix.shape[-2:] == (3, 3)\n        # Currently assumes bottom row is [0, 0, 1].\n        return SE2.from_rotation_and_translation(\n            rotation=SO2.from_matrix(matrix[..., :2, :2]),\n            translation=matrix[..., :2, 2],\n        )\n\n    # Accessors.\n\n    @override\n    def parameters(self) -> torch.Tensor:\n        return self.unit_complex_xy\n\n    @override\n    def matrix(self) -> torch.Tensor:\n        cos, sin, x, y = self.unit_complex_xy.unbind(dim=-1)\n        zero = torch.zeros_like(x)\n        one = torch.ones_like(x)\n        return torch.stack(\n            [cos, -sin, x, sin, cos, y, zero, zero, one], dim=-1\n        ).reshape(*cos.shape, 3, 3)\n\n    # Operations.\n\n    @staticmethod\n    @override\n    def exp(tangent: hints.Array) -> \"SE2\":\n        # Reference:\n        # > https://github.com/strasdat/Sophus/blob/a0fe89a323e20c42d3cecb590937eb7a06b8343a/sophus/se2.hpp#L558\n        # Also see:\n        # > http://ethaneade.com/lie.pdf\n\n        assert tangent.shape[-1] == 3\n\n        theta = tangent[..., 2]\n\n        # transform the translation vector\n        use_taylor = torch.abs(theta) < get_epsilon(tangent.dtype)\n        safe_theta = torch.where(\n            use_taylor,\n            torch.ones_like(theta),  # Any non-zero value should do here.\n            theta,\n        )\n\n        theta_sq = theta ** 2\n        sin_over_theta = torch.where(\n            use_taylor,\n            1.0 - theta_sq / 6.0,\n            torch.sin(safe_theta) / safe_theta,\n        )\n        one_minus_cos_over_theta = torch.where(\n            use_taylor,\n            0.5 * theta - theta * theta_sq / 24.0,\n            (1.0 - torch.cos(safe_theta)) / safe_theta,\n        )\n\n        V = torch.stack(\n            [\n                sin_over_theta,\n                -one_minus_cos_over_theta,\n                one_minus_cos_over_theta,\n                sin_over_theta,\n            ],\n            dim=-1,\n        ).reshape(*theta.shape, 2, 2)\n\n        return SE2.from_rotation_and_translation(\n            rotation=SO2.from_radians(theta),\n            translation=torch.einsum(\"...ij,...j->...i\", V, tangent[..., :2]),\n        )\n\n    @override\n    def log(self) -> torch.Tensor:\n        # Reference:\n        # > https://github.com/strasdat/Sophus/blob/a0fe89a323e20c42d3cecb590937eb7a06b8343a/sophus/se2.hpp#L160\n        # Also see:\n        # > http://ethaneade.com/lie.pdf\n\n        theta = self.rotation().log()[..., 0]\n\n        cos = torch.cos(theta)\n        cos_minus_one = cos - 1.0\n        half_theta = theta / 2.0\n        use_taylor = torch.abs(cos_minus_one) < get_epsilon(theta.dtype)\n\n        safe_cos_minus_one = torch.where(\n            use_taylor,\n            torch.ones_like(cos_minus_one),  # Any non-zero value should do here.\n            cos_minus_one,\n        )\n\n        half_theta_over_tan_half_theta = torch.where(\n            use_taylor,\n            # Taylor approximation.\n            1.0 - theta ** 2 / 12.0,\n            # Default.\n            -(half_theta * onp.sin(theta)) / safe_cos_minus_one,\n        )\n\n        V_inv = torch.stack(\n            [\n                half_theta_over_tan_half_theta,\n                half_theta,\n                -half_theta,\n                half_theta_over_tan_half_theta,\n            ],\n            dim=-1,\n        ).reshape(*theta.shape, 2, 2)\n\n        tangent = torch.cat(\n            [\n                torch.einsum(\"...ij,...j->...i\", V_inv, self.translation()),\n                theta[..., None],\n            ]\n        )\n        return tangent\n\n    @override\n    def adjoint(self, **kwargs) -> torch.Tensor:\n        cos, sin, x, y = self.unit_complex_xy.unbind(dim=-1)\n        zero = torch.zeros_like(x)\n        one = torch.ones_like(x)\n        return torch.stack(\n            [\n                cos,\n                -sin,\n                y,\n                sin,\n                cos,\n                -x,\n                zero,\n                zero,\n                one,\n            ],\n            dim=-1,\n        ).reshape(*x.shape, 3, 3)\n"
  },
  {
    "path": "src/egoallo/preprocessing/geometry/transforms/_se3.py",
    "content": "from __future__ import annotations\n\nimport dataclasses\nfrom typing import Optional, Tuple\n\nimport torch\nfrom typing_extensions import override\n\nfrom . import _base\nfrom ._so3 import SO3\nfrom .utils import get_epsilon, register_lie_group\n\n\ndef _skew(omega: torch.Tensor) -> torch.Tensor:\n    \"\"\"\n    Returns the skew-symmetric form of a length-3 vector.\n    :param omega (*, 3)\n    :returns (*, 3, 3)\n    \"\"\"\n\n    wx, wy, wz = omega.unbind(dim=-1)\n    o = torch.zeros_like(wx)\n    return torch.stack(\n        [o, -wz, wy, wz, o, -wx, -wy, wx, o],\n        dim=-1,\n    ).reshape(*wx.shape, 3, 3)\n\n\n@register_lie_group(\n    matrix_dim=4,\n    parameters_dim=7,\n    tangent_dim=6,\n    space_dim=3,\n)\n@dataclasses.dataclass\nclass SE3(_base.SEBase[SO3]):\n    \"\"\"Special Euclidean group for proper rigid transforms in 3D.\n\n    Ported to pytorch from `jaxlie.SE3`.\n\n    Internal parameterization is `(qw, qx, qy, qz, x, y, z)`. Tangent parameterization\n    is `(vx, vy, vz, omega_x, omega_y, omega_z)`.\n    \"\"\"\n\n    # SE3-specific.\n\n    wxyz_xyz: torch.Tensor\n    \"\"\"Internal parameters. wxyz quaternion followed by xyz translation.\"\"\"\n\n    @override\n    def __repr__(self) -> str:\n        quat = torch.round(self.wxyz_xyz[..., :4], decimals=5)\n        trans = torch.round(self.wxyz_xyz[..., 4:], decimals=5)\n        return f\"{self.__class__.__name__}(wxyz={quat}, xyz={trans})\"\n\n    # SE-specific.\n\n    @staticmethod\n    @override\n    def from_rotation_and_translation(\n        rotation: SO3,\n        translation: torch.Tensor,\n    ) -> \"SE3\":\n        assert translation.shape[-1] == 3\n        return SE3(wxyz_xyz=torch.cat([rotation.wxyz, translation], dim=-1))\n\n    @override\n    @classmethod\n    def from_translation(cls, translation: torch.Tensor) -> \"SE3\":\n        return SE3.from_rotation_and_translation(\n            SO3.Identity(\n                shape=translation.shape[:-1],\n                dtype=translation.dtype,\n                device=translation.device,\n            ),\n            translation,\n        )\n\n    @override\n    def rotation(self) -> SO3:\n        return SO3(wxyz=self.wxyz_xyz[..., :4])\n\n    @override\n    def translation(self) -> torch.Tensor:\n        return self.wxyz_xyz[..., 4:]\n\n    # Factory.\n\n    @staticmethod\n    @override\n    def Identity(shape: Optional[Tuple] = (), **kwargs) -> \"SE3\":\n        id_elem = (\n            torch.tensor([1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], **kwargs)\n            .reshape(*(1,) * len(shape), 7)\n            .repeat(*shape, 1)\n        )\n        return SE3(wxyz_xyz=id_elem)\n\n    @staticmethod\n    @override\n    def from_matrix(matrix: torch.Tensor) -> \"SE3\":\n        assert matrix.shape[-2:] == (4, 4)\n        # Currently assumes bottom row is [0, 0, 0, 1].\n        return SE3.from_rotation_and_translation(\n            rotation=SO3.from_matrix(matrix[..., :3, :3]),\n            translation=matrix[..., :3, 3],\n        )\n\n    # Accessors.\n\n    @override\n    def matrix(self) -> torch.Tensor:\n        R = self.rotation().matrix()  # (*, 3, 3)\n        t = self.translation().unsqueeze(-1)  # (*, 3, 1)\n        dims = R.shape[:-2]\n        bottom = (\n            torch.tensor([0, 0, 0, 1], dtype=R.dtype, device=R.device)\n            .reshape(*(1,) * len(dims), 1, 4)\n            .repeat(*dims, 1, 1)\n        )\n        return torch.cat([torch.cat([R, t], dim=-1), bottom], dim=-2)\n\n    @override\n    def parameters(self) -> torch.Tensor:\n        return self.wxyz_xyz\n\n    # Operations.\n\n    @staticmethod\n    @override\n    def exp(tangent: torch.Tensor) -> \"SE3\":\n        \"\"\"\n        :param tangent (*, 6)\n        \"\"\"\n        # Reference:\n        # > https://github.com/strasdat/Sophus/blob/a0fe89a323e20c42d3cecb590937eb7a06b8343a/sophus/se3.hpp#L761\n\n        # (x, y, z, omega_x, omega_y, omega_z)\n        *dims, d = tangent.shape\n        assert d == 6\n\n        trans, omega = torch.split(tangent, [3, 3], dim=-1)  # (*, 3), (*, 3)\n\n        rotation = SO3.exp(omega)  # (*, 3)\n        theta_squared = torch.square(omega).sum(dim=-1)  # (*)\n        use_taylor = theta_squared < get_epsilon(theta_squared.dtype)\n\n        theta_squared_safe = torch.where(\n            use_taylor,\n            torch.ones_like(theta_squared),  # Any non-zero value should do here.\n            theta_squared,\n        )\n        del theta_squared\n        theta_safe = torch.sqrt(theta_squared_safe)\n\n        skew_omega = _skew(omega)  # (*, 3, 3)\n        I = (\n            torch.eye(3, device=omega.device)\n            .reshape(*(1,) * len(dims), 3, 3)\n            .expand(*dims, 3, 3)\n        )\n        f1 = (1.0 - torch.cos(theta_safe)) / (theta_squared_safe)\n        f2 = (theta_safe - torch.sin(theta_safe)) / (theta_squared_safe * theta_safe)\n        V = torch.where(\n            use_taylor[..., None, None],\n            rotation.matrix(),\n            I\n            + f1[..., None, None] * skew_omega\n            + f2[..., None, None] * torch.matmul(skew_omega, skew_omega),\n        )\n        return SE3.from_rotation_and_translation(\n            rotation=rotation,\n            translation=torch.einsum(\"...ij,...j->...i\", V, trans),\n        )\n\n    @override\n    def log(self) -> torch.Tensor:\n        # Reference:\n        # > https://github.com/strasdat/Sophus/blob/a0fe89a323e20c42d3cecb590937eb7a06b8343a/sophus/se3.hpp#L223\n        omega = self.rotation().log()  # (*, 3)\n        theta_squared = torch.square(omega).sum(dim=-1)  # (*)\n        use_taylor = theta_squared < get_epsilon(theta_squared.dtype)\n\n        theta_squared_safe = torch.where(\n            use_taylor,\n            torch.ones_like(theta_squared),  # Any non-zero value should do here.\n            theta_squared,\n        )\n        del theta_squared\n        theta_safe = torch.sqrt(theta_squared_safe)\n        half_theta_safe = theta_safe / 2.0\n\n        skew_omega = _skew(omega)  # (*, 3, 3)\n        skew_omega_sq = torch.matmul(skew_omega)\n        I = torch.eye(3, device=omega.device).reshape(*(1,) * len(dims), 3, 3)\n        f2 = (\n            1.0\n            - theta_safe\n            * torch.cos(half_theta_safe)\n            / (2.0 * torch.sin(half_theta_safe))\n        ) / theta_squared_safe\n\n        V_inv = torch.where(\n            use_taylor,\n            I - 0.5 * skew_omega + skew_omega_sq / 12.0,\n            I - 0.5 * skew_omega + f2[..., None, None] * skew_omega_sq,\n        )\n        return torch.cat(\n            [torch.einsum(\"...ij,...j->...i\", V_inv, self.translation()), omega], dim=-1\n        )\n\n    @override\n    def adjoint(self) -> torch.Tensor:\n        R = self.rotation().matrix()\n        dims = R.shape[:-2]\n        # (*, 6, 6)\n        return torch.cat(\n            [\n                torch.cat([R, torch.matmul(_skew(self.translation()), R)], dim=-1),\n                torch.cat([torch.zeros((*dims, 3, 3)), R], dim=-1),\n            ],\n            dim=-2,\n        )\n"
  },
  {
    "path": "src/egoallo/preprocessing/geometry/transforms/_so2.py",
    "content": "from __future__ import annotations\n\nimport dataclasses\nfrom typing import Optional, Tuple\n\nimport torch\nfrom typing_extensions import override\n\nfrom . import _base, hints\nfrom .utils import register_lie_group\n\n\n@register_lie_group(\n    matrix_dim=2,\n    parameters_dim=2,\n    tangent_dim=1,\n    space_dim=2,\n)\n@dataclasses.dataclass\nclass SO2(_base.SOBase):\n    \"\"\"Special orthogonal group for 2D rotations.\n\n    Ported to pytorch from `jaxlie.SO2`.\n\n    Internal parameterization is `(cos, sin)`. Tangent parameterization is `(omega,)`.\n    \"\"\"\n\n    # SO2-specific.\n\n    unit_complex: torch.Tensor\n    \"\"\"Internal parameters. `(cos, sin)`.\"\"\"\n\n    @override\n    def __repr__(self) -> str:\n        unit_complex = torch.round(self.unit_complex, 5)\n        return f\"{self.__class__.__name__}(unit_complex={unit_complex})\"\n\n    @staticmethod\n    def from_radians(theta: hints.Scalar) -> SO2:\n        \"\"\"Construct a rotation object from a scalar angle.\"\"\"\n        theta = torch.as_tensor(theta)\n        cos = torch.cos(theta)\n        sin = torch.sin(theta)\n        return SO2(unit_complex=torch.stack([cos, sin], dim=-1))\n\n    def as_radians(self) -> torch.Tensor:\n        \"\"\"Compute a scalar angle from a rotation object.\"\"\"\n        radians = self.log()[..., 0]\n        return radians\n\n    # Factory.\n\n    @staticmethod\n    @override\n    def Identity(shape: Optional[Tuple] = (), **kwargs) -> SO2:\n        id_elem = (\n            torch.tensor([1.0, 0.0], **kwargs)\n            .reshape(*(1,) * len(shape), 2)\n            .repeat(*shape, 1)\n        )\n        return SO2(unit_complex=id_elem)\n\n    @staticmethod\n    @override\n    def from_matrix(matrix: torch.Tensor) -> SO2:\n        assert matrix.shape[-2:] == (2, 2)\n        return SO2(unit_complex=matrix[..., 0])\n\n    # Accessors.\n\n    @override\n    def matrix(self) -> torch.Tensor:\n        \"\"\"\n        [[cos, -sin], [sin, cos]]\n        :returns (*, 2, 2) tensor\n        \"\"\"\n        cos, sin = self.unit_complex.unbind(dim=-1)\n        return torch.stack([cos, -sin, sin, cos], dim=-1).reshape(*cos.shape, 2, 2)\n\n    @override\n    def parameters(self) -> torch.Tensor:\n        return self.unit_complex\n\n    # Operations.\n\n    @override\n    def act(self, target: torch.Tensor) -> torch.Tensor:\n        assert target.shape[-1] == 2\n        return torch.einsum(\"...ij,...j->...i\", self.matrix(), target)\n\n    @override\n    def mul(self, other: SO2) -> SO2:\n        return SO2(\n            unit_complex=torch.einsum(\n                \"...ij,...j->...i\", self.matrix(), other.unit_complex\n            )\n        )\n\n    @staticmethod\n    @override\n    def exp(tangent: torch.Tensor) -> SO2:\n        return SO2(\n            unit_complex=torch.stack([torch.cos(tangent), torch.sin(tangent)], dim=-1)\n        )\n\n    @override\n    def log(self) -> torch.Tensor:\n        return torch.atan2(\n            self.unit_complex[..., 1, None], self.unit_complex[..., 0, None]\n        )\n\n    @override\n    def adjoint(self, **kwargs) -> torch.Tensor:\n        return torch.eye(1, **kwargs)\n\n    @override\n    def inv(self) -> SO2:\n        cos, sin = self.unit_complex.unbind(dim=-1)\n        return SO2(unit_complex=torch.stack([cos, -sin], dim=-1))\n\n    @override\n    def normalize(self) -> SO2:\n        return SO2(\n            unit_complex=self.unit_complex\n            / torch.linalg.norm(self.unit_complex, dim=-1, keepdim=True)\n        )\n"
  },
  {
    "path": "src/egoallo/preprocessing/geometry/transforms/_so3.py",
    "content": "from __future__ import annotations\n\nfrom typing import Optional, Tuple\nimport dataclasses\n\nimport math\nimport torch\nfrom typing_extensions import override\n\nfrom . import _base, hints\nfrom .utils import get_epsilon, register_lie_group\n\n\n@register_lie_group(\n    matrix_dim=3,\n    parameters_dim=4,\n    tangent_dim=3,\n    space_dim=3,\n)\n@dataclasses.dataclass\nclass SO3(_base.SOBase):\n    \"\"\"Special orthogonal group for 3D rotations.\n\n    Ported to pytorch from `jaxlie.SO3`.\n\n    Internal parameterization is `(qw, qx, qy, qz)`. Tangent parameterization is\n    `(omega_x, omega_y, omega_z)`.\n    \"\"\"\n\n    # SO3-specific.\n\n    wxyz: torch.Tensor\n    \"\"\"Internal parameters. `(w, x, y, z)` quaternion.\"\"\"\n\n    @override\n    def __repr__(self) -> str:\n        wxyz = torch.round(self.wxyz, decimals=5)\n        return f\"{self.__class__.__name__}(wxyz={wxyz})\"\n\n    @staticmethod\n    def from_x_radians(theta: torch.Tensor) -> SO3:\n        \"\"\"\n        Generates a x-axis rotation.\n        :param theta (tensor) x rotation\n        :returns SO3 object\n        \"\"\"\n        zero = torch.zeros_like(theta)\n        return SO3.exp(torch.stack([theta, zero, zero], dim=-1))\n\n    @staticmethod\n    def from_y_radians(theta: torch.Tensor) -> SO3:\n        \"\"\"\n        Generates a y-axis rotation.\n        :param theta (tensor) y rotation\n        :returns SO3 object\n        \"\"\"\n        zero = torch.zeros_like(theta)\n        return SO3.exp(torch.stack([zero, theta, zero], dim=-1))\n\n    @staticmethod\n    def from_z_radians(theta: torch.Tensor) -> SO3:\n        \"\"\"\n        Generates a z-axis rotation.\n        :param theta (tensor) z rotation\n        :returns SO3 object\n        \"\"\"\n        zero = torch.zeros_like(theta)\n        return SO3.exp(torch.stack([zero, zero, theta], dim=-1))\n\n    @staticmethod\n    def from_rpy_radians(\n        roll: torch.Tensor,\n        pitch: torch.Tensor,\n        yaw: torch.Tensor,\n    ) -> SO3:\n        \"\"\"\n        Generates a transform from a set of Euler angles. Uses the ZYX convention.\n        Args:\n            roll: X rotation, in radians. Applied first.\n            pitch: Y rotation, in radians. Applied second.\n            yaw: Z rotation, in radians. Applied last.\n        \"\"\"\n        Rz = SO3.from_z_radians(yaw)\n        Ry = SO3.from_y_radians(pitch)\n        Rx = SO3.from_x_radians(roll)\n        return Rz.mul(Ry.mul(Rx))\n\n    @staticmethod\n    def from_quaternion_xyzw(xyzw: torch.Tensor) -> SO3:\n        \"\"\"\n        Construct a rotation from an `xyzw` quaternion.\n        Note that `wxyz` quaternions can be constructed using the default dataclass\n        constructor.\n        :param xyzw (*, 4) quat in xyzw convention\n        :returns SO3 object\n        \"\"\"\n        assert xyzw.shape[-1] == 4\n        return SO3(torch.roll(xyzw, shift=1, dims=-1))\n\n    def as_quaternion_xyzw(self) -> torch.Tensor:\n        \"\"\"Grab parameters as xyzw quaternion.\"\"\"\n        return torch.roll(self.wxyz, shift=-1, dims=-1)\n\n    def as_rpy_radians(self) -> hints.RollPitchYaw:\n        \"\"\"\n        Computes roll, pitch, and yaw angles. Uses the ZYX convention.\n        Returns:\n            Named tuple containing Euler angles in radians.\n        \"\"\"\n        return hints.RollPitchYaw(\n            roll=self.compute_roll_radians(),\n            pitch=self.compute_pitch_radians(),\n            yaw=self.compute_yaw_radians(),\n        )\n\n    def compute_roll_radians(self) -> torch.Tensor:\n        \"\"\"\n        Compute roll angle. Uses the ZYX convention.\n        :returns angle (*) if wxyz is (*, 4)\n        \"\"\"\n        # https://en.wikipedia.org/wiki/Conversion_between_quaternions_and_Euler_angles#Quaternion_to_Euler_angles_conversion\n        q0, q1, q2, q3 = self.wxyz.unbind(dim=-1)\n        return torch.atan2(2 * (q0 * q1 + q2 * q3), 1 - 2 * (q1 ** 2 + q2 ** 2))\n\n    def compute_pitch_radians(self) -> torch.Tensor:\n        \"\"\"\n        Compute pitch angle. Uses the ZYX convention.\n        :returns angle (*) if wxyz is (*, 4)\n        \"\"\"\n        # https://en.wikipedia.org/wiki/Conversion_between_quaternions_and_Euler_angles#Quaternion_to_Euler_angles_conversion\n        q0, q1, q2, q3 = self.wxyz.unbind(dim=-1)\n        return torch.asin(2 * (q0 * q2 - q3 * q1))\n\n    def compute_yaw_radians(self) -> torch.Tensor:\n        \"\"\"\n        Compute yaw angle. Uses the ZYX convention.\n        :returns angle (*) if wxyz is (*, 4)\n        \"\"\"\n        # https://en.wikipedia.org/wiki/Conversion_between_quaternions_and_Euler_angles#Quaternion_to_Euler_angles_conversion\n        q0, q1, q2, q3 = self.wxyz.unbind(dim=-1)\n        return torch.atan2(2 * (q0 * q3 + q1 * q2), 1 - 2 * (q2 ** 2 + q3 ** 2))\n\n    # Factory.\n\n    @staticmethod\n    @override\n    def Identity(shape: Optional[Tuple] = (), **kwargs) -> SO3:\n        id_elem = (\n            torch.tensor([1.0, 0.0, 0.0, 0.0], **kwargs)\n            .reshape(*(1,) * len(shape), 4)\n            .repeat(*shape, 1)\n        )\n        return SO3(wxyz=id_elem)\n\n    @staticmethod\n    @override\n    def from_matrix(matrix: torch.Tensor) -> SO3:\n        assert matrix.shape[-2:] == (3, 3)\n\n        # Modified from:\n        # > \"Converting a Rotation Matrix to a Quaternion\" from Mike Day\n        # > https://d3cw3dd2w32x2b.cloudfront.net/wp-content/uploads/2015/01/matrix-to-quat.pdf\n\n        def case0(m):\n            t = 1 + m[..., 0, 0] - m[..., 1, 1] - m[..., 2, 2]\n            q = torch.stack(\n                [\n                    m[..., 2, 1] - m[..., 1, 2],\n                    t,\n                    m[..., 1, 0] + m[..., 0, 1],\n                    m[..., 0, 2] + m[..., 2, 0],\n                ],\n                dim=-1,\n            )\n            return t, q\n\n        def case1(m):\n            t = 1 - m[..., 0, 0] + m[..., 1, 1] - m[..., 2, 2]\n            q = torch.stack(\n                [\n                    m[..., 0, 2] - m[..., 2, 0],\n                    m[..., 1, 0] + m[..., 0, 1],\n                    t,\n                    m[..., 2, 1] + m[..., 1, 2],\n                ],\n                dim=-1,\n            )\n            return t, q\n\n        def case2(m):\n            t = 1 - m[..., 0, 0] - m[..., 1, 1] + m[..., 2, 2]\n            q = torch.stack(\n                [\n                    m[..., 1, 0] - m[..., 0, 1],\n                    m[..., 0, 2] + m[..., 2, 0],\n                    m[..., 2, 1] + m[..., 1, 2],\n                    t,\n                ],\n                dim=-1,\n            )\n            return t, q\n\n        def case3(m):\n            t = 1 + m[..., 0, 0] + m[..., 1, 1] + m[..., 2, 2]\n            q = torch.stack(\n                [\n                    t,\n                    m[..., 2, 1] - m[..., 1, 2],\n                    m[..., 0, 2] - m[..., 2, 0],\n                    m[..., 1, 0] - m[..., 0, 1],\n                ],\n                dim=-1,\n            )\n            return t, q\n\n        # Compute four cases, then pick the most precise one.\n        # Probably worth revisiting this!\n        case0_t, case0_q = case0(matrix)\n        case1_t, case1_q = case1(matrix)\n        case2_t, case2_q = case2(matrix)\n        case3_t, case3_q = case3(matrix)\n\n        cond0 = matrix[..., 2, 2] < 0\n        cond1 = matrix[..., 0, 0] > matrix[..., 1, 1]\n        cond2 = matrix[..., 0, 0] < -matrix[..., 1, 1]\n\n        t = torch.where(\n            cond0,\n            torch.where(cond1, case0_t, case1_t),\n            torch.where(cond2, case2_t, case3_t),\n        ).unsqueeze(-1)\n        q = torch.where(\n            cond0.unsqueeze(-1),\n            torch.where(cond1.unsqueeze(-1), case0_q, case1_q),\n            torch.where(cond2.unsqueeze(-1), case2_q, case3_q),\n        )\n\n        return SO3(wxyz=q * 0.5 / torch.sqrt(t))\n\n    # Accessors.\n\n    @override\n    def matrix(self) -> torch.Tensor:\n        norm_sq = torch.square(self.wxyz).sum(dim=-1, keepdim=True)\n        qvec = self.wxyz * torch.sqrt(2.0 / norm_sq)  # (*, 4)\n        Q = torch.einsum(\"...i,...j->...ij\", qvec, qvec)  # (*, 4, 4)\n        return torch.stack(\n            [\n                1.0 - Q[..., 2, 2] - Q[..., 3, 3],\n                Q[..., 1, 2] - Q[..., 3, 0],\n                Q[..., 1, 3] + Q[..., 2, 0],\n                Q[..., 1, 2] + Q[..., 3, 0],\n                1.0 - Q[..., 1, 1] - Q[..., 3, 3],\n                Q[..., 2, 3] - Q[..., 1, 0],\n                Q[..., 1, 3] - Q[..., 2, 0],\n                Q[..., 2, 3] + Q[..., 1, 0],\n                1.0 - Q[..., 1, 1] - Q[..., 2, 2],\n            ],\n            dim=-1,\n        ).reshape(*qvec.shape[:-1], 3, 3)\n\n    @override\n    def parameters(self) -> torch.Tensor:\n        return self.wxyz\n\n    # Operations.\n\n    @override\n    def act(self, target: torch.Tensor) -> torch.Tensor:\n        assert target.shape[-1] == 3\n        # Compute using quaternion muls.\n        padded_target = torch.cat([torch.ones_like(target[..., :1]), target], dim=-1)\n        out = self.mul(SO3(wxyz=padded_target).mul(self.inv()))\n        return out.wxyz[..., 1:]\n\n    @override\n    def mul(self, other: SO3) -> SO3:\n        w0, x0, y0, z0 = self.wxyz.unbind(dim=-1)\n        w1, x1, y1, z1 = other.wxyz.unbind(dim=-1)\n        wxyz2 = torch.stack(\n            [\n                -x0 * x1 - y0 * y1 - z0 * z1 + w0 * w1,\n                x0 * w1 + y0 * z1 - z0 * y1 + w0 * x1,\n                -x0 * z1 + y0 * w1 + z0 * x1 + w0 * y1,\n                x0 * y1 - y0 * x1 + z0 * w1 + w0 * z1,\n            ],\n            dim=-1,\n        )\n\n        return SO3(wxyz=wxyz2)\n\n    @staticmethod\n    @override\n    def exp(tangent: torch.Tensor) -> SO3:\n        \"\"\"\n        create SO3 object from axis angle tangent vector\n        :param tangent (*, 3)\n        \"\"\"\n        # Reference:\n        # > https://github.com/strasdat/Sophus/blob/a0fe89a323e20c42d3cecb590937eb7a06b8343a/sophus/so3.hpp#L583\n\n        assert tangent.shape[-1] == 3\n\n        theta_squared = torch.square(tangent).sum(dim=-1)  # (*)\n        theta_pow_4 = theta_squared * theta_squared\n        use_taylor = theta_squared < get_epsilon(tangent.dtype)\n\n        safe_theta = torch.sqrt(\n            torch.where(\n                use_taylor,\n                torch.ones_like(theta_squared),  # Any constant value should do here.\n                theta_squared,\n            )\n        )\n        safe_half_theta = 0.5 * safe_theta\n\n        real_factor = torch.where(\n            use_taylor,\n            1.0 - theta_squared / 8.0 + theta_pow_4 / 384.0,\n            torch.cos(safe_half_theta),\n        )\n\n        imaginary_factor = torch.where(\n            use_taylor,\n            0.5 - theta_squared / 48.0 + theta_pow_4 / 3840.0,\n            torch.sin(safe_half_theta) / safe_theta,\n        )\n\n        return SO3(\n            wxyz=torch.cat(\n                [\n                    real_factor[..., None],\n                    imaginary_factor[..., None] * tangent,\n                ],\n                dim=-1,\n            )\n        )\n\n    @override\n    def log(self) -> torch.Tensor:\n        \"\"\"\n        log map to tangent space\n        :return (*, 3) tangent vector\n        \"\"\"\n        # Reference:\n        # > https://github.com/strasdat/Sophus/blob/a0fe89a323e20c42d3cecb590937eb7a06b8343a/sophus/so3.hpp#L247\n\n        w, xyz = torch.split(self.wxyz, [1, 3], dim=-1)  # (*, 1), (*, 3)\n        norm_sq = torch.square(xyz).sum(dim=-1, keepdim=True)  # (*, 1)\n        use_taylor = norm_sq < get_epsilon(norm_sq.dtype)\n\n        norm_safe = torch.sqrt(\n            torch.where(\n                use_taylor,\n                torch.ones_like(norm_sq),  # Any non-zero value should do here.\n                norm_sq,\n            )\n        )\n        w_safe = torch.where(use_taylor, w, torch.ones_like(w))\n        atan_n_over_w = torch.atan2(\n            torch.where(w < 0, -norm_safe, norm_safe),\n            torch.abs(w),\n        )\n        atan_factor = torch.where(\n            use_taylor,\n            2.0 / w_safe - 2.0 / 3.0 * norm_sq / w_safe ** 3,\n            torch.where(\n                torch.abs(w) < get_epsilon(w.dtype),\n                torch.where(w > 0, 1.0, -1.0) * math.pi / norm_safe,\n                2.0 * atan_n_over_w / norm_safe,\n            ),\n        )\n\n        return atan_factor * xyz\n\n    @override\n    def adjoint(self) -> torch.Tensor:\n        return self.matrix()\n\n    @override\n    def inv(self) -> SO3:\n        # Negate complex terms.\n        w, xyz = torch.split(self.wxyz, [1, 3], dim=-1)\n        return SO3(wxyz=torch.cat([w, -xyz], dim=-1))\n\n    @override\n    def normalize(self) -> SO3:\n        return SO3(wxyz=self.wxyz / torch.linalg.norm(self.wxyz, dim=-1, keepdim=True))\n"
  },
  {
    "path": "src/egoallo/preprocessing/geometry/transforms/hints/__init__.py",
    "content": "from typing import NamedTuple, Union\n\nimport numpy as np\nimport torch\n\n\nArray = torch.Tensor\n\"\"\"Type alias for `torch.Tensor`.\"\"\"\n\nScalar = Union[float, Array]\n\"\"\"Type alias for `Union[float, Array]`.\"\"\"\n\n\nclass RollPitchYaw(NamedTuple):\n    \"\"\"Tuple containing roll, pitch, and yaw Euler angles.\"\"\"\n\n    roll: Scalar\n    pitch: Scalar\n    yaw: Scalar\n\n\n__all__ = [\n    \"Array\",\n    \"Scalar\",\n    \"RollPitchYaw\",\n]\n"
  },
  {
    "path": "src/egoallo/preprocessing/geometry/transforms/utils/__init__.py",
    "content": "from ._utils import get_epsilon, register_lie_group\n\n__all__ = [\"get_epsilon\", \"register_lie_group\"]\n"
  },
  {
    "path": "src/egoallo/preprocessing/geometry/transforms/utils/_utils.py",
    "content": "from typing import TYPE_CHECKING, Callable, Type, TypeVar\n\nimport torch\n\nif TYPE_CHECKING:\n    from .._base import MatrixLieGroup\n\n\nT = TypeVar(\"T\", bound=\"MatrixLieGroup\")\n\n\ndef get_epsilon(dtype: torch.dtype) -> float:\n    \"\"\"Helper for grabbing type-specific precision constants.\n\n    Args:\n        dtype: Datatype.\n\n    Returns:\n        Output float.\n    \"\"\"\n    return {\n        torch.float32: 1e-5,\n        torch.float64: 1e-10,\n    }[dtype]\n\n\ndef register_lie_group(\n    *,\n    matrix_dim: int,\n    parameters_dim: int,\n    tangent_dim: int,\n    space_dim: int,\n) -> Callable[[Type[T]], Type[T]]:\n    \"\"\"Decorator for registering Lie group dataclasses.\n\n    Sets dimensionality class variables, and (formerly in the JAX version) marks all methods for JIT compilation.\n    \"\"\"\n\n    def _wrap(cls: Type[T]) -> Type[T]:\n        # Register dimensions as class attributes.\n        cls.matrix_dim = matrix_dim\n        cls.parameters_dim = parameters_dim\n        cls.tangent_dim = tangent_dim\n        cls.space_dim = space_dim\n\n        return cls\n\n    return _wrap\n"
  },
  {
    "path": "src/egoallo/preprocessing/util/__init__.py",
    "content": "from .tensor import *\n"
  },
  {
    "path": "src/egoallo/preprocessing/util/tensor.py",
    "content": "from loguru import logger as guru\nfrom typing import TypeVar, Dict, List\nimport torch\nfrom torch import Tensor\nimport torch.nn.functional as F\nimport numpy as np\nfrom PIL import Image\n\n\ndef batch_sum(x, nldims=1):\n    \"\"\"\n    Sum across all but batch dimension(s)\n    :param x (B, *)\n    :param nldims (optional int=1) number of leading dims to keep\n    \"\"\"\n    if x.ndim > nldims:\n        return x.sum(dim=tuple(range(nldims, x.ndim)))\n    return x\n\n\ndef batch_mean(x, nldims=1):\n    \"\"\"\n    Mean across all but batch dimension(s)\n    :param x (B, *)\n    :param nldims (optional int=1) number of leading dims to keep\n    \"\"\"\n    if x.ndim > nldims:\n        return x.mean(dim=tuple(range(nldims, x.ndim)))\n    return x\n\n\ndef pad_dim(x, max_len, dim=0, start=0, **kwargs):\n    \"\"\"\n    pads x to max_len in specified dim\n    :param x (tensor)\n    :param max_len (int)\n    :param start (int default 0)\n    :param dim (optional int default 0)\n    \"\"\"\n    N = x.shape[dim]\n    if max_len == N:\n        return x\n\n    if max_len < N:\n        return torch.narrow(x, dim, start, max_len)\n\n    if dim < 0:\n        dim = x.ndim + dim\n    pad = [0, 0] * x.ndim\n    pad[2 * dim + 1] = start\n    pad[2 * dim] = max_len - (N + start)\n    return F.pad(x, pad[::-1], **kwargs)\n\n\ndef pad_back(x, max_len, dim=0, **kwargs):\n    return pad_dim(x, max_len, dim, 0, **kwargs)\n\n\ndef pad_front(x, max_len, dim=0, **kwargs):\n    N = x.shape[dim]\n    return pad_dim(x, max_len, dim=dim, start=max_len - N, **kwargs)\n\n\ndef read_image(path, scale=1):\n    im = Image.open(path)\n    if scale == 1:\n        return np.array(im)\n    W, H = im.size\n    w, h = int(scale * W), int(scale * H)\n    return np.array(im.resize((w, h), Image.ANTIALIAS))\n\n\nT = TypeVar(\"T\")\n\n\ndef move_to(obj: T, device) -> T:\n    if isinstance(obj, torch.Tensor):\n        return obj.to(device)\n    if isinstance(obj, dict):\n        return {k: move_to(v, device) for k, v in obj.items()}  # type: ignore\n    if isinstance(obj, (list, tuple)):\n        return [move_to(x, device) for x in obj]  # type: ignore\n    return obj  # otherwise do nothing\n\n\ndef detach_all(obj: T) -> T:\n    if isinstance(obj, torch.Tensor):\n        return obj.detach()\n    if isinstance(obj, dict):\n        return {k: detach_all(v) for k, v in obj.items()}  # type: ignore\n    if isinstance(obj, (list, tuple)):\n        return [detach_all(x) for x in obj]  # type: ignore\n    return obj  # otherwise do nothing\n\n\ndef to_torch(obj):\n    if isinstance(obj, np.ndarray):\n        return torch.from_numpy(obj).float()\n    if isinstance(obj, dict):\n        return {k: to_torch(v) for k, v in obj.items()}\n    if isinstance(obj, (list, tuple)):\n        return [to_torch(x) for x in obj]\n    return obj\n\n\ndef to_np(obj):\n    if isinstance(obj, torch.Tensor):\n        return obj.numpy()\n    if isinstance(obj, dict):\n        return {k: to_np(v) for k, v in obj.items()}\n    if isinstance(obj, (list, tuple)):\n        return [to_np(x) for x in obj]\n    return obj\n\n\ndef load_npz_as_dict(path, **kwargs):\n    npz = np.load(path, **kwargs)\n    return {key: npz[key] for key in npz.files}\n\n\ndef get_device(i=0):\n    device = f\"cuda:{i}\" if torch.cuda.is_available() else \"cpu\"\n    return torch.device(device)\n\n\ndef invert_nested_dict(d):\n    \"\"\"\n    invert nesting of dict of dicts\n    \"\"\"\n    outer_keys = d.keys()  # outer nested keys\n    inner_keys = next(iter(d.values())).keys()  # inner nested keys\n    return {\n        inner: {outer: d[outer][inner] for outer in outer_keys} for inner in inner_keys\n    }\n\n\ndef batchify_dicts(dict_list: List[Dict]) -> Dict:\n    \"\"\"\n    given a list of dicts with shared keys,\n    return a dict of those keys stacked into lists\n    \"\"\"\n    keys = dict_list[0].keys()\n    if not all(d.keys() == keys for d in dict_list):\n        guru.warning(\"found dicts with not same keys! using first element's keys\")\n    return {k: [d[k] for d in dict_list] for k in keys}\n\n\ndef batchify_recursive(dict_list: List[Dict], levels: int = -1):\n    x = dict_list[0]\n    keys = x.keys()\n    out = {}\n    for k in keys:\n        if isinstance(x[k], dict) and levels != 0:\n            # aggregate the values with this key into\n            # a list of dicts and batchify recursively\n            vals = batchify_recursive(\n                [d[k] for d in dict_list], levels=levels - 1\n            )\n        elif isinstance(x[k], (list, tuple)) and levels != 0:\n            # aggregate the values with this key into\n            # a flattened list and batchify recursively\n            vals = [x for d in dict_list for x in d[k]]\n            # perhaps another list of dicts\n            if isinstance(vals[0], dict) and levels != 0:\n                vals = batchify_recursive( vals, levels=levels - 1)\n        else:\n            # aggregate the values with this key into a list as is\n            vals = [d[k] for d in dict_list]\n        out[k] = vals\n    return out\n\n\ndef unbatch_dict(batched_dict, batch_size):\n    \"\"\"\n    :param d (dict) of batched tensors\n    return len B list of dicts of unbatched tensors\n    \"\"\"\n    out_list = [{} for _ in range(batch_size)]\n    for k, v in batched_dict.items():\n        for b in range(batch_size):\n            out_list[b][k] = get_batch_element(v, b, batch_size)\n    return out_list\n\n\ndef get_batch_element(batch, idx, batch_size):\n    if isinstance(batch, torch.Tensor):\n        return batch[idx] if idx < batch.shape[0] else batch\n    if isinstance(batch, dict):\n        return {k: get_batch_element(v, idx, batch_size) for k, v in batch.items()}\n    if isinstance(batch, list):\n        if len(batch) == batch_size:\n            return batch[idx]\n        return [get_batch_element(v, idx, batch_size) for v in batch]\n    if isinstance(batch, tuple):\n        if len(batch) == batch_size:\n            return batch[idx]\n        return tuple(get_batch_element(v, idx, batch_size) for v in batch)\n    return batch\n\n\ndef narrow_dict(input_dict, tdim, start, length):\n    \"\"\"\n    slice dict of tensors\n    :param d (dict)\n    :param idcs (tensor or list)\n    \"\"\"\n    input_batch = {}\n    for k, v in input_dict.items():\n        input_batch[k] = narrow_obj(v, tdim, start, length)\n    return input_batch\n\n\ndef narrow_list(input_list, tdim, start, length):\n    return [narrow_obj(x, tdim, start, length) for x in input_list]\n\n\ndef narrow_obj(v, tdim, start, length):\n    if isinstance(v, dict):\n        return narrow_dict(v, tdim, start, length)\n    if isinstance(v, (tuple, list)):\n        return narrow_list(v, tdim, start, length)\n    if not isinstance(v, Tensor):\n        return v\n    if v.ndim <= tdim or v.shape[tdim] < start + length:\n        return v\n    return v.narrow(tdim, start, length)\n\n\ndef scatter_intervals(tensor, start, end, T: int = -1):\n    \"\"\"\n    Scatter the tensor contents into intervals from start to end\n    output tensor indexed from 0 to end.max()\n    :param tensor (B, S, *)\n    :param start (B) start indices\n    :param end (B) end indices\n    :param T (int, optional) max length\n    returns (B, T, *) scattered tensor\n    \"\"\"\n    assert isinstance(tensor, torch.Tensor) and tensor.ndim >= 2\n    if T < 0:\n        T = end.max()\n    assert torch.all(end <= T)\n\n    B, S, *dims = tensor.shape\n    start, end = start.long(), end.long()\n    # get idcs that go past the last time step so we don't have repeat indices in scatter\n    idcs = time_segment_idcs(start, end, min_len=T, clip=False)  # (B, T)\n    # mask out the extra padding\n    mask = idcs >= end[:, None]\n    tensor[mask] = 0\n\n    idcs = idcs.reshape(B, S, *(1,) * len(dims)).repeat(1, 1, *dims)\n    output = torch.zeros(\n        B, idcs.max() + 1, *dims, device=tensor.device, dtype=tensor.dtype\n    )\n    output.scatter_(1, idcs, tensor)\n    # slice out the extra segments\n    return output[:, :T]\n\n\ndef get_scatter_mask(start, end, T):\n    \"\"\"\n    get the mask of selected intervals\n    \"\"\"\n    B = start.shape[0]\n    start, end = start.long(), end.long()\n    assert torch.all(end <= T)\n    idcs = time_segment_idcs(start, end, clip=True)\n    mask = torch.zeros(B, T, device=start.device, dtype=torch.bool)\n    mask.scatter_(1, idcs, 1)\n    return mask\n\n\ndef select_intervals(series, start, end, pad_len: int = -1):\n    \"\"\"\n    Select slices of a tensor from start to end\n    will pad uneven sequences to all the max segment length\n    :param series (B, T, *)\n    :param start (B)\n    :param end (B)\n    returns (B, S, *) selected segments, S = max(end - start)\n    \"\"\"\n    B, T, *dims = series.shape\n    assert torch.all(end <= T)\n    sel = time_segment_idcs(start, end, min_len=pad_len, clip=True)\n    S = sel.shape[1]\n    sel = sel.reshape(B, S, *(1,) * len(dims)).repeat(1, 1, *dims)\n    return torch.gather(series, 1, sel)\n\n\ndef get_select_mask(start, end):\n    \"\"\"\n    get the mask of unpadded elementes for the selected time segments\n    e.g. sel[mask] are the unpadded elements\n    :param start (B)\n    :param end (B)\n    \"\"\"\n    idcs = time_segment_idcs(start, end, clip=False)\n    return idcs < end[:, None]  # (B, S)\n\n\ndef time_segment_idcs(start, end, min_len: int = -1, clip: bool = True):\n    \"\"\"\n    :param start (B)\n    :param end (B)\n    returns (B, S) long tensor of indices, where S = max(end - start)\n    \"\"\"\n    start, end = start.long(), end.long()\n    S = max(int((end - start).max()), min_len)\n    seg = torch.arange(S, dtype=torch.int64, device=start.device)\n    idcs = start[:, None] + seg[None, :]  # (B, S)\n    if clip:\n        # clip at the lengths of each track\n        imax = torch.maximum(end - 1, start)[:, None]\n        idcs = idcs.clamp(max=imax)\n    return idcs\n"
  },
  {
    "path": "src/egoallo/py.typed",
    "content": ""
  },
  {
    "path": "src/egoallo/sampling.py",
    "content": "from __future__ import annotations\n\nimport time\n\nimport numpy as np\nimport torch\nfrom jaxtyping import Float\nfrom torch import Tensor\nfrom tqdm.auto import tqdm\n\nfrom . import fncsmpl, network\nfrom .guidance_optimizer_jax import GuidanceMode, do_guidance_optimization\nfrom .hand_detection_structs import (\n    CorrespondedAriaHandWristPoseDetections,\n    CorrespondedHamerDetections,\n)\nfrom .tensor_dataclass import TensorDataclass\nfrom .transforms import SE3\n\n\ndef quadratic_ts() -> np.ndarray:\n    \"\"\"DDIM sampling schedule.\"\"\"\n    end_step = 0\n    start_step = 1000\n    x = np.arange(end_step, int(np.sqrt(start_step))) ** 2\n    x[-1] = start_step\n    return x[::-1]\n\n\nclass CosineNoiseScheduleConstants(TensorDataclass):\n    alpha_t: Float[Tensor, \"T\"]\n    r\"\"\"$1 - \\beta_t$\"\"\"\n\n    alpha_bar_t: Float[Tensor, \"T+1\"]\n    r\"\"\"$\\Prod_{j=1}^t (1 - \\beta_j)$\"\"\"\n\n    @staticmethod\n    def compute(timesteps: int, s: float = 0.008) -> CosineNoiseScheduleConstants:\n        steps = timesteps + 1\n        x = torch.linspace(0, 1, steps, dtype=torch.float64)\n\n        def get_betas():\n            alphas_cumprod = torch.cos((x + s) / (1 + s) * torch.pi * 0.5) ** 2\n            alphas_cumprod = alphas_cumprod / alphas_cumprod[0]\n            betas = 1.0 - (alphas_cumprod[1:] / alphas_cumprod[:-1])\n            return torch.clip(betas, 0, 0.999)\n\n        alpha_t = 1.0 - get_betas()\n        assert len(alpha_t.shape) == 1\n        alpha_cumprod_t = torch.cat(\n            [torch.ones((1,)), torch.cumprod(alpha_t, dim=0)],\n            dim=0,\n        )\n        return CosineNoiseScheduleConstants(\n            alpha_t=alpha_t, alpha_bar_t=alpha_cumprod_t\n        )\n\n\ndef run_sampling_with_stitching(\n    denoiser_network: network.EgoDenoiser,\n    body_model: fncsmpl.SmplhModel,\n    guidance_mode: GuidanceMode,\n    guidance_post: bool,\n    guidance_inner: bool,\n    Ts_world_cpf: Float[Tensor, \"time 7\"],\n    floor_z: float,\n    hamer_detections: None | CorrespondedHamerDetections,\n    aria_detections: None | CorrespondedAriaHandWristPoseDetections,\n    num_samples: int,\n    device: torch.device,\n    guidance_verbose: bool = True,\n) -> network.EgoDenoiseTraj:\n    # Offset the T_world_cpf transform to place the floor at z=0 for the\n    # denoiser network. All of the network outputs are local, so we don't need to\n    # unoffset when returning.\n    Ts_world_cpf_shifted = Ts_world_cpf.clone()\n    Ts_world_cpf_shifted[..., 6] -= floor_z\n\n    noise_constants = CosineNoiseScheduleConstants.compute(timesteps=1000).to(\n        device=device\n    )\n    alpha_bar_t = noise_constants.alpha_bar_t\n    alpha_t = noise_constants.alpha_t\n\n    T_cpf_tm1_cpf_t = (\n        SE3(Ts_world_cpf[..., :-1, :]).inverse() @ SE3(Ts_world_cpf[..., 1:, :])\n    ).wxyz_xyz\n\n    x_t_packed = torch.randn(\n        (num_samples, Ts_world_cpf.shape[0] - 1, denoiser_network.get_d_state()),\n        device=device,\n    )\n    x_t_list = [\n        network.EgoDenoiseTraj.unpack(\n            x_t_packed, include_hands=denoiser_network.config.include_hands\n        )\n    ]\n    ts = quadratic_ts()\n\n    seq_len = x_t_packed.shape[1]\n\n    start_time = None\n\n    window_size = 128\n    overlap_size = 32\n    canonical_overlap_weights = (\n        torch.from_numpy(\n            np.minimum(\n                # Make this shape /```\\\n                overlap_size,\n                np.minimum(\n                    # Make this shape: /\n                    np.arange(1, seq_len + 1),\n                    # Make this shape: \\\n                    np.arange(1, seq_len + 1)[::-1],\n                ),\n            )\n            / overlap_size,\n        )\n        .to(device)\n        .to(torch.float32)\n    )\n    for i in tqdm(range(len(ts) - 1)):\n        print(f\"Sampling {i}/{len(ts) - 1}\")\n        t = ts[i]\n        t_next = ts[i + 1]\n\n        with torch.inference_mode():\n            # Chop everything into windows.\n            x_0_packed_pred = torch.zeros_like(x_t_packed)\n            overlap_weights = torch.zeros((1, seq_len, 1), device=x_t_packed.device)\n\n            # Denoise each window.\n            for start_t in range(0, seq_len, window_size - overlap_size):\n                end_t = min(start_t + window_size, seq_len)\n                assert end_t - start_t > 0\n                overlap_weights_slice = canonical_overlap_weights[\n                    None, : end_t - start_t, None\n                ]\n                overlap_weights[:, start_t:end_t, :] += overlap_weights_slice\n                x_0_packed_pred[:, start_t:end_t, :] += (\n                    denoiser_network.forward(\n                        x_t_packed[:, start_t:end_t, :],\n                        torch.tensor([t], device=device).expand((num_samples,)),\n                        T_cpf_tm1_cpf_t=T_cpf_tm1_cpf_t[None, start_t:end_t, :].repeat(\n                            (num_samples, 1, 1)\n                        ),\n                        T_world_cpf=Ts_world_cpf_shifted[\n                            None, start_t + 1 : end_t + 1, :\n                        ].repeat((num_samples, 1, 1)),\n                        project_output_rotmats=False,\n                        hand_positions_wrt_cpf=None,  # TODO: this should be filled in!!\n                        mask=None,\n                    )\n                    * overlap_weights_slice\n                )\n\n            # Take the mean for overlapping regions.\n            x_0_packed_pred /= overlap_weights\n\n            x_0_packed_pred = network.EgoDenoiseTraj.unpack(\n                x_0_packed_pred,\n                include_hands=denoiser_network.config.include_hands,\n                project_rotmats=True,\n            ).pack()\n\n        if torch.any(torch.isnan(x_0_packed_pred)):\n            print(\"found nan\", i)\n        sigma_t = torch.cat(\n            [\n                torch.zeros((1,), device=device),\n                torch.sqrt(\n                    (1.0 - alpha_bar_t[:-1]) / (1 - alpha_bar_t[1:]) * (1 - alpha_t)\n                )\n                * 0.8,\n            ]\n        )\n\n        if guidance_mode != \"off\" and guidance_inner:\n            x_0_pred, _ = do_guidance_optimization(\n                # It's important that we _don't_ use the shifted transforms here.\n                Ts_world_cpf=Ts_world_cpf[1:, :],\n                traj=network.EgoDenoiseTraj.unpack(\n                    x_0_packed_pred, include_hands=denoiser_network.config.include_hands\n                ),\n                body_model=body_model,\n                guidance_mode=guidance_mode,\n                phase=\"inner\",\n                hamer_detections=hamer_detections,\n                aria_detections=aria_detections,\n                verbose=guidance_verbose,\n            )\n            x_0_packed_pred = x_0_pred.pack()\n            del x_0_pred\n\n        if start_time is None:\n            start_time = time.time()\n\n        # print(sigma_t)\n        x_t_packed = (\n            torch.sqrt(alpha_bar_t[t_next]) * x_0_packed_pred\n            + (\n                torch.sqrt(1 - alpha_bar_t[t_next] - sigma_t[t] ** 2)\n                * (x_t_packed - torch.sqrt(alpha_bar_t[t]) * x_0_packed_pred)\n                / torch.sqrt(1 - alpha_bar_t[t] + 1e-1)\n            )\n            + sigma_t[t] * torch.randn(x_0_packed_pred.shape, device=device)\n        )\n        x_t_list.append(\n            network.EgoDenoiseTraj.unpack(\n                x_t_packed, include_hands=denoiser_network.config.include_hands\n            )\n        )\n\n    if guidance_mode != \"off\" and guidance_post:\n        constrained_traj = x_t_list[-1]\n        constrained_traj, _ = do_guidance_optimization(\n            # It's important that we _don't_ use the shifted transforms here.\n            Ts_world_cpf=Ts_world_cpf[1:, :],\n            traj=constrained_traj,\n            body_model=body_model,\n            guidance_mode=guidance_mode,\n            phase=\"post\",\n            hamer_detections=hamer_detections,\n            aria_detections=aria_detections,\n            verbose=guidance_verbose,\n        )\n        assert start_time is not None\n        print(\"RUNTIME (exclude first optimization)\", time.time() - start_time)\n        return constrained_traj\n    else:\n        assert start_time is not None\n        print(\"RUNTIME (exclude first optimization)\", time.time() - start_time)\n        return x_t_list[-1]\n"
  },
  {
    "path": "src/egoallo/tensor_dataclass.py",
    "content": "import dataclasses\nfrom typing import Any, Callable, Self, dataclass_transform\n\nimport torch\n\n\n@dataclass_transform()\nclass TensorDataclass:\n    \"\"\"A lighter version of nerfstudio's TensorDataclass:\n    https://github.com/nerfstudio-project/nerfstudio/blob/main/nerfstudio/utils/tensor_dataclass.py\n    \"\"\"\n\n    def __init_subclass__(cls) -> None:\n        dataclasses.dataclass(cls)\n\n    def to(self, device: torch.device | str) -> Self:\n        \"\"\"Move the tensors in the dataclass to the given device.\n\n        Args:\n            device: The device to move to.\n\n        Returns:\n            A new dataclass.\n        \"\"\"\n        return self.map(lambda x: x.to(device))\n\n    def as_nested_dict(self, numpy: bool) -> dict[str, Any]:\n        \"\"\"Convert the dataclass to a nested dictionary.\n\n        Recurses into lists, tuples, and dictionaries.\n        \"\"\"\n\n        def _to_dict(obj: Any) -> Any:\n            if isinstance(obj, TensorDataclass):\n                return {k: _to_dict(v) for k, v in vars(obj).items()}\n            elif isinstance(obj, (list, tuple)):\n                return type(obj)(_to_dict(v) for v in obj)\n            elif isinstance(obj, dict):\n                return {k: _to_dict(v) for k, v in obj.items()}\n            elif isinstance(obj, torch.Tensor) and numpy:\n                return obj.numpy(force=True)\n            else:\n                return obj\n\n        return _to_dict(self)\n\n    def map(self, fn: Callable[[torch.Tensor], torch.Tensor]) -> Self:\n        \"\"\"Apply a function to all tensors in the dataclass.\n\n        Also recurses into lists, tuples, and dictionaries.\n\n        Args:\n            fn: The function to apply to each tensor.\n\n        Returns:\n            A new dataclass.\n        \"\"\"\n\n        def _map_impl[MapT](\n            fn: Callable[[torch.Tensor], torch.Tensor],\n            val: MapT,\n        ) -> MapT:\n            if isinstance(val, torch.Tensor):\n                return fn(val)\n            elif isinstance(val, TensorDataclass):\n                return type(val)(**_map_impl(fn, vars(val)))\n            elif isinstance(val, (list, tuple)):\n                return type(val)(_map_impl(fn, v) for v in val)\n            elif isinstance(val, dict):\n                assert type(val) is dict  # No subclass support.\n                return {k: _map_impl(fn, v) for k, v in val.items()}  # type: ignore\n            else:\n                return val\n\n        return _map_impl(fn, self)\n"
  },
  {
    "path": "src/egoallo/training_loss.py",
    "content": "\"\"\"Training loss configuration.\"\"\"\n\nimport dataclasses\nfrom typing import Literal\n\nimport torch.utils.data\nfrom jaxtyping import Bool, Float, Int\nfrom torch import Tensor\nfrom torch._dynamo import OptimizedModule\nfrom torch.nn.parallel import DistributedDataParallel\n\nfrom . import network\nfrom .data.amass import EgoTrainingData\nfrom .sampling import CosineNoiseScheduleConstants\nfrom .transforms import SO3\n\n\n@dataclasses.dataclass(frozen=True)\nclass TrainingLossConfig:\n    cond_dropout_prob: float = 0.0\n    beta_coeff_weights: tuple[float, ...] = tuple(1 / (i + 1) for i in range(16))\n    loss_weights: dict[str, float] = dataclasses.field(\n        default_factory={\n            \"betas\": 0.1,\n            \"body_rotmats\": 1.0,\n            \"contacts\": 0.1,\n            # We don't have many hands in the AMASS dataset...\n            \"hand_rotmats\": 0.01,\n        }.copy\n    )\n    weight_loss_by_t: Literal[\"emulate_eps_pred\"] = \"emulate_eps_pred\"\n    \"\"\"Weights to apply to the loss at each noise level.\"\"\"\n\n\nclass TrainingLossComputer:\n    \"\"\"Helper class for computing the training loss. Contains a single method\n    for computing a training loss.\"\"\"\n\n    def __init__(self, config: TrainingLossConfig, device: torch.device) -> None:\n        self.config = config\n        self.noise_constants = (\n            CosineNoiseScheduleConstants.compute(timesteps=1000)\n            .to(device)\n            .map(lambda tensor: tensor.to(torch.float32))\n        )\n\n        # Emulate loss weight that would be ~equivalent to epsilon prediction.\n        #\n        # This will penalize later errors (close to the end of sampling) much\n        # more than earlier ones (at the start of sampling).\n        assert self.config.weight_loss_by_t == \"emulate_eps_pred\"\n        weight_t = self.noise_constants.alpha_bar_t / (\n            1 - self.noise_constants.alpha_bar_t\n        )\n        # Pad for numerical stability, and scale between [padding, 1.0].\n        padding = 0.01\n        self.weight_t = weight_t / weight_t[1] * (1.0 - padding) + padding\n\n    def compute_denoising_loss(\n        self,\n        model: network.EgoDenoiser | DistributedDataParallel | OptimizedModule,\n        unwrapped_model: network.EgoDenoiser,\n        train_batch: EgoTrainingData,\n    ) -> tuple[Tensor, dict[str, Tensor | float]]:\n        \"\"\"Compute a training loss for the EgoDenoiser model.\n\n        Returns:\n            A tuple (loss tensor, dictionary of things to log).\n        \"\"\"\n        log_outputs: dict[str, Tensor | float] = {}\n\n        batch, time, num_joints, _ = train_batch.body_quats.shape\n        assert num_joints == 21\n        if unwrapped_model.config.include_hands:\n            assert train_batch.hand_quats is not None\n            x_0 = network.EgoDenoiseTraj(\n                betas=train_batch.betas.expand((batch, time, 16)),\n                body_rotmats=SO3(train_batch.body_quats).as_matrix(),\n                contacts=train_batch.contacts,\n                hand_rotmats=SO3(train_batch.hand_quats).as_matrix(),\n            )\n        else:\n            x_0 = network.EgoDenoiseTraj(\n                betas=train_batch.betas.expand((batch, time, 16)),\n                body_rotmats=SO3(train_batch.body_quats).as_matrix(),\n                contacts=train_batch.contacts,\n                hand_rotmats=None,\n            )\n        x_0_packed = x_0.pack()\n        device = x_0_packed.device\n        assert x_0_packed.shape == (batch, time, unwrapped_model.get_d_state())\n\n        # Diffuse.\n        t = torch.randint(\n            low=1,\n            high=unwrapped_model.config.max_t + 1,\n            size=(batch,),\n            device=device,\n        )\n        eps = torch.randn(x_0_packed.shape, dtype=x_0_packed.dtype, device=device)\n        assert self.noise_constants.alpha_bar_t.shape == (\n            unwrapped_model.config.max_t + 1,\n        )\n        alpha_bar_t = self.noise_constants.alpha_bar_t[t, None, None]\n        assert alpha_bar_t.shape == (batch, 1, 1)\n        x_t_packed = (\n            torch.sqrt(alpha_bar_t) * x_0_packed + torch.sqrt(1.0 - alpha_bar_t) * eps\n        )\n\n        hand_positions_wrt_cpf: Tensor | None = None\n        if unwrapped_model.config.include_hand_positions_cond:\n            # Joints 19 and 20 are the hand positions.\n            hand_positions_wrt_cpf = train_batch.joints_wrt_cpf[:, :, 19:21, :].reshape(\n                (batch, time, 6)\n            )\n\n            # Exclude hand positions for some items in the batch. We'll just do\n            # this by passing in zeros.\n            hand_positions_wrt_cpf = torch.where(\n                # Uniformly drop out with some uniformly sampled probability.\n                # :)\n                (\n                    torch.rand((batch, time, 1), device=device)\n                    < torch.rand((batch, 1, 1), device=device)\n                ),\n                hand_positions_wrt_cpf,\n                0.0,\n            )\n\n        # Denoise.\n        x_0_packed_pred = model.forward(\n            x_t_packed=x_t_packed,\n            t=t,\n            T_world_cpf=train_batch.T_world_cpf,\n            T_cpf_tm1_cpf_t=train_batch.T_cpf_tm1_cpf_t,\n            hand_positions_wrt_cpf=hand_positions_wrt_cpf,\n            project_output_rotmats=False,\n            mask=train_batch.mask,\n            cond_dropout_keep_mask=torch.rand((batch,), device=device)\n            > self.config.cond_dropout_prob\n            if self.config.cond_dropout_prob > 0.0\n            else None,\n        )\n        assert isinstance(x_0_packed_pred, torch.Tensor)\n        x_0_pred = network.EgoDenoiseTraj.unpack(\n            x_0_packed_pred, include_hands=unwrapped_model.config.include_hands\n        )\n\n        weight_t = self.weight_t[t].to(device)\n        assert weight_t.shape == (batch,)\n\n        def weight_and_mask_loss(\n            loss_per_step: Float[Tensor, \"b t d\"],\n            # bt stands for \"batch time\"\n            bt_mask: Bool[Tensor, \"b t\"] = train_batch.mask,\n            bt_mask_sum: Int[Tensor, \"\"] = torch.sum(train_batch.mask),\n        ) -> Float[Tensor, \"\"]:\n            \"\"\"Weight and mask per-timestep losses (squared errors).\"\"\"\n            _, _, d = loss_per_step.shape\n            assert loss_per_step.shape == (batch, time, d)\n            assert bt_mask.shape == (batch, time)\n            assert weight_t.shape == (batch,)\n            return (\n                # Sum across b axis.\n                torch.sum(\n                    # Sum across t axis.\n                    torch.sum(\n                        # Mean across d axis.\n                        torch.mean(loss_per_step, dim=-1) * bt_mask,\n                        dim=-1,\n                    )\n                    * weight_t\n                )\n                / bt_mask_sum\n            )\n\n        loss_terms: dict[str, Tensor | float] = {\n            \"betas\": weight_and_mask_loss(\n                # (b, t, 16)\n                (x_0_pred.betas - x_0.betas) ** 2\n                # (16,)\n                * x_0.betas.new_tensor(self.config.beta_coeff_weights),\n            ),\n            \"body_rotmats\": weight_and_mask_loss(\n                # (b, t, 21 * 3 * 3)\n                (x_0_pred.body_rotmats - x_0.body_rotmats).reshape(\n                    (batch, time, 21 * 3 * 3)\n                )\n                ** 2,\n            ),\n            \"contacts\": weight_and_mask_loss((x_0_pred.contacts - x_0.contacts) ** 2),\n        }\n\n        # Include hand objective.\n        # We didn't use this in the paper.\n        if unwrapped_model.config.include_hands:\n            assert x_0_pred.hand_rotmats is not None\n            assert x_0.hand_rotmats is not None\n            assert x_0.hand_rotmats.shape == (batch, time, 30, 3, 3)\n\n            # Detect whether or not hands move in a sequence.\n            # We should only supervise sequences where the hands are actully tracked / move;\n            # we mask out hands in AMASS sequences where they are not tracked.\n            gt_hand_flatmat = x_0.hand_rotmats.reshape((batch, time, -1))\n            hand_motion = (\n                torch.sum(  # (b,) from (b, t)\n                    torch.sum(  # (b, t) from (b, t, d)\n                        torch.abs(gt_hand_flatmat - gt_hand_flatmat[:, 0:1, :]), dim=-1\n                    )\n                    # Zero out changes in masked frames.\n                    * train_batch.mask,\n                    dim=-1,\n                )\n                > 1e-5\n            )\n            assert hand_motion.shape == (batch,)\n\n            hand_bt_mask = torch.logical_and(hand_motion[:, None], train_batch.mask)\n            loss_terms[\"hand_rotmats\"] = torch.sum(\n                weight_and_mask_loss(\n                    (x_0_pred.hand_rotmats - x_0.hand_rotmats).reshape(\n                        batch, time, 30 * 3 * 3\n                    )\n                    ** 2,\n                    bt_mask=hand_bt_mask,\n                    # We want to weight the loss by the number of frames where\n                    # the hands actually move, but gradients here can be too\n                    # noisy and put NaNs into mixed-precision training when we\n                    # inevitably sample too few frames. So we clip the\n                    # denominator.\n                    bt_mask_sum=torch.maximum(\n                        torch.sum(hand_bt_mask), torch.tensor(256, device=device)\n                    ),\n                )\n            )\n            # self.log(\n            #     \"train/hand_motion_proportion\",\n            #     torch.sum(hand_motion) / batch,\n            # )\n        else:\n            loss_terms[\"hand_rotmats\"] = 0.0\n\n        assert loss_terms.keys() == self.config.loss_weights.keys()\n\n        # Log loss terms.\n        for name, term in loss_terms.items():\n            log_outputs[f\"loss_term/{name}\"] = term\n\n        # Return loss.\n        loss = sum([loss_terms[k] * self.config.loss_weights[k] for k in loss_terms])\n        assert isinstance(loss, Tensor)\n        assert loss.shape == ()\n        log_outputs[\"train_loss\"] = loss\n\n        return loss, log_outputs\n"
  },
  {
    "path": "src/egoallo/training_utils.py",
    "content": "\"\"\"Utilities for writing training scripts.\"\"\"\n\nimport dataclasses\nimport pdb\nimport signal\nimport subprocess\nimport sys\nimport time\nimport traceback as tb\nfrom pathlib import Path\nfrom typing import (\n    Any,\n    Dict,\n    Generator,\n    Iterable,\n    Protocol,\n    Sized,\n    get_type_hints,\n    overload,\n)\n\nimport torch\n\n\ndef flattened_hparam_dict_from_dataclass(\n    dataclass: Any, prefix: str | None = None\n) -> Dict[str, Any]:\n    \"\"\"Convert a config object in the form of a nested dataclass into a\n    flattened dictionary, for use with Tensorboard hparams.\"\"\"\n    assert dataclasses.is_dataclass(dataclass)\n    cls = type(dataclass)\n    hints = get_type_hints(cls)\n\n    output = {}\n    for field in dataclasses.fields(dataclass):\n        field_type = hints[field.name]\n        value = getattr(dataclass, field.name)\n        if dataclasses.is_dataclass(field_type):\n            inner = flattened_hparam_dict_from_dataclass(value, prefix=None)\n            inner = {\".\".join([field.name, k]): v for k, v in inner.items()}\n            output.update(inner)\n        # Cast to type supported by tensorboard hparams.\n        elif isinstance(value, (int, float, str, bool, torch.Tensor)):\n            output[field.name] = value\n        else:\n            output[field.name] = str(value)\n\n    if prefix is None:\n        return output\n    else:\n        return {f\"{prefix}.{k}\": v for k, v in output.items()}\n\n\ndef pdb_safety_net():\n    \"\"\"Attaches a \"safety net\" for unexpected errors in a Python script.\n\n    When called, PDB will be automatically opened when either (a) the user hits Ctrl+C\n    or (b) we encounter an uncaught exception. Helpful for bypassing minor errors,\n    diagnosing problems, and rescuing unsaved models.\n    \"\"\"\n\n    # Open PDB on Ctrl+C\n    def handler(sig, frame):\n        pdb.set_trace()\n\n    signal.signal(signal.SIGINT, handler)\n\n    # Open PDB when we encounter an uncaught exception\n    def excepthook(type_, value, traceback):  # pragma: no cover (impossible to test)\n        tb.print_exception(type_, value, traceback, limit=100)\n        pdb.post_mortem(traceback)\n\n    sys.excepthook = excepthook\n\n\nclass SizedIterable[ContainedType](Iterable[ContainedType], Sized, Protocol):\n    \"\"\"Protocol for objects that define both `__iter__()` and `__len__()` methods.\n\n    This is particularly useful for managing minibatches, which can be iterated over but\n    only in order due to multiprocessing/prefetching optimizations, and for which length\n    evaluation is useful for tools like `tqdm`.\"\"\"\n\n\n@dataclasses.dataclass\nclass LoopMetrics:\n    counter: int\n    iterations_per_sec: float\n    time_elapsed: float\n\n\n@overload\ndef range_with_metrics(stop: int, /) -> SizedIterable[LoopMetrics]: ...\n\n\n@overload\ndef range_with_metrics(start: int, stop: int, /) -> SizedIterable[LoopMetrics]: ...\n\n\n@overload\ndef range_with_metrics(\n    start: int, stop: int, step: int, /\n) -> SizedIterable[LoopMetrics]: ...\n\n\ndef range_with_metrics(*args: int) -> SizedIterable[LoopMetrics]:\n    \"\"\"Light wrapper for `fifteen.utils.loop_metric_generator()`, for use in place of\n    `range()`. Yields a LoopMetrics object instead of an integer.\"\"\"\n    return _RangeWithMetrics(args=args)\n\n\n@dataclasses.dataclass\nclass _RangeWithMetrics:\n    args: tuple[int, ...]\n\n    def __iter__(self):\n        loop_metrics = loop_metric_generator()\n        for counter in range(*self.args):\n            yield dataclasses.replace(next(loop_metrics), counter=counter)\n\n    def __len__(self) -> int:\n        return len(range(*self.args))\n\n\ndef loop_metric_generator(counter_init: int = 0) -> Generator[LoopMetrics, None, None]:\n    \"\"\"Generator for computing loop metrics.\n\n    Note that the first `iteration_per_sec` metric will be 0.0.\n\n    Example usage:\n    ```\n    # Note that this is an infinite loop.\n    for metric in loop_metric_generator():\n        time.sleep(1.0)\n        print(metric)\n    ```\n\n    or:\n    ```\n    loop_metrics = loop_metric_generator()\n    while True:\n        time.sleep(1.0)\n        print(next(loop_metrics).iterations_per_sec)\n    ```\n    \"\"\"\n\n    counter = counter_init\n    del counter_init\n    time_start = time.time()\n    time_prev = time_start\n    while True:\n        time_now = time.time()\n        yield LoopMetrics(\n            counter=counter,\n            iterations_per_sec=1.0 / (time_now - time_prev) if counter > 0 else 0.0,\n            time_elapsed=time_now - time_start,\n        )\n        time_prev = time_now\n        counter += 1\n\n\ndef get_git_commit_hash(cwd: Path | None = None) -> str:\n    \"\"\"Returns the current Git commit hash.\"\"\"\n    if cwd is None:\n        cwd = Path.cwd()\n    return (\n        subprocess.check_output([\"git\", \"rev-parse\", \"HEAD\"], cwd=cwd.as_posix())\n        .decode(\"ascii\")\n        .strip()\n    )\n\n\ndef get_git_diff(cwd: Path | None = None) -> str:\n    \"\"\"Returns the output of `git diff HEAD`.\"\"\"\n    if cwd is None:\n        cwd = Path.cwd()\n    return (\n        subprocess.check_output([\"git\", \"diff\", \"HEAD\"], cwd=cwd.as_posix())\n        .decode(\"ascii\")\n        .strip()\n    )\n"
  },
  {
    "path": "src/egoallo/transforms/__init__.py",
    "content": "\"\"\"Rigid transforms implemented in PyTorch, ported from jaxlie.\"\"\"\n\nfrom . import utils as utils\nfrom ._base import MatrixLieGroup as MatrixLieGroup\nfrom ._base import SEBase as SEBase\nfrom ._base import SOBase as SOBase\nfrom ._se3 import SE3 as SE3\nfrom ._so3 import SO3 as SO3\n"
  },
  {
    "path": "src/egoallo/transforms/_base.py",
    "content": "import abc\nfrom typing import (\n    ClassVar,\n    Generic,\n    Self,\n    Tuple,\n    Type,\n    TypeVar,\n    Union,\n    final,\n    overload,\n    override,\n)\n\nimport numpy as onp\nimport torch\nfrom torch import Tensor\n\nGroupType = TypeVar(\"GroupType\", bound=\"MatrixLieGroup\")\nSEGroupType = TypeVar(\"SEGroupType\", bound=\"SEBase\")\n\n\nclass MatrixLieGroup(abc.ABC):\n    \"\"\"Interface definition for matrix Lie groups.\"\"\"\n\n    # Class properties.\n    # > These will be set in `_utils.register_lie_group()`.\n\n    matrix_dim: ClassVar[int]\n    \"\"\"Dimension of square matrix output from `.as_matrix()`.\"\"\"\n\n    parameters_dim: ClassVar[int]\n    \"\"\"Dimension of underlying parameters, `.parameters()`.\"\"\"\n\n    tangent_dim: ClassVar[int]\n    \"\"\"Dimension of tangent space.\"\"\"\n\n    space_dim: ClassVar[int]\n    \"\"\"Dimension of coordinates that can be transformed.\"\"\"\n\n    def __init__(\n        # Notes:\n        # - For the constructor signature to be consistent with subclasses, `parameters`\n        #   should be marked as positional-only. But this isn't possible in Python 3.7.\n        # - This method is implicitly overriden by the dataclass decorator and\n        #   should _not_ be marked abstract.\n        self,\n        parameters: Tensor,\n    ):\n        \"\"\"Construct a group object from its underlying parameters.\"\"\"\n        raise NotImplementedError()\n\n    # Shared implementations.\n\n    @overload\n    def __matmul__(self: GroupType, other: GroupType) -> GroupType: ...\n\n    @overload\n    def __matmul__(self, other: Tensor) -> Tensor: ...\n\n    def __matmul__(\n        self: GroupType, other: Union[GroupType, Tensor]\n    ) -> Union[GroupType, Tensor]:\n        \"\"\"Overload for the `@` operator.\n\n        Switches between the group action (`.apply()`) and multiplication\n        (`.multiply()`) based on the type of `other`.\n        \"\"\"\n        if isinstance(other, (onp.ndarray, Tensor)):\n            return self.apply(target=other)\n        elif isinstance(other, MatrixLieGroup):\n            assert self.space_dim == other.space_dim\n            return self.multiply(other=other)\n        else:\n            assert False, f\"Invalid argument type for `@` operator: {type(other)}\"\n\n    # Factory.\n\n    @classmethod\n    @abc.abstractmethod\n    def identity(\n        cls: Type[GroupType], device: Union[torch.device, str], dtype: torch.dtype\n    ) -> GroupType:\n        \"\"\"Returns identity element.\n\n        Returns:\n            Identity element.\n        \"\"\"\n\n    @classmethod\n    @abc.abstractmethod\n    def from_matrix(cls: Type[GroupType], matrix: Tensor) -> GroupType:\n        \"\"\"Get group member from matrix representation.\n\n        Args:\n            matrix: Matrix representaiton.\n\n        Returns:\n            Group member.\n        \"\"\"\n\n    # Accessors.\n\n    @abc.abstractmethod\n    def as_matrix(self) -> Tensor:\n        \"\"\"Get transformation as a matrix. Homogeneous for SE groups.\"\"\"\n\n    @abc.abstractmethod\n    def parameters(self) -> Tensor:\n        \"\"\"Get underlying representation.\"\"\"\n\n    # Operations.\n\n    @abc.abstractmethod\n    def apply(self, target: Tensor) -> Tensor:\n        \"\"\"Applies group action to a point.\n\n        Args:\n            target: Point to transform.\n\n        Returns:\n            Transformed point.\n        \"\"\"\n\n    @abc.abstractmethod\n    def multiply(self: Self, other: Self) -> Self:\n        \"\"\"Composes this transformation with another.\n\n        Returns:\n            self @ other\n        \"\"\"\n\n    @classmethod\n    @abc.abstractmethod\n    def exp(cls: Type[GroupType], tangent: Tensor) -> GroupType:\n        \"\"\"Computes `expm(wedge(tangent))`.\n\n        Args:\n            tangent: Tangent vector to take the exponential of.\n\n        Returns:\n            Output.\n        \"\"\"\n\n    @abc.abstractmethod\n    def log(self) -> Tensor:\n        \"\"\"Computes `vee(logm(transformation matrix))`.\n\n        Returns:\n            Output. Shape should be `(tangent_dim,)`.\n        \"\"\"\n\n    @abc.abstractmethod\n    def adjoint(self) -> Tensor:\n        \"\"\"Computes the adjoint, which transforms tangent vectors between tangent\n        spaces.\n\n        More precisely, for a transform `GroupType`:\n        ```\n        GroupType @ exp(omega) = exp(Adj_T @ omega) @ GroupType\n        ```\n\n        In robotics, typically used for transforming twists, wrenches, and Jacobians\n        across different reference frames.\n\n        Returns:\n            Output. Shape should be `(tangent_dim, tangent_dim)`.\n        \"\"\"\n\n    @abc.abstractmethod\n    def inverse(self: GroupType) -> GroupType:\n        \"\"\"Computes the inverse of our transform.\n\n        Returns:\n            Output.\n        \"\"\"\n\n    @abc.abstractmethod\n    def normalize(self: GroupType) -> GroupType:\n        \"\"\"Normalize/projects values and returns.\n\n        Returns:\n            GroupType: Normalized group member.\n        \"\"\"\n\n    # @classmethod\n    # @abc.abstractmethod\n    # def sample_uniform(cls: Type[GroupType], key: Tensor) -> GroupType:\n    #     \"\"\"Draw a uniform sample from the group. Translations (if applicable) are in the\n    #     range [-1, 1].\n    #\n    #     Args:\n    #         key: PRNG key, as returned by `jax.random.PRNGKey()`.\n    #\n    #     Returns:\n    #         Sampled group member.\n    #     \"\"\"\n\n    def get_batch_axes(self) -> Tuple[int, ...]:\n        \"\"\"Return any leading batch axes in contained parameters. If an array of shape\n        `(100, 4)` is placed in the wxyz field of an SO3 object, for example, this will\n        return `(100,)`.\"\"\"\n        return self.parameters().shape[:-1]\n\n\nclass SOBase(MatrixLieGroup):\n    \"\"\"Base class for special orthogonal groups.\"\"\"\n\n\nContainedSOType = TypeVar(\"ContainedSOType\", bound=SOBase)\n\n\nclass SEBase(Generic[ContainedSOType], MatrixLieGroup):\n    \"\"\"Base class for special Euclidean groups.\n\n    Each SE(N) group member contains an SO(N) rotation, as well as an N-dimensional\n    translation vector.\n    \"\"\"\n\n    # SE-specific interface.\n\n    @classmethod\n    @abc.abstractmethod\n    def from_rotation_and_translation(\n        cls: Type[SEGroupType],\n        rotation: ContainedSOType,\n        translation: Tensor,\n    ) -> SEGroupType:\n        \"\"\"Construct a rigid transform from a rotation and a translation.\n\n        Args:\n            rotation: Rotation term.\n            translation: Translation term.\n\n        Returns:\n            Constructed transformation.\n        \"\"\"\n\n    @final\n    @classmethod\n    def from_rotation(cls: Type[SEGroupType], rotation: ContainedSOType) -> SEGroupType:\n        return cls.from_rotation_and_translation(\n            rotation=rotation,\n            translation=rotation.parameters().new_zeros(\n                (*rotation.parameters().shape[:-1], cls.space_dim),\n                dtype=rotation.parameters().dtype,\n            ),\n        )\n\n    @abc.abstractmethod\n    def rotation(self) -> ContainedSOType:\n        \"\"\"Returns a transform's rotation term.\"\"\"\n\n    @abc.abstractmethod\n    def translation(self) -> Tensor:\n        \"\"\"Returns a transform's translation term.\"\"\"\n\n    # Overrides.\n\n    @final\n    @override\n    def apply(self, target: Tensor) -> Tensor:\n        return self.rotation() @ target + self.translation()  # type: ignore\n\n    @final\n    @override\n    def multiply(self: SEGroupType, other: SEGroupType) -> SEGroupType:  # type: ignore\n        return type(self).from_rotation_and_translation(\n            rotation=self.rotation() @ other.rotation(),\n            translation=(self.rotation() @ other.translation()) + self.translation(),\n        )\n\n    @final\n    @override\n    def inverse(self: SEGroupType) -> SEGroupType:\n        R_inv = self.rotation().inverse()\n        return type(self).from_rotation_and_translation(\n            rotation=R_inv,\n            translation=-(R_inv @ self.translation()),\n        )\n\n    @final\n    @override\n    def normalize(self: SEGroupType) -> SEGroupType:\n        return type(self).from_rotation_and_translation(\n            rotation=self.rotation().normalize(),\n            translation=self.translation(),\n        )\n"
  },
  {
    "path": "src/egoallo/transforms/_se3.py",
    "content": "from __future__ import annotations\n\nfrom dataclasses import dataclass\nfrom typing import Union, cast, override\n\nimport numpy as np\nimport torch\nfrom torch import Tensor\n\nfrom . import _base\nfrom ._so3 import SO3\nfrom .utils import get_epsilon, register_lie_group\n\n\ndef _skew(omega: Tensor) -> Tensor:\n    \"\"\"\n    Returns the skew-symmetric form of a length-3 vector.\n    :param omega (*, 3)\n    :returns (*, 3, 3)\n    \"\"\"\n\n    wx, wy, wz = omega.unbind(dim=-1)\n    o = torch.zeros_like(wx)\n    return torch.stack(\n        [o, -wz, wy, wz, o, -wx, -wy, wx, o],\n        dim=-1,\n    ).reshape(*wx.shape, 3, 3)\n\n\n@register_lie_group(\n    matrix_dim=4,\n    parameters_dim=7,\n    tangent_dim=6,\n    space_dim=3,\n)\n@dataclass(frozen=True)\nclass SE3(_base.SEBase[SO3]):\n    \"\"\"Special Euclidean group for proper rigid transforms in 3D.\n\n    Internal parameterization is `(qw, qx, qy, qz, x, y, z)`. Tangent parameterization\n    is `(vx, vy, vz, omega_x, omega_y, omega_z)`.\n    \"\"\"\n\n    # SE3-specific.\n\n    wxyz_xyz: Tensor\n    \"\"\"Internal parameters. wxyz quaternion followed by xyz translation.\"\"\"\n\n    @override\n    def __repr__(self) -> str:\n        quat = np.round(self.wxyz_xyz[..., :4].numpy(force=True), 5)\n        trans = np.round(self.wxyz_xyz[..., 4:].numpy(force=True), 5)\n        return f\"{self.__class__.__name__}(wxyz={quat}, xyz={trans})\"\n\n    # SE-specific.\n\n    @classmethod\n    @override\n    def from_rotation_and_translation(\n        cls,\n        rotation: SO3,\n        translation: Tensor,\n    ) -> SE3:\n        assert translation.shape[-1] == 3\n        return SE3(wxyz_xyz=torch.cat([rotation.wxyz, translation], dim=-1))\n\n    @override\n    def rotation(self) -> SO3:\n        return SO3(wxyz=self.wxyz_xyz[..., :4])\n\n    @override\n    def translation(self) -> Tensor:\n        return self.wxyz_xyz[..., 4:]\n\n    # Factory.\n\n    @classmethod\n    @override\n    def identity(cls, device: Union[torch.device, str], dtype: torch.dtype) -> SE3:\n        return SE3(\n            wxyz_xyz=torch.tensor(\n                [1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], device=device, dtype=dtype\n            )\n        )\n\n    @classmethod\n    @override\n    def from_matrix(cls, matrix: Tensor) -> SE3:\n        assert matrix.shape[-2:] == (4, 4) or matrix.shape[-2:] == (3, 4)\n        # Currently assumes bottom row is [0, 0, 0, 1].\n        return SE3.from_rotation_and_translation(\n            rotation=SO3.from_matrix(matrix[..., :3, :3]),\n            translation=matrix[..., :3, 3],\n        )\n\n    # Accessors.\n\n    @override\n    def as_matrix(self) -> Tensor:\n        R = self.rotation().as_matrix()  # (*, 3, 3)\n        t = self.translation().unsqueeze(-1)  # (*, 3, 1)\n        dims = R.shape[:-2]\n        bottom = (\n            torch.tensor([0, 0, 0, 1], dtype=R.dtype, device=R.device)\n            .reshape(*(1,) * len(dims), 1, 4)\n            .repeat(*dims, 1, 1)\n        )\n        return torch.cat([torch.cat([R, t], dim=-1), bottom], dim=-2)\n\n    @override\n    def parameters(self) -> Tensor:\n        return self.wxyz_xyz\n\n    # Operations.\n\n    @classmethod\n    @override\n    def exp(cls, tangent: Tensor) -> SE3:\n        # Reference:\n        # > https://github.com/strasdat/Sophus/blob/a0fe89a323e20c42d3cecb590937eb7a06b8343a/sophus/se3.hpp#L761\n\n        # (x, y, z, omega_x, omega_y, omega_z)\n        assert tangent.shape[-1] == 6\n\n        rotation = SO3.exp(tangent[..., 3:])\n\n        theta_squared = torch.square(tangent[..., 3:]).sum(dim=-1)  # (*)\n        use_taylor = theta_squared < get_epsilon(theta_squared.dtype)\n\n        theta_squared_safe = cast(\n            Tensor,\n            torch.where(\n                use_taylor,\n                1.0,  # Any non-zero value should do here.\n                theta_squared,\n            ),\n        )\n        del theta_squared\n        theta_safe = torch.sqrt(theta_squared_safe)\n\n        skew_omega = _skew(tangent[..., 3:])\n        dtype = skew_omega.dtype\n        device = skew_omega.device\n        V = torch.where(\n            use_taylor[..., None, None],\n            rotation.as_matrix(),\n            (\n                torch.eye(3, device=device, dtype=dtype)\n                + ((1.0 - torch.cos(theta_safe)) / (theta_squared_safe))[\n                    ..., None, None\n                ]\n                * skew_omega\n                + (\n                    (theta_safe - torch.sin(theta_safe))\n                    / (theta_squared_safe * theta_safe)\n                )[..., None, None]\n                * torch.einsum(\"...ij,...jk->...ik\", skew_omega, skew_omega)\n            ),\n        )\n\n        return SE3.from_rotation_and_translation(\n            rotation=rotation,\n            translation=torch.einsum(\"...ij,...j->...i\", V, tangent[..., :3]),\n        )\n\n    @override\n    def log(self) -> Tensor:\n        # Reference:\n        # > https://github.com/strasdat/Sophus/blob/a0fe89a323e20c42d3cecb590937eb7a06b8343a/sophus/se3.hpp#L223\n        omega = self.rotation().log()\n        theta_squared = torch.square(omega).sum(dim=-1)  # (*)\n        use_taylor = theta_squared < get_epsilon(theta_squared.dtype)\n\n        skew_omega = _skew(omega)\n\n        # Shim to avoid NaNs in jnp.where branches, which cause failures for\n        # reverse-mode AD.\n        theta_squared_safe = torch.where(\n            use_taylor,\n            1.0,  # Any non-zero value should do here.\n            theta_squared,\n        )\n        del theta_squared\n        theta_safe = torch.sqrt(theta_squared_safe)\n        half_theta_safe = theta_safe / 2.0\n\n        dtype = omega.dtype\n        device = omega.device\n        V_inv = torch.where(\n            use_taylor[..., None, None],\n            torch.eye(3, device=device, dtype=dtype)\n            - 0.5 * skew_omega\n            + torch.matmul(skew_omega, skew_omega) / 12.0,\n            (\n                torch.eye(3, device=device, dtype=dtype)\n                - 0.5 * skew_omega\n                + (\n                    1.0\n                    - theta_safe\n                    * torch.cos(half_theta_safe)\n                    / (2.0 * torch.sin(half_theta_safe))\n                )[..., None, None]\n                / theta_squared_safe[..., None, None]\n                * torch.matmul(skew_omega, skew_omega)\n            ),\n        )\n        return torch.cat(\n            [torch.einsum(\"...ij,...j->...i\", V_inv, self.translation()), omega], dim=-1\n        )\n\n    @override\n    def adjoint(self) -> Tensor:\n        R = self.rotation().as_matrix()\n        dims = R.shape[:-2]\n        # (*, 6, 6)\n        return torch.cat(\n            [\n                torch.cat([R, torch.matmul(_skew(self.translation()), R)], dim=-1),\n                torch.cat([torch.zeros((*dims, 3, 3)), R], dim=-1),\n            ],\n            dim=-2,\n        )\n"
  },
  {
    "path": "src/egoallo/transforms/_so3.py",
    "content": "from __future__ import annotations\n\nfrom dataclasses import dataclass\nfrom typing import Union, override\n\nimport numpy as np\nimport torch\nfrom torch import Tensor\n\nfrom . import _base\nfrom .utils import get_epsilon, register_lie_group\n\n\n@register_lie_group(\n    matrix_dim=3,\n    parameters_dim=4,\n    tangent_dim=3,\n    space_dim=3,\n)\n@dataclass(frozen=True)\nclass SO3(_base.SOBase):\n    \"\"\"Special orthogonal group for 3D rotations.\n\n    Internal parameterization is `(qw, qx, qy, qz)`. Tangent parameterization is\n    `(omega_x, omega_y, omega_z)`.\n    \"\"\"\n\n    # SO3-specific.\n\n    wxyz: Tensor\n    \"\"\"Internal parameters. `(w, x, y, z)` quaternion.\"\"\"\n\n    @override\n    def __repr__(self) -> str:\n        wxyz = np.round(self.wxyz.numpy(force=True), 5)\n        return f\"{self.__class__.__name__}(wxyz={wxyz})\"\n\n    @staticmethod\n    def from_x_radians(theta: Tensor) -> SO3:\n        \"\"\"Generates a x-axis rotation.\n\n        Args:\n            angle: X rotation, in radians.\n\n        Returns:\n            Output.\n        \"\"\"\n        zeros = torch.zeros_like(theta)\n        return SO3.exp(torch.stack([theta, zeros, zeros], dim=-1))\n\n    @staticmethod\n    def from_y_radians(theta: Tensor) -> SO3:\n        \"\"\"Generates a y-axis rotation.\n\n        Args:\n            angle: Y rotation, in radians.\n\n        Returns:\n            Output.\n        \"\"\"\n        zeros = torch.zeros_like(theta)\n        return SO3.exp(torch.stack([zeros, theta, zeros], dim=-1))\n\n    @staticmethod\n    def from_z_radians(theta: Tensor) -> SO3:\n        \"\"\"Generates a z-axis rotation.\n\n        Args:\n            angle: Z rotation, in radians.\n\n        Returns:\n            Output.\n        \"\"\"\n        zeros = torch.zeros_like(theta)\n        return SO3.exp(torch.stack([zeros, zeros, theta], dim=-1))\n\n    @staticmethod\n    def from_rpy_radians(\n        roll: Tensor,\n        pitch: Tensor,\n        yaw: Tensor,\n    ) -> SO3:\n        \"\"\"Generates a transform from a set of Euler angles. Uses the ZYX mobile robot\n        convention.\n\n        Args:\n            roll: X rotation, in radians. Applied first.\n            pitch: Y rotation, in radians. Applied second.\n            yaw: Z rotation, in radians. Applied last.\n\n        Returns:\n            Output.\n        \"\"\"\n        return (\n            SO3.from_z_radians(yaw)\n            @ SO3.from_y_radians(pitch)\n            @ SO3.from_x_radians(roll)\n        )\n\n    @staticmethod\n    def from_quaternion_xyzw(xyzw: Tensor) -> SO3:\n        \"\"\"Construct a rotation from an `xyzw` quaternion.\n\n        Note that `wxyz` quaternions can be constructed using the default dataclass\n        constructor.\n\n        Args:\n            xyzw: xyzw quaternion. Shape should be (4,).\n\n        Returns:\n            Output.\n        \"\"\"\n        assert xyzw.shape == (4,)\n        return SO3(torch.roll(xyzw, shifts=1, dims=-1))\n\n    def as_quaternion_xyzw(self) -> Tensor:\n        \"\"\"Grab parameters as xyzw quaternion.\"\"\"\n        return torch.roll(self.wxyz, shifts=-1, dims=-1)\n\n    # Factory.\n\n    @classmethod\n    @override\n    def identity(cls, device: Union[torch.device, str], dtype: torch.dtype) -> SO3:\n        return SO3(wxyz=torch.tensor([1.0, 0.0, 0.0, 0.0], device=device, dtype=dtype))\n\n    @classmethod\n    @override\n    def from_matrix(cls, matrix: Tensor) -> SO3:\n        assert matrix.shape[-2:] == (3, 3)\n\n        # Modified from:\n        # > \"Converting a Rotation Matrix to a Quaternion\" from Mike Day\n        # > https://d3cw3dd2w32x2b.cloudfront.net/wp-content/uploads/2015/01/matrix-to-quat.pdf\n\n        def case0(m):\n            t = 1 + m[..., 0, 0] - m[..., 1, 1] - m[..., 2, 2]\n            q = torch.stack(\n                [\n                    m[..., 2, 1] - m[..., 1, 2],\n                    t,\n                    m[..., 1, 0] + m[..., 0, 1],\n                    m[..., 0, 2] + m[..., 2, 0],\n                ],\n                dim=-1,\n            )\n            return t, q\n\n        def case1(m):\n            t = 1 - m[..., 0, 0] + m[..., 1, 1] - m[..., 2, 2]\n            q = torch.stack(\n                [\n                    m[..., 0, 2] - m[..., 2, 0],\n                    m[..., 1, 0] + m[..., 0, 1],\n                    t,\n                    m[..., 2, 1] + m[..., 1, 2],\n                ],\n                dim=-1,\n            )\n            return t, q\n\n        def case2(m):\n            t = 1 - m[..., 0, 0] - m[..., 1, 1] + m[..., 2, 2]\n            q = torch.stack(\n                [\n                    m[..., 1, 0] - m[..., 0, 1],\n                    m[..., 0, 2] + m[..., 2, 0],\n                    m[..., 2, 1] + m[..., 1, 2],\n                    t,\n                ],\n                dim=-1,\n            )\n            return t, q\n\n        def case3(m):\n            t = 1 + m[..., 0, 0] + m[..., 1, 1] + m[..., 2, 2]\n            q = torch.stack(\n                [\n                    t,\n                    m[..., 2, 1] - m[..., 1, 2],\n                    m[..., 0, 2] - m[..., 2, 0],\n                    m[..., 1, 0] - m[..., 0, 1],\n                ],\n                dim=-1,\n            )\n            return t, q\n\n        # Compute four cases, then pick the most precise one.\n        # Probably worth revisiting this!\n        case0_t, case0_q = case0(matrix)\n        case1_t, case1_q = case1(matrix)\n        case2_t, case2_q = case2(matrix)\n        case3_t, case3_q = case3(matrix)\n\n        cond0 = matrix[..., 2, 2] < 0\n        cond1 = matrix[..., 0, 0] > matrix[..., 1, 1]\n        cond2 = matrix[..., 0, 0] < -matrix[..., 1, 1]\n\n        t = torch.where(\n            cond0,\n            torch.where(cond1, case0_t, case1_t),\n            torch.where(cond2, case2_t, case3_t),\n        )\n        q = torch.where(\n            cond0[..., None],\n            torch.where(cond1[..., None], case0_q, case1_q),\n            torch.where(cond2[..., None], case2_q, case3_q),\n        )\n        return SO3(wxyz=q * 0.5 / torch.sqrt(t[..., None]))\n\n    # Accessors.\n\n    @override\n    def as_matrix(self) -> Tensor:\n        norm_sq = torch.square(self.wxyz).sum(dim=-1, keepdim=True)\n        qvec = self.wxyz * torch.sqrt(2.0 / norm_sq)  # (*, 4)\n        Q = torch.einsum(\"...i,...j->...ij\", qvec, qvec)  # (*, 4, 4)\n        return torch.stack(\n            [\n                1.0 - Q[..., 2, 2] - Q[..., 3, 3],\n                Q[..., 1, 2] - Q[..., 3, 0],\n                Q[..., 1, 3] + Q[..., 2, 0],\n                Q[..., 1, 2] + Q[..., 3, 0],\n                1.0 - Q[..., 1, 1] - Q[..., 3, 3],\n                Q[..., 2, 3] - Q[..., 1, 0],\n                Q[..., 1, 3] - Q[..., 2, 0],\n                Q[..., 2, 3] + Q[..., 1, 0],\n                1.0 - Q[..., 1, 1] - Q[..., 2, 2],\n            ],\n            dim=-1,\n        ).reshape(*qvec.shape[:-1], 3, 3)\n\n    @override\n    def parameters(self) -> Tensor:\n        return self.wxyz\n\n    # Operations.\n\n    @override\n    def apply(self, target: Tensor) -> Tensor:\n        assert target.shape[-1] == 3\n\n        # Compute using quaternion multiplys.\n        padded_target = torch.cat([torch.ones_like(target[..., :1]), target], dim=-1)\n        out = self.multiply(SO3(wxyz=padded_target).multiply(self.inverse()))\n        return out.wxyz[..., 1:]\n\n    @override\n    def multiply(self, other: SO3) -> SO3:  # type: ignore\n        w0, x0, y0, z0 = self.wxyz.unbind(dim=-1)\n        w1, x1, y1, z1 = other.wxyz.unbind(dim=-1)\n        wxyz2 = torch.stack(\n            [\n                -x0 * x1 - y0 * y1 - z0 * z1 + w0 * w1,\n                x0 * w1 + y0 * z1 - z0 * y1 + w0 * x1,\n                -x0 * z1 + y0 * w1 + z0 * x1 + w0 * y1,\n                x0 * y1 - y0 * x1 + z0 * w1 + w0 * z1,\n            ],\n            dim=-1,\n        )\n\n        return SO3(wxyz=wxyz2)\n\n    @classmethod\n    @override\n    def exp(cls, tangent: Tensor) -> SO3:\n        # Reference:\n        # > https://github.com/strasdat/Sophus/blob/a0fe89a323e20c42d3cecb590937eb7a06b8343a/sophus/so3.hpp#L583\n\n        assert tangent.shape[-1] == 3\n\n        theta_squared = torch.square(tangent).sum(dim=-1)  # (*)\n        theta_pow_4 = theta_squared * theta_squared\n        use_taylor = theta_squared < get_epsilon(tangent.dtype)\n\n        safe_theta = torch.sqrt(\n            torch.where(\n                use_taylor,\n                torch.ones_like(theta_squared),  # Any constant value should do here.\n                theta_squared,\n            )\n        )\n        safe_half_theta = 0.5 * safe_theta\n\n        real_factor = torch.where(\n            use_taylor,\n            1.0 - theta_squared / 8.0 + theta_pow_4 / 384.0,\n            torch.cos(safe_half_theta),\n        )\n\n        imaginary_factor = torch.where(\n            use_taylor,\n            0.5 - theta_squared / 48.0 + theta_pow_4 / 3840.0,\n            torch.sin(safe_half_theta) / safe_theta,\n        )\n\n        return SO3(\n            wxyz=torch.cat(\n                [\n                    real_factor[..., None],\n                    imaginary_factor[..., None] * tangent,\n                ],\n                dim=-1,\n            )\n        )\n\n    @override\n    def log(self) -> Tensor:\n        # Reference:\n        # > https://github.com/strasdat/Sophus/blob/a0fe89a323e20c42d3cecb590937eb7a06b8343a/sophus/so3.hpp#L247\n\n        w, xyz = torch.split(self.wxyz, [1, 3], dim=-1)  # (*, 1), (*, 3)\n        norm_sq = torch.square(xyz).sum(dim=-1, keepdim=True)  # (*, 1)\n        use_taylor = norm_sq < get_epsilon(norm_sq.dtype)\n\n        norm_safe = torch.sqrt(\n            torch.where(\n                use_taylor,\n                torch.ones_like(norm_sq),  # Any non-zero value should do here.\n                norm_sq,\n            )\n        )\n        w_safe = torch.where(use_taylor, w, torch.ones_like(w))\n        atan_n_over_w = torch.atan2(\n            torch.where(w < 0, -norm_safe, norm_safe),\n            torch.abs(w),\n        )\n        atan_factor = torch.where(\n            use_taylor,\n            2.0 / w_safe - 2.0 / 3.0 * norm_sq / w_safe**3,\n            torch.where(\n                torch.abs(w) < get_epsilon(w.dtype),\n                torch.where(w > 0, 1.0, -1.0) * torch.pi / norm_safe,\n                2.0 * atan_n_over_w / norm_safe,\n            ),\n        )\n\n        return atan_factor * xyz\n\n    @override\n    def adjoint(self) -> Tensor:\n        return self.as_matrix()\n\n    @override\n    def inverse(self) -> SO3:\n        # Negate complex terms.\n        w, xyz = torch.split(self.wxyz, [1, 3], dim=-1)\n        return SO3(wxyz=torch.cat([w, -xyz], dim=-1))\n\n    @override\n    def normalize(self) -> SO3:\n        return SO3(wxyz=self.wxyz / torch.linalg.norm(self.wxyz, dim=-1, keepdim=True))\n"
  },
  {
    "path": "src/egoallo/transforms/utils/__init__.py",
    "content": "from ._utils import get_epsilon, register_lie_group\n\n__all__ = [\"get_epsilon\", \"register_lie_group\"]\n"
  },
  {
    "path": "src/egoallo/transforms/utils/_utils.py",
    "content": "from typing import TYPE_CHECKING, Callable, Type, TypeVar\n\nimport torch\n\nif TYPE_CHECKING:\n    from .._base import MatrixLieGroup\n\n\nT = TypeVar(\"T\", bound=\"MatrixLieGroup\")\n\n\ndef get_epsilon(dtype: torch.dtype) -> float:\n    \"\"\"Helper for grabbing type-specific precision constants.\n\n    Args:\n        dtype: Datatype.\n\n    Returns:\n        Output float.\n    \"\"\"\n    return {\n        torch.float32: 1e-5,\n        torch.float64: 1e-10,\n    }[dtype]\n\n\ndef register_lie_group(\n    *,\n    matrix_dim: int,\n    parameters_dim: int,\n    tangent_dim: int,\n    space_dim: int,\n) -> Callable[[Type[T]], Type[T]]:\n    \"\"\"Decorator for registering Lie group dataclasses.\n\n    Sets dimensionality class variables.\n    \"\"\"\n\n    def _wrap(cls: Type[T]) -> Type[T]:\n        # Register dimensions as class attributes.\n        cls.matrix_dim = matrix_dim\n        cls.parameters_dim = parameters_dim\n        cls.tangent_dim = tangent_dim\n        cls.space_dim = space_dim\n\n        return cls\n\n    return _wrap\n"
  },
  {
    "path": "src/egoallo/vis_helpers.py",
    "content": "import time\nfrom pathlib import Path\nfrom typing import Callable, TypedDict\n\nimport numpy as np\nimport numpy.typing as npt\nimport torch\nimport trimesh\nimport viser\nimport viser.transforms as vtf\nfrom jaxtyping import Float\nfrom plyfile import PlyData\nfrom torch import Tensor\n\nfrom . import fncsmpl, fncsmpl_extensions, network\nfrom .hand_detection_structs import (\n    CorrespondedAriaHandWristPoseDetections,\n    CorrespondedHamerDetections,\n)\nfrom .transforms import SE3, SO3\n\n\nclass SplatArgs(TypedDict):\n    centers: npt.NDArray[np.floating]\n    \"\"\"(N, 3).\"\"\"\n    rgbs: npt.NDArray[np.floating]\n    \"\"\"(N, 3). Range [0, 1].\"\"\"\n    opacities: npt.NDArray[np.floating]\n    \"\"\"(N, 1). Range [0, 1].\"\"\"\n    covariances: npt.NDArray[np.floating]\n    \"\"\"(N, 3, 3).\"\"\"\n\n\ndef load_splat_file(splat_path: Path, center: bool = False) -> SplatArgs:\n    \"\"\"Load an antimatter15-style splat file.\"\"\"\n    start_time = time.time()\n    splat_buffer = splat_path.read_bytes()\n    bytes_per_gaussian = (\n        # Each Gaussian is serialized as:\n        # - position (vec3, float32)\n        3 * 4\n        # - xyz (vec3, float32)\n        + 3 * 4\n        # - rgba (vec4, uint8)\n        + 4\n        # - ijkl (vec4, uint8), where 0 => -1, 255 => 1.\n        + 4\n    )\n    assert len(splat_buffer) % bytes_per_gaussian == 0\n    num_gaussians = len(splat_buffer) // bytes_per_gaussian\n\n    # Reinterpret cast to dtypes that we want to extract.\n    splat_uint8 = np.frombuffer(splat_buffer, dtype=np.uint8).reshape(\n        (num_gaussians, bytes_per_gaussian)\n    )\n    scales = splat_uint8[:, 12:24].copy().view(np.float32)\n    wxyzs = splat_uint8[:, 28:32] / 255.0 * 2.0 - 1.0\n    Rs = vtf.SO3(wxyzs).as_matrix()\n    covariances = np.einsum(\n        \"nij,njk,nlk->nil\", Rs, np.eye(3)[None, :, :] * scales[:, None, :] ** 2, Rs\n    )\n    centers = splat_uint8[:, 0:12].copy().view(np.float32)\n    if center:\n        centers -= np.mean(centers, axis=0, keepdims=True)\n    print(\n        f\"Splat file with {num_gaussians=} loaded in {time.time() - start_time} seconds\"\n    )\n    return {\n        \"centers\": centers,\n        # Colors should have shape (N, 3).\n        \"rgbs\": splat_uint8[:, 24:27] / 255.0,\n        \"opacities\": splat_uint8[:, 27:28] / 255.0,\n        # Covariances should have shape (N, 3, 3).\n        \"covariances\": covariances,\n    }\n\n\ndef load_ply_file(ply_file_path: Path, center: bool = False) -> SplatArgs:\n    \"\"\"Load Gaussians stored in a PLY file.\"\"\"\n    start_time = time.time()\n\n    SH_C0 = 0.28209479177387814\n\n    plydata = PlyData.read(ply_file_path)\n    v = plydata[\"vertex\"]\n    positions = np.stack([v[\"x\"], v[\"y\"], v[\"z\"]], axis=-1)\n    scales = np.exp(np.stack([v[\"scale_0\"], v[\"scale_1\"], v[\"scale_2\"]], axis=-1))\n    wxyzs = np.stack([v[\"rot_0\"], v[\"rot_1\"], v[\"rot_2\"], v[\"rot_3\"]], axis=1)\n    colors = 0.5 + SH_C0 * np.stack([v[\"f_dc_0\"], v[\"f_dc_1\"], v[\"f_dc_2\"]], axis=1)\n    opacities = 1.0 / (1.0 + np.exp(-v[\"opacity\"][:, None]))\n\n    Rs = vtf.SO3(wxyzs).as_matrix()\n    covariances = np.einsum(\n        \"nij,njk,nlk->nil\", Rs, np.eye(3)[None, :, :] * scales[:, None, :] ** 2, Rs\n    )\n    if center:\n        positions -= np.mean(positions, axis=0, keepdims=True)\n\n    num_gaussians = len(v)\n    print(\n        f\"PLY file with {num_gaussians=} loaded in {time.time() - start_time} seconds\"\n    )\n    return {\n        \"centers\": positions,\n        \"rgbs\": colors,\n        \"opacities\": opacities,\n        \"covariances\": covariances,\n    }\n\n\ndef add_splat_to_viser(\n    splat_or_ply_path: Path, server: viser.ViserServer, z_offset: float = 0.0\n) -> None:\n    \"\"\"Add some Gaussian splats to the Viser server.\"\"\"\n    if splat_or_ply_path.suffix.lower() == \".ply\":\n        splat_args = load_ply_file(splat_or_ply_path)\n    elif splat_or_ply_path.suffix.lower() == \".splat\":\n        splat_args = load_splat_file(splat_or_ply_path)\n    else:\n        assert False\n    server.scene.add_gaussian_splats(\n        \"/gaussian_splats\",\n        centers=splat_args[\"centers\"],\n        rgbs=splat_args[\"rgbs\"],\n        opacities=splat_args[\"opacities\"],\n        covariances=splat_args[\"covariances\"],\n        position=(0.0, 0.0, z_offset),\n    )\n\n\ndef visualize_traj_and_hand_detections(\n    server: viser.ViserServer,\n    Ts_world_cpf: Float[Tensor, \"timesteps 7\"],\n    traj: network.EgoDenoiseTraj | None,\n    body_model: fncsmpl.SmplhModel,\n    hamer_detections: CorrespondedHamerDetections | None = None,\n    aria_detections: CorrespondedAriaHandWristPoseDetections | None = None,\n    points_data: np.ndarray | None = None,\n    splat_path: Path | None = None,\n    floor_z: float = 0.0,\n    show_joints: bool = False,\n    get_ego_video: Callable[[int, int, float], bytes] | None = None,\n) -> Callable[[], int]:\n    \"\"\"Chaotic mega-function for visualization. Returns a callback that should\n    be called repeatedly in a loop.\"\"\"\n\n    timesteps = Ts_world_cpf.shape[0]\n\n    server.scene.add_grid(\n        \"/ground\",\n        plane=\"xy\",\n        cell_color=(80, 80, 80),\n        section_color=(50, 50, 50),\n        position=(0.0, 0.0, floor_z),\n    )\n\n    if points_data is not None:\n        point_cloud = server.scene.add_point_cloud(\n            \"/aria_points\",\n            points=points_data,\n            colors=np.cos(points_data + np.arange(3)) / 3.0\n            + 0.7,  # Make points colorful :)\n            point_size=0.01,\n            # point_size=0.1,\n            point_shape=\"sparkle\",\n        )\n        size_slider = server.gui.add_slider(\n            \"Point cloud size\", min=0.001, max=0.05, step=0.001, initial_value=0.005\n        )\n\n        @size_slider.on_update\n        def _(_) -> None:\n            if point_cloud is not None:\n                point_cloud.point_size = size_slider.value\n\n    if splat_path is not None:\n        add_splat_to_viser(splat_path, server)  # , z_offset=-floor_z)\n\n    if traj is not None:\n        betas = traj.betas\n        timesteps = betas.shape[1]\n        sample_count = betas.shape[0]\n        assert betas.shape == (sample_count, timesteps, 16)\n        body_quats = SO3.from_matrix(traj.body_rotmats).wxyz\n        assert body_quats.shape == (sample_count, timesteps, 21, 4)\n        device = body_quats.device\n\n        if traj.hand_rotmats is not None:\n            hand_quats = SO3.from_matrix(traj.hand_rotmats).wxyz\n            left_hand_quats = hand_quats[..., :15, :]\n            right_hand_quats = hand_quats[..., 15:30, :]\n        else:\n            left_hand_quats = None\n            right_hand_quats = None\n\n        shaped = body_model.with_shape(torch.mean(betas, dim=1, keepdim=True))\n        fk_outputs = shaped.with_pose_decomposed(\n            T_world_root=SE3.identity(\n                device=device, dtype=body_quats.dtype\n            ).parameters(),\n            body_quats=body_quats,\n            left_hand_quats=left_hand_quats,\n            right_hand_quats=right_hand_quats,\n        )\n\n        assert Ts_world_cpf.shape == (timesteps, 7)\n        T_world_root = fncsmpl_extensions.get_T_world_root_from_cpf_pose(\n            # Batch axes of fk_outputs are (num_samples, time).\n            # Batch axes of Ts_world_cpf are (time,).\n            fk_outputs,\n            Ts_world_cpf[None, ...],\n        )\n        fk_outputs = fk_outputs.with_new_T_world_root(T_world_root)\n    else:\n        shaped = None\n        fk_outputs = None\n        sample_count = 0\n\n    glasses_mesh = trimesh.load(\"./data/glasses.stl\")\n    assert isinstance(glasses_mesh, trimesh.Trimesh)\n    glasses_mesh.visual.face_colors = [10, 20, 20, 255]  # type: ignore\n\n    cpf_handle = server.scene.add_frame(\n        \"/cpf\",\n        show_axes=True,\n        axes_length=0.05,\n        axes_radius=0.004,\n    )\n    server.scene.add_mesh_trimesh(\"/cpf/glasses\", glasses_mesh, scale=0.001 * 1.05)\n\n    # TODO: remove\n    # hamer_detections = None\n    # aria_detections = None\n\n    joint_position_handles: list[viser.SceneNodeHandle] = []\n    timestep_handles: list[viser.FrameHandle] = []\n    hamer_handles: list[viser.MeshHandle | viser.PointCloudHandle] = []\n    aria_handles: list[viser.SceneNodeHandle] = []\n    for t in range(Ts_world_cpf.shape[0]):\n        timestep_handles.append(\n            server.scene.add_frame(f\"/timesteps/{t}\", show_axes=False)\n        )\n\n        # Joints.\n        if show_joints and fk_outputs is not None:\n            assert traj is not None\n            for j in range(sample_count):\n                joints_colors = np.zeros((21, 3))\n                joints_colors[:, 0] = traj.contacts[j, t, :].numpy(force=True)\n                joints_colors[:, 2] = 1.0 - traj.contacts[j, t, :].numpy(force=True)\n                joint_position_handles.append(\n                    server.scene.add_point_cloud(\n                        f\"/timesteps/{t}/joints\",\n                        points=fk_outputs.Ts_world_joint[j, t, :21, 4:7].numpy(\n                            force=True\n                        ),\n                        colors=joints_colors,\n                        point_shape=\"circle\",\n                        point_size=0.02,\n                    )\n                )\n\n        # Visualize HaMeR outputs.\n        if hamer_detections is not None:\n            T_world_cam = SE3(Ts_world_cpf[t]) @ SE3(hamer_detections.T_cpf_cam)\n            server.scene.add_frame(\n                f\"/timesteps/{t}/cpf/cam\",\n                show_axes=True,\n                axes_length=0.025,\n                axes_radius=0.003,\n                wxyz=T_world_cam.wxyz_xyz[..., :4].numpy(force=True),\n                position=T_world_cam.wxyz_xyz[..., 4:7].numpy(force=True),\n            )\n            hands_l = hamer_detections.detections_left_tuple[t]\n            hands_r = hamer_detections.detections_right_tuple[t]\n            if hands_l is not None:\n                for j in range(hands_l[\"verts\"].shape[0]):\n                    hamer_handles.append(\n                        server.scene.add_mesh_simple(\n                            f\"/timesteps/{t}/cpf/cam/left_hand{j}\",\n                            vertices=hands_l[\"verts\"][j],\n                            faces=hamer_detections.mano_faces_left.numpy(force=True),\n                            visible=False,\n                        )\n                    )\n                    hamer_handles.append(\n                        server.scene.add_point_cloud(\n                            f\"/timesteps/{t}/cpf/cam/lefft_keypoints3d\",\n                            points=hands_l[\"keypoints_3d\"][j],\n                            colors=(255, 127, 0),\n                            point_size=0.008,\n                            point_shape=\"square\",\n                            visible=False,\n                        )\n                    )\n            if hands_r is not None:\n                for j in range(hands_r[\"verts\"].shape[0]):\n                    hamer_handles.append(\n                        server.scene.add_mesh_simple(\n                            f\"/timesteps/{t}/cpf/cam/right_hand{j}\",\n                            vertices=hands_r[\"verts\"][j],\n                            faces=hamer_detections.mano_faces_right.numpy(force=True),\n                            visible=False,\n                        )\n                    )\n                    hamer_handles.append(\n                        server.scene.add_point_cloud(\n                            f\"/timesteps/{t}/cpf/cam/right_keypoints3d\",\n                            points=hands_r[\"keypoints_3d\"][j],\n                            colors=(0, 127, 255),\n                            point_size=0.008,\n                            point_shape=\"square\",\n                            visible=False,\n                        )\n                    )\n\n        # Visualize Aria detections.\n        if aria_detections is not None:\n            for side in (\"left\", \"right\"):\n                detections = {\n                    \"left\": aria_detections.detections_left_concat,\n                    \"right\": aria_detections.detections_right_concat,\n                }[side]\n                if detections is None:\n                    continue\n                indices = detections.indices\n                index = torch.searchsorted(indices, t)\n                if index < len(indices) and indices[index] == t:  # found?\n                    aria_handles.append(\n                        server.scene.add_spline_catmull_rom(\n                            f\"/timesteps/{t}/aria_detections/{side}\",\n                            np.array(\n                                [\n                                    detections.wrist_position[index].numpy(force=True),\n                                    detections.palm_position[index].numpy(force=True),\n                                ]\n                            ),\n                            line_width=3.0,\n                            color=(255, 0, 0) if side == \"left\" else (0, 255, 0),\n                            visible=False,\n                        )\n                    )\n\n    body_handles = (\n        [\n            server.scene.add_mesh_skinned(\n                f\"/persons/{i}\",\n                vertices=shaped.verts_zero[i, 0, :, :].numpy(force=True),\n                faces=body_model.faces.numpy(force=True),\n                bone_wxyzs=vtf.SO3.identity(\n                    batch_axes=(body_model.get_num_joints() + 1,)\n                ).wxyz,\n                bone_positions=np.concatenate(\n                    [\n                        np.zeros((1, 3)),\n                        # Indices are (batch, time, joint, positions).\n                        shaped.joints_zero[i, :, :, :]\n                        .numpy(force=True)\n                        .squeeze(axis=0),\n                    ],\n                    axis=0,\n                ),\n                color=(152, 93, 229),\n                skin_weights=body_model.weights.numpy(force=True),\n            )\n            for i in range(sample_count)\n        ]\n        if shaped is not None\n        else []\n    )\n\n    gui_attach = server.gui.add_checkbox(\"Attach camera to CPF\", initial_value=False)\n    gui_attach_dist = server.gui.add_number(\"Attach distance\", initial_value=0.3)\n    gui_show_body = server.gui.add_checkbox(\"Show body\", initial_value=True)\n    gui_show_glasses = server.gui.add_checkbox(\"Show glasses\", initial_value=True)\n    gui_show_cpf_axes = server.gui.add_checkbox(\"Show CPF axes\", initial_value=False)\n    gui_wireframe = server.gui.add_checkbox(\"Wireframe\", initial_value=False)\n    gui_smpl_opacity = server.gui.add_slider(\n        \"SMPL Opacity\", initial_value=1.0, min=0.0, max=1.0, step=0.01\n    )\n    gui_hamer_opacity = server.gui.add_slider(\n        \"HaMeR Opacity\", initial_value=1.0, min=0.0, max=1.0, step=0.01\n    )\n\n    @gui_smpl_opacity.on_update\n    def _(_) -> None:\n        for handle in body_handles:\n            handle.opacity = gui_smpl_opacity.value\n\n    @gui_hamer_opacity.on_update\n    def _(_) -> None:\n        for handle in hamer_handles:\n            if isinstance(handle, viser.MeshHandle):\n                handle.opacity = gui_hamer_opacity.value\n\n    gui_show_hamer_hands = server.gui.add_checkbox(\n        \"Show HaMeR hands\", initial_value=False\n    )\n    gui_show_aria_hands = server.gui.add_checkbox(\n        \"Show wrist detections\", initial_value=False\n    )\n    gui_body_color = server.gui.add_rgb(\"Body color\", initial_value=(152, 93, 229))\n\n    if show_joints:\n        gui_show_joints = server.gui.add_checkbox(\"Show joints\", initial_value=True)\n\n        @gui_show_joints.on_update\n        def _(_) -> None:\n            for handle in joint_position_handles:\n                handle.visible = gui_show_joints.value\n\n    @gui_show_body.on_update\n    def _(_) -> None:\n        for handle in body_handles:\n            handle.visible = gui_show_body.value\n\n    @gui_show_glasses.on_update\n    def _(_) -> None:\n        # The glasses are a child of the CPF frame.\n        cpf_handle.visible = gui_show_glasses.value\n\n    @gui_show_cpf_axes.on_update\n    def _(_) -> None:\n        cpf_handle.show_axes = gui_show_cpf_axes.value\n\n    @gui_wireframe.on_update\n    def _(_) -> None:\n        for handle in body_handles:\n            handle.wireframe = gui_wireframe.value\n\n    @gui_show_hamer_hands.on_update\n    def _(_) -> None:\n        for handle in hamer_handles:\n            handle.visible = gui_show_hamer_hands.value\n\n    @gui_show_aria_hands.on_update\n    def _(_) -> None:\n        for handle in aria_handles:\n            handle.visible = gui_show_aria_hands.value\n\n    @gui_body_color.on_update\n    def _(_) -> None:\n        for handle in body_handles:\n            handle.color = gui_body_color.value\n\n    # Add playback UI.\n    with server.gui.add_folder(\"Playback\"):\n        gui_timestep = server.gui.add_slider(\n            \"Timestep\",\n            min=0,\n            max=timesteps - 1,\n            step=1,\n            initial_value=0,\n            disabled=True,\n        )\n        gui_start_end = server.gui.add_multi_slider(\n            \"Start/end\",\n            min=0,\n            max=timesteps - 1,\n            initial_value=(0, timesteps - 1),\n            step=1,\n        )\n        gui_next_frame = server.gui.add_button(\"Next Frame\", disabled=True)\n        gui_prev_frame = server.gui.add_button(\"Prev Frame\", disabled=True)\n        gui_playing = server.gui.add_checkbox(\"Playing\", True)\n        gui_framerate = server.gui.add_slider(\n            \"FPS\", min=1, max=60, step=0.1, initial_value=15\n        )\n        gui_framerate_options = server.gui.add_button_group(\n            \"FPS options\", (\"10\", \"20\", \"30\", \"60\")\n        )\n\n    # Frame step buttons.\n    @gui_next_frame.on_click\n    def _(_) -> None:\n        gui_timestep.value = (gui_timestep.value + 1) % timesteps\n\n    @gui_prev_frame.on_click\n    def _(_) -> None:\n        gui_timestep.value = (gui_timestep.value - 1) % timesteps\n\n    # Disable frame controls when we're playing.\n    @gui_playing.on_update\n    def _(_) -> None:\n        gui_timestep.disabled = gui_playing.value\n        gui_next_frame.disabled = gui_playing.value\n        gui_prev_frame.disabled = gui_playing.value\n\n    # Set the framerate when we click one of the options.\n    @gui_framerate_options.on_click\n    def _(_) -> None:\n        gui_framerate.value = int(gui_framerate_options.value)\n\n    Ts_world_cpf_numpy = Ts_world_cpf.numpy(force=True)\n\n    def do_update() -> None:\n        t = gui_timestep.value\n        cpf_handle.wxyz = Ts_world_cpf_numpy[t, :4]\n        cpf_handle.position = Ts_world_cpf_numpy[t, 4:7]\n\n        if gui_attach.value:\n            for client in server.get_clients().values():\n                client.camera.wxyz = (\n                    vtf.SO3(cpf_handle.wxyz) @ vtf.SO3.from_z_radians(np.pi)\n                ).wxyz\n                client.camera.position = cpf_handle.position - vtf.SO3(\n                    cpf_handle.wxyz\n                ) @ np.array([0.0, 0.0, gui_attach_dist.value])\n\n        if fk_outputs is not None:\n            for i in range(sample_count):\n                for b, bone_handle in enumerate(body_handles[i].bones):\n                    if b == 0:\n                        bone_transform = fk_outputs.T_world_root[i, t].numpy(force=True)\n                    else:\n                        bone_transform = fk_outputs.Ts_world_joint[i, t, b - 1].numpy(\n                            force=True\n                        )\n                    bone_handle.wxyz = bone_transform[:4]\n                    bone_handle.position = bone_transform[4:7]\n\n        for ii, timestep_frame in enumerate(timestep_handles):\n            timestep_frame.visible = t == ii\n\n    get_viser_file = server.gui.add_button(\"Get .viser file\")\n\n    if get_ego_video is not None:\n        ego_video = server.gui.add_button(\"Get Ego Video\")\n\n        @ego_video.on_click\n        def _(event: viser.GuiEvent) -> None:\n            assert event.client is not None\n            notif = event.client.add_notification(\n                \"Getting video...\", body=\"\", loading=True, with_close_button=False\n            )\n            ego_video_bytes = get_ego_video(\n                gui_start_end.value[0],\n                gui_start_end.value[1],\n                (gui_start_end.value[1] - gui_start_end.value[0]) / gui_framerate.value,\n            )\n            notif.remove()\n            event.client.send_file_download(\"ego_video.mp4\", ego_video_bytes)\n\n    prev_time = time.time()\n    handle = None\n\n    def loop_cb() -> int:\n        start, end = gui_start_end.value\n        duration = end - start\n\n        if get_viser_file.value is False:\n            nonlocal prev_time\n            now = time.time()\n            sleepdur = 1.0 / gui_framerate.value - (now - prev_time)\n            if sleepdur > 0.0:\n                time.sleep(sleepdur)\n            prev_time = now\n            if gui_playing.value:\n                gui_timestep.value = (gui_timestep.value + 1 - start) % duration + start\n            do_update()\n            return gui_timestep.value\n        else:\n            # Save trajectory.\n            nonlocal handle\n            if handle is None:\n                handle = server._start_scene_recording()\n                handle.set_loop_start()\n                gui_timestep.value = start\n\n            assert handle is not None\n            handle.insert_sleep(1.0 / gui_framerate.value)\n            gui_timestep.value = (gui_timestep.value + 1 - start) % duration + start\n\n            if gui_timestep.value == start:\n                get_viser_file.value = False\n                server.send_file_download(\n                    \"recording.viser\", content=handle.end_and_serialize()\n                )\n                handle = None\n\n            do_update()\n\n            return gui_timestep.value\n\n    return loop_cb\n"
  }
]