[
  {
    "path": "README.md",
    "content": "# Sketch2Pose: Estimating a 3D Character Pose from a Bitmap Sketch\n\nArtists frequently capture character poses via raster sketches, then use these\ndrawings as a reference while posing a 3D character in a specialized 3D\nsoftware --- a time-consuming process, requiring specialized 3D training and\nmental effort. We tackle this challenge by proposing the first system for\nautomatically inferring a 3D character pose from a single bitmap sketch,\nproducing poses consistent with viewer expectations. Algorithmically\ninterpreting bitmap sketches is challenging, as they contain significantly\ndistorted proportions and foreshortening. We address this by predicting three\nkey elements of a drawing, necessary to disambiguate the drawn poses: 2D bone\ntangents, self-contacts, and bone foreshortening. These elements are then\nleveraged in an optimization inferring the 3D character pose consistent with\nthe artist's intent. Our optimization balances cues derived from artistic\nliterature and perception research to compensate for distorted character\nproportions. We demonstrate a gallery of results on sketches of numerous\nstyles. We validate our method via numerical evaluations, user studies, and\ncomparisons to manually posed characters and previous work.\n\n[Project Page](http://www-labs.iro.umontreal.ca/~bmpix/sketch2pose/)\n\n# Prerequisites\n\n- [GNU/Linux](https://www.gnu.org/gnu/linux-and-gnu.en.html)\n- [`python`](https://python.org)\n- [`pytorch`](https://pytorch.org/)\n- [NVIDIA GPU] (optional, but highly recommended)\n\n## Download body model (SMPL-X)\n\nDownload SMPL-X body model from\n[https://smpl-x.is.tue.mpg.de](https://smpl-x.is.tue.mpg.de)\n\nSee [`download.sh`](./scripts/download.sh) and run\n\n```bash\nsh ./scripts/download.sh\n```\n\n## Virtual environement\n\nChange [`pytorch`](https://pytorch.org/) version if needed in\n[`prepare.sh`](./scripts/prepare.sh) and run\n\n```bash\nsh ./scripts/prepare.sh\n```\n\n# Demo\n\nActivate virtual environement `. venv/bin/activate` and run\n\n```bash\nsh ./scripts/run.sh\n\n# or\n\npython src/pose.py \\\n        --save-path \"${out_dir}\" \\\n        --img-path \"${img_path}\" \\\n        --use-contacts \\\n        --use-natural \\\n        --use-cos \\\n        --use-angle-transf \\\n\n# or without contacts\n\npython src/pose.py \\\n        --save-path \"${out_dir}\" \\\n        --img-path \"${img_path}\" \\\n        --use-natural \\\n        --use-cos \\\n        --use-angle-transf \\\n```\n\n# Citation\n\n```\n@article{brodt2022sketch2pose,\n    author = {Kirill Brodt and Mikhail Bessmeltsev},\n    title = {Sketch2Pose: Estimating a 3D Character Pose from a Bitmap Sketch},\n    journal = {ACM Transactions on Graphics},\n    year = {2022},\n    month = {7},\n    volume = {41},\n    number = {4},\n    doi = {10.1145/3528223.3530106},\n}\n```\n\n# Useful links\n\n- [Deep High-Resolution Representation Learning for Human Pose Estimation](https://github.com/leoxiaobin/deep-high-resolution-net.pytorch/)\n- [SMPLify-X](https://github.com/vchoutas/smplify-x) ([project](https://smpl-x.is.tue.mpg.de/))\n- [SPIN](https://github.com/nkolot/SPIN) ([project](https://www.seas.upenn.edu/~nkolot/projects/spin/))\n- [eft](https://github.com/facebookresearch/eft)\n- [SMPLify-XMC](https://github.com/muelea/smplify-xmc), [selfcontact](https://github.com/muelea/selfcontact) ([project](https://tuch.is.tue.mpg.de/))\n- [Mixamo](https://www.mixamo.com) models with animations and a\n  [script](https://forums.unrealengine.com/community/community-content-tools-and-tutorials/1376068-script-mixamo-download-script)\n  to download them\n- Quaternion-based [Forward\n  Kinematics](https://github.com/facebookresearch/QuaterNet)\n"
  },
  {
    "path": "patches/selfcontact.diff",
    "content": "+++ venv/lib/python3.10/site-packages/selfcontact/body_segmentation.py\n@@ -14,6 +14,8 @@\n #\n # Contact: ps-license@tuebingen.mpg.de\n \n+from pathlib import Path\n+\n import torch\n import trimesh\n import torch.nn as nn\n@@ -22,6 +24,17 @@\n \n from .utils.mesh import winding_numbers\n \n+\n+def load_pkl(path):\n+    with open(path, \"rb\") as fin:\n+        return pickle.load(fin)\n+\n+\n+def save_pkl(obj, path):\n+    with open(path, \"wb\") as fout:\n+        pickle.dump(obj, fout)\n+\n+\n class BodySegment(nn.Module):\n     def __init__(self,\n                  name,\n@@ -63,9 +76,17 @@\n         self.register_buffer('segment_faces', segment_faces)\n \n         # create vector to select vertices form faces\n-        tri_vidx = []\n-        for ii in range(faces.max().item()+1):\n-            tri_vidx += [torch.nonzero(faces==ii)[0].tolist()]\n+        segments_folder = Path(segments_folder)\n+        tri_vidx_path = segments_folder / \"tri_vidx.pkl\"\n+        if not tri_vidx_path.is_file():\n+            tri_vidx = []\n+            for ii in range(faces.max().item()+1):\n+                tri_vidx += [torch.nonzero(faces==ii)[0].tolist()]\n+\n+            save_pkl(tri_vidx, tri_vidx_path)\n+        else:\n+            tri_vidx = load_pkl(tri_vidx_path)\n+\n         self.register_buffer('tri_vidx', torch.tensor(tri_vidx))\n \n     def create_band_faces(self):\n@@ -149,7 +170,7 @@\n         self.segmentation = {}\n         for idx, name in enumerate(names):\n             self.segmentation[name] = BodySegment(name, faces, segments_folder,\n-                model_type).to('cuda')\n+                model_type).to(device)\n \n     def batch_has_self_isec_verts(self, vertices):\n         \"\"\"\n+++ venv/lib/python3.10/site-packages/selfcontact/selfcontact.py\n@@ -41,6 +41,7 @@\n         test_segments=True,\n         compute_hd=False,\n         buffer_geodists=False,\n+        device=\"cuda\",\n     ):\n         super().__init__()\n \n@@ -95,7 +96,7 @@\n         if self.test_segments:\n             sxseg = pickle.load(open(segments_bounds_path, 'rb'))\n             self.segments = BatchBodySegment(\n-                [x for x in sxseg.keys()], faces, segments_folder, self.model_type\n+                [x for x in sxseg.keys()], faces, segments_folder, self.model_type, device=device,\n             )\n \n         # load regressor to get high density mesh\n@@ -106,7 +107,7 @@\n                 torch.tensor(hd_operator['values']),\n                 torch.Size(hd_operator['size']))\n             self.register_buffer('hd_operator',\n-                torch.tensor(hd_operator).float())\n+                hd_operator.clone().detach().float())\n \n             with open(point_vert_corres_path, 'rb') as f:\n                 hd_geovec = pickle.load(f)['faces_vert_is_sampled_from']\n@@ -135,9 +136,13 @@\n         # split because of memory into two chunks\n         exterior = torch.zeros((bs, nv), device=vertices.device,\n             dtype=torch.bool)\n-        exterior[:, :5000] = winding_numbers(vertices[:,:5000,:],\n+        exterior[:, :3000] = winding_numbers(vertices[:,:3000,:],\n             triangles).le(0.99)\n-        exterior[:, 5000:] = winding_numbers(vertices[:,5000:,:],\n+        exterior[:, 3000:6000] = winding_numbers(vertices[:,3000:6000,:],\n+            triangles).le(0.99)\n+        exterior[:, 6000:9000] = winding_numbers(vertices[:,6000:9000,:],\n+            triangles).le(0.99)\n+        exterior[:, 9000:] = winding_numbers(vertices[:,9000:,:],\n             triangles).le(0.99)\n \n         # check if intersections happen within segments\n@@ -173,9 +178,13 @@\n         # split because of memory into two chunks\n         exterior = torch.zeros((bs, np), device=points.device,\n             dtype=torch.bool)\n-        exterior[:, :6000] = winding_numbers(points[:,:6000,:],\n+        exterior[:, :3000] = winding_numbers(points[:,:3000,:],\n+            triangles).le(0.99)\n+        exterior[:, 3000:6000] = winding_numbers(points[:,3000:6000,:],\n             triangles).le(0.99)\n-        exterior[:, 6000:] = winding_numbers(points[:,6000:,:],\n+        exterior[:, 6000:9000] = winding_numbers(points[:,6000:9000,:],\n+            triangles).le(0.99)\n+        exterior[:, 9000:] = winding_numbers(points[:,9000:,:],\n             triangles).le(0.99)\n \n         return exterior\n@@ -371,6 +380,23 @@\n \n         return hd_v2v_mins, hd_exteriors, hd_points, hd_faces_in_contacts\n \n+    def verts_in_contact(self, vertices, return_idx=False):\n+\n+            # get pairwise distances of vertices\n+            v2v = self.get_pairwise_dists(vertices, vertices, squared=True)\n+\n+            # mask v2v with eucledean and geodesic dsitance\n+            euclmask = v2v < self.euclthres**2\n+            mask = euclmask * self.geomask\n+\n+            # find closes vertex in contact\n+            in_contact = mask.sum(1) > 0\n+\n+            if return_idx:\n+                in_contact = torch.where(in_contact)\n+\n+            return in_contact\n+\n \n \n class SelfContactSmall(nn.Module):\n+++ venv/lib/python3.10/site-packages/selfcontact/utils/mesh.py\n@@ -82,7 +82,7 @@\n     if valid_vals > 0:\n         loss = (mask * dists).sum() / valid_vals\n     else:\n-        loss = torch.Tensor([0]).cuda()\n+        loss = mask.new_tensor([0])\n     return loss\n \n def batch_index_select(inp, dim, index):\n@@ -103,6 +103,7 @@\n     xx = torch.bmm(x, x.transpose(2, 1))\n     yy = torch.bmm(y, y.transpose(2, 1))\n     zz = torch.bmm(x, y.transpose(2, 1))\n+    use_cuda = x.device.type == \"cuda\"\n     if use_cuda:\n         dtype = torch.cuda.LongTensor\n     else:\n"
  },
  {
    "path": "patches/smplx.diff",
    "content": "+++ venv/lib/python3.10/site-packages/smplx/body_models.py\n@@ -366,7 +366,7 @@\n             num_repeats = int(batch_size / betas.shape[0])\n             betas = betas.expand(num_repeats, -1)\n \n-        vertices, joints = lbs(betas, full_pose, self.v_template,\n+        vertices, joints, _ = lbs(betas, full_pose, self.v_template,\n                                self.shapedirs, self.posedirs,\n                                self.J_regressor, self.parents,\n                                self.lbs_weights, pose2rot=pose2rot)\n@@ -1228,7 +1228,7 @@\n \n         shapedirs = torch.cat([self.shapedirs, self.expr_dirs], dim=-1)\n \n-        vertices, joints = lbs(shape_components, full_pose, self.v_template,\n+        vertices, joints, A = lbs(shape_components, full_pose, self.v_template,\n                                shapedirs, self.posedirs,\n                                self.J_regressor, self.parents,\n                                self.lbs_weights, pose2rot=pose2rot,\n@@ -1283,7 +1283,9 @@\n                              right_hand_pose=right_hand_pose,\n                              jaw_pose=jaw_pose,\n                              v_shaped=v_shaped,\n-                             full_pose=full_pose if return_full_pose else None)\n+                             full_pose=full_pose if return_full_pose else None,\n+                             A=A,\n+                             )\n         return output\n \n \n+++ venv/lib/python3.10/site-packages/smplx/lbs.py\n@@ -245,7 +245,7 @@\n \n     verts = v_homo[:, :, :3, 0]\n \n-    return verts, J_transformed\n+    return verts, J_transformed, (A, J)\n \n \n def vertices2joints(J_regressor: Tensor, vertices: Tensor) -> Tensor:\n+++ venv/lib/python3.10/site-packages/smplx/utils.py\n@@ -71,6 +71,7 @@\n class SMPLXOutput(SMPLHOutput):\n     expression: Optional[Tensor] = None\n     jaw_pose: Optional[Tensor] = None\n+    A: Optional[Tensor] = None\n \n \n @dataclass\n"
  },
  {
    "path": "patches/torchgeometry.diff",
    "content": "+++ venv/lib/python3.10/site-packages/torchgeometry/core/conversions.py\n@@ -298,6 +298,9 @@\n                       rmat_t[:, 0, 1] - rmat_t[:, 1, 0]], -1)\n     t3_rep = t3.repeat(4, 1).t()\n \n+    mask_d2 = mask_d2.float()\n+    mask_d0_d1 = mask_d0_d1.float()\n+    mask_d0_nd1 = mask_d0_nd1.float()\n     mask_c0 = mask_d2 * mask_d0_d1\n     mask_c1 = mask_d2 * (1 - mask_d0_d1)\n     mask_c2 = (1 - mask_d2) * mask_d0_nd1\n"
  },
  {
    "path": "requirements.txt",
    "content": "matplotlib>=3.5.1\nnumpy>=1.22.3\nopencv_python>=4.5.5.64\nPillow>=9.1.0\nplotly>=5.7.0\npyrender>=0.1.45\nscikit_image>=0.19.2\nscipy>=1.8.0\nShapely>=1.8.1.post1\nscikit-image>=0.19.2\ntensorboard>=2.8.0\ntorchgeometry>=0.1.2\ntqdm>=4.64.0\ntrimesh>=3.10.8\ngit+https://github.com/muelea/selfcontact.git@08da422526419c24736c0616bca49623e442c26a\ngit+https://github.com/vchoutas/smplx.git@5fa20519735cceda19afed0beeabd53caef711cd\n"
  },
  {
    "path": "scripts/download.sh",
    "content": "#!/usr/bin/env sh\n\n\nset -euo pipefail\n\n\nasset_dir=\"./assets\"\n\n[ ! -e \"${asset_dir}\"/models_smplx_v1_1.zip ] \\\n    && echo Error: Download SMPL-X body model from https://smpl-x.is.tue.mpg.de \\\n    and save zip archive to \"${asset_dir}\" \\\n    && exit 1 \\\n    && :\n\nasset_urls=(\n    # Download constants (SPIN)\n    http://visiondata.cis.upenn.edu/spin/data.tar.gz\n\n    # Download essentials (SMPLify-XMC)\n    https://download.is.tue.mpg.de/tuch/smplify-xmc-essentials.zip\n\n    # Download sketch2pose models\n    http://www-labs.iro.umontreal.ca/~bmpix/sketch2pose/models.zip\n\n    # Download test images\n    http://www-labs.iro.umontreal.ca/~bmpix/sketch2pose/images.zip\n)\nfor asset_url in \"${asset_urls[@]}\"; do\n    wget \\\n        -nc \\\n        -c \\\n        --directory-prefix \"${asset_dir}\" \\\n        \"${asset_url}\"\ndone\n\nmodels_dir=\"./models\"\nmkdir -p \"${models_dir}\"\n\nmodel_files=(\n    # Unzip smplx models\n    models_smplx_v1_1.zip\n\n    # Unzip essentials (SMPLifu-XMC)\n    smplify-xmc-essentials.zip\n\n    # Unzip sketch2pose models\n    models.zip\n)\n\nfor model_file in \"${model_files[@]}\"; do\n    unzip \\\n        -u \\\n        -d \"${models_dir}\" \\\n        \"${asset_dir}\"/\"${model_file}\"\ndone\n\n# Unzip constants (SPIN)\ntar \\\n    --skip-old-files \\\n    -xvf \"${asset_dir}\"/data.tar.gz \\\n    -C \"${models_dir}\" \\\n    data/smpl_mean_params.npz\n\ndata_dir=\"./data\"\nmkdir -p \"${data_dir}\"\n\n# Unzip test images\nunzip \\\n    -u \\\n    -d \"${data_dir}\" \\\n    \"${asset_dir}\"/images.zip\n"
  },
  {
    "path": "scripts/prepare.sh",
    "content": "#!/usr/bin/env sh\n\n\nset -euo pipefail\n\nvenv_dir=venv\npython -m venv --clear \"${venv_dir}\"\n\n. \"${venv_dir}\"/bin/activate\n\npip install -U pip setuptools\n\nextra=\"cpu\"\n[ -x \"$(command -v nvcc)\" ] && extra=\"cu113\"\npip install \\\n    torch \\\n    torchvision \\\n    --extra-index-url https://download.pytorch.org/whl/\"${extra}\"\n\npip install -r requirements.txt\n\nv=$(python -c 'import sys; v = sys.version_info; print(f\"{v.major}.{v.minor}\")')\nfor p in patches/*.diff; do\n    patch -p0 < <(sed \"s/python3.10/python${v}/\" \"${p}\")\ndone\n"
  },
  {
    "path": "scripts/run.sh",
    "content": "#!/usr/bin/env sh\n\n\nset -euo pipefail\n\n\nimg_dir=\"./data/images\"\nout_dir=\"./output\"\n\nfind \"${img_dir}\" -mindepth 1 -maxdepth 1 -type f -print0 \\\n    | xargs -0 -I \"{}\" python src/pose.py \\\n        --save-path \"${out_dir}\" \\\n        --img-path \"{}\" \\\n        --use-contacts \\\n        --use-natural \\\n        --use-cos \\\n        --use-angle-transf \\\n\nexit\n\n# baseline (SMPLify-XMC)\n\nfind \"${img_dir}\" -mindepth 1 -maxdepth 1 -type f -print0 \\\n    | xargs -0 -I \"{}\" python src/pose.py \\\n        --save-path \"${out_dir}_baseline\" \\\n        --img-path \"{}\" \\\n        --c-mse 1 \\\n        --c-par 0 \\\n        --use-contacts \\\n        --use-cos \\\n        --use-angle-transf \\\n\n# ablation\n\nfind \"${img_dir}\" -mindepth 1 -maxdepth 1 -type f -print0 \\\n    | xargs -0 -I \"{}\" python src/pose.py \\\n        --save-path \"${out_dir}_wocostransform\" \\\n        --img-path \"{}\" \\\n        --use-contacts \\\n        --use-natural \\\n        --use-cos \\\n\n\nfind \"${img_dir}\" -mindepth 1 -maxdepth 1 -type f -print0 \\\n    | xargs -0 -I \"{}\" python src/pose.py \\\n        --save-path \"${out_dir}_wocontacts\" \\\n        --img-path \"{}\" \\\n        --use-msc \\\n        --use-natural \\\n        --use-cos \\\n        --use-angle-transf \\\n\n\nfind \"${img_dir}\" -mindepth 1 -maxdepth 1 -type f -print0 \\\n    | xargs -0 -I \"{}\" python src/pose.py \\\n        --save-path \"${out_dir}_wonatural\" \\\n        --img-path \"{}\" \\\n        --use-contacts \\\n        --use-cos \\\n        --use-angle-transf \\\n"
  },
  {
    "path": "src/fist_pose.py",
    "content": "left_fist = [\n    0.0, 0.0, 0.0,\n    0.0, 0.0, 0.0,\n    0.0, 0.0, 0.0,\n    0.0, 0.0, 0.0,\n    0.0, 0.0, 0.0,\n    0.0, 0.0, 0.0,\n    0.0, 0.0, 0.0,\n    0.0, 0.0, 0.0,\n    0.0, 0.0, 0.0,\n    0.0, 0.0, 0.0,\n    0.0, 0.0, 0.0,\n    0.0, 0.0, 0.0,\n    0.0, 0.0, 0.0,\n    0.0, 0.0, 0.0,\n    0.0, 0.0, 0.0,\n    0.0, 0.0, 0.0,\n    0.0, 0.0, 0.0,\n    0.0, 0.0, 0.0,\n    0.0, 0.0, 0.0,\n    0.0, 0.0, 0.0,\n    0.0, 0.0, 0.0,\n    0.0, 0.0, 0.0,\n    0.0, 0.0, 0.0,\n    0.0, 0.0, 0.0,\n    0.0, 0.0, 0.0,\n    0.4183167815208435, 0.10645648092031479, -1.6593892574310303,\n    0.15252035856246948, -0.14700782299041748, -1.3719955682754517,\n    -0.04432843625545502, -0.15799851715564728, -0.938068151473999,\n    -0.12218914180994034, 0.073341965675354, -1.6415189504623413,\n    -0.14376045763492584, 0.1927780956029892, -1.3593589067459106,\n    -0.0851994976401329, 0.01652289740741253, -0.7474589347839355,\n    -0.9881719946861267, -0.3987707793712616, -1.3535722494125366,\n    -0.6686224937438965, 0.1261960119009018, -1.080643892288208,\n    -0.8101894855499268, -0.1306752860546112, -0.8412265777587891,\n    -0.3495230972766876, -0.17784251272678375, -1.4433038234710693,\n    -0.46278536319732666, 0.13677796721458435, -1.467200517654419,\n    -0.3681888282299042, 0.003404417773708701, -0.7764251232147217,\n    0.850964367389679, 0.2769227623939514, -0.09154807031154633,\n    0.14500413835048676, 0.09604815393686295, 0.219278022646904,\n    1.0451993942260742, 0.16911321878433228, -0.2426234930753708,\n    0.11167845129966736, -0.04289207234978676, 0.41644084453582764,\n    0.10881128907203674, 0.06598565727472305, 0.756219744682312,\n    -0.0963931530714035, 0.09091583639383316, 0.18845966458320618,\n    -0.11809506267309189, -0.050943851470947266, 0.5295845866203308,\n    -0.14369848370552063, -0.055241718888282776, 0.704857349395752,\n    -0.019182899966835976, 0.0923367589712143, 0.3379131853580475,\n    -0.45703303813934326, 0.1962839663028717, 0.6254575848579407,\n    -0.21465237438678741, 0.06599827855825424, 0.5068942308425903,\n    -0.36972442269325256, 0.0603446289896965, 0.07949023693799973,\n    -0.14186954498291016, 0.08585254102945328, 0.6355276107788086,\n    -0.3033415675163269, 0.05788097903132439, 0.6313892006874084,\n    -0.17612087726593018, 0.13209305703639984, 0.3733545243740082,\n    0.850964367389679, -0.2769227623939514, 0.09154807031154633,\n    -0.4998386800289154, -0.026556432247161865, -0.052880801260471344,\n    0.5355585217475891, -0.045960985124111176, 0.27735769748687744,\n]\n\nleft_right_fist = [\n    0.0, 0.0, 0.0,\n    0.0, 0.0, 0.0,\n    0.0, 0.0, 0.0,\n    0.0, 0.0, 0.0,\n    0.0, 0.0, 0.0,\n    0.0, 0.0, 0.0,\n    0.0, 0.0, 0.0,\n    0.0, 0.0, 0.0,\n    0.0, 0.0, 0.0,\n    0.0, 0.0, 0.0,\n    0.0, 0.0, 0.0,\n    0.0, 0.0, 0.0,\n    0.0, 0.0, 0.0,\n    0.0, 0.0, 0.0,\n    0.0, 0.0, 0.0,\n    0.0, 0.0, 0.0,\n    0.0, 0.0, 0.0,\n    0.0, 0.0, 0.0,\n    0.0, 0.0, 0.0,\n    0.0, 0.0, 0.0,\n    0.0, 0.0, 0.0,\n    0.0, 0.0, -0.0,\n    0.0, 0.0, 0.0,\n    0.0, 0.0, 0.0,\n    0.0, 0.0, 0.0,\n    0.4183167815208435, 0.10645648092031479, -1.6593892574310303,\n    0.15252035856246948, -0.14700782299041748, -1.3719955682754517,\n    -0.04432843625545502, -0.15799851715564728, -0.938068151473999,\n    -0.12218914180994034, 0.073341965675354, -1.6415189504623413,\n    -0.14376045763492584, 0.1927780956029892, -1.3593589067459106,\n    -0.0851994976401329, 0.01652289740741253, -0.7474589347839355,\n    -0.9881719946861267, -0.3987707793712616, -1.3535722494125366,\n    -0.6686224937438965, 0.1261960119009018, -1.080643892288208,\n    -0.8101894855499268, -0.1306752860546112, -0.8412265777587891,\n    -0.3495230972766876, -0.17784251272678375, -1.4433038234710693,\n    -0.46278536319732666, 0.13677796721458435, -1.467200517654419,\n    -0.3681888282299042, 0.003404417773708701, -0.7764251232147217,\n    0.850964367389679, 0.2769227623939514, -0.09154807031154633,\n    0.14500413835048676, 0.09604815393686295, 0.219278022646904,\n    1.0451993942260742, 0.16911321878433228, -0.2426234930753708,\n    0.4183167815208435, -0.10645647346973419, 1.6593892574310303,\n    0.15252038836479187, 0.14700786769390106, 1.3719956874847412,\n    -0.04432841017842293, 0.15799842774868011, 0.9380677938461304,\n    -0.12218913435935974, -0.0733419880270958, 1.6415191888809204,\n    -0.14376048743724823, -0.19277812540531158, 1.3593589067459106,\n    -0.08519953489303589, -0.016522908583283424, 0.7474592328071594,\n    -0.9881719350814819, 0.3987707495689392, 1.3535723686218262,\n    -0.6686226725578308, -0.12619605660438538, 1.080644130706787,\n    -0.8101896643638611, 0.1306752860546112, 0.8412266373634338,\n    -0.34952324628829956, 0.17784248292446136, 1.443304181098938,\n    -0.46278542280197144, -0.13677802681922913, 1.467200517654419,\n    -0.36818885803222656, -0.0034044249914586544, 0.7764251232147217,\n    0.8509642481803894, -0.2769228219985962, 0.09154807776212692,\n    0.14500458538532257, -0.09604845196008682, -0.21927869319915771,\n    1.0451991558074951, -0.1691131889820099, 0.242623433470726,\n]\n\nright_fist = []\nfor lf, lrf in zip(left_fist, left_right_fist):\n    if lf != lrf:\n        right_fist.append(lrf)\n    else:\n        right_fist.append(0)\n\n\nleft_flat_up = [\n    0.0, 0.0, 0.0,\n    0.0, 0.0, 0.0,\n    0.0, 0.0, 0.0,\n    0.0, 0.0, 0.0,\n    0.0, 0.0, 0.0,\n    0.0, 0.0, 0.0,\n    0.0, 0.0, 0.0,\n    0.0, 0.0, 0.0,\n    0.0, 0.0, 0.0,\n    0.0, 0.0, 0.0,\n    0.0, 0.0, 0.0,\n    0.0, 0.0, 0.0,\n    0.0, 0.0, 0.0,\n    0.0, 0.0, 0.0,\n    0.0, 0.0, 0.0,\n    0.0, 0.0, 0.0,\n    0.0, 0.0, 0.0,\n    0.0, 0.0, 0.0,\n    0.0, 0.0, 0.0,\n    0.0, 0.0, 0.0,\n    0.0, 0, 1.5129635334014893,\n    0.0, 0.0, 0.0,\n    0.0, 0.0, 0.0,\n    0.0, 0.0, 0.0,\n    0.0, 0.0, 0.0,\n    0.0, 0.0, 0.0,\n    0.0, 0.0, 0.0,\n    0.0, 0.0, 0.0,\n    0.0, 0.0, 0.0,\n    0.0, 0.0, 0.0,\n    0.0, 0.0, 0.0,\n    0.0, 0.0, 0.0,\n    0.0, 0.0, 0.0,\n    0.0, 0.0, 0.0,\n    0.0, 0.0, 0.0,\n    0.0, 0.0, 0.0,\n    0.0, 0.0, 0.0,\n    0.0, 0.0, 0.0,\n    0.0, 0.0, 0.0,\n    0.0, 0.0, 0.0,\n    0.0, 0.0, 0.0,\n    0.0, 0.0, 0.0,\n    0.0, 0.0, 0.0,\n    0.0, 0.0, 0.0,\n    0.0, 0.0, 0.0,\n    0.0, 0.0, 0.0,\n    0.0, 0.0, 0.0,\n    0.0, 0.0, 0.0,\n    0.0, 0.0, 0.0,\n    0.0, 0.0, 0.0,\n    0.0, 0.0, 0.0,\n    0.0, 0.0, 0.0,\n    0.0, 0.0, 0.0,\n    0.0, 0.0, 0.0,\n    0.0, 0.0, 0.0,\n]\n\nleft_flat_down = [\n    0.0, 0.0, 0.0,\n    0.0, 0.0, 0.0,\n    0.0, 0.0, 0.0,\n    0.0, 0.0, 0.0,\n    0.0, 0.0, 0.0,\n    0.0, 0.0, 0.0,\n    0.0, 0.0, 0.0,\n    0.0, 0.0, 0.0,\n    0.0, 0.0, 0.0,\n    0.0, 0.0, 0.0,\n    0.0, 0.0, 0.0,\n    0.0, 0.0, 0.0,\n    0.0, 0.0, 0.0,\n    0.0, 0.0, 0.0,\n    0.0, 0.0, 0.0,\n    0.0, 0.0, 0.0,\n    0.0, 0.0, 0.0,\n    0.0, 0.0, 0.0,\n    0.0, 0.0, 0.0,\n    0.0, 0.0, 0.0,\n    0.0, 0, -1.4648663997650146,\n    0.0, 0.0, 0.0,\n    0.0, 0.0, 0.0,\n    0.0, 0.0, 0.0,\n    0.0, 0.0, 0.0,\n    0.0, 0.0, 0.0,\n    0.0, 0.0, 0.0,\n    0.0, 0.0, 0.0,\n    0.0, 0.0, 0.0,\n    0.0, 0.0, 0.0,\n    0.0, 0.0, 0.0,\n    0.0, 0.0, 0.0,\n    0.0, 0.0, 0.0,\n    0.0, 0.0, 0.0,\n    0.0, 0.0, 0.0,\n    0.0, 0.0, 0.0,\n    0.0, 0.0, 0.0,\n    0.0, 0.0, 0.0,\n    0.0, 0.0, 0.0,\n    0.0, 0.0, 0.0,\n    0.0, 0.0, 0.0,\n    0.0, 0.0, 0.0,\n    0.0, 0.0, 0.0,\n    0.0, 0.0, 0.0,\n    0.0, 0.0, 0.0,\n    0.0, 0.0, 0.0,\n    0.0, 0.0, 0.0,\n    0.0, 0.0, 0.0,\n    0.0, 0.0, 0.0,\n    0.0, 0.0, 0.0,\n    0.0, 0.0, 0.0,\n    0.0, 0.0, 0.0,\n    0.0, 0.0, 0.0,\n    0.0, 0.0, 0.0,\n    0.0, 0.0, 0.0,\n]\n\nright_flat_up = [\n    0.0, 0.0, 0.0,\n    0.0, 0.0, 0.0,\n    0.0, 0.0, 0.0,\n    0.0, 0.0, 0.0,\n    0.0, 0.0, 0.0,\n    0.0, 0.0, 0.0,\n    0.0, 0.0, 0.0,\n    0.0, 0.0, 0.0,\n    0.0, 0.0, 0.0,\n    0.0, 0.0, 0.0,\n    0.0, 0.0, 0.0,\n    0.0, 0.0, 0.0,\n    0.0, 0.0, 0.0,\n    0.0, 0.0, 0.0,\n    0.0, 0.0, 0.0,\n    0.0, 0.0, 0.0,\n    0.0, 0.0, 0.0,\n    0.0, 0.0, 0.0,\n    0.0, 0.0, 0.0,\n    0.0, 0.0, 0.0,\n    0.0, 0.0, 0.0,\n    0.0, 0, -1.5021973848342896,\n    0.0, 0.0, 0.0,\n    0.0, 0.0, 0.0,\n    0.0, 0.0, 0.0,\n    0.0, 0.0, 0.0,\n    0.0, 0.0, 0.0,\n    0.0, 0.0, 0.0,\n    0.0, 0.0, 0.0,\n    0.0, 0.0, 0.0,\n    0.0, 0.0, 0.0,\n    0.0, 0.0, 0.0,\n    0.0, 0.0, 0.0,\n    0.0, 0.0, 0.0,\n    0.0, 0.0, 0.0,\n    0.0, 0.0, 0.0,\n    0.0, 0.0, 0.0,\n    0.0, 0.0, 0.0,\n    0.0, 0.0, 0.0,\n    0.0, 0.0, 0.0,\n    0.0, 0.0, 0.0,\n    0.0, 0.0, 0.0,\n    0.0, 0.0, 0.0,\n    0.0, 0.0, 0.0,\n    0.0, 0.0, 0.0,\n    0.0, 0.0, 0.0,\n    0.0, 0.0, 0.0,\n    0.0, 0.0, 0.0,\n    0.0, 0.0, 0.0,\n    0.0, 0.0, 0.0,\n    0.0, 0.0, 0.0,\n    0.0, 0.0, 0.0,\n    0.0, 0.0, 0.0,\n    0.0, 0.0, 0.0,\n    0.0, 0.0, 0.0,\n]\n\nright_flat_down = [\n    0.0, 0.0, 0.0,\n    0.0, 0.0, 0.0,\n    0.0, 0.0, 0.0,\n    0.0, 0.0, 0.0,\n    0.0, 0.0, 0.0,\n    0.0, 0.0, 0.0,\n    0.0, 0.0, 0.0,\n    0.0, 0.0, 0.0,\n    0.0, 0.0, 0.0,\n    0.0, 0.0, 0.0,\n    0.0, 0.0, 0.0,\n    0.0, 0.0, 0.0,\n    0.0, 0.0, 0.0,\n    0.0, 0.0, 0.0,\n    0.0, 0.0, 0.0,\n    0.0, 0.0, 0.0,\n    0.0, 0.0, 0.0,\n    0.0, 0.0, 0.0,\n    0.0, 0.0, 0.0,\n    0.0, 0.0, 0.0,\n    0.0, 0.0, 0.0,\n    0, 0, 1.494218111038208,\n    0.0, 0.0, 0.0,\n    0.0, 0.0, 0.0,\n    0.0, 0.0, 0.0,\n    0.0, 0.0, 0.0,\n    0.0, 0.0, 0.0,\n    0.0, 0.0, 0.0,\n    0.0, 0.0, 0.0,\n    0.0, 0.0, 0.0,\n    0.0, 0.0, 0.0,\n    0.0, 0.0, 0.0,\n    0.0, 0.0, 0.0,\n    0.0, 0.0, 0.0,\n    0.0, 0.0, 0.0,\n    0.0, 0.0, 0.0,\n    0.0, 0.0, 0.0,\n    0.0, 0.0, 0.0,\n    0.0, 0.0, 0.0,\n    0.0, 0.0, 0.0,\n    0.0, 0.0, 0.0,\n    0.0, 0.0, 0.0,\n    0.0, 0.0, 0.0,\n    0.0, 0.0, 0.0,\n    0.0, 0.0, 0.0,\n    0.0, 0.0, 0.0,\n    0.0, 0.0, 0.0,\n    0.0, 0.0, 0.0,\n    0.0, 0.0, 0.0,\n    0.0, 0.0, 0.0,\n    0.0, 0.0, 0.0,\n    0.0, 0.0, 0.0,\n    0.0, 0.0, 0.0,\n    0.0, 0.0, 0.0,\n    0.0, 0.0, 0.0,\n]\n\nrelaxed = [\n    0.0, 0.0, 0.0,\n    0.0, 0.0, 0.0,\n    0.0, 0.0, 0.0,\n    0.0, 0.0, 0.0,\n    0.0, 0.0, 0.0,\n    0.0, 0.0, 0.0,\n    0.0, 0.0, 0.0,\n    0.0, 0.0, 0.0,\n    0.0, 0.0, 0.0,\n    0.0, 0.0, 0.0,\n    0.0, 0.0, 0.0,\n    0.0, 0.0, 0.0,\n    0.0, 0.0, 0.0,\n    0.0, 0.0, 0.0,\n    0.0, 0.0, 0.0,\n    0.0, 0.0, 0.0,\n    0.0, 0.0, 0.0,\n    0.0, 0.0, 0.0,\n    0.0, 0.0, 0.0,\n    0.0, 0.0, 0.0,\n    0.0, 0.0, 0.0,\n    0.0, 0.0, 0.0,\n    0.0, 0.0, 0.0,\n    0.0, 0.0, 0.0,\n    0.0, 0.0, 0.0,\n    0.11167845129966736, 0.04289207234978676, -0.41644084453582764,\n    0.10881128907203674, -0.06598565727472305, -0.756219744682312,\n    -0.0963931530714035, -0.09091583639383316, -0.18845966458320618,\n    -0.11809506267309189, 0.050943851470947266, -0.5295845866203308,\n    -0.14369848370552063, 0.055241718888282776, -0.704857349395752,\n    -0.019182899966835976, -0.0923367589712143, -0.3379131853580475,\n    -0.45703303813934326, -0.1962839663028717, -0.6254575848579407,\n    -0.21465237438678741, -0.06599827855825424, -0.5068942308425903,\n    -0.36972442269325256, -0.0603446289896965, -0.07949023693799973,\n    -0.14186954498291016, -0.08585254102945328, -0.6355276107788086,\n    -0.3033415675163269, -0.05788097903132439, -0.6313892006874084,\n    -0.17612087726593018, -0.13209305703639984, -0.3733545243740082,\n    0.850964367389679, 0.2769227623939514, -0.09154807031154633,\n    -0.4998386800289154, 0.026556432247161865, 0.052880801260471344,\n    0.5355585217475891, 0.045960985124111176, -0.27735769748687744,\n    0.11167845129966736, -0.04289207234978676, 0.41644084453582764,\n    0.10881128907203674, 0.06598565727472305, 0.756219744682312,\n    -0.0963931530714035, 0.09091583639383316, 0.18845966458320618,\n    -0.11809506267309189, -0.050943851470947266, 0.5295845866203308,\n    -0.14369848370552063, -0.055241718888282776, 0.704857349395752,\n    -0.019182899966835976, 0.0923367589712143, 0.3379131853580475,\n    -0.45703303813934326, 0.1962839663028717, 0.6254575848579407,\n    -0.21465237438678741, 0.06599827855825424, 0.5068942308425903,\n    -0.36972442269325256, 0.0603446289896965, 0.07949023693799973,\n    -0.14186954498291016, 0.08585254102945328, 0.6355276107788086,\n    -0.3033415675163269, 0.05788097903132439, 0.6313892006874084,\n    -0.17612087726593018, 0.13209305703639984, 0.3733545243740082,\n    0.850964367389679, -0.2769227623939514, 0.09154807031154633,\n    -0.4998386800289154, -0.026556432247161865, -0.052880801260471344,\n    0.5355585217475891, -0.045960985124111176, 0.27735769748687744,\n]\n\n# body joints + left arm + right arm\n# 25 + 15 + 15\n# smpl(left_hand_pose, right_hand_pose)\n\nleft_start = 25 * 3\nleft_end = left_start + 15 * 3\nright_end = left_end + 15 * 3\n\nLEFT_FIST = left_fist[left_start:left_end]\nRIGHT_FIST = right_fist[left_end:right_end]\n\nLEFT_FLAT_UP = left_flat_up[20 * 3 : 20 * 3 + 3]\nLEFT_FLAT_DOWN = left_flat_down[20 * 3 : 20 * 3 + 3]\n\nRIGHT_FLAT_UP = right_flat_up[21 * 3 : 21 * 3 + 3]\nRIGHT_FLAT_DOWN = right_flat_down[21 * 3 : 21 * 3 + 3]\n\nLEFT_RELAXED = relaxed[left_start:left_end]\nRIGHT_RELAXED = relaxed[left_end:right_end]\n\nINT_TO_FIST = {\n    \"lfl\": None,\n    \"lf\": LEFT_FIST,\n    \"lu\": LEFT_FLAT_UP,\n    \"ld\": LEFT_FLAT_DOWN,\n    \"rfl\": None,\n    \"rf\": RIGHT_FIST,\n    \"ru\": RIGHT_FLAT_UP,\n    \"rd\": RIGHT_FLAT_DOWN,\n}\n"
  },
  {
    "path": "src/hist_cub.py",
    "content": "import itertools\nimport functools\nimport math\nimport multiprocessing\nfrom pathlib import Path\n\nimport matplotlib\nmatplotlib.rcParams.update({'font.size': 24})\nmatplotlib.rcParams.update({\n  \"text.usetex\": True,\n  \"text.latex.preamble\": r\"\\usepackage{biolinum} \\usepackage{libertineRoman} \\usepackage{libertineMono} \\usepackage{biolinum} \\usepackage[libertine]{newtxmath}\",\n  'ps.usedistiller': \"xpdf\",\n})\n\nimport matplotlib.pyplot as plt\nimport matplotlib.gridspec as gridspec\nimport numpy as np\nimport tqdm\nfrom scipy.stats import wasserstein_distance\n\nimport pose_estimation\n\n\ndef cub(x, a, b, c):\n    x2 = x * x\n    x3 = x2 * x\n\n    y = a * x3 + b * x2 + c * x\n\n    return y\n\n\ndef subsample(a, p=0.0005, seed=0):\n    np.random.seed(seed)\n    N = len(a)\n    inds = np.random.choice(range(N), size=int(p * N))\n    a = a[inds].copy()\n\n    return a\n\n\ndef read_cos_opt(path, fname=\"cos_hist.npy\"):\n    cos_opt = []\n    for p in Path(path).rglob(fname):\n        d = np.load(p)\n        cos_opt.append(d)\n\n    cos_opt = np.array(cos_opt)\n\n    return cos_opt\n\n\ndef plot_hist(cos_opt_dir, hist_smpl_fpath, params, out_dir, bins=10, xy=None):\n    cos_opt = read_cos_opt(cos_opt_dir)\n    angle_opt = np.arccos(cos_opt)\n    angle_opt2 = cub(angle_opt, *params)\n\n    cos_opt2 = np.cos(angle_opt2)\n    cos_smpl = np.load(hist_smpl_fpath)\n    # cos_smpl = subsample(cos_smpl)\n    print(cos_smpl.shape)\n\n    cos_smpl = np.clip(cos_smpl, -1, 1)\n\n    cos_opt = angle_opt\n    cos_opt2 = angle_opt2\n    cos_smpl = np.arccos(cos_smpl)\n\n    cos_opt = 180 / math.pi * cos_opt\n    cos_opt2 = 180 / math.pi * cos_opt2\n    cos_smpl = 180 / math.pi * cos_smpl\n    max_range = 90  # math.pi / 2\n\n    xticks = [0, 15, 30, 45, 60, 75, 90]\n    for idx, bone in enumerate(pose_estimation.SKELETON):\n        i, j = bone\n        i_name = pose_estimation.KPS[i]\n        j_name = pose_estimation.KPS[j]\n        if i_name != \"Left Upper Leg\":\n            continue\n\n        name = f\"{i_name}_{j_name}\"\n\n        gs = gridspec.GridSpec(2, 4)\n        fig = plt.figure(tight_layout=True, figsize=(16, 8), dpi=300)\n\n        ax0 = fig.add_subplot(gs[0, 0])\n        ax0.hist(cos_smpl[:, idx], bins=bins, range=(0, max_range), density=True)\n        ax0.set_xticks(xticks)\n        ax0.tick_params(labelbottom=False, labelleft=True)\n\n        ax1 = fig.add_subplot(gs[1, 0], sharex=ax0)\n        ax1.hist(cos_opt[:, idx], bins=bins, range=(0, max_range), density=True)\n        ax1.set_xticks(xticks)\n\n        if xy is not None:\n            ax2 = fig.add_subplot(gs[:, 1:3])\n            ax2.plot(xy[0], xy[1], linewidth=8)\n            ax2.plot(xy[0], xy[0], linewidth=4, linestyle=\"dashed\")\n            ax2.set_xticks(xticks)\n            ax2.set_yticks(xticks)\n\n        ax3 = fig.add_subplot(gs[0, 3], sharey=ax0)\n        ax3.hist(cos_opt2[:, idx], bins=bins, range=(0, max_range), density=True)\n        ax3.set_xticks(xticks)\n        ax3.tick_params(labelbottom=False, labelleft=False)\n\n        ax4 = fig.add_subplot(gs[1, 3], sharex=ax3, sharey=ax1)\n        alpha = 0.5\n        ax4.hist(cos_opt[:, idx], bins=bins, range=(0, max_range), density=True, label=r\"$\\mathcal{B}_i$\", alpha=alpha)\n        ax4.hist(cos_opt2[:, idx], bins=bins, range=(0, max_range), density=True, label=r\"$f(\\mathcal{B}_i)$\", alpha=alpha)\n        ax4.hist(cos_smpl[:, idx], bins=bins, range=(0, max_range), density=True, label=r\"$\\mathcal{A}_i$\", alpha=alpha)\n        ax4.set_xticks(xticks)\n        ax4.tick_params(labelbottom=True, labelleft=False)\n        ax4.legend()\n\n        fig.savefig(out_dir / f\"hist_{name}.png\")\n        plt.close()\n\n\ndef kldiv(p_hist, q_hist):\n    wd = wasserstein_distance(p_hist, q_hist)\n\n    return wd\n\n\ndef calc_histogram(x, bins=10, range=(0, 1)):\n    h, _ = np.histogram(x, bins=bins, range=range, density=True)\n\n    return h\n\ndef step(params, angles_opt, p_hist, bone_idx=None):\n    if sum(params) > 1:\n        return math.inf, params\n\n    kl = 0\n    for i, _ in enumerate(pose_estimation.SKELETON):\n        if bone_idx is not None and i != bone_idx:\n            continue\n\n        angles_opt2 = cub(angles_opt[:, i], *params)\n        if angles_opt2.max() > 1 or angles_opt2.min() < 0:\n            kl = math.inf\n\n            break\n\n        q_hist = calc_histogram(angles_opt2)\n\n        kl += kldiv(p_hist[i], q_hist)\n\n    return kl, params\n\n\ndef optimize(cos_opt_dir, hist_smpl_fpath, bone_idx=None):\n    cos_opt = read_cos_opt(cos_opt_dir)\n    angles_opt = np.arccos(cos_opt) / (math.pi / 2)\n    cos_smpl = np.load(hist_smpl_fpath)\n    # cos_smpl = subsample(cos_smpl)\n    print(cos_smpl.shape)\n    cos_smpl = np.clip(cos_smpl, -1, 1)\n    mask = cos_smpl <= 1\n    assert np.all(mask), (~mask).mean()\n    mask = cos_smpl >= 0\n    assert np.all(mask), (~mask).mean()\n    angles_smpl = np.arccos(cos_smpl) / (math.pi / 2)\n    p_hist = [\n        calc_histogram(angles_smpl[:, i])\n        for i, _ in enumerate(pose_estimation.SKELETON)\n    ]\n\n    with multiprocessing.Pool(8) as p:\n        results = list(\n            tqdm.tqdm(\n                p.imap_unordered(\n                    functools.partial(step, angles_opt=angles_opt, p_hist=p_hist, bone_idx=bone_idx),\n                    itertools.product(\n                        np.linspace(0, 20, 100),\n                        np.linspace(-20, 20, 200),\n                        np.linspace(-20, 1, 100),\n                    ),\n                ),\n                total=(100 * 200 * 100),\n            )\n        )\n\n    kls, params = zip(*results)\n    ind = np.argmin(kls)\n    best_params = params[ind]\n\n    print(kls[ind], best_params)\n\n    inds = np.argsort(kls)\n    for i in inds[:10]:\n        print(kls[i])\n        print(params[i])\n        print()\n\n    return best_params\n\n\ndef main():\n    cos_opt_dir = \"paper_single2_150mse\"\n    hist_smpl_fpath = \"./data/hist_smpl.npy\"\n    # hist_smpl_fpath = \"./testtest.npy\"\n    params = optimize(cos_opt_dir, hist_smpl_fpath)\n    # params = (1.2121212121212122, -1.105527638190953, 0.787878787878789)\n    # params = (0.20202020202020202, 0.30150753768844396, 0.3636363636363633)\n    print(params)\n\n    x = np.linspace(0, math.pi / 2, 100)\n    y = cub(x / (math.pi / 2), *params) * (math.pi / 2)\n    x = x * 180 / math.pi\n    y = y * 180 / math.pi\n\n    out_dir = Path(\"hists\")\n    out_dir.mkdir(parents=True, exist_ok=True)\n    plot_hist(cos_opt_dir, hist_smpl_fpath, params, out_dir, xy=(x, y))\n\n    plt.figure(figsize=(4, 4), dpi=300)\n    plt.plot(x, y, linewidth=6)\n    plt.plot(x, x, linewidth=2, linestyle=\"dashed\")\n    xticks = [0, 15, 30, 45, 60, 75, 90]\n    plt.xticks(xticks)\n    plt.yticks(xticks)\n    plt.axis(\"equal\")\n    plt.tight_layout()\n    plt.savefig(out_dir / \"new_out.png\")\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "src/losses.py",
    "content": "import itertools\n\nimport torch\nimport torch.nn as nn\n\nimport pose_estimation\n\n\nclass MSE(nn.Module):\n    def __init__(self, ignore=None):\n        super().__init__()\n\n        self.mse = torch.nn.MSELoss(reduction=\"none\")\n        self.ignore = ignore if ignore is not None else []\n\n    def forward(self, y_pred, y_data):\n        loss = self.mse(y_pred, y_data)\n\n        if len(self.ignore) > 0:\n            loss[self.ignore] *= 0\n\n        return loss.sum() / (len(loss) - len(self.ignore))\n\n\nclass Parallel(nn.Module):\n    def __init__(self, skeleton, ignore=None, ground_parallel=None):\n        super().__init__()\n\n        self.skeleton = skeleton\n        if ignore is not None:\n            self.ignore = set(ignore)\n        else:\n            self.ignore = set()\n\n        self.ground_parallel = ground_parallel if ground_parallel is not None else []\n        self.parallel_in_3d = []\n\n        self.cos = None\n\n    def forward(self, y_pred3d, y_data, z, spine_j, writer=None, global_step=0):\n        y_pred = y_pred3d[:, :2]\n        rleg, lleg = spine_j\n\n        Lcon2d = Lcount = 0\n        if hasattr(self, \"contact_2d\"):\n            for c2d in self.contact_2d:\n                for (\n                    (src_1, dst_1, t_1),\n                    (src_2, dst_2, t_2),\n                ) in itertools.combinations(c2d, 2):\n\n                    a_1 = torch.lerp(y_data[src_1], y_data[dst_1], t_1)\n                    a_2 = torch.lerp(y_data[src_2], y_data[dst_2], t_2)\n                    a = a_2 - a_1\n\n                    b_1 = torch.lerp(y_pred[src_1], y_pred[dst_1], t_1)\n                    b_2 = torch.lerp(y_pred[src_2], y_pred[dst_2], t_2)\n                    b = b_2 - b_1\n\n                    lcon2d = ((a - b) ** 2).sum()\n                    Lcon2d = Lcon2d + lcon2d\n                    Lcount += 1\n\n        if Lcount > 0:\n            Lcon2d = Lcon2d / Lcount\n\n        Ltan = Lpar = Lcos = Lcount = 0\n        Lspine = 0\n        for i, bone in enumerate(self.skeleton):\n            if bone in self.ignore:\n                continue\n\n            src, dst = bone\n\n            b = y_data[dst] - y_data[src]\n            t = nn.functional.normalize(b, dim=0)\n            n = torch.stack([-t[1], t[0]])\n\n            if src == 10 and dst == 11:  # right leg\n                a = rleg\n            elif src == 13 and dst == 14:  # left leg\n                a = lleg\n            else:\n                a = y_pred[dst] - y_pred[src]\n\n            bone_name = f\"{pose_estimation.KPS[src]}_{pose_estimation.KPS[dst]}\"\n            c = a - b\n            lcos_loc = ltan_loc = lpar_loc = 0\n            if self.cos is not None:\n                if bone not in [\n                    (1, 2),  # Neck + Right Shoulder\n                    (1, 5),  # Neck + Left Shoulder\n                    (9, 10),  # Hips + Right Upper Leg\n                    (9, 13),  # Hips + Left Upper Leg\n                ]:\n                    a = y_pred[dst] - y_pred[src]\n                    l2d = torch.norm(a, dim=0)\n                    l3d = torch.norm(y_pred3d[dst] - y_pred3d[src], dim=0)\n                    lcos = self.cos[i]\n\n                    lcos_loc = (l2d / l3d - lcos) ** 2\n                    Lcos = Lcos + lcos_loc\n                    lpar_loc = ((a / l2d) * n).sum() ** 2\n                    Lpar = Lpar + lpar_loc\n            else:\n                ltan_loc = ((c * t).sum()) ** 2\n                Ltan = Ltan + ltan_loc\n                lpar_loc = (c * n).sum() ** 2\n                Lpar = Lpar + lpar_loc\n\n            if writer is not None:\n                writer.add_scalar(f\"tan/{bone_name}\", ltan_loc, global_step=global_step)\n                writer.add_scalar(f\"cos/{bone_name}\", lcos_loc, global_step=global_step)\n                writer.add_scalar(f\"par/{bone_name}\", lpar_loc, global_step=global_step)\n\n            Lcount += 1\n\n        if Lcount > 0:\n            Ltan = Ltan / Lcount\n            Lcos = Lcos / Lcount\n            Lpar = Lpar / Lcount\n            Lspine = Lspine / Lcount\n\n        Lgr = Lcount = 0\n        for (src, dst), value in self.ground_parallel:\n            bone = y_pred[dst] - y_pred[src]\n            bone = nn.functional.normalize(bone, dim=0)\n            l = (torch.abs(bone[0]) - value) ** 2\n\n            Lgr = Lgr + l\n            Lcount += 1\n\n        if Lcount > 0:\n            Lgr = Lgr / Lcount\n\n        Lstraight3d = Lcount = 0\n        for (i, j), (k, l) in self.parallel_in_3d:\n            a = z[j] - z[i]\n            a = nn.functional.normalize(a, dim=0)\n            b = z[l] - z[k]\n            b = nn.functional.normalize(b, dim=0)\n            lo = (((a * b).sum() - 1) ** 2).sum()\n            Lstraight3d = Lstraight3d + lo\n            Lcount += 1\n\n            b = y_data[1] - y_data[8]\n            b = nn.functional.normalize(b, dim=0)\n\n        if Lcount > 0:\n            Lstraight3d = Lstraight3d / Lcount\n\n        return Ltan, Lcos, Lpar, Lspine, Lgr, Lstraight3d, Lcon2d\n\n\nclass MimickedSelfContactLoss(nn.Module):\n    def __init__(self, geodesics_mask):\n        super().__init__()\n        \"\"\"\n        Loss that lets vertices in contact on presented mesh attract vertices that are close.\n        \"\"\"\n        # geodesic distance mask\n        self.register_buffer(\"geomask\", geodesics_mask)\n\n    def forward(\n        self,\n        presented_contact,\n        vertices,\n        v2v=None,\n        contact_mode=\"dist_tanh\",\n        contact_thresh=1,\n    ):\n\n        contactloss = 0.0\n\n        if v2v is None:\n            # compute pairwise distances\n            verts = vertices.contiguous()\n            nv = verts.shape[1]\n            v2v = verts.squeeze().unsqueeze(1).expand(\n                nv, nv, 3\n            ) - verts.squeeze().unsqueeze(0).expand(nv, nv, 3)\n            v2v = torch.norm(v2v, 2, 2)\n\n        # loss for self-contact from mimic'ed pose\n        if len(presented_contact) > 0:\n            # without geodesic distance mask, compute distances\n            # between each pair of verts in contact\n            with torch.no_grad():\n                cvertstobody = v2v[presented_contact, :]\n                cvertstobody = cvertstobody[:, presented_contact]\n                maskgeo = self.geomask[presented_contact, :]\n                maskgeo = maskgeo[:, presented_contact]\n                weights = torch.ones_like(cvertstobody).to(verts.device)\n                weights[~maskgeo] = float(\"inf\")\n                min_idx = torch.min((cvertstobody + 1) * weights, 1)[1]\n                min_idx = presented_contact[min_idx.cpu().numpy()]\n\n            v2v_min = v2v[presented_contact, min_idx]\n\n            # tanh will not pull vertices that are ~more than contact_thres far apart\n            if contact_mode == \"dist_tanh\":\n                contactloss = contact_thresh * torch.tanh(v2v_min / contact_thresh)\n                contactloss = contactloss.mean()\n            else:\n                contactloss = v2v_min.mean()\n\n        return contactloss\n"
  },
  {
    "path": "src/pose.py",
    "content": "import argparse\nimport math\nfrom pathlib import Path\n\nimport cv2\nimport numpy as np\nimport PIL.Image as Image\nimport selfcontact\nimport selfcontact.losses\nimport shapely.geometry\nimport torch\nimport torch.nn as nn\nimport torch.optim as optim\nimport torchgeometry\nimport tqdm\nimport trimesh\nfrom skimage import measure\nfrom torch.utils.tensorboard.writer import SummaryWriter\n\nimport fist_pose\nimport hist_cub\nimport losses\nimport pose_estimation\nimport spin\nimport utils\n\nPE_KSP_TO_SPIN = {\n    \"Head\": \"Head\",\n    \"Neck\": \"Neck\",\n    \"Right Shoulder\": \"Right ForeArm\",\n    \"Right Arm\": \"Right Arm\",\n    \"Right Hand\": \"Right Hand\",\n    \"Left Shoulder\": \"Left ForeArm\",\n    \"Left Arm\": \"Left Arm\",\n    \"Left Hand\": \"Left Hand\",\n    \"Spine\": \"Spine1\",\n    \"Hips\": \"Hips\",\n    \"Right Upper Leg\": \"Right Upper Leg\",\n    \"Right Leg\": \"Right Leg\",\n    \"Right Foot\": \"Right Foot\",\n    \"Left Upper Leg\": \"Left Upper Leg\",\n    \"Left Leg\": \"Left Leg\",\n    \"Left Foot\": \"Left Foot\",\n    \"Left Toe\": \"Left Toe\",\n    \"Right Toe\": \"Right Toe\",\n}\nMODELS_DIR = \"models\"\n\n\ndef parse_args():\n    parser = argparse.ArgumentParser()\n\n    parser.add_argument(\n        \"--pose-estimation-model-path\",\n        type=str,\n        default=f\"./{MODELS_DIR}/hrn_w48_384x288.onnx\",\n        help=\"Pose Estimation model\",\n    )\n\n    parser.add_argument(\n        \"--contact-model-path\",\n        type=str,\n        default=f\"./{MODELS_DIR}/contact_hrn_w32_256x192.onnx\",\n        help=\"Contact model\",\n    )\n\n    parser.add_argument(\n        \"--device\",\n        type=str,\n        default=\"cuda\",\n        choices=[\"cpu\", \"cuda\"],\n        help=\"Torch device\",\n    )\n\n    parser.add_argument(\n        \"--spin-model-path\",\n        type=str,\n        default=f\"./{MODELS_DIR}/spin_model_smplx_eft_18.pt\",\n        help=\"SPIN model path\",\n    )\n\n    parser.add_argument(\n        \"--smpl-type\",\n        type=str,\n        default=\"smplx\",\n        choices=[\"smplx\"],\n        help=\"SMPL model type\",\n    )\n    parser.add_argument(\n        \"--smpl-model-dir\",\n        type=str,\n        default=f\"./{MODELS_DIR}/models/smplx\",\n        help=\"SMPL model dir\",\n    )\n    parser.add_argument(\n        \"--smpl-mean-params-path\",\n        type=str,\n        default=f\"./{MODELS_DIR}/data/smpl_mean_params.npz\",\n        help=\"SMPL mean params\",\n    )\n    parser.add_argument(\n        \"--essentials-dir\",\n        type=str,\n        default=f\"./{MODELS_DIR}/smplify-xmc-essentials\",\n        help=\"SMPL Essentials folder for contacts\",\n    )\n\n    parser.add_argument(\n        \"--parametrization-path\",\n        type=str,\n        default=f\"./{MODELS_DIR}/smplx_parametrization/parametrization.npy\",\n        help=\"Parametrization path\",\n    )\n    parser.add_argument(\n        \"--bone-parametrization-path\",\n        type=str,\n        default=f\"./{MODELS_DIR}/smplx_parametrization/bone_to_param2.npy\",\n        help=\"Bone parametrization path\",\n    )\n    parser.add_argument(\n        \"--foot-inds-path\",\n        type=str,\n        default=f\"./{MODELS_DIR}/smplx_parametrization/foot_inds.npy\",\n        help=\"Foot indinces\",\n    )\n\n    parser.add_argument(\n        \"--save-path\",\n        type=str,\n        required=True,\n        help=\"Path to save the results\",\n    )\n\n    parser.add_argument(\n        \"--img-path\",\n        type=str,\n        required=True,\n        help=\"Path to img to test\",\n    )\n\n    parser.add_argument(\n        \"--use-contacts\",\n        action=\"store_true\",\n        help=\"Use contact model\",\n    )\n    parser.add_argument(\n        \"--use-msc\",\n        action=\"store_true\",\n        help=\"Use MSC loss\",\n    )\n    parser.add_argument(\n        \"--use-natural\",\n        action=\"store_true\",\n        help=\"Use regularity\",\n    )\n    parser.add_argument(\n        \"--use-cos\",\n        action=\"store_true\",\n        help=\"Use cos model\",\n    )\n    parser.add_argument(\n        \"--use-angle-transf\",\n        action=\"store_true\",\n        help=\"Use cube foreshortening transformation\",\n    )\n\n    parser.add_argument(\n        \"--c-mse\",\n        type=float,\n        default=0,\n        help=\"MSE weight\",\n    )\n    parser.add_argument(\n        \"--c-par\",\n        type=float,\n        default=10,\n        help=\"Parallel weight\",\n    )\n\n    parser.add_argument(\n        \"--c-f\",\n        type=float,\n        default=1000,\n        help=\"Cos coef\",\n    )\n    parser.add_argument(\n        \"--c-parallel\",\n        type=float,\n        default=100,\n        help=\"Parallel weight\",\n    )\n    parser.add_argument(\n        \"--c-reg\",\n        type=float,\n        default=1000,\n        help=\"Regularity weight\",\n    )\n    parser.add_argument(\n        \"--c-cont2d\",\n        type=float,\n        default=1,\n        help=\"Contact 2D weight\",\n    )\n    parser.add_argument(\n        \"--c-msc\",\n        type=float,\n        default=17_500,\n        help=\"MSC weight\",\n    )\n\n    parser.add_argument(\n        \"--fist\",\n        nargs=\"+\",\n        type=str,\n        choices=list(fist_pose.INT_TO_FIST),\n    )\n\n    args = parser.parse_args()\n\n    return args\n\n\ndef freeze_layers(model):\n    for module in model.modules():\n        if type(module) is False:\n            continue\n\n        if isinstance(module, nn.modules.batchnorm._BatchNorm):\n            module.eval()\n            for m in module.parameters():\n                m.requires_grad = False\n\n        if isinstance(module, nn.Dropout):\n            module.eval()\n            for m in module.parameters():\n                m.requires_grad = False\n\n\ndef project_and_normalize_to_spin(vertices_3d, camera):\n    vertices_2d = vertices_3d  # [:, :2]\n\n    scale, translate = camera[0], camera[1:]\n    translate = scale.new_zeros(3)\n    translate[:2] = camera[1:]\n\n    vertices_2d = vertices_2d + translate\n    vertices_2d = scale * vertices_2d + 1\n    vertices_2d = spin.constants.IMG_RES / 2 * vertices_2d\n\n    return vertices_2d\n\n\ndef project_and_normalize_to_spin_legs(vertices_3d, A, camera):\n    A, J = A\n    A = A[0]\n    J = J[0]\n    L = vertices_3d.new_tensor(\n        [\n            [0.98619063, 0.16560926, 0.00127302],\n            [-0.16560601, 0.98603675, 0.01749799],\n            [0.00164258, -0.01746717, 0.99984609],\n        ]\n    )\n    R = vertices_3d.new_tensor(\n        [\n            [0.9910211, -0.13368178, -0.0025208],\n            [0.13367888, 0.99027076, 0.03864949],\n            [-0.00267045, -0.03863944, 0.99924965],\n        ]\n    )\n    scale = camera[0]\n    R = A[2, :3, :3] @ R  # 2 - right\n    L = A[1, :3, :3] @ L  # 1 - left\n    r = J[5] - J[2]\n    l = J[4] - J[1]\n\n    rleg = scale * spin.constants.IMG_RES / 2 * R @ r\n    lleg = scale * spin.constants.IMG_RES / 2 * L @ l\n\n    rleg = rleg[:2]\n    lleg = lleg[:2]\n\n    return rleg, lleg\n\n\ndef rotation_matrix_to_angle_axis(rotmat):\n    bs, n_joints, *_ = rotmat.size()\n    rotmat = torch.cat(\n        [\n            rotmat.view(-1, 3, 3),\n            rotmat.new_tensor([0, 0, 1], dtype=torch.float32)\n            .view(bs, 3, 1)\n            .expand(n_joints, -1, -1),\n        ],\n        dim=-1,\n    )\n    aa = torchgeometry.rotation_matrix_to_angle_axis(rotmat)\n    aa = aa.reshape(bs, 3 * n_joints)\n\n    return aa\n\n\ndef get_smpl_output(smpl, rotmat, betas, use_betas=True, zero_hands=False):\n    if smpl.name() == \"SMPL\":\n        smpl_output = smpl(\n            betas=betas if use_betas else None,\n            body_pose=rotmat[:, 1:],\n            global_orient=rotmat[:, 0].unsqueeze(1),\n            pose2rot=False,\n        )\n    elif smpl.name() == \"SMPL-X\":\n        rotmat = rotation_matrix_to_angle_axis(rotmat)\n        if zero_hands:\n            for i in [20, 21]:\n                rotmat[:, 3 * i : 3 * (i + 1)] = 0\n\n            for i in [12, 15]:  # neck, head\n                rotmat[:, 3 * i + 1] = 0  # y\n        smpl_output = smpl(\n            betas=betas if use_betas else None,\n            body_pose=rotmat[:, 3:],\n            global_orient=rotmat[:, :3],\n            pose2rot=True,\n        )\n    else:\n        raise NotImplementedError\n\n    return smpl_output, rotmat\n\n\ndef get_predictions(model_hmr, smpl, input_img, use_betas=True, zero_hands=False):\n    input_img = input_img.unsqueeze(0)\n    rotmat, betas, camera = model_hmr(input_img)\n\n    smpl_output, rotmat = get_smpl_output(\n        smpl, rotmat, betas, use_betas=use_betas, zero_hands=zero_hands\n    )\n\n    rotmat = rotmat.squeeze(0)\n    betas = betas.squeeze(0)\n    camera = camera.squeeze(0)\n    z = smpl_output.joints\n    z = z.squeeze(0)\n\n    return rotmat, betas, camera, smpl_output, z\n\n\ndef get_pred_and_data(\n    model_hmr, smpl, selector, input_img, use_betas=True, zero_hands=False\n):\n    rotmat, betas, camera, smpl_output, zz = get_predictions(\n        model_hmr, smpl, input_img, use_betas=use_betas, zero_hands=zero_hands\n    )\n\n    joints = smpl_output.joints.squeeze(0)\n    joints_2d = project_and_normalize_to_spin(joints, camera)\n    rleg, lleg = project_and_normalize_to_spin_legs(joints, smpl_output.A, camera)\n    joints_2d_orig = joints_2d\n    joints_2d = joints_2d[selector]\n\n    vertices = smpl_output.vertices.squeeze(0)\n    vertices_2d = project_and_normalize_to_spin(vertices, camera)\n\n    zz = zz[selector]\n\n    return (\n        rotmat,\n        betas,\n        camera,\n        joints_2d,\n        zz,\n        vertices_2d,\n        smpl_output,\n        (rleg, lleg),\n        joints_2d_orig,\n    )\n\n\ndef normalize_keypoints_to_spin(keypoints_2d, img_size):\n    h, w = img_size\n    if h > w:  # vertically\n        ax1 = 1\n        ax2 = 0\n    else:  # horizontal\n        ax1 = 0\n        ax2 = 1\n\n    shift = (img_size[ax1] - img_size[ax2]) / 2\n    scale = spin.constants.IMG_RES / img_size[ax2]\n    keypoints_2d_normalized = np.copy(keypoints_2d)\n    keypoints_2d_normalized[:, ax2] -= shift\n    keypoints_2d_normalized *= scale\n\n    return keypoints_2d_normalized, shift, scale, ax2\n\n\ndef unnormalize_keypoints_from_spin(keypoints_2d, shift, scale, ax2):\n    keypoints_2d_normalized = np.copy(keypoints_2d)\n    keypoints_2d_normalized /= scale\n    keypoints_2d_normalized[:, ax2] += shift\n\n    return keypoints_2d_normalized\n\n\ndef get_vertices_in_heatmap(contact_heatmap):\n    contact_heatmap_size = contact_heatmap.shape[:2]\n    label = measure.label(contact_heatmap)\n\n    y_data_conts = []\n    for i in range(1, label.max() + 1):\n        predicted_kps_contact = np.vstack(np.nonzero(label == i)[::-1]).T.astype(\n            \"float\"\n        )\n        predicted_kps_contact_scaled, *_ = normalize_keypoints_to_spin(\n            predicted_kps_contact, contact_heatmap_size\n        )\n        y_data_cont = torch.from_numpy(predicted_kps_contact_scaled).int().tolist()\n        y_data_cont = shapely.geometry.MultiPoint(y_data_cont).convex_hull\n        y_data_conts.append(y_data_cont)\n\n    return y_data_conts\n\n\ndef get_contact_heatmap(model_contact, img_path, thresh=0.5):\n    contact_heatmap = pose_estimation.infer_single_image(\n        model_contact,\n        img_path,\n        input_img_size=(192, 256),\n        return_kps=False,\n    )\n    contact_heatmap = contact_heatmap.squeeze(0)\n    contact_heatmap_orig = contact_heatmap.copy()\n\n    mi = contact_heatmap.min()\n    ma = contact_heatmap.max()\n    contact_heatmap = (contact_heatmap - mi) / (ma - mi)\n    contact_heatmap_ = ((contact_heatmap > thresh) * 255).astype(\"uint8\")\n\n    contact_heatmap = np.repeat(contact_heatmap[..., None], repeats=3, axis=-1)\n    contact_heatmap = (contact_heatmap * 255).astype(\"uint8\")\n\n    return contact_heatmap_, contact_heatmap, contact_heatmap_orig\n\n\ndef discretize(parametrization, n_bins=100):\n    bins = np.linspace(0, 1, n_bins + 1)\n    inds = np.digitize(parametrization, bins)\n    disc_parametrization = bins[inds - 1]\n\n    return disc_parametrization\n\n\ndef get_mapping_from_params_to_verts(verts, params):\n    mapping = {}\n    for v, t in zip(verts, params):\n        mapping.setdefault(t, []).append(v)\n\n    return mapping\n\n\ndef find_contacts(y_data_conts, keypoints_2d, bone_to_params, thresh=12, step=0.0072246375):\n    n_bins = int(math.ceil(1 / step)) - 1  # mean face's circumradius\n    contact = []\n    contact_2d = []\n    for_mask = []\n    for y_data_cont in y_data_conts:\n        contact_loc = []\n        contact_2d_loc = []\n        buffer = y_data_cont.buffer(thresh)\n        mask_add = False\n        for i, j in pose_estimation.SKELETON:\n            verts, t3d = bone_to_params[(i, j)]\n            if len(verts) == 0:\n                continue\n\n            t3d = discretize(t3d, n_bins=n_bins)\n            t3d_to_verts = get_mapping_from_params_to_verts(verts, t3d)\n            t3d_to_verts_sorted = sorted(t3d_to_verts.items(), key=lambda x: x[0])\n            t3d_sorted_np = np.array([x for x, _ in t3d_to_verts_sorted])\n\n            line = shapely.geometry.LineString([keypoints_2d[i], keypoints_2d[j]])\n            lint = buffer.intersection(line)\n            if len(lint.boundary.geoms) < 2:\n                continue\n\n            t2d_start = line.project(lint.boundary.geoms[0], normalized=True)\n            t2d_end = line.project(lint.boundary.geoms[1], normalized=True)\n            assert t2d_start <= t2d_end\n\n            t2ds = discretize(\n                np.linspace(t2d_start, t2d_end, n_bins + 1), n_bins=n_bins\n            )\n            to_add = False\n            for t2d in t2ds:\n                if t2d < t3d_sorted_np[0] or t2d > t3d_sorted_np[-1]:\n                    continue\n\n                t2d_ind = np.searchsorted(t3d_sorted_np, t2d)\n                c = t3d_to_verts_sorted[t2d_ind][1]\n\n                contact_loc.extend(c)\n                to_add = True\n                mask_add = True\n\n                if t2d_ind + 1 < len(t3d_to_verts_sorted):\n                    c = t3d_to_verts_sorted[t2d_ind + 1][1]\n                    contact_loc.extend(c)\n\n                if t2d_ind > 0:\n                    c = t3d_to_verts_sorted[t2d_ind - 1][1]\n                    contact_loc.extend(c)\n\n            if to_add:\n                contact_2d_loc.append((i, j, t2d_start + 0.5 * (t2d_end - t2d_start)))\n\n        if mask_add:\n            for_mask.append(buffer.exterior.coords.xy)\n\n        contact_loc = sorted(set(contact_loc))\n        contact_loc = np.array(contact_loc, dtype=\"int\")\n        contact.append(contact_loc)\n        contact_2d.append(contact_2d_loc)\n\n    for_mask = [np.stack((x, y), axis=0).T[:, None].astype(\"int\") for x, y in for_mask]\n\n    return contact, contact_2d, for_mask\n\n\ndef optimize(\n    model_hmr,\n    smpl,\n    selector,\n    input_img,\n    keypoints_2d,\n    optimizer,\n    args,\n    loss_mse=None,\n    loss_parallel=None,\n    c_mse=0.0,\n    c_new_mse=1.0,\n    c_beta=1e-3,\n    sc_crit=None,\n    msc_crit=None,\n    contact=None,\n    n_steps=60,\n    save_path=None,\n    writer=None,\n    i_ini=0,\n):\n    to_save = False\n    if save_path is not None:\n        (\n            img_original,\n            predicted_keypoints_2d,\n            save_path,\n            shift,\n            scale,\n            ax2,\n            prefix,\n        ) = save_path\n        to_save = True\n\n    mean_zfoot_val = {}\n    with tqdm.trange(n_steps) as pbar:\n        for i in pbar:\n            global_step = i + i_ini\n            optimizer.zero_grad()\n\n            (\n                rotmat_pred,\n                betas_pred,\n                camera_pred,\n                keypoints_3d_pred,\n                z,\n                vertices_2d_pred,\n                smpl_output,\n                (rleg, lleg),\n                joints_2d_orig,\n            ) = get_pred_and_data(\n                model_hmr,\n                smpl,\n                selector,\n                input_img,\n            )\n            keypoints_2d_pred = keypoints_3d_pred[:, :2]\n            if to_save:\n                utils.save_results_image(\n                    camera=camera_pred.detach().cpu().numpy(),\n                    focal_length_x=spin.constants.FOCAL_LENGTH,\n                    focal_length_y=spin.constants.FOCAL_LENGTH,\n                    vertices=smpl_output.vertices.detach()[0].cpu().numpy(),\n                    input_img=img_original,\n                    faces=smpl.faces,\n                    keypoints=predicted_keypoints_2d,\n                    keypoints_2=unnormalize_keypoints_from_spin(\n                        keypoints_2d_pred.detach().cpu().numpy(), shift, scale, ax2\n                    ),\n                    # keypoints_2=unnormalize_keypoints_from_spin(joints_2d_orig.detach().cpu().numpy(), shift, scale, ax2),\n                    # heatmap=predicted_contact_heatmap_raw,\n                    filename=save_path / f\"{prefix}_{i:0>4}.png\",\n                    contactlist=contact,\n                    user_study=False,\n                )\n\n            loss = l2 = 0.0\n            if c_mse > 0 and loss_mse is not None:\n                l2 = loss_mse(keypoints_2d_pred, keypoints_2d)\n                loss = loss + c_mse * l2\n\n                if writer is not None:\n                    writer.add_scalar(\"mse\", l2, global_step=global_step)\n\n            vertices_pred = smpl_output.vertices\n\n            lpar = z_loss = loss_sh = 0.0\n            if c_new_mse > 0 and loss_parallel is not None:\n                Ltan, Lcos, Lpar, Lspine, Lgr, Lstraight3d, Lcon2d = loss_parallel(\n                    keypoints_3d_pred,\n                    keypoints_2d,\n                    z,\n                    (rleg, lleg),\n                    writer=writer,\n                    global_step=global_step,\n                )\n                lpar = (\n                    Ltan\n                    + c_new_mse * (args.c_f * Lcos + args.c_parallel * Lpar)\n                    + Lspine\n                    + args.c_reg * Lgr\n                    + args.c_reg * Lstraight3d\n                    + args.c_cont2d * Lcon2d\n                )\n                loss = loss + 300 * lpar\n\n                if writer is not None:\n                    writer.add_scalar(\"tan\", Ltan, global_step=global_step)\n                    writer.add_scalar(\"cos\", Lcos, global_step=global_step)\n                    writer.add_scalar(\"par\", Lpar, global_step=global_step)\n                    writer.add_scalar(\"spine\", Lspine, global_step=global_step)\n                    writer.add_scalar(\"ground/chain\", Lgr, global_step=global_step)\n                    writer.add_scalar(\n                        \"straight_in_3d\", Lstraight3d, global_step=global_step\n                    )\n                    writer.add_scalar(\"contact/con2d\", Lcon2d, global_step=global_step)\n\n                for side in [\"left\", \"right\"]:\n                    attr = f\"{side}_foot_inds\"\n                    if hasattr(loss_parallel, attr):\n                        foot_inds = getattr(loss_parallel, attr)\n                        zind = 1\n                        if attr not in mean_zfoot_val:\n                            with torch.no_grad():\n                                mean_zfoot_val[attr] = torch.median(\n                                    vertices_pred[0, foot_inds, zind], dim=0\n                                ).values\n\n                        loss_foot = (\n                            (vertices_pred[0, foot_inds, zind] - mean_zfoot_val[attr])\n                            ** 2\n                        ).sum()\n                        loss = loss + args.c_reg * loss_foot\n\n                        if writer is not None:\n                            writer.add_scalar(\n                                f\"ground/{side} foot\",\n                                loss_foot,\n                                global_step=global_step,\n                            )\n\n                if hasattr(loss_parallel, \"silhuette_vertices_inds\"):\n                    inds = loss_parallel.silhuette_vertices_inds\n                    loss_sh = (\n                        (vertices_pred[0, inds, 1] - loss_parallel.ground) ** 2\n                    ).sum()\n                    loss = loss + args.c_reg * loss_sh\n\n                    if writer is not None:\n                        writer.add_scalar(\n                            \"ground/silhuette\", loss_sh, global_step=global_step\n                        )\n\n            lbeta = (betas_pred**2).mean()\n            lcam = ((torch.exp(-camera_pred[0] * 10)) ** 2).mean()\n            loss = loss + c_beta * lbeta + lcam\n\n            if writer is not None:\n                writer.add_scalar(\"loss/beta\", lbeta, global_step=global_step)\n                writer.add_scalar(\"loss/cam\", lcam, global_step=global_step)\n\n            lgsc_a = gsc_contact_loss = faces_angle_loss = 0.0\n            if sc_crit is not None:\n                gsc_contact_loss, faces_angle_loss = sc_crit(\n                    vertices_pred,\n                )\n                lgsc_a = 1000 * gsc_contact_loss + 0.1 * faces_angle_loss\n                loss = loss + lgsc_a\n\n                if writer is not None:\n                    writer.add_scalar(\n                        \"contact/gsc\", gsc_contact_loss, global_step=global_step\n                    )\n                    writer.add_scalar(\n                        \"contact/faces_angle\", faces_angle_loss, global_step=global_step\n                    )\n\n            msc_loss = 0.0\n            if contact is not None and len(contact) > 0 and msc_crit is not None:\n                if not isinstance(contact, list):\n                    contact = [contact]\n\n                for cntct in contact:\n                    msc_loss = msc_crit(\n                        cntct,\n                        vertices_pred,\n                    )\n                    loss = loss + args.c_msc * msc_loss\n\n                    if writer is not None:\n                        writer.add_scalar(\n                            \"contact/msc\", msc_loss, global_step=global_step\n                        )\n\n            loss.backward()\n            optimizer.step()\n\n            epoch_loss = loss.item()\n            pbar.set_postfix(\n                **{\n                    \"l\": f\"{epoch_loss:.3}\",\n                    \"l2\": f\"{l2:.3}\",\n                    \"par\": f\"{lpar:.3}\",\n                    \"beta\": f\"{lbeta:.3}\",\n                    \"cam\": f\"{lcam:.3}\",\n                    \"z\": f\"{z_loss:.3}\",\n                    \"gsc_contact\": f\"{float(gsc_contact_loss):.3}\",\n                    \"faces_angle\": f\"{float(faces_angle_loss):.3}\",\n                    \"msc\": f\"{float(msc_loss):.3}\",\n                }\n            )\n\n    with torch.no_grad():\n        (\n            rotmat_pred,\n            betas_pred,\n            camera_pred,\n            keypoints_3d_pred,\n            z,\n            vertices_2d_pred,\n            smpl_output,\n            (rleg, lleg),\n            joints_2d_orig,\n        ) = get_pred_and_data(\n            model_hmr,\n            smpl,\n            selector,\n            input_img,\n            zero_hands=True,\n        )\n\n    return (\n        rotmat_pred,\n        betas_pred,\n        camera_pred,\n        keypoints_3d_pred,\n        vertices_2d_pred,\n        smpl_output,\n        z,\n        joints_2d_orig,\n    )\n\n\ndef optimize_ft(\n    theta,\n    camera,\n    smpl,\n    selector,\n    input_img,\n    keypoints_2d,\n    args,\n    loss_mse=None,\n    loss_parallel=None,\n    c_mse=0.0,\n    c_new_mse=1.0,\n    sc_crit=None,\n    msc_crit=None,\n    contact=None,\n    n_steps=60,\n    save_path=None,\n    writer=None,\n    i_ini=0,\n    zero_hands=False,\n    fist=None,\n):\n    to_save = False\n    if save_path is not None:\n        (\n            img_original,\n            predicted_keypoints_2d,\n            save_path,\n            shift,\n            scale,\n            ax2,\n            prefix,\n        ) = save_path\n        to_save = True\n\n    mean_zfoot_val = {}\n\n    theta = theta.detach().clone()\n    camera = camera.detach().clone()\n    rotmat_pred = nn.Parameter(theta)\n    camera_pred = nn.Parameter(camera)\n    optimizer = torch.optim.Adam(\n        [\n            rotmat_pred,\n            camera_pred,\n        ],\n        lr=1e-3,\n    )\n    global_step = i_ini\n\n    with tqdm.trange(n_steps) as pbar:\n        for i in pbar:\n            global_step = i + i_ini\n            optimizer.zero_grad()\n\n            global_orient = rotmat_pred[:3]\n            body_pose = rotmat_pred[3:]\n            smpl_output = smpl(\n                global_orient=global_orient.unsqueeze(0),\n                body_pose=body_pose.unsqueeze(0),\n                pose2rot=True,\n            )\n\n            z = smpl_output.joints\n            z = z.squeeze(0)\n\n            joints = smpl_output.joints.squeeze(0)\n            joints_2d = project_and_normalize_to_spin(joints, camera_pred)\n            rleg, lleg = project_and_normalize_to_spin_legs(\n                joints, smpl_output.A, camera_pred\n            )\n            joints_2d = joints_2d[selector]\n            z = z[selector]\n            keypoints_3d_pred = joints_2d\n\n            keypoints_2d_pred = keypoints_3d_pred[:, :2]\n            if to_save:\n                utils.save_results_image(\n                    camera=camera_pred.detach().cpu().numpy(),\n                    focal_length_x=spin.constants.FOCAL_LENGTH,\n                    focal_length_y=spin.constants.FOCAL_LENGTH,\n                    vertices=smpl_output.vertices.detach()[0].cpu().numpy(),\n                    input_img=img_original,\n                    faces=smpl.faces,\n                    keypoints=predicted_keypoints_2d,\n                    keypoints_2=unnormalize_keypoints_from_spin(\n                        keypoints_2d_pred.detach().cpu().numpy(), shift, scale, ax2\n                    ),\n                    # keypoints_2=unnormalize_keypoints_from_spin(joints_2d_orig.detach().cpu().numpy(), shift, scale, ax2),\n                    # heatmap=predicted_contact_heatmap_raw,\n                    filename=save_path / f\"{prefix}_{i:0>4}.png\",\n                    contactlist=contact,\n                    user_study=False,\n                )\n\n            lprior = ((rotmat_pred - theta) ** 2).sum() + (\n                (camera_pred - camera) ** 2\n            ).sum()\n            loss = lprior\n\n            l2 = 0.0\n            if c_mse > 0 and loss_mse is not None:\n                l2 = loss_mse(keypoints_2d_pred, keypoints_2d)\n                loss = loss + c_mse * l2\n\n                if writer is not None:\n                    writer.add_scalar(\"mse\", l2, global_step=global_step)\n\n            vertices_pred = smpl_output.vertices\n\n            lpar = z_loss = loss_sh = 0.0\n            if c_new_mse > 0 and loss_parallel is not None:\n                Ltan, Lcos, Lpar, Lspine, Lgr, Lstraight3d, Lcon2d = loss_parallel(\n                    keypoints_3d_pred,\n                    keypoints_2d,\n                    z,\n                    (rleg, lleg),\n                    writer=writer,\n                    global_step=global_step,\n                )\n                lpar = (\n                    Ltan\n                    + c_new_mse * (args.c_f * Lcos + args.c_parallel * Lpar)\n                    + Lspine\n                    + args.c_reg * Lgr\n                    + args.c_reg * Lstraight3d\n                    + args.c_cont2d * Lcon2d\n                )\n                loss = loss + 300 * lpar\n\n                if writer is not None:\n                    writer.add_scalar(\"tan\", Ltan, global_step=global_step)\n                    writer.add_scalar(\"cos\", Lcos, global_step=global_step)\n                    writer.add_scalar(\"par\", Lpar, global_step=global_step)\n                    writer.add_scalar(\"spine\", Lspine, global_step=global_step)\n                    writer.add_scalar(\"ground/chain\", Lgr, global_step=global_step)\n                    writer.add_scalar(\n                        \"straight_in_3d\", Lstraight3d, global_step=global_step\n                    )\n                    writer.add_scalar(\"contact/con2d\", Lcon2d, global_step=global_step)\n\n                for side in [\"left\", \"right\"]:\n                    attr = f\"{side}_foot_inds\"\n                    if hasattr(loss_parallel, attr):\n                        foot_inds = getattr(loss_parallel, attr)\n                        zind = 1\n                        if attr not in mean_zfoot_val:\n                            with torch.no_grad():\n                                mean_zfoot_val[attr] = torch.median(\n                                    vertices_pred[0, foot_inds, zind], dim=0\n                                ).values\n\n                        loss_foot = (\n                            (vertices_pred[0, foot_inds, zind] - mean_zfoot_val[attr])\n                            ** 2\n                        ).sum()\n                        loss = loss + args.c_reg * loss_foot\n\n                        if writer is not None:\n                            writer.add_scalar(\n                                f\"ground/{side} foot\",\n                                loss_foot,\n                                global_step=global_step,\n                            )\n\n                if hasattr(loss_parallel, \"silhuette_vertices_inds\"):\n                    inds = loss_parallel.silhuette_vertices_inds\n                    loss_sh = (\n                        (vertices_pred[0, inds, 1] - loss_parallel.ground) ** 2\n                    ).sum()\n                    loss = loss + args.c_reg * loss_sh\n\n                    if writer is not None:\n                        writer.add_scalar(\n                            \"ground/silhuette\", loss_sh, global_step=global_step\n                        )\n\n            lgsc_a = gsc_contact_loss = faces_angle_loss = 0.0\n            if sc_crit is not None:\n                gsc_contact_loss, faces_angle_loss = sc_crit(vertices_pred)\n                lgsc_a = 1000 * gsc_contact_loss + 0.1 * faces_angle_loss\n                loss = loss + lgsc_a\n\n                if writer is not None:\n                    writer.add_scalar(\n                        \"contact/gsc\", gsc_contact_loss, global_step=global_step\n                    )\n                    writer.add_scalar(\n                        \"contact/faces_angle\", faces_angle_loss, global_step=global_step\n                    )\n\n            msc_loss = 0.0\n            if contact is not None and len(contact) > 0 and msc_crit is not None:\n                if not isinstance(contact, list):\n                    contact = [contact]\n\n                for cntct in contact:\n                    msc_loss = msc_crit(\n                        cntct,\n                        vertices_pred,\n                    )\n                    loss = loss + args.c_msc * msc_loss\n\n                    if writer is not None:\n                        writer.add_scalar(\n                            \"contact/msc\", msc_loss, global_step=global_step\n                        )\n\n            loss.backward()\n            optimizer.step()\n\n            epoch_loss = loss.item()\n            pbar.set_postfix(\n                **{\n                    \"l\": f\"{epoch_loss:.3}\",\n                    \"l2\": f\"{l2:.3}\",\n                    \"par\": f\"{lpar:.3}\",\n                    \"z\": f\"{z_loss:.3}\",\n                    \"gsc_contact\": f\"{float(gsc_contact_loss):.3}\",\n                    \"faces_angle\": f\"{float(faces_angle_loss):.3}\",\n                    \"msc\": f\"{float(msc_loss):.3}\",\n                }\n            )\n\n    rotmat_pred = rotmat_pred.detach()\n\n    if zero_hands:\n        for i in [20, 21]:\n            rotmat_pred[3 * i : 3 * (i + 1)] = 0\n\n        for i in [12, 15]:  # neck, head\n            rotmat_pred[3 * i + 1] = 0  # y\n\n    global_orient = rotmat_pred[:3]\n    body_pose = rotmat_pred[3:]\n    left_hand_pose = None\n    right_hand_pose = None\n    if fist is not None:\n        left_hand_pose = rotmat_pred.new_tensor(fist_pose.LEFT_RELAXED).unsqueeze(0)\n        right_hand_pose = rotmat_pred.new_tensor(fist_pose.RIGHT_RELAXED).unsqueeze(0)\n        for f in fist:\n            pp = fist_pose.INT_TO_FIST[f]\n            if pp is not None:\n                pp = rotmat_pred.new_tensor(pp).unsqueeze(0)\n\n            if f.startswith(\"lf\"):\n                left_hand_pose = pp\n            elif f.startswith(\"rf\"):\n                right_hand_pose = pp\n            elif f.startswith(\"l\"):\n                body_pose[19 * 3 : 19 * 3 + 3] = pp\n                left_hand_pose = None\n            elif f.startswith(\"r\"):\n                body_pose[20 * 3 : 20 * 3 + 3] = pp\n                right_hand_pose = None\n            else:\n                raise RuntimeError(f\"No such hand pose: {f}\")\n\n    with torch.no_grad():\n        smpl_output = smpl(\n            global_orient=global_orient.unsqueeze(0),\n            body_pose=body_pose.unsqueeze(0),\n            left_hand_pose=left_hand_pose,\n            right_hand_pose=right_hand_pose,\n            pose2rot=True,\n        )\n\n    return rotmat_pred, smpl_output\n\n\ndef create_bone(i, j, keypoints_2d):\n    a = keypoints_2d[i]\n    b = keypoints_2d[j]\n    ab = b - a\n    ab = torch.nn.functional.normalize(ab, dim=0)\n\n    return ab\n\n\ndef is_parallel_to_plane(bone, thresh=21):\n    return abs(bone[0]) > math.cos(math.radians(thresh))\n\n\ndef is_close_to_plane(bone, plane, thresh):\n    dist = abs(bone[0] - plane)\n\n    return dist < thresh\n\n\ndef get_selector():\n    selector = []\n    for kp in pose_estimation.KPS:\n        tmp = spin.JOINT_NAMES.index(PE_KSP_TO_SPIN[kp])\n        selector.append(tmp)\n\n    return selector\n\n\ndef calc_cos(joints_2d, joints_3d):\n    cos = []\n    for i, j in pose_estimation.SKELETON:\n        a = joints_2d[i] - joints_2d[j]\n        a = nn.functional.normalize(a, dim=0)\n\n        b = joints_3d[i] - joints_3d[j]\n        b = nn.functional.normalize(b, dim=0)[:2]\n\n        c = (a * b).sum()\n        cos.append(c)\n\n    cos = torch.stack(cos, dim=0)\n\n    return cos\n\n\ndef get_natural(keypoints_2d, vertices, right_foot_inds, left_foot_inds, loss_parallel, smpl):\n    height_2d = (\n        keypoints_2d.max(dim=0).values[0] - keypoints_2d.min(dim=0).values[0]\n    ).item()\n    plane_2d = keypoints_2d.max(dim=0).values[0].item()\n\n    ground_parallel = []\n    parallel_in_3d = []\n    parallel3d_bones = set()\n\n    # parallel chains\n    for i, j, k in [\n        (\"Right Upper Leg\", \"Right Leg\", \"Right Foot\"),\n        (\"Right Leg\", \"Right Foot\", \"Right Toe\"),  # to remove?\n        (\"Left Upper Leg\", \"Left Leg\", \"Left Foot\"),\n        (\"Left Leg\", \"Left Foot\", \"Left Toe\"),  # to remove?\n        (\"Right Shoulder\", \"Right Arm\", \"Right Hand\"),\n        (\"Left Shoulder\", \"Left Arm\", \"Left Hand\"),\n        # (\"Hips\", \"Spine\", \"Neck\"),\n        # (\"Spine\", \"Neck\", \"Head\"),\n    ]:\n        i = pose_estimation.KPS.index(i)\n        j = pose_estimation.KPS.index(j)\n        k = pose_estimation.KPS.index(k)\n        upleg_leg = create_bone(i, j, keypoints_2d)\n        leg_foot = create_bone(j, k, keypoints_2d)\n\n        if is_parallel_to_plane(upleg_leg) and is_parallel_to_plane(leg_foot):\n            if is_close_to_plane(\n                upleg_leg, plane_2d, thresh=0.1 * height_2d\n            ) or is_close_to_plane(leg_foot, plane_2d, thresh=0.1 * height_2d):\n                ground_parallel.append(((i, j), 1))\n                ground_parallel.append(((j, k), 1))\n\n        if (upleg_leg * leg_foot).sum() > math.cos(math.radians(21)):\n            parallel_in_3d.append(((i, j), (j, k)))\n            parallel3d_bones.add((i, j))\n            parallel3d_bones.add((j, k))\n\n    # parallel feets\n    for i, j in [\n        (\"Right Foot\", \"Right Toe\"),\n        (\"Left Foot\", \"Left Toe\"),\n    ]:\n        i = pose_estimation.KPS.index(i)\n        j = pose_estimation.KPS.index(j)\n        if (i, j) in parallel3d_bones:\n            continue\n\n        foot_toe = create_bone(i, j, keypoints_2d)\n        if is_parallel_to_plane(foot_toe, thresh=25):\n            if \"Right\" in pose_estimation.KPS[i]:\n                loss_parallel.right_foot_inds = right_foot_inds\n            else:\n                loss_parallel.left_foot_inds = left_foot_inds\n\n    loss_parallel.ground_parallel = ground_parallel\n    loss_parallel.parallel_in_3d = parallel_in_3d\n\n    vertices_np = vertices[0].cpu().numpy()\n    if len(ground_parallel) > 0:\n        # Silhuette veritices\n        mesh = trimesh.Trimesh(vertices=vertices_np, faces=smpl.faces, process=False)\n        silhuette_vertices_mask_1 = np.abs(mesh.vertex_normals[..., 2]) < 2e-1\n        height_3d = vertices_np[:, 1].max() - vertices_np[:, 1].min()\n        plane_3d = vertices_np[:, 1].max()\n        silhuette_vertices_mask_2 = (\n            np.abs(vertices_np[:, 1] - plane_3d) < 0.15 * height_3d\n        )\n        silhuette_vertices_mask = np.logical_and(\n            silhuette_vertices_mask_1, silhuette_vertices_mask_2\n        )\n        (silhuette_vertices_inds,) = np.where(silhuette_vertices_mask)\n        if len(silhuette_vertices_inds) > 0:\n            loss_parallel.silhuette_vertices_inds = silhuette_vertices_inds\n            loss_parallel.ground = plane_3d\n\n\ndef get_cos(keypoints_3d_pred, use_angle_transf, loss_parallel):\n    keypoints_2d_pred = keypoints_3d_pred[:, :2]\n    with torch.no_grad():\n        cos_r = calc_cos(keypoints_2d_pred, keypoints_3d_pred)\n\n    alpha = torch.acos(cos_r)\n    if use_angle_transf:\n        leg_inds = [\n            5,\n            6,  # right leg\n            7,\n            8,  # left leg\n        ]\n        foot_inds = [15, 16]\n        nleg_inds = sorted(\n            set(range(len(pose_estimation.SKELETON))) - set(leg_inds) - set(foot_inds)\n        )\n        alpha[nleg_inds] = alpha[nleg_inds] - alpha[nleg_inds].min()\n\n        amli = alpha[leg_inds].min()\n        leg_inds.extend(foot_inds)\n        alpha[leg_inds] = alpha[leg_inds] - amli\n\n        angles = alpha.detach().cpu().numpy()\n        angles = hist_cub.cub(\n            angles / (math.pi / 2),\n            a=1.2121212121212122,\n            b=-1.105527638190953,\n            c=0.787878787878789,\n        ) * (math.pi / 2)\n        alpha = alpha.new_tensor(angles)\n\n    loss_parallel.cos = torch.cos(alpha)\n\n    return cos_r\n\n\ndef save_mesh_with_winding_numbers(sc_module, vertices, smpl, save_path):\n    triangles = sc_module.triangles(vertices)\n    exterior = sc_module.get_intersection_mask(vertices, triangles, test_segments=False)\n    exterior = exterior.cpu().numpy().squeeze(0)\n    utils.save_mesh_with_colors(\n        vertices[0].cpu().numpy(),\n        smpl.faces,\n        save_path / \"winding_numbers.ply\",\n        mask=exterior,\n    )\n\n    exterior = sc_module.get_intersection_mask(vertices, triangles)\n    exterior = exterior.cpu().numpy().squeeze(0)\n    utils.save_mesh_with_colors(\n        vertices[0].cpu().numpy(),\n        smpl.faces,\n        save_path / \"winding_numbers_filtered.ply\",\n        mask=exterior,\n    )\n\n\ndef get_contacts(\n    args,\n    sc_module,\n    y_data_conts,\n    keypoints_2d,\n    vertices,\n    bone_to_params,\n    loss_parallel,\n    img_size_original,\n    save_path,\n):\n    use_contacts = args.use_contacts\n    use_msc = args.use_msc\n    c_mse = args.c_mse\n\n    if use_contacts:\n        assert c_mse == 0\n        contact, contact_2d, for_mask = find_contacts(\n            y_data_conts, keypoints_2d, bone_to_params\n        )\n        if len(contact_2d) > 0:\n            loss_parallel.contact_2d = contact_2d\n\n            mask = np.zeros((spin.constants.IMG_RES, spin.constants.IMG_RES), dtype=\"uint8\")\n            mask += 255\n            cv2.drawContours(mask, for_mask, -1, 0, 2)\n            mask = cv2.resize(mask, img_size_original[::-1])\n            cv2.imwrite(str(save_path / \"mask.png\"), mask)\n\n        if len(contact) == 0:\n            _, contact = sc_module.verts_in_contact(vertices, return_idx=True)\n            contact = contact.cpu().numpy().ravel()\n    elif use_msc:\n        _, contact = sc_module.verts_in_contact(vertices, return_idx=True)\n        contact = contact.cpu().numpy().ravel()\n    else:\n        contact = np.array([])\n\n    return contact\n\n\ndef save_all(\n    keypoints_3d_pred,\n    rotmat_pred,\n    camera_pred,\n    betas_pred,\n    smpl,\n    contact,\n    img_original,\n    predicted_keypoints_2d,\n    predicted_contact_heatmap_raw,\n    loss_parallel,\n    smpl_output,\n    shift,\n    scale,\n    ax2,\n    summary_writer,\n    save_path,\n    fname,\n):\n    keypoints_2d_pred = keypoints_3d_pred[:, :2]\n\n    vertices = smpl_output.vertices.detach()\n    betas_pred = betas_pred.detach().cpu().numpy()\n\n    utils.save_pose_params(\n        rotmat_pred,\n        camera_pred,\n        betas_pred,\n        vertices,\n        smpl,\n        contact,\n        save_path / f\"{fname}.pkl\",\n    )\n\n    if hasattr(loss_parallel, \"silhuette_vertices_inds\"):\n        contact.append(loss_parallel.silhuette_vertices_inds)\n\n    img_sw = utils.save_results_image(\n        camera=camera_pred.detach().cpu().numpy(),\n        focal_length_x=spin.constants.FOCAL_LENGTH,\n        focal_length_y=spin.constants.FOCAL_LENGTH,\n        vertices=vertices[0].cpu().numpy(),\n        input_img=img_original,\n        faces=smpl.faces,\n        keypoints=predicted_keypoints_2d,\n        keypoints_2=unnormalize_keypoints_from_spin(\n            keypoints_2d_pred.cpu().numpy(), shift, scale, ax2\n        )\n        if shift is not None\n        else None,\n        # keypoints_2=unnormalize_keypoints_from_spin(joints_2d_orig.detach().cpu().numpy(), shift, scale, ax2) if shift is not None else None,\n        heatmap=predicted_contact_heatmap_raw,\n        filename=save_path / f\"{fname}.png\",\n        contactlist=contact,\n        contact2dlist=loss_parallel.contact_2d\n        if hasattr(loss_parallel, \"contact_2d\")\n        else None,\n        cos=loss_parallel.cos.tolist() if loss_parallel.cos is not None else None,\n    )\n\n    utils.save_mesh_with_colors(\n        smpl_output.vertices[0].cpu().numpy(),\n        smpl.faces,\n        save_path / f\"{fname}.ply\",\n        inds=contact,\n    )\n\n    joints = smpl_output.joints.squeeze(0).cpu().numpy()\n    fig = utils.plot_3D(joints, vertices.squeeze(0).cpu().numpy(), smpl.faces)\n    fig.write_html(save_path / f\"{fname}.html\")\n\n    summary_writer.add_image(\n        fname, np.array(img_sw).astype(\"float32\") / 255, dataformats=\"HWC\"\n    )\n    summary_writer.add_mesh(\n        fname,\n        vertices=(\n            vertices.cpu().float()[0]\n            @ torch.tensor([[1, 0, 0], [0, -1, 0], [0, 0, -1.0]])\n        ).unsqueeze(0),\n        faces=torch.from_numpy(smpl.faces[None].astype(\"int64\")),\n    )\n\n\ndef spin_step(\n    model_hmr,\n    smpl,\n    selector,\n    input_img,\n    img_original,\n    predicted_keypoints_2d,\n    predicted_contact_heatmap_raw,\n    loss_parallel,\n    shift,\n    scale,\n    ax2,\n    summary_writer,\n    save_path,\n):\n    with torch.no_grad():\n        (\n            rotmat_pred,\n            betas_pred,\n            camera_pred,\n            keypoints_3d_pred,\n            _,\n            _,\n            smpl_output,\n            _,\n            _,\n        ) = get_pred_and_data(\n            model_hmr,\n            smpl,\n            selector,\n            input_img,\n            zero_hands=True,\n        )\n\n    save_all(\n        keypoints_3d_pred,\n        rotmat_pred,\n        camera_pred,\n        betas_pred,\n        smpl,\n        None,\n        img_original,\n        predicted_keypoints_2d,\n        predicted_contact_heatmap_raw,\n        loss_parallel,\n        smpl_output,\n        shift,\n        scale,\n        ax2,\n        summary_writer,\n        save_path,\n        \"spin\",\n    )\n\n\ndef eft_step(\n    model_hmr,\n    smpl,\n    selector,\n    input_img,\n    keypoints_2d,\n    optimizer,\n    args,\n    loss_mse,\n    loss_parallel,\n    c_beta,\n    sc_module,\n    y_data_conts,\n    bone_to_params,\n    img_original,\n    predicted_keypoints_2d,\n    predicted_contact_heatmap_raw,\n    shift,\n    scale,\n    ax2,\n    summary_writer,\n    save_path,\n):\n    img_size_original = img_original.shape[:2]\n    (\n        rotmat_pred,\n        betas_pred,\n        camera_pred,\n        keypoints_3d_pred,\n        _,\n        smpl_output,\n        _,\n        _,\n    ) = optimize(\n        model_hmr,\n        smpl,\n        selector,\n        input_img,\n        keypoints_2d,\n        optimizer,\n        args,\n        loss_mse=loss_mse,\n        loss_parallel=loss_parallel,\n        c_mse=1,\n        c_new_mse=0,\n        c_beta=c_beta,\n        sc_crit=None,\n        msc_crit=None,\n        contact=None,\n        n_steps=60 + 90,\n        writer=summary_writer,\n    )\n\n    # find contacts\n    vertices = smpl_output.vertices.detach()\n    contact = get_contacts(\n        args,\n        sc_module,\n        y_data_conts,\n        keypoints_2d,\n        vertices,\n        bone_to_params,\n        loss_parallel,\n        img_size_original,\n        save_path,\n    )\n\n    save_all(\n        keypoints_3d_pred,\n        rotmat_pred,\n        camera_pred,\n        betas_pred,\n        smpl,\n        contact,\n        img_original,\n        predicted_keypoints_2d,\n        predicted_contact_heatmap_raw,\n        loss_parallel,\n        smpl_output,\n        shift,\n        scale,\n        ax2,\n        summary_writer,\n        save_path,\n        \"eft\",\n    )\n\n    if sc_module is not None:\n        save_mesh_with_winding_numbers(sc_module, vertices, smpl, save_path)\n\n    return vertices, keypoints_3d_pred, contact\n\n\ndef dc_step(\n    model_hmr,\n    smpl,\n    selector,\n    input_img,\n    keypoints_2d,\n    optimizer,\n    args,\n    loss_mse,\n    loss_parallel,\n    c_mse,\n    c_new_mse,\n    c_beta,\n    sc_crit,\n    msc_crit,\n    contact,\n    use_contacts,\n    use_msc,\n    img_original,\n    predicted_keypoints_2d,\n    predicted_contact_heatmap_raw,\n    shift,\n    scale,\n    ax2,\n    summary_writer,\n    save_path,\n):\n    (\n        rotmat_pred,\n        betas_pred,\n        camera_pred,\n        keypoints_3d_pred,\n        _,\n        smpl_output,\n        _,\n        _,\n    ) = optimize(\n        model_hmr,\n        smpl,\n        selector,\n        input_img,\n        keypoints_2d,\n        optimizer,\n        args,\n        loss_mse=loss_mse,\n        loss_parallel=loss_parallel,\n        c_mse=c_mse,\n        c_new_mse=c_new_mse,\n        c_beta=c_beta,\n        sc_crit=sc_crit,\n        msc_crit=msc_crit if use_contacts or use_msc else None,\n        contact=contact if use_contacts or use_msc else None,\n        n_steps=60 if use_contacts or use_msc else 0,  # + 60,\n        # save_path=(img_original, predicted_keypoints_2d, save_path, shift, scale, ax2, \"dc\"),\n        writer=summary_writer,\n        i_ini=60 + 90,\n    )\n\n    save_all(\n        keypoints_3d_pred,\n        rotmat_pred,\n        camera_pred,\n        betas_pred,\n        smpl,\n        contact,\n        img_original,\n        predicted_keypoints_2d,\n        predicted_contact_heatmap_raw,\n        loss_parallel,\n        smpl_output,\n        shift,\n        scale,\n        ax2,\n        summary_writer,\n        save_path,\n        \"dc\",\n    )\n\n    return rotmat_pred\n\n\ndef us_step(\n    model_hmr,\n    smpl,\n    selector,\n    input_img,\n    rotmat_pred,\n    keypoints_2d,\n    args,\n    loss_mse,\n    loss_parallel,\n    c_mse,\n    c_new_mse,\n    sc_crit,\n    msc_crit,\n    contact,\n    use_contacts,\n    use_msc,\n    img_original,\n    keypoints_3d_pred,\n    summary_writer,\n    save_path,\n):\n    (_, _, camera_pred_us, _, _, _, smpl_output_us, _, _,) = get_pred_and_data(\n        model_hmr,\n        smpl,\n        selector,\n        input_img,\n        use_betas=False,\n        zero_hands=True,\n    )\n\n    rotmat_pred_us, smpl_output_us = optimize_ft(\n        rotmat_pred,\n        camera_pred_us,\n        smpl,\n        selector,\n        input_img,\n        keypoints_2d,\n        args,\n        loss_mse=loss_mse,\n        loss_parallel=loss_parallel,\n        c_mse=c_mse,\n        c_new_mse=c_new_mse,\n        sc_crit=sc_crit,\n        msc_crit=msc_crit if use_contacts or use_msc else None,\n        contact=contact if use_contacts or use_msc else None,\n        n_steps=60 if use_contacts or use_msc else 0,  # + 60,\n        # save_path=(img_original, predicted_keypoints_2d, save_path, shift, scale, ax2, \"dc\"),\n        writer=summary_writer,\n        i_ini=60 + 90 + 60,\n        zero_hands=True,\n        fist=args.fist,\n    )\n\n    save_all(\n        keypoints_3d_pred,\n        rotmat_pred_us,\n        camera_pred_us,\n        torch.zeros(1, 10, dtype=torch.float32),\n        smpl,\n        None,\n        img_original,\n        None,\n        None,\n        loss_parallel,\n        smpl_output_us,\n        None,\n        None,\n        None,\n        summary_writer,\n        save_path,\n        \"us\",\n    )\n\n\ndef main():\n    args = parse_args()\n    print(args)\n\n    # models\n    model_pose = cv2.dnn.readNetFromONNX(\n        args.pose_estimation_model_path\n    )  # \"hrn_w48_384x288.onnx\"\n    model_contact = cv2.dnn.readNetFromONNX(\n        args.contact_model_path\n    )  # \"contact_hrn_w32_256x192.onnx\"\n\n    device = (\n        torch.device(args.device) if torch.cuda.is_available() else torch.device(\"cpu\")\n    )\n    model_hmr = spin.hmr(args.smpl_mean_params_path)  # \"smpl_mean_params.npz\"\n    model_hmr.to(device)\n    checkpoint = torch.load(\n        args.spin_model_path,  # \"spin_model_smplx_eft_18.pt\"\n        map_location=\"cpu\"\n    )\n\n    smpl = spin.SMPLX(\n        args.smpl_model_dir,  # \"models/smplx\"\n        batch_size=1,\n        create_transl=False,\n        use_pca=False,\n        flat_hand_mean=args.fist is not None,\n    )\n    smpl.to(device)\n\n    selector = get_selector()\n\n    use_contacts = args.use_contacts\n    use_msc = args.use_msc\n\n    bone_to_params = np.load(args.bone_parametrization_path, allow_pickle=True).item()\n    foot_inds = np.load(args.foot_inds_path, allow_pickle=True).item()\n    left_foot_inds = foot_inds[\"left_foot_inds\"]\n    right_foot_inds = foot_inds[\"right_foot_inds\"]\n\n    if use_contacts:\n        model_type = args.smpl_type\n        sc_module = selfcontact.SelfContact(\n            essentials_folder=args.essentials_dir,  # \"smplify-xmc-essentials\"\n            geothres=0.3,\n            euclthres=0.02,\n            test_segments=True,\n            compute_hd=True,\n            model_type=model_type,\n            device=device,\n        )\n        sc_module.to(device)\n\n        sc_crit = selfcontact.losses.SelfContactLoss(\n            contact_module=sc_module,\n            inside_loss_weight=0.5,\n            outside_loss_weight=0.0,\n            contact_loss_weight=0.5,\n            align_faces=True,\n            use_hd=True,\n            test_segments=True,\n            device=device,\n            model_type=model_type,\n        )\n        sc_crit.to(device)\n\n        msc_crit = losses.MimickedSelfContactLoss(geodesics_mask=sc_module.geomask)\n        msc_crit.to(device)\n    else:\n        sc_module = None\n        sc_crit = None\n        msc_crit = None\n\n    loss_mse = losses.MSE([1, 10, 13])  # Neck + Right Upper Leg + Left Upper Leg\n\n    ignore = (\n        (1, 2),  # Neck + Right Shoulder\n        (1, 5),  # Neck + Left Shoulder\n        (9, 10),  # Hips + Right Upper Leg\n        (9, 13),  # Hips + Left Upper Leg\n    )\n    loss_parallel = losses.Parallel(\n        skeleton=pose_estimation.SKELETON,\n        ignore=ignore,\n    )\n\n    c_mse = args.c_mse\n    c_new_mse = args.c_par\n    c_beta = 1e-3\n\n    if c_mse > 0:\n        assert c_new_mse == 0\n    elif c_mse == 0:\n        assert c_new_mse > 0\n\n    root_path = Path(args.save_path)\n    root_path.mkdir(exist_ok=True, parents=True)\n\n    path_to_imgs = Path(args.img_path)\n    if path_to_imgs.is_dir():\n        path_to_imgs = path_to_imgs.iterdir()\n    else:\n        path_to_imgs = [path_to_imgs]\n\n    for img_path in path_to_imgs:\n        if not any(\n            img_path.name.lower().endswith(ext) for ext in [\".jpg\", \".png\", \".jpeg\"]\n        ):\n            continue\n\n        img_name = img_path.stem\n\n        # use 2d keypoints detection\n        (\n            img_original,\n            predicted_keypoints_2d,\n            _,\n            _,\n        ) = pose_estimation.infer_single_image(\n            model_pose,\n            img_path,\n            input_img_size=pose_estimation.IMG_SIZE,\n            return_kps=True,\n        )\n\n        save_path = root_path / img_name\n        save_path.mkdir(exist_ok=True, parents=True)\n        # if (save_path / \"us_orig.png\").is_file():\n        #     return\n\n        summary_writer = SummaryWriter(log_dir=save_path / f\"runDoknc2_{c_new_mse}\")\n\n        img_original = cv2.cvtColor(img_original, cv2.COLOR_BGR2RGB)\n        img_size_original = img_original.shape[:2]\n        keypoints_2d, shift, scale, ax2 = normalize_keypoints_to_spin(\n            predicted_keypoints_2d, img_size_original\n        )\n        keypoints_2d = torch.from_numpy(keypoints_2d)\n        keypoints_2d = keypoints_2d.to(device)\n\n        (\n            predicted_contact_heatmap,\n            predicted_contact_heatmap_raw,\n            very_hm_raw,\n        ) = get_contact_heatmap(model_contact, img_path)\n        predicted_contact_heatmap_raw = Image.fromarray(\n            predicted_contact_heatmap_raw\n        ).resize(img_size_original[::-1])\n        predicted_contact_heatmap_raw = cv2.resize(very_hm_raw, img_size_original[::-1])\n\n        if c_new_mse == 0:\n            predicted_contact_heatmap_raw = None\n\n        y_data_conts = get_vertices_in_heatmap(predicted_contact_heatmap)\n\n        model_hmr.load_state_dict(checkpoint[\"model\"], strict=True)\n        model_hmr.train()\n        freeze_layers(model_hmr)\n\n        _, input_img = spin.process_image(img_path, input_res=spin.constants.IMG_RES)\n        input_img = input_img.to(device)\n\n        spin_step(\n            model_hmr,\n            smpl,\n            selector,\n            input_img,\n            img_original,\n            predicted_keypoints_2d,\n            predicted_contact_heatmap_raw,\n            loss_parallel,\n            shift,\n            scale,\n            ax2,\n            summary_writer,\n            save_path,\n        )\n\n        optimizer = optim.Adam(\n            filter(lambda p: p.requires_grad, model_hmr.parameters()),\n            lr=1e-6,\n        )\n\n        vertices, keypoints_3d_pred, contact = eft_step(\n            model_hmr,\n            smpl,\n            selector,\n            input_img,\n            keypoints_2d,\n            optimizer,\n            args,\n            loss_mse,\n            loss_parallel,\n            c_beta,\n            sc_module,\n            y_data_conts,\n            bone_to_params,\n            img_original,\n            predicted_keypoints_2d,\n            predicted_contact_heatmap_raw,\n            shift,\n            scale,\n            ax2,\n            summary_writer,\n            save_path,\n        )\n\n        if args.use_natural:\n            get_natural(\n                keypoints_2d, vertices, right_foot_inds, left_foot_inds, loss_parallel, smpl,\n            )\n\n        if args.use_cos:\n            cos_r = get_cos(keypoints_3d_pred, args.use_angle_transf, loss_parallel)\n            np.save(save_path / \"cos_hist\", cos_r.cpu().numpy())\n\n        rotmat_pred = dc_step(\n            model_hmr,\n            smpl,\n            selector,\n            input_img,\n            keypoints_2d,\n            optimizer,\n            args,\n            loss_mse,\n            loss_parallel,\n            c_mse,\n            c_new_mse,\n            c_beta,\n            sc_crit,\n            msc_crit,\n            contact,\n            use_contacts,\n            use_msc,\n            img_original,\n            predicted_keypoints_2d,\n            predicted_contact_heatmap_raw,\n            shift,\n            scale,\n            ax2,\n            summary_writer,\n            save_path,\n        )\n\n        us_step(\n            model_hmr,\n            smpl,\n            selector,\n            input_img,\n            rotmat_pred,\n            keypoints_2d,\n            args,\n            loss_mse,\n            loss_parallel,\n            c_mse,\n            c_new_mse,\n            sc_crit,\n            msc_crit,\n            contact,\n            use_contacts,\n            use_msc,\n            img_original,\n            keypoints_3d_pred,\n            summary_writer,\n            save_path,\n        )\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "src/pose_estimation.py",
    "content": "import math\n\nimport cv2\nimport numpy as np\n\nIMG_SIZE = (288, 384)\nMEAN = np.array([0.485, 0.456, 0.406])\nSTD = np.array([0.229, 0.224, 0.225])\n\nKPS = (\n    \"Head\",\n    \"Neck\",\n    \"Right Shoulder\",\n    \"Right Arm\",\n    \"Right Hand\",\n    \"Left Shoulder\",\n    \"Left Arm\",\n    \"Left Hand\",\n    \"Spine\",\n    \"Hips\",\n    \"Right Upper Leg\",\n    \"Right Leg\",\n    \"Right Foot\",\n    \"Left Upper Leg\",\n    \"Left Leg\",\n    \"Left Foot\",\n    \"Left Toe\",\n    \"Right Toe\",\n)\n\nSKELETON = (\n    (0, 1),\n    (1, 8),\n    (8, 9),\n    (9, 10),\n    (9, 13),\n    (10, 11),\n    (11, 12),\n    (13, 14),\n    (14, 15),\n    (1, 2),\n    (2, 3),\n    (3, 4),\n    (1, 5),\n    (5, 6),\n    (6, 7),\n    (15, 16),\n    (12, 17),\n)\n\n\nOPENPOSE_TO_GESTURE = (\n    0,  # 0 Head\\n\",\n    1,  #   Neck\\n\",\n    2,  # 2 Right Shoulder\\n\",\n    3,  #   Right Arm\\n\",\n    4,  # 4 Right Hand\\n\",\n    5,  #   Left Shoulder\\n\",\n    6,  # 6 Left Arm\\n\",\n    7,  #   Left Hand\\n\",\n    9,  # 8 Hips\\n\",\n    10,  #   Right Upper Leg\\n\",\n    11,  # 10Right Leg\\n\",\n    12,  #   Right Foot\\n\",\n    13,  # 12Left Upper Leg\\n\",\n    14,  #   Left Leg\\n\",\n    15,  # 14Left Foot\\n\",\n    -1,  # \\n\",\n    -1,  # 16\\n\",\n    -1,  # \\n\",\n    -1,  # 18\\n\",\n    16,  #   Left Toe\\n\",\n    -1,  # 20\\n\",\n    -1,  # \\n\",\n    17,  # 22Right Toe\\n\",\n    -1,  # \\n\",\n    -1,  # 24\\n\",\n)\n\n\ndef transform(img):\n    img = img.astype(\"float32\") / 255\n\n    img = (img - MEAN) / STD\n\n    return np.transpose(img, axes=(2, 0, 1))\n\n\ndef get_affine_transform(\n    center,\n    scale,\n    rot,\n    output_size,\n    shift=np.array([0, 0], dtype=np.float32),\n    inv=0,\n    pixel_std=200,\n):\n    if not isinstance(scale, np.ndarray) and not isinstance(scale, list):\n        scale = np.array([scale, scale])\n\n    scale_tmp = scale * pixel_std\n    src_w = scale_tmp[0]\n    dst_w = output_size[0]\n    dst_h = output_size[1]\n\n    rot_rad = np.pi * rot / 180\n    src_dir = get_dir([0, src_w * -0.5], rot_rad)\n    dst_dir = np.array([0, dst_w * -0.5], np.float32)\n    src = np.zeros((3, 2), dtype=np.float32)\n    dst = np.zeros((3, 2), dtype=np.float32)\n    src[0, :] = center + scale_tmp * shift\n    src[1, :] = center + src_dir + scale_tmp * shift\n    dst[0, :] = [dst_w * 0.5, dst_h * 0.5]\n    dst[1, :] = np.array([dst_w * 0.5, dst_h * 0.5]) + dst_dir\n\n    src[2:, :] = get_3rd_point(src[0, :], src[1, :])\n    dst[2:, :] = get_3rd_point(dst[0, :], dst[1, :])\n\n    if inv:\n        trans = cv2.getAffineTransform(np.float32(dst), np.float32(src))\n    else:\n        trans = cv2.getAffineTransform(np.float32(src), np.float32(dst))\n\n    return trans\n\n\ndef get_3rd_point(a, b):\n    direct = a - b\n    return b + np.array([-direct[1], direct[0]], dtype=np.float32)\n\n\ndef get_dir(src_point, rot_rad):\n    sn, cs = np.sin(rot_rad), np.cos(rot_rad)\n\n    src_result = [0, 0]\n    src_result[0] = src_point[0] * cs - src_point[1] * sn\n    src_result[1] = src_point[0] * sn + src_point[1] * cs\n\n    return src_result\n\n\ndef process_image(path, input_img_size, pixel_std=200):\n    data_numpy = cv2.imread(path, cv2.IMREAD_COLOR | cv2.IMREAD_IGNORE_ORIENTATION)\n    # BUG HERE. Must be uncommented\n    # data_numpy = cv2.cvtColor(data_numpy, cv2.COLOR_BGR2RGB)\n\n    h, w = data_numpy.shape[:2]\n    c = np.array([w / 2, h / 2], dtype=np.float32)\n\n    aspect_ratio = input_img_size[0] / input_img_size[1]\n    if w > aspect_ratio * h:\n        h = w * 1.0 / aspect_ratio\n    elif w < aspect_ratio * h:\n        w = h * aspect_ratio\n\n    s = np.array([w / pixel_std, h / pixel_std], dtype=np.float32) * 1.25\n    r = 0\n    trans = get_affine_transform(c, s, r, input_img_size, pixel_std=pixel_std)\n    input = cv2.warpAffine(data_numpy, trans, input_img_size, flags=cv2.INTER_LINEAR)\n\n    input = transform(input)\n\n    return input, data_numpy, c, s\n\n\ndef get_final_preds(batch_heatmaps, center, scale, post_process=False):\n    coords, maxvals = get_max_preds(batch_heatmaps)\n\n    heatmap_height = batch_heatmaps.shape[2]\n    heatmap_width = batch_heatmaps.shape[3]\n\n    # post-processing\n    if post_process:\n        for n in range(coords.shape[0]):\n            for p in range(coords.shape[1]):\n                hm = batch_heatmaps[n][p]\n                px = int(math.floor(coords[n][p][0] + 0.5))\n                py = int(math.floor(coords[n][p][1] + 0.5))\n                if 1 < px < heatmap_width - 1 and 1 < py < heatmap_height - 1:\n                    diff = np.array(\n                        [\n                            hm[py][px + 1] - hm[py][px - 1],\n                            hm[py + 1][px] - hm[py - 1][px],\n                        ]\n                    )\n                    coords[n][p] += np.sign(diff) * 0.25\n\n    preds = coords.copy()\n\n    # Transform back\n    for i in range(coords.shape[0]):\n        preds[i] = transform_preds(\n            coords[i], center[i], scale[i], [heatmap_width, heatmap_height]\n        )\n\n    return preds, maxvals\n\n\ndef transform_preds(coords, center, scale, output_size):\n    target_coords = np.zeros(coords.shape)\n    trans = get_affine_transform(center, scale, 0, output_size, inv=1)\n    for p in range(coords.shape[0]):\n        target_coords[p, 0:2] = affine_transform(coords[p, 0:2], trans)\n    return target_coords\n\n\ndef affine_transform(pt, t):\n    new_pt = np.array([pt[0], pt[1], 1.0]).T\n    new_pt = np.dot(t, new_pt)\n    return new_pt[:2]\n\n\ndef get_max_preds(batch_heatmaps):\n    \"\"\"\n    get predictions from score maps\n    heatmaps: numpy.ndarray([batch_size, num_joints, height, width])\n    \"\"\"\n    assert isinstance(\n        batch_heatmaps, np.ndarray\n    ), \"batch_heatmaps should be numpy.ndarray\"\n    assert batch_heatmaps.ndim == 4, \"batch_images should be 4-ndim\"\n\n    batch_size = batch_heatmaps.shape[0]\n    num_joints = batch_heatmaps.shape[1]\n    width = batch_heatmaps.shape[3]\n    heatmaps_reshaped = batch_heatmaps.reshape((batch_size, num_joints, -1))\n    idx = np.argmax(heatmaps_reshaped, 2)\n    maxvals = np.amax(heatmaps_reshaped, 2)\n\n    maxvals = maxvals.reshape((batch_size, num_joints, 1))\n    idx = idx.reshape((batch_size, num_joints, 1))\n\n    preds = np.tile(idx, (1, 1, 2)).astype(np.float32)\n\n    preds[:, :, 0] = (preds[:, :, 0]) % width\n    preds[:, :, 1] = np.floor((preds[:, :, 1]) / width)\n\n    pred_mask = np.tile(np.greater(maxvals, 0.0), (1, 1, 2))\n    pred_mask = pred_mask.astype(np.float32)\n\n    preds *= pred_mask\n    return preds, maxvals\n\n\ndef infer_single_image(model, img_path, input_img_size=(288, 384), return_kps=True):\n    img_path = str(img_path)\n    pose_input, img, center, scale = process_image(\n        img_path, input_img_size=input_img_size\n    )\n    model.setInput(pose_input[None])\n    predicted_heatmap = model.forward()\n\n    if not return_kps:\n        return predicted_heatmap.squeeze(0)\n\n    predicted_keypoints, confidence = get_final_preds(\n        predicted_heatmap, center[None], scale[None], post_process=True\n    )\n\n    (predicted_keypoints, confidence, predicted_heatmap,) = (\n        predicted_keypoints.squeeze(0),\n        confidence.squeeze(0),\n        predicted_heatmap.squeeze(0),\n    )\n\n    return img, predicted_keypoints, confidence, predicted_heatmap\n"
  },
  {
    "path": "src/renderer.py",
    "content": "import numpy as np\nimport pyrender\nimport torch\nimport trimesh\nfrom torchvision.utils import make_grid\n\n\nclass Renderer:\n    \"\"\"\n    Renderer used for visualizing the SMPL model\n    Code adapted from https://github.com/vchoutas/smplify-x\n    \"\"\"\n\n    def __init__(self, focal_length=5000, img_res=224, faces=None):\n        self.renderer = pyrender.OffscreenRenderer(\n            viewport_width=img_res, viewport_height=img_res, point_size=1.0\n        )\n        self.focal_length = focal_length\n        self.camera_center = [img_res // 2, img_res // 2]\n        self.faces = faces\n\n    def visualize_tb(self, vertices, camera_translation, images):\n        vertices = vertices.cpu().numpy()\n        camera_translation = camera_translation.cpu().numpy()\n        images = images.cpu()\n        images_np = np.transpose(images.numpy(), (0, 2, 3, 1))\n        rend_imgs = []\n        for i in range(vertices.shape[0]):\n            rend_img = torch.from_numpy(\n                np.transpose(\n                    self.__call__(vertices[i], camera_translation[i], images_np[i]),\n                    (2, 0, 1),\n                )\n            ).float()\n            rend_imgs.append(images[i])\n            rend_imgs.append(rend_img)\n        rend_imgs = make_grid(rend_imgs, nrow=2)\n        return rend_imgs\n\n    def __call__(self, vertices, camera_translation, image):\n        material = pyrender.MetallicRoughnessMaterial(\n            metallicFactor=0.2, alphaMode=\"OPAQUE\", baseColorFactor=(0.8, 0.3, 0.3, 1.0)\n        )\n\n        camera_translation[0] *= -1.0\n\n        mesh = trimesh.Trimesh(vertices, self.faces)\n        rot = trimesh.transformations.rotation_matrix(np.radians(180), [1, 0, 0])\n        mesh.apply_transform(rot)\n        mesh = pyrender.Mesh.from_trimesh(mesh, material=material)\n\n        scene = pyrender.Scene(ambient_light=(0.5, 0.5, 0.5))\n        scene.add(mesh, \"mesh\")\n\n        camera_pose = np.eye(4)\n        camera_pose[:3, 3] = camera_translation\n        camera = pyrender.IntrinsicsCamera(\n            fx=self.focal_length,\n            fy=self.focal_length,\n            cx=self.camera_center[0],\n            cy=self.camera_center[1],\n        )\n        scene.add(camera, pose=camera_pose)\n\n        light = pyrender.DirectionalLight(color=[1.0, 1.0, 1.0], intensity=1)\n        light_pose = np.eye(4)\n\n        light_pose[:3, 3] = np.array([0, -1, 1])\n        scene.add(light, pose=light_pose)\n\n        light_pose[:3, 3] = np.array([0, 1, 1])\n        scene.add(light, pose=light_pose)\n\n        light_pose[:3, 3] = np.array([1, 1, 2])\n        scene.add(light, pose=light_pose)\n\n        color, rend_depth = self.renderer.render(scene, flags=pyrender.RenderFlags.RGBA)\n        color = color.astype(np.float32) / 255.0\n        valid_mask = (rend_depth > 0)[:, :, None]\n        output_img = color[:, :, :3] * valid_mask + (1 - valid_mask) * image\n        return output_img\n\n\ndef overlay_mesh(\n    verts,\n    faces,\n    camera_transl,\n    focal_length_x,\n    focal_length_y,\n    camera_center,\n    H,\n    W,\n    img,\n    camera_rotation=None,\n    rotaround=None,\n    contactlist=None,\n    color=False,\n    scale=1,\n):\n\n    material = pyrender.MetallicRoughnessMaterial(\n        metallicFactor=0.0, alphaMode=\"OPAQUE\", baseColorFactor=(1.0, 1.0, 0.9, 1.0)\n    )\n    out_mesh = trimesh.Trimesh(verts, faces, process=False)\n    out_mesh_col = np.array(out_mesh.visual.vertex_colors)\n\n    if contactlist is not None and len(contactlist) > 0:\n        color = [255, 0, 0, 255]\n        out_mesh_col[contactlist] = color\n        out_mesh.visual.vertex_colors = out_mesh_col\n\n    if camera_rotation is None:\n        camera_rotation = np.eye(3)\n    else:\n        camera_rotation = camera_rotation[0]\n\n    # rotate mesh and stack output images\n    if rotaround is None:\n        out_mesh.vertices = np.matmul(verts, camera_rotation.T) + camera_transl\n    else:\n        base_mesh = trimesh.Trimesh(verts, faces, process=False)\n        # rot_center = (base_mesh.vertices[5615] + base_mesh.vertices[5614] ) / 2\n        rot = trimesh.transformations.rotation_matrix(\n            np.radians(rotaround), [0, 1, 0], base_mesh.vertices[4297]\n        )\n        base_mesh.apply_transform(rot)\n        out_mesh.vertices = (\n            np.matmul(base_mesh.vertices, camera_rotation.T) + camera_transl\n        )\n\n    out_mesh.vertices += np.array([0, 0, 50])\n    # add mesh to scene\n    mesh = pyrender.Mesh.from_trimesh(\n        out_mesh,\n        material=material,\n        smooth=False,\n    )\n    if img is not None:\n        scene = pyrender.Scene(\n            bg_color=[0.0, 0.0, 0.0, 0.0],\n            ambient_light=(0.3, 0.3, 0.3, 1.0),\n        )\n    else:\n        scene = pyrender.Scene(\n            bg_color=[1.0, 1.0, 1.0, 1.0],\n            ambient_light=(0.3, 0.3, 0.3, 1.0),\n        )\n    scene.add(mesh, \"mesh\")\n\n    # create and add camera\n    camera_pose = np.eye(4)\n    camera_pose[1, :] = -camera_pose[1, :]\n    camera_pose[2, :] = -camera_pose[2, :]\n    pyrencamera = pyrender.camera.OrthographicCamera(\n        camera_transl[2],\n        camera_transl[2],\n        znear=1e-6,\n        zfar=1000000,\n    )\n    scene.add(pyrencamera, pose=camera_pose)\n\n    # create and add light\n    light = pyrender.PointLight(\n        color=[1.0, 1.0, 1.0],\n        intensity=1,\n    )\n    light_pose = np.eye(4)\n    for lp in [[1, 1, -1], [-1, 1, -1], [1, -1, -1], [-1, -1, -1]]:\n        light_pose[:3, 3] = out_mesh.vertices.mean(0) + np.array(lp)\n        scene.add(light, pose=light_pose)\n\n    r = pyrender.OffscreenRenderer(\n        viewport_width=int(scale * W),\n        viewport_height=int(scale * H),\n        point_size=1.0,\n    )\n    color, _ = r.render(scene, flags=pyrender.RenderFlags.RGBA)\n    color = color.astype(np.float32) / 255.0\n\n    if img is not None:\n        valid_mask = (color[:, :, -1] > 0)[:, :, np.newaxis]\n        output_img = color[:, :, :-1] * valid_mask + (1 - valid_mask) * img\n    else:\n        output_img = color\n\n    output_img = (output_img * 255).astype(np.uint8)[..., :3]\n\n    return output_img\n"
  },
  {
    "path": "src/spin/__init__.py",
    "content": "from .constants import JOINT_NAMES\nfrom .hmr import hmr\nfrom .smpl import SMPLX\nfrom .utils import process_image\n\n__all__ = [\n    \"hmr\",\n    \"SMPLX\",\n    \"process_image\",\n    \"JOINT_NAMES\",\n]\n"
  },
  {
    "path": "src/spin/constants.py",
    "content": "FOCAL_LENGTH = 5000.0\nIMG_RES = 224\n\n# Mean and standard deviation for normalizing input image\nIMG_NORM_MEAN = [0.485, 0.456, 0.406]\nIMG_NORM_STD = [0.229, 0.224, 0.225]\n\n\"\"\"\nWe create a superset of joints containing the OpenPose joints together with the ones that each dataset provides.\nWe keep a superset of 24 joints such that we include all joints from every dataset.\nIf a dataset doesn't provide annotations for a specific joint, we simply ignore it.\nThe joints used here are the following:\n\"\"\"\nJOINT_NAMES = (\n    \"Hips\",\n    \"Left Upper Leg\",\n    \"Right Upper Leg\",\n    \"Spine\",\n    \"Left Leg\",\n    \"Right Leg\",\n    \"Spine1\",\n    \"Left Foot\",\n    \"Right Foot\",\n    \"Thorax\",\n    \"Left Toe\",\n    \"Right Toe\",\n    \"Neck\",\n    \"Left Shoulder\",\n    \"Right Shoulder\",\n    \"Head\",\n    \"Left ForeArm\",\n    \"Right ForeArm\",\n    \"Left Arm\",\n    \"Right Arm\",\n    \"Left Hand\",\n    \"Right Hand\",\n    # 25 OpenPose joints (in the order provided by OpenPose)\n    # \"OP Nose\",\n    # \"OP Neck\",\n    # \"OP RShoulder\",\n    # \"OP RElbow\",\n    # \"OP RWrist\",\n    # \"OP LShoulder\",\n    # \"OP LElbow\",\n    # \"OP LWrist\",\n    # \"OP MidHip\",\n    # \"OP RHip\",\n    # \"OP RKnee\",\n    # \"OP RAnkle\",\n    # \"OP LHip\",\n    # \"OP LKnee\",\n    # \"OP LAnkle\",\n    # \"OP REye\",\n    # \"OP LEye\",\n    # \"OP REar\",\n    # \"OP LEar\",\n    # \"OP LBigToe\",\n    # \"OP LSmallToe\",\n    # \"OP LHeel\",\n    # \"OP RBigToe\",\n    # \"OP RSmallToe\",\n    # \"OP RHeel\",\n    ## 24 Ground Truth joints (superset of joints from different datasets)\n    # \"Right Ankle\",\n    # \"Right Knee\",\n    # \"Right Hip\",\n    # \"Left Hip\",\n    # \"Left Knee\",\n    # \"Left Ankle\",\n    # \"Right Wrist\",\n    # \"Right Elbow\",\n    # \"Right Shoulder\",\n    # \"Left Shoulder\",\n    # \"Left Elbow\",\n    # \"Left Wrist\",\n    # \"Neck (LSP)\",\n    # \"Top of Head (LSP)\",\n    # \"Pelvis (MPII)\",\n    # \"Thorax (MPII)\",\n    # \"Spine (H36M)\",\n    # \"Jaw (H36M)\",\n    # \"Head (H36M)\",\n    # \"Nose\",\n    # \"Left Eye\",\n    # \"Right Eye\",\n    # \"Left Ear\",\n    # \"Right Ear\",\n    # \"OP MidHip\",\n    # \"Spine1\",\n    # \"Spine2\",\n    # \"Spine3\",\n    # \"OP Neck\",\n    # \"Head\",\n)\n\n# Dict containing the joints in numerical order\nJOINT_IDS = {JOINT_NAMES[i]: i for i in range(len(JOINT_NAMES))}\n\n# Map joints to SMPL joints\nJOINT_MAP = {\n    \"Hips\": 0,\n    \"Left Upper Leg\": 1,\n    \"Right Upper Leg\": 2,\n    \"Spine\": 3,\n    \"Left Leg\": 4,\n    \"Right Leg\": 5,\n    \"Spine1\": 6,\n    \"Left Foot\": 7,\n    \"Right Foot\": 8,\n    \"Thorax\": 9,\n    \"Left Toe\": 10,\n    \"Right Toe\": 11,\n    \"Neck\": 12,\n    \"Left Shoulder\": 13,\n    \"Right Shoulder\": 14,\n    \"Head\": 15,\n    \"Left ForeArm\": 16,\n    \"Right ForeArm\": 17,\n    \"Left Arm\": 18,\n    \"Right Arm\": 19,\n    \"Left Hand\": 20,\n    \"Right Hand\": 21,\n    # \"OP Nose\": 24,\n    # \"OP Neck\": 12,\n    # \"OP RShoulder\": 17,\n    # \"OP RElbow\": 19,\n    # \"OP RWrist\": 21,\n    # \"OP LShoulder\": 16,\n    # \"OP LElbow\": 18,\n    # \"OP LWrist\": 20,\n    # \"OP MidHip\": 0,\n    # \"OP RHip\": 2,\n    # \"OP RKnee\": 5,\n    # \"OP RAnkle\": 8,\n    # \"OP LHip\": 1,\n    # \"OP LKnee\": 4,\n    # \"OP LAnkle\": 7,\n    # \"OP REye\": 25,\n    # \"OP LEye\": 26,\n    # \"OP REar\": 27,\n    # \"OP LEar\": 28,\n    # \"OP LBigToe\": 29,\n    # \"OP LSmallToe\": 30,\n    # \"OP LHeel\": 31,\n    # \"OP RBigToe\": 32,\n    # \"OP RSmallToe\": 33,\n    # \"OP RHeel\": 34,\n    # \"Right Ankle\": 8,\n    # \"Right Knee\": 5,\n    # \"Right Hip\": 45,\n    # \"Left Hip\": 46,\n    # \"Left Knee\": 4,\n    # \"Left Ankle\": 7,\n    # \"Right Wrist\": 21,\n    # \"Right Elbow\": 19,\n    # \"Right Shoulder\": 17,\n    # \"Left Shoulder\": 16,\n    # \"Left Elbow\": 18,\n    # \"Left Wrist\": 20,\n    # \"Neck (LSP)\": 47,\n    # \"Top of Head (LSP)\": 15, # 48,\n    # \"Pelvis (MPII)\": 49,\n    # \"Thorax (MPII)\": 50,\n    # \"Spine (H36M)\": 51,\n    # \"Jaw (H36M)\": 52,\n    # \"Head (H36M)\": 15, # 53,\n    # \"Nose\": 24,\n    # \"Left Eye\": 26,\n    # \"Right Eye\": 25,\n    # \"Left Ear\": 28,\n    # \"Right Ear\": 27,\n    # \"Spine1\": 3,\n    # \"Spine2\": 6,\n    # \"Spine3\": 9,\n    # \"Head\": 15,\n}\n\n# Joint selectors\n# Indices to get the 14 LSP joints from the 17 H36M joints\nH36M_TO_J17 = [6, 5, 4, 1, 2, 3, 16, 15, 14, 11, 12, 13, 8, 10, 0, 7, 9]\nH36M_TO_J14 = H36M_TO_J17[:14]\n# Indices to get the 14 LSP joints from the ground truth joints\nJ24_TO_J17 = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 18, 14, 16, 17]\nJ24_TO_J14 = J24_TO_J17[:14]\n\n# Permutation of SMPL pose parameters when flipping the shape\nSMPL_JOINTS_FLIP_PERM = [\n    0,\n    2,\n    1,\n    3,\n    5,\n    4,\n    6,\n    8,\n    7,\n    9,\n    11,\n    10,\n    12,\n    14,\n    13,\n    15,\n    17,\n    16,\n    19,\n    18,\n    21,\n    20,\n    23,\n    22,\n]\nSMPL_POSE_FLIP_PERM = []\nfor i in SMPL_JOINTS_FLIP_PERM:\n    SMPL_POSE_FLIP_PERM.append(3 * i)\n    SMPL_POSE_FLIP_PERM.append(3 * i + 1)\n    SMPL_POSE_FLIP_PERM.append(3 * i + 2)\n# Permutation indices for the 24 ground truth joints\nJ24_FLIP_PERM = [\n    5,\n    4,\n    3,\n    2,\n    1,\n    0,\n    11,\n    10,\n    9,\n    8,\n    7,\n    6,\n    12,\n    13,\n    14,\n    15,\n    16,\n    17,\n    18,\n    19,\n    21,\n    20,\n    23,\n    22,\n]\n# Permutation indices for the full set of 49 joints\nJ49_FLIP_PERM = [\n    0,\n    1,\n    5,\n    6,\n    7,\n    2,\n    3,\n    4,\n    8,\n    12,\n    13,\n    14,\n    9,\n    10,\n    11,\n    16,\n    15,\n    18,\n    17,\n    22,\n    23,\n    24,\n    19,\n    20,\n    21,\n] + [25 + i for i in J24_FLIP_PERM]\n"
  },
  {
    "path": "src/spin/hmr.py",
    "content": "import math\n\nimport numpy as np\nimport torch\nimport torch.nn as nn\nimport torchvision.models.resnet as resnet\n\n\ndef rot6d_to_rotmat(x):\n    \"\"\"Convert 6D rotation representation to 3x3 rotation matrix.\n    Based on Zhou et al., \"On the Continuity of Rotation Representations in Neural Networks\", CVPR 2019\n    Input:\n        (B,6) Batch of 6-D rotation representations\n    Output:\n        (B,3,3) Batch of corresponding rotation matrices\n    \"\"\"\n\n    x = x.view(-1, 3, 2)\n    a1 = x[:, :, 0]\n    a2 = x[:, :, 1]\n    b1 = nn.functional.normalize(a1)\n    b2 = nn.functional.normalize(\n        a2 - torch.einsum(\"bi,bi->b\", b1, a2).unsqueeze(-1) * b1\n    )\n\n    b3 = torch.cross(b1, b2)\n\n    return torch.stack((b1, b2, b3), dim=-1)\n\n\nclass Bottleneck(nn.Module):\n    \"\"\"Redefinition of Bottleneck residual block\n    Adapted from the official PyTorch implementation\n    \"\"\"\n\n    expansion = 4\n\n    def __init__(self, inplanes, planes, stride=1, downsample=None):\n        super(Bottleneck, self).__init__()\n        self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)\n        self.bn1 = nn.BatchNorm2d(planes)\n        self.conv2 = nn.Conv2d(\n            planes, planes, kernel_size=3, stride=stride, padding=1, bias=False\n        )\n        self.bn2 = nn.BatchNorm2d(planes)\n        self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)\n        self.bn3 = nn.BatchNorm2d(planes * 4)\n        self.relu = nn.ReLU(inplace=True)\n        self.downsample = downsample\n        self.stride = stride\n\n    def forward(self, x):\n        residual = x\n\n        out = self.conv1(x)\n        out = self.bn1(out)\n        out = self.relu(out)\n\n        out = self.conv2(out)\n        out = self.bn2(out)\n        out = self.relu(out)\n\n        out = self.conv3(out)\n        out = self.bn3(out)\n\n        if self.downsample is not None:\n            residual = self.downsample(x)\n\n        out += residual\n        out = self.relu(out)\n\n        return out\n\n\nclass HMR(nn.Module):\n    \"\"\"SMPL Iterative Regressor with ResNet50 backbone\"\"\"\n\n    def __init__(self, block, layers, smpl_mean_params):\n        self.inplanes = 64\n        super(HMR, self).__init__()\n        self.n_shape = 10\n        self.n_cam = 3\n        self.n_joints = 24\n        npose = self.n_joints * 6\n        self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)\n        self.bn1 = nn.BatchNorm2d(64)\n        self.relu = nn.ReLU(inplace=True)\n        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)\n        self.layer1 = self._make_layer(block, 64, layers[0])\n        self.layer2 = self._make_layer(block, 128, layers[1], stride=2)\n        self.layer3 = self._make_layer(block, 256, layers[2], stride=2)\n        self.layer4 = self._make_layer(block, 512, layers[3], stride=2)\n        self.avgpool = nn.AvgPool2d(7, stride=1)\n        self.fc1 = nn.Linear(512 * block.expansion + npose + self.n_shape + self.n_cam, 1024)\n        self.drop1 = nn.Dropout()\n        self.fc2 = nn.Linear(1024, 1024)\n        self.drop2 = nn.Dropout()\n        self.decpose = nn.Linear(1024, npose)\n        self.decshape = nn.Linear(1024, self.n_shape)\n        self.deccam = nn.Linear(1024, self.n_cam)\n        nn.init.xavier_uniform_(self.decpose.weight, gain=0.01)\n        nn.init.xavier_uniform_(self.decshape.weight, gain=0.01)\n        nn.init.xavier_uniform_(self.deccam.weight, gain=0.01)\n\n        for m in self.modules():\n            if isinstance(m, nn.Conv2d):\n                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels\n                m.weight.data.normal_(0, math.sqrt(2.0 / n))\n            elif isinstance(m, nn.BatchNorm2d):\n                m.weight.data.fill_(1)\n                m.bias.data.zero_()\n\n        mean_params = np.load(smpl_mean_params)\n        init_pose = torch.from_numpy(mean_params[\"pose\"][:]).unsqueeze(0)\n        init_shape = torch.from_numpy(\n            mean_params[\"shape\"][:].astype(\"float32\")\n        ).unsqueeze(0)\n        init_cam = torch.from_numpy(mean_params[\"cam\"]).unsqueeze(0)\n        self.register_buffer(\"init_pose\", init_pose)\n        self.register_buffer(\"init_shape\", init_shape)\n        self.register_buffer(\"init_cam\", init_cam)\n\n    def _make_layer(self, block, planes, blocks, stride=1):\n        downsample = None\n        if stride != 1 or self.inplanes != planes * block.expansion:\n            downsample = nn.Sequential(\n                nn.Conv2d(\n                    self.inplanes,\n                    planes * block.expansion,\n                    kernel_size=1,\n                    stride=stride,\n                    bias=False,\n                ),\n                nn.BatchNorm2d(planes * block.expansion),\n            )\n\n        layers = []\n        layers.append(block(self.inplanes, planes, stride, downsample))\n        self.inplanes = planes * block.expansion\n        for _ in range(1, blocks):\n            layers.append(block(self.inplanes, planes))\n\n        return nn.Sequential(*layers)\n\n    def forward(self, x, init_pose=None, init_shape=None, init_cam=None, n_iter=3):\n\n        batch_size = x.shape[0]\n\n        if init_pose is None:\n            init_pose = self.init_pose.expand(batch_size, -1)\n        if init_shape is None:\n            init_shape = self.init_shape.expand(batch_size, -1)\n        if init_cam is None:\n            init_cam = self.init_cam.expand(batch_size, -1)\n\n        x = self.conv1(x)\n        x = self.bn1(x)\n        x = self.relu(x)\n        x = self.maxpool(x)\n\n        x1 = self.layer1(x)\n        x2 = self.layer2(x1)\n        x3 = self.layer3(x2)\n        x4 = self.layer4(x3)\n\n        xf = self.avgpool(x4)\n        xf = xf.view(xf.size(0), -1)\n\n        pred_pose = init_pose\n        pred_shape = init_shape\n        pred_cam = init_cam\n        for _ in range(n_iter):\n            xc = torch.cat([xf, pred_pose, pred_shape, pred_cam], 1)\n            xc = self.fc1(xc)\n            xc = self.drop1(xc)\n            xc = self.fc2(xc)\n            xc = self.drop2(xc)\n            pred_pose = self.decpose(xc) + pred_pose\n            pred_shape = self.decshape(xc) + pred_shape\n            pred_cam = self.deccam(xc) + pred_cam\n\n        pred_rotmat = rot6d_to_rotmat(pred_pose).view(batch_size, self.n_joints, 3, 3)\n\n        return pred_rotmat, pred_shape, pred_cam\n\n\ndef hmr(smpl_mean_params, pretrained=True, **kwargs):\n    \"\"\"Constructs an HMR model with ResNet50 backbone.\n    Args:\n        pretrained (bool): If True, returns a model pre-trained on ImageNet\n    \"\"\"\n    model = HMR(Bottleneck, [3, 4, 6, 3], smpl_mean_params, **kwargs)\n    if pretrained:\n        resnet_imagenet = resnet.resnet50(pretrained=True)\n        model.load_state_dict(resnet_imagenet.state_dict(), strict=False)\n    return model\n"
  },
  {
    "path": "src/spin/smpl.py",
    "content": "import numpy as np\nimport torch\nfrom smplx import SMPL as _SMPL\nfrom smplx import SMPLX as _SMPLX\nfrom smplx.body_models import SMPLOutput, SMPLXOutput\nfrom smplx.lbs import vertices2joints\n\nfrom .constants import JOINT_MAP, JOINT_NAMES\n\n# Hand joints\nSMPLX_HAND_TO_PANOPTIC = [\n    0,\n    13,\n    14,\n    15,\n    16,\n    1,\n    2,\n    3,\n    17,\n    4,\n    5,\n    6,\n    18,\n    10,\n    11,\n    12,\n    19,\n    7,\n    8,\n    9,\n    20,\n]  # Wrist Thumb to Pinky\n\n\n\nclass SMPL(_SMPL):\n    \"\"\"Extension of the official SMPL implementation to support more joints\"\"\"\n\n    JOINTS = (\n        \"Hips\",\n        \"Left Upper Leg\",\n        \"Right Upper Leg\",\n        \"Spine\",\n        \"Left Leg\",\n        \"Right Leg\",\n        \"Spine1\",\n        \"Left Foot\",\n        \"Right Foot\",\n        \"Thorax\",\n        \"Left Toe\",\n        \"Right Toe\",\n        \"Neck\",\n        \"Left Shoulder\",\n        \"Right Shoulder\",\n        \"Head\",\n        \"Left ForeArm\",\n        \"Right ForeArm\",\n        \"Left Arm\",\n        \"Right Arm\",\n        \"Left Hand\",\n        \"Right Hand\",\n        \"Left Finger\",\n        \"Right Finger\",\n    )\n\n    SKELETON = (\n        (0, 1),\n        (0, 2),\n        (0, 3),\n        (1, 4),\n        (2, 5),\n        (3, 6),\n        (4, 7),\n        (5, 8),\n        (6, 9),\n        (7, 10),\n        (8, 11),\n        (9, 12),\n        (12, 13),\n        (12, 14),\n        (12, 15),\n        (13, 16),\n        (14, 17),\n        (16, 18),\n        (17, 19),\n        (18, 20),\n        (19, 21),\n        (20, 22),\n        (21, 23),\n    )\n\n    def __init__(self, *args, **kwargs):\n        super(SMPL, self).__init__(*args, **kwargs)\n        joints = [JOINT_MAP[i] for i in JOINT_NAMES]\n        joint_regressor_extra = kwargs[\"joint_regressor_extra_path\"]\n        J_regressor_extra = np.load(joint_regressor_extra)\n        self.register_buffer(\n            \"J_regressor_extra\", torch.tensor(J_regressor_extra, dtype=torch.float32)\n        )\n        self.joint_map = torch.tensor(joints, dtype=torch.long)\n\n    def forward(self, *args, **kwargs):\n        kwargs[\"get_skin\"] = True\n        smpl_output = super(SMPL, self).forward(*args, **kwargs)\n        extra_joints = vertices2joints(\n            self.J_regressor_extra, smpl_output.vertices\n        )  # Additional 9 joints #Check doc/J_regressor_extra.png\n        joints = torch.cat(\n            [smpl_output.joints, extra_joints], dim=1\n        )  # [N, 24 + 21, 3]  + [N, 9, 3]\n        joints = joints[:, self.joint_map, :]\n        output = SMPLOutput(\n            vertices=smpl_output.vertices,\n            global_orient=smpl_output.global_orient,\n            body_pose=smpl_output.body_pose,\n            joints=joints,\n            betas=smpl_output.betas,\n            full_pose=smpl_output.full_pose,\n        )\n        return output\n\n\nclass SMPLX(_SMPLX):\n    \"\"\"Extension of the official SMPL implementation to support more joints\"\"\"\n\n    JOINTS = (\n        \"Hips\",\n        \"Left Upper Leg\",\n        \"Right Upper Leg\",\n        \"Spine\",\n        \"Left Leg\",\n        \"Right Leg\",\n        \"Spine1\",\n        \"Left Foot\",\n        \"Right Foot\",\n        \"Thorax\",\n        \"Left Toe\",\n        \"Right Toe\",\n        \"Neck\",\n        \"Left Shoulder\",\n        \"Right Shoulder\",\n        \"Head\",\n        \"Left ForeArm\",\n        \"Right ForeArm\",\n        \"Left Arm\",\n        \"Right Arm\",\n        \"Left Hand\",\n        \"Right Hand\",\n    )\n\n    SKELETON = (\n        (0, 1),\n        (0, 2),\n        (0, 3),\n        (1, 4),\n        (2, 5),\n        (3, 6),\n        (4, 7),\n        (5, 8),\n        (6, 9),\n        (7, 10),\n        (8, 11),\n        (9, 12),\n        (12, 13),\n        (12, 14),\n        (12, 15),\n        (13, 16),\n        (14, 17),\n        (16, 18),\n        (17, 19),\n        (18, 20),\n        (19, 21),\n    )\n\n    def __init__(self, *args, **kwargs):\n        kwargs[\"ext\"] = \"pkl\"  # We have pkl file\n        super(SMPLX, self).__init__(*args, **kwargs)\n        joints = [JOINT_MAP[i] for i in JOINT_NAMES]\n        self.joint_map = torch.tensor(joints, dtype=torch.long)\n\n    def forward(self, *args, **kwargs):\n        kwargs[\"get_skin\"] = True\n\n        # if pose parameter is for SMPL with 21 joints (ignoring root)\n        try:\n            if kwargs[\"body_pose\"].shape[1] == 69:\n                kwargs[\"body_pose\"] = kwargs[\"body_pose\"][\n                    :, : -2 * 3\n                ]  # Ignore the last two joints (which are on the palm. Not used)\n\n            if kwargs[\"body_pose\"].shape[1] == 23:\n                kwargs[\"body_pose\"] = kwargs[\"body_pose\"][\n                    :, :-2\n                ]  # Ignore the last two joints (which are on the palm. Not used)\n        except:\n            pass\n\n        smpl_output = super(SMPLX, self).forward(*args, **kwargs)\n\n        # SMPL-X Joint order: https://docs.google.com/spreadsheets/d/1_1dLdaX-sbMkCKr_JzJW_RZCpwBwd7rcKkWT_VgAQ_0/edit#gid=0\n        smplx_to_smpl = (\n            list(range(0, 22)) + [28, 43] + list(range(55, 76))\n        )  # 28 left middle finger , 43: right middle finger 1\n        smpl_joints = smpl_output.joints[\n            :, smplx_to_smpl, :\n        ]  # Convert SMPL-X to SMPL     127 ->45\n        joints = smpl_joints\n        joints = joints[:, self.joint_map, :]\n\n        smplx_lhand = (\n            [20] + list(range(25, 40)) + list(range(66, 71))\n        )  # 20 for left wrist. 20 finger joints\n        lhand_joints = smpl_output.joints[:, smplx_lhand, :]  # (N,21,3)\n        lhand_joints = lhand_joints[\n            :, SMPLX_HAND_TO_PANOPTIC, :\n        ]  # Convert SMPL-X hand order to paonptic hand order\n\n        smplx_rhand = (\n            [21] + list(range(40, 55)) + list(range(71, 76))\n        )  # 21 for right wrist. 20 finger joints\n        rhand_joints = smpl_output.joints[:, smplx_rhand, :]  # (N,21,3)\n        rhand_joints = rhand_joints[\n            :, SMPLX_HAND_TO_PANOPTIC, :\n        ]  # Convert SMPL-X hand order to paonptic hand order\n\n        output = SMPLXOutput(\n            vertices=smpl_output.vertices,\n            global_orient=smpl_output.global_orient,\n            body_pose=smpl_output.body_pose,\n            joints=joints,\n            right_hand_pose=rhand_joints,  # N,21,3\n            left_hand_pose=lhand_joints,  # N,21,3\n            betas=smpl_output.betas,\n            full_pose=smpl_output.full_pose,\n            A=smpl_output.A,\n        )\n        return output\n\n\n\"\"\"\n0\tpelvis',\n1\tleft_hip',\n2\tright_hip',\n3\tspine1',\n4\tleft_knee',\n5\tright_knee',\n6\tspine2',\n7\tleft_ankle',\n8\tright_ankle',\n9\tspine3',\n10\tleft_foot',\n11\tright_foot',\n12\tneck',\n13\tleft_collar',\n14\tright_collar',\n15\thead',\n16\tleft_shoulder',\n17\tright_shoulder',\n18\tleft_elbow',\n19\tright_elbow',\n20\tleft_wrist',\n21\tright_wrist',\n22\tjaw',\n23\tleft_eye_smplhf',\n24\tright_eye_smplhf',\n25\tleft_index1',\n26\tleft_index2',\n27\tleft_index3',\n28\tleft_middle1',\n29\tleft_middle2',\n30\tleft_middle3',\n31\tleft_pinky1',\n32\tleft_pinky2',\n33\tleft_pinky3',\n34\tleft_ring1',\n35\tleft_ring2',\n36\tleft_ring3',\n37\tleft_thumb1',\n38\tleft_thumb2',\n39\tleft_thumb3',\n40\tright_index1',\n41\tright_index2',\n42\tright_index3',\n43\tright_middle1',\n44\tright_middle2',\n45\tright_middle3',\n46\tright_pinky1',\n47\tright_pinky2',\n48\tright_pinky3',\n49\tright_ring1',\n50\tright_ring2',\n51\tright_ring3',\n52\tright_thumb1',\n53\tright_thumb2',\n54\tright_thumb3',\n55\tnose',\n56\tright_eye',\n57\tleft_eye',\n58\tright_ear',\n59\tleft_ear',\n60\tleft_big_toe',\n61\tleft_small_toe',\n62\tleft_heel',\n63\tright_big_toe',\n64\tright_small_toe',\n65\tright_heel',\n66\tleft_thumb',\n67\tleft_index',\n68\tleft_middle',\n69\tleft_ring',\n70\tleft_pinky',\n71\tright_thumb',\n72\tright_index',\n73\tright_middle',\n74\tright_ring',\n75\tright_pinky',\n76\tright_eye_brow1',\n77\tright_eye_brow2',\n78\tright_eye_brow3',\n79\tright_eye_brow4',\n80\tright_eye_brow5',\n81\tleft_eye_brow5',\n82\tleft_eye_brow4',\n83\tleft_eye_brow3',\n84\tleft_eye_brow2',\n85\tleft_eye_brow1',\n86\tnose1',\n87\tnose2',\n88\tnose3',\n89\tnose4',\n90\tright_nose_2',\n91\tright_nose_1',\n92\tnose_middle',\n93\tleft_nose_1',\n94\tleft_nose_2',\n95\tright_eye1',\n96\tright_eye2',\n97\tright_eye3',\n98\tright_eye4',\n99\tright_eye5',\n100\tright_eye6',\n101\tleft_eye4',\n102\tleft_eye3',\n103\tleft_eye2',\n104\tleft_eye1',\n105\tleft_eye6',\n106\tleft_eye5',\n107\tright_mouth_1',\n108\tright_mouth_2',\n109\tright_mouth_3',\n110\tmouth_top',\n111\tleft_mouth_3',\n112\tleft_mouth_2',\n113\tleft_mouth_1',\n114\tleft_mouth_5', # 59 in OpenPose output\n115\tleft_mouth_4', # 58 in OpenPose output\n116\tmouth_bottom',\n117\tright_mouth_4',\n118\tright_mouth_5',\n119\tright_lip_1',\n120\tright_lip_2',\n121\tlip_top',\n122\tleft_lip_2',\n123\tleft_lip_1',\n124\tleft_lip_3',\n125\tlip_bottom',\n126\tright_lip_3',\n127\tright_contour_1',\n128\tright_contour_2',\n129\tright_contour_3',\n130\tright_contour_4',\n131\tright_contour_5',\n132\tright_contour_6',\n133\tright_contour_7',\n134\tright_contour_8',\n135\tcontour_middle',\n136\tleft_contour_8',\n137\tleft_contour_7',\n138\tleft_contour_6',\n139\tleft_contour_5',\n140\tleft_contour_4',\n141\tleft_contour_3',\n142\tleft_contour_2',\n143\tleft_contour_1'\n\"\"\"\n\n\n# SMPL Joints:\n\"\"\"\n0\tpelvis',\n1\tleft_hip',\n2\tright_hip',\n3\tspine1',\n4\tleft_knee',\n5\tright_knee',\n6\tspine2',\n7\tleft_ankle',\n8\tright_ankle',\n9\tspine3',\n10\tleft_foot',\n11\tright_foot',\n12\tneck',\n13\tleft_collar',\n14\tright_collar',\n15\thead',\n16\tleft_shoulder',\n17\tright_shoulder',\n18\tleft_elbow',\n19\tright_elbow',\n20\tleft_wrist',\n21\tright_wrist',\n22\t\n23\t\n24\tnose',\n25\tright_eye',\n26\tleft_eye',\n27\tright_ear',\n28\tleft_ear',\n29\tleft_big_toe',\n30\tleft_small_toe',\n31\tleft_heel',\n32\tright_big_toe',\n33\tright_small_toe',\n34\tright_heel',\n35\tleft_thumb',\n36\tleft_index',\n37\tleft_middle',\n38\tleft_ring',\n39\tleft_pinky',\n40\tright_thumb',\n41\tright_index',\n42\tright_middle',\n43\tright_ring',\n44\tright_pinky',\n\"\"\"\n"
  },
  {
    "path": "src/spin/utils.py",
    "content": "import json\n\nimport cv2\nimport numpy as np\nimport torch\nfrom skimage.transform import resize, rotate\nfrom torchvision.transforms import Normalize\n\nfrom .constants import IMG_NORM_MEAN, IMG_NORM_STD, IMG_RES\n\n\ndef get_transform(center, scale, res, rot=0):\n    \"\"\"Generate transformation matrix.\"\"\"\n    h = 200 * scale\n    t = np.zeros((3, 3))\n    t[0, 0] = float(res[1]) / h\n    t[1, 1] = float(res[0]) / h\n    t[0, 2] = res[1] * (-float(center[0]) / h + 0.5)\n    t[1, 2] = res[0] * (-float(center[1]) / h + 0.5)\n    t[2, 2] = 1\n    if not rot == 0:\n        rot = -rot  # To match direction of rotation from cropping\n        rot_mat = np.zeros((3, 3))\n        rot_rad = rot * np.pi / 180\n        sn, cs = np.sin(rot_rad), np.cos(rot_rad)\n        rot_mat[0, :2] = [cs, -sn]\n        rot_mat[1, :2] = [sn, cs]\n        rot_mat[2, 2] = 1\n        # Need to rotate around center\n        t_mat = np.eye(3)\n        t_mat[0, 2] = -res[1] / 2\n        t_mat[1, 2] = -res[0] / 2\n        t_inv = t_mat.copy()\n        t_inv[:2, 2] *= -1\n        t = np.dot(t_inv, np.dot(rot_mat, np.dot(t_mat, t)))\n\n    return t\n\n\ndef transform(pt, center, scale, res, invert=0, rot=0):\n    \"\"\"Transform pixel location to different reference.\"\"\"\n    t = get_transform(center, scale, res, rot=rot)\n    if invert:\n        t = np.linalg.inv(t)\n    new_pt = np.array([pt[0] - 1, pt[1] - 1, 1.0]).T\n    new_pt = np.dot(t, new_pt)\n\n    return new_pt[:2].astype(int) + 1\n\n\ndef crop(img, center, scale, res, rot=0):\n    \"\"\"Crop image according to the supplied bounding box.\"\"\"\n    # Upper left point\n    ul = np.array(transform([1, 1], center, scale, res, invert=1)) - 1\n    # Bottom right point\n    br = np.array(transform([res[0] + 1, res[1] + 1], center, scale, res, invert=1)) - 1\n\n    # Padding so that when rotated proper amount of context is included\n    pad = int(np.linalg.norm(br - ul) / 2 - float(br[1] - ul[1]) / 2)\n    if not rot == 0:\n        ul -= pad\n        br += pad\n\n    new_shape = [br[1] - ul[1], br[0] - ul[0]]\n    if len(img.shape) > 2:\n        new_shape += [img.shape[2]]\n    new_img = np.zeros(new_shape)\n\n    # Range to fill new array\n    new_x = max(0, -ul[0]), min(br[0], len(img[0])) - ul[0]\n    new_y = max(0, -ul[1]), min(br[1], len(img)) - ul[1]\n    # Range to sample from original image\n    old_x = max(0, ul[0]), min(len(img[0]), br[0])\n    old_y = max(0, ul[1]), min(len(img), br[1])\n    new_img[new_y[0] : new_y[1], new_x[0] : new_x[1]] = img[\n        old_y[0] : old_y[1], old_x[0] : old_x[1]\n    ]\n\n    if not rot == 0:\n        # Remove padding\n        new_img = rotate(new_img, rot)\n        new_img = new_img[pad:-pad, pad:-pad]\n\n    new_img = resize(new_img, res)\n\n    return new_img\n\n\ndef bbox_from_openpose(openpose_file, rescale=1.2, detection_thresh=0.2):\n    \"\"\"Get center and scale for bounding box from openpose detections.\"\"\"\n    with open(openpose_file, \"r\") as f:\n        keypoints = json.load(f)[\"people\"][0][\"pose_keypoints_2d\"]\n    keypoints = np.reshape(np.array(keypoints), (-1, 3))\n    valid = keypoints[:, -1] > detection_thresh\n    valid_keypoints = keypoints[valid][:, :-1]\n    center = valid_keypoints.mean(axis=0)\n    bbox_size = (valid_keypoints.max(axis=0) - valid_keypoints.min(axis=0)).max()\n    # adjust bounding box tightness\n    scale = bbox_size / 200.0\n    scale *= rescale\n\n    return center, scale\n\n\ndef bbox_from_json(bbox_file):\n    \"\"\"Get center and scale of bounding box from bounding box annotations.\n    The expected format is [top_left(x), top_left(y), width, height].\n    \"\"\"\n    with open(bbox_file, \"r\") as f:\n        bbox = np.array(json.load(f)[\"bbox\"]).astype(np.float32)\n    ul_corner = bbox[:2]\n    center = ul_corner + 0.5 * bbox[2:]\n    width = max(bbox[2], bbox[3])\n    scale = width / 200.0\n    # make sure the bounding box is rectangular\n    return center, scale\n\n\ndef process_image(img_file, bbox_file=None, openpose_file=None, input_res=IMG_RES):\n    \"\"\"Read image, do preprocessing and possibly crop it according to the bounding box.\n    If there are bounding box annotations, use them to crop the image.\n    If no bounding box is specified but openpose detections are available, use them to get the bounding box.\n    \"\"\"\n    img_file = str(img_file)\n    normalize_img = Normalize(mean=IMG_NORM_MEAN, std=IMG_NORM_STD)\n    img = cv2.imread(img_file)[\n        :, :, ::-1\n    ].copy()  # PyTorch does not support negative stride at the moment\n    if bbox_file is None and openpose_file is None:\n        # Assume that the person is centerered in the image\n        height = img.shape[0]\n        width = img.shape[1]\n        center = np.array([width // 2, height // 2])\n        scale = max(height, width) / 200\n    else:\n        if bbox_file is not None:\n            center, scale = bbox_from_json(bbox_file)\n        elif openpose_file is not None:\n            center, scale = bbox_from_openpose(openpose_file)\n\n    img = crop(img, center, scale, (input_res, input_res))\n    img = img.astype(np.float32) / 255.0\n    img = torch.from_numpy(img).permute(2, 0, 1)\n    norm_img = normalize_img(img.clone())\n\n    return img, norm_img\n"
  },
  {
    "path": "src/utils.py",
    "content": "import colorsys\nimport itertools\nimport json\nimport pickle\n\nimport cv2\nimport plotly.graph_objects as go\nimport trimesh\nimport torch\nimport numpy as np\nimport PIL.Image as pil_img\nimport PIL.ImageDraw as ImageDraw\nfrom PIL import Image, ImageChops\nfrom skimage import exposure\n\nimport spin\nimport renderer\nimport pose_estimation\n\n\ndef load_json(path):\n    with open(path) as f:\n        return json.load(f)\n\n\ndef save_json(o, path):\n    with open(path, \"w\") as f:\n        json.dump(o, f)\n\n\ndef load_pkl(path):\n    with open(path, \"rb\") as f:\n        return pickle.load(f)\n\n\ndef save_pkl(o, path):\n    with open(path, \"wb\") as f:\n        pickle.dump(o, f)\n\n\ndef plot_3D(joints, vertices, faces):\n    x, y, z = joints.T\n    x1, y1, z1 = vertices.T\n    i, j, k = faces.T\n\n    data = [\n        go.Mesh3d(\n            x=x1,\n            y=y1,\n            z=z1,\n            i=i,\n            j=j,\n            k=k,\n        ),\n        go.Scatter3d(\n            x=x,\n            y=y,\n            z=z,\n            mode=\"markers\",\n            marker_size=5,\n        ),\n    ]\n\n    fig = go.Figure(\n        data=data,\n    )\n\n    return fig\n\n\ndef draw_keypoints(\n    input_img_kp,\n    keypoints,\n    skeleton,\n    r,\n    color,\n    contact2dlist=None,\n    contact2dlist_color=\"green\",\n    cos=None,\n):\n    if keypoints is not None:\n        draw = ImageDraw.Draw(input_img_kp)\n\n        for skidx, (i, j) in enumerate(skeleton):\n            a = keypoints[i]\n            b = keypoints[j]\n            ln = np.linalg.norm(b - a)\n\n            xy = [a[0], a[1], b[0], b[1]]\n            if cos is not None:\n                c = colorsys.hsv_to_rgb(cos[skidx] ** 8, 1, 1)\n                c = tuple(int(c_ * 255) for c_ in c)\n                draw.line(xy, fill=c, width=r)\n            else:\n                draw.line(xy, fill=color, width=r)\n\n        draw_kpts = [(p[0] - r, p[1] - r, p[0] + r, p[1] + r) for p in keypoints]\n        for _, elipse in enumerate(draw_kpts):\n            draw.ellipse(elipse, fill=\"black\", outline=\"black\")\n\n        if contact2dlist is not None:\n            keypoints_torch = torch.from_numpy(keypoints)\n            for c2d in contact2dlist:\n                for (src_1, dst_1, t_1), (src_2, dst_2, t_2) in itertools.combinations(\n                    c2d, 2\n                ):\n                    a = torch.lerp(\n                        keypoints_torch[src_1], keypoints_torch[dst_1], t_1\n                    ).tolist()\n                    b = torch.lerp(\n                        keypoints_torch[src_2], keypoints_torch[dst_2], t_2\n                    ).tolist()\n\n                    xy = [a[0], a[1], b[0], b[1]]\n                    draw.line(xy, fill=contact2dlist_color, width=max(r // 3, 10))\n\n    return input_img_kp\n\n\ndef save_results_image(\n    camera,\n    focal_length_x,\n    focal_length_y,\n    input_img,\n    vertices,\n    faces,\n    filename,\n    keypoints=None,\n    keypoints_2=None,\n    heatmap=None,\n    cvt_camera=True,\n    contactlist=None,\n    contact2dlist=None,\n    user_study=True,\n    cos=None,\n):\n    if isinstance(contactlist, list) and len(contactlist) > 0:\n        contactlist = np.concatenate(contactlist)\n\n    H, W, _ = input_img.shape\n    HW = max(H, W)\n    camera_center = np.array([W // 2, H // 2])\n    if not cvt_camera:\n        camera_transl = camera.copy()\n    else:\n        camera_transl = np.stack(\n            [\n                camera[1],\n                camera[2],\n                1 / camera[0],\n            ],\n        )\n\n    # draw keypoints\n    input_img_kp = pil_img.fromarray(input_img)\n    if keypoints is not None:\n        draw_keypoints(\n            input_img_kp,\n            keypoints,\n            pose_estimation.SKELETON,\n            r=int(HW * 0.01),\n            color=(255, 0, 0, 255),\n            # contact2dlist=contact2dlist,\n            # contact2dlist_color=\"orange\",\n        )\n\n    if cos is not None:\n        input_img_kp_cos = pil_img.fromarray(input_img)\n        if keypoints is not None:\n            draw_keypoints(\n                input_img_kp_cos,\n                keypoints,\n                pose_estimation.SKELETON,\n                r=int(HW * 0.01),\n                color=(255, 0, 0, 255),\n                # contact2dlist=contact2dlist,\n                # contact2dlist_color=\"orange\",\n                cos=cos,\n            )\n\n    input_img_kp_2 = input_img_kp.copy()\n    if keypoints_2 is not None:\n        draw_keypoints(\n            input_img_kp_2,\n            keypoints_2,\n            # spin.SMPLX.SKELETON if \"eft\" in str(filename) else pose_estimation.SKELETON,\n            pose_estimation.SKELETON,\n            r=int(HW * 0.01),\n            color=(0, 0, 255, 255),\n            contact2dlist=contact2dlist,\n            contact2dlist_color=\"purple\",\n        )\n\n    # heatmap = ImageOps.invert(heatmap)\n    if heatmap is not None:\n        # input_img_kp_2 = pil_img.blend(input_img_kp_2, heatmap, 0.5)\n\n        hm = np.copy(input_img)\n        gray_img = exposure.rescale_intensity(heatmap, out_range=(0, 255))\n        gray_img = gray_img.astype(np.uint8)\n        heatmap_img = cv2.applyColorMap(gray_img, cv2.COLORMAP_JET)\n        hm = pil_img.fromarray(cv2.cvtColor(heatmap_img, cv2.COLOR_BGR2RGB))\n        hm.save(filename.with_stem(f\"{filename.stem}_heatmap\"))\n        input_img_kp.save(filename.with_stem(f\"{filename.stem}_2dkps\"))\n        heatmap = pil_img.fromarray(heatmap)\n\n        if cos is not None:\n            input_img_kp_cos.save(filename.with_stem(f\"{filename.stem}_2dkpscos\"))\n\n    # render fitted mesh from different views\n    overlay_fit_img = renderer.overlay_mesh(\n        vertices,\n        faces,\n        camera_transl,\n        focal_length_x,\n        focal_length_y,\n        camera_center,\n        H,\n        W,\n        input_img.astype(\"float32\") / 255,\n        None,\n        rotaround=None,\n    )\n\n    # overlay_fit_img = pil_img.fromarray(overlay_fit_img)\n    # draw_keypoints(overlay_fit_img, keypoints_2, r=int(HW * 0.01), color=(0, 0, 255, 255))\n\n    # camera_transl[-1] *= 1\n    view1_fit = renderer.overlay_mesh(\n        vertices,\n        faces,\n        camera_transl.astype(np.float32),\n        focal_length_x,\n        focal_length_y,\n        camera_center,\n        H,\n        W,\n        None,\n        None,\n        rotaround=-45,\n        contactlist=contactlist,\n    )\n    view2_fit = renderer.overlay_mesh(\n        vertices,\n        faces,\n        camera_transl.astype(np.float32),\n        focal_length_x,\n        focal_length_y,\n        camera_center,\n        H,\n        W,\n        None,\n        None,\n        rotaround=None,\n        contactlist=contactlist,\n    )\n    view3_fit = renderer.overlay_mesh(\n        vertices,\n        faces,\n        camera_transl.astype(np.float32),\n        focal_length_x,\n        focal_length_y,\n        camera_center,\n        H,\n        W,\n        None,\n        None,\n        rotaround=90,\n        contactlist=contactlist,\n        scale=1,\n    )\n\n    IMG = np.vstack(\n        (\n            np.hstack(\n                (\n                    np.asarray(input_img_kp)\n                    if keypoints is not None\n                    else 255 * np.ones_like(np.asarray(input_img_kp)),\n                    np.asarray(input_img_kp_2),\n                    overlay_fit_img,\n                    # np.asanyarray(overlay_fit_img),\n                ),\n            ),\n            np.hstack(\n                (\n                    view1_fit,\n                    view2_fit,\n                    view3_fit,\n                ),\n            ),\n        ),\n    )\n    IMG = pil_img.fromarray(IMG)\n    IMG.thumbnail((2000, 2000))\n\n    IMG.save(filename)\n\n    if user_study:\n        w = 768\n        input_img_kp.thumbnail((w, w))\n        input_img_kp.save(filename.with_stem(f\"{filename.stem}_orig\"))\n        W, H = input_img_kp.size\n\n        camera_transl[-1] *= 2\n        view2_fit = renderer.overlay_mesh(\n            vertices,\n            faces,\n            camera_transl.astype(np.float32),\n            focal_length_x,\n            focal_length_y,\n            camera_center,\n            H,\n            W,\n            None,\n            None,\n            rotaround=None,\n            contactlist=contactlist,\n            scale=2,\n        )\n        view2_fit = pil_img.fromarray(view2_fit)\n        w *= 2\n        view2_fit.thumbnail((w, w))\n        view2_fit.save(filename.with_stem(f\"{filename.stem}_same\"))\n        view3_fit = renderer.overlay_mesh(\n            vertices,\n            faces,\n            camera_transl.astype(np.float32),\n            focal_length_x,\n            focal_length_y,\n            camera_center,\n            H,\n            W,\n            None,\n            None,\n            rotaround=90,\n            contactlist=contactlist,\n            scale=2,\n        )\n        view3_fit = pil_img.fromarray(view3_fit)\n        view3_fit.thumbnail((w, w))\n        view3_fit.save(filename.with_stem(f\"{filename.stem}_alt\"))\n\n    return IMG\n\n\ndef save_3d_model_on_img(\n    camera,\n    vertices,\n    faces,\n    img,\n    filename,\n    save_path,\n):\n    img_res = max(img.shape[:2])\n    r = renderer.Renderer(\n        focal_length=spin.constants.FOCAL_LENGTH,\n        img_res=img_res,\n        faces=faces,\n    )\n\n    # Calculate camera parameters for rendering\n    camera_translation = np.stack(\n        [\n            camera[1],\n            camera[2],\n            2 * spin.constants.FOCAL_LENGTH / (img_res * camera[0] + 1e-9),\n        ],\n    )\n    # Render parametric shape\n    img_shape = r(vertices, camera_translation, img)\n    img_shape = (255 * img_shape).astype(\"uint8\")\n    img_shape = cv2.cvtColor(img_shape, cv2.COLOR_RGB2BGR)\n    cv2.imwrite(str(save_path / f\"shape_{filename}.png\"), img_shape)\n\n    # Render side views\n    aroundy = cv2.Rodrigues(np.array([0, np.radians(90.0), 0]))[0]\n    center = vertices.mean(axis=0)\n    rot_vertices = np.dot((vertices - center), aroundy) + center\n\n    # Render non-parametric shape\n    img_shape = r(rot_vertices, camera_translation, np.ones_like(img))\n    img_shape = (255 * img_shape).astype(\"uint8\")\n    img_shape = cv2.cvtColor(img_shape, cv2.COLOR_RGB2BGR)\n    cv2.imwrite(str(save_path / f\"shape_rot_{filename}.png\"), img_shape)\n\n\ndef save_mesh_with_colors(vertices, faces, save_path, mask=None, inds=None):\n    if inds is not None and isinstance(inds, list):\n        inds = np.concatenate(inds)\n    mesh = trimesh.Trimesh(\n        vertices=vertices,\n        faces=faces,\n        process=False,\n    )\n    color = np.array(mesh.visual.vertex_colors)\n    color[:] = [233, 233, 233, 255]\n    if mask is not None and any(mask):\n        color[~mask] = [255, 0, 0, 255]\n    elif inds is not None and len(inds) > 0:\n        color[inds] = [255, 0, 0, 255]\n    mesh.visual.vertex_colors = color\n    mesh.export(save_path)\n\n\ndef save_pose_params(rotmat, camera, betas, vertices, smpl, contact, save_path):\n    if contact is not None and isinstance(contact, list) and len(contact) > 0:\n        contact = np.concatenate(contact)\n\n    rotmat = rotmat.detach()\n    camera = camera.detach()\n    if smpl.name() == \"SMPL-X\":\n        rotmat = rotmat[: -2 * 3]\n\n    res = {\n        \"camera_s_t\": camera.unsqueeze(0).cpu().numpy(),\n        \"global_orient\": rotmat[:3].unsqueeze(0).cpu().numpy(),\n        \"betas\": betas,\n        \"body_pose\": rotmat[3:].unsqueeze(0).cpu().numpy(),\n        \"left_hand_pose\": smpl.left_hand_pose.unsqueeze(0).detach().cpu().numpy(),\n        \"right_hand_pose\": smpl.right_hand_pose.unsqueeze(0).detach().cpu().numpy(),\n        \"model\": smpl.name().replace(\"-\", \"\"),\n        \"gender\": smpl.gender,\n        \"vertices\": vertices[0].cpu().numpy(),\n    }\n\n    if contact is not None and len(contact) > 0:\n        contact = np.array(contact)\n        res[\"contact\"] = contact\n    else:\n        res[\"v\"] = vertices[0].cpu().numpy()\n\n    save_pkl(res, save_path)\n\n    np.savez(save_path, **res)\n"
  }
]