Full Code of kbrodt/sketch2pose for AI

main c027b65219b1 cached
20 files
145.0 KB
45.8k tokens
96 symbols
1 requests
Download .txt
Repository: kbrodt/sketch2pose
Branch: main
Commit: c027b65219b1
Files: 20
Total size: 145.0 KB

Directory structure:
gitextract__nwva2u6/

├── README.md
├── patches/
│   ├── selfcontact.diff
│   ├── smplx.diff
│   └── torchgeometry.diff
├── requirements.txt
├── scripts/
│   ├── download.sh
│   ├── prepare.sh
│   └── run.sh
└── src/
    ├── fist_pose.py
    ├── hist_cub.py
    ├── losses.py
    ├── pose.py
    ├── pose_estimation.py
    ├── renderer.py
    ├── spin/
    │   ├── __init__.py
    │   ├── constants.py
    │   ├── hmr.py
    │   ├── smpl.py
    │   └── utils.py
    └── utils.py

================================================
FILE CONTENTS
================================================

================================================
FILE: README.md
================================================
# Sketch2Pose: Estimating a 3D Character Pose from a Bitmap Sketch

Artists frequently capture character poses via raster sketches, then use these
drawings as a reference while posing a 3D character in a specialized 3D
software --- a time-consuming process, requiring specialized 3D training and
mental effort. We tackle this challenge by proposing the first system for
automatically inferring a 3D character pose from a single bitmap sketch,
producing poses consistent with viewer expectations. Algorithmically
interpreting bitmap sketches is challenging, as they contain significantly
distorted proportions and foreshortening. We address this by predicting three
key elements of a drawing, necessary to disambiguate the drawn poses: 2D bone
tangents, self-contacts, and bone foreshortening. These elements are then
leveraged in an optimization inferring the 3D character pose consistent with
the artist's intent. Our optimization balances cues derived from artistic
literature and perception research to compensate for distorted character
proportions. We demonstrate a gallery of results on sketches of numerous
styles. We validate our method via numerical evaluations, user studies, and
comparisons to manually posed characters and previous work.

[Project Page](http://www-labs.iro.umontreal.ca/~bmpix/sketch2pose/)

# Prerequisites

- [GNU/Linux](https://www.gnu.org/gnu/linux-and-gnu.en.html)
- [`python`](https://python.org)
- [`pytorch`](https://pytorch.org/)
- [NVIDIA GPU] (optional, but highly recommended)

## Download body model (SMPL-X)

Download SMPL-X body model from
[https://smpl-x.is.tue.mpg.de](https://smpl-x.is.tue.mpg.de)

See [`download.sh`](./scripts/download.sh) and run

```bash
sh ./scripts/download.sh
```

## Virtual environement

Change [`pytorch`](https://pytorch.org/) version if needed in
[`prepare.sh`](./scripts/prepare.sh) and run

```bash
sh ./scripts/prepare.sh
```

# Demo

Activate virtual environement `. venv/bin/activate` and run

```bash
sh ./scripts/run.sh

# or

python src/pose.py \
        --save-path "${out_dir}" \
        --img-path "${img_path}" \
        --use-contacts \
        --use-natural \
        --use-cos \
        --use-angle-transf \

# or without contacts

python src/pose.py \
        --save-path "${out_dir}" \
        --img-path "${img_path}" \
        --use-natural \
        --use-cos \
        --use-angle-transf \
```

# Citation

```
@article{brodt2022sketch2pose,
    author = {Kirill Brodt and Mikhail Bessmeltsev},
    title = {Sketch2Pose: Estimating a 3D Character Pose from a Bitmap Sketch},
    journal = {ACM Transactions on Graphics},
    year = {2022},
    month = {7},
    volume = {41},
    number = {4},
    doi = {10.1145/3528223.3530106},
}
```

# Useful links

- [Deep High-Resolution Representation Learning for Human Pose Estimation](https://github.com/leoxiaobin/deep-high-resolution-net.pytorch/)
- [SMPLify-X](https://github.com/vchoutas/smplify-x) ([project](https://smpl-x.is.tue.mpg.de/))
- [SPIN](https://github.com/nkolot/SPIN) ([project](https://www.seas.upenn.edu/~nkolot/projects/spin/))
- [eft](https://github.com/facebookresearch/eft)
- [SMPLify-XMC](https://github.com/muelea/smplify-xmc), [selfcontact](https://github.com/muelea/selfcontact) ([project](https://tuch.is.tue.mpg.de/))
- [Mixamo](https://www.mixamo.com) models with animations and a
  [script](https://forums.unrealengine.com/community/community-content-tools-and-tutorials/1376068-script-mixamo-download-script)
  to download them
- Quaternion-based [Forward
  Kinematics](https://github.com/facebookresearch/QuaterNet)


================================================
FILE: patches/selfcontact.diff
================================================
+++ venv/lib/python3.10/site-packages/selfcontact/body_segmentation.py
@@ -14,6 +14,8 @@
 #
 # Contact: ps-license@tuebingen.mpg.de
 
+from pathlib import Path
+
 import torch
 import trimesh
 import torch.nn as nn
@@ -22,6 +24,17 @@
 
 from .utils.mesh import winding_numbers
 
+
+def load_pkl(path):
+    with open(path, "rb") as fin:
+        return pickle.load(fin)
+
+
+def save_pkl(obj, path):
+    with open(path, "wb") as fout:
+        pickle.dump(obj, fout)
+
+
 class BodySegment(nn.Module):
     def __init__(self,
                  name,
@@ -63,9 +76,17 @@
         self.register_buffer('segment_faces', segment_faces)
 
         # create vector to select vertices form faces
-        tri_vidx = []
-        for ii in range(faces.max().item()+1):
-            tri_vidx += [torch.nonzero(faces==ii)[0].tolist()]
+        segments_folder = Path(segments_folder)
+        tri_vidx_path = segments_folder / "tri_vidx.pkl"
+        if not tri_vidx_path.is_file():
+            tri_vidx = []
+            for ii in range(faces.max().item()+1):
+                tri_vidx += [torch.nonzero(faces==ii)[0].tolist()]
+
+            save_pkl(tri_vidx, tri_vidx_path)
+        else:
+            tri_vidx = load_pkl(tri_vidx_path)
+
         self.register_buffer('tri_vidx', torch.tensor(tri_vidx))
 
     def create_band_faces(self):
@@ -149,7 +170,7 @@
         self.segmentation = {}
         for idx, name in enumerate(names):
             self.segmentation[name] = BodySegment(name, faces, segments_folder,
-                model_type).to('cuda')
+                model_type).to(device)
 
     def batch_has_self_isec_verts(self, vertices):
         """
+++ venv/lib/python3.10/site-packages/selfcontact/selfcontact.py
@@ -41,6 +41,7 @@
         test_segments=True,
         compute_hd=False,
         buffer_geodists=False,
+        device="cuda",
     ):
         super().__init__()
 
@@ -95,7 +96,7 @@
         if self.test_segments:
             sxseg = pickle.load(open(segments_bounds_path, 'rb'))
             self.segments = BatchBodySegment(
-                [x for x in sxseg.keys()], faces, segments_folder, self.model_type
+                [x for x in sxseg.keys()], faces, segments_folder, self.model_type, device=device,
             )
 
         # load regressor to get high density mesh
@@ -106,7 +107,7 @@
                 torch.tensor(hd_operator['values']),
                 torch.Size(hd_operator['size']))
             self.register_buffer('hd_operator',
-                torch.tensor(hd_operator).float())
+                hd_operator.clone().detach().float())
 
             with open(point_vert_corres_path, 'rb') as f:
                 hd_geovec = pickle.load(f)['faces_vert_is_sampled_from']
@@ -135,9 +136,13 @@
         # split because of memory into two chunks
         exterior = torch.zeros((bs, nv), device=vertices.device,
             dtype=torch.bool)
-        exterior[:, :5000] = winding_numbers(vertices[:,:5000,:],
+        exterior[:, :3000] = winding_numbers(vertices[:,:3000,:],
             triangles).le(0.99)
-        exterior[:, 5000:] = winding_numbers(vertices[:,5000:,:],
+        exterior[:, 3000:6000] = winding_numbers(vertices[:,3000:6000,:],
+            triangles).le(0.99)
+        exterior[:, 6000:9000] = winding_numbers(vertices[:,6000:9000,:],
+            triangles).le(0.99)
+        exterior[:, 9000:] = winding_numbers(vertices[:,9000:,:],
             triangles).le(0.99)
 
         # check if intersections happen within segments
@@ -173,9 +178,13 @@
         # split because of memory into two chunks
         exterior = torch.zeros((bs, np), device=points.device,
             dtype=torch.bool)
-        exterior[:, :6000] = winding_numbers(points[:,:6000,:],
+        exterior[:, :3000] = winding_numbers(points[:,:3000,:],
+            triangles).le(0.99)
+        exterior[:, 3000:6000] = winding_numbers(points[:,3000:6000,:],
             triangles).le(0.99)
-        exterior[:, 6000:] = winding_numbers(points[:,6000:,:],
+        exterior[:, 6000:9000] = winding_numbers(points[:,6000:9000,:],
+            triangles).le(0.99)
+        exterior[:, 9000:] = winding_numbers(points[:,9000:,:],
             triangles).le(0.99)
 
         return exterior
@@ -371,6 +380,23 @@
 
         return hd_v2v_mins, hd_exteriors, hd_points, hd_faces_in_contacts
 
+    def verts_in_contact(self, vertices, return_idx=False):
+
+            # get pairwise distances of vertices
+            v2v = self.get_pairwise_dists(vertices, vertices, squared=True)
+
+            # mask v2v with eucledean and geodesic dsitance
+            euclmask = v2v < self.euclthres**2
+            mask = euclmask * self.geomask
+
+            # find closes vertex in contact
+            in_contact = mask.sum(1) > 0
+
+            if return_idx:
+                in_contact = torch.where(in_contact)
+
+            return in_contact
+
 
 
 class SelfContactSmall(nn.Module):
+++ venv/lib/python3.10/site-packages/selfcontact/utils/mesh.py
@@ -82,7 +82,7 @@
     if valid_vals > 0:
         loss = (mask * dists).sum() / valid_vals
     else:
-        loss = torch.Tensor([0]).cuda()
+        loss = mask.new_tensor([0])
     return loss
 
 def batch_index_select(inp, dim, index):
@@ -103,6 +103,7 @@
     xx = torch.bmm(x, x.transpose(2, 1))
     yy = torch.bmm(y, y.transpose(2, 1))
     zz = torch.bmm(x, y.transpose(2, 1))
+    use_cuda = x.device.type == "cuda"
     if use_cuda:
         dtype = torch.cuda.LongTensor
     else:


================================================
FILE: patches/smplx.diff
================================================
+++ venv/lib/python3.10/site-packages/smplx/body_models.py
@@ -366,7 +366,7 @@
             num_repeats = int(batch_size / betas.shape[0])
             betas = betas.expand(num_repeats, -1)
 
-        vertices, joints = lbs(betas, full_pose, self.v_template,
+        vertices, joints, _ = lbs(betas, full_pose, self.v_template,
                                self.shapedirs, self.posedirs,
                                self.J_regressor, self.parents,
                                self.lbs_weights, pose2rot=pose2rot)
@@ -1228,7 +1228,7 @@
 
         shapedirs = torch.cat([self.shapedirs, self.expr_dirs], dim=-1)
 
-        vertices, joints = lbs(shape_components, full_pose, self.v_template,
+        vertices, joints, A = lbs(shape_components, full_pose, self.v_template,
                                shapedirs, self.posedirs,
                                self.J_regressor, self.parents,
                                self.lbs_weights, pose2rot=pose2rot,
@@ -1283,7 +1283,9 @@
                              right_hand_pose=right_hand_pose,
                              jaw_pose=jaw_pose,
                              v_shaped=v_shaped,
-                             full_pose=full_pose if return_full_pose else None)
+                             full_pose=full_pose if return_full_pose else None,
+                             A=A,
+                             )
         return output
 
 
+++ venv/lib/python3.10/site-packages/smplx/lbs.py
@@ -245,7 +245,7 @@
 
     verts = v_homo[:, :, :3, 0]
 
-    return verts, J_transformed
+    return verts, J_transformed, (A, J)
 
 
 def vertices2joints(J_regressor: Tensor, vertices: Tensor) -> Tensor:
+++ venv/lib/python3.10/site-packages/smplx/utils.py
@@ -71,6 +71,7 @@
 class SMPLXOutput(SMPLHOutput):
     expression: Optional[Tensor] = None
     jaw_pose: Optional[Tensor] = None
+    A: Optional[Tensor] = None
 
 
 @dataclass


================================================
FILE: patches/torchgeometry.diff
================================================
+++ venv/lib/python3.10/site-packages/torchgeometry/core/conversions.py
@@ -298,6 +298,9 @@
                       rmat_t[:, 0, 1] - rmat_t[:, 1, 0]], -1)
     t3_rep = t3.repeat(4, 1).t()
 
+    mask_d2 = mask_d2.float()
+    mask_d0_d1 = mask_d0_d1.float()
+    mask_d0_nd1 = mask_d0_nd1.float()
     mask_c0 = mask_d2 * mask_d0_d1
     mask_c1 = mask_d2 * (1 - mask_d0_d1)
     mask_c2 = (1 - mask_d2) * mask_d0_nd1


================================================
FILE: requirements.txt
================================================
matplotlib>=3.5.1
numpy>=1.22.3
opencv_python>=4.5.5.64
Pillow>=9.1.0
plotly>=5.7.0
pyrender>=0.1.45
scikit_image>=0.19.2
scipy>=1.8.0
Shapely>=1.8.1.post1
scikit-image>=0.19.2
tensorboard>=2.8.0
torchgeometry>=0.1.2
tqdm>=4.64.0
trimesh>=3.10.8
git+https://github.com/muelea/selfcontact.git@08da422526419c24736c0616bca49623e442c26a
git+https://github.com/vchoutas/smplx.git@5fa20519735cceda19afed0beeabd53caef711cd


================================================
FILE: scripts/download.sh
================================================
#!/usr/bin/env sh


set -euo pipefail


asset_dir="./assets"

[ ! -e "${asset_dir}"/models_smplx_v1_1.zip ] \
    && echo Error: Download SMPL-X body model from https://smpl-x.is.tue.mpg.de \
    and save zip archive to "${asset_dir}" \
    && exit 1 \
    && :

asset_urls=(
    # Download constants (SPIN)
    http://visiondata.cis.upenn.edu/spin/data.tar.gz

    # Download essentials (SMPLify-XMC)
    https://download.is.tue.mpg.de/tuch/smplify-xmc-essentials.zip

    # Download sketch2pose models
    http://www-labs.iro.umontreal.ca/~bmpix/sketch2pose/models.zip

    # Download test images
    http://www-labs.iro.umontreal.ca/~bmpix/sketch2pose/images.zip
)
for asset_url in "${asset_urls[@]}"; do
    wget \
        -nc \
        -c \
        --directory-prefix "${asset_dir}" \
        "${asset_url}"
done

models_dir="./models"
mkdir -p "${models_dir}"

model_files=(
    # Unzip smplx models
    models_smplx_v1_1.zip

    # Unzip essentials (SMPLifu-XMC)
    smplify-xmc-essentials.zip

    # Unzip sketch2pose models
    models.zip
)

for model_file in "${model_files[@]}"; do
    unzip \
        -u \
        -d "${models_dir}" \
        "${asset_dir}"/"${model_file}"
done

# Unzip constants (SPIN)
tar \
    --skip-old-files \
    -xvf "${asset_dir}"/data.tar.gz \
    -C "${models_dir}" \
    data/smpl_mean_params.npz

data_dir="./data"
mkdir -p "${data_dir}"

# Unzip test images
unzip \
    -u \
    -d "${data_dir}" \
    "${asset_dir}"/images.zip


================================================
FILE: scripts/prepare.sh
================================================
#!/usr/bin/env sh


set -euo pipefail

venv_dir=venv
python -m venv --clear "${venv_dir}"

. "${venv_dir}"/bin/activate

pip install -U pip setuptools

extra="cpu"
[ -x "$(command -v nvcc)" ] && extra="cu113"
pip install \
    torch \
    torchvision \
    --extra-index-url https://download.pytorch.org/whl/"${extra}"

pip install -r requirements.txt

v=$(python -c 'import sys; v = sys.version_info; print(f"{v.major}.{v.minor}")')
for p in patches/*.diff; do
    patch -p0 < <(sed "s/python3.10/python${v}/" "${p}")
done


================================================
FILE: scripts/run.sh
================================================
#!/usr/bin/env sh


set -euo pipefail


img_dir="./data/images"
out_dir="./output"

find "${img_dir}" -mindepth 1 -maxdepth 1 -type f -print0 \
    | xargs -0 -I "{}" python src/pose.py \
        --save-path "${out_dir}" \
        --img-path "{}" \
        --use-contacts \
        --use-natural \
        --use-cos \
        --use-angle-transf \

exit

# baseline (SMPLify-XMC)

find "${img_dir}" -mindepth 1 -maxdepth 1 -type f -print0 \
    | xargs -0 -I "{}" python src/pose.py \
        --save-path "${out_dir}_baseline" \
        --img-path "{}" \
        --c-mse 1 \
        --c-par 0 \
        --use-contacts \
        --use-cos \
        --use-angle-transf \

# ablation

find "${img_dir}" -mindepth 1 -maxdepth 1 -type f -print0 \
    | xargs -0 -I "{}" python src/pose.py \
        --save-path "${out_dir}_wocostransform" \
        --img-path "{}" \
        --use-contacts \
        --use-natural \
        --use-cos \


find "${img_dir}" -mindepth 1 -maxdepth 1 -type f -print0 \
    | xargs -0 -I "{}" python src/pose.py \
        --save-path "${out_dir}_wocontacts" \
        --img-path "{}" \
        --use-msc \
        --use-natural \
        --use-cos \
        --use-angle-transf \


find "${img_dir}" -mindepth 1 -maxdepth 1 -type f -print0 \
    | xargs -0 -I "{}" python src/pose.py \
        --save-path "${out_dir}_wonatural" \
        --img-path "{}" \
        --use-contacts \
        --use-cos \
        --use-angle-transf \


================================================
FILE: src/fist_pose.py
================================================
left_fist = [
    0.0, 0.0, 0.0,
    0.0, 0.0, 0.0,
    0.0, 0.0, 0.0,
    0.0, 0.0, 0.0,
    0.0, 0.0, 0.0,
    0.0, 0.0, 0.0,
    0.0, 0.0, 0.0,
    0.0, 0.0, 0.0,
    0.0, 0.0, 0.0,
    0.0, 0.0, 0.0,
    0.0, 0.0, 0.0,
    0.0, 0.0, 0.0,
    0.0, 0.0, 0.0,
    0.0, 0.0, 0.0,
    0.0, 0.0, 0.0,
    0.0, 0.0, 0.0,
    0.0, 0.0, 0.0,
    0.0, 0.0, 0.0,
    0.0, 0.0, 0.0,
    0.0, 0.0, 0.0,
    0.0, 0.0, 0.0,
    0.0, 0.0, 0.0,
    0.0, 0.0, 0.0,
    0.0, 0.0, 0.0,
    0.0, 0.0, 0.0,
    0.4183167815208435, 0.10645648092031479, -1.6593892574310303,
    0.15252035856246948, -0.14700782299041748, -1.3719955682754517,
    -0.04432843625545502, -0.15799851715564728, -0.938068151473999,
    -0.12218914180994034, 0.073341965675354, -1.6415189504623413,
    -0.14376045763492584, 0.1927780956029892, -1.3593589067459106,
    -0.0851994976401329, 0.01652289740741253, -0.7474589347839355,
    -0.9881719946861267, -0.3987707793712616, -1.3535722494125366,
    -0.6686224937438965, 0.1261960119009018, -1.080643892288208,
    -0.8101894855499268, -0.1306752860546112, -0.8412265777587891,
    -0.3495230972766876, -0.17784251272678375, -1.4433038234710693,
    -0.46278536319732666, 0.13677796721458435, -1.467200517654419,
    -0.3681888282299042, 0.003404417773708701, -0.7764251232147217,
    0.850964367389679, 0.2769227623939514, -0.09154807031154633,
    0.14500413835048676, 0.09604815393686295, 0.219278022646904,
    1.0451993942260742, 0.16911321878433228, -0.2426234930753708,
    0.11167845129966736, -0.04289207234978676, 0.41644084453582764,
    0.10881128907203674, 0.06598565727472305, 0.756219744682312,
    -0.0963931530714035, 0.09091583639383316, 0.18845966458320618,
    -0.11809506267309189, -0.050943851470947266, 0.5295845866203308,
    -0.14369848370552063, -0.055241718888282776, 0.704857349395752,
    -0.019182899966835976, 0.0923367589712143, 0.3379131853580475,
    -0.45703303813934326, 0.1962839663028717, 0.6254575848579407,
    -0.21465237438678741, 0.06599827855825424, 0.5068942308425903,
    -0.36972442269325256, 0.0603446289896965, 0.07949023693799973,
    -0.14186954498291016, 0.08585254102945328, 0.6355276107788086,
    -0.3033415675163269, 0.05788097903132439, 0.6313892006874084,
    -0.17612087726593018, 0.13209305703639984, 0.3733545243740082,
    0.850964367389679, -0.2769227623939514, 0.09154807031154633,
    -0.4998386800289154, -0.026556432247161865, -0.052880801260471344,
    0.5355585217475891, -0.045960985124111176, 0.27735769748687744,
]

left_right_fist = [
    0.0, 0.0, 0.0,
    0.0, 0.0, 0.0,
    0.0, 0.0, 0.0,
    0.0, 0.0, 0.0,
    0.0, 0.0, 0.0,
    0.0, 0.0, 0.0,
    0.0, 0.0, 0.0,
    0.0, 0.0, 0.0,
    0.0, 0.0, 0.0,
    0.0, 0.0, 0.0,
    0.0, 0.0, 0.0,
    0.0, 0.0, 0.0,
    0.0, 0.0, 0.0,
    0.0, 0.0, 0.0,
    0.0, 0.0, 0.0,
    0.0, 0.0, 0.0,
    0.0, 0.0, 0.0,
    0.0, 0.0, 0.0,
    0.0, 0.0, 0.0,
    0.0, 0.0, 0.0,
    0.0, 0.0, 0.0,
    0.0, 0.0, -0.0,
    0.0, 0.0, 0.0,
    0.0, 0.0, 0.0,
    0.0, 0.0, 0.0,
    0.4183167815208435, 0.10645648092031479, -1.6593892574310303,
    0.15252035856246948, -0.14700782299041748, -1.3719955682754517,
    -0.04432843625545502, -0.15799851715564728, -0.938068151473999,
    -0.12218914180994034, 0.073341965675354, -1.6415189504623413,
    -0.14376045763492584, 0.1927780956029892, -1.3593589067459106,
    -0.0851994976401329, 0.01652289740741253, -0.7474589347839355,
    -0.9881719946861267, -0.3987707793712616, -1.3535722494125366,
    -0.6686224937438965, 0.1261960119009018, -1.080643892288208,
    -0.8101894855499268, -0.1306752860546112, -0.8412265777587891,
    -0.3495230972766876, -0.17784251272678375, -1.4433038234710693,
    -0.46278536319732666, 0.13677796721458435, -1.467200517654419,
    -0.3681888282299042, 0.003404417773708701, -0.7764251232147217,
    0.850964367389679, 0.2769227623939514, -0.09154807031154633,
    0.14500413835048676, 0.09604815393686295, 0.219278022646904,
    1.0451993942260742, 0.16911321878433228, -0.2426234930753708,
    0.4183167815208435, -0.10645647346973419, 1.6593892574310303,
    0.15252038836479187, 0.14700786769390106, 1.3719956874847412,
    -0.04432841017842293, 0.15799842774868011, 0.9380677938461304,
    -0.12218913435935974, -0.0733419880270958, 1.6415191888809204,
    -0.14376048743724823, -0.19277812540531158, 1.3593589067459106,
    -0.08519953489303589, -0.016522908583283424, 0.7474592328071594,
    -0.9881719350814819, 0.3987707495689392, 1.3535723686218262,
    -0.6686226725578308, -0.12619605660438538, 1.080644130706787,
    -0.8101896643638611, 0.1306752860546112, 0.8412266373634338,
    -0.34952324628829956, 0.17784248292446136, 1.443304181098938,
    -0.46278542280197144, -0.13677802681922913, 1.467200517654419,
    -0.36818885803222656, -0.0034044249914586544, 0.7764251232147217,
    0.8509642481803894, -0.2769228219985962, 0.09154807776212692,
    0.14500458538532257, -0.09604845196008682, -0.21927869319915771,
    1.0451991558074951, -0.1691131889820099, 0.242623433470726,
]

right_fist = []
for lf, lrf in zip(left_fist, left_right_fist):
    if lf != lrf:
        right_fist.append(lrf)
    else:
        right_fist.append(0)


left_flat_up = [
    0.0, 0.0, 0.0,
    0.0, 0.0, 0.0,
    0.0, 0.0, 0.0,
    0.0, 0.0, 0.0,
    0.0, 0.0, 0.0,
    0.0, 0.0, 0.0,
    0.0, 0.0, 0.0,
    0.0, 0.0, 0.0,
    0.0, 0.0, 0.0,
    0.0, 0.0, 0.0,
    0.0, 0.0, 0.0,
    0.0, 0.0, 0.0,
    0.0, 0.0, 0.0,
    0.0, 0.0, 0.0,
    0.0, 0.0, 0.0,
    0.0, 0.0, 0.0,
    0.0, 0.0, 0.0,
    0.0, 0.0, 0.0,
    0.0, 0.0, 0.0,
    0.0, 0.0, 0.0,
    0.0, 0, 1.5129635334014893,
    0.0, 0.0, 0.0,
    0.0, 0.0, 0.0,
    0.0, 0.0, 0.0,
    0.0, 0.0, 0.0,
    0.0, 0.0, 0.0,
    0.0, 0.0, 0.0,
    0.0, 0.0, 0.0,
    0.0, 0.0, 0.0,
    0.0, 0.0, 0.0,
    0.0, 0.0, 0.0,
    0.0, 0.0, 0.0,
    0.0, 0.0, 0.0,
    0.0, 0.0, 0.0,
    0.0, 0.0, 0.0,
    0.0, 0.0, 0.0,
    0.0, 0.0, 0.0,
    0.0, 0.0, 0.0,
    0.0, 0.0, 0.0,
    0.0, 0.0, 0.0,
    0.0, 0.0, 0.0,
    0.0, 0.0, 0.0,
    0.0, 0.0, 0.0,
    0.0, 0.0, 0.0,
    0.0, 0.0, 0.0,
    0.0, 0.0, 0.0,
    0.0, 0.0, 0.0,
    0.0, 0.0, 0.0,
    0.0, 0.0, 0.0,
    0.0, 0.0, 0.0,
    0.0, 0.0, 0.0,
    0.0, 0.0, 0.0,
    0.0, 0.0, 0.0,
    0.0, 0.0, 0.0,
    0.0, 0.0, 0.0,
]

left_flat_down = [
    0.0, 0.0, 0.0,
    0.0, 0.0, 0.0,
    0.0, 0.0, 0.0,
    0.0, 0.0, 0.0,
    0.0, 0.0, 0.0,
    0.0, 0.0, 0.0,
    0.0, 0.0, 0.0,
    0.0, 0.0, 0.0,
    0.0, 0.0, 0.0,
    0.0, 0.0, 0.0,
    0.0, 0.0, 0.0,
    0.0, 0.0, 0.0,
    0.0, 0.0, 0.0,
    0.0, 0.0, 0.0,
    0.0, 0.0, 0.0,
    0.0, 0.0, 0.0,
    0.0, 0.0, 0.0,
    0.0, 0.0, 0.0,
    0.0, 0.0, 0.0,
    0.0, 0.0, 0.0,
    0.0, 0, -1.4648663997650146,
    0.0, 0.0, 0.0,
    0.0, 0.0, 0.0,
    0.0, 0.0, 0.0,
    0.0, 0.0, 0.0,
    0.0, 0.0, 0.0,
    0.0, 0.0, 0.0,
    0.0, 0.0, 0.0,
    0.0, 0.0, 0.0,
    0.0, 0.0, 0.0,
    0.0, 0.0, 0.0,
    0.0, 0.0, 0.0,
    0.0, 0.0, 0.0,
    0.0, 0.0, 0.0,
    0.0, 0.0, 0.0,
    0.0, 0.0, 0.0,
    0.0, 0.0, 0.0,
    0.0, 0.0, 0.0,
    0.0, 0.0, 0.0,
    0.0, 0.0, 0.0,
    0.0, 0.0, 0.0,
    0.0, 0.0, 0.0,
    0.0, 0.0, 0.0,
    0.0, 0.0, 0.0,
    0.0, 0.0, 0.0,
    0.0, 0.0, 0.0,
    0.0, 0.0, 0.0,
    0.0, 0.0, 0.0,
    0.0, 0.0, 0.0,
    0.0, 0.0, 0.0,
    0.0, 0.0, 0.0,
    0.0, 0.0, 0.0,
    0.0, 0.0, 0.0,
    0.0, 0.0, 0.0,
    0.0, 0.0, 0.0,
]

right_flat_up = [
    0.0, 0.0, 0.0,
    0.0, 0.0, 0.0,
    0.0, 0.0, 0.0,
    0.0, 0.0, 0.0,
    0.0, 0.0, 0.0,
    0.0, 0.0, 0.0,
    0.0, 0.0, 0.0,
    0.0, 0.0, 0.0,
    0.0, 0.0, 0.0,
    0.0, 0.0, 0.0,
    0.0, 0.0, 0.0,
    0.0, 0.0, 0.0,
    0.0, 0.0, 0.0,
    0.0, 0.0, 0.0,
    0.0, 0.0, 0.0,
    0.0, 0.0, 0.0,
    0.0, 0.0, 0.0,
    0.0, 0.0, 0.0,
    0.0, 0.0, 0.0,
    0.0, 0.0, 0.0,
    0.0, 0.0, 0.0,
    0.0, 0, -1.5021973848342896,
    0.0, 0.0, 0.0,
    0.0, 0.0, 0.0,
    0.0, 0.0, 0.0,
    0.0, 0.0, 0.0,
    0.0, 0.0, 0.0,
    0.0, 0.0, 0.0,
    0.0, 0.0, 0.0,
    0.0, 0.0, 0.0,
    0.0, 0.0, 0.0,
    0.0, 0.0, 0.0,
    0.0, 0.0, 0.0,
    0.0, 0.0, 0.0,
    0.0, 0.0, 0.0,
    0.0, 0.0, 0.0,
    0.0, 0.0, 0.0,
    0.0, 0.0, 0.0,
    0.0, 0.0, 0.0,
    0.0, 0.0, 0.0,
    0.0, 0.0, 0.0,
    0.0, 0.0, 0.0,
    0.0, 0.0, 0.0,
    0.0, 0.0, 0.0,
    0.0, 0.0, 0.0,
    0.0, 0.0, 0.0,
    0.0, 0.0, 0.0,
    0.0, 0.0, 0.0,
    0.0, 0.0, 0.0,
    0.0, 0.0, 0.0,
    0.0, 0.0, 0.0,
    0.0, 0.0, 0.0,
    0.0, 0.0, 0.0,
    0.0, 0.0, 0.0,
    0.0, 0.0, 0.0,
]

right_flat_down = [
    0.0, 0.0, 0.0,
    0.0, 0.0, 0.0,
    0.0, 0.0, 0.0,
    0.0, 0.0, 0.0,
    0.0, 0.0, 0.0,
    0.0, 0.0, 0.0,
    0.0, 0.0, 0.0,
    0.0, 0.0, 0.0,
    0.0, 0.0, 0.0,
    0.0, 0.0, 0.0,
    0.0, 0.0, 0.0,
    0.0, 0.0, 0.0,
    0.0, 0.0, 0.0,
    0.0, 0.0, 0.0,
    0.0, 0.0, 0.0,
    0.0, 0.0, 0.0,
    0.0, 0.0, 0.0,
    0.0, 0.0, 0.0,
    0.0, 0.0, 0.0,
    0.0, 0.0, 0.0,
    0.0, 0.0, 0.0,
    0, 0, 1.494218111038208,
    0.0, 0.0, 0.0,
    0.0, 0.0, 0.0,
    0.0, 0.0, 0.0,
    0.0, 0.0, 0.0,
    0.0, 0.0, 0.0,
    0.0, 0.0, 0.0,
    0.0, 0.0, 0.0,
    0.0, 0.0, 0.0,
    0.0, 0.0, 0.0,
    0.0, 0.0, 0.0,
    0.0, 0.0, 0.0,
    0.0, 0.0, 0.0,
    0.0, 0.0, 0.0,
    0.0, 0.0, 0.0,
    0.0, 0.0, 0.0,
    0.0, 0.0, 0.0,
    0.0, 0.0, 0.0,
    0.0, 0.0, 0.0,
    0.0, 0.0, 0.0,
    0.0, 0.0, 0.0,
    0.0, 0.0, 0.0,
    0.0, 0.0, 0.0,
    0.0, 0.0, 0.0,
    0.0, 0.0, 0.0,
    0.0, 0.0, 0.0,
    0.0, 0.0, 0.0,
    0.0, 0.0, 0.0,
    0.0, 0.0, 0.0,
    0.0, 0.0, 0.0,
    0.0, 0.0, 0.0,
    0.0, 0.0, 0.0,
    0.0, 0.0, 0.0,
    0.0, 0.0, 0.0,
]

relaxed = [
    0.0, 0.0, 0.0,
    0.0, 0.0, 0.0,
    0.0, 0.0, 0.0,
    0.0, 0.0, 0.0,
    0.0, 0.0, 0.0,
    0.0, 0.0, 0.0,
    0.0, 0.0, 0.0,
    0.0, 0.0, 0.0,
    0.0, 0.0, 0.0,
    0.0, 0.0, 0.0,
    0.0, 0.0, 0.0,
    0.0, 0.0, 0.0,
    0.0, 0.0, 0.0,
    0.0, 0.0, 0.0,
    0.0, 0.0, 0.0,
    0.0, 0.0, 0.0,
    0.0, 0.0, 0.0,
    0.0, 0.0, 0.0,
    0.0, 0.0, 0.0,
    0.0, 0.0, 0.0,
    0.0, 0.0, 0.0,
    0.0, 0.0, 0.0,
    0.0, 0.0, 0.0,
    0.0, 0.0, 0.0,
    0.0, 0.0, 0.0,
    0.11167845129966736, 0.04289207234978676, -0.41644084453582764,
    0.10881128907203674, -0.06598565727472305, -0.756219744682312,
    -0.0963931530714035, -0.09091583639383316, -0.18845966458320618,
    -0.11809506267309189, 0.050943851470947266, -0.5295845866203308,
    -0.14369848370552063, 0.055241718888282776, -0.704857349395752,
    -0.019182899966835976, -0.0923367589712143, -0.3379131853580475,
    -0.45703303813934326, -0.1962839663028717, -0.6254575848579407,
    -0.21465237438678741, -0.06599827855825424, -0.5068942308425903,
    -0.36972442269325256, -0.0603446289896965, -0.07949023693799973,
    -0.14186954498291016, -0.08585254102945328, -0.6355276107788086,
    -0.3033415675163269, -0.05788097903132439, -0.6313892006874084,
    -0.17612087726593018, -0.13209305703639984, -0.3733545243740082,
    0.850964367389679, 0.2769227623939514, -0.09154807031154633,
    -0.4998386800289154, 0.026556432247161865, 0.052880801260471344,
    0.5355585217475891, 0.045960985124111176, -0.27735769748687744,
    0.11167845129966736, -0.04289207234978676, 0.41644084453582764,
    0.10881128907203674, 0.06598565727472305, 0.756219744682312,
    -0.0963931530714035, 0.09091583639383316, 0.18845966458320618,
    -0.11809506267309189, -0.050943851470947266, 0.5295845866203308,
    -0.14369848370552063, -0.055241718888282776, 0.704857349395752,
    -0.019182899966835976, 0.0923367589712143, 0.3379131853580475,
    -0.45703303813934326, 0.1962839663028717, 0.6254575848579407,
    -0.21465237438678741, 0.06599827855825424, 0.5068942308425903,
    -0.36972442269325256, 0.0603446289896965, 0.07949023693799973,
    -0.14186954498291016, 0.08585254102945328, 0.6355276107788086,
    -0.3033415675163269, 0.05788097903132439, 0.6313892006874084,
    -0.17612087726593018, 0.13209305703639984, 0.3733545243740082,
    0.850964367389679, -0.2769227623939514, 0.09154807031154633,
    -0.4998386800289154, -0.026556432247161865, -0.052880801260471344,
    0.5355585217475891, -0.045960985124111176, 0.27735769748687744,
]

# body joints + left arm + right arm
# 25 + 15 + 15
# smpl(left_hand_pose, right_hand_pose)

left_start = 25 * 3
left_end = left_start + 15 * 3
right_end = left_end + 15 * 3

LEFT_FIST = left_fist[left_start:left_end]
RIGHT_FIST = right_fist[left_end:right_end]

LEFT_FLAT_UP = left_flat_up[20 * 3 : 20 * 3 + 3]
LEFT_FLAT_DOWN = left_flat_down[20 * 3 : 20 * 3 + 3]

RIGHT_FLAT_UP = right_flat_up[21 * 3 : 21 * 3 + 3]
RIGHT_FLAT_DOWN = right_flat_down[21 * 3 : 21 * 3 + 3]

LEFT_RELAXED = relaxed[left_start:left_end]
RIGHT_RELAXED = relaxed[left_end:right_end]

INT_TO_FIST = {
    "lfl": None,
    "lf": LEFT_FIST,
    "lu": LEFT_FLAT_UP,
    "ld": LEFT_FLAT_DOWN,
    "rfl": None,
    "rf": RIGHT_FIST,
    "ru": RIGHT_FLAT_UP,
    "rd": RIGHT_FLAT_DOWN,
}


================================================
FILE: src/hist_cub.py
================================================
import itertools
import functools
import math
import multiprocessing
from pathlib import Path

import matplotlib
matplotlib.rcParams.update({'font.size': 24})
matplotlib.rcParams.update({
  "text.usetex": True,
  "text.latex.preamble": r"\usepackage{biolinum} \usepackage{libertineRoman} \usepackage{libertineMono} \usepackage{biolinum} \usepackage[libertine]{newtxmath}",
  'ps.usedistiller': "xpdf",
})

import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import numpy as np
import tqdm
from scipy.stats import wasserstein_distance

import pose_estimation


def cub(x, a, b, c):
    x2 = x * x
    x3 = x2 * x

    y = a * x3 + b * x2 + c * x

    return y


def subsample(a, p=0.0005, seed=0):
    np.random.seed(seed)
    N = len(a)
    inds = np.random.choice(range(N), size=int(p * N))
    a = a[inds].copy()

    return a


def read_cos_opt(path, fname="cos_hist.npy"):
    cos_opt = []
    for p in Path(path).rglob(fname):
        d = np.load(p)
        cos_opt.append(d)

    cos_opt = np.array(cos_opt)

    return cos_opt


def plot_hist(cos_opt_dir, hist_smpl_fpath, params, out_dir, bins=10, xy=None):
    cos_opt = read_cos_opt(cos_opt_dir)
    angle_opt = np.arccos(cos_opt)
    angle_opt2 = cub(angle_opt, *params)

    cos_opt2 = np.cos(angle_opt2)
    cos_smpl = np.load(hist_smpl_fpath)
    # cos_smpl = subsample(cos_smpl)
    print(cos_smpl.shape)

    cos_smpl = np.clip(cos_smpl, -1, 1)

    cos_opt = angle_opt
    cos_opt2 = angle_opt2
    cos_smpl = np.arccos(cos_smpl)

    cos_opt = 180 / math.pi * cos_opt
    cos_opt2 = 180 / math.pi * cos_opt2
    cos_smpl = 180 / math.pi * cos_smpl
    max_range = 90  # math.pi / 2

    xticks = [0, 15, 30, 45, 60, 75, 90]
    for idx, bone in enumerate(pose_estimation.SKELETON):
        i, j = bone
        i_name = pose_estimation.KPS[i]
        j_name = pose_estimation.KPS[j]
        if i_name != "Left Upper Leg":
            continue

        name = f"{i_name}_{j_name}"

        gs = gridspec.GridSpec(2, 4)
        fig = plt.figure(tight_layout=True, figsize=(16, 8), dpi=300)

        ax0 = fig.add_subplot(gs[0, 0])
        ax0.hist(cos_smpl[:, idx], bins=bins, range=(0, max_range), density=True)
        ax0.set_xticks(xticks)
        ax0.tick_params(labelbottom=False, labelleft=True)

        ax1 = fig.add_subplot(gs[1, 0], sharex=ax0)
        ax1.hist(cos_opt[:, idx], bins=bins, range=(0, max_range), density=True)
        ax1.set_xticks(xticks)

        if xy is not None:
            ax2 = fig.add_subplot(gs[:, 1:3])
            ax2.plot(xy[0], xy[1], linewidth=8)
            ax2.plot(xy[0], xy[0], linewidth=4, linestyle="dashed")
            ax2.set_xticks(xticks)
            ax2.set_yticks(xticks)

        ax3 = fig.add_subplot(gs[0, 3], sharey=ax0)
        ax3.hist(cos_opt2[:, idx], bins=bins, range=(0, max_range), density=True)
        ax3.set_xticks(xticks)
        ax3.tick_params(labelbottom=False, labelleft=False)

        ax4 = fig.add_subplot(gs[1, 3], sharex=ax3, sharey=ax1)
        alpha = 0.5
        ax4.hist(cos_opt[:, idx], bins=bins, range=(0, max_range), density=True, label=r"$\mathcal{B}_i$", alpha=alpha)
        ax4.hist(cos_opt2[:, idx], bins=bins, range=(0, max_range), density=True, label=r"$f(\mathcal{B}_i)$", alpha=alpha)
        ax4.hist(cos_smpl[:, idx], bins=bins, range=(0, max_range), density=True, label=r"$\mathcal{A}_i$", alpha=alpha)
        ax4.set_xticks(xticks)
        ax4.tick_params(labelbottom=True, labelleft=False)
        ax4.legend()

        fig.savefig(out_dir / f"hist_{name}.png")
        plt.close()


def kldiv(p_hist, q_hist):
    wd = wasserstein_distance(p_hist, q_hist)

    return wd


def calc_histogram(x, bins=10, range=(0, 1)):
    h, _ = np.histogram(x, bins=bins, range=range, density=True)

    return h

def step(params, angles_opt, p_hist, bone_idx=None):
    if sum(params) > 1:
        return math.inf, params

    kl = 0
    for i, _ in enumerate(pose_estimation.SKELETON):
        if bone_idx is not None and i != bone_idx:
            continue

        angles_opt2 = cub(angles_opt[:, i], *params)
        if angles_opt2.max() > 1 or angles_opt2.min() < 0:
            kl = math.inf

            break

        q_hist = calc_histogram(angles_opt2)

        kl += kldiv(p_hist[i], q_hist)

    return kl, params


def optimize(cos_opt_dir, hist_smpl_fpath, bone_idx=None):
    cos_opt = read_cos_opt(cos_opt_dir)
    angles_opt = np.arccos(cos_opt) / (math.pi / 2)
    cos_smpl = np.load(hist_smpl_fpath)
    # cos_smpl = subsample(cos_smpl)
    print(cos_smpl.shape)
    cos_smpl = np.clip(cos_smpl, -1, 1)
    mask = cos_smpl <= 1
    assert np.all(mask), (~mask).mean()
    mask = cos_smpl >= 0
    assert np.all(mask), (~mask).mean()
    angles_smpl = np.arccos(cos_smpl) / (math.pi / 2)
    p_hist = [
        calc_histogram(angles_smpl[:, i])
        for i, _ in enumerate(pose_estimation.SKELETON)
    ]

    with multiprocessing.Pool(8) as p:
        results = list(
            tqdm.tqdm(
                p.imap_unordered(
                    functools.partial(step, angles_opt=angles_opt, p_hist=p_hist, bone_idx=bone_idx),
                    itertools.product(
                        np.linspace(0, 20, 100),
                        np.linspace(-20, 20, 200),
                        np.linspace(-20, 1, 100),
                    ),
                ),
                total=(100 * 200 * 100),
            )
        )

    kls, params = zip(*results)
    ind = np.argmin(kls)
    best_params = params[ind]

    print(kls[ind], best_params)

    inds = np.argsort(kls)
    for i in inds[:10]:
        print(kls[i])
        print(params[i])
        print()

    return best_params


def main():
    cos_opt_dir = "paper_single2_150mse"
    hist_smpl_fpath = "./data/hist_smpl.npy"
    # hist_smpl_fpath = "./testtest.npy"
    params = optimize(cos_opt_dir, hist_smpl_fpath)
    # params = (1.2121212121212122, -1.105527638190953, 0.787878787878789)
    # params = (0.20202020202020202, 0.30150753768844396, 0.3636363636363633)
    print(params)

    x = np.linspace(0, math.pi / 2, 100)
    y = cub(x / (math.pi / 2), *params) * (math.pi / 2)
    x = x * 180 / math.pi
    y = y * 180 / math.pi

    out_dir = Path("hists")
    out_dir.mkdir(parents=True, exist_ok=True)
    plot_hist(cos_opt_dir, hist_smpl_fpath, params, out_dir, xy=(x, y))

    plt.figure(figsize=(4, 4), dpi=300)
    plt.plot(x, y, linewidth=6)
    plt.plot(x, x, linewidth=2, linestyle="dashed")
    xticks = [0, 15, 30, 45, 60, 75, 90]
    plt.xticks(xticks)
    plt.yticks(xticks)
    plt.axis("equal")
    plt.tight_layout()
    plt.savefig(out_dir / "new_out.png")


if __name__ == "__main__":
    main()


================================================
FILE: src/losses.py
================================================
import itertools

import torch
import torch.nn as nn

import pose_estimation


class MSE(nn.Module):
    def __init__(self, ignore=None):
        super().__init__()

        self.mse = torch.nn.MSELoss(reduction="none")
        self.ignore = ignore if ignore is not None else []

    def forward(self, y_pred, y_data):
        loss = self.mse(y_pred, y_data)

        if len(self.ignore) > 0:
            loss[self.ignore] *= 0

        return loss.sum() / (len(loss) - len(self.ignore))


class Parallel(nn.Module):
    def __init__(self, skeleton, ignore=None, ground_parallel=None):
        super().__init__()

        self.skeleton = skeleton
        if ignore is not None:
            self.ignore = set(ignore)
        else:
            self.ignore = set()

        self.ground_parallel = ground_parallel if ground_parallel is not None else []
        self.parallel_in_3d = []

        self.cos = None

    def forward(self, y_pred3d, y_data, z, spine_j, writer=None, global_step=0):
        y_pred = y_pred3d[:, :2]
        rleg, lleg = spine_j

        Lcon2d = Lcount = 0
        if hasattr(self, "contact_2d"):
            for c2d in self.contact_2d:
                for (
                    (src_1, dst_1, t_1),
                    (src_2, dst_2, t_2),
                ) in itertools.combinations(c2d, 2):

                    a_1 = torch.lerp(y_data[src_1], y_data[dst_1], t_1)
                    a_2 = torch.lerp(y_data[src_2], y_data[dst_2], t_2)
                    a = a_2 - a_1

                    b_1 = torch.lerp(y_pred[src_1], y_pred[dst_1], t_1)
                    b_2 = torch.lerp(y_pred[src_2], y_pred[dst_2], t_2)
                    b = b_2 - b_1

                    lcon2d = ((a - b) ** 2).sum()
                    Lcon2d = Lcon2d + lcon2d
                    Lcount += 1

        if Lcount > 0:
            Lcon2d = Lcon2d / Lcount

        Ltan = Lpar = Lcos = Lcount = 0
        Lspine = 0
        for i, bone in enumerate(self.skeleton):
            if bone in self.ignore:
                continue

            src, dst = bone

            b = y_data[dst] - y_data[src]
            t = nn.functional.normalize(b, dim=0)
            n = torch.stack([-t[1], t[0]])

            if src == 10 and dst == 11:  # right leg
                a = rleg
            elif src == 13 and dst == 14:  # left leg
                a = lleg
            else:
                a = y_pred[dst] - y_pred[src]

            bone_name = f"{pose_estimation.KPS[src]}_{pose_estimation.KPS[dst]}"
            c = a - b
            lcos_loc = ltan_loc = lpar_loc = 0
            if self.cos is not None:
                if bone not in [
                    (1, 2),  # Neck + Right Shoulder
                    (1, 5),  # Neck + Left Shoulder
                    (9, 10),  # Hips + Right Upper Leg
                    (9, 13),  # Hips + Left Upper Leg
                ]:
                    a = y_pred[dst] - y_pred[src]
                    l2d = torch.norm(a, dim=0)
                    l3d = torch.norm(y_pred3d[dst] - y_pred3d[src], dim=0)
                    lcos = self.cos[i]

                    lcos_loc = (l2d / l3d - lcos) ** 2
                    Lcos = Lcos + lcos_loc
                    lpar_loc = ((a / l2d) * n).sum() ** 2
                    Lpar = Lpar + lpar_loc
            else:
                ltan_loc = ((c * t).sum()) ** 2
                Ltan = Ltan + ltan_loc
                lpar_loc = (c * n).sum() ** 2
                Lpar = Lpar + lpar_loc

            if writer is not None:
                writer.add_scalar(f"tan/{bone_name}", ltan_loc, global_step=global_step)
                writer.add_scalar(f"cos/{bone_name}", lcos_loc, global_step=global_step)
                writer.add_scalar(f"par/{bone_name}", lpar_loc, global_step=global_step)

            Lcount += 1

        if Lcount > 0:
            Ltan = Ltan / Lcount
            Lcos = Lcos / Lcount
            Lpar = Lpar / Lcount
            Lspine = Lspine / Lcount

        Lgr = Lcount = 0
        for (src, dst), value in self.ground_parallel:
            bone = y_pred[dst] - y_pred[src]
            bone = nn.functional.normalize(bone, dim=0)
            l = (torch.abs(bone[0]) - value) ** 2

            Lgr = Lgr + l
            Lcount += 1

        if Lcount > 0:
            Lgr = Lgr / Lcount

        Lstraight3d = Lcount = 0
        for (i, j), (k, l) in self.parallel_in_3d:
            a = z[j] - z[i]
            a = nn.functional.normalize(a, dim=0)
            b = z[l] - z[k]
            b = nn.functional.normalize(b, dim=0)
            lo = (((a * b).sum() - 1) ** 2).sum()
            Lstraight3d = Lstraight3d + lo
            Lcount += 1

            b = y_data[1] - y_data[8]
            b = nn.functional.normalize(b, dim=0)

        if Lcount > 0:
            Lstraight3d = Lstraight3d / Lcount

        return Ltan, Lcos, Lpar, Lspine, Lgr, Lstraight3d, Lcon2d


class MimickedSelfContactLoss(nn.Module):
    def __init__(self, geodesics_mask):
        super().__init__()
        """
        Loss that lets vertices in contact on presented mesh attract vertices that are close.
        """
        # geodesic distance mask
        self.register_buffer("geomask", geodesics_mask)

    def forward(
        self,
        presented_contact,
        vertices,
        v2v=None,
        contact_mode="dist_tanh",
        contact_thresh=1,
    ):

        contactloss = 0.0

        if v2v is None:
            # compute pairwise distances
            verts = vertices.contiguous()
            nv = verts.shape[1]
            v2v = verts.squeeze().unsqueeze(1).expand(
                nv, nv, 3
            ) - verts.squeeze().unsqueeze(0).expand(nv, nv, 3)
            v2v = torch.norm(v2v, 2, 2)

        # loss for self-contact from mimic'ed pose
        if len(presented_contact) > 0:
            # without geodesic distance mask, compute distances
            # between each pair of verts in contact
            with torch.no_grad():
                cvertstobody = v2v[presented_contact, :]
                cvertstobody = cvertstobody[:, presented_contact]
                maskgeo = self.geomask[presented_contact, :]
                maskgeo = maskgeo[:, presented_contact]
                weights = torch.ones_like(cvertstobody).to(verts.device)
                weights[~maskgeo] = float("inf")
                min_idx = torch.min((cvertstobody + 1) * weights, 1)[1]
                min_idx = presented_contact[min_idx.cpu().numpy()]

            v2v_min = v2v[presented_contact, min_idx]

            # tanh will not pull vertices that are ~more than contact_thres far apart
            if contact_mode == "dist_tanh":
                contactloss = contact_thresh * torch.tanh(v2v_min / contact_thresh)
                contactloss = contactloss.mean()
            else:
                contactloss = v2v_min.mean()

        return contactloss


================================================
FILE: src/pose.py
================================================
import argparse
import math
from pathlib import Path

import cv2
import numpy as np
import PIL.Image as Image
import selfcontact
import selfcontact.losses
import shapely.geometry
import torch
import torch.nn as nn
import torch.optim as optim
import torchgeometry
import tqdm
import trimesh
from skimage import measure
from torch.utils.tensorboard.writer import SummaryWriter

import fist_pose
import hist_cub
import losses
import pose_estimation
import spin
import utils

PE_KSP_TO_SPIN = {
    "Head": "Head",
    "Neck": "Neck",
    "Right Shoulder": "Right ForeArm",
    "Right Arm": "Right Arm",
    "Right Hand": "Right Hand",
    "Left Shoulder": "Left ForeArm",
    "Left Arm": "Left Arm",
    "Left Hand": "Left Hand",
    "Spine": "Spine1",
    "Hips": "Hips",
    "Right Upper Leg": "Right Upper Leg",
    "Right Leg": "Right Leg",
    "Right Foot": "Right Foot",
    "Left Upper Leg": "Left Upper Leg",
    "Left Leg": "Left Leg",
    "Left Foot": "Left Foot",
    "Left Toe": "Left Toe",
    "Right Toe": "Right Toe",
}
MODELS_DIR = "models"


def parse_args():
    parser = argparse.ArgumentParser()

    parser.add_argument(
        "--pose-estimation-model-path",
        type=str,
        default=f"./{MODELS_DIR}/hrn_w48_384x288.onnx",
        help="Pose Estimation model",
    )

    parser.add_argument(
        "--contact-model-path",
        type=str,
        default=f"./{MODELS_DIR}/contact_hrn_w32_256x192.onnx",
        help="Contact model",
    )

    parser.add_argument(
        "--device",
        type=str,
        default="cuda",
        choices=["cpu", "cuda"],
        help="Torch device",
    )

    parser.add_argument(
        "--spin-model-path",
        type=str,
        default=f"./{MODELS_DIR}/spin_model_smplx_eft_18.pt",
        help="SPIN model path",
    )

    parser.add_argument(
        "--smpl-type",
        type=str,
        default="smplx",
        choices=["smplx"],
        help="SMPL model type",
    )
    parser.add_argument(
        "--smpl-model-dir",
        type=str,
        default=f"./{MODELS_DIR}/models/smplx",
        help="SMPL model dir",
    )
    parser.add_argument(
        "--smpl-mean-params-path",
        type=str,
        default=f"./{MODELS_DIR}/data/smpl_mean_params.npz",
        help="SMPL mean params",
    )
    parser.add_argument(
        "--essentials-dir",
        type=str,
        default=f"./{MODELS_DIR}/smplify-xmc-essentials",
        help="SMPL Essentials folder for contacts",
    )

    parser.add_argument(
        "--parametrization-path",
        type=str,
        default=f"./{MODELS_DIR}/smplx_parametrization/parametrization.npy",
        help="Parametrization path",
    )
    parser.add_argument(
        "--bone-parametrization-path",
        type=str,
        default=f"./{MODELS_DIR}/smplx_parametrization/bone_to_param2.npy",
        help="Bone parametrization path",
    )
    parser.add_argument(
        "--foot-inds-path",
        type=str,
        default=f"./{MODELS_DIR}/smplx_parametrization/foot_inds.npy",
        help="Foot indinces",
    )

    parser.add_argument(
        "--save-path",
        type=str,
        required=True,
        help="Path to save the results",
    )

    parser.add_argument(
        "--img-path",
        type=str,
        required=True,
        help="Path to img to test",
    )

    parser.add_argument(
        "--use-contacts",
        action="store_true",
        help="Use contact model",
    )
    parser.add_argument(
        "--use-msc",
        action="store_true",
        help="Use MSC loss",
    )
    parser.add_argument(
        "--use-natural",
        action="store_true",
        help="Use regularity",
    )
    parser.add_argument(
        "--use-cos",
        action="store_true",
        help="Use cos model",
    )
    parser.add_argument(
        "--use-angle-transf",
        action="store_true",
        help="Use cube foreshortening transformation",
    )

    parser.add_argument(
        "--c-mse",
        type=float,
        default=0,
        help="MSE weight",
    )
    parser.add_argument(
        "--c-par",
        type=float,
        default=10,
        help="Parallel weight",
    )

    parser.add_argument(
        "--c-f",
        type=float,
        default=1000,
        help="Cos coef",
    )
    parser.add_argument(
        "--c-parallel",
        type=float,
        default=100,
        help="Parallel weight",
    )
    parser.add_argument(
        "--c-reg",
        type=float,
        default=1000,
        help="Regularity weight",
    )
    parser.add_argument(
        "--c-cont2d",
        type=float,
        default=1,
        help="Contact 2D weight",
    )
    parser.add_argument(
        "--c-msc",
        type=float,
        default=17_500,
        help="MSC weight",
    )

    parser.add_argument(
        "--fist",
        nargs="+",
        type=str,
        choices=list(fist_pose.INT_TO_FIST),
    )

    args = parser.parse_args()

    return args


def freeze_layers(model):
    for module in model.modules():
        if type(module) is False:
            continue

        if isinstance(module, nn.modules.batchnorm._BatchNorm):
            module.eval()
            for m in module.parameters():
                m.requires_grad = False

        if isinstance(module, nn.Dropout):
            module.eval()
            for m in module.parameters():
                m.requires_grad = False


def project_and_normalize_to_spin(vertices_3d, camera):
    vertices_2d = vertices_3d  # [:, :2]

    scale, translate = camera[0], camera[1:]
    translate = scale.new_zeros(3)
    translate[:2] = camera[1:]

    vertices_2d = vertices_2d + translate
    vertices_2d = scale * vertices_2d + 1
    vertices_2d = spin.constants.IMG_RES / 2 * vertices_2d

    return vertices_2d


def project_and_normalize_to_spin_legs(vertices_3d, A, camera):
    A, J = A
    A = A[0]
    J = J[0]
    L = vertices_3d.new_tensor(
        [
            [0.98619063, 0.16560926, 0.00127302],
            [-0.16560601, 0.98603675, 0.01749799],
            [0.00164258, -0.01746717, 0.99984609],
        ]
    )
    R = vertices_3d.new_tensor(
        [
            [0.9910211, -0.13368178, -0.0025208],
            [0.13367888, 0.99027076, 0.03864949],
            [-0.00267045, -0.03863944, 0.99924965],
        ]
    )
    scale = camera[0]
    R = A[2, :3, :3] @ R  # 2 - right
    L = A[1, :3, :3] @ L  # 1 - left
    r = J[5] - J[2]
    l = J[4] - J[1]

    rleg = scale * spin.constants.IMG_RES / 2 * R @ r
    lleg = scale * spin.constants.IMG_RES / 2 * L @ l

    rleg = rleg[:2]
    lleg = lleg[:2]

    return rleg, lleg


def rotation_matrix_to_angle_axis(rotmat):
    bs, n_joints, *_ = rotmat.size()
    rotmat = torch.cat(
        [
            rotmat.view(-1, 3, 3),
            rotmat.new_tensor([0, 0, 1], dtype=torch.float32)
            .view(bs, 3, 1)
            .expand(n_joints, -1, -1),
        ],
        dim=-1,
    )
    aa = torchgeometry.rotation_matrix_to_angle_axis(rotmat)
    aa = aa.reshape(bs, 3 * n_joints)

    return aa


def get_smpl_output(smpl, rotmat, betas, use_betas=True, zero_hands=False):
    if smpl.name() == "SMPL":
        smpl_output = smpl(
            betas=betas if use_betas else None,
            body_pose=rotmat[:, 1:],
            global_orient=rotmat[:, 0].unsqueeze(1),
            pose2rot=False,
        )
    elif smpl.name() == "SMPL-X":
        rotmat = rotation_matrix_to_angle_axis(rotmat)
        if zero_hands:
            for i in [20, 21]:
                rotmat[:, 3 * i : 3 * (i + 1)] = 0

            for i in [12, 15]:  # neck, head
                rotmat[:, 3 * i + 1] = 0  # y
        smpl_output = smpl(
            betas=betas if use_betas else None,
            body_pose=rotmat[:, 3:],
            global_orient=rotmat[:, :3],
            pose2rot=True,
        )
    else:
        raise NotImplementedError

    return smpl_output, rotmat


def get_predictions(model_hmr, smpl, input_img, use_betas=True, zero_hands=False):
    input_img = input_img.unsqueeze(0)
    rotmat, betas, camera = model_hmr(input_img)

    smpl_output, rotmat = get_smpl_output(
        smpl, rotmat, betas, use_betas=use_betas, zero_hands=zero_hands
    )

    rotmat = rotmat.squeeze(0)
    betas = betas.squeeze(0)
    camera = camera.squeeze(0)
    z = smpl_output.joints
    z = z.squeeze(0)

    return rotmat, betas, camera, smpl_output, z


def get_pred_and_data(
    model_hmr, smpl, selector, input_img, use_betas=True, zero_hands=False
):
    rotmat, betas, camera, smpl_output, zz = get_predictions(
        model_hmr, smpl, input_img, use_betas=use_betas, zero_hands=zero_hands
    )

    joints = smpl_output.joints.squeeze(0)
    joints_2d = project_and_normalize_to_spin(joints, camera)
    rleg, lleg = project_and_normalize_to_spin_legs(joints, smpl_output.A, camera)
    joints_2d_orig = joints_2d
    joints_2d = joints_2d[selector]

    vertices = smpl_output.vertices.squeeze(0)
    vertices_2d = project_and_normalize_to_spin(vertices, camera)

    zz = zz[selector]

    return (
        rotmat,
        betas,
        camera,
        joints_2d,
        zz,
        vertices_2d,
        smpl_output,
        (rleg, lleg),
        joints_2d_orig,
    )


def normalize_keypoints_to_spin(keypoints_2d, img_size):
    h, w = img_size
    if h > w:  # vertically
        ax1 = 1
        ax2 = 0
    else:  # horizontal
        ax1 = 0
        ax2 = 1

    shift = (img_size[ax1] - img_size[ax2]) / 2
    scale = spin.constants.IMG_RES / img_size[ax2]
    keypoints_2d_normalized = np.copy(keypoints_2d)
    keypoints_2d_normalized[:, ax2] -= shift
    keypoints_2d_normalized *= scale

    return keypoints_2d_normalized, shift, scale, ax2


def unnormalize_keypoints_from_spin(keypoints_2d, shift, scale, ax2):
    keypoints_2d_normalized = np.copy(keypoints_2d)
    keypoints_2d_normalized /= scale
    keypoints_2d_normalized[:, ax2] += shift

    return keypoints_2d_normalized


def get_vertices_in_heatmap(contact_heatmap):
    contact_heatmap_size = contact_heatmap.shape[:2]
    label = measure.label(contact_heatmap)

    y_data_conts = []
    for i in range(1, label.max() + 1):
        predicted_kps_contact = np.vstack(np.nonzero(label == i)[::-1]).T.astype(
            "float"
        )
        predicted_kps_contact_scaled, *_ = normalize_keypoints_to_spin(
            predicted_kps_contact, contact_heatmap_size
        )
        y_data_cont = torch.from_numpy(predicted_kps_contact_scaled).int().tolist()
        y_data_cont = shapely.geometry.MultiPoint(y_data_cont).convex_hull
        y_data_conts.append(y_data_cont)

    return y_data_conts


def get_contact_heatmap(model_contact, img_path, thresh=0.5):
    contact_heatmap = pose_estimation.infer_single_image(
        model_contact,
        img_path,
        input_img_size=(192, 256),
        return_kps=False,
    )
    contact_heatmap = contact_heatmap.squeeze(0)
    contact_heatmap_orig = contact_heatmap.copy()

    mi = contact_heatmap.min()
    ma = contact_heatmap.max()
    contact_heatmap = (contact_heatmap - mi) / (ma - mi)
    contact_heatmap_ = ((contact_heatmap > thresh) * 255).astype("uint8")

    contact_heatmap = np.repeat(contact_heatmap[..., None], repeats=3, axis=-1)
    contact_heatmap = (contact_heatmap * 255).astype("uint8")

    return contact_heatmap_, contact_heatmap, contact_heatmap_orig


def discretize(parametrization, n_bins=100):
    bins = np.linspace(0, 1, n_bins + 1)
    inds = np.digitize(parametrization, bins)
    disc_parametrization = bins[inds - 1]

    return disc_parametrization


def get_mapping_from_params_to_verts(verts, params):
    mapping = {}
    for v, t in zip(verts, params):
        mapping.setdefault(t, []).append(v)

    return mapping


def find_contacts(y_data_conts, keypoints_2d, bone_to_params, thresh=12, step=0.0072246375):
    n_bins = int(math.ceil(1 / step)) - 1  # mean face's circumradius
    contact = []
    contact_2d = []
    for_mask = []
    for y_data_cont in y_data_conts:
        contact_loc = []
        contact_2d_loc = []
        buffer = y_data_cont.buffer(thresh)
        mask_add = False
        for i, j in pose_estimation.SKELETON:
            verts, t3d = bone_to_params[(i, j)]
            if len(verts) == 0:
                continue

            t3d = discretize(t3d, n_bins=n_bins)
            t3d_to_verts = get_mapping_from_params_to_verts(verts, t3d)
            t3d_to_verts_sorted = sorted(t3d_to_verts.items(), key=lambda x: x[0])
            t3d_sorted_np = np.array([x for x, _ in t3d_to_verts_sorted])

            line = shapely.geometry.LineString([keypoints_2d[i], keypoints_2d[j]])
            lint = buffer.intersection(line)
            if len(lint.boundary.geoms) < 2:
                continue

            t2d_start = line.project(lint.boundary.geoms[0], normalized=True)
            t2d_end = line.project(lint.boundary.geoms[1], normalized=True)
            assert t2d_start <= t2d_end

            t2ds = discretize(
                np.linspace(t2d_start, t2d_end, n_bins + 1), n_bins=n_bins
            )
            to_add = False
            for t2d in t2ds:
                if t2d < t3d_sorted_np[0] or t2d > t3d_sorted_np[-1]:
                    continue

                t2d_ind = np.searchsorted(t3d_sorted_np, t2d)
                c = t3d_to_verts_sorted[t2d_ind][1]

                contact_loc.extend(c)
                to_add = True
                mask_add = True

                if t2d_ind + 1 < len(t3d_to_verts_sorted):
                    c = t3d_to_verts_sorted[t2d_ind + 1][1]
                    contact_loc.extend(c)

                if t2d_ind > 0:
                    c = t3d_to_verts_sorted[t2d_ind - 1][1]
                    contact_loc.extend(c)

            if to_add:
                contact_2d_loc.append((i, j, t2d_start + 0.5 * (t2d_end - t2d_start)))

        if mask_add:
            for_mask.append(buffer.exterior.coords.xy)

        contact_loc = sorted(set(contact_loc))
        contact_loc = np.array(contact_loc, dtype="int")
        contact.append(contact_loc)
        contact_2d.append(contact_2d_loc)

    for_mask = [np.stack((x, y), axis=0).T[:, None].astype("int") for x, y in for_mask]

    return contact, contact_2d, for_mask


def optimize(
    model_hmr,
    smpl,
    selector,
    input_img,
    keypoints_2d,
    optimizer,
    args,
    loss_mse=None,
    loss_parallel=None,
    c_mse=0.0,
    c_new_mse=1.0,
    c_beta=1e-3,
    sc_crit=None,
    msc_crit=None,
    contact=None,
    n_steps=60,
    save_path=None,
    writer=None,
    i_ini=0,
):
    to_save = False
    if save_path is not None:
        (
            img_original,
            predicted_keypoints_2d,
            save_path,
            shift,
            scale,
            ax2,
            prefix,
        ) = save_path
        to_save = True

    mean_zfoot_val = {}
    with tqdm.trange(n_steps) as pbar:
        for i in pbar:
            global_step = i + i_ini
            optimizer.zero_grad()

            (
                rotmat_pred,
                betas_pred,
                camera_pred,
                keypoints_3d_pred,
                z,
                vertices_2d_pred,
                smpl_output,
                (rleg, lleg),
                joints_2d_orig,
            ) = get_pred_and_data(
                model_hmr,
                smpl,
                selector,
                input_img,
            )
            keypoints_2d_pred = keypoints_3d_pred[:, :2]
            if to_save:
                utils.save_results_image(
                    camera=camera_pred.detach().cpu().numpy(),
                    focal_length_x=spin.constants.FOCAL_LENGTH,
                    focal_length_y=spin.constants.FOCAL_LENGTH,
                    vertices=smpl_output.vertices.detach()[0].cpu().numpy(),
                    input_img=img_original,
                    faces=smpl.faces,
                    keypoints=predicted_keypoints_2d,
                    keypoints_2=unnormalize_keypoints_from_spin(
                        keypoints_2d_pred.detach().cpu().numpy(), shift, scale, ax2
                    ),
                    # keypoints_2=unnormalize_keypoints_from_spin(joints_2d_orig.detach().cpu().numpy(), shift, scale, ax2),
                    # heatmap=predicted_contact_heatmap_raw,
                    filename=save_path / f"{prefix}_{i:0>4}.png",
                    contactlist=contact,
                    user_study=False,
                )

            loss = l2 = 0.0
            if c_mse > 0 and loss_mse is not None:
                l2 = loss_mse(keypoints_2d_pred, keypoints_2d)
                loss = loss + c_mse * l2

                if writer is not None:
                    writer.add_scalar("mse", l2, global_step=global_step)

            vertices_pred = smpl_output.vertices

            lpar = z_loss = loss_sh = 0.0
            if c_new_mse > 0 and loss_parallel is not None:
                Ltan, Lcos, Lpar, Lspine, Lgr, Lstraight3d, Lcon2d = loss_parallel(
                    keypoints_3d_pred,
                    keypoints_2d,
                    z,
                    (rleg, lleg),
                    writer=writer,
                    global_step=global_step,
                )
                lpar = (
                    Ltan
                    + c_new_mse * (args.c_f * Lcos + args.c_parallel * Lpar)
                    + Lspine
                    + args.c_reg * Lgr
                    + args.c_reg * Lstraight3d
                    + args.c_cont2d * Lcon2d
                )
                loss = loss + 300 * lpar

                if writer is not None:
                    writer.add_scalar("tan", Ltan, global_step=global_step)
                    writer.add_scalar("cos", Lcos, global_step=global_step)
                    writer.add_scalar("par", Lpar, global_step=global_step)
                    writer.add_scalar("spine", Lspine, global_step=global_step)
                    writer.add_scalar("ground/chain", Lgr, global_step=global_step)
                    writer.add_scalar(
                        "straight_in_3d", Lstraight3d, global_step=global_step
                    )
                    writer.add_scalar("contact/con2d", Lcon2d, global_step=global_step)

                for side in ["left", "right"]:
                    attr = f"{side}_foot_inds"
                    if hasattr(loss_parallel, attr):
                        foot_inds = getattr(loss_parallel, attr)
                        zind = 1
                        if attr not in mean_zfoot_val:
                            with torch.no_grad():
                                mean_zfoot_val[attr] = torch.median(
                                    vertices_pred[0, foot_inds, zind], dim=0
                                ).values

                        loss_foot = (
                            (vertices_pred[0, foot_inds, zind] - mean_zfoot_val[attr])
                            ** 2
                        ).sum()
                        loss = loss + args.c_reg * loss_foot

                        if writer is not None:
                            writer.add_scalar(
                                f"ground/{side} foot",
                                loss_foot,
                                global_step=global_step,
                            )

                if hasattr(loss_parallel, "silhuette_vertices_inds"):
                    inds = loss_parallel.silhuette_vertices_inds
                    loss_sh = (
                        (vertices_pred[0, inds, 1] - loss_parallel.ground) ** 2
                    ).sum()
                    loss = loss + args.c_reg * loss_sh

                    if writer is not None:
                        writer.add_scalar(
                            "ground/silhuette", loss_sh, global_step=global_step
                        )

            lbeta = (betas_pred**2).mean()
            lcam = ((torch.exp(-camera_pred[0] * 10)) ** 2).mean()
            loss = loss + c_beta * lbeta + lcam

            if writer is not None:
                writer.add_scalar("loss/beta", lbeta, global_step=global_step)
                writer.add_scalar("loss/cam", lcam, global_step=global_step)

            lgsc_a = gsc_contact_loss = faces_angle_loss = 0.0
            if sc_crit is not None:
                gsc_contact_loss, faces_angle_loss = sc_crit(
                    vertices_pred,
                )
                lgsc_a = 1000 * gsc_contact_loss + 0.1 * faces_angle_loss
                loss = loss + lgsc_a

                if writer is not None:
                    writer.add_scalar(
                        "contact/gsc", gsc_contact_loss, global_step=global_step
                    )
                    writer.add_scalar(
                        "contact/faces_angle", faces_angle_loss, global_step=global_step
                    )

            msc_loss = 0.0
            if contact is not None and len(contact) > 0 and msc_crit is not None:
                if not isinstance(contact, list):
                    contact = [contact]

                for cntct in contact:
                    msc_loss = msc_crit(
                        cntct,
                        vertices_pred,
                    )
                    loss = loss + args.c_msc * msc_loss

                    if writer is not None:
                        writer.add_scalar(
                            "contact/msc", msc_loss, global_step=global_step
                        )

            loss.backward()
            optimizer.step()

            epoch_loss = loss.item()
            pbar.set_postfix(
                **{
                    "l": f"{epoch_loss:.3}",
                    "l2": f"{l2:.3}",
                    "par": f"{lpar:.3}",
                    "beta": f"{lbeta:.3}",
                    "cam": f"{lcam:.3}",
                    "z": f"{z_loss:.3}",
                    "gsc_contact": f"{float(gsc_contact_loss):.3}",
                    "faces_angle": f"{float(faces_angle_loss):.3}",
                    "msc": f"{float(msc_loss):.3}",
                }
            )

    with torch.no_grad():
        (
            rotmat_pred,
            betas_pred,
            camera_pred,
            keypoints_3d_pred,
            z,
            vertices_2d_pred,
            smpl_output,
            (rleg, lleg),
            joints_2d_orig,
        ) = get_pred_and_data(
            model_hmr,
            smpl,
            selector,
            input_img,
            zero_hands=True,
        )

    return (
        rotmat_pred,
        betas_pred,
        camera_pred,
        keypoints_3d_pred,
        vertices_2d_pred,
        smpl_output,
        z,
        joints_2d_orig,
    )


def optimize_ft(
    theta,
    camera,
    smpl,
    selector,
    input_img,
    keypoints_2d,
    args,
    loss_mse=None,
    loss_parallel=None,
    c_mse=0.0,
    c_new_mse=1.0,
    sc_crit=None,
    msc_crit=None,
    contact=None,
    n_steps=60,
    save_path=None,
    writer=None,
    i_ini=0,
    zero_hands=False,
    fist=None,
):
    to_save = False
    if save_path is not None:
        (
            img_original,
            predicted_keypoints_2d,
            save_path,
            shift,
            scale,
            ax2,
            prefix,
        ) = save_path
        to_save = True

    mean_zfoot_val = {}

    theta = theta.detach().clone()
    camera = camera.detach().clone()
    rotmat_pred = nn.Parameter(theta)
    camera_pred = nn.Parameter(camera)
    optimizer = torch.optim.Adam(
        [
            rotmat_pred,
            camera_pred,
        ],
        lr=1e-3,
    )
    global_step = i_ini

    with tqdm.trange(n_steps) as pbar:
        for i in pbar:
            global_step = i + i_ini
            optimizer.zero_grad()

            global_orient = rotmat_pred[:3]
            body_pose = rotmat_pred[3:]
            smpl_output = smpl(
                global_orient=global_orient.unsqueeze(0),
                body_pose=body_pose.unsqueeze(0),
                pose2rot=True,
            )

            z = smpl_output.joints
            z = z.squeeze(0)

            joints = smpl_output.joints.squeeze(0)
            joints_2d = project_and_normalize_to_spin(joints, camera_pred)
            rleg, lleg = project_and_normalize_to_spin_legs(
                joints, smpl_output.A, camera_pred
            )
            joints_2d = joints_2d[selector]
            z = z[selector]
            keypoints_3d_pred = joints_2d

            keypoints_2d_pred = keypoints_3d_pred[:, :2]
            if to_save:
                utils.save_results_image(
                    camera=camera_pred.detach().cpu().numpy(),
                    focal_length_x=spin.constants.FOCAL_LENGTH,
                    focal_length_y=spin.constants.FOCAL_LENGTH,
                    vertices=smpl_output.vertices.detach()[0].cpu().numpy(),
                    input_img=img_original,
                    faces=smpl.faces,
                    keypoints=predicted_keypoints_2d,
                    keypoints_2=unnormalize_keypoints_from_spin(
                        keypoints_2d_pred.detach().cpu().numpy(), shift, scale, ax2
                    ),
                    # keypoints_2=unnormalize_keypoints_from_spin(joints_2d_orig.detach().cpu().numpy(), shift, scale, ax2),
                    # heatmap=predicted_contact_heatmap_raw,
                    filename=save_path / f"{prefix}_{i:0>4}.png",
                    contactlist=contact,
                    user_study=False,
                )

            lprior = ((rotmat_pred - theta) ** 2).sum() + (
                (camera_pred - camera) ** 2
            ).sum()
            loss = lprior

            l2 = 0.0
            if c_mse > 0 and loss_mse is not None:
                l2 = loss_mse(keypoints_2d_pred, keypoints_2d)
                loss = loss + c_mse * l2

                if writer is not None:
                    writer.add_scalar("mse", l2, global_step=global_step)

            vertices_pred = smpl_output.vertices

            lpar = z_loss = loss_sh = 0.0
            if c_new_mse > 0 and loss_parallel is not None:
                Ltan, Lcos, Lpar, Lspine, Lgr, Lstraight3d, Lcon2d = loss_parallel(
                    keypoints_3d_pred,
                    keypoints_2d,
                    z,
                    (rleg, lleg),
                    writer=writer,
                    global_step=global_step,
                )
                lpar = (
                    Ltan
                    + c_new_mse * (args.c_f * Lcos + args.c_parallel * Lpar)
                    + Lspine
                    + args.c_reg * Lgr
                    + args.c_reg * Lstraight3d
                    + args.c_cont2d * Lcon2d
                )
                loss = loss + 300 * lpar

                if writer is not None:
                    writer.add_scalar("tan", Ltan, global_step=global_step)
                    writer.add_scalar("cos", Lcos, global_step=global_step)
                    writer.add_scalar("par", Lpar, global_step=global_step)
                    writer.add_scalar("spine", Lspine, global_step=global_step)
                    writer.add_scalar("ground/chain", Lgr, global_step=global_step)
                    writer.add_scalar(
                        "straight_in_3d", Lstraight3d, global_step=global_step
                    )
                    writer.add_scalar("contact/con2d", Lcon2d, global_step=global_step)

                for side in ["left", "right"]:
                    attr = f"{side}_foot_inds"
                    if hasattr(loss_parallel, attr):
                        foot_inds = getattr(loss_parallel, attr)
                        zind = 1
                        if attr not in mean_zfoot_val:
                            with torch.no_grad():
                                mean_zfoot_val[attr] = torch.median(
                                    vertices_pred[0, foot_inds, zind], dim=0
                                ).values

                        loss_foot = (
                            (vertices_pred[0, foot_inds, zind] - mean_zfoot_val[attr])
                            ** 2
                        ).sum()
                        loss = loss + args.c_reg * loss_foot

                        if writer is not None:
                            writer.add_scalar(
                                f"ground/{side} foot",
                                loss_foot,
                                global_step=global_step,
                            )

                if hasattr(loss_parallel, "silhuette_vertices_inds"):
                    inds = loss_parallel.silhuette_vertices_inds
                    loss_sh = (
                        (vertices_pred[0, inds, 1] - loss_parallel.ground) ** 2
                    ).sum()
                    loss = loss + args.c_reg * loss_sh

                    if writer is not None:
                        writer.add_scalar(
                            "ground/silhuette", loss_sh, global_step=global_step
                        )

            lgsc_a = gsc_contact_loss = faces_angle_loss = 0.0
            if sc_crit is not None:
                gsc_contact_loss, faces_angle_loss = sc_crit(vertices_pred)
                lgsc_a = 1000 * gsc_contact_loss + 0.1 * faces_angle_loss
                loss = loss + lgsc_a

                if writer is not None:
                    writer.add_scalar(
                        "contact/gsc", gsc_contact_loss, global_step=global_step
                    )
                    writer.add_scalar(
                        "contact/faces_angle", faces_angle_loss, global_step=global_step
                    )

            msc_loss = 0.0
            if contact is not None and len(contact) > 0 and msc_crit is not None:
                if not isinstance(contact, list):
                    contact = [contact]

                for cntct in contact:
                    msc_loss = msc_crit(
                        cntct,
                        vertices_pred,
                    )
                    loss = loss + args.c_msc * msc_loss

                    if writer is not None:
                        writer.add_scalar(
                            "contact/msc", msc_loss, global_step=global_step
                        )

            loss.backward()
            optimizer.step()

            epoch_loss = loss.item()
            pbar.set_postfix(
                **{
                    "l": f"{epoch_loss:.3}",
                    "l2": f"{l2:.3}",
                    "par": f"{lpar:.3}",
                    "z": f"{z_loss:.3}",
                    "gsc_contact": f"{float(gsc_contact_loss):.3}",
                    "faces_angle": f"{float(faces_angle_loss):.3}",
                    "msc": f"{float(msc_loss):.3}",
                }
            )

    rotmat_pred = rotmat_pred.detach()

    if zero_hands:
        for i in [20, 21]:
            rotmat_pred[3 * i : 3 * (i + 1)] = 0

        for i in [12, 15]:  # neck, head
            rotmat_pred[3 * i + 1] = 0  # y

    global_orient = rotmat_pred[:3]
    body_pose = rotmat_pred[3:]
    left_hand_pose = None
    right_hand_pose = None
    if fist is not None:
        left_hand_pose = rotmat_pred.new_tensor(fist_pose.LEFT_RELAXED).unsqueeze(0)
        right_hand_pose = rotmat_pred.new_tensor(fist_pose.RIGHT_RELAXED).unsqueeze(0)
        for f in fist:
            pp = fist_pose.INT_TO_FIST[f]
            if pp is not None:
                pp = rotmat_pred.new_tensor(pp).unsqueeze(0)

            if f.startswith("lf"):
                left_hand_pose = pp
            elif f.startswith("rf"):
                right_hand_pose = pp
            elif f.startswith("l"):
                body_pose[19 * 3 : 19 * 3 + 3] = pp
                left_hand_pose = None
            elif f.startswith("r"):
                body_pose[20 * 3 : 20 * 3 + 3] = pp
                right_hand_pose = None
            else:
                raise RuntimeError(f"No such hand pose: {f}")

    with torch.no_grad():
        smpl_output = smpl(
            global_orient=global_orient.unsqueeze(0),
            body_pose=body_pose.unsqueeze(0),
            left_hand_pose=left_hand_pose,
            right_hand_pose=right_hand_pose,
            pose2rot=True,
        )

    return rotmat_pred, smpl_output


def create_bone(i, j, keypoints_2d):
    a = keypoints_2d[i]
    b = keypoints_2d[j]
    ab = b - a
    ab = torch.nn.functional.normalize(ab, dim=0)

    return ab


def is_parallel_to_plane(bone, thresh=21):
    return abs(bone[0]) > math.cos(math.radians(thresh))


def is_close_to_plane(bone, plane, thresh):
    dist = abs(bone[0] - plane)

    return dist < thresh


def get_selector():
    selector = []
    for kp in pose_estimation.KPS:
        tmp = spin.JOINT_NAMES.index(PE_KSP_TO_SPIN[kp])
        selector.append(tmp)

    return selector


def calc_cos(joints_2d, joints_3d):
    cos = []
    for i, j in pose_estimation.SKELETON:
        a = joints_2d[i] - joints_2d[j]
        a = nn.functional.normalize(a, dim=0)

        b = joints_3d[i] - joints_3d[j]
        b = nn.functional.normalize(b, dim=0)[:2]

        c = (a * b).sum()
        cos.append(c)

    cos = torch.stack(cos, dim=0)

    return cos


def get_natural(keypoints_2d, vertices, right_foot_inds, left_foot_inds, loss_parallel, smpl):
    height_2d = (
        keypoints_2d.max(dim=0).values[0] - keypoints_2d.min(dim=0).values[0]
    ).item()
    plane_2d = keypoints_2d.max(dim=0).values[0].item()

    ground_parallel = []
    parallel_in_3d = []
    parallel3d_bones = set()

    # parallel chains
    for i, j, k in [
        ("Right Upper Leg", "Right Leg", "Right Foot"),
        ("Right Leg", "Right Foot", "Right Toe"),  # to remove?
        ("Left Upper Leg", "Left Leg", "Left Foot"),
        ("Left Leg", "Left Foot", "Left Toe"),  # to remove?
        ("Right Shoulder", "Right Arm", "Right Hand"),
        ("Left Shoulder", "Left Arm", "Left Hand"),
        # ("Hips", "Spine", "Neck"),
        # ("Spine", "Neck", "Head"),
    ]:
        i = pose_estimation.KPS.index(i)
        j = pose_estimation.KPS.index(j)
        k = pose_estimation.KPS.index(k)
        upleg_leg = create_bone(i, j, keypoints_2d)
        leg_foot = create_bone(j, k, keypoints_2d)

        if is_parallel_to_plane(upleg_leg) and is_parallel_to_plane(leg_foot):
            if is_close_to_plane(
                upleg_leg, plane_2d, thresh=0.1 * height_2d
            ) or is_close_to_plane(leg_foot, plane_2d, thresh=0.1 * height_2d):
                ground_parallel.append(((i, j), 1))
                ground_parallel.append(((j, k), 1))

        if (upleg_leg * leg_foot).sum() > math.cos(math.radians(21)):
            parallel_in_3d.append(((i, j), (j, k)))
            parallel3d_bones.add((i, j))
            parallel3d_bones.add((j, k))

    # parallel feets
    for i, j in [
        ("Right Foot", "Right Toe"),
        ("Left Foot", "Left Toe"),
    ]:
        i = pose_estimation.KPS.index(i)
        j = pose_estimation.KPS.index(j)
        if (i, j) in parallel3d_bones:
            continue

        foot_toe = create_bone(i, j, keypoints_2d)
        if is_parallel_to_plane(foot_toe, thresh=25):
            if "Right" in pose_estimation.KPS[i]:
                loss_parallel.right_foot_inds = right_foot_inds
            else:
                loss_parallel.left_foot_inds = left_foot_inds

    loss_parallel.ground_parallel = ground_parallel
    loss_parallel.parallel_in_3d = parallel_in_3d

    vertices_np = vertices[0].cpu().numpy()
    if len(ground_parallel) > 0:
        # Silhuette veritices
        mesh = trimesh.Trimesh(vertices=vertices_np, faces=smpl.faces, process=False)
        silhuette_vertices_mask_1 = np.abs(mesh.vertex_normals[..., 2]) < 2e-1
        height_3d = vertices_np[:, 1].max() - vertices_np[:, 1].min()
        plane_3d = vertices_np[:, 1].max()
        silhuette_vertices_mask_2 = (
            np.abs(vertices_np[:, 1] - plane_3d) < 0.15 * height_3d
        )
        silhuette_vertices_mask = np.logical_and(
            silhuette_vertices_mask_1, silhuette_vertices_mask_2
        )
        (silhuette_vertices_inds,) = np.where(silhuette_vertices_mask)
        if len(silhuette_vertices_inds) > 0:
            loss_parallel.silhuette_vertices_inds = silhuette_vertices_inds
            loss_parallel.ground = plane_3d


def get_cos(keypoints_3d_pred, use_angle_transf, loss_parallel):
    keypoints_2d_pred = keypoints_3d_pred[:, :2]
    with torch.no_grad():
        cos_r = calc_cos(keypoints_2d_pred, keypoints_3d_pred)

    alpha = torch.acos(cos_r)
    if use_angle_transf:
        leg_inds = [
            5,
            6,  # right leg
            7,
            8,  # left leg
        ]
        foot_inds = [15, 16]
        nleg_inds = sorted(
            set(range(len(pose_estimation.SKELETON))) - set(leg_inds) - set(foot_inds)
        )
        alpha[nleg_inds] = alpha[nleg_inds] - alpha[nleg_inds].min()

        amli = alpha[leg_inds].min()
        leg_inds.extend(foot_inds)
        alpha[leg_inds] = alpha[leg_inds] - amli

        angles = alpha.detach().cpu().numpy()
        angles = hist_cub.cub(
            angles / (math.pi / 2),
            a=1.2121212121212122,
            b=-1.105527638190953,
            c=0.787878787878789,
        ) * (math.pi / 2)
        alpha = alpha.new_tensor(angles)

    loss_parallel.cos = torch.cos(alpha)

    return cos_r


def save_mesh_with_winding_numbers(sc_module, vertices, smpl, save_path):
    triangles = sc_module.triangles(vertices)
    exterior = sc_module.get_intersection_mask(vertices, triangles, test_segments=False)
    exterior = exterior.cpu().numpy().squeeze(0)
    utils.save_mesh_with_colors(
        vertices[0].cpu().numpy(),
        smpl.faces,
        save_path / "winding_numbers.ply",
        mask=exterior,
    )

    exterior = sc_module.get_intersection_mask(vertices, triangles)
    exterior = exterior.cpu().numpy().squeeze(0)
    utils.save_mesh_with_colors(
        vertices[0].cpu().numpy(),
        smpl.faces,
        save_path / "winding_numbers_filtered.ply",
        mask=exterior,
    )


def get_contacts(
    args,
    sc_module,
    y_data_conts,
    keypoints_2d,
    vertices,
    bone_to_params,
    loss_parallel,
    img_size_original,
    save_path,
):
    use_contacts = args.use_contacts
    use_msc = args.use_msc
    c_mse = args.c_mse

    if use_contacts:
        assert c_mse == 0
        contact, contact_2d, for_mask = find_contacts(
            y_data_conts, keypoints_2d, bone_to_params
        )
        if len(contact_2d) > 0:
            loss_parallel.contact_2d = contact_2d

            mask = np.zeros((spin.constants.IMG_RES, spin.constants.IMG_RES), dtype="uint8")
            mask += 255
            cv2.drawContours(mask, for_mask, -1, 0, 2)
            mask = cv2.resize(mask, img_size_original[::-1])
            cv2.imwrite(str(save_path / "mask.png"), mask)

        if len(contact) == 0:
            _, contact = sc_module.verts_in_contact(vertices, return_idx=True)
            contact = contact.cpu().numpy().ravel()
    elif use_msc:
        _, contact = sc_module.verts_in_contact(vertices, return_idx=True)
        contact = contact.cpu().numpy().ravel()
    else:
        contact = np.array([])

    return contact


def save_all(
    keypoints_3d_pred,
    rotmat_pred,
    camera_pred,
    betas_pred,
    smpl,
    contact,
    img_original,
    predicted_keypoints_2d,
    predicted_contact_heatmap_raw,
    loss_parallel,
    smpl_output,
    shift,
    scale,
    ax2,
    summary_writer,
    save_path,
    fname,
):
    keypoints_2d_pred = keypoints_3d_pred[:, :2]

    vertices = smpl_output.vertices.detach()
    betas_pred = betas_pred.detach().cpu().numpy()

    utils.save_pose_params(
        rotmat_pred,
        camera_pred,
        betas_pred,
        vertices,
        smpl,
        contact,
        save_path / f"{fname}.pkl",
    )

    if hasattr(loss_parallel, "silhuette_vertices_inds"):
        contact.append(loss_parallel.silhuette_vertices_inds)

    img_sw = utils.save_results_image(
        camera=camera_pred.detach().cpu().numpy(),
        focal_length_x=spin.constants.FOCAL_LENGTH,
        focal_length_y=spin.constants.FOCAL_LENGTH,
        vertices=vertices[0].cpu().numpy(),
        input_img=img_original,
        faces=smpl.faces,
        keypoints=predicted_keypoints_2d,
        keypoints_2=unnormalize_keypoints_from_spin(
            keypoints_2d_pred.cpu().numpy(), shift, scale, ax2
        )
        if shift is not None
        else None,
        # keypoints_2=unnormalize_keypoints_from_spin(joints_2d_orig.detach().cpu().numpy(), shift, scale, ax2) if shift is not None else None,
        heatmap=predicted_contact_heatmap_raw,
        filename=save_path / f"{fname}.png",
        contactlist=contact,
        contact2dlist=loss_parallel.contact_2d
        if hasattr(loss_parallel, "contact_2d")
        else None,
        cos=loss_parallel.cos.tolist() if loss_parallel.cos is not None else None,
    )

    utils.save_mesh_with_colors(
        smpl_output.vertices[0].cpu().numpy(),
        smpl.faces,
        save_path / f"{fname}.ply",
        inds=contact,
    )

    joints = smpl_output.joints.squeeze(0).cpu().numpy()
    fig = utils.plot_3D(joints, vertices.squeeze(0).cpu().numpy(), smpl.faces)
    fig.write_html(save_path / f"{fname}.html")

    summary_writer.add_image(
        fname, np.array(img_sw).astype("float32") / 255, dataformats="HWC"
    )
    summary_writer.add_mesh(
        fname,
        vertices=(
            vertices.cpu().float()[0]
            @ torch.tensor([[1, 0, 0], [0, -1, 0], [0, 0, -1.0]])
        ).unsqueeze(0),
        faces=torch.from_numpy(smpl.faces[None].astype("int64")),
    )


def spin_step(
    model_hmr,
    smpl,
    selector,
    input_img,
    img_original,
    predicted_keypoints_2d,
    predicted_contact_heatmap_raw,
    loss_parallel,
    shift,
    scale,
    ax2,
    summary_writer,
    save_path,
):
    with torch.no_grad():
        (
            rotmat_pred,
            betas_pred,
            camera_pred,
            keypoints_3d_pred,
            _,
            _,
            smpl_output,
            _,
            _,
        ) = get_pred_and_data(
            model_hmr,
            smpl,
            selector,
            input_img,
            zero_hands=True,
        )

    save_all(
        keypoints_3d_pred,
        rotmat_pred,
        camera_pred,
        betas_pred,
        smpl,
        None,
        img_original,
        predicted_keypoints_2d,
        predicted_contact_heatmap_raw,
        loss_parallel,
        smpl_output,
        shift,
        scale,
        ax2,
        summary_writer,
        save_path,
        "spin",
    )


def eft_step(
    model_hmr,
    smpl,
    selector,
    input_img,
    keypoints_2d,
    optimizer,
    args,
    loss_mse,
    loss_parallel,
    c_beta,
    sc_module,
    y_data_conts,
    bone_to_params,
    img_original,
    predicted_keypoints_2d,
    predicted_contact_heatmap_raw,
    shift,
    scale,
    ax2,
    summary_writer,
    save_path,
):
    img_size_original = img_original.shape[:2]
    (
        rotmat_pred,
        betas_pred,
        camera_pred,
        keypoints_3d_pred,
        _,
        smpl_output,
        _,
        _,
    ) = optimize(
        model_hmr,
        smpl,
        selector,
        input_img,
        keypoints_2d,
        optimizer,
        args,
        loss_mse=loss_mse,
        loss_parallel=loss_parallel,
        c_mse=1,
        c_new_mse=0,
        c_beta=c_beta,
        sc_crit=None,
        msc_crit=None,
        contact=None,
        n_steps=60 + 90,
        writer=summary_writer,
    )

    # find contacts
    vertices = smpl_output.vertices.detach()
    contact = get_contacts(
        args,
        sc_module,
        y_data_conts,
        keypoints_2d,
        vertices,
        bone_to_params,
        loss_parallel,
        img_size_original,
        save_path,
    )

    save_all(
        keypoints_3d_pred,
        rotmat_pred,
        camera_pred,
        betas_pred,
        smpl,
        contact,
        img_original,
        predicted_keypoints_2d,
        predicted_contact_heatmap_raw,
        loss_parallel,
        smpl_output,
        shift,
        scale,
        ax2,
        summary_writer,
        save_path,
        "eft",
    )

    if sc_module is not None:
        save_mesh_with_winding_numbers(sc_module, vertices, smpl, save_path)

    return vertices, keypoints_3d_pred, contact


def dc_step(
    model_hmr,
    smpl,
    selector,
    input_img,
    keypoints_2d,
    optimizer,
    args,
    loss_mse,
    loss_parallel,
    c_mse,
    c_new_mse,
    c_beta,
    sc_crit,
    msc_crit,
    contact,
    use_contacts,
    use_msc,
    img_original,
    predicted_keypoints_2d,
    predicted_contact_heatmap_raw,
    shift,
    scale,
    ax2,
    summary_writer,
    save_path,
):
    (
        rotmat_pred,
        betas_pred,
        camera_pred,
        keypoints_3d_pred,
        _,
        smpl_output,
        _,
        _,
    ) = optimize(
        model_hmr,
        smpl,
        selector,
        input_img,
        keypoints_2d,
        optimizer,
        args,
        loss_mse=loss_mse,
        loss_parallel=loss_parallel,
        c_mse=c_mse,
        c_new_mse=c_new_mse,
        c_beta=c_beta,
        sc_crit=sc_crit,
        msc_crit=msc_crit if use_contacts or use_msc else None,
        contact=contact if use_contacts or use_msc else None,
        n_steps=60 if use_contacts or use_msc else 0,  # + 60,
        # save_path=(img_original, predicted_keypoints_2d, save_path, shift, scale, ax2, "dc"),
        writer=summary_writer,
        i_ini=60 + 90,
    )

    save_all(
        keypoints_3d_pred,
        rotmat_pred,
        camera_pred,
        betas_pred,
        smpl,
        contact,
        img_original,
        predicted_keypoints_2d,
        predicted_contact_heatmap_raw,
        loss_parallel,
        smpl_output,
        shift,
        scale,
        ax2,
        summary_writer,
        save_path,
        "dc",
    )

    return rotmat_pred


def us_step(
    model_hmr,
    smpl,
    selector,
    input_img,
    rotmat_pred,
    keypoints_2d,
    args,
    loss_mse,
    loss_parallel,
    c_mse,
    c_new_mse,
    sc_crit,
    msc_crit,
    contact,
    use_contacts,
    use_msc,
    img_original,
    keypoints_3d_pred,
    summary_writer,
    save_path,
):
    (_, _, camera_pred_us, _, _, _, smpl_output_us, _, _,) = get_pred_and_data(
        model_hmr,
        smpl,
        selector,
        input_img,
        use_betas=False,
        zero_hands=True,
    )

    rotmat_pred_us, smpl_output_us = optimize_ft(
        rotmat_pred,
        camera_pred_us,
        smpl,
        selector,
        input_img,
        keypoints_2d,
        args,
        loss_mse=loss_mse,
        loss_parallel=loss_parallel,
        c_mse=c_mse,
        c_new_mse=c_new_mse,
        sc_crit=sc_crit,
        msc_crit=msc_crit if use_contacts or use_msc else None,
        contact=contact if use_contacts or use_msc else None,
        n_steps=60 if use_contacts or use_msc else 0,  # + 60,
        # save_path=(img_original, predicted_keypoints_2d, save_path, shift, scale, ax2, "dc"),
        writer=summary_writer,
        i_ini=60 + 90 + 60,
        zero_hands=True,
        fist=args.fist,
    )

    save_all(
        keypoints_3d_pred,
        rotmat_pred_us,
        camera_pred_us,
        torch.zeros(1, 10, dtype=torch.float32),
        smpl,
        None,
        img_original,
        None,
        None,
        loss_parallel,
        smpl_output_us,
        None,
        None,
        None,
        summary_writer,
        save_path,
        "us",
    )


def main():
    args = parse_args()
    print(args)

    # models
    model_pose = cv2.dnn.readNetFromONNX(
        args.pose_estimation_model_path
    )  # "hrn_w48_384x288.onnx"
    model_contact = cv2.dnn.readNetFromONNX(
        args.contact_model_path
    )  # "contact_hrn_w32_256x192.onnx"

    device = (
        torch.device(args.device) if torch.cuda.is_available() else torch.device("cpu")
    )
    model_hmr = spin.hmr(args.smpl_mean_params_path)  # "smpl_mean_params.npz"
    model_hmr.to(device)
    checkpoint = torch.load(
        args.spin_model_path,  # "spin_model_smplx_eft_18.pt"
        map_location="cpu"
    )

    smpl = spin.SMPLX(
        args.smpl_model_dir,  # "models/smplx"
        batch_size=1,
        create_transl=False,
        use_pca=False,
        flat_hand_mean=args.fist is not None,
    )
    smpl.to(device)

    selector = get_selector()

    use_contacts = args.use_contacts
    use_msc = args.use_msc

    bone_to_params = np.load(args.bone_parametrization_path, allow_pickle=True).item()
    foot_inds = np.load(args.foot_inds_path, allow_pickle=True).item()
    left_foot_inds = foot_inds["left_foot_inds"]
    right_foot_inds = foot_inds["right_foot_inds"]

    if use_contacts:
        model_type = args.smpl_type
        sc_module = selfcontact.SelfContact(
            essentials_folder=args.essentials_dir,  # "smplify-xmc-essentials"
            geothres=0.3,
            euclthres=0.02,
            test_segments=True,
            compute_hd=True,
            model_type=model_type,
            device=device,
        )
        sc_module.to(device)

        sc_crit = selfcontact.losses.SelfContactLoss(
            contact_module=sc_module,
            inside_loss_weight=0.5,
            outside_loss_weight=0.0,
            contact_loss_weight=0.5,
            align_faces=True,
            use_hd=True,
            test_segments=True,
            device=device,
            model_type=model_type,
        )
        sc_crit.to(device)

        msc_crit = losses.MimickedSelfContactLoss(geodesics_mask=sc_module.geomask)
        msc_crit.to(device)
    else:
        sc_module = None
        sc_crit = None
        msc_crit = None

    loss_mse = losses.MSE([1, 10, 13])  # Neck + Right Upper Leg + Left Upper Leg

    ignore = (
        (1, 2),  # Neck + Right Shoulder
        (1, 5),  # Neck + Left Shoulder
        (9, 10),  # Hips + Right Upper Leg
        (9, 13),  # Hips + Left Upper Leg
    )
    loss_parallel = losses.Parallel(
        skeleton=pose_estimation.SKELETON,
        ignore=ignore,
    )

    c_mse = args.c_mse
    c_new_mse = args.c_par
    c_beta = 1e-3

    if c_mse > 0:
        assert c_new_mse == 0
    elif c_mse == 0:
        assert c_new_mse > 0

    root_path = Path(args.save_path)
    root_path.mkdir(exist_ok=True, parents=True)

    path_to_imgs = Path(args.img_path)
    if path_to_imgs.is_dir():
        path_to_imgs = path_to_imgs.iterdir()
    else:
        path_to_imgs = [path_to_imgs]

    for img_path in path_to_imgs:
        if not any(
            img_path.name.lower().endswith(ext) for ext in [".jpg", ".png", ".jpeg"]
        ):
            continue

        img_name = img_path.stem

        # use 2d keypoints detection
        (
            img_original,
            predicted_keypoints_2d,
            _,
            _,
        ) = pose_estimation.infer_single_image(
            model_pose,
            img_path,
            input_img_size=pose_estimation.IMG_SIZE,
            return_kps=True,
        )

        save_path = root_path / img_name
        save_path.mkdir(exist_ok=True, parents=True)
        # if (save_path / "us_orig.png").is_file():
        #     return

        summary_writer = SummaryWriter(log_dir=save_path / f"runDoknc2_{c_new_mse}")

        img_original = cv2.cvtColor(img_original, cv2.COLOR_BGR2RGB)
        img_size_original = img_original.shape[:2]
        keypoints_2d, shift, scale, ax2 = normalize_keypoints_to_spin(
            predicted_keypoints_2d, img_size_original
        )
        keypoints_2d = torch.from_numpy(keypoints_2d)
        keypoints_2d = keypoints_2d.to(device)

        (
            predicted_contact_heatmap,
            predicted_contact_heatmap_raw,
            very_hm_raw,
        ) = get_contact_heatmap(model_contact, img_path)
        predicted_contact_heatmap_raw = Image.fromarray(
            predicted_contact_heatmap_raw
        ).resize(img_size_original[::-1])
        predicted_contact_heatmap_raw = cv2.resize(very_hm_raw, img_size_original[::-1])

        if c_new_mse == 0:
            predicted_contact_heatmap_raw = None

        y_data_conts = get_vertices_in_heatmap(predicted_contact_heatmap)

        model_hmr.load_state_dict(checkpoint["model"], strict=True)
        model_hmr.train()
        freeze_layers(model_hmr)

        _, input_img = spin.process_image(img_path, input_res=spin.constants.IMG_RES)
        input_img = input_img.to(device)

        spin_step(
            model_hmr,
            smpl,
            selector,
            input_img,
            img_original,
            predicted_keypoints_2d,
            predicted_contact_heatmap_raw,
            loss_parallel,
            shift,
            scale,
            ax2,
            summary_writer,
            save_path,
        )

        optimizer = optim.Adam(
            filter(lambda p: p.requires_grad, model_hmr.parameters()),
            lr=1e-6,
        )

        vertices, keypoints_3d_pred, contact = eft_step(
            model_hmr,
            smpl,
            selector,
            input_img,
            keypoints_2d,
            optimizer,
            args,
            loss_mse,
            loss_parallel,
            c_beta,
            sc_module,
            y_data_conts,
            bone_to_params,
            img_original,
            predicted_keypoints_2d,
            predicted_contact_heatmap_raw,
            shift,
            scale,
            ax2,
            summary_writer,
            save_path,
        )

        if args.use_natural:
            get_natural(
                keypoints_2d, vertices, right_foot_inds, left_foot_inds, loss_parallel, smpl,
            )

        if args.use_cos:
            cos_r = get_cos(keypoints_3d_pred, args.use_angle_transf, loss_parallel)
            np.save(save_path / "cos_hist", cos_r.cpu().numpy())

        rotmat_pred = dc_step(
            model_hmr,
            smpl,
            selector,
            input_img,
            keypoints_2d,
            optimizer,
            args,
            loss_mse,
            loss_parallel,
            c_mse,
            c_new_mse,
            c_beta,
            sc_crit,
            msc_crit,
            contact,
            use_contacts,
            use_msc,
            img_original,
            predicted_keypoints_2d,
            predicted_contact_heatmap_raw,
            shift,
            scale,
            ax2,
            summary_writer,
            save_path,
        )

        us_step(
            model_hmr,
            smpl,
            selector,
            input_img,
            rotmat_pred,
            keypoints_2d,
            args,
            loss_mse,
            loss_parallel,
            c_mse,
            c_new_mse,
            sc_crit,
            msc_crit,
            contact,
            use_contacts,
            use_msc,
            img_original,
            keypoints_3d_pred,
            summary_writer,
            save_path,
        )


if __name__ == "__main__":
    main()


================================================
FILE: src/pose_estimation.py
================================================
import math

import cv2
import numpy as np

IMG_SIZE = (288, 384)
MEAN = np.array([0.485, 0.456, 0.406])
STD = np.array([0.229, 0.224, 0.225])

KPS = (
    "Head",
    "Neck",
    "Right Shoulder",
    "Right Arm",
    "Right Hand",
    "Left Shoulder",
    "Left Arm",
    "Left Hand",
    "Spine",
    "Hips",
    "Right Upper Leg",
    "Right Leg",
    "Right Foot",
    "Left Upper Leg",
    "Left Leg",
    "Left Foot",
    "Left Toe",
    "Right Toe",
)

SKELETON = (
    (0, 1),
    (1, 8),
    (8, 9),
    (9, 10),
    (9, 13),
    (10, 11),
    (11, 12),
    (13, 14),
    (14, 15),
    (1, 2),
    (2, 3),
    (3, 4),
    (1, 5),
    (5, 6),
    (6, 7),
    (15, 16),
    (12, 17),
)


OPENPOSE_TO_GESTURE = (
    0,  # 0 Head\n",
    1,  #   Neck\n",
    2,  # 2 Right Shoulder\n",
    3,  #   Right Arm\n",
    4,  # 4 Right Hand\n",
    5,  #   Left Shoulder\n",
    6,  # 6 Left Arm\n",
    7,  #   Left Hand\n",
    9,  # 8 Hips\n",
    10,  #   Right Upper Leg\n",
    11,  # 10Right Leg\n",
    12,  #   Right Foot\n",
    13,  # 12Left Upper Leg\n",
    14,  #   Left Leg\n",
    15,  # 14Left Foot\n",
    -1,  # \n",
    -1,  # 16\n",
    -1,  # \n",
    -1,  # 18\n",
    16,  #   Left Toe\n",
    -1,  # 20\n",
    -1,  # \n",
    17,  # 22Right Toe\n",
    -1,  # \n",
    -1,  # 24\n",
)


def transform(img):
    img = img.astype("float32") / 255

    img = (img - MEAN) / STD

    return np.transpose(img, axes=(2, 0, 1))


def get_affine_transform(
    center,
    scale,
    rot,
    output_size,
    shift=np.array([0, 0], dtype=np.float32),
    inv=0,
    pixel_std=200,
):
    if not isinstance(scale, np.ndarray) and not isinstance(scale, list):
        scale = np.array([scale, scale])

    scale_tmp = scale * pixel_std
    src_w = scale_tmp[0]
    dst_w = output_size[0]
    dst_h = output_size[1]

    rot_rad = np.pi * rot / 180
    src_dir = get_dir([0, src_w * -0.5], rot_rad)
    dst_dir = np.array([0, dst_w * -0.5], np.float32)
    src = np.zeros((3, 2), dtype=np.float32)
    dst = np.zeros((3, 2), dtype=np.float32)
    src[0, :] = center + scale_tmp * shift
    src[1, :] = center + src_dir + scale_tmp * shift
    dst[0, :] = [dst_w * 0.5, dst_h * 0.5]
    dst[1, :] = np.array([dst_w * 0.5, dst_h * 0.5]) + dst_dir

    src[2:, :] = get_3rd_point(src[0, :], src[1, :])
    dst[2:, :] = get_3rd_point(dst[0, :], dst[1, :])

    if inv:
        trans = cv2.getAffineTransform(np.float32(dst), np.float32(src))
    else:
        trans = cv2.getAffineTransform(np.float32(src), np.float32(dst))

    return trans


def get_3rd_point(a, b):
    direct = a - b
    return b + np.array([-direct[1], direct[0]], dtype=np.float32)


def get_dir(src_point, rot_rad):
    sn, cs = np.sin(rot_rad), np.cos(rot_rad)

    src_result = [0, 0]
    src_result[0] = src_point[0] * cs - src_point[1] * sn
    src_result[1] = src_point[0] * sn + src_point[1] * cs

    return src_result


def process_image(path, input_img_size, pixel_std=200):
    data_numpy = cv2.imread(path, cv2.IMREAD_COLOR | cv2.IMREAD_IGNORE_ORIENTATION)
    # BUG HERE. Must be uncommented
    # data_numpy = cv2.cvtColor(data_numpy, cv2.COLOR_BGR2RGB)

    h, w = data_numpy.shape[:2]
    c = np.array([w / 2, h / 2], dtype=np.float32)

    aspect_ratio = input_img_size[0] / input_img_size[1]
    if w > aspect_ratio * h:
        h = w * 1.0 / aspect_ratio
    elif w < aspect_ratio * h:
        w = h * aspect_ratio

    s = np.array([w / pixel_std, h / pixel_std], dtype=np.float32) * 1.25
    r = 0
    trans = get_affine_transform(c, s, r, input_img_size, pixel_std=pixel_std)
    input = cv2.warpAffine(data_numpy, trans, input_img_size, flags=cv2.INTER_LINEAR)

    input = transform(input)

    return input, data_numpy, c, s


def get_final_preds(batch_heatmaps, center, scale, post_process=False):
    coords, maxvals = get_max_preds(batch_heatmaps)

    heatmap_height = batch_heatmaps.shape[2]
    heatmap_width = batch_heatmaps.shape[3]

    # post-processing
    if post_process:
        for n in range(coords.shape[0]):
            for p in range(coords.shape[1]):
                hm = batch_heatmaps[n][p]
                px = int(math.floor(coords[n][p][0] + 0.5))
                py = int(math.floor(coords[n][p][1] + 0.5))
                if 1 < px < heatmap_width - 1 and 1 < py < heatmap_height - 1:
                    diff = np.array(
                        [
                            hm[py][px + 1] - hm[py][px - 1],
                            hm[py + 1][px] - hm[py - 1][px],
                        ]
                    )
                    coords[n][p] += np.sign(diff) * 0.25

    preds = coords.copy()

    # Transform back
    for i in range(coords.shape[0]):
        preds[i] = transform_preds(
            coords[i], center[i], scale[i], [heatmap_width, heatmap_height]
        )

    return preds, maxvals


def transform_preds(coords, center, scale, output_size):
    target_coords = np.zeros(coords.shape)
    trans = get_affine_transform(center, scale, 0, output_size, inv=1)
    for p in range(coords.shape[0]):
        target_coords[p, 0:2] = affine_transform(coords[p, 0:2], trans)
    return target_coords


def affine_transform(pt, t):
    new_pt = np.array([pt[0], pt[1], 1.0]).T
    new_pt = np.dot(t, new_pt)
    return new_pt[:2]


def get_max_preds(batch_heatmaps):
    """
    get predictions from score maps
    heatmaps: numpy.ndarray([batch_size, num_joints, height, width])
    """
    assert isinstance(
        batch_heatmaps, np.ndarray
    ), "batch_heatmaps should be numpy.ndarray"
    assert batch_heatmaps.ndim == 4, "batch_images should be 4-ndim"

    batch_size = batch_heatmaps.shape[0]
    num_joints = batch_heatmaps.shape[1]
    width = batch_heatmaps.shape[3]
    heatmaps_reshaped = batch_heatmaps.reshape((batch_size, num_joints, -1))
    idx = np.argmax(heatmaps_reshaped, 2)
    maxvals = np.amax(heatmaps_reshaped, 2)

    maxvals = maxvals.reshape((batch_size, num_joints, 1))
    idx = idx.reshape((batch_size, num_joints, 1))

    preds = np.tile(idx, (1, 1, 2)).astype(np.float32)

    preds[:, :, 0] = (preds[:, :, 0]) % width
    preds[:, :, 1] = np.floor((preds[:, :, 1]) / width)

    pred_mask = np.tile(np.greater(maxvals, 0.0), (1, 1, 2))
    pred_mask = pred_mask.astype(np.float32)

    preds *= pred_mask
    return preds, maxvals


def infer_single_image(model, img_path, input_img_size=(288, 384), return_kps=True):
    img_path = str(img_path)
    pose_input, img, center, scale = process_image(
        img_path, input_img_size=input_img_size
    )
    model.setInput(pose_input[None])
    predicted_heatmap = model.forward()

    if not return_kps:
        return predicted_heatmap.squeeze(0)

    predicted_keypoints, confidence = get_final_preds(
        predicted_heatmap, center[None], scale[None], post_process=True
    )

    (predicted_keypoints, confidence, predicted_heatmap,) = (
        predicted_keypoints.squeeze(0),
        confidence.squeeze(0),
        predicted_heatmap.squeeze(0),
    )

    return img, predicted_keypoints, confidence, predicted_heatmap


================================================
FILE: src/renderer.py
================================================
import numpy as np
import pyrender
import torch
import trimesh
from torchvision.utils import make_grid


class Renderer:
    """
    Renderer used for visualizing the SMPL model
    Code adapted from https://github.com/vchoutas/smplify-x
    """

    def __init__(self, focal_length=5000, img_res=224, faces=None):
        self.renderer = pyrender.OffscreenRenderer(
            viewport_width=img_res, viewport_height=img_res, point_size=1.0
        )
        self.focal_length = focal_length
        self.camera_center = [img_res // 2, img_res // 2]
        self.faces = faces

    def visualize_tb(self, vertices, camera_translation, images):
        vertices = vertices.cpu().numpy()
        camera_translation = camera_translation.cpu().numpy()
        images = images.cpu()
        images_np = np.transpose(images.numpy(), (0, 2, 3, 1))
        rend_imgs = []
        for i in range(vertices.shape[0]):
            rend_img = torch.from_numpy(
                np.transpose(
                    self.__call__(vertices[i], camera_translation[i], images_np[i]),
                    (2, 0, 1),
                )
            ).float()
            rend_imgs.append(images[i])
            rend_imgs.append(rend_img)
        rend_imgs = make_grid(rend_imgs, nrow=2)
        return rend_imgs

    def __call__(self, vertices, camera_translation, image):
        material = pyrender.MetallicRoughnessMaterial(
            metallicFactor=0.2, alphaMode="OPAQUE", baseColorFactor=(0.8, 0.3, 0.3, 1.0)
        )

        camera_translation[0] *= -1.0

        mesh = trimesh.Trimesh(vertices, self.faces)
        rot = trimesh.transformations.rotation_matrix(np.radians(180), [1, 0, 0])
        mesh.apply_transform(rot)
        mesh = pyrender.Mesh.from_trimesh(mesh, material=material)

        scene = pyrender.Scene(ambient_light=(0.5, 0.5, 0.5))
        scene.add(mesh, "mesh")

        camera_pose = np.eye(4)
        camera_pose[:3, 3] = camera_translation
        camera = pyrender.IntrinsicsCamera(
            fx=self.focal_length,
            fy=self.focal_length,
            cx=self.camera_center[0],
            cy=self.camera_center[1],
        )
        scene.add(camera, pose=camera_pose)

        light = pyrender.DirectionalLight(color=[1.0, 1.0, 1.0], intensity=1)
        light_pose = np.eye(4)

        light_pose[:3, 3] = np.array([0, -1, 1])
        scene.add(light, pose=light_pose)

        light_pose[:3, 3] = np.array([0, 1, 1])
        scene.add(light, pose=light_pose)

        light_pose[:3, 3] = np.array([1, 1, 2])
        scene.add(light, pose=light_pose)

        color, rend_depth = self.renderer.render(scene, flags=pyrender.RenderFlags.RGBA)
        color = color.astype(np.float32) / 255.0
        valid_mask = (rend_depth > 0)[:, :, None]
        output_img = color[:, :, :3] * valid_mask + (1 - valid_mask) * image
        return output_img


def overlay_mesh(
    verts,
    faces,
    camera_transl,
    focal_length_x,
    focal_length_y,
    camera_center,
    H,
    W,
    img,
    camera_rotation=None,
    rotaround=None,
    contactlist=None,
    color=False,
    scale=1,
):

    material = pyrender.MetallicRoughnessMaterial(
        metallicFactor=0.0, alphaMode="OPAQUE", baseColorFactor=(1.0, 1.0, 0.9, 1.0)
    )
    out_mesh = trimesh.Trimesh(verts, faces, process=False)
    out_mesh_col = np.array(out_mesh.visual.vertex_colors)

    if contactlist is not None and len(contactlist) > 0:
        color = [255, 0, 0, 255]
        out_mesh_col[contactlist] = color
        out_mesh.visual.vertex_colors = out_mesh_col

    if camera_rotation is None:
        camera_rotation = np.eye(3)
    else:
        camera_rotation = camera_rotation[0]

    # rotate mesh and stack output images
    if rotaround is None:
        out_mesh.vertices = np.matmul(verts, camera_rotation.T) + camera_transl
    else:
        base_mesh = trimesh.Trimesh(verts, faces, process=False)
        # rot_center = (base_mesh.vertices[5615] + base_mesh.vertices[5614] ) / 2
        rot = trimesh.transformations.rotation_matrix(
            np.radians(rotaround), [0, 1, 0], base_mesh.vertices[4297]
        )
        base_mesh.apply_transform(rot)
        out_mesh.vertices = (
            np.matmul(base_mesh.vertices, camera_rotation.T) + camera_transl
        )

    out_mesh.vertices += np.array([0, 0, 50])
    # add mesh to scene
    mesh = pyrender.Mesh.from_trimesh(
        out_mesh,
        material=material,
        smooth=False,
    )
    if img is not None:
        scene = pyrender.Scene(
            bg_color=[0.0, 0.0, 0.0, 0.0],
            ambient_light=(0.3, 0.3, 0.3, 1.0),
        )
    else:
        scene = pyrender.Scene(
            bg_color=[1.0, 1.0, 1.0, 1.0],
            ambient_light=(0.3, 0.3, 0.3, 1.0),
        )
    scene.add(mesh, "mesh")

    # create and add camera
    camera_pose = np.eye(4)
    camera_pose[1, :] = -camera_pose[1, :]
    camera_pose[2, :] = -camera_pose[2, :]
    pyrencamera = pyrender.camera.OrthographicCamera(
        camera_transl[2],
        camera_transl[2],
        znear=1e-6,
        zfar=1000000,
    )
    scene.add(pyrencamera, pose=camera_pose)

    # create and add light
    light = pyrender.PointLight(
        color=[1.0, 1.0, 1.0],
        intensity=1,
    )
    light_pose = np.eye(4)
    for lp in [[1, 1, -1], [-1, 1, -1], [1, -1, -1], [-1, -1, -1]]:
        light_pose[:3, 3] = out_mesh.vertices.mean(0) + np.array(lp)
        scene.add(light, pose=light_pose)

    r = pyrender.OffscreenRenderer(
        viewport_width=int(scale * W),
        viewport_height=int(scale * H),
        point_size=1.0,
    )
    color, _ = r.render(scene, flags=pyrender.RenderFlags.RGBA)
    color = color.astype(np.float32) / 255.0

    if img is not None:
        valid_mask = (color[:, :, -1] > 0)[:, :, np.newaxis]
        output_img = color[:, :, :-1] * valid_mask + (1 - valid_mask) * img
    else:
        output_img = color

    output_img = (output_img * 255).astype(np.uint8)[..., :3]

    return output_img


================================================
FILE: src/spin/__init__.py
================================================
from .constants import JOINT_NAMES
from .hmr import hmr
from .smpl import SMPLX
from .utils import process_image

__all__ = [
    "hmr",
    "SMPLX",
    "process_image",
    "JOINT_NAMES",
]


================================================
FILE: src/spin/constants.py
================================================
FOCAL_LENGTH = 5000.0
IMG_RES = 224

# Mean and standard deviation for normalizing input image
IMG_NORM_MEAN = [0.485, 0.456, 0.406]
IMG_NORM_STD = [0.229, 0.224, 0.225]

"""
We create a superset of joints containing the OpenPose joints together with the ones that each dataset provides.
We keep a superset of 24 joints such that we include all joints from every dataset.
If a dataset doesn't provide annotations for a specific joint, we simply ignore it.
The joints used here are the following:
"""
JOINT_NAMES = (
    "Hips",
    "Left Upper Leg",
    "Right Upper Leg",
    "Spine",
    "Left Leg",
    "Right Leg",
    "Spine1",
    "Left Foot",
    "Right Foot",
    "Thorax",
    "Left Toe",
    "Right Toe",
    "Neck",
    "Left Shoulder",
    "Right Shoulder",
    "Head",
    "Left ForeArm",
    "Right ForeArm",
    "Left Arm",
    "Right Arm",
    "Left Hand",
    "Right Hand",
    # 25 OpenPose joints (in the order provided by OpenPose)
    # "OP Nose",
    # "OP Neck",
    # "OP RShoulder",
    # "OP RElbow",
    # "OP RWrist",
    # "OP LShoulder",
    # "OP LElbow",
    # "OP LWrist",
    # "OP MidHip",
    # "OP RHip",
    # "OP RKnee",
    # "OP RAnkle",
    # "OP LHip",
    # "OP LKnee",
    # "OP LAnkle",
    # "OP REye",
    # "OP LEye",
    # "OP REar",
    # "OP LEar",
    # "OP LBigToe",
    # "OP LSmallToe",
    # "OP LHeel",
    # "OP RBigToe",
    # "OP RSmallToe",
    # "OP RHeel",
    ## 24 Ground Truth joints (superset of joints from different datasets)
    # "Right Ankle",
    # "Right Knee",
    # "Right Hip",
    # "Left Hip",
    # "Left Knee",
    # "Left Ankle",
    # "Right Wrist",
    # "Right Elbow",
    # "Right Shoulder",
    # "Left Shoulder",
    # "Left Elbow",
    # "Left Wrist",
    # "Neck (LSP)",
    # "Top of Head (LSP)",
    # "Pelvis (MPII)",
    # "Thorax (MPII)",
    # "Spine (H36M)",
    # "Jaw (H36M)",
    # "Head (H36M)",
    # "Nose",
    # "Left Eye",
    # "Right Eye",
    # "Left Ear",
    # "Right Ear",
    # "OP MidHip",
    # "Spine1",
    # "Spine2",
    # "Spine3",
    # "OP Neck",
    # "Head",
)

# Dict containing the joints in numerical order
JOINT_IDS = {JOINT_NAMES[i]: i for i in range(len(JOINT_NAMES))}

# Map joints to SMPL joints
JOINT_MAP = {
    "Hips": 0,
    "Left Upper Leg": 1,
    "Right Upper Leg": 2,
    "Spine": 3,
    "Left Leg": 4,
    "Right Leg": 5,
    "Spine1": 6,
    "Left Foot": 7,
    "Right Foot": 8,
    "Thorax": 9,
    "Left Toe": 10,
    "Right Toe": 11,
    "Neck": 12,
    "Left Shoulder": 13,
    "Right Shoulder": 14,
    "Head": 15,
    "Left ForeArm": 16,
    "Right ForeArm": 17,
    "Left Arm": 18,
    "Right Arm": 19,
    "Left Hand": 20,
    "Right Hand": 21,
    # "OP Nose": 24,
    # "OP Neck": 12,
    # "OP RShoulder": 17,
    # "OP RElbow": 19,
    # "OP RWrist": 21,
    # "OP LShoulder": 16,
    # "OP LElbow": 18,
    # "OP LWrist": 20,
    # "OP MidHip": 0,
    # "OP RHip": 2,
    # "OP RKnee": 5,
    # "OP RAnkle": 8,
    # "OP LHip": 1,
    # "OP LKnee": 4,
    # "OP LAnkle": 7,
    # "OP REye": 25,
    # "OP LEye": 26,
    # "OP REar": 27,
    # "OP LEar": 28,
    # "OP LBigToe": 29,
    # "OP LSmallToe": 30,
    # "OP LHeel": 31,
    # "OP RBigToe": 32,
    # "OP RSmallToe": 33,
    # "OP RHeel": 34,
    # "Right Ankle": 8,
    # "Right Knee": 5,
    # "Right Hip": 45,
    # "Left Hip": 46,
    # "Left Knee": 4,
    # "Left Ankle": 7,
    # "Right Wrist": 21,
    # "Right Elbow": 19,
    # "Right Shoulder": 17,
    # "Left Shoulder": 16,
    # "Left Elbow": 18,
    # "Left Wrist": 20,
    # "Neck (LSP)": 47,
    # "Top of Head (LSP)": 15, # 48,
    # "Pelvis (MPII)": 49,
    # "Thorax (MPII)": 50,
    # "Spine (H36M)": 51,
    # "Jaw (H36M)": 52,
    # "Head (H36M)": 15, # 53,
    # "Nose": 24,
    # "Left Eye": 26,
    # "Right Eye": 25,
    # "Left Ear": 28,
    # "Right Ear": 27,
    # "Spine1": 3,
    # "Spine2": 6,
    # "Spine3": 9,
    # "Head": 15,
}

# Joint selectors
# Indices to get the 14 LSP joints from the 17 H36M joints
H36M_TO_J17 = [6, 5, 4, 1, 2, 3, 16, 15, 14, 11, 12, 13, 8, 10, 0, 7, 9]
H36M_TO_J14 = H36M_TO_J17[:14]
# Indices to get the 14 LSP joints from the ground truth joints
J24_TO_J17 = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 18, 14, 16, 17]
J24_TO_J14 = J24_TO_J17[:14]

# Permutation of SMPL pose parameters when flipping the shape
SMPL_JOINTS_FLIP_PERM = [
    0,
    2,
    1,
    3,
    5,
    4,
    6,
    8,
    7,
    9,
    11,
    10,
    12,
    14,
    13,
    15,
    17,
    16,
    19,
    18,
    21,
    20,
    23,
    22,
]
SMPL_POSE_FLIP_PERM = []
for i in SMPL_JOINTS_FLIP_PERM:
    SMPL_POSE_FLIP_PERM.append(3 * i)
    SMPL_POSE_FLIP_PERM.append(3 * i + 1)
    SMPL_POSE_FLIP_PERM.append(3 * i + 2)
# Permutation indices for the 24 ground truth joints
J24_FLIP_PERM = [
    5,
    4,
    3,
    2,
    1,
    0,
    11,
    10,
    9,
    8,
    7,
    6,
    12,
    13,
    14,
    15,
    16,
    17,
    18,
    19,
    21,
    20,
    23,
    22,
]
# Permutation indices for the full set of 49 joints
J49_FLIP_PERM = [
    0,
    1,
    5,
    6,
    7,
    2,
    3,
    4,
    8,
    12,
    13,
    14,
    9,
    10,
    11,
    16,
    15,
    18,
    17,
    22,
    23,
    24,
    19,
    20,
    21,
] + [25 + i for i in J24_FLIP_PERM]


================================================
FILE: src/spin/hmr.py
================================================
import math

import numpy as np
import torch
import torch.nn as nn
import torchvision.models.resnet as resnet


def rot6d_to_rotmat(x):
    """Convert 6D rotation representation to 3x3 rotation matrix.
    Based on Zhou et al., "On the Continuity of Rotation Representations in Neural Networks", CVPR 2019
    Input:
        (B,6) Batch of 6-D rotation representations
    Output:
        (B,3,3) Batch of corresponding rotation matrices
    """

    x = x.view(-1, 3, 2)
    a1 = x[:, :, 0]
    a2 = x[:, :, 1]
    b1 = nn.functional.normalize(a1)
    b2 = nn.functional.normalize(
        a2 - torch.einsum("bi,bi->b", b1, a2).unsqueeze(-1) * b1
    )

    b3 = torch.cross(b1, b2)

    return torch.stack((b1, b2, b3), dim=-1)


class Bottleneck(nn.Module):
    """Redefinition of Bottleneck residual block
    Adapted from the official PyTorch implementation
    """

    expansion = 4

    def __init__(self, inplanes, planes, stride=1, downsample=None):
        super(Bottleneck, self).__init__()
        self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(
            planes, planes, kernel_size=3, stride=stride, padding=1, bias=False
        )
        self.bn2 = nn.BatchNorm2d(planes)
        self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm2d(planes * 4)
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual
        out = self.relu(out)

        return out


class HMR(nn.Module):
    """SMPL Iterative Regressor with ResNet50 backbone"""

    def __init__(self, block, layers, smpl_mean_params):
        self.inplanes = 64
        super(HMR, self).__init__()
        self.n_shape = 10
        self.n_cam = 3
        self.n_joints = 24
        npose = self.n_joints * 6
        self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.layer1 = self._make_layer(block, 64, layers[0])
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
        self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
        self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
        self.avgpool = nn.AvgPool2d(7, stride=1)
        self.fc1 = nn.Linear(512 * block.expansion + npose + self.n_shape + self.n_cam, 1024)
        self.drop1 = nn.Dropout()
        self.fc2 = nn.Linear(1024, 1024)
        self.drop2 = nn.Dropout()
        self.decpose = nn.Linear(1024, npose)
        self.decshape = nn.Linear(1024, self.n_shape)
        self.deccam = nn.Linear(1024, self.n_cam)
        nn.init.xavier_uniform_(self.decpose.weight, gain=0.01)
        nn.init.xavier_uniform_(self.decshape.weight, gain=0.01)
        nn.init.xavier_uniform_(self.deccam.weight, gain=0.01)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2.0 / n))
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()

        mean_params = np.load(smpl_mean_params)
        init_pose = torch.from_numpy(mean_params["pose"][:]).unsqueeze(0)
        init_shape = torch.from_numpy(
            mean_params["shape"][:].astype("float32")
        ).unsqueeze(0)
        init_cam = torch.from_numpy(mean_params["cam"]).unsqueeze(0)
        self.register_buffer("init_pose", init_pose)
        self.register_buffer("init_shape", init_shape)
        self.register_buffer("init_cam", init_cam)

    def _make_layer(self, block, planes, blocks, stride=1):
        downsample = None
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                nn.Conv2d(
                    self.inplanes,
                    planes * block.expansion,
                    kernel_size=1,
                    stride=stride,
                    bias=False,
                ),
                nn.BatchNorm2d(planes * block.expansion),
            )

        layers = []
        layers.append(block(self.inplanes, planes, stride, downsample))
        self.inplanes = planes * block.expansion
        for _ in range(1, blocks):
            layers.append(block(self.inplanes, planes))

        return nn.Sequential(*layers)

    def forward(self, x, init_pose=None, init_shape=None, init_cam=None, n_iter=3):

        batch_size = x.shape[0]

        if init_pose is None:
            init_pose = self.init_pose.expand(batch_size, -1)
        if init_shape is None:
            init_shape = self.init_shape.expand(batch_size, -1)
        if init_cam is None:
            init_cam = self.init_cam.expand(batch_size, -1)

        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x1 = self.layer1(x)
        x2 = self.layer2(x1)
        x3 = self.layer3(x2)
        x4 = self.layer4(x3)

        xf = self.avgpool(x4)
        xf = xf.view(xf.size(0), -1)

        pred_pose = init_pose
        pred_shape = init_shape
        pred_cam = init_cam
        for _ in range(n_iter):
            xc = torch.cat([xf, pred_pose, pred_shape, pred_cam], 1)
            xc = self.fc1(xc)
            xc = self.drop1(xc)
            xc = self.fc2(xc)
            xc = self.drop2(xc)
            pred_pose = self.decpose(xc) + pred_pose
            pred_shape = self.decshape(xc) + pred_shape
            pred_cam = self.deccam(xc) + pred_cam

        pred_rotmat = rot6d_to_rotmat(pred_pose).view(batch_size, self.n_joints, 3, 3)

        return pred_rotmat, pred_shape, pred_cam


def hmr(smpl_mean_params, pretrained=True, **kwargs):
    """Constructs an HMR model with ResNet50 backbone.
    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
    """
    model = HMR(Bottleneck, [3, 4, 6, 3], smpl_mean_params, **kwargs)
    if pretrained:
        resnet_imagenet = resnet.resnet50(pretrained=True)
        model.load_state_dict(resnet_imagenet.state_dict(), strict=False)
    return model


================================================
FILE: src/spin/smpl.py
================================================
import numpy as np
import torch
from smplx import SMPL as _SMPL
from smplx import SMPLX as _SMPLX
from smplx.body_models import SMPLOutput, SMPLXOutput
from smplx.lbs import vertices2joints

from .constants import JOINT_MAP, JOINT_NAMES

# Hand joints
SMPLX_HAND_TO_PANOPTIC = [
    0,
    13,
    14,
    15,
    16,
    1,
    2,
    3,
    17,
    4,
    5,
    6,
    18,
    10,
    11,
    12,
    19,
    7,
    8,
    9,
    20,
]  # Wrist Thumb to Pinky



class SMPL(_SMPL):
    """Extension of the official SMPL implementation to support more joints"""

    JOINTS = (
        "Hips",
        "Left Upper Leg",
        "Right Upper Leg",
        "Spine",
        "Left Leg",
        "Right Leg",
        "Spine1",
        "Left Foot",
        "Right Foot",
        "Thorax",
        "Left Toe",
        "Right Toe",
        "Neck",
        "Left Shoulder",
        "Right Shoulder",
        "Head",
        "Left ForeArm",
        "Right ForeArm",
        "Left Arm",
        "Right Arm",
        "Left Hand",
        "Right Hand",
        "Left Finger",
        "Right Finger",
    )

    SKELETON = (
        (0, 1),
        (0, 2),
        (0, 3),
        (1, 4),
        (2, 5),
        (3, 6),
        (4, 7),
        (5, 8),
        (6, 9),
        (7, 10),
        (8, 11),
        (9, 12),
        (12, 13),
        (12, 14),
        (12, 15),
        (13, 16),
        (14, 17),
        (16, 18),
        (17, 19),
        (18, 20),
        (19, 21),
        (20, 22),
        (21, 23),
    )

    def __init__(self, *args, **kwargs):
        super(SMPL, self).__init__(*args, **kwargs)
        joints = [JOINT_MAP[i] for i in JOINT_NAMES]
        joint_regressor_extra = kwargs["joint_regressor_extra_path"]
        J_regressor_extra = np.load(joint_regressor_extra)
        self.register_buffer(
            "J_regressor_extra", torch.tensor(J_regressor_extra, dtype=torch.float32)
        )
        self.joint_map = torch.tensor(joints, dtype=torch.long)

    def forward(self, *args, **kwargs):
        kwargs["get_skin"] = True
        smpl_output = super(SMPL, self).forward(*args, **kwargs)
        extra_joints = vertices2joints(
            self.J_regressor_extra, smpl_output.vertices
        )  # Additional 9 joints #Check doc/J_regressor_extra.png
        joints = torch.cat(
            [smpl_output.joints, extra_joints], dim=1
        )  # [N, 24 + 21, 3]  + [N, 9, 3]
        joints = joints[:, self.joint_map, :]
        output = SMPLOutput(
            vertices=smpl_output.vertices,
            global_orient=smpl_output.global_orient,
            body_pose=smpl_output.body_pose,
            joints=joints,
            betas=smpl_output.betas,
            full_pose=smpl_output.full_pose,
        )
        return output


class SMPLX(_SMPLX):
    """Extension of the official SMPL implementation to support more joints"""

    JOINTS = (
        "Hips",
        "Left Upper Leg",
        "Right Upper Leg",
        "Spine",
        "Left Leg",
        "Right Leg",
        "Spine1",
        "Left Foot",
        "Right Foot",
        "Thorax",
        "Left Toe",
        "Right Toe",
        "Neck",
        "Left Shoulder",
        "Right Shoulder",
        "Head",
        "Left ForeArm",
        "Right ForeArm",
        "Left Arm",
        "Right Arm",
        "Left Hand",
        "Right Hand",
    )

    SKELETON = (
        (0, 1),
        (0, 2),
        (0, 3),
        (1, 4),
        (2, 5),
        (3, 6),
        (4, 7),
        (5, 8),
        (6, 9),
        (7, 10),
        (8, 11),
        (9, 12),
        (12, 13),
        (12, 14),
        (12, 15),
        (13, 16),
        (14, 17),
        (16, 18),
        (17, 19),
        (18, 20),
        (19, 21),
    )

    def __init__(self, *args, **kwargs):
        kwargs["ext"] = "pkl"  # We have pkl file
        super(SMPLX, self).__init__(*args, **kwargs)
        joints = [JOINT_MAP[i] for i in JOINT_NAMES]
        self.joint_map = torch.tensor(joints, dtype=torch.long)

    def forward(self, *args, **kwargs):
        kwargs["get_skin"] = True

        # if pose parameter is for SMPL with 21 joints (ignoring root)
        try:
            if kwargs["body_pose"].shape[1] == 69:
                kwargs["body_pose"] = kwargs["body_pose"][
                    :, : -2 * 3
                ]  # Ignore the last two joints (which are on the palm. Not used)

            if kwargs["body_pose"].shape[1] == 23:
                kwargs["body_pose"] = kwargs["body_pose"][
                    :, :-2
                ]  # Ignore the last two joints (which are on the palm. Not used)
        except:
            pass

        smpl_output = super(SMPLX, self).forward(*args, **kwargs)

        # SMPL-X Joint order: https://docs.google.com/spreadsheets/d/1_1dLdaX-sbMkCKr_JzJW_RZCpwBwd7rcKkWT_VgAQ_0/edit#gid=0
        smplx_to_smpl = (
            list(range(0, 22)) + [28, 43] + list(range(55, 76))
        )  # 28 left middle finger , 43: right middle finger 1
        smpl_joints = smpl_output.joints[
            :, smplx_to_smpl, :
        ]  # Convert SMPL-X to SMPL     127 ->45
        joints = smpl_joints
        joints = joints[:, self.joint_map, :]

        smplx_lhand = (
            [20] + list(range(25, 40)) + list(range(66, 71))
        )  # 20 for left wrist. 20 finger joints
        lhand_joints = smpl_output.joints[:, smplx_lhand, :]  # (N,21,3)
        lhand_joints = lhand_joints[
            :, SMPLX_HAND_TO_PANOPTIC, :
        ]  # Convert SMPL-X hand order to paonptic hand order

        smplx_rhand = (
            [21] + list(range(40, 55)) + list(range(71, 76))
        )  # 21 for right wrist. 20 finger joints
        rhand_joints = smpl_output.joints[:, smplx_rhand, :]  # (N,21,3)
        rhand_joints = rhand_joints[
            :, SMPLX_HAND_TO_PANOPTIC, :
        ]  # Convert SMPL-X hand order to paonptic hand order

        output = SMPLXOutput(
            vertices=smpl_output.vertices,
            global_orient=smpl_output.global_orient,
            body_pose=smpl_output.body_pose,
            joints=joints,
            right_hand_pose=rhand_joints,  # N,21,3
            left_hand_pose=lhand_joints,  # N,21,3
            betas=smpl_output.betas,
            full_pose=smpl_output.full_pose,
            A=smpl_output.A,
        )
        return output


"""
0	pelvis',
1	left_hip',
2	right_hip',
3	spine1',
4	left_knee',
5	right_knee',
6	spine2',
7	left_ankle',
8	right_ankle',
9	spine3',
10	left_foot',
11	right_foot',
12	neck',
13	left_collar',
14	right_collar',
15	head',
16	left_shoulder',
17	right_shoulder',
18	left_elbow',
19	right_elbow',
20	left_wrist',
21	right_wrist',
22	jaw',
23	left_eye_smplhf',
24	right_eye_smplhf',
25	left_index1',
26	left_index2',
27	left_index3',
28	left_middle1',
29	left_middle2',
30	left_middle3',
31	left_pinky1',
32	left_pinky2',
33	left_pinky3',
34	left_ring1',
35	left_ring2',
36	left_ring3',
37	left_thumb1',
38	left_thumb2',
39	left_thumb3',
40	right_index1',
41	right_index2',
42	right_index3',
43	right_middle1',
44	right_middle2',
45	right_middle3',
46	right_pinky1',
47	right_pinky2',
48	right_pinky3',
49	right_ring1',
50	right_ring2',
51	right_ring3',
52	right_thumb1',
53	right_thumb2',
54	right_thumb3',
55	nose',
56	right_eye',
57	left_eye',
58	right_ear',
59	left_ear',
60	left_big_toe',
61	left_small_toe',
62	left_heel',
63	right_big_toe',
64	right_small_toe',
65	right_heel',
66	left_thumb',
67	left_index',
68	left_middle',
69	left_ring',
70	left_pinky',
71	right_thumb',
72	right_index',
73	right_middle',
74	right_ring',
75	right_pinky',
76	right_eye_brow1',
77	right_eye_brow2',
78	right_eye_brow3',
79	right_eye_brow4',
80	right_eye_brow5',
81	left_eye_brow5',
82	left_eye_brow4',
83	left_eye_brow3',
84	left_eye_brow2',
85	left_eye_brow1',
86	nose1',
87	nose2',
88	nose3',
89	nose4',
90	right_nose_2',
91	right_nose_1',
92	nose_middle',
93	left_nose_1',
94	left_nose_2',
95	right_eye1',
96	right_eye2',
97	right_eye3',
98	right_eye4',
99	right_eye5',
100	right_eye6',
101	left_eye4',
102	left_eye3',
103	left_eye2',
104	left_eye1',
105	left_eye6',
106	left_eye5',
107	right_mouth_1',
108	right_mouth_2',
109	right_mouth_3',
110	mouth_top',
111	left_mouth_3',
112	left_mouth_2',
113	left_mouth_1',
114	left_mouth_5', # 59 in OpenPose output
115	left_mouth_4', # 58 in OpenPose output
116	mouth_bottom',
117	right_mouth_4',
118	right_mouth_5',
119	right_lip_1',
120	right_lip_2',
121	lip_top',
122	left_lip_2',
123	left_lip_1',
124	left_lip_3',
125	lip_bottom',
126	right_lip_3',
127	right_contour_1',
128	right_contour_2',
129	right_contour_3',
130	right_contour_4',
131	right_contour_5',
132	right_contour_6',
133	right_contour_7',
134	right_contour_8',
135	contour_middle',
136	left_contour_8',
137	left_contour_7',
138	left_contour_6',
139	left_contour_5',
140	left_contour_4',
141	left_contour_3',
142	left_contour_2',
143	left_contour_1'
"""


# SMPL Joints:
"""
0	pelvis',
1	left_hip',
2	right_hip',
3	spine1',
4	left_knee',
5	right_knee',
6	spine2',
7	left_ankle',
8	right_ankle',
9	spine3',
10	left_foot',
11	right_foot',
12	neck',
13	left_collar',
14	right_collar',
15	head',
16	left_shoulder',
17	right_shoulder',
18	left_elbow',
19	right_elbow',
20	left_wrist',
21	right_wrist',
22	
23	
24	nose',
25	right_eye',
26	left_eye',
27	right_ear',
28	left_ear',
29	left_big_toe',
30	left_small_toe',
31	left_heel',
32	right_big_toe',
33	right_small_toe',
34	right_heel',
35	left_thumb',
36	left_index',
37	left_middle',
38	left_ring',
39	left_pinky',
40	right_thumb',
41	right_index',
42	right_middle',
43	right_ring',
44	right_pinky',
"""


================================================
FILE: src/spin/utils.py
================================================
import json

import cv2
import numpy as np
import torch
from skimage.transform import resize, rotate
from torchvision.transforms import Normalize

from .constants import IMG_NORM_MEAN, IMG_NORM_STD, IMG_RES


def get_transform(center, scale, res, rot=0):
    """Generate transformation matrix."""
    h = 200 * scale
    t = np.zeros((3, 3))
    t[0, 0] = float(res[1]) / h
    t[1, 1] = float(res[0]) / h
    t[0, 2] = res[1] * (-float(center[0]) / h + 0.5)
    t[1, 2] = res[0] * (-float(center[1]) / h + 0.5)
    t[2, 2] = 1
    if not rot == 0:
        rot = -rot  # To match direction of rotation from cropping
        rot_mat = np.zeros((3, 3))
        rot_rad = rot * np.pi / 180
        sn, cs = np.sin(rot_rad), np.cos(rot_rad)
        rot_mat[0, :2] = [cs, -sn]
        rot_mat[1, :2] = [sn, cs]
        rot_mat[2, 2] = 1
        # Need to rotate around center
        t_mat = np.eye(3)
        t_mat[0, 2] = -res[1] / 2
        t_mat[1, 2] = -res[0] / 2
        t_inv = t_mat.copy()
        t_inv[:2, 2] *= -1
        t = np.dot(t_inv, np.dot(rot_mat, np.dot(t_mat, t)))

    return t


def transform(pt, center, scale, res, invert=0, rot=0):
    """Transform pixel location to different reference."""
    t = get_transform(center, scale, res, rot=rot)
    if invert:
        t = np.linalg.inv(t)
    new_pt = np.array([pt[0] - 1, pt[1] - 1, 1.0]).T
    new_pt = np.dot(t, new_pt)

    return new_pt[:2].astype(int) + 1


def crop(img, center, scale, res, rot=0):
    """Crop image according to the supplied bounding box."""
    # Upper left point
    ul = np.array(transform([1, 1], center, scale, res, invert=1)) - 1
    # Bottom right point
    br = np.array(transform([res[0] + 1, res[1] + 1], center, scale, res, invert=1)) - 1

    # Padding so that when rotated proper amount of context is included
    pad = int(np.linalg.norm(br - ul) / 2 - float(br[1] - ul[1]) / 2)
    if not rot == 0:
        ul -= pad
        br += pad

    new_shape = [br[1] - ul[1], br[0] - ul[0]]
    if len(img.shape) > 2:
        new_shape += [img.shape[2]]
    new_img = np.zeros(new_shape)

    # Range to fill new array
    new_x = max(0, -ul[0]), min(br[0], len(img[0])) - ul[0]
    new_y = max(0, -ul[1]), min(br[1], len(img)) - ul[1]
    # Range to sample from original image
    old_x = max(0, ul[0]), min(len(img[0]), br[0])
    old_y = max(0, ul[1]), min(len(img), br[1])
    new_img[new_y[0] : new_y[1], new_x[0] : new_x[1]] = img[
        old_y[0] : old_y[1], old_x[0] : old_x[1]
    ]

    if not rot == 0:
        # Remove padding
        new_img = rotate(new_img, rot)
        new_img = new_img[pad:-pad, pad:-pad]

    new_img = resize(new_img, res)

    return new_img


def bbox_from_openpose(openpose_file, rescale=1.2, detection_thresh=0.2):
    """Get center and scale for bounding box from openpose detections."""
    with open(openpose_file, "r") as f:
        keypoints = json.load(f)["people"][0]["pose_keypoints_2d"]
    keypoints = np.reshape(np.array(keypoints), (-1, 3))
    valid = keypoints[:, -1] > detection_thresh
    valid_keypoints = keypoints[valid][:, :-1]
    center = valid_keypoints.mean(axis=0)
    bbox_size = (valid_keypoints.max(axis=0) - valid_keypoints.min(axis=0)).max()
    # adjust bounding box tightness
    scale = bbox_size / 200.0
    scale *= rescale

    return center, scale


def bbox_from_json(bbox_file):
    """Get center and scale of bounding box from bounding box annotations.
    The expected format is [top_left(x), top_left(y), width, height].
    """
    with open(bbox_file, "r") as f:
        bbox = np.array(json.load(f)["bbox"]).astype(np.float32)
    ul_corner = bbox[:2]
    center = ul_corner + 0.5 * bbox[2:]
    width = max(bbox[2], bbox[3])
    scale = width / 200.0
    # make sure the bounding box is rectangular
    return center, scale


def process_image(img_file, bbox_file=None, openpose_file=None, input_res=IMG_RES):
    """Read image, do preprocessing and possibly crop it according to the bounding box.
    If there are bounding box annotations, use them to crop the image.
    If no bounding box is specified but openpose detections are available, use them to get the bounding box.
    """
    img_file = str(img_file)
    normalize_img = Normalize(mean=IMG_NORM_MEAN, std=IMG_NORM_STD)
    img = cv2.imread(img_file)[
        :, :, ::-1
    ].copy()  # PyTorch does not support negative stride at the moment
    if bbox_file is None and openpose_file is None:
        # Assume that the person is centerered in the image
        height = img.shape[0]
        width = img.shape[1]
        center = np.array([width // 2, height // 2])
        scale = max(height, width) / 200
    else:
        if bbox_file is not None:
            center, scale = bbox_from_json(bbox_file)
        elif openpose_file is not None:
            center, scale = bbox_from_openpose(openpose_file)

    img = crop(img, center, scale, (input_res, input_res))
    img = img.astype(np.float32) / 255.0
    img = torch.from_numpy(img).permute(2, 0, 1)
    norm_img = normalize_img(img.clone())

    return img, norm_img


================================================
FILE: src/utils.py
================================================
import colorsys
import itertools
import json
import pickle

import cv2
import plotly.graph_objects as go
import trimesh
import torch
import numpy as np
import PIL.Image as pil_img
import PIL.ImageDraw as ImageDraw
from PIL import Image, ImageChops
from skimage import exposure

import spin
import renderer
import pose_estimation


def load_json(path):
    with open(path) as f:
        return json.load(f)


def save_json(o, path):
    with open(path, "w") as f:
        json.dump(o, f)


def load_pkl(path):
    with open(path, "rb") as f:
        return pickle.load(f)


def save_pkl(o, path):
    with open(path, "wb") as f:
        pickle.dump(o, f)


def plot_3D(joints, vertices, faces):
    x, y, z = joints.T
    x1, y1, z1 = vertices.T
    i, j, k = faces.T

    data = [
        go.Mesh3d(
            x=x1,
            y=y1,
            z=z1,
            i=i,
            j=j,
            k=k,
        ),
        go.Scatter3d(
            x=x,
            y=y,
            z=z,
            mode="markers",
            marker_size=5,
        ),
    ]

    fig = go.Figure(
        data=data,
    )

    return fig


def draw_keypoints(
    input_img_kp,
    keypoints,
    skeleton,
    r,
    color,
    contact2dlist=None,
    contact2dlist_color="green",
    cos=None,
):
    if keypoints is not None:
        draw = ImageDraw.Draw(input_img_kp)

        for skidx, (i, j) in enumerate(skeleton):
            a = keypoints[i]
            b = keypoints[j]
            ln = np.linalg.norm(b - a)

            xy = [a[0], a[1], b[0], b[1]]
            if cos is not None:
                c = colorsys.hsv_to_rgb(cos[skidx] ** 8, 1, 1)
                c = tuple(int(c_ * 255) for c_ in c)
                draw.line(xy, fill=c, width=r)
            else:
                draw.line(xy, fill=color, width=r)

        draw_kpts = [(p[0] - r, p[1] - r, p[0] + r, p[1] + r) for p in keypoints]
        for _, elipse in enumerate(draw_kpts):
            draw.ellipse(elipse, fill="black", outline="black")

        if contact2dlist is not None:
            keypoints_torch = torch.from_numpy(keypoints)
            for c2d in contact2dlist:
                for (src_1, dst_1, t_1), (src_2, dst_2, t_2) in itertools.combinations(
                    c2d, 2
                ):
                    a = torch.lerp(
                        keypoints_torch[src_1], keypoints_torch[dst_1], t_1
                    ).tolist()
                    b = torch.lerp(
                        keypoints_torch[src_2], keypoints_torch[dst_2], t_2
                    ).tolist()

                    xy = [a[0], a[1], b[0], b[1]]
                    draw.line(xy, fill=contact2dlist_color, width=max(r // 3, 10))

    return input_img_kp


def save_results_image(
    camera,
    focal_length_x,
    focal_length_y,
    input_img,
    vertices,
    faces,
    filename,
    keypoints=None,
    keypoints_2=None,
    heatmap=None,
    cvt_camera=True,
    contactlist=None,
    contact2dlist=None,
    user_study=True,
    cos=None,
):
    if isinstance(contactlist, list) and len(contactlist) > 0:
        contactlist = np.concatenate(contactlist)

    H, W, _ = input_img.shape
    HW = max(H, W)
    camera_center = np.array([W // 2, H // 2])
    if not cvt_camera:
        camera_transl = camera.copy()
    else:
        camera_transl = np.stack(
            [
                camera[1],
                camera[2],
                1 / camera[0],
            ],
        )

    # draw keypoints
    input_img_kp = pil_img.fromarray(input_img)
    if keypoints is not None:
        draw_keypoints(
            input_img_kp,
            keypoints,
            pose_estimation.SKELETON,
            r=int(HW * 0.01),
            color=(255, 0, 0, 255),
            # contact2dlist=contact2dlist,
            # contact2dlist_color="orange",
        )

    if cos is not None:
        input_img_kp_cos = pil_img.fromarray(input_img)
        if keypoints is not None:
            draw_keypoints(
                input_img_kp_cos,
                keypoints,
                pose_estimation.SKELETON,
                r=int(HW * 0.01),
                color=(255, 0, 0, 255),
                # contact2dlist=contact2dlist,
                # contact2dlist_color="orange",
                cos=cos,
            )

    input_img_kp_2 = input_img_kp.copy()
    if keypoints_2 is not None:
        draw_keypoints(
            input_img_kp_2,
            keypoints_2,
            # spin.SMPLX.SKELETON if "eft" in str(filename) else pose_estimation.SKELETON,
            pose_estimation.SKELETON,
            r=int(HW * 0.01),
            color=(0, 0, 255, 255),
            contact2dlist=contact2dlist,
            contact2dlist_color="purple",
        )

    # heatmap = ImageOps.invert(heatmap)
    if heatmap is not None:
        # input_img_kp_2 = pil_img.blend(input_img_kp_2, heatmap, 0.5)

        hm = np.copy(input_img)
        gray_img = exposure.rescale_intensity(heatmap, out_range=(0, 255))
        gray_img = gray_img.astype(np.uint8)
        heatmap_img = cv2.applyColorMap(gray_img, cv2.COLORMAP_JET)
        hm = pil_img.fromarray(cv2.cvtColor(heatmap_img, cv2.COLOR_BGR2RGB))
        hm.save(filename.with_stem(f"{filename.stem}_heatmap"))
        input_img_kp.save(filename.with_stem(f"{filename.stem}_2dkps"))
        heatmap = pil_img.fromarray(heatmap)

        if cos is not None:
            input_img_kp_cos.save(filename.with_stem(f"{filename.stem}_2dkpscos"))

    # render fitted mesh from different views
    overlay_fit_img = renderer.overlay_mesh(
        vertices,
        faces,
        camera_transl,
        focal_length_x,
        focal_length_y,
        camera_center,
        H,
        W,
        input_img.astype("float32") / 255,
        None,
        rotaround=None,
    )

    # overlay_fit_img = pil_img.fromarray(overlay_fit_img)
    # draw_keypoints(overlay_fit_img, keypoints_2, r=int(HW * 0.01), color=(0, 0, 255, 255))

    # camera_transl[-1] *= 1
    view1_fit = renderer.overlay_mesh(
        vertices,
        faces,
        camera_transl.astype(np.float32),
        focal_length_x,
        focal_length_y,
        camera_center,
        H,
        W,
        None,
        None,
        rotaround=-45,
        contactlist=contactlist,
    )
    view2_fit = renderer.overlay_mesh(
        vertices,
        faces,
        camera_transl.astype(np.float32),
        focal_length_x,
        focal_length_y,
        camera_center,
        H,
        W,
        None,
        None,
        rotaround=None,
        contactlist=contactlist,
    )
    view3_fit = renderer.overlay_mesh(
        vertices,
        faces,
        camera_transl.astype(np.float32),
        focal_length_x,
        focal_length_y,
        camera_center,
        H,
        W,
        None,
        None,
        rotaround=90,
        contactlist=contactlist,
        scale=1,
    )

    IMG = np.vstack(
        (
            np.hstack(
                (
                    np.asarray(input_img_kp)
                    if keypoints is not None
                    else 255 * np.ones_like(np.asarray(input_img_kp)),
                    np.asarray(input_img_kp_2),
                    overlay_fit_img,
                    # np.asanyarray(overlay_fit_img),
                ),
            ),
            np.hstack(
                (
                    view1_fit,
                    view2_fit,
                    view3_fit,
                ),
            ),
        ),
    )
    IMG = pil_img.fromarray(IMG)
    IMG.thumbnail((2000, 2000))

    IMG.save(filename)

    if user_study:
        w = 768
        input_img_kp.thumbnail((w, w))
        input_img_kp.save(filename.with_stem(f"{filename.stem}_orig"))
        W, H = input_img_kp.size

        camera_transl[-1] *= 2
        view2_fit = renderer.overlay_mesh(
            vertices,
            faces,
            camera_transl.astype(np.float32),
            focal_length_x,
            focal_length_y,
            camera_center,
            H,
            W,
            None,
            None,
            rotaround=None,
            contactlist=contactlist,
            scale=2,
        )
        view2_fit = pil_img.fromarray(view2_fit)
        w *= 2
        view2_fit.thumbnail((w, w))
        view2_fit.save(filename.with_stem(f"{filename.stem}_same"))
        view3_fit = renderer.overlay_mesh(
            vertices,
            faces,
            camera_transl.astype(np.float32),
            focal_length_x,
            focal_length_y,
            camera_center,
            H,
            W,
            None,
            None,
            rotaround=90,
            contactlist=contactlist,
            scale=2,
        )
        view3_fit = pil_img.fromarray(view3_fit)
        view3_fit.thumbnail((w, w))
        view3_fit.save(filename.with_stem(f"{filename.stem}_alt"))

    return IMG


def save_3d_model_on_img(
    camera,
    vertices,
    faces,
    img,
    filename,
    save_path,
):
    img_res = max(img.shape[:2])
    r = renderer.Renderer(
        focal_length=spin.constants.FOCAL_LENGTH,
        img_res=img_res,
        faces=faces,
    )

    # Calculate camera parameters for rendering
    camera_translation = np.stack(
        [
            camera[1],
            camera[2],
            2 * spin.constants.FOCAL_LENGTH / (img_res * camera[0] + 1e-9),
        ],
    )
    # Render parametric shape
    img_shape = r(vertices, camera_translation, img)
    img_shape = (255 * img_shape).astype("uint8")
    img_shape = cv2.cvtColor(img_shape, cv2.COLOR_RGB2BGR)
    cv2.imwrite(str(save_path / f"shape_{filename}.png"), img_shape)

    # Render side views
    aroundy = cv2.Rodrigues(np.array([0, np.radians(90.0), 0]))[0]
    center = vertices.mean(axis=0)
    rot_vertices = np.dot((vertices - center), aroundy) + center

    # Render non-parametric shape
    img_shape = r(rot_vertices, camera_translation, np.ones_like(img))
    img_shape = (255 * img_shape).astype("uint8")
    img_shape = cv2.cvtColor(img_shape, cv2.COLOR_RGB2BGR)
    cv2.imwrite(str(save_path / f"shape_rot_{filename}.png"), img_shape)


def save_mesh_with_colors(vertices, faces, save_path, mask=None, inds=None):
    if inds is not None and isinstance(inds, list):
        inds = np.concatenate(inds)
    mesh = trimesh.Trimesh(
        vertices=vertices,
        faces=faces,
        process=False,
    )
    color = np.array(mesh.visual.vertex_colors)
    color[:] = [233, 233, 233, 255]
    if mask is not None and any(mask):
        color[~mask] = [255, 0, 0, 255]
    elif inds is not None and len(inds) > 0:
        color[inds] = [255, 0, 0, 255]
    mesh.visual.vertex_colors = color
    mesh.export(save_path)


def save_pose_params(rotmat, camera, betas, vertices, smpl, contact, save_path):
    if contact is not None and isinstance(contact, list) and len(contact) > 0:
        contact = np.concatenate(contact)

    rotmat = rotmat.detach()
    camera = camera.detach()
    if smpl.name() == "SMPL-X":
        rotmat = rotmat[: -2 * 3]

    res = {
        "camera_s_t": camera.unsqueeze(0).cpu().numpy(),
        "global_orient": rotmat[:3].unsqueeze(0).cpu().numpy(),
        "betas": betas,
        "body_pose": rotmat[3:].unsqueeze(0).cpu().numpy(),
        "left_hand_pose": smpl.left_hand_pose.unsqueeze(0).detach().cpu().numpy(),
        "right_hand_pose": smpl.right_hand_pose.unsqueeze(0).detach().cpu().numpy(),
        "model": smpl.name().replace("-", ""),
        "gender": smpl.gender,
        "vertices": vertices[0].cpu().numpy(),
    }

    if contact is not None and len(contact) > 0:
        contact = np.array(contact)
        res["contact"] = contact
    else:
        res["v"] = vertices[0].cpu().numpy()

    save_pkl(res, save_path)

    np.savez(save_path, **res)
Download .txt
gitextract__nwva2u6/

├── README.md
├── patches/
│   ├── selfcontact.diff
│   ├── smplx.diff
│   └── torchgeometry.diff
├── requirements.txt
├── scripts/
│   ├── download.sh
│   ├── prepare.sh
│   └── run.sh
└── src/
    ├── fist_pose.py
    ├── hist_cub.py
    ├── losses.py
    ├── pose.py
    ├── pose_estimation.py
    ├── renderer.py
    ├── spin/
    │   ├── __init__.py
    │   ├── constants.py
    │   ├── hmr.py
    │   ├── smpl.py
    │   └── utils.py
    └── utils.py
Download .txt
SYMBOL INDEX (96 symbols across 9 files)

FILE: src/hist_cub.py
  function cub (line 24) | def cub(x, a, b, c):
  function subsample (line 33) | def subsample(a, p=0.0005, seed=0):
  function read_cos_opt (line 42) | def read_cos_opt(path, fname="cos_hist.npy"):
  function plot_hist (line 53) | def plot_hist(cos_opt_dir, hist_smpl_fpath, params, out_dir, bins=10, xy...
  function kldiv (line 121) | def kldiv(p_hist, q_hist):
  function calc_histogram (line 127) | def calc_histogram(x, bins=10, range=(0, 1)):
  function step (line 132) | def step(params, angles_opt, p_hist, bone_idx=None):
  function optimize (line 154) | def optimize(cos_opt_dir, hist_smpl_fpath, bone_idx=None):
  function main (line 201) | def main():

FILE: src/losses.py
  class MSE (line 9) | class MSE(nn.Module):
    method __init__ (line 10) | def __init__(self, ignore=None):
    method forward (line 16) | def forward(self, y_pred, y_data):
  class Parallel (line 25) | class Parallel(nn.Module):
    method __init__ (line 26) | def __init__(self, skeleton, ignore=None, ground_parallel=None):
    method forward (line 40) | def forward(self, y_pred3d, y_data, z, spine_j, writer=None, global_st...
  class MimickedSelfContactLoss (line 155) | class MimickedSelfContactLoss(nn.Module):
    method __init__ (line 156) | def __init__(self, geodesics_mask):
    method forward (line 164) | def forward(

FILE: src/pose.py
  function parse_args (line 50) | def parse_args():
  function freeze_layers (line 223) | def freeze_layers(model):
  function project_and_normalize_to_spin (line 239) | def project_and_normalize_to_spin(vertices_3d, camera):
  function project_and_normalize_to_spin_legs (line 253) | def project_and_normalize_to_spin_legs(vertices_3d, A, camera):
  function rotation_matrix_to_angle_axis (line 286) | def rotation_matrix_to_angle_axis(rotmat):
  function get_smpl_output (line 303) | def get_smpl_output(smpl, rotmat, betas, use_betas=True, zero_hands=False):
  function get_predictions (line 331) | def get_predictions(model_hmr, smpl, input_img, use_betas=True, zero_han...
  function get_pred_and_data (line 348) | def get_pred_and_data(
  function normalize_keypoints_to_spin (line 379) | def normalize_keypoints_to_spin(keypoints_2d, img_size):
  function unnormalize_keypoints_from_spin (line 397) | def unnormalize_keypoints_from_spin(keypoints_2d, shift, scale, ax2):
  function get_vertices_in_heatmap (line 405) | def get_vertices_in_heatmap(contact_heatmap):
  function get_contact_heatmap (line 424) | def get_contact_heatmap(model_contact, img_path, thresh=0.5):
  function discretize (line 445) | def discretize(parametrization, n_bins=100):
  function get_mapping_from_params_to_verts (line 453) | def get_mapping_from_params_to_verts(verts, params):
  function find_contacts (line 461) | def find_contacts(y_data_conts, keypoints_2d, bone_to_params, thresh=12,...
  function optimize (line 529) | def optimize(
  function optimize_ft (line 772) | def optimize_ft(
  function create_bone (line 1042) | def create_bone(i, j, keypoints_2d):
  function is_parallel_to_plane (line 1051) | def is_parallel_to_plane(bone, thresh=21):
  function is_close_to_plane (line 1055) | def is_close_to_plane(bone, plane, thresh):
  function get_selector (line 1061) | def get_selector():
  function calc_cos (line 1070) | def calc_cos(joints_2d, joints_3d):
  function get_natural (line 1087) | def get_natural(keypoints_2d, vertices, right_foot_inds, left_foot_inds,...
  function get_cos (line 1165) | def get_cos(keypoints_3d_pred, use_angle_transf, loss_parallel):
  function save_mesh_with_winding_numbers (line 1202) | def save_mesh_with_winding_numbers(sc_module, vertices, smpl, save_path):
  function get_contacts (line 1223) | def get_contacts(
  function save_all (line 1264) | def save_all(
  function spin_step (line 1348) | def spin_step(
  function eft_step (line 1403) | def eft_step(
  function dc_step (line 1496) | def dc_step(
  function us_step (line 1577) | def us_step(
  function main (line 1652) | def main():

FILE: src/pose_estimation.py
  function transform (line 81) | def transform(img):
  function get_affine_transform (line 89) | def get_affine_transform(
  function get_3rd_point (line 127) | def get_3rd_point(a, b):
  function get_dir (line 132) | def get_dir(src_point, rot_rad):
  function process_image (line 142) | def process_image(path, input_img_size, pixel_std=200):
  function get_final_preds (line 166) | def get_final_preds(batch_heatmaps, center, scale, post_process=False):
  function transform_preds (line 199) | def transform_preds(coords, center, scale, output_size):
  function affine_transform (line 207) | def affine_transform(pt, t):
  function get_max_preds (line 213) | def get_max_preds(batch_heatmaps):
  function infer_single_image (line 245) | def infer_single_image(model, img_path, input_img_size=(288, 384), retur...

FILE: src/renderer.py
  class Renderer (line 8) | class Renderer:
    method __init__ (line 14) | def __init__(self, focal_length=5000, img_res=224, faces=None):
    method visualize_tb (line 22) | def visualize_tb(self, vertices, camera_translation, images):
    method __call__ (line 40) | def __call__(self, vertices, camera_translation, image):
  function overlay_mesh (line 84) | def overlay_mesh(

FILE: src/spin/hmr.py
  function rot6d_to_rotmat (line 9) | def rot6d_to_rotmat(x):
  class Bottleneck (line 31) | class Bottleneck(nn.Module):
    method __init__ (line 38) | def __init__(self, inplanes, planes, stride=1, downsample=None):
    method forward (line 52) | def forward(self, x):
  class HMR (line 75) | class HMR(nn.Module):
    method __init__ (line 78) | def __init__(self, block, layers, smpl_mean_params):
    method _make_layer (line 123) | def _make_layer(self, block, planes, blocks, stride=1):
    method forward (line 145) | def forward(self, x, init_pose=None, init_shape=None, init_cam=None, n...
  function hmr (line 187) | def hmr(smpl_mean_params, pretrained=True, **kwargs):

FILE: src/spin/smpl.py
  class SMPL (line 37) | class SMPL(_SMPL):
    method __init__ (line 93) | def __init__(self, *args, **kwargs):
    method forward (line 103) | def forward(self, *args, **kwargs):
  class SMPLX (line 124) | class SMPLX(_SMPLX):
    method __init__ (line 176) | def __init__(self, *args, **kwargs):
    method forward (line 182) | def forward(self, *args, **kwargs):

FILE: src/spin/utils.py
  function get_transform (line 12) | def get_transform(center, scale, res, rot=0):
  function transform (line 40) | def transform(pt, center, scale, res, invert=0, rot=0):
  function crop (line 51) | def crop(img, center, scale, res, rot=0):
  function bbox_from_openpose (line 89) | def bbox_from_openpose(openpose_file, rescale=1.2, detection_thresh=0.2):
  function bbox_from_json (line 105) | def bbox_from_json(bbox_file):
  function process_image (line 119) | def process_image(img_file, bbox_file=None, openpose_file=None, input_re...

FILE: src/utils.py
  function load_json (line 21) | def load_json(path):
  function save_json (line 26) | def save_json(o, path):
  function load_pkl (line 31) | def load_pkl(path):
  function save_pkl (line 36) | def save_pkl(o, path):
  function plot_3D (line 41) | def plot_3D(joints, vertices, faces):
  function draw_keypoints (line 71) | def draw_keypoints(
  function save_results_image (line 120) | def save_results_image(
  function save_3d_model_on_img (line 347) | def save_3d_model_on_img(
  function save_mesh_with_colors (line 388) | def save_mesh_with_colors(vertices, faces, save_path, mask=None, inds=No...
  function save_pose_params (line 406) | def save_pose_params(rotmat, camera, betas, vertices, smpl, contact, sav...
Condensed preview — 20 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (156K chars).
[
  {
    "path": "README.md",
    "chars": 3592,
    "preview": "# Sketch2Pose: Estimating a 3D Character Pose from a Bitmap Sketch\n\nArtists frequently capture character poses via raste"
  },
  {
    "path": "patches/selfcontact.diff",
    "chars": 5486,
    "preview": "+++ venv/lib/python3.10/site-packages/selfcontact/body_segmentation.py\n@@ -14,6 +14,8 @@\n #\n # Contact: ps-license@tuebi"
  },
  {
    "path": "patches/smplx.diff",
    "chars": 1902,
    "preview": "+++ venv/lib/python3.10/site-packages/smplx/body_models.py\n@@ -366,7 +366,7 @@\n             num_repeats = int(batch_size"
  },
  {
    "path": "patches/torchgeometry.diff",
    "chars": 419,
    "preview": "+++ venv/lib/python3.10/site-packages/torchgeometry/core/conversions.py\n@@ -298,6 +298,9 @@\n                       rmat_"
  },
  {
    "path": "requirements.txt",
    "chars": 416,
    "preview": "matplotlib>=3.5.1\nnumpy>=1.22.3\nopencv_python>=4.5.5.64\nPillow>=9.1.0\nplotly>=5.7.0\npyrender>=0.1.45\nscikit_image>=0.19."
  },
  {
    "path": "scripts/download.sh",
    "chars": 1472,
    "preview": "#!/usr/bin/env sh\n\n\nset -euo pipefail\n\n\nasset_dir=\"./assets\"\n\n[ ! -e \"${asset_dir}\"/models_smplx_v1_1.zip ] \\\n    && ech"
  },
  {
    "path": "scripts/prepare.sh",
    "chars": 524,
    "preview": "#!/usr/bin/env sh\n\n\nset -euo pipefail\n\nvenv_dir=venv\npython -m venv --clear \"${venv_dir}\"\n\n. \"${venv_dir}\"/bin/activate\n"
  },
  {
    "path": "scripts/run.sh",
    "chars": 1452,
    "preview": "#!/usr/bin/env sh\n\n\nset -euo pipefail\n\n\nimg_dir=\"./data/images\"\nout_dir=\"./output\"\n\nfind \"${img_dir}\" -mindepth 1 -maxde"
  },
  {
    "path": "src/fist_pose.py",
    "chars": 12752,
    "preview": "left_fist = [\n    0.0, 0.0, 0.0,\n    0.0, 0.0, 0.0,\n    0.0, 0.0, 0.0,\n    0.0, 0.0, 0.0,\n    0.0, 0.0, 0.0,\n    0.0, 0."
  },
  {
    "path": "src/hist_cub.py",
    "chars": 6680,
    "preview": "import itertools\nimport functools\nimport math\nimport multiprocessing\nfrom pathlib import Path\n\nimport matplotlib\nmatplot"
  },
  {
    "path": "src/losses.py",
    "chars": 6870,
    "preview": "import itertools\n\nimport torch\nimport torch.nn as nn\n\nimport pose_estimation\n\n\nclass MSE(nn.Module):\n    def __init__(se"
  },
  {
    "path": "src/pose.py",
    "chars": 55260,
    "preview": "import argparse\nimport math\nfrom pathlib import Path\n\nimport cv2\nimport numpy as np\nimport PIL.Image as Image\nimport sel"
  },
  {
    "path": "src/pose_estimation.py",
    "chars": 7074,
    "preview": "import math\n\nimport cv2\nimport numpy as np\n\nIMG_SIZE = (288, 384)\nMEAN = np.array([0.485, 0.456, 0.406])\nSTD = np.array("
  },
  {
    "path": "src/renderer.py",
    "chars": 6007,
    "preview": "import numpy as np\nimport pyrender\nimport torch\nimport trimesh\nfrom torchvision.utils import make_grid\n\n\nclass Renderer:"
  },
  {
    "path": "src/spin/__init__.py",
    "chars": 192,
    "preview": "from .constants import JOINT_NAMES\nfrom .hmr import hmr\nfrom .smpl import SMPLX\nfrom .utils import process_image\n\n__all_"
  },
  {
    "path": "src/spin/constants.py",
    "chars": 5279,
    "preview": "FOCAL_LENGTH = 5000.0\nIMG_RES = 224\n\n# Mean and standard deviation for normalizing input image\nIMG_NORM_MEAN = [0.485, 0"
  },
  {
    "path": "src/spin/hmr.py",
    "chars": 6685,
    "preview": "import math\n\nimport numpy as np\nimport torch\nimport torch.nn as nn\nimport torchvision.models.resnet as resnet\n\n\ndef rot6"
  },
  {
    "path": "src/spin/smpl.py",
    "chars": 9570,
    "preview": "import numpy as np\nimport torch\nfrom smplx import SMPL as _SMPL\nfrom smplx import SMPLX as _SMPLX\nfrom smplx.body_models"
  },
  {
    "path": "src/spin/utils.py",
    "chars": 5077,
    "preview": "import json\n\nimport cv2\nimport numpy as np\nimport torch\nfrom skimage.transform import resize, rotate\nfrom torchvision.tr"
  },
  {
    "path": "src/utils.py",
    "chars": 11760,
    "preview": "import colorsys\nimport itertools\nimport json\nimport pickle\n\nimport cv2\nimport plotly.graph_objects as go\nimport trimesh\n"
  }
]

About this extraction

This page contains the full source code of the kbrodt/sketch2pose GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 20 files (145.0 KB), approximately 45.8k tokens, and a symbol index with 96 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.

Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.

Copied to clipboard!