Repository: kbrodt/sketch2pose
Branch: main
Commit: c027b65219b1
Files: 20
Total size: 145.0 KB
Directory structure:
gitextract__nwva2u6/
├── README.md
├── patches/
│ ├── selfcontact.diff
│ ├── smplx.diff
│ └── torchgeometry.diff
├── requirements.txt
├── scripts/
│ ├── download.sh
│ ├── prepare.sh
│ └── run.sh
└── src/
├── fist_pose.py
├── hist_cub.py
├── losses.py
├── pose.py
├── pose_estimation.py
├── renderer.py
├── spin/
│ ├── __init__.py
│ ├── constants.py
│ ├── hmr.py
│ ├── smpl.py
│ └── utils.py
└── utils.py
================================================
FILE CONTENTS
================================================
================================================
FILE: README.md
================================================
# Sketch2Pose: Estimating a 3D Character Pose from a Bitmap Sketch
Artists frequently capture character poses via raster sketches, then use these
drawings as a reference while posing a 3D character in a specialized 3D
software --- a time-consuming process, requiring specialized 3D training and
mental effort. We tackle this challenge by proposing the first system for
automatically inferring a 3D character pose from a single bitmap sketch,
producing poses consistent with viewer expectations. Algorithmically
interpreting bitmap sketches is challenging, as they contain significantly
distorted proportions and foreshortening. We address this by predicting three
key elements of a drawing, necessary to disambiguate the drawn poses: 2D bone
tangents, self-contacts, and bone foreshortening. These elements are then
leveraged in an optimization inferring the 3D character pose consistent with
the artist's intent. Our optimization balances cues derived from artistic
literature and perception research to compensate for distorted character
proportions. We demonstrate a gallery of results on sketches of numerous
styles. We validate our method via numerical evaluations, user studies, and
comparisons to manually posed characters and previous work.
[Project Page](http://www-labs.iro.umontreal.ca/~bmpix/sketch2pose/)
# Prerequisites
- [GNU/Linux](https://www.gnu.org/gnu/linux-and-gnu.en.html)
- [`python`](https://python.org)
- [`pytorch`](https://pytorch.org/)
- [NVIDIA GPU] (optional, but highly recommended)
## Download body model (SMPL-X)
Download SMPL-X body model from
[https://smpl-x.is.tue.mpg.de](https://smpl-x.is.tue.mpg.de)
See [`download.sh`](./scripts/download.sh) and run
```bash
sh ./scripts/download.sh
```
## Virtual environement
Change [`pytorch`](https://pytorch.org/) version if needed in
[`prepare.sh`](./scripts/prepare.sh) and run
```bash
sh ./scripts/prepare.sh
```
# Demo
Activate virtual environement `. venv/bin/activate` and run
```bash
sh ./scripts/run.sh
# or
python src/pose.py \
--save-path "${out_dir}" \
--img-path "${img_path}" \
--use-contacts \
--use-natural \
--use-cos \
--use-angle-transf \
# or without contacts
python src/pose.py \
--save-path "${out_dir}" \
--img-path "${img_path}" \
--use-natural \
--use-cos \
--use-angle-transf \
```
# Citation
```
@article{brodt2022sketch2pose,
author = {Kirill Brodt and Mikhail Bessmeltsev},
title = {Sketch2Pose: Estimating a 3D Character Pose from a Bitmap Sketch},
journal = {ACM Transactions on Graphics},
year = {2022},
month = {7},
volume = {41},
number = {4},
doi = {10.1145/3528223.3530106},
}
```
# Useful links
- [Deep High-Resolution Representation Learning for Human Pose Estimation](https://github.com/leoxiaobin/deep-high-resolution-net.pytorch/)
- [SMPLify-X](https://github.com/vchoutas/smplify-x) ([project](https://smpl-x.is.tue.mpg.de/))
- [SPIN](https://github.com/nkolot/SPIN) ([project](https://www.seas.upenn.edu/~nkolot/projects/spin/))
- [eft](https://github.com/facebookresearch/eft)
- [SMPLify-XMC](https://github.com/muelea/smplify-xmc), [selfcontact](https://github.com/muelea/selfcontact) ([project](https://tuch.is.tue.mpg.de/))
- [Mixamo](https://www.mixamo.com) models with animations and a
[script](https://forums.unrealengine.com/community/community-content-tools-and-tutorials/1376068-script-mixamo-download-script)
to download them
- Quaternion-based [Forward
Kinematics](https://github.com/facebookresearch/QuaterNet)
================================================
FILE: patches/selfcontact.diff
================================================
+++ venv/lib/python3.10/site-packages/selfcontact/body_segmentation.py
@@ -14,6 +14,8 @@
#
# Contact: ps-license@tuebingen.mpg.de
+from pathlib import Path
+
import torch
import trimesh
import torch.nn as nn
@@ -22,6 +24,17 @@
from .utils.mesh import winding_numbers
+
+def load_pkl(path):
+ with open(path, "rb") as fin:
+ return pickle.load(fin)
+
+
+def save_pkl(obj, path):
+ with open(path, "wb") as fout:
+ pickle.dump(obj, fout)
+
+
class BodySegment(nn.Module):
def __init__(self,
name,
@@ -63,9 +76,17 @@
self.register_buffer('segment_faces', segment_faces)
# create vector to select vertices form faces
- tri_vidx = []
- for ii in range(faces.max().item()+1):
- tri_vidx += [torch.nonzero(faces==ii)[0].tolist()]
+ segments_folder = Path(segments_folder)
+ tri_vidx_path = segments_folder / "tri_vidx.pkl"
+ if not tri_vidx_path.is_file():
+ tri_vidx = []
+ for ii in range(faces.max().item()+1):
+ tri_vidx += [torch.nonzero(faces==ii)[0].tolist()]
+
+ save_pkl(tri_vidx, tri_vidx_path)
+ else:
+ tri_vidx = load_pkl(tri_vidx_path)
+
self.register_buffer('tri_vidx', torch.tensor(tri_vidx))
def create_band_faces(self):
@@ -149,7 +170,7 @@
self.segmentation = {}
for idx, name in enumerate(names):
self.segmentation[name] = BodySegment(name, faces, segments_folder,
- model_type).to('cuda')
+ model_type).to(device)
def batch_has_self_isec_verts(self, vertices):
"""
+++ venv/lib/python3.10/site-packages/selfcontact/selfcontact.py
@@ -41,6 +41,7 @@
test_segments=True,
compute_hd=False,
buffer_geodists=False,
+ device="cuda",
):
super().__init__()
@@ -95,7 +96,7 @@
if self.test_segments:
sxseg = pickle.load(open(segments_bounds_path, 'rb'))
self.segments = BatchBodySegment(
- [x for x in sxseg.keys()], faces, segments_folder, self.model_type
+ [x for x in sxseg.keys()], faces, segments_folder, self.model_type, device=device,
)
# load regressor to get high density mesh
@@ -106,7 +107,7 @@
torch.tensor(hd_operator['values']),
torch.Size(hd_operator['size']))
self.register_buffer('hd_operator',
- torch.tensor(hd_operator).float())
+ hd_operator.clone().detach().float())
with open(point_vert_corres_path, 'rb') as f:
hd_geovec = pickle.load(f)['faces_vert_is_sampled_from']
@@ -135,9 +136,13 @@
# split because of memory into two chunks
exterior = torch.zeros((bs, nv), device=vertices.device,
dtype=torch.bool)
- exterior[:, :5000] = winding_numbers(vertices[:,:5000,:],
+ exterior[:, :3000] = winding_numbers(vertices[:,:3000,:],
triangles).le(0.99)
- exterior[:, 5000:] = winding_numbers(vertices[:,5000:,:],
+ exterior[:, 3000:6000] = winding_numbers(vertices[:,3000:6000,:],
+ triangles).le(0.99)
+ exterior[:, 6000:9000] = winding_numbers(vertices[:,6000:9000,:],
+ triangles).le(0.99)
+ exterior[:, 9000:] = winding_numbers(vertices[:,9000:,:],
triangles).le(0.99)
# check if intersections happen within segments
@@ -173,9 +178,13 @@
# split because of memory into two chunks
exterior = torch.zeros((bs, np), device=points.device,
dtype=torch.bool)
- exterior[:, :6000] = winding_numbers(points[:,:6000,:],
+ exterior[:, :3000] = winding_numbers(points[:,:3000,:],
+ triangles).le(0.99)
+ exterior[:, 3000:6000] = winding_numbers(points[:,3000:6000,:],
triangles).le(0.99)
- exterior[:, 6000:] = winding_numbers(points[:,6000:,:],
+ exterior[:, 6000:9000] = winding_numbers(points[:,6000:9000,:],
+ triangles).le(0.99)
+ exterior[:, 9000:] = winding_numbers(points[:,9000:,:],
triangles).le(0.99)
return exterior
@@ -371,6 +380,23 @@
return hd_v2v_mins, hd_exteriors, hd_points, hd_faces_in_contacts
+ def verts_in_contact(self, vertices, return_idx=False):
+
+ # get pairwise distances of vertices
+ v2v = self.get_pairwise_dists(vertices, vertices, squared=True)
+
+ # mask v2v with eucledean and geodesic dsitance
+ euclmask = v2v < self.euclthres**2
+ mask = euclmask * self.geomask
+
+ # find closes vertex in contact
+ in_contact = mask.sum(1) > 0
+
+ if return_idx:
+ in_contact = torch.where(in_contact)
+
+ return in_contact
+
class SelfContactSmall(nn.Module):
+++ venv/lib/python3.10/site-packages/selfcontact/utils/mesh.py
@@ -82,7 +82,7 @@
if valid_vals > 0:
loss = (mask * dists).sum() / valid_vals
else:
- loss = torch.Tensor([0]).cuda()
+ loss = mask.new_tensor([0])
return loss
def batch_index_select(inp, dim, index):
@@ -103,6 +103,7 @@
xx = torch.bmm(x, x.transpose(2, 1))
yy = torch.bmm(y, y.transpose(2, 1))
zz = torch.bmm(x, y.transpose(2, 1))
+ use_cuda = x.device.type == "cuda"
if use_cuda:
dtype = torch.cuda.LongTensor
else:
================================================
FILE: patches/smplx.diff
================================================
+++ venv/lib/python3.10/site-packages/smplx/body_models.py
@@ -366,7 +366,7 @@
num_repeats = int(batch_size / betas.shape[0])
betas = betas.expand(num_repeats, -1)
- vertices, joints = lbs(betas, full_pose, self.v_template,
+ vertices, joints, _ = lbs(betas, full_pose, self.v_template,
self.shapedirs, self.posedirs,
self.J_regressor, self.parents,
self.lbs_weights, pose2rot=pose2rot)
@@ -1228,7 +1228,7 @@
shapedirs = torch.cat([self.shapedirs, self.expr_dirs], dim=-1)
- vertices, joints = lbs(shape_components, full_pose, self.v_template,
+ vertices, joints, A = lbs(shape_components, full_pose, self.v_template,
shapedirs, self.posedirs,
self.J_regressor, self.parents,
self.lbs_weights, pose2rot=pose2rot,
@@ -1283,7 +1283,9 @@
right_hand_pose=right_hand_pose,
jaw_pose=jaw_pose,
v_shaped=v_shaped,
- full_pose=full_pose if return_full_pose else None)
+ full_pose=full_pose if return_full_pose else None,
+ A=A,
+ )
return output
+++ venv/lib/python3.10/site-packages/smplx/lbs.py
@@ -245,7 +245,7 @@
verts = v_homo[:, :, :3, 0]
- return verts, J_transformed
+ return verts, J_transformed, (A, J)
def vertices2joints(J_regressor: Tensor, vertices: Tensor) -> Tensor:
+++ venv/lib/python3.10/site-packages/smplx/utils.py
@@ -71,6 +71,7 @@
class SMPLXOutput(SMPLHOutput):
expression: Optional[Tensor] = None
jaw_pose: Optional[Tensor] = None
+ A: Optional[Tensor] = None
@dataclass
================================================
FILE: patches/torchgeometry.diff
================================================
+++ venv/lib/python3.10/site-packages/torchgeometry/core/conversions.py
@@ -298,6 +298,9 @@
rmat_t[:, 0, 1] - rmat_t[:, 1, 0]], -1)
t3_rep = t3.repeat(4, 1).t()
+ mask_d2 = mask_d2.float()
+ mask_d0_d1 = mask_d0_d1.float()
+ mask_d0_nd1 = mask_d0_nd1.float()
mask_c0 = mask_d2 * mask_d0_d1
mask_c1 = mask_d2 * (1 - mask_d0_d1)
mask_c2 = (1 - mask_d2) * mask_d0_nd1
================================================
FILE: requirements.txt
================================================
matplotlib>=3.5.1
numpy>=1.22.3
opencv_python>=4.5.5.64
Pillow>=9.1.0
plotly>=5.7.0
pyrender>=0.1.45
scikit_image>=0.19.2
scipy>=1.8.0
Shapely>=1.8.1.post1
scikit-image>=0.19.2
tensorboard>=2.8.0
torchgeometry>=0.1.2
tqdm>=4.64.0
trimesh>=3.10.8
git+https://github.com/muelea/selfcontact.git@08da422526419c24736c0616bca49623e442c26a
git+https://github.com/vchoutas/smplx.git@5fa20519735cceda19afed0beeabd53caef711cd
================================================
FILE: scripts/download.sh
================================================
#!/usr/bin/env sh
set -euo pipefail
asset_dir="./assets"
[ ! -e "${asset_dir}"/models_smplx_v1_1.zip ] \
&& echo Error: Download SMPL-X body model from https://smpl-x.is.tue.mpg.de \
and save zip archive to "${asset_dir}" \
&& exit 1 \
&& :
asset_urls=(
# Download constants (SPIN)
http://visiondata.cis.upenn.edu/spin/data.tar.gz
# Download essentials (SMPLify-XMC)
https://download.is.tue.mpg.de/tuch/smplify-xmc-essentials.zip
# Download sketch2pose models
http://www-labs.iro.umontreal.ca/~bmpix/sketch2pose/models.zip
# Download test images
http://www-labs.iro.umontreal.ca/~bmpix/sketch2pose/images.zip
)
for asset_url in "${asset_urls[@]}"; do
wget \
-nc \
-c \
--directory-prefix "${asset_dir}" \
"${asset_url}"
done
models_dir="./models"
mkdir -p "${models_dir}"
model_files=(
# Unzip smplx models
models_smplx_v1_1.zip
# Unzip essentials (SMPLifu-XMC)
smplify-xmc-essentials.zip
# Unzip sketch2pose models
models.zip
)
for model_file in "${model_files[@]}"; do
unzip \
-u \
-d "${models_dir}" \
"${asset_dir}"/"${model_file}"
done
# Unzip constants (SPIN)
tar \
--skip-old-files \
-xvf "${asset_dir}"/data.tar.gz \
-C "${models_dir}" \
data/smpl_mean_params.npz
data_dir="./data"
mkdir -p "${data_dir}"
# Unzip test images
unzip \
-u \
-d "${data_dir}" \
"${asset_dir}"/images.zip
================================================
FILE: scripts/prepare.sh
================================================
#!/usr/bin/env sh
set -euo pipefail
venv_dir=venv
python -m venv --clear "${venv_dir}"
. "${venv_dir}"/bin/activate
pip install -U pip setuptools
extra="cpu"
[ -x "$(command -v nvcc)" ] && extra="cu113"
pip install \
torch \
torchvision \
--extra-index-url https://download.pytorch.org/whl/"${extra}"
pip install -r requirements.txt
v=$(python -c 'import sys; v = sys.version_info; print(f"{v.major}.{v.minor}")')
for p in patches/*.diff; do
patch -p0 < <(sed "s/python3.10/python${v}/" "${p}")
done
================================================
FILE: scripts/run.sh
================================================
#!/usr/bin/env sh
set -euo pipefail
img_dir="./data/images"
out_dir="./output"
find "${img_dir}" -mindepth 1 -maxdepth 1 -type f -print0 \
| xargs -0 -I "{}" python src/pose.py \
--save-path "${out_dir}" \
--img-path "{}" \
--use-contacts \
--use-natural \
--use-cos \
--use-angle-transf \
exit
# baseline (SMPLify-XMC)
find "${img_dir}" -mindepth 1 -maxdepth 1 -type f -print0 \
| xargs -0 -I "{}" python src/pose.py \
--save-path "${out_dir}_baseline" \
--img-path "{}" \
--c-mse 1 \
--c-par 0 \
--use-contacts \
--use-cos \
--use-angle-transf \
# ablation
find "${img_dir}" -mindepth 1 -maxdepth 1 -type f -print0 \
| xargs -0 -I "{}" python src/pose.py \
--save-path "${out_dir}_wocostransform" \
--img-path "{}" \
--use-contacts \
--use-natural \
--use-cos \
find "${img_dir}" -mindepth 1 -maxdepth 1 -type f -print0 \
| xargs -0 -I "{}" python src/pose.py \
--save-path "${out_dir}_wocontacts" \
--img-path "{}" \
--use-msc \
--use-natural \
--use-cos \
--use-angle-transf \
find "${img_dir}" -mindepth 1 -maxdepth 1 -type f -print0 \
| xargs -0 -I "{}" python src/pose.py \
--save-path "${out_dir}_wonatural" \
--img-path "{}" \
--use-contacts \
--use-cos \
--use-angle-transf \
================================================
FILE: src/fist_pose.py
================================================
left_fist = [
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.4183167815208435, 0.10645648092031479, -1.6593892574310303,
0.15252035856246948, -0.14700782299041748, -1.3719955682754517,
-0.04432843625545502, -0.15799851715564728, -0.938068151473999,
-0.12218914180994034, 0.073341965675354, -1.6415189504623413,
-0.14376045763492584, 0.1927780956029892, -1.3593589067459106,
-0.0851994976401329, 0.01652289740741253, -0.7474589347839355,
-0.9881719946861267, -0.3987707793712616, -1.3535722494125366,
-0.6686224937438965, 0.1261960119009018, -1.080643892288208,
-0.8101894855499268, -0.1306752860546112, -0.8412265777587891,
-0.3495230972766876, -0.17784251272678375, -1.4433038234710693,
-0.46278536319732666, 0.13677796721458435, -1.467200517654419,
-0.3681888282299042, 0.003404417773708701, -0.7764251232147217,
0.850964367389679, 0.2769227623939514, -0.09154807031154633,
0.14500413835048676, 0.09604815393686295, 0.219278022646904,
1.0451993942260742, 0.16911321878433228, -0.2426234930753708,
0.11167845129966736, -0.04289207234978676, 0.41644084453582764,
0.10881128907203674, 0.06598565727472305, 0.756219744682312,
-0.0963931530714035, 0.09091583639383316, 0.18845966458320618,
-0.11809506267309189, -0.050943851470947266, 0.5295845866203308,
-0.14369848370552063, -0.055241718888282776, 0.704857349395752,
-0.019182899966835976, 0.0923367589712143, 0.3379131853580475,
-0.45703303813934326, 0.1962839663028717, 0.6254575848579407,
-0.21465237438678741, 0.06599827855825424, 0.5068942308425903,
-0.36972442269325256, 0.0603446289896965, 0.07949023693799973,
-0.14186954498291016, 0.08585254102945328, 0.6355276107788086,
-0.3033415675163269, 0.05788097903132439, 0.6313892006874084,
-0.17612087726593018, 0.13209305703639984, 0.3733545243740082,
0.850964367389679, -0.2769227623939514, 0.09154807031154633,
-0.4998386800289154, -0.026556432247161865, -0.052880801260471344,
0.5355585217475891, -0.045960985124111176, 0.27735769748687744,
]
left_right_fist = [
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0.0, -0.0,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.4183167815208435, 0.10645648092031479, -1.6593892574310303,
0.15252035856246948, -0.14700782299041748, -1.3719955682754517,
-0.04432843625545502, -0.15799851715564728, -0.938068151473999,
-0.12218914180994034, 0.073341965675354, -1.6415189504623413,
-0.14376045763492584, 0.1927780956029892, -1.3593589067459106,
-0.0851994976401329, 0.01652289740741253, -0.7474589347839355,
-0.9881719946861267, -0.3987707793712616, -1.3535722494125366,
-0.6686224937438965, 0.1261960119009018, -1.080643892288208,
-0.8101894855499268, -0.1306752860546112, -0.8412265777587891,
-0.3495230972766876, -0.17784251272678375, -1.4433038234710693,
-0.46278536319732666, 0.13677796721458435, -1.467200517654419,
-0.3681888282299042, 0.003404417773708701, -0.7764251232147217,
0.850964367389679, 0.2769227623939514, -0.09154807031154633,
0.14500413835048676, 0.09604815393686295, 0.219278022646904,
1.0451993942260742, 0.16911321878433228, -0.2426234930753708,
0.4183167815208435, -0.10645647346973419, 1.6593892574310303,
0.15252038836479187, 0.14700786769390106, 1.3719956874847412,
-0.04432841017842293, 0.15799842774868011, 0.9380677938461304,
-0.12218913435935974, -0.0733419880270958, 1.6415191888809204,
-0.14376048743724823, -0.19277812540531158, 1.3593589067459106,
-0.08519953489303589, -0.016522908583283424, 0.7474592328071594,
-0.9881719350814819, 0.3987707495689392, 1.3535723686218262,
-0.6686226725578308, -0.12619605660438538, 1.080644130706787,
-0.8101896643638611, 0.1306752860546112, 0.8412266373634338,
-0.34952324628829956, 0.17784248292446136, 1.443304181098938,
-0.46278542280197144, -0.13677802681922913, 1.467200517654419,
-0.36818885803222656, -0.0034044249914586544, 0.7764251232147217,
0.8509642481803894, -0.2769228219985962, 0.09154807776212692,
0.14500458538532257, -0.09604845196008682, -0.21927869319915771,
1.0451991558074951, -0.1691131889820099, 0.242623433470726,
]
right_fist = []
for lf, lrf in zip(left_fist, left_right_fist):
if lf != lrf:
right_fist.append(lrf)
else:
right_fist.append(0)
left_flat_up = [
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0, 1.5129635334014893,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
]
left_flat_down = [
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0, -1.4648663997650146,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
]
right_flat_up = [
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0, -1.5021973848342896,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
]
right_flat_down = [
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0, 0, 1.494218111038208,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
]
relaxed = [
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.0, 0.0, 0.0,
0.11167845129966736, 0.04289207234978676, -0.41644084453582764,
0.10881128907203674, -0.06598565727472305, -0.756219744682312,
-0.0963931530714035, -0.09091583639383316, -0.18845966458320618,
-0.11809506267309189, 0.050943851470947266, -0.5295845866203308,
-0.14369848370552063, 0.055241718888282776, -0.704857349395752,
-0.019182899966835976, -0.0923367589712143, -0.3379131853580475,
-0.45703303813934326, -0.1962839663028717, -0.6254575848579407,
-0.21465237438678741, -0.06599827855825424, -0.5068942308425903,
-0.36972442269325256, -0.0603446289896965, -0.07949023693799973,
-0.14186954498291016, -0.08585254102945328, -0.6355276107788086,
-0.3033415675163269, -0.05788097903132439, -0.6313892006874084,
-0.17612087726593018, -0.13209305703639984, -0.3733545243740082,
0.850964367389679, 0.2769227623939514, -0.09154807031154633,
-0.4998386800289154, 0.026556432247161865, 0.052880801260471344,
0.5355585217475891, 0.045960985124111176, -0.27735769748687744,
0.11167845129966736, -0.04289207234978676, 0.41644084453582764,
0.10881128907203674, 0.06598565727472305, 0.756219744682312,
-0.0963931530714035, 0.09091583639383316, 0.18845966458320618,
-0.11809506267309189, -0.050943851470947266, 0.5295845866203308,
-0.14369848370552063, -0.055241718888282776, 0.704857349395752,
-0.019182899966835976, 0.0923367589712143, 0.3379131853580475,
-0.45703303813934326, 0.1962839663028717, 0.6254575848579407,
-0.21465237438678741, 0.06599827855825424, 0.5068942308425903,
-0.36972442269325256, 0.0603446289896965, 0.07949023693799973,
-0.14186954498291016, 0.08585254102945328, 0.6355276107788086,
-0.3033415675163269, 0.05788097903132439, 0.6313892006874084,
-0.17612087726593018, 0.13209305703639984, 0.3733545243740082,
0.850964367389679, -0.2769227623939514, 0.09154807031154633,
-0.4998386800289154, -0.026556432247161865, -0.052880801260471344,
0.5355585217475891, -0.045960985124111176, 0.27735769748687744,
]
# body joints + left arm + right arm
# 25 + 15 + 15
# smpl(left_hand_pose, right_hand_pose)
left_start = 25 * 3
left_end = left_start + 15 * 3
right_end = left_end + 15 * 3
LEFT_FIST = left_fist[left_start:left_end]
RIGHT_FIST = right_fist[left_end:right_end]
LEFT_FLAT_UP = left_flat_up[20 * 3 : 20 * 3 + 3]
LEFT_FLAT_DOWN = left_flat_down[20 * 3 : 20 * 3 + 3]
RIGHT_FLAT_UP = right_flat_up[21 * 3 : 21 * 3 + 3]
RIGHT_FLAT_DOWN = right_flat_down[21 * 3 : 21 * 3 + 3]
LEFT_RELAXED = relaxed[left_start:left_end]
RIGHT_RELAXED = relaxed[left_end:right_end]
INT_TO_FIST = {
"lfl": None,
"lf": LEFT_FIST,
"lu": LEFT_FLAT_UP,
"ld": LEFT_FLAT_DOWN,
"rfl": None,
"rf": RIGHT_FIST,
"ru": RIGHT_FLAT_UP,
"rd": RIGHT_FLAT_DOWN,
}
================================================
FILE: src/hist_cub.py
================================================
import itertools
import functools
import math
import multiprocessing
from pathlib import Path
import matplotlib
matplotlib.rcParams.update({'font.size': 24})
matplotlib.rcParams.update({
"text.usetex": True,
"text.latex.preamble": r"\usepackage{biolinum} \usepackage{libertineRoman} \usepackage{libertineMono} \usepackage{biolinum} \usepackage[libertine]{newtxmath}",
'ps.usedistiller': "xpdf",
})
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import numpy as np
import tqdm
from scipy.stats import wasserstein_distance
import pose_estimation
def cub(x, a, b, c):
x2 = x * x
x3 = x2 * x
y = a * x3 + b * x2 + c * x
return y
def subsample(a, p=0.0005, seed=0):
np.random.seed(seed)
N = len(a)
inds = np.random.choice(range(N), size=int(p * N))
a = a[inds].copy()
return a
def read_cos_opt(path, fname="cos_hist.npy"):
cos_opt = []
for p in Path(path).rglob(fname):
d = np.load(p)
cos_opt.append(d)
cos_opt = np.array(cos_opt)
return cos_opt
def plot_hist(cos_opt_dir, hist_smpl_fpath, params, out_dir, bins=10, xy=None):
cos_opt = read_cos_opt(cos_opt_dir)
angle_opt = np.arccos(cos_opt)
angle_opt2 = cub(angle_opt, *params)
cos_opt2 = np.cos(angle_opt2)
cos_smpl = np.load(hist_smpl_fpath)
# cos_smpl = subsample(cos_smpl)
print(cos_smpl.shape)
cos_smpl = np.clip(cos_smpl, -1, 1)
cos_opt = angle_opt
cos_opt2 = angle_opt2
cos_smpl = np.arccos(cos_smpl)
cos_opt = 180 / math.pi * cos_opt
cos_opt2 = 180 / math.pi * cos_opt2
cos_smpl = 180 / math.pi * cos_smpl
max_range = 90 # math.pi / 2
xticks = [0, 15, 30, 45, 60, 75, 90]
for idx, bone in enumerate(pose_estimation.SKELETON):
i, j = bone
i_name = pose_estimation.KPS[i]
j_name = pose_estimation.KPS[j]
if i_name != "Left Upper Leg":
continue
name = f"{i_name}_{j_name}"
gs = gridspec.GridSpec(2, 4)
fig = plt.figure(tight_layout=True, figsize=(16, 8), dpi=300)
ax0 = fig.add_subplot(gs[0, 0])
ax0.hist(cos_smpl[:, idx], bins=bins, range=(0, max_range), density=True)
ax0.set_xticks(xticks)
ax0.tick_params(labelbottom=False, labelleft=True)
ax1 = fig.add_subplot(gs[1, 0], sharex=ax0)
ax1.hist(cos_opt[:, idx], bins=bins, range=(0, max_range), density=True)
ax1.set_xticks(xticks)
if xy is not None:
ax2 = fig.add_subplot(gs[:, 1:3])
ax2.plot(xy[0], xy[1], linewidth=8)
ax2.plot(xy[0], xy[0], linewidth=4, linestyle="dashed")
ax2.set_xticks(xticks)
ax2.set_yticks(xticks)
ax3 = fig.add_subplot(gs[0, 3], sharey=ax0)
ax3.hist(cos_opt2[:, idx], bins=bins, range=(0, max_range), density=True)
ax3.set_xticks(xticks)
ax3.tick_params(labelbottom=False, labelleft=False)
ax4 = fig.add_subplot(gs[1, 3], sharex=ax3, sharey=ax1)
alpha = 0.5
ax4.hist(cos_opt[:, idx], bins=bins, range=(0, max_range), density=True, label=r"$\mathcal{B}_i$", alpha=alpha)
ax4.hist(cos_opt2[:, idx], bins=bins, range=(0, max_range), density=True, label=r"$f(\mathcal{B}_i)$", alpha=alpha)
ax4.hist(cos_smpl[:, idx], bins=bins, range=(0, max_range), density=True, label=r"$\mathcal{A}_i$", alpha=alpha)
ax4.set_xticks(xticks)
ax4.tick_params(labelbottom=True, labelleft=False)
ax4.legend()
fig.savefig(out_dir / f"hist_{name}.png")
plt.close()
def kldiv(p_hist, q_hist):
wd = wasserstein_distance(p_hist, q_hist)
return wd
def calc_histogram(x, bins=10, range=(0, 1)):
h, _ = np.histogram(x, bins=bins, range=range, density=True)
return h
def step(params, angles_opt, p_hist, bone_idx=None):
if sum(params) > 1:
return math.inf, params
kl = 0
for i, _ in enumerate(pose_estimation.SKELETON):
if bone_idx is not None and i != bone_idx:
continue
angles_opt2 = cub(angles_opt[:, i], *params)
if angles_opt2.max() > 1 or angles_opt2.min() < 0:
kl = math.inf
break
q_hist = calc_histogram(angles_opt2)
kl += kldiv(p_hist[i], q_hist)
return kl, params
def optimize(cos_opt_dir, hist_smpl_fpath, bone_idx=None):
cos_opt = read_cos_opt(cos_opt_dir)
angles_opt = np.arccos(cos_opt) / (math.pi / 2)
cos_smpl = np.load(hist_smpl_fpath)
# cos_smpl = subsample(cos_smpl)
print(cos_smpl.shape)
cos_smpl = np.clip(cos_smpl, -1, 1)
mask = cos_smpl <= 1
assert np.all(mask), (~mask).mean()
mask = cos_smpl >= 0
assert np.all(mask), (~mask).mean()
angles_smpl = np.arccos(cos_smpl) / (math.pi / 2)
p_hist = [
calc_histogram(angles_smpl[:, i])
for i, _ in enumerate(pose_estimation.SKELETON)
]
with multiprocessing.Pool(8) as p:
results = list(
tqdm.tqdm(
p.imap_unordered(
functools.partial(step, angles_opt=angles_opt, p_hist=p_hist, bone_idx=bone_idx),
itertools.product(
np.linspace(0, 20, 100),
np.linspace(-20, 20, 200),
np.linspace(-20, 1, 100),
),
),
total=(100 * 200 * 100),
)
)
kls, params = zip(*results)
ind = np.argmin(kls)
best_params = params[ind]
print(kls[ind], best_params)
inds = np.argsort(kls)
for i in inds[:10]:
print(kls[i])
print(params[i])
print()
return best_params
def main():
cos_opt_dir = "paper_single2_150mse"
hist_smpl_fpath = "./data/hist_smpl.npy"
# hist_smpl_fpath = "./testtest.npy"
params = optimize(cos_opt_dir, hist_smpl_fpath)
# params = (1.2121212121212122, -1.105527638190953, 0.787878787878789)
# params = (0.20202020202020202, 0.30150753768844396, 0.3636363636363633)
print(params)
x = np.linspace(0, math.pi / 2, 100)
y = cub(x / (math.pi / 2), *params) * (math.pi / 2)
x = x * 180 / math.pi
y = y * 180 / math.pi
out_dir = Path("hists")
out_dir.mkdir(parents=True, exist_ok=True)
plot_hist(cos_opt_dir, hist_smpl_fpath, params, out_dir, xy=(x, y))
plt.figure(figsize=(4, 4), dpi=300)
plt.plot(x, y, linewidth=6)
plt.plot(x, x, linewidth=2, linestyle="dashed")
xticks = [0, 15, 30, 45, 60, 75, 90]
plt.xticks(xticks)
plt.yticks(xticks)
plt.axis("equal")
plt.tight_layout()
plt.savefig(out_dir / "new_out.png")
if __name__ == "__main__":
main()
================================================
FILE: src/losses.py
================================================
import itertools
import torch
import torch.nn as nn
import pose_estimation
class MSE(nn.Module):
def __init__(self, ignore=None):
super().__init__()
self.mse = torch.nn.MSELoss(reduction="none")
self.ignore = ignore if ignore is not None else []
def forward(self, y_pred, y_data):
loss = self.mse(y_pred, y_data)
if len(self.ignore) > 0:
loss[self.ignore] *= 0
return loss.sum() / (len(loss) - len(self.ignore))
class Parallel(nn.Module):
def __init__(self, skeleton, ignore=None, ground_parallel=None):
super().__init__()
self.skeleton = skeleton
if ignore is not None:
self.ignore = set(ignore)
else:
self.ignore = set()
self.ground_parallel = ground_parallel if ground_parallel is not None else []
self.parallel_in_3d = []
self.cos = None
def forward(self, y_pred3d, y_data, z, spine_j, writer=None, global_step=0):
y_pred = y_pred3d[:, :2]
rleg, lleg = spine_j
Lcon2d = Lcount = 0
if hasattr(self, "contact_2d"):
for c2d in self.contact_2d:
for (
(src_1, dst_1, t_1),
(src_2, dst_2, t_2),
) in itertools.combinations(c2d, 2):
a_1 = torch.lerp(y_data[src_1], y_data[dst_1], t_1)
a_2 = torch.lerp(y_data[src_2], y_data[dst_2], t_2)
a = a_2 - a_1
b_1 = torch.lerp(y_pred[src_1], y_pred[dst_1], t_1)
b_2 = torch.lerp(y_pred[src_2], y_pred[dst_2], t_2)
b = b_2 - b_1
lcon2d = ((a - b) ** 2).sum()
Lcon2d = Lcon2d + lcon2d
Lcount += 1
if Lcount > 0:
Lcon2d = Lcon2d / Lcount
Ltan = Lpar = Lcos = Lcount = 0
Lspine = 0
for i, bone in enumerate(self.skeleton):
if bone in self.ignore:
continue
src, dst = bone
b = y_data[dst] - y_data[src]
t = nn.functional.normalize(b, dim=0)
n = torch.stack([-t[1], t[0]])
if src == 10 and dst == 11: # right leg
a = rleg
elif src == 13 and dst == 14: # left leg
a = lleg
else:
a = y_pred[dst] - y_pred[src]
bone_name = f"{pose_estimation.KPS[src]}_{pose_estimation.KPS[dst]}"
c = a - b
lcos_loc = ltan_loc = lpar_loc = 0
if self.cos is not None:
if bone not in [
(1, 2), # Neck + Right Shoulder
(1, 5), # Neck + Left Shoulder
(9, 10), # Hips + Right Upper Leg
(9, 13), # Hips + Left Upper Leg
]:
a = y_pred[dst] - y_pred[src]
l2d = torch.norm(a, dim=0)
l3d = torch.norm(y_pred3d[dst] - y_pred3d[src], dim=0)
lcos = self.cos[i]
lcos_loc = (l2d / l3d - lcos) ** 2
Lcos = Lcos + lcos_loc
lpar_loc = ((a / l2d) * n).sum() ** 2
Lpar = Lpar + lpar_loc
else:
ltan_loc = ((c * t).sum()) ** 2
Ltan = Ltan + ltan_loc
lpar_loc = (c * n).sum() ** 2
Lpar = Lpar + lpar_loc
if writer is not None:
writer.add_scalar(f"tan/{bone_name}", ltan_loc, global_step=global_step)
writer.add_scalar(f"cos/{bone_name}", lcos_loc, global_step=global_step)
writer.add_scalar(f"par/{bone_name}", lpar_loc, global_step=global_step)
Lcount += 1
if Lcount > 0:
Ltan = Ltan / Lcount
Lcos = Lcos / Lcount
Lpar = Lpar / Lcount
Lspine = Lspine / Lcount
Lgr = Lcount = 0
for (src, dst), value in self.ground_parallel:
bone = y_pred[dst] - y_pred[src]
bone = nn.functional.normalize(bone, dim=0)
l = (torch.abs(bone[0]) - value) ** 2
Lgr = Lgr + l
Lcount += 1
if Lcount > 0:
Lgr = Lgr / Lcount
Lstraight3d = Lcount = 0
for (i, j), (k, l) in self.parallel_in_3d:
a = z[j] - z[i]
a = nn.functional.normalize(a, dim=0)
b = z[l] - z[k]
b = nn.functional.normalize(b, dim=0)
lo = (((a * b).sum() - 1) ** 2).sum()
Lstraight3d = Lstraight3d + lo
Lcount += 1
b = y_data[1] - y_data[8]
b = nn.functional.normalize(b, dim=0)
if Lcount > 0:
Lstraight3d = Lstraight3d / Lcount
return Ltan, Lcos, Lpar, Lspine, Lgr, Lstraight3d, Lcon2d
class MimickedSelfContactLoss(nn.Module):
def __init__(self, geodesics_mask):
super().__init__()
"""
Loss that lets vertices in contact on presented mesh attract vertices that are close.
"""
# geodesic distance mask
self.register_buffer("geomask", geodesics_mask)
def forward(
self,
presented_contact,
vertices,
v2v=None,
contact_mode="dist_tanh",
contact_thresh=1,
):
contactloss = 0.0
if v2v is None:
# compute pairwise distances
verts = vertices.contiguous()
nv = verts.shape[1]
v2v = verts.squeeze().unsqueeze(1).expand(
nv, nv, 3
) - verts.squeeze().unsqueeze(0).expand(nv, nv, 3)
v2v = torch.norm(v2v, 2, 2)
# loss for self-contact from mimic'ed pose
if len(presented_contact) > 0:
# without geodesic distance mask, compute distances
# between each pair of verts in contact
with torch.no_grad():
cvertstobody = v2v[presented_contact, :]
cvertstobody = cvertstobody[:, presented_contact]
maskgeo = self.geomask[presented_contact, :]
maskgeo = maskgeo[:, presented_contact]
weights = torch.ones_like(cvertstobody).to(verts.device)
weights[~maskgeo] = float("inf")
min_idx = torch.min((cvertstobody + 1) * weights, 1)[1]
min_idx = presented_contact[min_idx.cpu().numpy()]
v2v_min = v2v[presented_contact, min_idx]
# tanh will not pull vertices that are ~more than contact_thres far apart
if contact_mode == "dist_tanh":
contactloss = contact_thresh * torch.tanh(v2v_min / contact_thresh)
contactloss = contactloss.mean()
else:
contactloss = v2v_min.mean()
return contactloss
================================================
FILE: src/pose.py
================================================
import argparse
import math
from pathlib import Path
import cv2
import numpy as np
import PIL.Image as Image
import selfcontact
import selfcontact.losses
import shapely.geometry
import torch
import torch.nn as nn
import torch.optim as optim
import torchgeometry
import tqdm
import trimesh
from skimage import measure
from torch.utils.tensorboard.writer import SummaryWriter
import fist_pose
import hist_cub
import losses
import pose_estimation
import spin
import utils
PE_KSP_TO_SPIN = {
"Head": "Head",
"Neck": "Neck",
"Right Shoulder": "Right ForeArm",
"Right Arm": "Right Arm",
"Right Hand": "Right Hand",
"Left Shoulder": "Left ForeArm",
"Left Arm": "Left Arm",
"Left Hand": "Left Hand",
"Spine": "Spine1",
"Hips": "Hips",
"Right Upper Leg": "Right Upper Leg",
"Right Leg": "Right Leg",
"Right Foot": "Right Foot",
"Left Upper Leg": "Left Upper Leg",
"Left Leg": "Left Leg",
"Left Foot": "Left Foot",
"Left Toe": "Left Toe",
"Right Toe": "Right Toe",
}
MODELS_DIR = "models"
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--pose-estimation-model-path",
type=str,
default=f"./{MODELS_DIR}/hrn_w48_384x288.onnx",
help="Pose Estimation model",
)
parser.add_argument(
"--contact-model-path",
type=str,
default=f"./{MODELS_DIR}/contact_hrn_w32_256x192.onnx",
help="Contact model",
)
parser.add_argument(
"--device",
type=str,
default="cuda",
choices=["cpu", "cuda"],
help="Torch device",
)
parser.add_argument(
"--spin-model-path",
type=str,
default=f"./{MODELS_DIR}/spin_model_smplx_eft_18.pt",
help="SPIN model path",
)
parser.add_argument(
"--smpl-type",
type=str,
default="smplx",
choices=["smplx"],
help="SMPL model type",
)
parser.add_argument(
"--smpl-model-dir",
type=str,
default=f"./{MODELS_DIR}/models/smplx",
help="SMPL model dir",
)
parser.add_argument(
"--smpl-mean-params-path",
type=str,
default=f"./{MODELS_DIR}/data/smpl_mean_params.npz",
help="SMPL mean params",
)
parser.add_argument(
"--essentials-dir",
type=str,
default=f"./{MODELS_DIR}/smplify-xmc-essentials",
help="SMPL Essentials folder for contacts",
)
parser.add_argument(
"--parametrization-path",
type=str,
default=f"./{MODELS_DIR}/smplx_parametrization/parametrization.npy",
help="Parametrization path",
)
parser.add_argument(
"--bone-parametrization-path",
type=str,
default=f"./{MODELS_DIR}/smplx_parametrization/bone_to_param2.npy",
help="Bone parametrization path",
)
parser.add_argument(
"--foot-inds-path",
type=str,
default=f"./{MODELS_DIR}/smplx_parametrization/foot_inds.npy",
help="Foot indinces",
)
parser.add_argument(
"--save-path",
type=str,
required=True,
help="Path to save the results",
)
parser.add_argument(
"--img-path",
type=str,
required=True,
help="Path to img to test",
)
parser.add_argument(
"--use-contacts",
action="store_true",
help="Use contact model",
)
parser.add_argument(
"--use-msc",
action="store_true",
help="Use MSC loss",
)
parser.add_argument(
"--use-natural",
action="store_true",
help="Use regularity",
)
parser.add_argument(
"--use-cos",
action="store_true",
help="Use cos model",
)
parser.add_argument(
"--use-angle-transf",
action="store_true",
help="Use cube foreshortening transformation",
)
parser.add_argument(
"--c-mse",
type=float,
default=0,
help="MSE weight",
)
parser.add_argument(
"--c-par",
type=float,
default=10,
help="Parallel weight",
)
parser.add_argument(
"--c-f",
type=float,
default=1000,
help="Cos coef",
)
parser.add_argument(
"--c-parallel",
type=float,
default=100,
help="Parallel weight",
)
parser.add_argument(
"--c-reg",
type=float,
default=1000,
help="Regularity weight",
)
parser.add_argument(
"--c-cont2d",
type=float,
default=1,
help="Contact 2D weight",
)
parser.add_argument(
"--c-msc",
type=float,
default=17_500,
help="MSC weight",
)
parser.add_argument(
"--fist",
nargs="+",
type=str,
choices=list(fist_pose.INT_TO_FIST),
)
args = parser.parse_args()
return args
def freeze_layers(model):
for module in model.modules():
if type(module) is False:
continue
if isinstance(module, nn.modules.batchnorm._BatchNorm):
module.eval()
for m in module.parameters():
m.requires_grad = False
if isinstance(module, nn.Dropout):
module.eval()
for m in module.parameters():
m.requires_grad = False
def project_and_normalize_to_spin(vertices_3d, camera):
vertices_2d = vertices_3d # [:, :2]
scale, translate = camera[0], camera[1:]
translate = scale.new_zeros(3)
translate[:2] = camera[1:]
vertices_2d = vertices_2d + translate
vertices_2d = scale * vertices_2d + 1
vertices_2d = spin.constants.IMG_RES / 2 * vertices_2d
return vertices_2d
def project_and_normalize_to_spin_legs(vertices_3d, A, camera):
A, J = A
A = A[0]
J = J[0]
L = vertices_3d.new_tensor(
[
[0.98619063, 0.16560926, 0.00127302],
[-0.16560601, 0.98603675, 0.01749799],
[0.00164258, -0.01746717, 0.99984609],
]
)
R = vertices_3d.new_tensor(
[
[0.9910211, -0.13368178, -0.0025208],
[0.13367888, 0.99027076, 0.03864949],
[-0.00267045, -0.03863944, 0.99924965],
]
)
scale = camera[0]
R = A[2, :3, :3] @ R # 2 - right
L = A[1, :3, :3] @ L # 1 - left
r = J[5] - J[2]
l = J[4] - J[1]
rleg = scale * spin.constants.IMG_RES / 2 * R @ r
lleg = scale * spin.constants.IMG_RES / 2 * L @ l
rleg = rleg[:2]
lleg = lleg[:2]
return rleg, lleg
def rotation_matrix_to_angle_axis(rotmat):
bs, n_joints, *_ = rotmat.size()
rotmat = torch.cat(
[
rotmat.view(-1, 3, 3),
rotmat.new_tensor([0, 0, 1], dtype=torch.float32)
.view(bs, 3, 1)
.expand(n_joints, -1, -1),
],
dim=-1,
)
aa = torchgeometry.rotation_matrix_to_angle_axis(rotmat)
aa = aa.reshape(bs, 3 * n_joints)
return aa
def get_smpl_output(smpl, rotmat, betas, use_betas=True, zero_hands=False):
if smpl.name() == "SMPL":
smpl_output = smpl(
betas=betas if use_betas else None,
body_pose=rotmat[:, 1:],
global_orient=rotmat[:, 0].unsqueeze(1),
pose2rot=False,
)
elif smpl.name() == "SMPL-X":
rotmat = rotation_matrix_to_angle_axis(rotmat)
if zero_hands:
for i in [20, 21]:
rotmat[:, 3 * i : 3 * (i + 1)] = 0
for i in [12, 15]: # neck, head
rotmat[:, 3 * i + 1] = 0 # y
smpl_output = smpl(
betas=betas if use_betas else None,
body_pose=rotmat[:, 3:],
global_orient=rotmat[:, :3],
pose2rot=True,
)
else:
raise NotImplementedError
return smpl_output, rotmat
def get_predictions(model_hmr, smpl, input_img, use_betas=True, zero_hands=False):
input_img = input_img.unsqueeze(0)
rotmat, betas, camera = model_hmr(input_img)
smpl_output, rotmat = get_smpl_output(
smpl, rotmat, betas, use_betas=use_betas, zero_hands=zero_hands
)
rotmat = rotmat.squeeze(0)
betas = betas.squeeze(0)
camera = camera.squeeze(0)
z = smpl_output.joints
z = z.squeeze(0)
return rotmat, betas, camera, smpl_output, z
def get_pred_and_data(
model_hmr, smpl, selector, input_img, use_betas=True, zero_hands=False
):
rotmat, betas, camera, smpl_output, zz = get_predictions(
model_hmr, smpl, input_img, use_betas=use_betas, zero_hands=zero_hands
)
joints = smpl_output.joints.squeeze(0)
joints_2d = project_and_normalize_to_spin(joints, camera)
rleg, lleg = project_and_normalize_to_spin_legs(joints, smpl_output.A, camera)
joints_2d_orig = joints_2d
joints_2d = joints_2d[selector]
vertices = smpl_output.vertices.squeeze(0)
vertices_2d = project_and_normalize_to_spin(vertices, camera)
zz = zz[selector]
return (
rotmat,
betas,
camera,
joints_2d,
zz,
vertices_2d,
smpl_output,
(rleg, lleg),
joints_2d_orig,
)
def normalize_keypoints_to_spin(keypoints_2d, img_size):
h, w = img_size
if h > w: # vertically
ax1 = 1
ax2 = 0
else: # horizontal
ax1 = 0
ax2 = 1
shift = (img_size[ax1] - img_size[ax2]) / 2
scale = spin.constants.IMG_RES / img_size[ax2]
keypoints_2d_normalized = np.copy(keypoints_2d)
keypoints_2d_normalized[:, ax2] -= shift
keypoints_2d_normalized *= scale
return keypoints_2d_normalized, shift, scale, ax2
def unnormalize_keypoints_from_spin(keypoints_2d, shift, scale, ax2):
keypoints_2d_normalized = np.copy(keypoints_2d)
keypoints_2d_normalized /= scale
keypoints_2d_normalized[:, ax2] += shift
return keypoints_2d_normalized
def get_vertices_in_heatmap(contact_heatmap):
contact_heatmap_size = contact_heatmap.shape[:2]
label = measure.label(contact_heatmap)
y_data_conts = []
for i in range(1, label.max() + 1):
predicted_kps_contact = np.vstack(np.nonzero(label == i)[::-1]).T.astype(
"float"
)
predicted_kps_contact_scaled, *_ = normalize_keypoints_to_spin(
predicted_kps_contact, contact_heatmap_size
)
y_data_cont = torch.from_numpy(predicted_kps_contact_scaled).int().tolist()
y_data_cont = shapely.geometry.MultiPoint(y_data_cont).convex_hull
y_data_conts.append(y_data_cont)
return y_data_conts
def get_contact_heatmap(model_contact, img_path, thresh=0.5):
contact_heatmap = pose_estimation.infer_single_image(
model_contact,
img_path,
input_img_size=(192, 256),
return_kps=False,
)
contact_heatmap = contact_heatmap.squeeze(0)
contact_heatmap_orig = contact_heatmap.copy()
mi = contact_heatmap.min()
ma = contact_heatmap.max()
contact_heatmap = (contact_heatmap - mi) / (ma - mi)
contact_heatmap_ = ((contact_heatmap > thresh) * 255).astype("uint8")
contact_heatmap = np.repeat(contact_heatmap[..., None], repeats=3, axis=-1)
contact_heatmap = (contact_heatmap * 255).astype("uint8")
return contact_heatmap_, contact_heatmap, contact_heatmap_orig
def discretize(parametrization, n_bins=100):
bins = np.linspace(0, 1, n_bins + 1)
inds = np.digitize(parametrization, bins)
disc_parametrization = bins[inds - 1]
return disc_parametrization
def get_mapping_from_params_to_verts(verts, params):
mapping = {}
for v, t in zip(verts, params):
mapping.setdefault(t, []).append(v)
return mapping
def find_contacts(y_data_conts, keypoints_2d, bone_to_params, thresh=12, step=0.0072246375):
n_bins = int(math.ceil(1 / step)) - 1 # mean face's circumradius
contact = []
contact_2d = []
for_mask = []
for y_data_cont in y_data_conts:
contact_loc = []
contact_2d_loc = []
buffer = y_data_cont.buffer(thresh)
mask_add = False
for i, j in pose_estimation.SKELETON:
verts, t3d = bone_to_params[(i, j)]
if len(verts) == 0:
continue
t3d = discretize(t3d, n_bins=n_bins)
t3d_to_verts = get_mapping_from_params_to_verts(verts, t3d)
t3d_to_verts_sorted = sorted(t3d_to_verts.items(), key=lambda x: x[0])
t3d_sorted_np = np.array([x for x, _ in t3d_to_verts_sorted])
line = shapely.geometry.LineString([keypoints_2d[i], keypoints_2d[j]])
lint = buffer.intersection(line)
if len(lint.boundary.geoms) < 2:
continue
t2d_start = line.project(lint.boundary.geoms[0], normalized=True)
t2d_end = line.project(lint.boundary.geoms[1], normalized=True)
assert t2d_start <= t2d_end
t2ds = discretize(
np.linspace(t2d_start, t2d_end, n_bins + 1), n_bins=n_bins
)
to_add = False
for t2d in t2ds:
if t2d < t3d_sorted_np[0] or t2d > t3d_sorted_np[-1]:
continue
t2d_ind = np.searchsorted(t3d_sorted_np, t2d)
c = t3d_to_verts_sorted[t2d_ind][1]
contact_loc.extend(c)
to_add = True
mask_add = True
if t2d_ind + 1 < len(t3d_to_verts_sorted):
c = t3d_to_verts_sorted[t2d_ind + 1][1]
contact_loc.extend(c)
if t2d_ind > 0:
c = t3d_to_verts_sorted[t2d_ind - 1][1]
contact_loc.extend(c)
if to_add:
contact_2d_loc.append((i, j, t2d_start + 0.5 * (t2d_end - t2d_start)))
if mask_add:
for_mask.append(buffer.exterior.coords.xy)
contact_loc = sorted(set(contact_loc))
contact_loc = np.array(contact_loc, dtype="int")
contact.append(contact_loc)
contact_2d.append(contact_2d_loc)
for_mask = [np.stack((x, y), axis=0).T[:, None].astype("int") for x, y in for_mask]
return contact, contact_2d, for_mask
def optimize(
model_hmr,
smpl,
selector,
input_img,
keypoints_2d,
optimizer,
args,
loss_mse=None,
loss_parallel=None,
c_mse=0.0,
c_new_mse=1.0,
c_beta=1e-3,
sc_crit=None,
msc_crit=None,
contact=None,
n_steps=60,
save_path=None,
writer=None,
i_ini=0,
):
to_save = False
if save_path is not None:
(
img_original,
predicted_keypoints_2d,
save_path,
shift,
scale,
ax2,
prefix,
) = save_path
to_save = True
mean_zfoot_val = {}
with tqdm.trange(n_steps) as pbar:
for i in pbar:
global_step = i + i_ini
optimizer.zero_grad()
(
rotmat_pred,
betas_pred,
camera_pred,
keypoints_3d_pred,
z,
vertices_2d_pred,
smpl_output,
(rleg, lleg),
joints_2d_orig,
) = get_pred_and_data(
model_hmr,
smpl,
selector,
input_img,
)
keypoints_2d_pred = keypoints_3d_pred[:, :2]
if to_save:
utils.save_results_image(
camera=camera_pred.detach().cpu().numpy(),
focal_length_x=spin.constants.FOCAL_LENGTH,
focal_length_y=spin.constants.FOCAL_LENGTH,
vertices=smpl_output.vertices.detach()[0].cpu().numpy(),
input_img=img_original,
faces=smpl.faces,
keypoints=predicted_keypoints_2d,
keypoints_2=unnormalize_keypoints_from_spin(
keypoints_2d_pred.detach().cpu().numpy(), shift, scale, ax2
),
# keypoints_2=unnormalize_keypoints_from_spin(joints_2d_orig.detach().cpu().numpy(), shift, scale, ax2),
# heatmap=predicted_contact_heatmap_raw,
filename=save_path / f"{prefix}_{i:0>4}.png",
contactlist=contact,
user_study=False,
)
loss = l2 = 0.0
if c_mse > 0 and loss_mse is not None:
l2 = loss_mse(keypoints_2d_pred, keypoints_2d)
loss = loss + c_mse * l2
if writer is not None:
writer.add_scalar("mse", l2, global_step=global_step)
vertices_pred = smpl_output.vertices
lpar = z_loss = loss_sh = 0.0
if c_new_mse > 0 and loss_parallel is not None:
Ltan, Lcos, Lpar, Lspine, Lgr, Lstraight3d, Lcon2d = loss_parallel(
keypoints_3d_pred,
keypoints_2d,
z,
(rleg, lleg),
writer=writer,
global_step=global_step,
)
lpar = (
Ltan
+ c_new_mse * (args.c_f * Lcos + args.c_parallel * Lpar)
+ Lspine
+ args.c_reg * Lgr
+ args.c_reg * Lstraight3d
+ args.c_cont2d * Lcon2d
)
loss = loss + 300 * lpar
if writer is not None:
writer.add_scalar("tan", Ltan, global_step=global_step)
writer.add_scalar("cos", Lcos, global_step=global_step)
writer.add_scalar("par", Lpar, global_step=global_step)
writer.add_scalar("spine", Lspine, global_step=global_step)
writer.add_scalar("ground/chain", Lgr, global_step=global_step)
writer.add_scalar(
"straight_in_3d", Lstraight3d, global_step=global_step
)
writer.add_scalar("contact/con2d", Lcon2d, global_step=global_step)
for side in ["left", "right"]:
attr = f"{side}_foot_inds"
if hasattr(loss_parallel, attr):
foot_inds = getattr(loss_parallel, attr)
zind = 1
if attr not in mean_zfoot_val:
with torch.no_grad():
mean_zfoot_val[attr] = torch.median(
vertices_pred[0, foot_inds, zind], dim=0
).values
loss_foot = (
(vertices_pred[0, foot_inds, zind] - mean_zfoot_val[attr])
** 2
).sum()
loss = loss + args.c_reg * loss_foot
if writer is not None:
writer.add_scalar(
f"ground/{side} foot",
loss_foot,
global_step=global_step,
)
if hasattr(loss_parallel, "silhuette_vertices_inds"):
inds = loss_parallel.silhuette_vertices_inds
loss_sh = (
(vertices_pred[0, inds, 1] - loss_parallel.ground) ** 2
).sum()
loss = loss + args.c_reg * loss_sh
if writer is not None:
writer.add_scalar(
"ground/silhuette", loss_sh, global_step=global_step
)
lbeta = (betas_pred**2).mean()
lcam = ((torch.exp(-camera_pred[0] * 10)) ** 2).mean()
loss = loss + c_beta * lbeta + lcam
if writer is not None:
writer.add_scalar("loss/beta", lbeta, global_step=global_step)
writer.add_scalar("loss/cam", lcam, global_step=global_step)
lgsc_a = gsc_contact_loss = faces_angle_loss = 0.0
if sc_crit is not None:
gsc_contact_loss, faces_angle_loss = sc_crit(
vertices_pred,
)
lgsc_a = 1000 * gsc_contact_loss + 0.1 * faces_angle_loss
loss = loss + lgsc_a
if writer is not None:
writer.add_scalar(
"contact/gsc", gsc_contact_loss, global_step=global_step
)
writer.add_scalar(
"contact/faces_angle", faces_angle_loss, global_step=global_step
)
msc_loss = 0.0
if contact is not None and len(contact) > 0 and msc_crit is not None:
if not isinstance(contact, list):
contact = [contact]
for cntct in contact:
msc_loss = msc_crit(
cntct,
vertices_pred,
)
loss = loss + args.c_msc * msc_loss
if writer is not None:
writer.add_scalar(
"contact/msc", msc_loss, global_step=global_step
)
loss.backward()
optimizer.step()
epoch_loss = loss.item()
pbar.set_postfix(
**{
"l": f"{epoch_loss:.3}",
"l2": f"{l2:.3}",
"par": f"{lpar:.3}",
"beta": f"{lbeta:.3}",
"cam": f"{lcam:.3}",
"z": f"{z_loss:.3}",
"gsc_contact": f"{float(gsc_contact_loss):.3}",
"faces_angle": f"{float(faces_angle_loss):.3}",
"msc": f"{float(msc_loss):.3}",
}
)
with torch.no_grad():
(
rotmat_pred,
betas_pred,
camera_pred,
keypoints_3d_pred,
z,
vertices_2d_pred,
smpl_output,
(rleg, lleg),
joints_2d_orig,
) = get_pred_and_data(
model_hmr,
smpl,
selector,
input_img,
zero_hands=True,
)
return (
rotmat_pred,
betas_pred,
camera_pred,
keypoints_3d_pred,
vertices_2d_pred,
smpl_output,
z,
joints_2d_orig,
)
def optimize_ft(
theta,
camera,
smpl,
selector,
input_img,
keypoints_2d,
args,
loss_mse=None,
loss_parallel=None,
c_mse=0.0,
c_new_mse=1.0,
sc_crit=None,
msc_crit=None,
contact=None,
n_steps=60,
save_path=None,
writer=None,
i_ini=0,
zero_hands=False,
fist=None,
):
to_save = False
if save_path is not None:
(
img_original,
predicted_keypoints_2d,
save_path,
shift,
scale,
ax2,
prefix,
) = save_path
to_save = True
mean_zfoot_val = {}
theta = theta.detach().clone()
camera = camera.detach().clone()
rotmat_pred = nn.Parameter(theta)
camera_pred = nn.Parameter(camera)
optimizer = torch.optim.Adam(
[
rotmat_pred,
camera_pred,
],
lr=1e-3,
)
global_step = i_ini
with tqdm.trange(n_steps) as pbar:
for i in pbar:
global_step = i + i_ini
optimizer.zero_grad()
global_orient = rotmat_pred[:3]
body_pose = rotmat_pred[3:]
smpl_output = smpl(
global_orient=global_orient.unsqueeze(0),
body_pose=body_pose.unsqueeze(0),
pose2rot=True,
)
z = smpl_output.joints
z = z.squeeze(0)
joints = smpl_output.joints.squeeze(0)
joints_2d = project_and_normalize_to_spin(joints, camera_pred)
rleg, lleg = project_and_normalize_to_spin_legs(
joints, smpl_output.A, camera_pred
)
joints_2d = joints_2d[selector]
z = z[selector]
keypoints_3d_pred = joints_2d
keypoints_2d_pred = keypoints_3d_pred[:, :2]
if to_save:
utils.save_results_image(
camera=camera_pred.detach().cpu().numpy(),
focal_length_x=spin.constants.FOCAL_LENGTH,
focal_length_y=spin.constants.FOCAL_LENGTH,
vertices=smpl_output.vertices.detach()[0].cpu().numpy(),
input_img=img_original,
faces=smpl.faces,
keypoints=predicted_keypoints_2d,
keypoints_2=unnormalize_keypoints_from_spin(
keypoints_2d_pred.detach().cpu().numpy(), shift, scale, ax2
),
# keypoints_2=unnormalize_keypoints_from_spin(joints_2d_orig.detach().cpu().numpy(), shift, scale, ax2),
# heatmap=predicted_contact_heatmap_raw,
filename=save_path / f"{prefix}_{i:0>4}.png",
contactlist=contact,
user_study=False,
)
lprior = ((rotmat_pred - theta) ** 2).sum() + (
(camera_pred - camera) ** 2
).sum()
loss = lprior
l2 = 0.0
if c_mse > 0 and loss_mse is not None:
l2 = loss_mse(keypoints_2d_pred, keypoints_2d)
loss = loss + c_mse * l2
if writer is not None:
writer.add_scalar("mse", l2, global_step=global_step)
vertices_pred = smpl_output.vertices
lpar = z_loss = loss_sh = 0.0
if c_new_mse > 0 and loss_parallel is not None:
Ltan, Lcos, Lpar, Lspine, Lgr, Lstraight3d, Lcon2d = loss_parallel(
keypoints_3d_pred,
keypoints_2d,
z,
(rleg, lleg),
writer=writer,
global_step=global_step,
)
lpar = (
Ltan
+ c_new_mse * (args.c_f * Lcos + args.c_parallel * Lpar)
+ Lspine
+ args.c_reg * Lgr
+ args.c_reg * Lstraight3d
+ args.c_cont2d * Lcon2d
)
loss = loss + 300 * lpar
if writer is not None:
writer.add_scalar("tan", Ltan, global_step=global_step)
writer.add_scalar("cos", Lcos, global_step=global_step)
writer.add_scalar("par", Lpar, global_step=global_step)
writer.add_scalar("spine", Lspine, global_step=global_step)
writer.add_scalar("ground/chain", Lgr, global_step=global_step)
writer.add_scalar(
"straight_in_3d", Lstraight3d, global_step=global_step
)
writer.add_scalar("contact/con2d", Lcon2d, global_step=global_step)
for side in ["left", "right"]:
attr = f"{side}_foot_inds"
if hasattr(loss_parallel, attr):
foot_inds = getattr(loss_parallel, attr)
zind = 1
if attr not in mean_zfoot_val:
with torch.no_grad():
mean_zfoot_val[attr] = torch.median(
vertices_pred[0, foot_inds, zind], dim=0
).values
loss_foot = (
(vertices_pred[0, foot_inds, zind] - mean_zfoot_val[attr])
** 2
).sum()
loss = loss + args.c_reg * loss_foot
if writer is not None:
writer.add_scalar(
f"ground/{side} foot",
loss_foot,
global_step=global_step,
)
if hasattr(loss_parallel, "silhuette_vertices_inds"):
inds = loss_parallel.silhuette_vertices_inds
loss_sh = (
(vertices_pred[0, inds, 1] - loss_parallel.ground) ** 2
).sum()
loss = loss + args.c_reg * loss_sh
if writer is not None:
writer.add_scalar(
"ground/silhuette", loss_sh, global_step=global_step
)
lgsc_a = gsc_contact_loss = faces_angle_loss = 0.0
if sc_crit is not None:
gsc_contact_loss, faces_angle_loss = sc_crit(vertices_pred)
lgsc_a = 1000 * gsc_contact_loss + 0.1 * faces_angle_loss
loss = loss + lgsc_a
if writer is not None:
writer.add_scalar(
"contact/gsc", gsc_contact_loss, global_step=global_step
)
writer.add_scalar(
"contact/faces_angle", faces_angle_loss, global_step=global_step
)
msc_loss = 0.0
if contact is not None and len(contact) > 0 and msc_crit is not None:
if not isinstance(contact, list):
contact = [contact]
for cntct in contact:
msc_loss = msc_crit(
cntct,
vertices_pred,
)
loss = loss + args.c_msc * msc_loss
if writer is not None:
writer.add_scalar(
"contact/msc", msc_loss, global_step=global_step
)
loss.backward()
optimizer.step()
epoch_loss = loss.item()
pbar.set_postfix(
**{
"l": f"{epoch_loss:.3}",
"l2": f"{l2:.3}",
"par": f"{lpar:.3}",
"z": f"{z_loss:.3}",
"gsc_contact": f"{float(gsc_contact_loss):.3}",
"faces_angle": f"{float(faces_angle_loss):.3}",
"msc": f"{float(msc_loss):.3}",
}
)
rotmat_pred = rotmat_pred.detach()
if zero_hands:
for i in [20, 21]:
rotmat_pred[3 * i : 3 * (i + 1)] = 0
for i in [12, 15]: # neck, head
rotmat_pred[3 * i + 1] = 0 # y
global_orient = rotmat_pred[:3]
body_pose = rotmat_pred[3:]
left_hand_pose = None
right_hand_pose = None
if fist is not None:
left_hand_pose = rotmat_pred.new_tensor(fist_pose.LEFT_RELAXED).unsqueeze(0)
right_hand_pose = rotmat_pred.new_tensor(fist_pose.RIGHT_RELAXED).unsqueeze(0)
for f in fist:
pp = fist_pose.INT_TO_FIST[f]
if pp is not None:
pp = rotmat_pred.new_tensor(pp).unsqueeze(0)
if f.startswith("lf"):
left_hand_pose = pp
elif f.startswith("rf"):
right_hand_pose = pp
elif f.startswith("l"):
body_pose[19 * 3 : 19 * 3 + 3] = pp
left_hand_pose = None
elif f.startswith("r"):
body_pose[20 * 3 : 20 * 3 + 3] = pp
right_hand_pose = None
else:
raise RuntimeError(f"No such hand pose: {f}")
with torch.no_grad():
smpl_output = smpl(
global_orient=global_orient.unsqueeze(0),
body_pose=body_pose.unsqueeze(0),
left_hand_pose=left_hand_pose,
right_hand_pose=right_hand_pose,
pose2rot=True,
)
return rotmat_pred, smpl_output
def create_bone(i, j, keypoints_2d):
a = keypoints_2d[i]
b = keypoints_2d[j]
ab = b - a
ab = torch.nn.functional.normalize(ab, dim=0)
return ab
def is_parallel_to_plane(bone, thresh=21):
return abs(bone[0]) > math.cos(math.radians(thresh))
def is_close_to_plane(bone, plane, thresh):
dist = abs(bone[0] - plane)
return dist < thresh
def get_selector():
selector = []
for kp in pose_estimation.KPS:
tmp = spin.JOINT_NAMES.index(PE_KSP_TO_SPIN[kp])
selector.append(tmp)
return selector
def calc_cos(joints_2d, joints_3d):
cos = []
for i, j in pose_estimation.SKELETON:
a = joints_2d[i] - joints_2d[j]
a = nn.functional.normalize(a, dim=0)
b = joints_3d[i] - joints_3d[j]
b = nn.functional.normalize(b, dim=0)[:2]
c = (a * b).sum()
cos.append(c)
cos = torch.stack(cos, dim=0)
return cos
def get_natural(keypoints_2d, vertices, right_foot_inds, left_foot_inds, loss_parallel, smpl):
height_2d = (
keypoints_2d.max(dim=0).values[0] - keypoints_2d.min(dim=0).values[0]
).item()
plane_2d = keypoints_2d.max(dim=0).values[0].item()
ground_parallel = []
parallel_in_3d = []
parallel3d_bones = set()
# parallel chains
for i, j, k in [
("Right Upper Leg", "Right Leg", "Right Foot"),
("Right Leg", "Right Foot", "Right Toe"), # to remove?
("Left Upper Leg", "Left Leg", "Left Foot"),
("Left Leg", "Left Foot", "Left Toe"), # to remove?
("Right Shoulder", "Right Arm", "Right Hand"),
("Left Shoulder", "Left Arm", "Left Hand"),
# ("Hips", "Spine", "Neck"),
# ("Spine", "Neck", "Head"),
]:
i = pose_estimation.KPS.index(i)
j = pose_estimation.KPS.index(j)
k = pose_estimation.KPS.index(k)
upleg_leg = create_bone(i, j, keypoints_2d)
leg_foot = create_bone(j, k, keypoints_2d)
if is_parallel_to_plane(upleg_leg) and is_parallel_to_plane(leg_foot):
if is_close_to_plane(
upleg_leg, plane_2d, thresh=0.1 * height_2d
) or is_close_to_plane(leg_foot, plane_2d, thresh=0.1 * height_2d):
ground_parallel.append(((i, j), 1))
ground_parallel.append(((j, k), 1))
if (upleg_leg * leg_foot).sum() > math.cos(math.radians(21)):
parallel_in_3d.append(((i, j), (j, k)))
parallel3d_bones.add((i, j))
parallel3d_bones.add((j, k))
# parallel feets
for i, j in [
("Right Foot", "Right Toe"),
("Left Foot", "Left Toe"),
]:
i = pose_estimation.KPS.index(i)
j = pose_estimation.KPS.index(j)
if (i, j) in parallel3d_bones:
continue
foot_toe = create_bone(i, j, keypoints_2d)
if is_parallel_to_plane(foot_toe, thresh=25):
if "Right" in pose_estimation.KPS[i]:
loss_parallel.right_foot_inds = right_foot_inds
else:
loss_parallel.left_foot_inds = left_foot_inds
loss_parallel.ground_parallel = ground_parallel
loss_parallel.parallel_in_3d = parallel_in_3d
vertices_np = vertices[0].cpu().numpy()
if len(ground_parallel) > 0:
# Silhuette veritices
mesh = trimesh.Trimesh(vertices=vertices_np, faces=smpl.faces, process=False)
silhuette_vertices_mask_1 = np.abs(mesh.vertex_normals[..., 2]) < 2e-1
height_3d = vertices_np[:, 1].max() - vertices_np[:, 1].min()
plane_3d = vertices_np[:, 1].max()
silhuette_vertices_mask_2 = (
np.abs(vertices_np[:, 1] - plane_3d) < 0.15 * height_3d
)
silhuette_vertices_mask = np.logical_and(
silhuette_vertices_mask_1, silhuette_vertices_mask_2
)
(silhuette_vertices_inds,) = np.where(silhuette_vertices_mask)
if len(silhuette_vertices_inds) > 0:
loss_parallel.silhuette_vertices_inds = silhuette_vertices_inds
loss_parallel.ground = plane_3d
def get_cos(keypoints_3d_pred, use_angle_transf, loss_parallel):
keypoints_2d_pred = keypoints_3d_pred[:, :2]
with torch.no_grad():
cos_r = calc_cos(keypoints_2d_pred, keypoints_3d_pred)
alpha = torch.acos(cos_r)
if use_angle_transf:
leg_inds = [
5,
6, # right leg
7,
8, # left leg
]
foot_inds = [15, 16]
nleg_inds = sorted(
set(range(len(pose_estimation.SKELETON))) - set(leg_inds) - set(foot_inds)
)
alpha[nleg_inds] = alpha[nleg_inds] - alpha[nleg_inds].min()
amli = alpha[leg_inds].min()
leg_inds.extend(foot_inds)
alpha[leg_inds] = alpha[leg_inds] - amli
angles = alpha.detach().cpu().numpy()
angles = hist_cub.cub(
angles / (math.pi / 2),
a=1.2121212121212122,
b=-1.105527638190953,
c=0.787878787878789,
) * (math.pi / 2)
alpha = alpha.new_tensor(angles)
loss_parallel.cos = torch.cos(alpha)
return cos_r
def save_mesh_with_winding_numbers(sc_module, vertices, smpl, save_path):
triangles = sc_module.triangles(vertices)
exterior = sc_module.get_intersection_mask(vertices, triangles, test_segments=False)
exterior = exterior.cpu().numpy().squeeze(0)
utils.save_mesh_with_colors(
vertices[0].cpu().numpy(),
smpl.faces,
save_path / "winding_numbers.ply",
mask=exterior,
)
exterior = sc_module.get_intersection_mask(vertices, triangles)
exterior = exterior.cpu().numpy().squeeze(0)
utils.save_mesh_with_colors(
vertices[0].cpu().numpy(),
smpl.faces,
save_path / "winding_numbers_filtered.ply",
mask=exterior,
)
def get_contacts(
args,
sc_module,
y_data_conts,
keypoints_2d,
vertices,
bone_to_params,
loss_parallel,
img_size_original,
save_path,
):
use_contacts = args.use_contacts
use_msc = args.use_msc
c_mse = args.c_mse
if use_contacts:
assert c_mse == 0
contact, contact_2d, for_mask = find_contacts(
y_data_conts, keypoints_2d, bone_to_params
)
if len(contact_2d) > 0:
loss_parallel.contact_2d = contact_2d
mask = np.zeros((spin.constants.IMG_RES, spin.constants.IMG_RES), dtype="uint8")
mask += 255
cv2.drawContours(mask, for_mask, -1, 0, 2)
mask = cv2.resize(mask, img_size_original[::-1])
cv2.imwrite(str(save_path / "mask.png"), mask)
if len(contact) == 0:
_, contact = sc_module.verts_in_contact(vertices, return_idx=True)
contact = contact.cpu().numpy().ravel()
elif use_msc:
_, contact = sc_module.verts_in_contact(vertices, return_idx=True)
contact = contact.cpu().numpy().ravel()
else:
contact = np.array([])
return contact
def save_all(
keypoints_3d_pred,
rotmat_pred,
camera_pred,
betas_pred,
smpl,
contact,
img_original,
predicted_keypoints_2d,
predicted_contact_heatmap_raw,
loss_parallel,
smpl_output,
shift,
scale,
ax2,
summary_writer,
save_path,
fname,
):
keypoints_2d_pred = keypoints_3d_pred[:, :2]
vertices = smpl_output.vertices.detach()
betas_pred = betas_pred.detach().cpu().numpy()
utils.save_pose_params(
rotmat_pred,
camera_pred,
betas_pred,
vertices,
smpl,
contact,
save_path / f"{fname}.pkl",
)
if hasattr(loss_parallel, "silhuette_vertices_inds"):
contact.append(loss_parallel.silhuette_vertices_inds)
img_sw = utils.save_results_image(
camera=camera_pred.detach().cpu().numpy(),
focal_length_x=spin.constants.FOCAL_LENGTH,
focal_length_y=spin.constants.FOCAL_LENGTH,
vertices=vertices[0].cpu().numpy(),
input_img=img_original,
faces=smpl.faces,
keypoints=predicted_keypoints_2d,
keypoints_2=unnormalize_keypoints_from_spin(
keypoints_2d_pred.cpu().numpy(), shift, scale, ax2
)
if shift is not None
else None,
# keypoints_2=unnormalize_keypoints_from_spin(joints_2d_orig.detach().cpu().numpy(), shift, scale, ax2) if shift is not None else None,
heatmap=predicted_contact_heatmap_raw,
filename=save_path / f"{fname}.png",
contactlist=contact,
contact2dlist=loss_parallel.contact_2d
if hasattr(loss_parallel, "contact_2d")
else None,
cos=loss_parallel.cos.tolist() if loss_parallel.cos is not None else None,
)
utils.save_mesh_with_colors(
smpl_output.vertices[0].cpu().numpy(),
smpl.faces,
save_path / f"{fname}.ply",
inds=contact,
)
joints = smpl_output.joints.squeeze(0).cpu().numpy()
fig = utils.plot_3D(joints, vertices.squeeze(0).cpu().numpy(), smpl.faces)
fig.write_html(save_path / f"{fname}.html")
summary_writer.add_image(
fname, np.array(img_sw).astype("float32") / 255, dataformats="HWC"
)
summary_writer.add_mesh(
fname,
vertices=(
vertices.cpu().float()[0]
@ torch.tensor([[1, 0, 0], [0, -1, 0], [0, 0, -1.0]])
).unsqueeze(0),
faces=torch.from_numpy(smpl.faces[None].astype("int64")),
)
def spin_step(
model_hmr,
smpl,
selector,
input_img,
img_original,
predicted_keypoints_2d,
predicted_contact_heatmap_raw,
loss_parallel,
shift,
scale,
ax2,
summary_writer,
save_path,
):
with torch.no_grad():
(
rotmat_pred,
betas_pred,
camera_pred,
keypoints_3d_pred,
_,
_,
smpl_output,
_,
_,
) = get_pred_and_data(
model_hmr,
smpl,
selector,
input_img,
zero_hands=True,
)
save_all(
keypoints_3d_pred,
rotmat_pred,
camera_pred,
betas_pred,
smpl,
None,
img_original,
predicted_keypoints_2d,
predicted_contact_heatmap_raw,
loss_parallel,
smpl_output,
shift,
scale,
ax2,
summary_writer,
save_path,
"spin",
)
def eft_step(
model_hmr,
smpl,
selector,
input_img,
keypoints_2d,
optimizer,
args,
loss_mse,
loss_parallel,
c_beta,
sc_module,
y_data_conts,
bone_to_params,
img_original,
predicted_keypoints_2d,
predicted_contact_heatmap_raw,
shift,
scale,
ax2,
summary_writer,
save_path,
):
img_size_original = img_original.shape[:2]
(
rotmat_pred,
betas_pred,
camera_pred,
keypoints_3d_pred,
_,
smpl_output,
_,
_,
) = optimize(
model_hmr,
smpl,
selector,
input_img,
keypoints_2d,
optimizer,
args,
loss_mse=loss_mse,
loss_parallel=loss_parallel,
c_mse=1,
c_new_mse=0,
c_beta=c_beta,
sc_crit=None,
msc_crit=None,
contact=None,
n_steps=60 + 90,
writer=summary_writer,
)
# find contacts
vertices = smpl_output.vertices.detach()
contact = get_contacts(
args,
sc_module,
y_data_conts,
keypoints_2d,
vertices,
bone_to_params,
loss_parallel,
img_size_original,
save_path,
)
save_all(
keypoints_3d_pred,
rotmat_pred,
camera_pred,
betas_pred,
smpl,
contact,
img_original,
predicted_keypoints_2d,
predicted_contact_heatmap_raw,
loss_parallel,
smpl_output,
shift,
scale,
ax2,
summary_writer,
save_path,
"eft",
)
if sc_module is not None:
save_mesh_with_winding_numbers(sc_module, vertices, smpl, save_path)
return vertices, keypoints_3d_pred, contact
def dc_step(
model_hmr,
smpl,
selector,
input_img,
keypoints_2d,
optimizer,
args,
loss_mse,
loss_parallel,
c_mse,
c_new_mse,
c_beta,
sc_crit,
msc_crit,
contact,
use_contacts,
use_msc,
img_original,
predicted_keypoints_2d,
predicted_contact_heatmap_raw,
shift,
scale,
ax2,
summary_writer,
save_path,
):
(
rotmat_pred,
betas_pred,
camera_pred,
keypoints_3d_pred,
_,
smpl_output,
_,
_,
) = optimize(
model_hmr,
smpl,
selector,
input_img,
keypoints_2d,
optimizer,
args,
loss_mse=loss_mse,
loss_parallel=loss_parallel,
c_mse=c_mse,
c_new_mse=c_new_mse,
c_beta=c_beta,
sc_crit=sc_crit,
msc_crit=msc_crit if use_contacts or use_msc else None,
contact=contact if use_contacts or use_msc else None,
n_steps=60 if use_contacts or use_msc else 0, # + 60,
# save_path=(img_original, predicted_keypoints_2d, save_path, shift, scale, ax2, "dc"),
writer=summary_writer,
i_ini=60 + 90,
)
save_all(
keypoints_3d_pred,
rotmat_pred,
camera_pred,
betas_pred,
smpl,
contact,
img_original,
predicted_keypoints_2d,
predicted_contact_heatmap_raw,
loss_parallel,
smpl_output,
shift,
scale,
ax2,
summary_writer,
save_path,
"dc",
)
return rotmat_pred
def us_step(
model_hmr,
smpl,
selector,
input_img,
rotmat_pred,
keypoints_2d,
args,
loss_mse,
loss_parallel,
c_mse,
c_new_mse,
sc_crit,
msc_crit,
contact,
use_contacts,
use_msc,
img_original,
keypoints_3d_pred,
summary_writer,
save_path,
):
(_, _, camera_pred_us, _, _, _, smpl_output_us, _, _,) = get_pred_and_data(
model_hmr,
smpl,
selector,
input_img,
use_betas=False,
zero_hands=True,
)
rotmat_pred_us, smpl_output_us = optimize_ft(
rotmat_pred,
camera_pred_us,
smpl,
selector,
input_img,
keypoints_2d,
args,
loss_mse=loss_mse,
loss_parallel=loss_parallel,
c_mse=c_mse,
c_new_mse=c_new_mse,
sc_crit=sc_crit,
msc_crit=msc_crit if use_contacts or use_msc else None,
contact=contact if use_contacts or use_msc else None,
n_steps=60 if use_contacts or use_msc else 0, # + 60,
# save_path=(img_original, predicted_keypoints_2d, save_path, shift, scale, ax2, "dc"),
writer=summary_writer,
i_ini=60 + 90 + 60,
zero_hands=True,
fist=args.fist,
)
save_all(
keypoints_3d_pred,
rotmat_pred_us,
camera_pred_us,
torch.zeros(1, 10, dtype=torch.float32),
smpl,
None,
img_original,
None,
None,
loss_parallel,
smpl_output_us,
None,
None,
None,
summary_writer,
save_path,
"us",
)
def main():
args = parse_args()
print(args)
# models
model_pose = cv2.dnn.readNetFromONNX(
args.pose_estimation_model_path
) # "hrn_w48_384x288.onnx"
model_contact = cv2.dnn.readNetFromONNX(
args.contact_model_path
) # "contact_hrn_w32_256x192.onnx"
device = (
torch.device(args.device) if torch.cuda.is_available() else torch.device("cpu")
)
model_hmr = spin.hmr(args.smpl_mean_params_path) # "smpl_mean_params.npz"
model_hmr.to(device)
checkpoint = torch.load(
args.spin_model_path, # "spin_model_smplx_eft_18.pt"
map_location="cpu"
)
smpl = spin.SMPLX(
args.smpl_model_dir, # "models/smplx"
batch_size=1,
create_transl=False,
use_pca=False,
flat_hand_mean=args.fist is not None,
)
smpl.to(device)
selector = get_selector()
use_contacts = args.use_contacts
use_msc = args.use_msc
bone_to_params = np.load(args.bone_parametrization_path, allow_pickle=True).item()
foot_inds = np.load(args.foot_inds_path, allow_pickle=True).item()
left_foot_inds = foot_inds["left_foot_inds"]
right_foot_inds = foot_inds["right_foot_inds"]
if use_contacts:
model_type = args.smpl_type
sc_module = selfcontact.SelfContact(
essentials_folder=args.essentials_dir, # "smplify-xmc-essentials"
geothres=0.3,
euclthres=0.02,
test_segments=True,
compute_hd=True,
model_type=model_type,
device=device,
)
sc_module.to(device)
sc_crit = selfcontact.losses.SelfContactLoss(
contact_module=sc_module,
inside_loss_weight=0.5,
outside_loss_weight=0.0,
contact_loss_weight=0.5,
align_faces=True,
use_hd=True,
test_segments=True,
device=device,
model_type=model_type,
)
sc_crit.to(device)
msc_crit = losses.MimickedSelfContactLoss(geodesics_mask=sc_module.geomask)
msc_crit.to(device)
else:
sc_module = None
sc_crit = None
msc_crit = None
loss_mse = losses.MSE([1, 10, 13]) # Neck + Right Upper Leg + Left Upper Leg
ignore = (
(1, 2), # Neck + Right Shoulder
(1, 5), # Neck + Left Shoulder
(9, 10), # Hips + Right Upper Leg
(9, 13), # Hips + Left Upper Leg
)
loss_parallel = losses.Parallel(
skeleton=pose_estimation.SKELETON,
ignore=ignore,
)
c_mse = args.c_mse
c_new_mse = args.c_par
c_beta = 1e-3
if c_mse > 0:
assert c_new_mse == 0
elif c_mse == 0:
assert c_new_mse > 0
root_path = Path(args.save_path)
root_path.mkdir(exist_ok=True, parents=True)
path_to_imgs = Path(args.img_path)
if path_to_imgs.is_dir():
path_to_imgs = path_to_imgs.iterdir()
else:
path_to_imgs = [path_to_imgs]
for img_path in path_to_imgs:
if not any(
img_path.name.lower().endswith(ext) for ext in [".jpg", ".png", ".jpeg"]
):
continue
img_name = img_path.stem
# use 2d keypoints detection
(
img_original,
predicted_keypoints_2d,
_,
_,
) = pose_estimation.infer_single_image(
model_pose,
img_path,
input_img_size=pose_estimation.IMG_SIZE,
return_kps=True,
)
save_path = root_path / img_name
save_path.mkdir(exist_ok=True, parents=True)
# if (save_path / "us_orig.png").is_file():
# return
summary_writer = SummaryWriter(log_dir=save_path / f"runDoknc2_{c_new_mse}")
img_original = cv2.cvtColor(img_original, cv2.COLOR_BGR2RGB)
img_size_original = img_original.shape[:2]
keypoints_2d, shift, scale, ax2 = normalize_keypoints_to_spin(
predicted_keypoints_2d, img_size_original
)
keypoints_2d = torch.from_numpy(keypoints_2d)
keypoints_2d = keypoints_2d.to(device)
(
predicted_contact_heatmap,
predicted_contact_heatmap_raw,
very_hm_raw,
) = get_contact_heatmap(model_contact, img_path)
predicted_contact_heatmap_raw = Image.fromarray(
predicted_contact_heatmap_raw
).resize(img_size_original[::-1])
predicted_contact_heatmap_raw = cv2.resize(very_hm_raw, img_size_original[::-1])
if c_new_mse == 0:
predicted_contact_heatmap_raw = None
y_data_conts = get_vertices_in_heatmap(predicted_contact_heatmap)
model_hmr.load_state_dict(checkpoint["model"], strict=True)
model_hmr.train()
freeze_layers(model_hmr)
_, input_img = spin.process_image(img_path, input_res=spin.constants.IMG_RES)
input_img = input_img.to(device)
spin_step(
model_hmr,
smpl,
selector,
input_img,
img_original,
predicted_keypoints_2d,
predicted_contact_heatmap_raw,
loss_parallel,
shift,
scale,
ax2,
summary_writer,
save_path,
)
optimizer = optim.Adam(
filter(lambda p: p.requires_grad, model_hmr.parameters()),
lr=1e-6,
)
vertices, keypoints_3d_pred, contact = eft_step(
model_hmr,
smpl,
selector,
input_img,
keypoints_2d,
optimizer,
args,
loss_mse,
loss_parallel,
c_beta,
sc_module,
y_data_conts,
bone_to_params,
img_original,
predicted_keypoints_2d,
predicted_contact_heatmap_raw,
shift,
scale,
ax2,
summary_writer,
save_path,
)
if args.use_natural:
get_natural(
keypoints_2d, vertices, right_foot_inds, left_foot_inds, loss_parallel, smpl,
)
if args.use_cos:
cos_r = get_cos(keypoints_3d_pred, args.use_angle_transf, loss_parallel)
np.save(save_path / "cos_hist", cos_r.cpu().numpy())
rotmat_pred = dc_step(
model_hmr,
smpl,
selector,
input_img,
keypoints_2d,
optimizer,
args,
loss_mse,
loss_parallel,
c_mse,
c_new_mse,
c_beta,
sc_crit,
msc_crit,
contact,
use_contacts,
use_msc,
img_original,
predicted_keypoints_2d,
predicted_contact_heatmap_raw,
shift,
scale,
ax2,
summary_writer,
save_path,
)
us_step(
model_hmr,
smpl,
selector,
input_img,
rotmat_pred,
keypoints_2d,
args,
loss_mse,
loss_parallel,
c_mse,
c_new_mse,
sc_crit,
msc_crit,
contact,
use_contacts,
use_msc,
img_original,
keypoints_3d_pred,
summary_writer,
save_path,
)
if __name__ == "__main__":
main()
================================================
FILE: src/pose_estimation.py
================================================
import math
import cv2
import numpy as np
IMG_SIZE = (288, 384)
MEAN = np.array([0.485, 0.456, 0.406])
STD = np.array([0.229, 0.224, 0.225])
KPS = (
"Head",
"Neck",
"Right Shoulder",
"Right Arm",
"Right Hand",
"Left Shoulder",
"Left Arm",
"Left Hand",
"Spine",
"Hips",
"Right Upper Leg",
"Right Leg",
"Right Foot",
"Left Upper Leg",
"Left Leg",
"Left Foot",
"Left Toe",
"Right Toe",
)
SKELETON = (
(0, 1),
(1, 8),
(8, 9),
(9, 10),
(9, 13),
(10, 11),
(11, 12),
(13, 14),
(14, 15),
(1, 2),
(2, 3),
(3, 4),
(1, 5),
(5, 6),
(6, 7),
(15, 16),
(12, 17),
)
OPENPOSE_TO_GESTURE = (
0, # 0 Head\n",
1, # Neck\n",
2, # 2 Right Shoulder\n",
3, # Right Arm\n",
4, # 4 Right Hand\n",
5, # Left Shoulder\n",
6, # 6 Left Arm\n",
7, # Left Hand\n",
9, # 8 Hips\n",
10, # Right Upper Leg\n",
11, # 10Right Leg\n",
12, # Right Foot\n",
13, # 12Left Upper Leg\n",
14, # Left Leg\n",
15, # 14Left Foot\n",
-1, # \n",
-1, # 16\n",
-1, # \n",
-1, # 18\n",
16, # Left Toe\n",
-1, # 20\n",
-1, # \n",
17, # 22Right Toe\n",
-1, # \n",
-1, # 24\n",
)
def transform(img):
img = img.astype("float32") / 255
img = (img - MEAN) / STD
return np.transpose(img, axes=(2, 0, 1))
def get_affine_transform(
center,
scale,
rot,
output_size,
shift=np.array([0, 0], dtype=np.float32),
inv=0,
pixel_std=200,
):
if not isinstance(scale, np.ndarray) and not isinstance(scale, list):
scale = np.array([scale, scale])
scale_tmp = scale * pixel_std
src_w = scale_tmp[0]
dst_w = output_size[0]
dst_h = output_size[1]
rot_rad = np.pi * rot / 180
src_dir = get_dir([0, src_w * -0.5], rot_rad)
dst_dir = np.array([0, dst_w * -0.5], np.float32)
src = np.zeros((3, 2), dtype=np.float32)
dst = np.zeros((3, 2), dtype=np.float32)
src[0, :] = center + scale_tmp * shift
src[1, :] = center + src_dir + scale_tmp * shift
dst[0, :] = [dst_w * 0.5, dst_h * 0.5]
dst[1, :] = np.array([dst_w * 0.5, dst_h * 0.5]) + dst_dir
src[2:, :] = get_3rd_point(src[0, :], src[1, :])
dst[2:, :] = get_3rd_point(dst[0, :], dst[1, :])
if inv:
trans = cv2.getAffineTransform(np.float32(dst), np.float32(src))
else:
trans = cv2.getAffineTransform(np.float32(src), np.float32(dst))
return trans
def get_3rd_point(a, b):
direct = a - b
return b + np.array([-direct[1], direct[0]], dtype=np.float32)
def get_dir(src_point, rot_rad):
sn, cs = np.sin(rot_rad), np.cos(rot_rad)
src_result = [0, 0]
src_result[0] = src_point[0] * cs - src_point[1] * sn
src_result[1] = src_point[0] * sn + src_point[1] * cs
return src_result
def process_image(path, input_img_size, pixel_std=200):
data_numpy = cv2.imread(path, cv2.IMREAD_COLOR | cv2.IMREAD_IGNORE_ORIENTATION)
# BUG HERE. Must be uncommented
# data_numpy = cv2.cvtColor(data_numpy, cv2.COLOR_BGR2RGB)
h, w = data_numpy.shape[:2]
c = np.array([w / 2, h / 2], dtype=np.float32)
aspect_ratio = input_img_size[0] / input_img_size[1]
if w > aspect_ratio * h:
h = w * 1.0 / aspect_ratio
elif w < aspect_ratio * h:
w = h * aspect_ratio
s = np.array([w / pixel_std, h / pixel_std], dtype=np.float32) * 1.25
r = 0
trans = get_affine_transform(c, s, r, input_img_size, pixel_std=pixel_std)
input = cv2.warpAffine(data_numpy, trans, input_img_size, flags=cv2.INTER_LINEAR)
input = transform(input)
return input, data_numpy, c, s
def get_final_preds(batch_heatmaps, center, scale, post_process=False):
coords, maxvals = get_max_preds(batch_heatmaps)
heatmap_height = batch_heatmaps.shape[2]
heatmap_width = batch_heatmaps.shape[3]
# post-processing
if post_process:
for n in range(coords.shape[0]):
for p in range(coords.shape[1]):
hm = batch_heatmaps[n][p]
px = int(math.floor(coords[n][p][0] + 0.5))
py = int(math.floor(coords[n][p][1] + 0.5))
if 1 < px < heatmap_width - 1 and 1 < py < heatmap_height - 1:
diff = np.array(
[
hm[py][px + 1] - hm[py][px - 1],
hm[py + 1][px] - hm[py - 1][px],
]
)
coords[n][p] += np.sign(diff) * 0.25
preds = coords.copy()
# Transform back
for i in range(coords.shape[0]):
preds[i] = transform_preds(
coords[i], center[i], scale[i], [heatmap_width, heatmap_height]
)
return preds, maxvals
def transform_preds(coords, center, scale, output_size):
target_coords = np.zeros(coords.shape)
trans = get_affine_transform(center, scale, 0, output_size, inv=1)
for p in range(coords.shape[0]):
target_coords[p, 0:2] = affine_transform(coords[p, 0:2], trans)
return target_coords
def affine_transform(pt, t):
new_pt = np.array([pt[0], pt[1], 1.0]).T
new_pt = np.dot(t, new_pt)
return new_pt[:2]
def get_max_preds(batch_heatmaps):
"""
get predictions from score maps
heatmaps: numpy.ndarray([batch_size, num_joints, height, width])
"""
assert isinstance(
batch_heatmaps, np.ndarray
), "batch_heatmaps should be numpy.ndarray"
assert batch_heatmaps.ndim == 4, "batch_images should be 4-ndim"
batch_size = batch_heatmaps.shape[0]
num_joints = batch_heatmaps.shape[1]
width = batch_heatmaps.shape[3]
heatmaps_reshaped = batch_heatmaps.reshape((batch_size, num_joints, -1))
idx = np.argmax(heatmaps_reshaped, 2)
maxvals = np.amax(heatmaps_reshaped, 2)
maxvals = maxvals.reshape((batch_size, num_joints, 1))
idx = idx.reshape((batch_size, num_joints, 1))
preds = np.tile(idx, (1, 1, 2)).astype(np.float32)
preds[:, :, 0] = (preds[:, :, 0]) % width
preds[:, :, 1] = np.floor((preds[:, :, 1]) / width)
pred_mask = np.tile(np.greater(maxvals, 0.0), (1, 1, 2))
pred_mask = pred_mask.astype(np.float32)
preds *= pred_mask
return preds, maxvals
def infer_single_image(model, img_path, input_img_size=(288, 384), return_kps=True):
img_path = str(img_path)
pose_input, img, center, scale = process_image(
img_path, input_img_size=input_img_size
)
model.setInput(pose_input[None])
predicted_heatmap = model.forward()
if not return_kps:
return predicted_heatmap.squeeze(0)
predicted_keypoints, confidence = get_final_preds(
predicted_heatmap, center[None], scale[None], post_process=True
)
(predicted_keypoints, confidence, predicted_heatmap,) = (
predicted_keypoints.squeeze(0),
confidence.squeeze(0),
predicted_heatmap.squeeze(0),
)
return img, predicted_keypoints, confidence, predicted_heatmap
================================================
FILE: src/renderer.py
================================================
import numpy as np
import pyrender
import torch
import trimesh
from torchvision.utils import make_grid
class Renderer:
"""
Renderer used for visualizing the SMPL model
Code adapted from https://github.com/vchoutas/smplify-x
"""
def __init__(self, focal_length=5000, img_res=224, faces=None):
self.renderer = pyrender.OffscreenRenderer(
viewport_width=img_res, viewport_height=img_res, point_size=1.0
)
self.focal_length = focal_length
self.camera_center = [img_res // 2, img_res // 2]
self.faces = faces
def visualize_tb(self, vertices, camera_translation, images):
vertices = vertices.cpu().numpy()
camera_translation = camera_translation.cpu().numpy()
images = images.cpu()
images_np = np.transpose(images.numpy(), (0, 2, 3, 1))
rend_imgs = []
for i in range(vertices.shape[0]):
rend_img = torch.from_numpy(
np.transpose(
self.__call__(vertices[i], camera_translation[i], images_np[i]),
(2, 0, 1),
)
).float()
rend_imgs.append(images[i])
rend_imgs.append(rend_img)
rend_imgs = make_grid(rend_imgs, nrow=2)
return rend_imgs
def __call__(self, vertices, camera_translation, image):
material = pyrender.MetallicRoughnessMaterial(
metallicFactor=0.2, alphaMode="OPAQUE", baseColorFactor=(0.8, 0.3, 0.3, 1.0)
)
camera_translation[0] *= -1.0
mesh = trimesh.Trimesh(vertices, self.faces)
rot = trimesh.transformations.rotation_matrix(np.radians(180), [1, 0, 0])
mesh.apply_transform(rot)
mesh = pyrender.Mesh.from_trimesh(mesh, material=material)
scene = pyrender.Scene(ambient_light=(0.5, 0.5, 0.5))
scene.add(mesh, "mesh")
camera_pose = np.eye(4)
camera_pose[:3, 3] = camera_translation
camera = pyrender.IntrinsicsCamera(
fx=self.focal_length,
fy=self.focal_length,
cx=self.camera_center[0],
cy=self.camera_center[1],
)
scene.add(camera, pose=camera_pose)
light = pyrender.DirectionalLight(color=[1.0, 1.0, 1.0], intensity=1)
light_pose = np.eye(4)
light_pose[:3, 3] = np.array([0, -1, 1])
scene.add(light, pose=light_pose)
light_pose[:3, 3] = np.array([0, 1, 1])
scene.add(light, pose=light_pose)
light_pose[:3, 3] = np.array([1, 1, 2])
scene.add(light, pose=light_pose)
color, rend_depth = self.renderer.render(scene, flags=pyrender.RenderFlags.RGBA)
color = color.astype(np.float32) / 255.0
valid_mask = (rend_depth > 0)[:, :, None]
output_img = color[:, :, :3] * valid_mask + (1 - valid_mask) * image
return output_img
def overlay_mesh(
verts,
faces,
camera_transl,
focal_length_x,
focal_length_y,
camera_center,
H,
W,
img,
camera_rotation=None,
rotaround=None,
contactlist=None,
color=False,
scale=1,
):
material = pyrender.MetallicRoughnessMaterial(
metallicFactor=0.0, alphaMode="OPAQUE", baseColorFactor=(1.0, 1.0, 0.9, 1.0)
)
out_mesh = trimesh.Trimesh(verts, faces, process=False)
out_mesh_col = np.array(out_mesh.visual.vertex_colors)
if contactlist is not None and len(contactlist) > 0:
color = [255, 0, 0, 255]
out_mesh_col[contactlist] = color
out_mesh.visual.vertex_colors = out_mesh_col
if camera_rotation is None:
camera_rotation = np.eye(3)
else:
camera_rotation = camera_rotation[0]
# rotate mesh and stack output images
if rotaround is None:
out_mesh.vertices = np.matmul(verts, camera_rotation.T) + camera_transl
else:
base_mesh = trimesh.Trimesh(verts, faces, process=False)
# rot_center = (base_mesh.vertices[5615] + base_mesh.vertices[5614] ) / 2
rot = trimesh.transformations.rotation_matrix(
np.radians(rotaround), [0, 1, 0], base_mesh.vertices[4297]
)
base_mesh.apply_transform(rot)
out_mesh.vertices = (
np.matmul(base_mesh.vertices, camera_rotation.T) + camera_transl
)
out_mesh.vertices += np.array([0, 0, 50])
# add mesh to scene
mesh = pyrender.Mesh.from_trimesh(
out_mesh,
material=material,
smooth=False,
)
if img is not None:
scene = pyrender.Scene(
bg_color=[0.0, 0.0, 0.0, 0.0],
ambient_light=(0.3, 0.3, 0.3, 1.0),
)
else:
scene = pyrender.Scene(
bg_color=[1.0, 1.0, 1.0, 1.0],
ambient_light=(0.3, 0.3, 0.3, 1.0),
)
scene.add(mesh, "mesh")
# create and add camera
camera_pose = np.eye(4)
camera_pose[1, :] = -camera_pose[1, :]
camera_pose[2, :] = -camera_pose[2, :]
pyrencamera = pyrender.camera.OrthographicCamera(
camera_transl[2],
camera_transl[2],
znear=1e-6,
zfar=1000000,
)
scene.add(pyrencamera, pose=camera_pose)
# create and add light
light = pyrender.PointLight(
color=[1.0, 1.0, 1.0],
intensity=1,
)
light_pose = np.eye(4)
for lp in [[1, 1, -1], [-1, 1, -1], [1, -1, -1], [-1, -1, -1]]:
light_pose[:3, 3] = out_mesh.vertices.mean(0) + np.array(lp)
scene.add(light, pose=light_pose)
r = pyrender.OffscreenRenderer(
viewport_width=int(scale * W),
viewport_height=int(scale * H),
point_size=1.0,
)
color, _ = r.render(scene, flags=pyrender.RenderFlags.RGBA)
color = color.astype(np.float32) / 255.0
if img is not None:
valid_mask = (color[:, :, -1] > 0)[:, :, np.newaxis]
output_img = color[:, :, :-1] * valid_mask + (1 - valid_mask) * img
else:
output_img = color
output_img = (output_img * 255).astype(np.uint8)[..., :3]
return output_img
================================================
FILE: src/spin/__init__.py
================================================
from .constants import JOINT_NAMES
from .hmr import hmr
from .smpl import SMPLX
from .utils import process_image
__all__ = [
"hmr",
"SMPLX",
"process_image",
"JOINT_NAMES",
]
================================================
FILE: src/spin/constants.py
================================================
FOCAL_LENGTH = 5000.0
IMG_RES = 224
# Mean and standard deviation for normalizing input image
IMG_NORM_MEAN = [0.485, 0.456, 0.406]
IMG_NORM_STD = [0.229, 0.224, 0.225]
"""
We create a superset of joints containing the OpenPose joints together with the ones that each dataset provides.
We keep a superset of 24 joints such that we include all joints from every dataset.
If a dataset doesn't provide annotations for a specific joint, we simply ignore it.
The joints used here are the following:
"""
JOINT_NAMES = (
"Hips",
"Left Upper Leg",
"Right Upper Leg",
"Spine",
"Left Leg",
"Right Leg",
"Spine1",
"Left Foot",
"Right Foot",
"Thorax",
"Left Toe",
"Right Toe",
"Neck",
"Left Shoulder",
"Right Shoulder",
"Head",
"Left ForeArm",
"Right ForeArm",
"Left Arm",
"Right Arm",
"Left Hand",
"Right Hand",
# 25 OpenPose joints (in the order provided by OpenPose)
# "OP Nose",
# "OP Neck",
# "OP RShoulder",
# "OP RElbow",
# "OP RWrist",
# "OP LShoulder",
# "OP LElbow",
# "OP LWrist",
# "OP MidHip",
# "OP RHip",
# "OP RKnee",
# "OP RAnkle",
# "OP LHip",
# "OP LKnee",
# "OP LAnkle",
# "OP REye",
# "OP LEye",
# "OP REar",
# "OP LEar",
# "OP LBigToe",
# "OP LSmallToe",
# "OP LHeel",
# "OP RBigToe",
# "OP RSmallToe",
# "OP RHeel",
## 24 Ground Truth joints (superset of joints from different datasets)
# "Right Ankle",
# "Right Knee",
# "Right Hip",
# "Left Hip",
# "Left Knee",
# "Left Ankle",
# "Right Wrist",
# "Right Elbow",
# "Right Shoulder",
# "Left Shoulder",
# "Left Elbow",
# "Left Wrist",
# "Neck (LSP)",
# "Top of Head (LSP)",
# "Pelvis (MPII)",
# "Thorax (MPII)",
# "Spine (H36M)",
# "Jaw (H36M)",
# "Head (H36M)",
# "Nose",
# "Left Eye",
# "Right Eye",
# "Left Ear",
# "Right Ear",
# "OP MidHip",
# "Spine1",
# "Spine2",
# "Spine3",
# "OP Neck",
# "Head",
)
# Dict containing the joints in numerical order
JOINT_IDS = {JOINT_NAMES[i]: i for i in range(len(JOINT_NAMES))}
# Map joints to SMPL joints
JOINT_MAP = {
"Hips": 0,
"Left Upper Leg": 1,
"Right Upper Leg": 2,
"Spine": 3,
"Left Leg": 4,
"Right Leg": 5,
"Spine1": 6,
"Left Foot": 7,
"Right Foot": 8,
"Thorax": 9,
"Left Toe": 10,
"Right Toe": 11,
"Neck": 12,
"Left Shoulder": 13,
"Right Shoulder": 14,
"Head": 15,
"Left ForeArm": 16,
"Right ForeArm": 17,
"Left Arm": 18,
"Right Arm": 19,
"Left Hand": 20,
"Right Hand": 21,
# "OP Nose": 24,
# "OP Neck": 12,
# "OP RShoulder": 17,
# "OP RElbow": 19,
# "OP RWrist": 21,
# "OP LShoulder": 16,
# "OP LElbow": 18,
# "OP LWrist": 20,
# "OP MidHip": 0,
# "OP RHip": 2,
# "OP RKnee": 5,
# "OP RAnkle": 8,
# "OP LHip": 1,
# "OP LKnee": 4,
# "OP LAnkle": 7,
# "OP REye": 25,
# "OP LEye": 26,
# "OP REar": 27,
# "OP LEar": 28,
# "OP LBigToe": 29,
# "OP LSmallToe": 30,
# "OP LHeel": 31,
# "OP RBigToe": 32,
# "OP RSmallToe": 33,
# "OP RHeel": 34,
# "Right Ankle": 8,
# "Right Knee": 5,
# "Right Hip": 45,
# "Left Hip": 46,
# "Left Knee": 4,
# "Left Ankle": 7,
# "Right Wrist": 21,
# "Right Elbow": 19,
# "Right Shoulder": 17,
# "Left Shoulder": 16,
# "Left Elbow": 18,
# "Left Wrist": 20,
# "Neck (LSP)": 47,
# "Top of Head (LSP)": 15, # 48,
# "Pelvis (MPII)": 49,
# "Thorax (MPII)": 50,
# "Spine (H36M)": 51,
# "Jaw (H36M)": 52,
# "Head (H36M)": 15, # 53,
# "Nose": 24,
# "Left Eye": 26,
# "Right Eye": 25,
# "Left Ear": 28,
# "Right Ear": 27,
# "Spine1": 3,
# "Spine2": 6,
# "Spine3": 9,
# "Head": 15,
}
# Joint selectors
# Indices to get the 14 LSP joints from the 17 H36M joints
H36M_TO_J17 = [6, 5, 4, 1, 2, 3, 16, 15, 14, 11, 12, 13, 8, 10, 0, 7, 9]
H36M_TO_J14 = H36M_TO_J17[:14]
# Indices to get the 14 LSP joints from the ground truth joints
J24_TO_J17 = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 18, 14, 16, 17]
J24_TO_J14 = J24_TO_J17[:14]
# Permutation of SMPL pose parameters when flipping the shape
SMPL_JOINTS_FLIP_PERM = [
0,
2,
1,
3,
5,
4,
6,
8,
7,
9,
11,
10,
12,
14,
13,
15,
17,
16,
19,
18,
21,
20,
23,
22,
]
SMPL_POSE_FLIP_PERM = []
for i in SMPL_JOINTS_FLIP_PERM:
SMPL_POSE_FLIP_PERM.append(3 * i)
SMPL_POSE_FLIP_PERM.append(3 * i + 1)
SMPL_POSE_FLIP_PERM.append(3 * i + 2)
# Permutation indices for the 24 ground truth joints
J24_FLIP_PERM = [
5,
4,
3,
2,
1,
0,
11,
10,
9,
8,
7,
6,
12,
13,
14,
15,
16,
17,
18,
19,
21,
20,
23,
22,
]
# Permutation indices for the full set of 49 joints
J49_FLIP_PERM = [
0,
1,
5,
6,
7,
2,
3,
4,
8,
12,
13,
14,
9,
10,
11,
16,
15,
18,
17,
22,
23,
24,
19,
20,
21,
] + [25 + i for i in J24_FLIP_PERM]
================================================
FILE: src/spin/hmr.py
================================================
import math
import numpy as np
import torch
import torch.nn as nn
import torchvision.models.resnet as resnet
def rot6d_to_rotmat(x):
"""Convert 6D rotation representation to 3x3 rotation matrix.
Based on Zhou et al., "On the Continuity of Rotation Representations in Neural Networks", CVPR 2019
Input:
(B,6) Batch of 6-D rotation representations
Output:
(B,3,3) Batch of corresponding rotation matrices
"""
x = x.view(-1, 3, 2)
a1 = x[:, :, 0]
a2 = x[:, :, 1]
b1 = nn.functional.normalize(a1)
b2 = nn.functional.normalize(
a2 - torch.einsum("bi,bi->b", b1, a2).unsqueeze(-1) * b1
)
b3 = torch.cross(b1, b2)
return torch.stack((b1, b2, b3), dim=-1)
class Bottleneck(nn.Module):
"""Redefinition of Bottleneck residual block
Adapted from the official PyTorch implementation
"""
expansion = 4
def __init__(self, inplanes, planes, stride=1, downsample=None):
super(Bottleneck, self).__init__()
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
self.bn1 = nn.BatchNorm2d(planes)
self.conv2 = nn.Conv2d(
planes, planes, kernel_size=3, stride=stride, padding=1, bias=False
)
self.bn2 = nn.BatchNorm2d(planes)
self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
self.bn3 = nn.BatchNorm2d(planes * 4)
self.relu = nn.ReLU(inplace=True)
self.downsample = downsample
self.stride = stride
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv3(out)
out = self.bn3(out)
if self.downsample is not None:
residual = self.downsample(x)
out += residual
out = self.relu(out)
return out
class HMR(nn.Module):
"""SMPL Iterative Regressor with ResNet50 backbone"""
def __init__(self, block, layers, smpl_mean_params):
self.inplanes = 64
super(HMR, self).__init__()
self.n_shape = 10
self.n_cam = 3
self.n_joints = 24
npose = self.n_joints * 6
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
self.bn1 = nn.BatchNorm2d(64)
self.relu = nn.ReLU(inplace=True)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
self.layer1 = self._make_layer(block, 64, layers[0])
self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
self.avgpool = nn.AvgPool2d(7, stride=1)
self.fc1 = nn.Linear(512 * block.expansion + npose + self.n_shape + self.n_cam, 1024)
self.drop1 = nn.Dropout()
self.fc2 = nn.Linear(1024, 1024)
self.drop2 = nn.Dropout()
self.decpose = nn.Linear(1024, npose)
self.decshape = nn.Linear(1024, self.n_shape)
self.deccam = nn.Linear(1024, self.n_cam)
nn.init.xavier_uniform_(self.decpose.weight, gain=0.01)
nn.init.xavier_uniform_(self.decshape.weight, gain=0.01)
nn.init.xavier_uniform_(self.deccam.weight, gain=0.01)
for m in self.modules():
if isinstance(m, nn.Conv2d):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.data.normal_(0, math.sqrt(2.0 / n))
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
mean_params = np.load(smpl_mean_params)
init_pose = torch.from_numpy(mean_params["pose"][:]).unsqueeze(0)
init_shape = torch.from_numpy(
mean_params["shape"][:].astype("float32")
).unsqueeze(0)
init_cam = torch.from_numpy(mean_params["cam"]).unsqueeze(0)
self.register_buffer("init_pose", init_pose)
self.register_buffer("init_shape", init_shape)
self.register_buffer("init_cam", init_cam)
def _make_layer(self, block, planes, blocks, stride=1):
downsample = None
if stride != 1 or self.inplanes != planes * block.expansion:
downsample = nn.Sequential(
nn.Conv2d(
self.inplanes,
planes * block.expansion,
kernel_size=1,
stride=stride,
bias=False,
),
nn.BatchNorm2d(planes * block.expansion),
)
layers = []
layers.append(block(self.inplanes, planes, stride, downsample))
self.inplanes = planes * block.expansion
for _ in range(1, blocks):
layers.append(block(self.inplanes, planes))
return nn.Sequential(*layers)
def forward(self, x, init_pose=None, init_shape=None, init_cam=None, n_iter=3):
batch_size = x.shape[0]
if init_pose is None:
init_pose = self.init_pose.expand(batch_size, -1)
if init_shape is None:
init_shape = self.init_shape.expand(batch_size, -1)
if init_cam is None:
init_cam = self.init_cam.expand(batch_size, -1)
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)
x1 = self.layer1(x)
x2 = self.layer2(x1)
x3 = self.layer3(x2)
x4 = self.layer4(x3)
xf = self.avgpool(x4)
xf = xf.view(xf.size(0), -1)
pred_pose = init_pose
pred_shape = init_shape
pred_cam = init_cam
for _ in range(n_iter):
xc = torch.cat([xf, pred_pose, pred_shape, pred_cam], 1)
xc = self.fc1(xc)
xc = self.drop1(xc)
xc = self.fc2(xc)
xc = self.drop2(xc)
pred_pose = self.decpose(xc) + pred_pose
pred_shape = self.decshape(xc) + pred_shape
pred_cam = self.deccam(xc) + pred_cam
pred_rotmat = rot6d_to_rotmat(pred_pose).view(batch_size, self.n_joints, 3, 3)
return pred_rotmat, pred_shape, pred_cam
def hmr(smpl_mean_params, pretrained=True, **kwargs):
"""Constructs an HMR model with ResNet50 backbone.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
model = HMR(Bottleneck, [3, 4, 6, 3], smpl_mean_params, **kwargs)
if pretrained:
resnet_imagenet = resnet.resnet50(pretrained=True)
model.load_state_dict(resnet_imagenet.state_dict(), strict=False)
return model
================================================
FILE: src/spin/smpl.py
================================================
import numpy as np
import torch
from smplx import SMPL as _SMPL
from smplx import SMPLX as _SMPLX
from smplx.body_models import SMPLOutput, SMPLXOutput
from smplx.lbs import vertices2joints
from .constants import JOINT_MAP, JOINT_NAMES
# Hand joints
SMPLX_HAND_TO_PANOPTIC = [
0,
13,
14,
15,
16,
1,
2,
3,
17,
4,
5,
6,
18,
10,
11,
12,
19,
7,
8,
9,
20,
] # Wrist Thumb to Pinky
class SMPL(_SMPL):
"""Extension of the official SMPL implementation to support more joints"""
JOINTS = (
"Hips",
"Left Upper Leg",
"Right Upper Leg",
"Spine",
"Left Leg",
"Right Leg",
"Spine1",
"Left Foot",
"Right Foot",
"Thorax",
"Left Toe",
"Right Toe",
"Neck",
"Left Shoulder",
"Right Shoulder",
"Head",
"Left ForeArm",
"Right ForeArm",
"Left Arm",
"Right Arm",
"Left Hand",
"Right Hand",
"Left Finger",
"Right Finger",
)
SKELETON = (
(0, 1),
(0, 2),
(0, 3),
(1, 4),
(2, 5),
(3, 6),
(4, 7),
(5, 8),
(6, 9),
(7, 10),
(8, 11),
(9, 12),
(12, 13),
(12, 14),
(12, 15),
(13, 16),
(14, 17),
(16, 18),
(17, 19),
(18, 20),
(19, 21),
(20, 22),
(21, 23),
)
def __init__(self, *args, **kwargs):
super(SMPL, self).__init__(*args, **kwargs)
joints = [JOINT_MAP[i] for i in JOINT_NAMES]
joint_regressor_extra = kwargs["joint_regressor_extra_path"]
J_regressor_extra = np.load(joint_regressor_extra)
self.register_buffer(
"J_regressor_extra", torch.tensor(J_regressor_extra, dtype=torch.float32)
)
self.joint_map = torch.tensor(joints, dtype=torch.long)
def forward(self, *args, **kwargs):
kwargs["get_skin"] = True
smpl_output = super(SMPL, self).forward(*args, **kwargs)
extra_joints = vertices2joints(
self.J_regressor_extra, smpl_output.vertices
) # Additional 9 joints #Check doc/J_regressor_extra.png
joints = torch.cat(
[smpl_output.joints, extra_joints], dim=1
) # [N, 24 + 21, 3] + [N, 9, 3]
joints = joints[:, self.joint_map, :]
output = SMPLOutput(
vertices=smpl_output.vertices,
global_orient=smpl_output.global_orient,
body_pose=smpl_output.body_pose,
joints=joints,
betas=smpl_output.betas,
full_pose=smpl_output.full_pose,
)
return output
class SMPLX(_SMPLX):
"""Extension of the official SMPL implementation to support more joints"""
JOINTS = (
"Hips",
"Left Upper Leg",
"Right Upper Leg",
"Spine",
"Left Leg",
"Right Leg",
"Spine1",
"Left Foot",
"Right Foot",
"Thorax",
"Left Toe",
"Right Toe",
"Neck",
"Left Shoulder",
"Right Shoulder",
"Head",
"Left ForeArm",
"Right ForeArm",
"Left Arm",
"Right Arm",
"Left Hand",
"Right Hand",
)
SKELETON = (
(0, 1),
(0, 2),
(0, 3),
(1, 4),
(2, 5),
(3, 6),
(4, 7),
(5, 8),
(6, 9),
(7, 10),
(8, 11),
(9, 12),
(12, 13),
(12, 14),
(12, 15),
(13, 16),
(14, 17),
(16, 18),
(17, 19),
(18, 20),
(19, 21),
)
def __init__(self, *args, **kwargs):
kwargs["ext"] = "pkl" # We have pkl file
super(SMPLX, self).__init__(*args, **kwargs)
joints = [JOINT_MAP[i] for i in JOINT_NAMES]
self.joint_map = torch.tensor(joints, dtype=torch.long)
def forward(self, *args, **kwargs):
kwargs["get_skin"] = True
# if pose parameter is for SMPL with 21 joints (ignoring root)
try:
if kwargs["body_pose"].shape[1] == 69:
kwargs["body_pose"] = kwargs["body_pose"][
:, : -2 * 3
] # Ignore the last two joints (which are on the palm. Not used)
if kwargs["body_pose"].shape[1] == 23:
kwargs["body_pose"] = kwargs["body_pose"][
:, :-2
] # Ignore the last two joints (which are on the palm. Not used)
except:
pass
smpl_output = super(SMPLX, self).forward(*args, **kwargs)
# SMPL-X Joint order: https://docs.google.com/spreadsheets/d/1_1dLdaX-sbMkCKr_JzJW_RZCpwBwd7rcKkWT_VgAQ_0/edit#gid=0
smplx_to_smpl = (
list(range(0, 22)) + [28, 43] + list(range(55, 76))
) # 28 left middle finger , 43: right middle finger 1
smpl_joints = smpl_output.joints[
:, smplx_to_smpl, :
] # Convert SMPL-X to SMPL 127 ->45
joints = smpl_joints
joints = joints[:, self.joint_map, :]
smplx_lhand = (
[20] + list(range(25, 40)) + list(range(66, 71))
) # 20 for left wrist. 20 finger joints
lhand_joints = smpl_output.joints[:, smplx_lhand, :] # (N,21,3)
lhand_joints = lhand_joints[
:, SMPLX_HAND_TO_PANOPTIC, :
] # Convert SMPL-X hand order to paonptic hand order
smplx_rhand = (
[21] + list(range(40, 55)) + list(range(71, 76))
) # 21 for right wrist. 20 finger joints
rhand_joints = smpl_output.joints[:, smplx_rhand, :] # (N,21,3)
rhand_joints = rhand_joints[
:, SMPLX_HAND_TO_PANOPTIC, :
] # Convert SMPL-X hand order to paonptic hand order
output = SMPLXOutput(
vertices=smpl_output.vertices,
global_orient=smpl_output.global_orient,
body_pose=smpl_output.body_pose,
joints=joints,
right_hand_pose=rhand_joints, # N,21,3
left_hand_pose=lhand_joints, # N,21,3
betas=smpl_output.betas,
full_pose=smpl_output.full_pose,
A=smpl_output.A,
)
return output
"""
0 pelvis',
1 left_hip',
2 right_hip',
3 spine1',
4 left_knee',
5 right_knee',
6 spine2',
7 left_ankle',
8 right_ankle',
9 spine3',
10 left_foot',
11 right_foot',
12 neck',
13 left_collar',
14 right_collar',
15 head',
16 left_shoulder',
17 right_shoulder',
18 left_elbow',
19 right_elbow',
20 left_wrist',
21 right_wrist',
22 jaw',
23 left_eye_smplhf',
24 right_eye_smplhf',
25 left_index1',
26 left_index2',
27 left_index3',
28 left_middle1',
29 left_middle2',
30 left_middle3',
31 left_pinky1',
32 left_pinky2',
33 left_pinky3',
34 left_ring1',
35 left_ring2',
36 left_ring3',
37 left_thumb1',
38 left_thumb2',
39 left_thumb3',
40 right_index1',
41 right_index2',
42 right_index3',
43 right_middle1',
44 right_middle2',
45 right_middle3',
46 right_pinky1',
47 right_pinky2',
48 right_pinky3',
49 right_ring1',
50 right_ring2',
51 right_ring3',
52 right_thumb1',
53 right_thumb2',
54 right_thumb3',
55 nose',
56 right_eye',
57 left_eye',
58 right_ear',
59 left_ear',
60 left_big_toe',
61 left_small_toe',
62 left_heel',
63 right_big_toe',
64 right_small_toe',
65 right_heel',
66 left_thumb',
67 left_index',
68 left_middle',
69 left_ring',
70 left_pinky',
71 right_thumb',
72 right_index',
73 right_middle',
74 right_ring',
75 right_pinky',
76 right_eye_brow1',
77 right_eye_brow2',
78 right_eye_brow3',
79 right_eye_brow4',
80 right_eye_brow5',
81 left_eye_brow5',
82 left_eye_brow4',
83 left_eye_brow3',
84 left_eye_brow2',
85 left_eye_brow1',
86 nose1',
87 nose2',
88 nose3',
89 nose4',
90 right_nose_2',
91 right_nose_1',
92 nose_middle',
93 left_nose_1',
94 left_nose_2',
95 right_eye1',
96 right_eye2',
97 right_eye3',
98 right_eye4',
99 right_eye5',
100 right_eye6',
101 left_eye4',
102 left_eye3',
103 left_eye2',
104 left_eye1',
105 left_eye6',
106 left_eye5',
107 right_mouth_1',
108 right_mouth_2',
109 right_mouth_3',
110 mouth_top',
111 left_mouth_3',
112 left_mouth_2',
113 left_mouth_1',
114 left_mouth_5', # 59 in OpenPose output
115 left_mouth_4', # 58 in OpenPose output
116 mouth_bottom',
117 right_mouth_4',
118 right_mouth_5',
119 right_lip_1',
120 right_lip_2',
121 lip_top',
122 left_lip_2',
123 left_lip_1',
124 left_lip_3',
125 lip_bottom',
126 right_lip_3',
127 right_contour_1',
128 right_contour_2',
129 right_contour_3',
130 right_contour_4',
131 right_contour_5',
132 right_contour_6',
133 right_contour_7',
134 right_contour_8',
135 contour_middle',
136 left_contour_8',
137 left_contour_7',
138 left_contour_6',
139 left_contour_5',
140 left_contour_4',
141 left_contour_3',
142 left_contour_2',
143 left_contour_1'
"""
# SMPL Joints:
"""
0 pelvis',
1 left_hip',
2 right_hip',
3 spine1',
4 left_knee',
5 right_knee',
6 spine2',
7 left_ankle',
8 right_ankle',
9 spine3',
10 left_foot',
11 right_foot',
12 neck',
13 left_collar',
14 right_collar',
15 head',
16 left_shoulder',
17 right_shoulder',
18 left_elbow',
19 right_elbow',
20 left_wrist',
21 right_wrist',
22
23
24 nose',
25 right_eye',
26 left_eye',
27 right_ear',
28 left_ear',
29 left_big_toe',
30 left_small_toe',
31 left_heel',
32 right_big_toe',
33 right_small_toe',
34 right_heel',
35 left_thumb',
36 left_index',
37 left_middle',
38 left_ring',
39 left_pinky',
40 right_thumb',
41 right_index',
42 right_middle',
43 right_ring',
44 right_pinky',
"""
================================================
FILE: src/spin/utils.py
================================================
import json
import cv2
import numpy as np
import torch
from skimage.transform import resize, rotate
from torchvision.transforms import Normalize
from .constants import IMG_NORM_MEAN, IMG_NORM_STD, IMG_RES
def get_transform(center, scale, res, rot=0):
"""Generate transformation matrix."""
h = 200 * scale
t = np.zeros((3, 3))
t[0, 0] = float(res[1]) / h
t[1, 1] = float(res[0]) / h
t[0, 2] = res[1] * (-float(center[0]) / h + 0.5)
t[1, 2] = res[0] * (-float(center[1]) / h + 0.5)
t[2, 2] = 1
if not rot == 0:
rot = -rot # To match direction of rotation from cropping
rot_mat = np.zeros((3, 3))
rot_rad = rot * np.pi / 180
sn, cs = np.sin(rot_rad), np.cos(rot_rad)
rot_mat[0, :2] = [cs, -sn]
rot_mat[1, :2] = [sn, cs]
rot_mat[2, 2] = 1
# Need to rotate around center
t_mat = np.eye(3)
t_mat[0, 2] = -res[1] / 2
t_mat[1, 2] = -res[0] / 2
t_inv = t_mat.copy()
t_inv[:2, 2] *= -1
t = np.dot(t_inv, np.dot(rot_mat, np.dot(t_mat, t)))
return t
def transform(pt, center, scale, res, invert=0, rot=0):
"""Transform pixel location to different reference."""
t = get_transform(center, scale, res, rot=rot)
if invert:
t = np.linalg.inv(t)
new_pt = np.array([pt[0] - 1, pt[1] - 1, 1.0]).T
new_pt = np.dot(t, new_pt)
return new_pt[:2].astype(int) + 1
def crop(img, center, scale, res, rot=0):
"""Crop image according to the supplied bounding box."""
# Upper left point
ul = np.array(transform([1, 1], center, scale, res, invert=1)) - 1
# Bottom right point
br = np.array(transform([res[0] + 1, res[1] + 1], center, scale, res, invert=1)) - 1
# Padding so that when rotated proper amount of context is included
pad = int(np.linalg.norm(br - ul) / 2 - float(br[1] - ul[1]) / 2)
if not rot == 0:
ul -= pad
br += pad
new_shape = [br[1] - ul[1], br[0] - ul[0]]
if len(img.shape) > 2:
new_shape += [img.shape[2]]
new_img = np.zeros(new_shape)
# Range to fill new array
new_x = max(0, -ul[0]), min(br[0], len(img[0])) - ul[0]
new_y = max(0, -ul[1]), min(br[1], len(img)) - ul[1]
# Range to sample from original image
old_x = max(0, ul[0]), min(len(img[0]), br[0])
old_y = max(0, ul[1]), min(len(img), br[1])
new_img[new_y[0] : new_y[1], new_x[0] : new_x[1]] = img[
old_y[0] : old_y[1], old_x[0] : old_x[1]
]
if not rot == 0:
# Remove padding
new_img = rotate(new_img, rot)
new_img = new_img[pad:-pad, pad:-pad]
new_img = resize(new_img, res)
return new_img
def bbox_from_openpose(openpose_file, rescale=1.2, detection_thresh=0.2):
"""Get center and scale for bounding box from openpose detections."""
with open(openpose_file, "r") as f:
keypoints = json.load(f)["people"][0]["pose_keypoints_2d"]
keypoints = np.reshape(np.array(keypoints), (-1, 3))
valid = keypoints[:, -1] > detection_thresh
valid_keypoints = keypoints[valid][:, :-1]
center = valid_keypoints.mean(axis=0)
bbox_size = (valid_keypoints.max(axis=0) - valid_keypoints.min(axis=0)).max()
# adjust bounding box tightness
scale = bbox_size / 200.0
scale *= rescale
return center, scale
def bbox_from_json(bbox_file):
"""Get center and scale of bounding box from bounding box annotations.
The expected format is [top_left(x), top_left(y), width, height].
"""
with open(bbox_file, "r") as f:
bbox = np.array(json.load(f)["bbox"]).astype(np.float32)
ul_corner = bbox[:2]
center = ul_corner + 0.5 * bbox[2:]
width = max(bbox[2], bbox[3])
scale = width / 200.0
# make sure the bounding box is rectangular
return center, scale
def process_image(img_file, bbox_file=None, openpose_file=None, input_res=IMG_RES):
"""Read image, do preprocessing and possibly crop it according to the bounding box.
If there are bounding box annotations, use them to crop the image.
If no bounding box is specified but openpose detections are available, use them to get the bounding box.
"""
img_file = str(img_file)
normalize_img = Normalize(mean=IMG_NORM_MEAN, std=IMG_NORM_STD)
img = cv2.imread(img_file)[
:, :, ::-1
].copy() # PyTorch does not support negative stride at the moment
if bbox_file is None and openpose_file is None:
# Assume that the person is centerered in the image
height = img.shape[0]
width = img.shape[1]
center = np.array([width // 2, height // 2])
scale = max(height, width) / 200
else:
if bbox_file is not None:
center, scale = bbox_from_json(bbox_file)
elif openpose_file is not None:
center, scale = bbox_from_openpose(openpose_file)
img = crop(img, center, scale, (input_res, input_res))
img = img.astype(np.float32) / 255.0
img = torch.from_numpy(img).permute(2, 0, 1)
norm_img = normalize_img(img.clone())
return img, norm_img
================================================
FILE: src/utils.py
================================================
import colorsys
import itertools
import json
import pickle
import cv2
import plotly.graph_objects as go
import trimesh
import torch
import numpy as np
import PIL.Image as pil_img
import PIL.ImageDraw as ImageDraw
from PIL import Image, ImageChops
from skimage import exposure
import spin
import renderer
import pose_estimation
def load_json(path):
with open(path) as f:
return json.load(f)
def save_json(o, path):
with open(path, "w") as f:
json.dump(o, f)
def load_pkl(path):
with open(path, "rb") as f:
return pickle.load(f)
def save_pkl(o, path):
with open(path, "wb") as f:
pickle.dump(o, f)
def plot_3D(joints, vertices, faces):
x, y, z = joints.T
x1, y1, z1 = vertices.T
i, j, k = faces.T
data = [
go.Mesh3d(
x=x1,
y=y1,
z=z1,
i=i,
j=j,
k=k,
),
go.Scatter3d(
x=x,
y=y,
z=z,
mode="markers",
marker_size=5,
),
]
fig = go.Figure(
data=data,
)
return fig
def draw_keypoints(
input_img_kp,
keypoints,
skeleton,
r,
color,
contact2dlist=None,
contact2dlist_color="green",
cos=None,
):
if keypoints is not None:
draw = ImageDraw.Draw(input_img_kp)
for skidx, (i, j) in enumerate(skeleton):
a = keypoints[i]
b = keypoints[j]
ln = np.linalg.norm(b - a)
xy = [a[0], a[1], b[0], b[1]]
if cos is not None:
c = colorsys.hsv_to_rgb(cos[skidx] ** 8, 1, 1)
c = tuple(int(c_ * 255) for c_ in c)
draw.line(xy, fill=c, width=r)
else:
draw.line(xy, fill=color, width=r)
draw_kpts = [(p[0] - r, p[1] - r, p[0] + r, p[1] + r) for p in keypoints]
for _, elipse in enumerate(draw_kpts):
draw.ellipse(elipse, fill="black", outline="black")
if contact2dlist is not None:
keypoints_torch = torch.from_numpy(keypoints)
for c2d in contact2dlist:
for (src_1, dst_1, t_1), (src_2, dst_2, t_2) in itertools.combinations(
c2d, 2
):
a = torch.lerp(
keypoints_torch[src_1], keypoints_torch[dst_1], t_1
).tolist()
b = torch.lerp(
keypoints_torch[src_2], keypoints_torch[dst_2], t_2
).tolist()
xy = [a[0], a[1], b[0], b[1]]
draw.line(xy, fill=contact2dlist_color, width=max(r // 3, 10))
return input_img_kp
def save_results_image(
camera,
focal_length_x,
focal_length_y,
input_img,
vertices,
faces,
filename,
keypoints=None,
keypoints_2=None,
heatmap=None,
cvt_camera=True,
contactlist=None,
contact2dlist=None,
user_study=True,
cos=None,
):
if isinstance(contactlist, list) and len(contactlist) > 0:
contactlist = np.concatenate(contactlist)
H, W, _ = input_img.shape
HW = max(H, W)
camera_center = np.array([W // 2, H // 2])
if not cvt_camera:
camera_transl = camera.copy()
else:
camera_transl = np.stack(
[
camera[1],
camera[2],
1 / camera[0],
],
)
# draw keypoints
input_img_kp = pil_img.fromarray(input_img)
if keypoints is not None:
draw_keypoints(
input_img_kp,
keypoints,
pose_estimation.SKELETON,
r=int(HW * 0.01),
color=(255, 0, 0, 255),
# contact2dlist=contact2dlist,
# contact2dlist_color="orange",
)
if cos is not None:
input_img_kp_cos = pil_img.fromarray(input_img)
if keypoints is not None:
draw_keypoints(
input_img_kp_cos,
keypoints,
pose_estimation.SKELETON,
r=int(HW * 0.01),
color=(255, 0, 0, 255),
# contact2dlist=contact2dlist,
# contact2dlist_color="orange",
cos=cos,
)
input_img_kp_2 = input_img_kp.copy()
if keypoints_2 is not None:
draw_keypoints(
input_img_kp_2,
keypoints_2,
# spin.SMPLX.SKELETON if "eft" in str(filename) else pose_estimation.SKELETON,
pose_estimation.SKELETON,
r=int(HW * 0.01),
color=(0, 0, 255, 255),
contact2dlist=contact2dlist,
contact2dlist_color="purple",
)
# heatmap = ImageOps.invert(heatmap)
if heatmap is not None:
# input_img_kp_2 = pil_img.blend(input_img_kp_2, heatmap, 0.5)
hm = np.copy(input_img)
gray_img = exposure.rescale_intensity(heatmap, out_range=(0, 255))
gray_img = gray_img.astype(np.uint8)
heatmap_img = cv2.applyColorMap(gray_img, cv2.COLORMAP_JET)
hm = pil_img.fromarray(cv2.cvtColor(heatmap_img, cv2.COLOR_BGR2RGB))
hm.save(filename.with_stem(f"{filename.stem}_heatmap"))
input_img_kp.save(filename.with_stem(f"{filename.stem}_2dkps"))
heatmap = pil_img.fromarray(heatmap)
if cos is not None:
input_img_kp_cos.save(filename.with_stem(f"{filename.stem}_2dkpscos"))
# render fitted mesh from different views
overlay_fit_img = renderer.overlay_mesh(
vertices,
faces,
camera_transl,
focal_length_x,
focal_length_y,
camera_center,
H,
W,
input_img.astype("float32") / 255,
None,
rotaround=None,
)
# overlay_fit_img = pil_img.fromarray(overlay_fit_img)
# draw_keypoints(overlay_fit_img, keypoints_2, r=int(HW * 0.01), color=(0, 0, 255, 255))
# camera_transl[-1] *= 1
view1_fit = renderer.overlay_mesh(
vertices,
faces,
camera_transl.astype(np.float32),
focal_length_x,
focal_length_y,
camera_center,
H,
W,
None,
None,
rotaround=-45,
contactlist=contactlist,
)
view2_fit = renderer.overlay_mesh(
vertices,
faces,
camera_transl.astype(np.float32),
focal_length_x,
focal_length_y,
camera_center,
H,
W,
None,
None,
rotaround=None,
contactlist=contactlist,
)
view3_fit = renderer.overlay_mesh(
vertices,
faces,
camera_transl.astype(np.float32),
focal_length_x,
focal_length_y,
camera_center,
H,
W,
None,
None,
rotaround=90,
contactlist=contactlist,
scale=1,
)
IMG = np.vstack(
(
np.hstack(
(
np.asarray(input_img_kp)
if keypoints is not None
else 255 * np.ones_like(np.asarray(input_img_kp)),
np.asarray(input_img_kp_2),
overlay_fit_img,
# np.asanyarray(overlay_fit_img),
),
),
np.hstack(
(
view1_fit,
view2_fit,
view3_fit,
),
),
),
)
IMG = pil_img.fromarray(IMG)
IMG.thumbnail((2000, 2000))
IMG.save(filename)
if user_study:
w = 768
input_img_kp.thumbnail((w, w))
input_img_kp.save(filename.with_stem(f"{filename.stem}_orig"))
W, H = input_img_kp.size
camera_transl[-1] *= 2
view2_fit = renderer.overlay_mesh(
vertices,
faces,
camera_transl.astype(np.float32),
focal_length_x,
focal_length_y,
camera_center,
H,
W,
None,
None,
rotaround=None,
contactlist=contactlist,
scale=2,
)
view2_fit = pil_img.fromarray(view2_fit)
w *= 2
view2_fit.thumbnail((w, w))
view2_fit.save(filename.with_stem(f"{filename.stem}_same"))
view3_fit = renderer.overlay_mesh(
vertices,
faces,
camera_transl.astype(np.float32),
focal_length_x,
focal_length_y,
camera_center,
H,
W,
None,
None,
rotaround=90,
contactlist=contactlist,
scale=2,
)
view3_fit = pil_img.fromarray(view3_fit)
view3_fit.thumbnail((w, w))
view3_fit.save(filename.with_stem(f"{filename.stem}_alt"))
return IMG
def save_3d_model_on_img(
camera,
vertices,
faces,
img,
filename,
save_path,
):
img_res = max(img.shape[:2])
r = renderer.Renderer(
focal_length=spin.constants.FOCAL_LENGTH,
img_res=img_res,
faces=faces,
)
# Calculate camera parameters for rendering
camera_translation = np.stack(
[
camera[1],
camera[2],
2 * spin.constants.FOCAL_LENGTH / (img_res * camera[0] + 1e-9),
],
)
# Render parametric shape
img_shape = r(vertices, camera_translation, img)
img_shape = (255 * img_shape).astype("uint8")
img_shape = cv2.cvtColor(img_shape, cv2.COLOR_RGB2BGR)
cv2.imwrite(str(save_path / f"shape_{filename}.png"), img_shape)
# Render side views
aroundy = cv2.Rodrigues(np.array([0, np.radians(90.0), 0]))[0]
center = vertices.mean(axis=0)
rot_vertices = np.dot((vertices - center), aroundy) + center
# Render non-parametric shape
img_shape = r(rot_vertices, camera_translation, np.ones_like(img))
img_shape = (255 * img_shape).astype("uint8")
img_shape = cv2.cvtColor(img_shape, cv2.COLOR_RGB2BGR)
cv2.imwrite(str(save_path / f"shape_rot_{filename}.png"), img_shape)
def save_mesh_with_colors(vertices, faces, save_path, mask=None, inds=None):
if inds is not None and isinstance(inds, list):
inds = np.concatenate(inds)
mesh = trimesh.Trimesh(
vertices=vertices,
faces=faces,
process=False,
)
color = np.array(mesh.visual.vertex_colors)
color[:] = [233, 233, 233, 255]
if mask is not None and any(mask):
color[~mask] = [255, 0, 0, 255]
elif inds is not None and len(inds) > 0:
color[inds] = [255, 0, 0, 255]
mesh.visual.vertex_colors = color
mesh.export(save_path)
def save_pose_params(rotmat, camera, betas, vertices, smpl, contact, save_path):
if contact is not None and isinstance(contact, list) and len(contact) > 0:
contact = np.concatenate(contact)
rotmat = rotmat.detach()
camera = camera.detach()
if smpl.name() == "SMPL-X":
rotmat = rotmat[: -2 * 3]
res = {
"camera_s_t": camera.unsqueeze(0).cpu().numpy(),
"global_orient": rotmat[:3].unsqueeze(0).cpu().numpy(),
"betas": betas,
"body_pose": rotmat[3:].unsqueeze(0).cpu().numpy(),
"left_hand_pose": smpl.left_hand_pose.unsqueeze(0).detach().cpu().numpy(),
"right_hand_pose": smpl.right_hand_pose.unsqueeze(0).detach().cpu().numpy(),
"model": smpl.name().replace("-", ""),
"gender": smpl.gender,
"vertices": vertices[0].cpu().numpy(),
}
if contact is not None and len(contact) > 0:
contact = np.array(contact)
res["contact"] = contact
else:
res["v"] = vertices[0].cpu().numpy()
save_pkl(res, save_path)
np.savez(save_path, **res)
gitextract__nwva2u6/
├── README.md
├── patches/
│ ├── selfcontact.diff
│ ├── smplx.diff
│ └── torchgeometry.diff
├── requirements.txt
├── scripts/
│ ├── download.sh
│ ├── prepare.sh
│ └── run.sh
└── src/
├── fist_pose.py
├── hist_cub.py
├── losses.py
├── pose.py
├── pose_estimation.py
├── renderer.py
├── spin/
│ ├── __init__.py
│ ├── constants.py
│ ├── hmr.py
│ ├── smpl.py
│ └── utils.py
└── utils.py
SYMBOL INDEX (96 symbols across 9 files)
FILE: src/hist_cub.py
function cub (line 24) | def cub(x, a, b, c):
function subsample (line 33) | def subsample(a, p=0.0005, seed=0):
function read_cos_opt (line 42) | def read_cos_opt(path, fname="cos_hist.npy"):
function plot_hist (line 53) | def plot_hist(cos_opt_dir, hist_smpl_fpath, params, out_dir, bins=10, xy...
function kldiv (line 121) | def kldiv(p_hist, q_hist):
function calc_histogram (line 127) | def calc_histogram(x, bins=10, range=(0, 1)):
function step (line 132) | def step(params, angles_opt, p_hist, bone_idx=None):
function optimize (line 154) | def optimize(cos_opt_dir, hist_smpl_fpath, bone_idx=None):
function main (line 201) | def main():
FILE: src/losses.py
class MSE (line 9) | class MSE(nn.Module):
method __init__ (line 10) | def __init__(self, ignore=None):
method forward (line 16) | def forward(self, y_pred, y_data):
class Parallel (line 25) | class Parallel(nn.Module):
method __init__ (line 26) | def __init__(self, skeleton, ignore=None, ground_parallel=None):
method forward (line 40) | def forward(self, y_pred3d, y_data, z, spine_j, writer=None, global_st...
class MimickedSelfContactLoss (line 155) | class MimickedSelfContactLoss(nn.Module):
method __init__ (line 156) | def __init__(self, geodesics_mask):
method forward (line 164) | def forward(
FILE: src/pose.py
function parse_args (line 50) | def parse_args():
function freeze_layers (line 223) | def freeze_layers(model):
function project_and_normalize_to_spin (line 239) | def project_and_normalize_to_spin(vertices_3d, camera):
function project_and_normalize_to_spin_legs (line 253) | def project_and_normalize_to_spin_legs(vertices_3d, A, camera):
function rotation_matrix_to_angle_axis (line 286) | def rotation_matrix_to_angle_axis(rotmat):
function get_smpl_output (line 303) | def get_smpl_output(smpl, rotmat, betas, use_betas=True, zero_hands=False):
function get_predictions (line 331) | def get_predictions(model_hmr, smpl, input_img, use_betas=True, zero_han...
function get_pred_and_data (line 348) | def get_pred_and_data(
function normalize_keypoints_to_spin (line 379) | def normalize_keypoints_to_spin(keypoints_2d, img_size):
function unnormalize_keypoints_from_spin (line 397) | def unnormalize_keypoints_from_spin(keypoints_2d, shift, scale, ax2):
function get_vertices_in_heatmap (line 405) | def get_vertices_in_heatmap(contact_heatmap):
function get_contact_heatmap (line 424) | def get_contact_heatmap(model_contact, img_path, thresh=0.5):
function discretize (line 445) | def discretize(parametrization, n_bins=100):
function get_mapping_from_params_to_verts (line 453) | def get_mapping_from_params_to_verts(verts, params):
function find_contacts (line 461) | def find_contacts(y_data_conts, keypoints_2d, bone_to_params, thresh=12,...
function optimize (line 529) | def optimize(
function optimize_ft (line 772) | def optimize_ft(
function create_bone (line 1042) | def create_bone(i, j, keypoints_2d):
function is_parallel_to_plane (line 1051) | def is_parallel_to_plane(bone, thresh=21):
function is_close_to_plane (line 1055) | def is_close_to_plane(bone, plane, thresh):
function get_selector (line 1061) | def get_selector():
function calc_cos (line 1070) | def calc_cos(joints_2d, joints_3d):
function get_natural (line 1087) | def get_natural(keypoints_2d, vertices, right_foot_inds, left_foot_inds,...
function get_cos (line 1165) | def get_cos(keypoints_3d_pred, use_angle_transf, loss_parallel):
function save_mesh_with_winding_numbers (line 1202) | def save_mesh_with_winding_numbers(sc_module, vertices, smpl, save_path):
function get_contacts (line 1223) | def get_contacts(
function save_all (line 1264) | def save_all(
function spin_step (line 1348) | def spin_step(
function eft_step (line 1403) | def eft_step(
function dc_step (line 1496) | def dc_step(
function us_step (line 1577) | def us_step(
function main (line 1652) | def main():
FILE: src/pose_estimation.py
function transform (line 81) | def transform(img):
function get_affine_transform (line 89) | def get_affine_transform(
function get_3rd_point (line 127) | def get_3rd_point(a, b):
function get_dir (line 132) | def get_dir(src_point, rot_rad):
function process_image (line 142) | def process_image(path, input_img_size, pixel_std=200):
function get_final_preds (line 166) | def get_final_preds(batch_heatmaps, center, scale, post_process=False):
function transform_preds (line 199) | def transform_preds(coords, center, scale, output_size):
function affine_transform (line 207) | def affine_transform(pt, t):
function get_max_preds (line 213) | def get_max_preds(batch_heatmaps):
function infer_single_image (line 245) | def infer_single_image(model, img_path, input_img_size=(288, 384), retur...
FILE: src/renderer.py
class Renderer (line 8) | class Renderer:
method __init__ (line 14) | def __init__(self, focal_length=5000, img_res=224, faces=None):
method visualize_tb (line 22) | def visualize_tb(self, vertices, camera_translation, images):
method __call__ (line 40) | def __call__(self, vertices, camera_translation, image):
function overlay_mesh (line 84) | def overlay_mesh(
FILE: src/spin/hmr.py
function rot6d_to_rotmat (line 9) | def rot6d_to_rotmat(x):
class Bottleneck (line 31) | class Bottleneck(nn.Module):
method __init__ (line 38) | def __init__(self, inplanes, planes, stride=1, downsample=None):
method forward (line 52) | def forward(self, x):
class HMR (line 75) | class HMR(nn.Module):
method __init__ (line 78) | def __init__(self, block, layers, smpl_mean_params):
method _make_layer (line 123) | def _make_layer(self, block, planes, blocks, stride=1):
method forward (line 145) | def forward(self, x, init_pose=None, init_shape=None, init_cam=None, n...
function hmr (line 187) | def hmr(smpl_mean_params, pretrained=True, **kwargs):
FILE: src/spin/smpl.py
class SMPL (line 37) | class SMPL(_SMPL):
method __init__ (line 93) | def __init__(self, *args, **kwargs):
method forward (line 103) | def forward(self, *args, **kwargs):
class SMPLX (line 124) | class SMPLX(_SMPLX):
method __init__ (line 176) | def __init__(self, *args, **kwargs):
method forward (line 182) | def forward(self, *args, **kwargs):
FILE: src/spin/utils.py
function get_transform (line 12) | def get_transform(center, scale, res, rot=0):
function transform (line 40) | def transform(pt, center, scale, res, invert=0, rot=0):
function crop (line 51) | def crop(img, center, scale, res, rot=0):
function bbox_from_openpose (line 89) | def bbox_from_openpose(openpose_file, rescale=1.2, detection_thresh=0.2):
function bbox_from_json (line 105) | def bbox_from_json(bbox_file):
function process_image (line 119) | def process_image(img_file, bbox_file=None, openpose_file=None, input_re...
FILE: src/utils.py
function load_json (line 21) | def load_json(path):
function save_json (line 26) | def save_json(o, path):
function load_pkl (line 31) | def load_pkl(path):
function save_pkl (line 36) | def save_pkl(o, path):
function plot_3D (line 41) | def plot_3D(joints, vertices, faces):
function draw_keypoints (line 71) | def draw_keypoints(
function save_results_image (line 120) | def save_results_image(
function save_3d_model_on_img (line 347) | def save_3d_model_on_img(
function save_mesh_with_colors (line 388) | def save_mesh_with_colors(vertices, faces, save_path, mask=None, inds=No...
function save_pose_params (line 406) | def save_pose_params(rotmat, camera, betas, vertices, smpl, contact, sav...
Condensed preview — 20 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (156K chars).
[
{
"path": "README.md",
"chars": 3592,
"preview": "# Sketch2Pose: Estimating a 3D Character Pose from a Bitmap Sketch\n\nArtists frequently capture character poses via raste"
},
{
"path": "patches/selfcontact.diff",
"chars": 5486,
"preview": "+++ venv/lib/python3.10/site-packages/selfcontact/body_segmentation.py\n@@ -14,6 +14,8 @@\n #\n # Contact: ps-license@tuebi"
},
{
"path": "patches/smplx.diff",
"chars": 1902,
"preview": "+++ venv/lib/python3.10/site-packages/smplx/body_models.py\n@@ -366,7 +366,7 @@\n num_repeats = int(batch_size"
},
{
"path": "patches/torchgeometry.diff",
"chars": 419,
"preview": "+++ venv/lib/python3.10/site-packages/torchgeometry/core/conversions.py\n@@ -298,6 +298,9 @@\n rmat_"
},
{
"path": "requirements.txt",
"chars": 416,
"preview": "matplotlib>=3.5.1\nnumpy>=1.22.3\nopencv_python>=4.5.5.64\nPillow>=9.1.0\nplotly>=5.7.0\npyrender>=0.1.45\nscikit_image>=0.19."
},
{
"path": "scripts/download.sh",
"chars": 1472,
"preview": "#!/usr/bin/env sh\n\n\nset -euo pipefail\n\n\nasset_dir=\"./assets\"\n\n[ ! -e \"${asset_dir}\"/models_smplx_v1_1.zip ] \\\n && ech"
},
{
"path": "scripts/prepare.sh",
"chars": 524,
"preview": "#!/usr/bin/env sh\n\n\nset -euo pipefail\n\nvenv_dir=venv\npython -m venv --clear \"${venv_dir}\"\n\n. \"${venv_dir}\"/bin/activate\n"
},
{
"path": "scripts/run.sh",
"chars": 1452,
"preview": "#!/usr/bin/env sh\n\n\nset -euo pipefail\n\n\nimg_dir=\"./data/images\"\nout_dir=\"./output\"\n\nfind \"${img_dir}\" -mindepth 1 -maxde"
},
{
"path": "src/fist_pose.py",
"chars": 12752,
"preview": "left_fist = [\n 0.0, 0.0, 0.0,\n 0.0, 0.0, 0.0,\n 0.0, 0.0, 0.0,\n 0.0, 0.0, 0.0,\n 0.0, 0.0, 0.0,\n 0.0, 0."
},
{
"path": "src/hist_cub.py",
"chars": 6680,
"preview": "import itertools\nimport functools\nimport math\nimport multiprocessing\nfrom pathlib import Path\n\nimport matplotlib\nmatplot"
},
{
"path": "src/losses.py",
"chars": 6870,
"preview": "import itertools\n\nimport torch\nimport torch.nn as nn\n\nimport pose_estimation\n\n\nclass MSE(nn.Module):\n def __init__(se"
},
{
"path": "src/pose.py",
"chars": 55260,
"preview": "import argparse\nimport math\nfrom pathlib import Path\n\nimport cv2\nimport numpy as np\nimport PIL.Image as Image\nimport sel"
},
{
"path": "src/pose_estimation.py",
"chars": 7074,
"preview": "import math\n\nimport cv2\nimport numpy as np\n\nIMG_SIZE = (288, 384)\nMEAN = np.array([0.485, 0.456, 0.406])\nSTD = np.array("
},
{
"path": "src/renderer.py",
"chars": 6007,
"preview": "import numpy as np\nimport pyrender\nimport torch\nimport trimesh\nfrom torchvision.utils import make_grid\n\n\nclass Renderer:"
},
{
"path": "src/spin/__init__.py",
"chars": 192,
"preview": "from .constants import JOINT_NAMES\nfrom .hmr import hmr\nfrom .smpl import SMPLX\nfrom .utils import process_image\n\n__all_"
},
{
"path": "src/spin/constants.py",
"chars": 5279,
"preview": "FOCAL_LENGTH = 5000.0\nIMG_RES = 224\n\n# Mean and standard deviation for normalizing input image\nIMG_NORM_MEAN = [0.485, 0"
},
{
"path": "src/spin/hmr.py",
"chars": 6685,
"preview": "import math\n\nimport numpy as np\nimport torch\nimport torch.nn as nn\nimport torchvision.models.resnet as resnet\n\n\ndef rot6"
},
{
"path": "src/spin/smpl.py",
"chars": 9570,
"preview": "import numpy as np\nimport torch\nfrom smplx import SMPL as _SMPL\nfrom smplx import SMPLX as _SMPLX\nfrom smplx.body_models"
},
{
"path": "src/spin/utils.py",
"chars": 5077,
"preview": "import json\n\nimport cv2\nimport numpy as np\nimport torch\nfrom skimage.transform import resize, rotate\nfrom torchvision.tr"
},
{
"path": "src/utils.py",
"chars": 11760,
"preview": "import colorsys\nimport itertools\nimport json\nimport pickle\n\nimport cv2\nimport plotly.graph_objects as go\nimport trimesh\n"
}
]
About this extraction
This page contains the full source code of the kbrodt/sketch2pose GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 20 files (145.0 KB), approximately 45.8k tokens, and a symbol index with 96 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.
Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.