Repository: kbrodt/sketch2pose Branch: main Commit: c027b65219b1 Files: 20 Total size: 145.0 KB Directory structure: gitextract__nwva2u6/ ├── README.md ├── patches/ │ ├── selfcontact.diff │ ├── smplx.diff │ └── torchgeometry.diff ├── requirements.txt ├── scripts/ │ ├── download.sh │ ├── prepare.sh │ └── run.sh └── src/ ├── fist_pose.py ├── hist_cub.py ├── losses.py ├── pose.py ├── pose_estimation.py ├── renderer.py ├── spin/ │ ├── __init__.py │ ├── constants.py │ ├── hmr.py │ ├── smpl.py │ └── utils.py └── utils.py ================================================ FILE CONTENTS ================================================ ================================================ FILE: README.md ================================================ # Sketch2Pose: Estimating a 3D Character Pose from a Bitmap Sketch Artists frequently capture character poses via raster sketches, then use these drawings as a reference while posing a 3D character in a specialized 3D software --- a time-consuming process, requiring specialized 3D training and mental effort. We tackle this challenge by proposing the first system for automatically inferring a 3D character pose from a single bitmap sketch, producing poses consistent with viewer expectations. Algorithmically interpreting bitmap sketches is challenging, as they contain significantly distorted proportions and foreshortening. We address this by predicting three key elements of a drawing, necessary to disambiguate the drawn poses: 2D bone tangents, self-contacts, and bone foreshortening. These elements are then leveraged in an optimization inferring the 3D character pose consistent with the artist's intent. Our optimization balances cues derived from artistic literature and perception research to compensate for distorted character proportions. We demonstrate a gallery of results on sketches of numerous styles. We validate our method via numerical evaluations, user studies, and comparisons to manually posed characters and previous work. [Project Page](http://www-labs.iro.umontreal.ca/~bmpix/sketch2pose/) # Prerequisites - [GNU/Linux](https://www.gnu.org/gnu/linux-and-gnu.en.html) - [`python`](https://python.org) - [`pytorch`](https://pytorch.org/) - [NVIDIA GPU] (optional, but highly recommended) ## Download body model (SMPL-X) Download SMPL-X body model from [https://smpl-x.is.tue.mpg.de](https://smpl-x.is.tue.mpg.de) See [`download.sh`](./scripts/download.sh) and run ```bash sh ./scripts/download.sh ``` ## Virtual environement Change [`pytorch`](https://pytorch.org/) version if needed in [`prepare.sh`](./scripts/prepare.sh) and run ```bash sh ./scripts/prepare.sh ``` # Demo Activate virtual environement `. venv/bin/activate` and run ```bash sh ./scripts/run.sh # or python src/pose.py \ --save-path "${out_dir}" \ --img-path "${img_path}" \ --use-contacts \ --use-natural \ --use-cos \ --use-angle-transf \ # or without contacts python src/pose.py \ --save-path "${out_dir}" \ --img-path "${img_path}" \ --use-natural \ --use-cos \ --use-angle-transf \ ``` # Citation ``` @article{brodt2022sketch2pose, author = {Kirill Brodt and Mikhail Bessmeltsev}, title = {Sketch2Pose: Estimating a 3D Character Pose from a Bitmap Sketch}, journal = {ACM Transactions on Graphics}, year = {2022}, month = {7}, volume = {41}, number = {4}, doi = {10.1145/3528223.3530106}, } ``` # Useful links - [Deep High-Resolution Representation Learning for Human Pose Estimation](https://github.com/leoxiaobin/deep-high-resolution-net.pytorch/) - [SMPLify-X](https://github.com/vchoutas/smplify-x) ([project](https://smpl-x.is.tue.mpg.de/)) - [SPIN](https://github.com/nkolot/SPIN) ([project](https://www.seas.upenn.edu/~nkolot/projects/spin/)) - [eft](https://github.com/facebookresearch/eft) - [SMPLify-XMC](https://github.com/muelea/smplify-xmc), [selfcontact](https://github.com/muelea/selfcontact) ([project](https://tuch.is.tue.mpg.de/)) - [Mixamo](https://www.mixamo.com) models with animations and a [script](https://forums.unrealengine.com/community/community-content-tools-and-tutorials/1376068-script-mixamo-download-script) to download them - Quaternion-based [Forward Kinematics](https://github.com/facebookresearch/QuaterNet) ================================================ FILE: patches/selfcontact.diff ================================================ +++ venv/lib/python3.10/site-packages/selfcontact/body_segmentation.py @@ -14,6 +14,8 @@ # # Contact: ps-license@tuebingen.mpg.de +from pathlib import Path + import torch import trimesh import torch.nn as nn @@ -22,6 +24,17 @@ from .utils.mesh import winding_numbers + +def load_pkl(path): + with open(path, "rb") as fin: + return pickle.load(fin) + + +def save_pkl(obj, path): + with open(path, "wb") as fout: + pickle.dump(obj, fout) + + class BodySegment(nn.Module): def __init__(self, name, @@ -63,9 +76,17 @@ self.register_buffer('segment_faces', segment_faces) # create vector to select vertices form faces - tri_vidx = [] - for ii in range(faces.max().item()+1): - tri_vidx += [torch.nonzero(faces==ii)[0].tolist()] + segments_folder = Path(segments_folder) + tri_vidx_path = segments_folder / "tri_vidx.pkl" + if not tri_vidx_path.is_file(): + tri_vidx = [] + for ii in range(faces.max().item()+1): + tri_vidx += [torch.nonzero(faces==ii)[0].tolist()] + + save_pkl(tri_vidx, tri_vidx_path) + else: + tri_vidx = load_pkl(tri_vidx_path) + self.register_buffer('tri_vidx', torch.tensor(tri_vidx)) def create_band_faces(self): @@ -149,7 +170,7 @@ self.segmentation = {} for idx, name in enumerate(names): self.segmentation[name] = BodySegment(name, faces, segments_folder, - model_type).to('cuda') + model_type).to(device) def batch_has_self_isec_verts(self, vertices): """ +++ venv/lib/python3.10/site-packages/selfcontact/selfcontact.py @@ -41,6 +41,7 @@ test_segments=True, compute_hd=False, buffer_geodists=False, + device="cuda", ): super().__init__() @@ -95,7 +96,7 @@ if self.test_segments: sxseg = pickle.load(open(segments_bounds_path, 'rb')) self.segments = BatchBodySegment( - [x for x in sxseg.keys()], faces, segments_folder, self.model_type + [x for x in sxseg.keys()], faces, segments_folder, self.model_type, device=device, ) # load regressor to get high density mesh @@ -106,7 +107,7 @@ torch.tensor(hd_operator['values']), torch.Size(hd_operator['size'])) self.register_buffer('hd_operator', - torch.tensor(hd_operator).float()) + hd_operator.clone().detach().float()) with open(point_vert_corres_path, 'rb') as f: hd_geovec = pickle.load(f)['faces_vert_is_sampled_from'] @@ -135,9 +136,13 @@ # split because of memory into two chunks exterior = torch.zeros((bs, nv), device=vertices.device, dtype=torch.bool) - exterior[:, :5000] = winding_numbers(vertices[:,:5000,:], + exterior[:, :3000] = winding_numbers(vertices[:,:3000,:], triangles).le(0.99) - exterior[:, 5000:] = winding_numbers(vertices[:,5000:,:], + exterior[:, 3000:6000] = winding_numbers(vertices[:,3000:6000,:], + triangles).le(0.99) + exterior[:, 6000:9000] = winding_numbers(vertices[:,6000:9000,:], + triangles).le(0.99) + exterior[:, 9000:] = winding_numbers(vertices[:,9000:,:], triangles).le(0.99) # check if intersections happen within segments @@ -173,9 +178,13 @@ # split because of memory into two chunks exterior = torch.zeros((bs, np), device=points.device, dtype=torch.bool) - exterior[:, :6000] = winding_numbers(points[:,:6000,:], + exterior[:, :3000] = winding_numbers(points[:,:3000,:], + triangles).le(0.99) + exterior[:, 3000:6000] = winding_numbers(points[:,3000:6000,:], triangles).le(0.99) - exterior[:, 6000:] = winding_numbers(points[:,6000:,:], + exterior[:, 6000:9000] = winding_numbers(points[:,6000:9000,:], + triangles).le(0.99) + exterior[:, 9000:] = winding_numbers(points[:,9000:,:], triangles).le(0.99) return exterior @@ -371,6 +380,23 @@ return hd_v2v_mins, hd_exteriors, hd_points, hd_faces_in_contacts + def verts_in_contact(self, vertices, return_idx=False): + + # get pairwise distances of vertices + v2v = self.get_pairwise_dists(vertices, vertices, squared=True) + + # mask v2v with eucledean and geodesic dsitance + euclmask = v2v < self.euclthres**2 + mask = euclmask * self.geomask + + # find closes vertex in contact + in_contact = mask.sum(1) > 0 + + if return_idx: + in_contact = torch.where(in_contact) + + return in_contact + class SelfContactSmall(nn.Module): +++ venv/lib/python3.10/site-packages/selfcontact/utils/mesh.py @@ -82,7 +82,7 @@ if valid_vals > 0: loss = (mask * dists).sum() / valid_vals else: - loss = torch.Tensor([0]).cuda() + loss = mask.new_tensor([0]) return loss def batch_index_select(inp, dim, index): @@ -103,6 +103,7 @@ xx = torch.bmm(x, x.transpose(2, 1)) yy = torch.bmm(y, y.transpose(2, 1)) zz = torch.bmm(x, y.transpose(2, 1)) + use_cuda = x.device.type == "cuda" if use_cuda: dtype = torch.cuda.LongTensor else: ================================================ FILE: patches/smplx.diff ================================================ +++ venv/lib/python3.10/site-packages/smplx/body_models.py @@ -366,7 +366,7 @@ num_repeats = int(batch_size / betas.shape[0]) betas = betas.expand(num_repeats, -1) - vertices, joints = lbs(betas, full_pose, self.v_template, + vertices, joints, _ = lbs(betas, full_pose, self.v_template, self.shapedirs, self.posedirs, self.J_regressor, self.parents, self.lbs_weights, pose2rot=pose2rot) @@ -1228,7 +1228,7 @@ shapedirs = torch.cat([self.shapedirs, self.expr_dirs], dim=-1) - vertices, joints = lbs(shape_components, full_pose, self.v_template, + vertices, joints, A = lbs(shape_components, full_pose, self.v_template, shapedirs, self.posedirs, self.J_regressor, self.parents, self.lbs_weights, pose2rot=pose2rot, @@ -1283,7 +1283,9 @@ right_hand_pose=right_hand_pose, jaw_pose=jaw_pose, v_shaped=v_shaped, - full_pose=full_pose if return_full_pose else None) + full_pose=full_pose if return_full_pose else None, + A=A, + ) return output +++ venv/lib/python3.10/site-packages/smplx/lbs.py @@ -245,7 +245,7 @@ verts = v_homo[:, :, :3, 0] - return verts, J_transformed + return verts, J_transformed, (A, J) def vertices2joints(J_regressor: Tensor, vertices: Tensor) -> Tensor: +++ venv/lib/python3.10/site-packages/smplx/utils.py @@ -71,6 +71,7 @@ class SMPLXOutput(SMPLHOutput): expression: Optional[Tensor] = None jaw_pose: Optional[Tensor] = None + A: Optional[Tensor] = None @dataclass ================================================ FILE: patches/torchgeometry.diff ================================================ +++ venv/lib/python3.10/site-packages/torchgeometry/core/conversions.py @@ -298,6 +298,9 @@ rmat_t[:, 0, 1] - rmat_t[:, 1, 0]], -1) t3_rep = t3.repeat(4, 1).t() + mask_d2 = mask_d2.float() + mask_d0_d1 = mask_d0_d1.float() + mask_d0_nd1 = mask_d0_nd1.float() mask_c0 = mask_d2 * mask_d0_d1 mask_c1 = mask_d2 * (1 - mask_d0_d1) mask_c2 = (1 - mask_d2) * mask_d0_nd1 ================================================ FILE: requirements.txt ================================================ matplotlib>=3.5.1 numpy>=1.22.3 opencv_python>=4.5.5.64 Pillow>=9.1.0 plotly>=5.7.0 pyrender>=0.1.45 scikit_image>=0.19.2 scipy>=1.8.0 Shapely>=1.8.1.post1 scikit-image>=0.19.2 tensorboard>=2.8.0 torchgeometry>=0.1.2 tqdm>=4.64.0 trimesh>=3.10.8 git+https://github.com/muelea/selfcontact.git@08da422526419c24736c0616bca49623e442c26a git+https://github.com/vchoutas/smplx.git@5fa20519735cceda19afed0beeabd53caef711cd ================================================ FILE: scripts/download.sh ================================================ #!/usr/bin/env sh set -euo pipefail asset_dir="./assets" [ ! -e "${asset_dir}"/models_smplx_v1_1.zip ] \ && echo Error: Download SMPL-X body model from https://smpl-x.is.tue.mpg.de \ and save zip archive to "${asset_dir}" \ && exit 1 \ && : asset_urls=( # Download constants (SPIN) http://visiondata.cis.upenn.edu/spin/data.tar.gz # Download essentials (SMPLify-XMC) https://download.is.tue.mpg.de/tuch/smplify-xmc-essentials.zip # Download sketch2pose models http://www-labs.iro.umontreal.ca/~bmpix/sketch2pose/models.zip # Download test images http://www-labs.iro.umontreal.ca/~bmpix/sketch2pose/images.zip ) for asset_url in "${asset_urls[@]}"; do wget \ -nc \ -c \ --directory-prefix "${asset_dir}" \ "${asset_url}" done models_dir="./models" mkdir -p "${models_dir}" model_files=( # Unzip smplx models models_smplx_v1_1.zip # Unzip essentials (SMPLifu-XMC) smplify-xmc-essentials.zip # Unzip sketch2pose models models.zip ) for model_file in "${model_files[@]}"; do unzip \ -u \ -d "${models_dir}" \ "${asset_dir}"/"${model_file}" done # Unzip constants (SPIN) tar \ --skip-old-files \ -xvf "${asset_dir}"/data.tar.gz \ -C "${models_dir}" \ data/smpl_mean_params.npz data_dir="./data" mkdir -p "${data_dir}" # Unzip test images unzip \ -u \ -d "${data_dir}" \ "${asset_dir}"/images.zip ================================================ FILE: scripts/prepare.sh ================================================ #!/usr/bin/env sh set -euo pipefail venv_dir=venv python -m venv --clear "${venv_dir}" . "${venv_dir}"/bin/activate pip install -U pip setuptools extra="cpu" [ -x "$(command -v nvcc)" ] && extra="cu113" pip install \ torch \ torchvision \ --extra-index-url https://download.pytorch.org/whl/"${extra}" pip install -r requirements.txt v=$(python -c 'import sys; v = sys.version_info; print(f"{v.major}.{v.minor}")') for p in patches/*.diff; do patch -p0 < <(sed "s/python3.10/python${v}/" "${p}") done ================================================ FILE: scripts/run.sh ================================================ #!/usr/bin/env sh set -euo pipefail img_dir="./data/images" out_dir="./output" find "${img_dir}" -mindepth 1 -maxdepth 1 -type f -print0 \ | xargs -0 -I "{}" python src/pose.py \ --save-path "${out_dir}" \ --img-path "{}" \ --use-contacts \ --use-natural \ --use-cos \ --use-angle-transf \ exit # baseline (SMPLify-XMC) find "${img_dir}" -mindepth 1 -maxdepth 1 -type f -print0 \ | xargs -0 -I "{}" python src/pose.py \ --save-path "${out_dir}_baseline" \ --img-path "{}" \ --c-mse 1 \ --c-par 0 \ --use-contacts \ --use-cos \ --use-angle-transf \ # ablation find "${img_dir}" -mindepth 1 -maxdepth 1 -type f -print0 \ | xargs -0 -I "{}" python src/pose.py \ --save-path "${out_dir}_wocostransform" \ --img-path "{}" \ --use-contacts \ --use-natural \ --use-cos \ find "${img_dir}" -mindepth 1 -maxdepth 1 -type f -print0 \ | xargs -0 -I "{}" python src/pose.py \ --save-path "${out_dir}_wocontacts" \ --img-path "{}" \ --use-msc \ --use-natural \ --use-cos \ --use-angle-transf \ find "${img_dir}" -mindepth 1 -maxdepth 1 -type f -print0 \ | xargs -0 -I "{}" python src/pose.py \ --save-path "${out_dir}_wonatural" \ --img-path "{}" \ --use-contacts \ --use-cos \ --use-angle-transf \ ================================================ FILE: src/fist_pose.py ================================================ left_fist = [ 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.4183167815208435, 0.10645648092031479, -1.6593892574310303, 0.15252035856246948, -0.14700782299041748, -1.3719955682754517, -0.04432843625545502, -0.15799851715564728, -0.938068151473999, -0.12218914180994034, 0.073341965675354, -1.6415189504623413, -0.14376045763492584, 0.1927780956029892, -1.3593589067459106, -0.0851994976401329, 0.01652289740741253, -0.7474589347839355, -0.9881719946861267, -0.3987707793712616, -1.3535722494125366, -0.6686224937438965, 0.1261960119009018, -1.080643892288208, -0.8101894855499268, -0.1306752860546112, -0.8412265777587891, -0.3495230972766876, -0.17784251272678375, -1.4433038234710693, -0.46278536319732666, 0.13677796721458435, -1.467200517654419, -0.3681888282299042, 0.003404417773708701, -0.7764251232147217, 0.850964367389679, 0.2769227623939514, -0.09154807031154633, 0.14500413835048676, 0.09604815393686295, 0.219278022646904, 1.0451993942260742, 0.16911321878433228, -0.2426234930753708, 0.11167845129966736, -0.04289207234978676, 0.41644084453582764, 0.10881128907203674, 0.06598565727472305, 0.756219744682312, -0.0963931530714035, 0.09091583639383316, 0.18845966458320618, -0.11809506267309189, -0.050943851470947266, 0.5295845866203308, -0.14369848370552063, -0.055241718888282776, 0.704857349395752, -0.019182899966835976, 0.0923367589712143, 0.3379131853580475, -0.45703303813934326, 0.1962839663028717, 0.6254575848579407, -0.21465237438678741, 0.06599827855825424, 0.5068942308425903, -0.36972442269325256, 0.0603446289896965, 0.07949023693799973, -0.14186954498291016, 0.08585254102945328, 0.6355276107788086, -0.3033415675163269, 0.05788097903132439, 0.6313892006874084, -0.17612087726593018, 0.13209305703639984, 0.3733545243740082, 0.850964367389679, -0.2769227623939514, 0.09154807031154633, -0.4998386800289154, -0.026556432247161865, -0.052880801260471344, 0.5355585217475891, -0.045960985124111176, 0.27735769748687744, ] left_right_fist = [ 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, -0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.4183167815208435, 0.10645648092031479, -1.6593892574310303, 0.15252035856246948, -0.14700782299041748, -1.3719955682754517, -0.04432843625545502, -0.15799851715564728, -0.938068151473999, -0.12218914180994034, 0.073341965675354, -1.6415189504623413, -0.14376045763492584, 0.1927780956029892, -1.3593589067459106, -0.0851994976401329, 0.01652289740741253, -0.7474589347839355, -0.9881719946861267, -0.3987707793712616, -1.3535722494125366, -0.6686224937438965, 0.1261960119009018, -1.080643892288208, -0.8101894855499268, -0.1306752860546112, -0.8412265777587891, -0.3495230972766876, -0.17784251272678375, -1.4433038234710693, -0.46278536319732666, 0.13677796721458435, -1.467200517654419, -0.3681888282299042, 0.003404417773708701, -0.7764251232147217, 0.850964367389679, 0.2769227623939514, -0.09154807031154633, 0.14500413835048676, 0.09604815393686295, 0.219278022646904, 1.0451993942260742, 0.16911321878433228, -0.2426234930753708, 0.4183167815208435, -0.10645647346973419, 1.6593892574310303, 0.15252038836479187, 0.14700786769390106, 1.3719956874847412, -0.04432841017842293, 0.15799842774868011, 0.9380677938461304, -0.12218913435935974, -0.0733419880270958, 1.6415191888809204, -0.14376048743724823, -0.19277812540531158, 1.3593589067459106, -0.08519953489303589, -0.016522908583283424, 0.7474592328071594, -0.9881719350814819, 0.3987707495689392, 1.3535723686218262, -0.6686226725578308, -0.12619605660438538, 1.080644130706787, -0.8101896643638611, 0.1306752860546112, 0.8412266373634338, -0.34952324628829956, 0.17784248292446136, 1.443304181098938, -0.46278542280197144, -0.13677802681922913, 1.467200517654419, -0.36818885803222656, -0.0034044249914586544, 0.7764251232147217, 0.8509642481803894, -0.2769228219985962, 0.09154807776212692, 0.14500458538532257, -0.09604845196008682, -0.21927869319915771, 1.0451991558074951, -0.1691131889820099, 0.242623433470726, ] right_fist = [] for lf, lrf in zip(left_fist, left_right_fist): if lf != lrf: right_fist.append(lrf) else: right_fist.append(0) left_flat_up = [ 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0, 1.5129635334014893, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ] left_flat_down = [ 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0, -1.4648663997650146, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ] right_flat_up = [ 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0, -1.5021973848342896, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ] right_flat_down = [ 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0, 0, 1.494218111038208, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ] relaxed = [ 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.11167845129966736, 0.04289207234978676, -0.41644084453582764, 0.10881128907203674, -0.06598565727472305, -0.756219744682312, -0.0963931530714035, -0.09091583639383316, -0.18845966458320618, -0.11809506267309189, 0.050943851470947266, -0.5295845866203308, -0.14369848370552063, 0.055241718888282776, -0.704857349395752, -0.019182899966835976, -0.0923367589712143, -0.3379131853580475, -0.45703303813934326, -0.1962839663028717, -0.6254575848579407, -0.21465237438678741, -0.06599827855825424, -0.5068942308425903, -0.36972442269325256, -0.0603446289896965, -0.07949023693799973, -0.14186954498291016, -0.08585254102945328, -0.6355276107788086, -0.3033415675163269, -0.05788097903132439, -0.6313892006874084, -0.17612087726593018, -0.13209305703639984, -0.3733545243740082, 0.850964367389679, 0.2769227623939514, -0.09154807031154633, -0.4998386800289154, 0.026556432247161865, 0.052880801260471344, 0.5355585217475891, 0.045960985124111176, -0.27735769748687744, 0.11167845129966736, -0.04289207234978676, 0.41644084453582764, 0.10881128907203674, 0.06598565727472305, 0.756219744682312, -0.0963931530714035, 0.09091583639383316, 0.18845966458320618, -0.11809506267309189, -0.050943851470947266, 0.5295845866203308, -0.14369848370552063, -0.055241718888282776, 0.704857349395752, -0.019182899966835976, 0.0923367589712143, 0.3379131853580475, -0.45703303813934326, 0.1962839663028717, 0.6254575848579407, -0.21465237438678741, 0.06599827855825424, 0.5068942308425903, -0.36972442269325256, 0.0603446289896965, 0.07949023693799973, -0.14186954498291016, 0.08585254102945328, 0.6355276107788086, -0.3033415675163269, 0.05788097903132439, 0.6313892006874084, -0.17612087726593018, 0.13209305703639984, 0.3733545243740082, 0.850964367389679, -0.2769227623939514, 0.09154807031154633, -0.4998386800289154, -0.026556432247161865, -0.052880801260471344, 0.5355585217475891, -0.045960985124111176, 0.27735769748687744, ] # body joints + left arm + right arm # 25 + 15 + 15 # smpl(left_hand_pose, right_hand_pose) left_start = 25 * 3 left_end = left_start + 15 * 3 right_end = left_end + 15 * 3 LEFT_FIST = left_fist[left_start:left_end] RIGHT_FIST = right_fist[left_end:right_end] LEFT_FLAT_UP = left_flat_up[20 * 3 : 20 * 3 + 3] LEFT_FLAT_DOWN = left_flat_down[20 * 3 : 20 * 3 + 3] RIGHT_FLAT_UP = right_flat_up[21 * 3 : 21 * 3 + 3] RIGHT_FLAT_DOWN = right_flat_down[21 * 3 : 21 * 3 + 3] LEFT_RELAXED = relaxed[left_start:left_end] RIGHT_RELAXED = relaxed[left_end:right_end] INT_TO_FIST = { "lfl": None, "lf": LEFT_FIST, "lu": LEFT_FLAT_UP, "ld": LEFT_FLAT_DOWN, "rfl": None, "rf": RIGHT_FIST, "ru": RIGHT_FLAT_UP, "rd": RIGHT_FLAT_DOWN, } ================================================ FILE: src/hist_cub.py ================================================ import itertools import functools import math import multiprocessing from pathlib import Path import matplotlib matplotlib.rcParams.update({'font.size': 24}) matplotlib.rcParams.update({ "text.usetex": True, "text.latex.preamble": r"\usepackage{biolinum} \usepackage{libertineRoman} \usepackage{libertineMono} \usepackage{biolinum} \usepackage[libertine]{newtxmath}", 'ps.usedistiller': "xpdf", }) import matplotlib.pyplot as plt import matplotlib.gridspec as gridspec import numpy as np import tqdm from scipy.stats import wasserstein_distance import pose_estimation def cub(x, a, b, c): x2 = x * x x3 = x2 * x y = a * x3 + b * x2 + c * x return y def subsample(a, p=0.0005, seed=0): np.random.seed(seed) N = len(a) inds = np.random.choice(range(N), size=int(p * N)) a = a[inds].copy() return a def read_cos_opt(path, fname="cos_hist.npy"): cos_opt = [] for p in Path(path).rglob(fname): d = np.load(p) cos_opt.append(d) cos_opt = np.array(cos_opt) return cos_opt def plot_hist(cos_opt_dir, hist_smpl_fpath, params, out_dir, bins=10, xy=None): cos_opt = read_cos_opt(cos_opt_dir) angle_opt = np.arccos(cos_opt) angle_opt2 = cub(angle_opt, *params) cos_opt2 = np.cos(angle_opt2) cos_smpl = np.load(hist_smpl_fpath) # cos_smpl = subsample(cos_smpl) print(cos_smpl.shape) cos_smpl = np.clip(cos_smpl, -1, 1) cos_opt = angle_opt cos_opt2 = angle_opt2 cos_smpl = np.arccos(cos_smpl) cos_opt = 180 / math.pi * cos_opt cos_opt2 = 180 / math.pi * cos_opt2 cos_smpl = 180 / math.pi * cos_smpl max_range = 90 # math.pi / 2 xticks = [0, 15, 30, 45, 60, 75, 90] for idx, bone in enumerate(pose_estimation.SKELETON): i, j = bone i_name = pose_estimation.KPS[i] j_name = pose_estimation.KPS[j] if i_name != "Left Upper Leg": continue name = f"{i_name}_{j_name}" gs = gridspec.GridSpec(2, 4) fig = plt.figure(tight_layout=True, figsize=(16, 8), dpi=300) ax0 = fig.add_subplot(gs[0, 0]) ax0.hist(cos_smpl[:, idx], bins=bins, range=(0, max_range), density=True) ax0.set_xticks(xticks) ax0.tick_params(labelbottom=False, labelleft=True) ax1 = fig.add_subplot(gs[1, 0], sharex=ax0) ax1.hist(cos_opt[:, idx], bins=bins, range=(0, max_range), density=True) ax1.set_xticks(xticks) if xy is not None: ax2 = fig.add_subplot(gs[:, 1:3]) ax2.plot(xy[0], xy[1], linewidth=8) ax2.plot(xy[0], xy[0], linewidth=4, linestyle="dashed") ax2.set_xticks(xticks) ax2.set_yticks(xticks) ax3 = fig.add_subplot(gs[0, 3], sharey=ax0) ax3.hist(cos_opt2[:, idx], bins=bins, range=(0, max_range), density=True) ax3.set_xticks(xticks) ax3.tick_params(labelbottom=False, labelleft=False) ax4 = fig.add_subplot(gs[1, 3], sharex=ax3, sharey=ax1) alpha = 0.5 ax4.hist(cos_opt[:, idx], bins=bins, range=(0, max_range), density=True, label=r"$\mathcal{B}_i$", alpha=alpha) ax4.hist(cos_opt2[:, idx], bins=bins, range=(0, max_range), density=True, label=r"$f(\mathcal{B}_i)$", alpha=alpha) ax4.hist(cos_smpl[:, idx], bins=bins, range=(0, max_range), density=True, label=r"$\mathcal{A}_i$", alpha=alpha) ax4.set_xticks(xticks) ax4.tick_params(labelbottom=True, labelleft=False) ax4.legend() fig.savefig(out_dir / f"hist_{name}.png") plt.close() def kldiv(p_hist, q_hist): wd = wasserstein_distance(p_hist, q_hist) return wd def calc_histogram(x, bins=10, range=(0, 1)): h, _ = np.histogram(x, bins=bins, range=range, density=True) return h def step(params, angles_opt, p_hist, bone_idx=None): if sum(params) > 1: return math.inf, params kl = 0 for i, _ in enumerate(pose_estimation.SKELETON): if bone_idx is not None and i != bone_idx: continue angles_opt2 = cub(angles_opt[:, i], *params) if angles_opt2.max() > 1 or angles_opt2.min() < 0: kl = math.inf break q_hist = calc_histogram(angles_opt2) kl += kldiv(p_hist[i], q_hist) return kl, params def optimize(cos_opt_dir, hist_smpl_fpath, bone_idx=None): cos_opt = read_cos_opt(cos_opt_dir) angles_opt = np.arccos(cos_opt) / (math.pi / 2) cos_smpl = np.load(hist_smpl_fpath) # cos_smpl = subsample(cos_smpl) print(cos_smpl.shape) cos_smpl = np.clip(cos_smpl, -1, 1) mask = cos_smpl <= 1 assert np.all(mask), (~mask).mean() mask = cos_smpl >= 0 assert np.all(mask), (~mask).mean() angles_smpl = np.arccos(cos_smpl) / (math.pi / 2) p_hist = [ calc_histogram(angles_smpl[:, i]) for i, _ in enumerate(pose_estimation.SKELETON) ] with multiprocessing.Pool(8) as p: results = list( tqdm.tqdm( p.imap_unordered( functools.partial(step, angles_opt=angles_opt, p_hist=p_hist, bone_idx=bone_idx), itertools.product( np.linspace(0, 20, 100), np.linspace(-20, 20, 200), np.linspace(-20, 1, 100), ), ), total=(100 * 200 * 100), ) ) kls, params = zip(*results) ind = np.argmin(kls) best_params = params[ind] print(kls[ind], best_params) inds = np.argsort(kls) for i in inds[:10]: print(kls[i]) print(params[i]) print() return best_params def main(): cos_opt_dir = "paper_single2_150mse" hist_smpl_fpath = "./data/hist_smpl.npy" # hist_smpl_fpath = "./testtest.npy" params = optimize(cos_opt_dir, hist_smpl_fpath) # params = (1.2121212121212122, -1.105527638190953, 0.787878787878789) # params = (0.20202020202020202, 0.30150753768844396, 0.3636363636363633) print(params) x = np.linspace(0, math.pi / 2, 100) y = cub(x / (math.pi / 2), *params) * (math.pi / 2) x = x * 180 / math.pi y = y * 180 / math.pi out_dir = Path("hists") out_dir.mkdir(parents=True, exist_ok=True) plot_hist(cos_opt_dir, hist_smpl_fpath, params, out_dir, xy=(x, y)) plt.figure(figsize=(4, 4), dpi=300) plt.plot(x, y, linewidth=6) plt.plot(x, x, linewidth=2, linestyle="dashed") xticks = [0, 15, 30, 45, 60, 75, 90] plt.xticks(xticks) plt.yticks(xticks) plt.axis("equal") plt.tight_layout() plt.savefig(out_dir / "new_out.png") if __name__ == "__main__": main() ================================================ FILE: src/losses.py ================================================ import itertools import torch import torch.nn as nn import pose_estimation class MSE(nn.Module): def __init__(self, ignore=None): super().__init__() self.mse = torch.nn.MSELoss(reduction="none") self.ignore = ignore if ignore is not None else [] def forward(self, y_pred, y_data): loss = self.mse(y_pred, y_data) if len(self.ignore) > 0: loss[self.ignore] *= 0 return loss.sum() / (len(loss) - len(self.ignore)) class Parallel(nn.Module): def __init__(self, skeleton, ignore=None, ground_parallel=None): super().__init__() self.skeleton = skeleton if ignore is not None: self.ignore = set(ignore) else: self.ignore = set() self.ground_parallel = ground_parallel if ground_parallel is not None else [] self.parallel_in_3d = [] self.cos = None def forward(self, y_pred3d, y_data, z, spine_j, writer=None, global_step=0): y_pred = y_pred3d[:, :2] rleg, lleg = spine_j Lcon2d = Lcount = 0 if hasattr(self, "contact_2d"): for c2d in self.contact_2d: for ( (src_1, dst_1, t_1), (src_2, dst_2, t_2), ) in itertools.combinations(c2d, 2): a_1 = torch.lerp(y_data[src_1], y_data[dst_1], t_1) a_2 = torch.lerp(y_data[src_2], y_data[dst_2], t_2) a = a_2 - a_1 b_1 = torch.lerp(y_pred[src_1], y_pred[dst_1], t_1) b_2 = torch.lerp(y_pred[src_2], y_pred[dst_2], t_2) b = b_2 - b_1 lcon2d = ((a - b) ** 2).sum() Lcon2d = Lcon2d + lcon2d Lcount += 1 if Lcount > 0: Lcon2d = Lcon2d / Lcount Ltan = Lpar = Lcos = Lcount = 0 Lspine = 0 for i, bone in enumerate(self.skeleton): if bone in self.ignore: continue src, dst = bone b = y_data[dst] - y_data[src] t = nn.functional.normalize(b, dim=0) n = torch.stack([-t[1], t[0]]) if src == 10 and dst == 11: # right leg a = rleg elif src == 13 and dst == 14: # left leg a = lleg else: a = y_pred[dst] - y_pred[src] bone_name = f"{pose_estimation.KPS[src]}_{pose_estimation.KPS[dst]}" c = a - b lcos_loc = ltan_loc = lpar_loc = 0 if self.cos is not None: if bone not in [ (1, 2), # Neck + Right Shoulder (1, 5), # Neck + Left Shoulder (9, 10), # Hips + Right Upper Leg (9, 13), # Hips + Left Upper Leg ]: a = y_pred[dst] - y_pred[src] l2d = torch.norm(a, dim=0) l3d = torch.norm(y_pred3d[dst] - y_pred3d[src], dim=0) lcos = self.cos[i] lcos_loc = (l2d / l3d - lcos) ** 2 Lcos = Lcos + lcos_loc lpar_loc = ((a / l2d) * n).sum() ** 2 Lpar = Lpar + lpar_loc else: ltan_loc = ((c * t).sum()) ** 2 Ltan = Ltan + ltan_loc lpar_loc = (c * n).sum() ** 2 Lpar = Lpar + lpar_loc if writer is not None: writer.add_scalar(f"tan/{bone_name}", ltan_loc, global_step=global_step) writer.add_scalar(f"cos/{bone_name}", lcos_loc, global_step=global_step) writer.add_scalar(f"par/{bone_name}", lpar_loc, global_step=global_step) Lcount += 1 if Lcount > 0: Ltan = Ltan / Lcount Lcos = Lcos / Lcount Lpar = Lpar / Lcount Lspine = Lspine / Lcount Lgr = Lcount = 0 for (src, dst), value in self.ground_parallel: bone = y_pred[dst] - y_pred[src] bone = nn.functional.normalize(bone, dim=0) l = (torch.abs(bone[0]) - value) ** 2 Lgr = Lgr + l Lcount += 1 if Lcount > 0: Lgr = Lgr / Lcount Lstraight3d = Lcount = 0 for (i, j), (k, l) in self.parallel_in_3d: a = z[j] - z[i] a = nn.functional.normalize(a, dim=0) b = z[l] - z[k] b = nn.functional.normalize(b, dim=0) lo = (((a * b).sum() - 1) ** 2).sum() Lstraight3d = Lstraight3d + lo Lcount += 1 b = y_data[1] - y_data[8] b = nn.functional.normalize(b, dim=0) if Lcount > 0: Lstraight3d = Lstraight3d / Lcount return Ltan, Lcos, Lpar, Lspine, Lgr, Lstraight3d, Lcon2d class MimickedSelfContactLoss(nn.Module): def __init__(self, geodesics_mask): super().__init__() """ Loss that lets vertices in contact on presented mesh attract vertices that are close. """ # geodesic distance mask self.register_buffer("geomask", geodesics_mask) def forward( self, presented_contact, vertices, v2v=None, contact_mode="dist_tanh", contact_thresh=1, ): contactloss = 0.0 if v2v is None: # compute pairwise distances verts = vertices.contiguous() nv = verts.shape[1] v2v = verts.squeeze().unsqueeze(1).expand( nv, nv, 3 ) - verts.squeeze().unsqueeze(0).expand(nv, nv, 3) v2v = torch.norm(v2v, 2, 2) # loss for self-contact from mimic'ed pose if len(presented_contact) > 0: # without geodesic distance mask, compute distances # between each pair of verts in contact with torch.no_grad(): cvertstobody = v2v[presented_contact, :] cvertstobody = cvertstobody[:, presented_contact] maskgeo = self.geomask[presented_contact, :] maskgeo = maskgeo[:, presented_contact] weights = torch.ones_like(cvertstobody).to(verts.device) weights[~maskgeo] = float("inf") min_idx = torch.min((cvertstobody + 1) * weights, 1)[1] min_idx = presented_contact[min_idx.cpu().numpy()] v2v_min = v2v[presented_contact, min_idx] # tanh will not pull vertices that are ~more than contact_thres far apart if contact_mode == "dist_tanh": contactloss = contact_thresh * torch.tanh(v2v_min / contact_thresh) contactloss = contactloss.mean() else: contactloss = v2v_min.mean() return contactloss ================================================ FILE: src/pose.py ================================================ import argparse import math from pathlib import Path import cv2 import numpy as np import PIL.Image as Image import selfcontact import selfcontact.losses import shapely.geometry import torch import torch.nn as nn import torch.optim as optim import torchgeometry import tqdm import trimesh from skimage import measure from torch.utils.tensorboard.writer import SummaryWriter import fist_pose import hist_cub import losses import pose_estimation import spin import utils PE_KSP_TO_SPIN = { "Head": "Head", "Neck": "Neck", "Right Shoulder": "Right ForeArm", "Right Arm": "Right Arm", "Right Hand": "Right Hand", "Left Shoulder": "Left ForeArm", "Left Arm": "Left Arm", "Left Hand": "Left Hand", "Spine": "Spine1", "Hips": "Hips", "Right Upper Leg": "Right Upper Leg", "Right Leg": "Right Leg", "Right Foot": "Right Foot", "Left Upper Leg": "Left Upper Leg", "Left Leg": "Left Leg", "Left Foot": "Left Foot", "Left Toe": "Left Toe", "Right Toe": "Right Toe", } MODELS_DIR = "models" def parse_args(): parser = argparse.ArgumentParser() parser.add_argument( "--pose-estimation-model-path", type=str, default=f"./{MODELS_DIR}/hrn_w48_384x288.onnx", help="Pose Estimation model", ) parser.add_argument( "--contact-model-path", type=str, default=f"./{MODELS_DIR}/contact_hrn_w32_256x192.onnx", help="Contact model", ) parser.add_argument( "--device", type=str, default="cuda", choices=["cpu", "cuda"], help="Torch device", ) parser.add_argument( "--spin-model-path", type=str, default=f"./{MODELS_DIR}/spin_model_smplx_eft_18.pt", help="SPIN model path", ) parser.add_argument( "--smpl-type", type=str, default="smplx", choices=["smplx"], help="SMPL model type", ) parser.add_argument( "--smpl-model-dir", type=str, default=f"./{MODELS_DIR}/models/smplx", help="SMPL model dir", ) parser.add_argument( "--smpl-mean-params-path", type=str, default=f"./{MODELS_DIR}/data/smpl_mean_params.npz", help="SMPL mean params", ) parser.add_argument( "--essentials-dir", type=str, default=f"./{MODELS_DIR}/smplify-xmc-essentials", help="SMPL Essentials folder for contacts", ) parser.add_argument( "--parametrization-path", type=str, default=f"./{MODELS_DIR}/smplx_parametrization/parametrization.npy", help="Parametrization path", ) parser.add_argument( "--bone-parametrization-path", type=str, default=f"./{MODELS_DIR}/smplx_parametrization/bone_to_param2.npy", help="Bone parametrization path", ) parser.add_argument( "--foot-inds-path", type=str, default=f"./{MODELS_DIR}/smplx_parametrization/foot_inds.npy", help="Foot indinces", ) parser.add_argument( "--save-path", type=str, required=True, help="Path to save the results", ) parser.add_argument( "--img-path", type=str, required=True, help="Path to img to test", ) parser.add_argument( "--use-contacts", action="store_true", help="Use contact model", ) parser.add_argument( "--use-msc", action="store_true", help="Use MSC loss", ) parser.add_argument( "--use-natural", action="store_true", help="Use regularity", ) parser.add_argument( "--use-cos", action="store_true", help="Use cos model", ) parser.add_argument( "--use-angle-transf", action="store_true", help="Use cube foreshortening transformation", ) parser.add_argument( "--c-mse", type=float, default=0, help="MSE weight", ) parser.add_argument( "--c-par", type=float, default=10, help="Parallel weight", ) parser.add_argument( "--c-f", type=float, default=1000, help="Cos coef", ) parser.add_argument( "--c-parallel", type=float, default=100, help="Parallel weight", ) parser.add_argument( "--c-reg", type=float, default=1000, help="Regularity weight", ) parser.add_argument( "--c-cont2d", type=float, default=1, help="Contact 2D weight", ) parser.add_argument( "--c-msc", type=float, default=17_500, help="MSC weight", ) parser.add_argument( "--fist", nargs="+", type=str, choices=list(fist_pose.INT_TO_FIST), ) args = parser.parse_args() return args def freeze_layers(model): for module in model.modules(): if type(module) is False: continue if isinstance(module, nn.modules.batchnorm._BatchNorm): module.eval() for m in module.parameters(): m.requires_grad = False if isinstance(module, nn.Dropout): module.eval() for m in module.parameters(): m.requires_grad = False def project_and_normalize_to_spin(vertices_3d, camera): vertices_2d = vertices_3d # [:, :2] scale, translate = camera[0], camera[1:] translate = scale.new_zeros(3) translate[:2] = camera[1:] vertices_2d = vertices_2d + translate vertices_2d = scale * vertices_2d + 1 vertices_2d = spin.constants.IMG_RES / 2 * vertices_2d return vertices_2d def project_and_normalize_to_spin_legs(vertices_3d, A, camera): A, J = A A = A[0] J = J[0] L = vertices_3d.new_tensor( [ [0.98619063, 0.16560926, 0.00127302], [-0.16560601, 0.98603675, 0.01749799], [0.00164258, -0.01746717, 0.99984609], ] ) R = vertices_3d.new_tensor( [ [0.9910211, -0.13368178, -0.0025208], [0.13367888, 0.99027076, 0.03864949], [-0.00267045, -0.03863944, 0.99924965], ] ) scale = camera[0] R = A[2, :3, :3] @ R # 2 - right L = A[1, :3, :3] @ L # 1 - left r = J[5] - J[2] l = J[4] - J[1] rleg = scale * spin.constants.IMG_RES / 2 * R @ r lleg = scale * spin.constants.IMG_RES / 2 * L @ l rleg = rleg[:2] lleg = lleg[:2] return rleg, lleg def rotation_matrix_to_angle_axis(rotmat): bs, n_joints, *_ = rotmat.size() rotmat = torch.cat( [ rotmat.view(-1, 3, 3), rotmat.new_tensor([0, 0, 1], dtype=torch.float32) .view(bs, 3, 1) .expand(n_joints, -1, -1), ], dim=-1, ) aa = torchgeometry.rotation_matrix_to_angle_axis(rotmat) aa = aa.reshape(bs, 3 * n_joints) return aa def get_smpl_output(smpl, rotmat, betas, use_betas=True, zero_hands=False): if smpl.name() == "SMPL": smpl_output = smpl( betas=betas if use_betas else None, body_pose=rotmat[:, 1:], global_orient=rotmat[:, 0].unsqueeze(1), pose2rot=False, ) elif smpl.name() == "SMPL-X": rotmat = rotation_matrix_to_angle_axis(rotmat) if zero_hands: for i in [20, 21]: rotmat[:, 3 * i : 3 * (i + 1)] = 0 for i in [12, 15]: # neck, head rotmat[:, 3 * i + 1] = 0 # y smpl_output = smpl( betas=betas if use_betas else None, body_pose=rotmat[:, 3:], global_orient=rotmat[:, :3], pose2rot=True, ) else: raise NotImplementedError return smpl_output, rotmat def get_predictions(model_hmr, smpl, input_img, use_betas=True, zero_hands=False): input_img = input_img.unsqueeze(0) rotmat, betas, camera = model_hmr(input_img) smpl_output, rotmat = get_smpl_output( smpl, rotmat, betas, use_betas=use_betas, zero_hands=zero_hands ) rotmat = rotmat.squeeze(0) betas = betas.squeeze(0) camera = camera.squeeze(0) z = smpl_output.joints z = z.squeeze(0) return rotmat, betas, camera, smpl_output, z def get_pred_and_data( model_hmr, smpl, selector, input_img, use_betas=True, zero_hands=False ): rotmat, betas, camera, smpl_output, zz = get_predictions( model_hmr, smpl, input_img, use_betas=use_betas, zero_hands=zero_hands ) joints = smpl_output.joints.squeeze(0) joints_2d = project_and_normalize_to_spin(joints, camera) rleg, lleg = project_and_normalize_to_spin_legs(joints, smpl_output.A, camera) joints_2d_orig = joints_2d joints_2d = joints_2d[selector] vertices = smpl_output.vertices.squeeze(0) vertices_2d = project_and_normalize_to_spin(vertices, camera) zz = zz[selector] return ( rotmat, betas, camera, joints_2d, zz, vertices_2d, smpl_output, (rleg, lleg), joints_2d_orig, ) def normalize_keypoints_to_spin(keypoints_2d, img_size): h, w = img_size if h > w: # vertically ax1 = 1 ax2 = 0 else: # horizontal ax1 = 0 ax2 = 1 shift = (img_size[ax1] - img_size[ax2]) / 2 scale = spin.constants.IMG_RES / img_size[ax2] keypoints_2d_normalized = np.copy(keypoints_2d) keypoints_2d_normalized[:, ax2] -= shift keypoints_2d_normalized *= scale return keypoints_2d_normalized, shift, scale, ax2 def unnormalize_keypoints_from_spin(keypoints_2d, shift, scale, ax2): keypoints_2d_normalized = np.copy(keypoints_2d) keypoints_2d_normalized /= scale keypoints_2d_normalized[:, ax2] += shift return keypoints_2d_normalized def get_vertices_in_heatmap(contact_heatmap): contact_heatmap_size = contact_heatmap.shape[:2] label = measure.label(contact_heatmap) y_data_conts = [] for i in range(1, label.max() + 1): predicted_kps_contact = np.vstack(np.nonzero(label == i)[::-1]).T.astype( "float" ) predicted_kps_contact_scaled, *_ = normalize_keypoints_to_spin( predicted_kps_contact, contact_heatmap_size ) y_data_cont = torch.from_numpy(predicted_kps_contact_scaled).int().tolist() y_data_cont = shapely.geometry.MultiPoint(y_data_cont).convex_hull y_data_conts.append(y_data_cont) return y_data_conts def get_contact_heatmap(model_contact, img_path, thresh=0.5): contact_heatmap = pose_estimation.infer_single_image( model_contact, img_path, input_img_size=(192, 256), return_kps=False, ) contact_heatmap = contact_heatmap.squeeze(0) contact_heatmap_orig = contact_heatmap.copy() mi = contact_heatmap.min() ma = contact_heatmap.max() contact_heatmap = (contact_heatmap - mi) / (ma - mi) contact_heatmap_ = ((contact_heatmap > thresh) * 255).astype("uint8") contact_heatmap = np.repeat(contact_heatmap[..., None], repeats=3, axis=-1) contact_heatmap = (contact_heatmap * 255).astype("uint8") return contact_heatmap_, contact_heatmap, contact_heatmap_orig def discretize(parametrization, n_bins=100): bins = np.linspace(0, 1, n_bins + 1) inds = np.digitize(parametrization, bins) disc_parametrization = bins[inds - 1] return disc_parametrization def get_mapping_from_params_to_verts(verts, params): mapping = {} for v, t in zip(verts, params): mapping.setdefault(t, []).append(v) return mapping def find_contacts(y_data_conts, keypoints_2d, bone_to_params, thresh=12, step=0.0072246375): n_bins = int(math.ceil(1 / step)) - 1 # mean face's circumradius contact = [] contact_2d = [] for_mask = [] for y_data_cont in y_data_conts: contact_loc = [] contact_2d_loc = [] buffer = y_data_cont.buffer(thresh) mask_add = False for i, j in pose_estimation.SKELETON: verts, t3d = bone_to_params[(i, j)] if len(verts) == 0: continue t3d = discretize(t3d, n_bins=n_bins) t3d_to_verts = get_mapping_from_params_to_verts(verts, t3d) t3d_to_verts_sorted = sorted(t3d_to_verts.items(), key=lambda x: x[0]) t3d_sorted_np = np.array([x for x, _ in t3d_to_verts_sorted]) line = shapely.geometry.LineString([keypoints_2d[i], keypoints_2d[j]]) lint = buffer.intersection(line) if len(lint.boundary.geoms) < 2: continue t2d_start = line.project(lint.boundary.geoms[0], normalized=True) t2d_end = line.project(lint.boundary.geoms[1], normalized=True) assert t2d_start <= t2d_end t2ds = discretize( np.linspace(t2d_start, t2d_end, n_bins + 1), n_bins=n_bins ) to_add = False for t2d in t2ds: if t2d < t3d_sorted_np[0] or t2d > t3d_sorted_np[-1]: continue t2d_ind = np.searchsorted(t3d_sorted_np, t2d) c = t3d_to_verts_sorted[t2d_ind][1] contact_loc.extend(c) to_add = True mask_add = True if t2d_ind + 1 < len(t3d_to_verts_sorted): c = t3d_to_verts_sorted[t2d_ind + 1][1] contact_loc.extend(c) if t2d_ind > 0: c = t3d_to_verts_sorted[t2d_ind - 1][1] contact_loc.extend(c) if to_add: contact_2d_loc.append((i, j, t2d_start + 0.5 * (t2d_end - t2d_start))) if mask_add: for_mask.append(buffer.exterior.coords.xy) contact_loc = sorted(set(contact_loc)) contact_loc = np.array(contact_loc, dtype="int") contact.append(contact_loc) contact_2d.append(contact_2d_loc) for_mask = [np.stack((x, y), axis=0).T[:, None].astype("int") for x, y in for_mask] return contact, contact_2d, for_mask def optimize( model_hmr, smpl, selector, input_img, keypoints_2d, optimizer, args, loss_mse=None, loss_parallel=None, c_mse=0.0, c_new_mse=1.0, c_beta=1e-3, sc_crit=None, msc_crit=None, contact=None, n_steps=60, save_path=None, writer=None, i_ini=0, ): to_save = False if save_path is not None: ( img_original, predicted_keypoints_2d, save_path, shift, scale, ax2, prefix, ) = save_path to_save = True mean_zfoot_val = {} with tqdm.trange(n_steps) as pbar: for i in pbar: global_step = i + i_ini optimizer.zero_grad() ( rotmat_pred, betas_pred, camera_pred, keypoints_3d_pred, z, vertices_2d_pred, smpl_output, (rleg, lleg), joints_2d_orig, ) = get_pred_and_data( model_hmr, smpl, selector, input_img, ) keypoints_2d_pred = keypoints_3d_pred[:, :2] if to_save: utils.save_results_image( camera=camera_pred.detach().cpu().numpy(), focal_length_x=spin.constants.FOCAL_LENGTH, focal_length_y=spin.constants.FOCAL_LENGTH, vertices=smpl_output.vertices.detach()[0].cpu().numpy(), input_img=img_original, faces=smpl.faces, keypoints=predicted_keypoints_2d, keypoints_2=unnormalize_keypoints_from_spin( keypoints_2d_pred.detach().cpu().numpy(), shift, scale, ax2 ), # keypoints_2=unnormalize_keypoints_from_spin(joints_2d_orig.detach().cpu().numpy(), shift, scale, ax2), # heatmap=predicted_contact_heatmap_raw, filename=save_path / f"{prefix}_{i:0>4}.png", contactlist=contact, user_study=False, ) loss = l2 = 0.0 if c_mse > 0 and loss_mse is not None: l2 = loss_mse(keypoints_2d_pred, keypoints_2d) loss = loss + c_mse * l2 if writer is not None: writer.add_scalar("mse", l2, global_step=global_step) vertices_pred = smpl_output.vertices lpar = z_loss = loss_sh = 0.0 if c_new_mse > 0 and loss_parallel is not None: Ltan, Lcos, Lpar, Lspine, Lgr, Lstraight3d, Lcon2d = loss_parallel( keypoints_3d_pred, keypoints_2d, z, (rleg, lleg), writer=writer, global_step=global_step, ) lpar = ( Ltan + c_new_mse * (args.c_f * Lcos + args.c_parallel * Lpar) + Lspine + args.c_reg * Lgr + args.c_reg * Lstraight3d + args.c_cont2d * Lcon2d ) loss = loss + 300 * lpar if writer is not None: writer.add_scalar("tan", Ltan, global_step=global_step) writer.add_scalar("cos", Lcos, global_step=global_step) writer.add_scalar("par", Lpar, global_step=global_step) writer.add_scalar("spine", Lspine, global_step=global_step) writer.add_scalar("ground/chain", Lgr, global_step=global_step) writer.add_scalar( "straight_in_3d", Lstraight3d, global_step=global_step ) writer.add_scalar("contact/con2d", Lcon2d, global_step=global_step) for side in ["left", "right"]: attr = f"{side}_foot_inds" if hasattr(loss_parallel, attr): foot_inds = getattr(loss_parallel, attr) zind = 1 if attr not in mean_zfoot_val: with torch.no_grad(): mean_zfoot_val[attr] = torch.median( vertices_pred[0, foot_inds, zind], dim=0 ).values loss_foot = ( (vertices_pred[0, foot_inds, zind] - mean_zfoot_val[attr]) ** 2 ).sum() loss = loss + args.c_reg * loss_foot if writer is not None: writer.add_scalar( f"ground/{side} foot", loss_foot, global_step=global_step, ) if hasattr(loss_parallel, "silhuette_vertices_inds"): inds = loss_parallel.silhuette_vertices_inds loss_sh = ( (vertices_pred[0, inds, 1] - loss_parallel.ground) ** 2 ).sum() loss = loss + args.c_reg * loss_sh if writer is not None: writer.add_scalar( "ground/silhuette", loss_sh, global_step=global_step ) lbeta = (betas_pred**2).mean() lcam = ((torch.exp(-camera_pred[0] * 10)) ** 2).mean() loss = loss + c_beta * lbeta + lcam if writer is not None: writer.add_scalar("loss/beta", lbeta, global_step=global_step) writer.add_scalar("loss/cam", lcam, global_step=global_step) lgsc_a = gsc_contact_loss = faces_angle_loss = 0.0 if sc_crit is not None: gsc_contact_loss, faces_angle_loss = sc_crit( vertices_pred, ) lgsc_a = 1000 * gsc_contact_loss + 0.1 * faces_angle_loss loss = loss + lgsc_a if writer is not None: writer.add_scalar( "contact/gsc", gsc_contact_loss, global_step=global_step ) writer.add_scalar( "contact/faces_angle", faces_angle_loss, global_step=global_step ) msc_loss = 0.0 if contact is not None and len(contact) > 0 and msc_crit is not None: if not isinstance(contact, list): contact = [contact] for cntct in contact: msc_loss = msc_crit( cntct, vertices_pred, ) loss = loss + args.c_msc * msc_loss if writer is not None: writer.add_scalar( "contact/msc", msc_loss, global_step=global_step ) loss.backward() optimizer.step() epoch_loss = loss.item() pbar.set_postfix( **{ "l": f"{epoch_loss:.3}", "l2": f"{l2:.3}", "par": f"{lpar:.3}", "beta": f"{lbeta:.3}", "cam": f"{lcam:.3}", "z": f"{z_loss:.3}", "gsc_contact": f"{float(gsc_contact_loss):.3}", "faces_angle": f"{float(faces_angle_loss):.3}", "msc": f"{float(msc_loss):.3}", } ) with torch.no_grad(): ( rotmat_pred, betas_pred, camera_pred, keypoints_3d_pred, z, vertices_2d_pred, smpl_output, (rleg, lleg), joints_2d_orig, ) = get_pred_and_data( model_hmr, smpl, selector, input_img, zero_hands=True, ) return ( rotmat_pred, betas_pred, camera_pred, keypoints_3d_pred, vertices_2d_pred, smpl_output, z, joints_2d_orig, ) def optimize_ft( theta, camera, smpl, selector, input_img, keypoints_2d, args, loss_mse=None, loss_parallel=None, c_mse=0.0, c_new_mse=1.0, sc_crit=None, msc_crit=None, contact=None, n_steps=60, save_path=None, writer=None, i_ini=0, zero_hands=False, fist=None, ): to_save = False if save_path is not None: ( img_original, predicted_keypoints_2d, save_path, shift, scale, ax2, prefix, ) = save_path to_save = True mean_zfoot_val = {} theta = theta.detach().clone() camera = camera.detach().clone() rotmat_pred = nn.Parameter(theta) camera_pred = nn.Parameter(camera) optimizer = torch.optim.Adam( [ rotmat_pred, camera_pred, ], lr=1e-3, ) global_step = i_ini with tqdm.trange(n_steps) as pbar: for i in pbar: global_step = i + i_ini optimizer.zero_grad() global_orient = rotmat_pred[:3] body_pose = rotmat_pred[3:] smpl_output = smpl( global_orient=global_orient.unsqueeze(0), body_pose=body_pose.unsqueeze(0), pose2rot=True, ) z = smpl_output.joints z = z.squeeze(0) joints = smpl_output.joints.squeeze(0) joints_2d = project_and_normalize_to_spin(joints, camera_pred) rleg, lleg = project_and_normalize_to_spin_legs( joints, smpl_output.A, camera_pred ) joints_2d = joints_2d[selector] z = z[selector] keypoints_3d_pred = joints_2d keypoints_2d_pred = keypoints_3d_pred[:, :2] if to_save: utils.save_results_image( camera=camera_pred.detach().cpu().numpy(), focal_length_x=spin.constants.FOCAL_LENGTH, focal_length_y=spin.constants.FOCAL_LENGTH, vertices=smpl_output.vertices.detach()[0].cpu().numpy(), input_img=img_original, faces=smpl.faces, keypoints=predicted_keypoints_2d, keypoints_2=unnormalize_keypoints_from_spin( keypoints_2d_pred.detach().cpu().numpy(), shift, scale, ax2 ), # keypoints_2=unnormalize_keypoints_from_spin(joints_2d_orig.detach().cpu().numpy(), shift, scale, ax2), # heatmap=predicted_contact_heatmap_raw, filename=save_path / f"{prefix}_{i:0>4}.png", contactlist=contact, user_study=False, ) lprior = ((rotmat_pred - theta) ** 2).sum() + ( (camera_pred - camera) ** 2 ).sum() loss = lprior l2 = 0.0 if c_mse > 0 and loss_mse is not None: l2 = loss_mse(keypoints_2d_pred, keypoints_2d) loss = loss + c_mse * l2 if writer is not None: writer.add_scalar("mse", l2, global_step=global_step) vertices_pred = smpl_output.vertices lpar = z_loss = loss_sh = 0.0 if c_new_mse > 0 and loss_parallel is not None: Ltan, Lcos, Lpar, Lspine, Lgr, Lstraight3d, Lcon2d = loss_parallel( keypoints_3d_pred, keypoints_2d, z, (rleg, lleg), writer=writer, global_step=global_step, ) lpar = ( Ltan + c_new_mse * (args.c_f * Lcos + args.c_parallel * Lpar) + Lspine + args.c_reg * Lgr + args.c_reg * Lstraight3d + args.c_cont2d * Lcon2d ) loss = loss + 300 * lpar if writer is not None: writer.add_scalar("tan", Ltan, global_step=global_step) writer.add_scalar("cos", Lcos, global_step=global_step) writer.add_scalar("par", Lpar, global_step=global_step) writer.add_scalar("spine", Lspine, global_step=global_step) writer.add_scalar("ground/chain", Lgr, global_step=global_step) writer.add_scalar( "straight_in_3d", Lstraight3d, global_step=global_step ) writer.add_scalar("contact/con2d", Lcon2d, global_step=global_step) for side in ["left", "right"]: attr = f"{side}_foot_inds" if hasattr(loss_parallel, attr): foot_inds = getattr(loss_parallel, attr) zind = 1 if attr not in mean_zfoot_val: with torch.no_grad(): mean_zfoot_val[attr] = torch.median( vertices_pred[0, foot_inds, zind], dim=0 ).values loss_foot = ( (vertices_pred[0, foot_inds, zind] - mean_zfoot_val[attr]) ** 2 ).sum() loss = loss + args.c_reg * loss_foot if writer is not None: writer.add_scalar( f"ground/{side} foot", loss_foot, global_step=global_step, ) if hasattr(loss_parallel, "silhuette_vertices_inds"): inds = loss_parallel.silhuette_vertices_inds loss_sh = ( (vertices_pred[0, inds, 1] - loss_parallel.ground) ** 2 ).sum() loss = loss + args.c_reg * loss_sh if writer is not None: writer.add_scalar( "ground/silhuette", loss_sh, global_step=global_step ) lgsc_a = gsc_contact_loss = faces_angle_loss = 0.0 if sc_crit is not None: gsc_contact_loss, faces_angle_loss = sc_crit(vertices_pred) lgsc_a = 1000 * gsc_contact_loss + 0.1 * faces_angle_loss loss = loss + lgsc_a if writer is not None: writer.add_scalar( "contact/gsc", gsc_contact_loss, global_step=global_step ) writer.add_scalar( "contact/faces_angle", faces_angle_loss, global_step=global_step ) msc_loss = 0.0 if contact is not None and len(contact) > 0 and msc_crit is not None: if not isinstance(contact, list): contact = [contact] for cntct in contact: msc_loss = msc_crit( cntct, vertices_pred, ) loss = loss + args.c_msc * msc_loss if writer is not None: writer.add_scalar( "contact/msc", msc_loss, global_step=global_step ) loss.backward() optimizer.step() epoch_loss = loss.item() pbar.set_postfix( **{ "l": f"{epoch_loss:.3}", "l2": f"{l2:.3}", "par": f"{lpar:.3}", "z": f"{z_loss:.3}", "gsc_contact": f"{float(gsc_contact_loss):.3}", "faces_angle": f"{float(faces_angle_loss):.3}", "msc": f"{float(msc_loss):.3}", } ) rotmat_pred = rotmat_pred.detach() if zero_hands: for i in [20, 21]: rotmat_pred[3 * i : 3 * (i + 1)] = 0 for i in [12, 15]: # neck, head rotmat_pred[3 * i + 1] = 0 # y global_orient = rotmat_pred[:3] body_pose = rotmat_pred[3:] left_hand_pose = None right_hand_pose = None if fist is not None: left_hand_pose = rotmat_pred.new_tensor(fist_pose.LEFT_RELAXED).unsqueeze(0) right_hand_pose = rotmat_pred.new_tensor(fist_pose.RIGHT_RELAXED).unsqueeze(0) for f in fist: pp = fist_pose.INT_TO_FIST[f] if pp is not None: pp = rotmat_pred.new_tensor(pp).unsqueeze(0) if f.startswith("lf"): left_hand_pose = pp elif f.startswith("rf"): right_hand_pose = pp elif f.startswith("l"): body_pose[19 * 3 : 19 * 3 + 3] = pp left_hand_pose = None elif f.startswith("r"): body_pose[20 * 3 : 20 * 3 + 3] = pp right_hand_pose = None else: raise RuntimeError(f"No such hand pose: {f}") with torch.no_grad(): smpl_output = smpl( global_orient=global_orient.unsqueeze(0), body_pose=body_pose.unsqueeze(0), left_hand_pose=left_hand_pose, right_hand_pose=right_hand_pose, pose2rot=True, ) return rotmat_pred, smpl_output def create_bone(i, j, keypoints_2d): a = keypoints_2d[i] b = keypoints_2d[j] ab = b - a ab = torch.nn.functional.normalize(ab, dim=0) return ab def is_parallel_to_plane(bone, thresh=21): return abs(bone[0]) > math.cos(math.radians(thresh)) def is_close_to_plane(bone, plane, thresh): dist = abs(bone[0] - plane) return dist < thresh def get_selector(): selector = [] for kp in pose_estimation.KPS: tmp = spin.JOINT_NAMES.index(PE_KSP_TO_SPIN[kp]) selector.append(tmp) return selector def calc_cos(joints_2d, joints_3d): cos = [] for i, j in pose_estimation.SKELETON: a = joints_2d[i] - joints_2d[j] a = nn.functional.normalize(a, dim=0) b = joints_3d[i] - joints_3d[j] b = nn.functional.normalize(b, dim=0)[:2] c = (a * b).sum() cos.append(c) cos = torch.stack(cos, dim=0) return cos def get_natural(keypoints_2d, vertices, right_foot_inds, left_foot_inds, loss_parallel, smpl): height_2d = ( keypoints_2d.max(dim=0).values[0] - keypoints_2d.min(dim=0).values[0] ).item() plane_2d = keypoints_2d.max(dim=0).values[0].item() ground_parallel = [] parallel_in_3d = [] parallel3d_bones = set() # parallel chains for i, j, k in [ ("Right Upper Leg", "Right Leg", "Right Foot"), ("Right Leg", "Right Foot", "Right Toe"), # to remove? ("Left Upper Leg", "Left Leg", "Left Foot"), ("Left Leg", "Left Foot", "Left Toe"), # to remove? ("Right Shoulder", "Right Arm", "Right Hand"), ("Left Shoulder", "Left Arm", "Left Hand"), # ("Hips", "Spine", "Neck"), # ("Spine", "Neck", "Head"), ]: i = pose_estimation.KPS.index(i) j = pose_estimation.KPS.index(j) k = pose_estimation.KPS.index(k) upleg_leg = create_bone(i, j, keypoints_2d) leg_foot = create_bone(j, k, keypoints_2d) if is_parallel_to_plane(upleg_leg) and is_parallel_to_plane(leg_foot): if is_close_to_plane( upleg_leg, plane_2d, thresh=0.1 * height_2d ) or is_close_to_plane(leg_foot, plane_2d, thresh=0.1 * height_2d): ground_parallel.append(((i, j), 1)) ground_parallel.append(((j, k), 1)) if (upleg_leg * leg_foot).sum() > math.cos(math.radians(21)): parallel_in_3d.append(((i, j), (j, k))) parallel3d_bones.add((i, j)) parallel3d_bones.add((j, k)) # parallel feets for i, j in [ ("Right Foot", "Right Toe"), ("Left Foot", "Left Toe"), ]: i = pose_estimation.KPS.index(i) j = pose_estimation.KPS.index(j) if (i, j) in parallel3d_bones: continue foot_toe = create_bone(i, j, keypoints_2d) if is_parallel_to_plane(foot_toe, thresh=25): if "Right" in pose_estimation.KPS[i]: loss_parallel.right_foot_inds = right_foot_inds else: loss_parallel.left_foot_inds = left_foot_inds loss_parallel.ground_parallel = ground_parallel loss_parallel.parallel_in_3d = parallel_in_3d vertices_np = vertices[0].cpu().numpy() if len(ground_parallel) > 0: # Silhuette veritices mesh = trimesh.Trimesh(vertices=vertices_np, faces=smpl.faces, process=False) silhuette_vertices_mask_1 = np.abs(mesh.vertex_normals[..., 2]) < 2e-1 height_3d = vertices_np[:, 1].max() - vertices_np[:, 1].min() plane_3d = vertices_np[:, 1].max() silhuette_vertices_mask_2 = ( np.abs(vertices_np[:, 1] - plane_3d) < 0.15 * height_3d ) silhuette_vertices_mask = np.logical_and( silhuette_vertices_mask_1, silhuette_vertices_mask_2 ) (silhuette_vertices_inds,) = np.where(silhuette_vertices_mask) if len(silhuette_vertices_inds) > 0: loss_parallel.silhuette_vertices_inds = silhuette_vertices_inds loss_parallel.ground = plane_3d def get_cos(keypoints_3d_pred, use_angle_transf, loss_parallel): keypoints_2d_pred = keypoints_3d_pred[:, :2] with torch.no_grad(): cos_r = calc_cos(keypoints_2d_pred, keypoints_3d_pred) alpha = torch.acos(cos_r) if use_angle_transf: leg_inds = [ 5, 6, # right leg 7, 8, # left leg ] foot_inds = [15, 16] nleg_inds = sorted( set(range(len(pose_estimation.SKELETON))) - set(leg_inds) - set(foot_inds) ) alpha[nleg_inds] = alpha[nleg_inds] - alpha[nleg_inds].min() amli = alpha[leg_inds].min() leg_inds.extend(foot_inds) alpha[leg_inds] = alpha[leg_inds] - amli angles = alpha.detach().cpu().numpy() angles = hist_cub.cub( angles / (math.pi / 2), a=1.2121212121212122, b=-1.105527638190953, c=0.787878787878789, ) * (math.pi / 2) alpha = alpha.new_tensor(angles) loss_parallel.cos = torch.cos(alpha) return cos_r def save_mesh_with_winding_numbers(sc_module, vertices, smpl, save_path): triangles = sc_module.triangles(vertices) exterior = sc_module.get_intersection_mask(vertices, triangles, test_segments=False) exterior = exterior.cpu().numpy().squeeze(0) utils.save_mesh_with_colors( vertices[0].cpu().numpy(), smpl.faces, save_path / "winding_numbers.ply", mask=exterior, ) exterior = sc_module.get_intersection_mask(vertices, triangles) exterior = exterior.cpu().numpy().squeeze(0) utils.save_mesh_with_colors( vertices[0].cpu().numpy(), smpl.faces, save_path / "winding_numbers_filtered.ply", mask=exterior, ) def get_contacts( args, sc_module, y_data_conts, keypoints_2d, vertices, bone_to_params, loss_parallel, img_size_original, save_path, ): use_contacts = args.use_contacts use_msc = args.use_msc c_mse = args.c_mse if use_contacts: assert c_mse == 0 contact, contact_2d, for_mask = find_contacts( y_data_conts, keypoints_2d, bone_to_params ) if len(contact_2d) > 0: loss_parallel.contact_2d = contact_2d mask = np.zeros((spin.constants.IMG_RES, spin.constants.IMG_RES), dtype="uint8") mask += 255 cv2.drawContours(mask, for_mask, -1, 0, 2) mask = cv2.resize(mask, img_size_original[::-1]) cv2.imwrite(str(save_path / "mask.png"), mask) if len(contact) == 0: _, contact = sc_module.verts_in_contact(vertices, return_idx=True) contact = contact.cpu().numpy().ravel() elif use_msc: _, contact = sc_module.verts_in_contact(vertices, return_idx=True) contact = contact.cpu().numpy().ravel() else: contact = np.array([]) return contact def save_all( keypoints_3d_pred, rotmat_pred, camera_pred, betas_pred, smpl, contact, img_original, predicted_keypoints_2d, predicted_contact_heatmap_raw, loss_parallel, smpl_output, shift, scale, ax2, summary_writer, save_path, fname, ): keypoints_2d_pred = keypoints_3d_pred[:, :2] vertices = smpl_output.vertices.detach() betas_pred = betas_pred.detach().cpu().numpy() utils.save_pose_params( rotmat_pred, camera_pred, betas_pred, vertices, smpl, contact, save_path / f"{fname}.pkl", ) if hasattr(loss_parallel, "silhuette_vertices_inds"): contact.append(loss_parallel.silhuette_vertices_inds) img_sw = utils.save_results_image( camera=camera_pred.detach().cpu().numpy(), focal_length_x=spin.constants.FOCAL_LENGTH, focal_length_y=spin.constants.FOCAL_LENGTH, vertices=vertices[0].cpu().numpy(), input_img=img_original, faces=smpl.faces, keypoints=predicted_keypoints_2d, keypoints_2=unnormalize_keypoints_from_spin( keypoints_2d_pred.cpu().numpy(), shift, scale, ax2 ) if shift is not None else None, # keypoints_2=unnormalize_keypoints_from_spin(joints_2d_orig.detach().cpu().numpy(), shift, scale, ax2) if shift is not None else None, heatmap=predicted_contact_heatmap_raw, filename=save_path / f"{fname}.png", contactlist=contact, contact2dlist=loss_parallel.contact_2d if hasattr(loss_parallel, "contact_2d") else None, cos=loss_parallel.cos.tolist() if loss_parallel.cos is not None else None, ) utils.save_mesh_with_colors( smpl_output.vertices[0].cpu().numpy(), smpl.faces, save_path / f"{fname}.ply", inds=contact, ) joints = smpl_output.joints.squeeze(0).cpu().numpy() fig = utils.plot_3D(joints, vertices.squeeze(0).cpu().numpy(), smpl.faces) fig.write_html(save_path / f"{fname}.html") summary_writer.add_image( fname, np.array(img_sw).astype("float32") / 255, dataformats="HWC" ) summary_writer.add_mesh( fname, vertices=( vertices.cpu().float()[0] @ torch.tensor([[1, 0, 0], [0, -1, 0], [0, 0, -1.0]]) ).unsqueeze(0), faces=torch.from_numpy(smpl.faces[None].astype("int64")), ) def spin_step( model_hmr, smpl, selector, input_img, img_original, predicted_keypoints_2d, predicted_contact_heatmap_raw, loss_parallel, shift, scale, ax2, summary_writer, save_path, ): with torch.no_grad(): ( rotmat_pred, betas_pred, camera_pred, keypoints_3d_pred, _, _, smpl_output, _, _, ) = get_pred_and_data( model_hmr, smpl, selector, input_img, zero_hands=True, ) save_all( keypoints_3d_pred, rotmat_pred, camera_pred, betas_pred, smpl, None, img_original, predicted_keypoints_2d, predicted_contact_heatmap_raw, loss_parallel, smpl_output, shift, scale, ax2, summary_writer, save_path, "spin", ) def eft_step( model_hmr, smpl, selector, input_img, keypoints_2d, optimizer, args, loss_mse, loss_parallel, c_beta, sc_module, y_data_conts, bone_to_params, img_original, predicted_keypoints_2d, predicted_contact_heatmap_raw, shift, scale, ax2, summary_writer, save_path, ): img_size_original = img_original.shape[:2] ( rotmat_pred, betas_pred, camera_pred, keypoints_3d_pred, _, smpl_output, _, _, ) = optimize( model_hmr, smpl, selector, input_img, keypoints_2d, optimizer, args, loss_mse=loss_mse, loss_parallel=loss_parallel, c_mse=1, c_new_mse=0, c_beta=c_beta, sc_crit=None, msc_crit=None, contact=None, n_steps=60 + 90, writer=summary_writer, ) # find contacts vertices = smpl_output.vertices.detach() contact = get_contacts( args, sc_module, y_data_conts, keypoints_2d, vertices, bone_to_params, loss_parallel, img_size_original, save_path, ) save_all( keypoints_3d_pred, rotmat_pred, camera_pred, betas_pred, smpl, contact, img_original, predicted_keypoints_2d, predicted_contact_heatmap_raw, loss_parallel, smpl_output, shift, scale, ax2, summary_writer, save_path, "eft", ) if sc_module is not None: save_mesh_with_winding_numbers(sc_module, vertices, smpl, save_path) return vertices, keypoints_3d_pred, contact def dc_step( model_hmr, smpl, selector, input_img, keypoints_2d, optimizer, args, loss_mse, loss_parallel, c_mse, c_new_mse, c_beta, sc_crit, msc_crit, contact, use_contacts, use_msc, img_original, predicted_keypoints_2d, predicted_contact_heatmap_raw, shift, scale, ax2, summary_writer, save_path, ): ( rotmat_pred, betas_pred, camera_pred, keypoints_3d_pred, _, smpl_output, _, _, ) = optimize( model_hmr, smpl, selector, input_img, keypoints_2d, optimizer, args, loss_mse=loss_mse, loss_parallel=loss_parallel, c_mse=c_mse, c_new_mse=c_new_mse, c_beta=c_beta, sc_crit=sc_crit, msc_crit=msc_crit if use_contacts or use_msc else None, contact=contact if use_contacts or use_msc else None, n_steps=60 if use_contacts or use_msc else 0, # + 60, # save_path=(img_original, predicted_keypoints_2d, save_path, shift, scale, ax2, "dc"), writer=summary_writer, i_ini=60 + 90, ) save_all( keypoints_3d_pred, rotmat_pred, camera_pred, betas_pred, smpl, contact, img_original, predicted_keypoints_2d, predicted_contact_heatmap_raw, loss_parallel, smpl_output, shift, scale, ax2, summary_writer, save_path, "dc", ) return rotmat_pred def us_step( model_hmr, smpl, selector, input_img, rotmat_pred, keypoints_2d, args, loss_mse, loss_parallel, c_mse, c_new_mse, sc_crit, msc_crit, contact, use_contacts, use_msc, img_original, keypoints_3d_pred, summary_writer, save_path, ): (_, _, camera_pred_us, _, _, _, smpl_output_us, _, _,) = get_pred_and_data( model_hmr, smpl, selector, input_img, use_betas=False, zero_hands=True, ) rotmat_pred_us, smpl_output_us = optimize_ft( rotmat_pred, camera_pred_us, smpl, selector, input_img, keypoints_2d, args, loss_mse=loss_mse, loss_parallel=loss_parallel, c_mse=c_mse, c_new_mse=c_new_mse, sc_crit=sc_crit, msc_crit=msc_crit if use_contacts or use_msc else None, contact=contact if use_contacts or use_msc else None, n_steps=60 if use_contacts or use_msc else 0, # + 60, # save_path=(img_original, predicted_keypoints_2d, save_path, shift, scale, ax2, "dc"), writer=summary_writer, i_ini=60 + 90 + 60, zero_hands=True, fist=args.fist, ) save_all( keypoints_3d_pred, rotmat_pred_us, camera_pred_us, torch.zeros(1, 10, dtype=torch.float32), smpl, None, img_original, None, None, loss_parallel, smpl_output_us, None, None, None, summary_writer, save_path, "us", ) def main(): args = parse_args() print(args) # models model_pose = cv2.dnn.readNetFromONNX( args.pose_estimation_model_path ) # "hrn_w48_384x288.onnx" model_contact = cv2.dnn.readNetFromONNX( args.contact_model_path ) # "contact_hrn_w32_256x192.onnx" device = ( torch.device(args.device) if torch.cuda.is_available() else torch.device("cpu") ) model_hmr = spin.hmr(args.smpl_mean_params_path) # "smpl_mean_params.npz" model_hmr.to(device) checkpoint = torch.load( args.spin_model_path, # "spin_model_smplx_eft_18.pt" map_location="cpu" ) smpl = spin.SMPLX( args.smpl_model_dir, # "models/smplx" batch_size=1, create_transl=False, use_pca=False, flat_hand_mean=args.fist is not None, ) smpl.to(device) selector = get_selector() use_contacts = args.use_contacts use_msc = args.use_msc bone_to_params = np.load(args.bone_parametrization_path, allow_pickle=True).item() foot_inds = np.load(args.foot_inds_path, allow_pickle=True).item() left_foot_inds = foot_inds["left_foot_inds"] right_foot_inds = foot_inds["right_foot_inds"] if use_contacts: model_type = args.smpl_type sc_module = selfcontact.SelfContact( essentials_folder=args.essentials_dir, # "smplify-xmc-essentials" geothres=0.3, euclthres=0.02, test_segments=True, compute_hd=True, model_type=model_type, device=device, ) sc_module.to(device) sc_crit = selfcontact.losses.SelfContactLoss( contact_module=sc_module, inside_loss_weight=0.5, outside_loss_weight=0.0, contact_loss_weight=0.5, align_faces=True, use_hd=True, test_segments=True, device=device, model_type=model_type, ) sc_crit.to(device) msc_crit = losses.MimickedSelfContactLoss(geodesics_mask=sc_module.geomask) msc_crit.to(device) else: sc_module = None sc_crit = None msc_crit = None loss_mse = losses.MSE([1, 10, 13]) # Neck + Right Upper Leg + Left Upper Leg ignore = ( (1, 2), # Neck + Right Shoulder (1, 5), # Neck + Left Shoulder (9, 10), # Hips + Right Upper Leg (9, 13), # Hips + Left Upper Leg ) loss_parallel = losses.Parallel( skeleton=pose_estimation.SKELETON, ignore=ignore, ) c_mse = args.c_mse c_new_mse = args.c_par c_beta = 1e-3 if c_mse > 0: assert c_new_mse == 0 elif c_mse == 0: assert c_new_mse > 0 root_path = Path(args.save_path) root_path.mkdir(exist_ok=True, parents=True) path_to_imgs = Path(args.img_path) if path_to_imgs.is_dir(): path_to_imgs = path_to_imgs.iterdir() else: path_to_imgs = [path_to_imgs] for img_path in path_to_imgs: if not any( img_path.name.lower().endswith(ext) for ext in [".jpg", ".png", ".jpeg"] ): continue img_name = img_path.stem # use 2d keypoints detection ( img_original, predicted_keypoints_2d, _, _, ) = pose_estimation.infer_single_image( model_pose, img_path, input_img_size=pose_estimation.IMG_SIZE, return_kps=True, ) save_path = root_path / img_name save_path.mkdir(exist_ok=True, parents=True) # if (save_path / "us_orig.png").is_file(): # return summary_writer = SummaryWriter(log_dir=save_path / f"runDoknc2_{c_new_mse}") img_original = cv2.cvtColor(img_original, cv2.COLOR_BGR2RGB) img_size_original = img_original.shape[:2] keypoints_2d, shift, scale, ax2 = normalize_keypoints_to_spin( predicted_keypoints_2d, img_size_original ) keypoints_2d = torch.from_numpy(keypoints_2d) keypoints_2d = keypoints_2d.to(device) ( predicted_contact_heatmap, predicted_contact_heatmap_raw, very_hm_raw, ) = get_contact_heatmap(model_contact, img_path) predicted_contact_heatmap_raw = Image.fromarray( predicted_contact_heatmap_raw ).resize(img_size_original[::-1]) predicted_contact_heatmap_raw = cv2.resize(very_hm_raw, img_size_original[::-1]) if c_new_mse == 0: predicted_contact_heatmap_raw = None y_data_conts = get_vertices_in_heatmap(predicted_contact_heatmap) model_hmr.load_state_dict(checkpoint["model"], strict=True) model_hmr.train() freeze_layers(model_hmr) _, input_img = spin.process_image(img_path, input_res=spin.constants.IMG_RES) input_img = input_img.to(device) spin_step( model_hmr, smpl, selector, input_img, img_original, predicted_keypoints_2d, predicted_contact_heatmap_raw, loss_parallel, shift, scale, ax2, summary_writer, save_path, ) optimizer = optim.Adam( filter(lambda p: p.requires_grad, model_hmr.parameters()), lr=1e-6, ) vertices, keypoints_3d_pred, contact = eft_step( model_hmr, smpl, selector, input_img, keypoints_2d, optimizer, args, loss_mse, loss_parallel, c_beta, sc_module, y_data_conts, bone_to_params, img_original, predicted_keypoints_2d, predicted_contact_heatmap_raw, shift, scale, ax2, summary_writer, save_path, ) if args.use_natural: get_natural( keypoints_2d, vertices, right_foot_inds, left_foot_inds, loss_parallel, smpl, ) if args.use_cos: cos_r = get_cos(keypoints_3d_pred, args.use_angle_transf, loss_parallel) np.save(save_path / "cos_hist", cos_r.cpu().numpy()) rotmat_pred = dc_step( model_hmr, smpl, selector, input_img, keypoints_2d, optimizer, args, loss_mse, loss_parallel, c_mse, c_new_mse, c_beta, sc_crit, msc_crit, contact, use_contacts, use_msc, img_original, predicted_keypoints_2d, predicted_contact_heatmap_raw, shift, scale, ax2, summary_writer, save_path, ) us_step( model_hmr, smpl, selector, input_img, rotmat_pred, keypoints_2d, args, loss_mse, loss_parallel, c_mse, c_new_mse, sc_crit, msc_crit, contact, use_contacts, use_msc, img_original, keypoints_3d_pred, summary_writer, save_path, ) if __name__ == "__main__": main() ================================================ FILE: src/pose_estimation.py ================================================ import math import cv2 import numpy as np IMG_SIZE = (288, 384) MEAN = np.array([0.485, 0.456, 0.406]) STD = np.array([0.229, 0.224, 0.225]) KPS = ( "Head", "Neck", "Right Shoulder", "Right Arm", "Right Hand", "Left Shoulder", "Left Arm", "Left Hand", "Spine", "Hips", "Right Upper Leg", "Right Leg", "Right Foot", "Left Upper Leg", "Left Leg", "Left Foot", "Left Toe", "Right Toe", ) SKELETON = ( (0, 1), (1, 8), (8, 9), (9, 10), (9, 13), (10, 11), (11, 12), (13, 14), (14, 15), (1, 2), (2, 3), (3, 4), (1, 5), (5, 6), (6, 7), (15, 16), (12, 17), ) OPENPOSE_TO_GESTURE = ( 0, # 0 Head\n", 1, # Neck\n", 2, # 2 Right Shoulder\n", 3, # Right Arm\n", 4, # 4 Right Hand\n", 5, # Left Shoulder\n", 6, # 6 Left Arm\n", 7, # Left Hand\n", 9, # 8 Hips\n", 10, # Right Upper Leg\n", 11, # 10Right Leg\n", 12, # Right Foot\n", 13, # 12Left Upper Leg\n", 14, # Left Leg\n", 15, # 14Left Foot\n", -1, # \n", -1, # 16\n", -1, # \n", -1, # 18\n", 16, # Left Toe\n", -1, # 20\n", -1, # \n", 17, # 22Right Toe\n", -1, # \n", -1, # 24\n", ) def transform(img): img = img.astype("float32") / 255 img = (img - MEAN) / STD return np.transpose(img, axes=(2, 0, 1)) def get_affine_transform( center, scale, rot, output_size, shift=np.array([0, 0], dtype=np.float32), inv=0, pixel_std=200, ): if not isinstance(scale, np.ndarray) and not isinstance(scale, list): scale = np.array([scale, scale]) scale_tmp = scale * pixel_std src_w = scale_tmp[0] dst_w = output_size[0] dst_h = output_size[1] rot_rad = np.pi * rot / 180 src_dir = get_dir([0, src_w * -0.5], rot_rad) dst_dir = np.array([0, dst_w * -0.5], np.float32) src = np.zeros((3, 2), dtype=np.float32) dst = np.zeros((3, 2), dtype=np.float32) src[0, :] = center + scale_tmp * shift src[1, :] = center + src_dir + scale_tmp * shift dst[0, :] = [dst_w * 0.5, dst_h * 0.5] dst[1, :] = np.array([dst_w * 0.5, dst_h * 0.5]) + dst_dir src[2:, :] = get_3rd_point(src[0, :], src[1, :]) dst[2:, :] = get_3rd_point(dst[0, :], dst[1, :]) if inv: trans = cv2.getAffineTransform(np.float32(dst), np.float32(src)) else: trans = cv2.getAffineTransform(np.float32(src), np.float32(dst)) return trans def get_3rd_point(a, b): direct = a - b return b + np.array([-direct[1], direct[0]], dtype=np.float32) def get_dir(src_point, rot_rad): sn, cs = np.sin(rot_rad), np.cos(rot_rad) src_result = [0, 0] src_result[0] = src_point[0] * cs - src_point[1] * sn src_result[1] = src_point[0] * sn + src_point[1] * cs return src_result def process_image(path, input_img_size, pixel_std=200): data_numpy = cv2.imread(path, cv2.IMREAD_COLOR | cv2.IMREAD_IGNORE_ORIENTATION) # BUG HERE. Must be uncommented # data_numpy = cv2.cvtColor(data_numpy, cv2.COLOR_BGR2RGB) h, w = data_numpy.shape[:2] c = np.array([w / 2, h / 2], dtype=np.float32) aspect_ratio = input_img_size[0] / input_img_size[1] if w > aspect_ratio * h: h = w * 1.0 / aspect_ratio elif w < aspect_ratio * h: w = h * aspect_ratio s = np.array([w / pixel_std, h / pixel_std], dtype=np.float32) * 1.25 r = 0 trans = get_affine_transform(c, s, r, input_img_size, pixel_std=pixel_std) input = cv2.warpAffine(data_numpy, trans, input_img_size, flags=cv2.INTER_LINEAR) input = transform(input) return input, data_numpy, c, s def get_final_preds(batch_heatmaps, center, scale, post_process=False): coords, maxvals = get_max_preds(batch_heatmaps) heatmap_height = batch_heatmaps.shape[2] heatmap_width = batch_heatmaps.shape[3] # post-processing if post_process: for n in range(coords.shape[0]): for p in range(coords.shape[1]): hm = batch_heatmaps[n][p] px = int(math.floor(coords[n][p][0] + 0.5)) py = int(math.floor(coords[n][p][1] + 0.5)) if 1 < px < heatmap_width - 1 and 1 < py < heatmap_height - 1: diff = np.array( [ hm[py][px + 1] - hm[py][px - 1], hm[py + 1][px] - hm[py - 1][px], ] ) coords[n][p] += np.sign(diff) * 0.25 preds = coords.copy() # Transform back for i in range(coords.shape[0]): preds[i] = transform_preds( coords[i], center[i], scale[i], [heatmap_width, heatmap_height] ) return preds, maxvals def transform_preds(coords, center, scale, output_size): target_coords = np.zeros(coords.shape) trans = get_affine_transform(center, scale, 0, output_size, inv=1) for p in range(coords.shape[0]): target_coords[p, 0:2] = affine_transform(coords[p, 0:2], trans) return target_coords def affine_transform(pt, t): new_pt = np.array([pt[0], pt[1], 1.0]).T new_pt = np.dot(t, new_pt) return new_pt[:2] def get_max_preds(batch_heatmaps): """ get predictions from score maps heatmaps: numpy.ndarray([batch_size, num_joints, height, width]) """ assert isinstance( batch_heatmaps, np.ndarray ), "batch_heatmaps should be numpy.ndarray" assert batch_heatmaps.ndim == 4, "batch_images should be 4-ndim" batch_size = batch_heatmaps.shape[0] num_joints = batch_heatmaps.shape[1] width = batch_heatmaps.shape[3] heatmaps_reshaped = batch_heatmaps.reshape((batch_size, num_joints, -1)) idx = np.argmax(heatmaps_reshaped, 2) maxvals = np.amax(heatmaps_reshaped, 2) maxvals = maxvals.reshape((batch_size, num_joints, 1)) idx = idx.reshape((batch_size, num_joints, 1)) preds = np.tile(idx, (1, 1, 2)).astype(np.float32) preds[:, :, 0] = (preds[:, :, 0]) % width preds[:, :, 1] = np.floor((preds[:, :, 1]) / width) pred_mask = np.tile(np.greater(maxvals, 0.0), (1, 1, 2)) pred_mask = pred_mask.astype(np.float32) preds *= pred_mask return preds, maxvals def infer_single_image(model, img_path, input_img_size=(288, 384), return_kps=True): img_path = str(img_path) pose_input, img, center, scale = process_image( img_path, input_img_size=input_img_size ) model.setInput(pose_input[None]) predicted_heatmap = model.forward() if not return_kps: return predicted_heatmap.squeeze(0) predicted_keypoints, confidence = get_final_preds( predicted_heatmap, center[None], scale[None], post_process=True ) (predicted_keypoints, confidence, predicted_heatmap,) = ( predicted_keypoints.squeeze(0), confidence.squeeze(0), predicted_heatmap.squeeze(0), ) return img, predicted_keypoints, confidence, predicted_heatmap ================================================ FILE: src/renderer.py ================================================ import numpy as np import pyrender import torch import trimesh from torchvision.utils import make_grid class Renderer: """ Renderer used for visualizing the SMPL model Code adapted from https://github.com/vchoutas/smplify-x """ def __init__(self, focal_length=5000, img_res=224, faces=None): self.renderer = pyrender.OffscreenRenderer( viewport_width=img_res, viewport_height=img_res, point_size=1.0 ) self.focal_length = focal_length self.camera_center = [img_res // 2, img_res // 2] self.faces = faces def visualize_tb(self, vertices, camera_translation, images): vertices = vertices.cpu().numpy() camera_translation = camera_translation.cpu().numpy() images = images.cpu() images_np = np.transpose(images.numpy(), (0, 2, 3, 1)) rend_imgs = [] for i in range(vertices.shape[0]): rend_img = torch.from_numpy( np.transpose( self.__call__(vertices[i], camera_translation[i], images_np[i]), (2, 0, 1), ) ).float() rend_imgs.append(images[i]) rend_imgs.append(rend_img) rend_imgs = make_grid(rend_imgs, nrow=2) return rend_imgs def __call__(self, vertices, camera_translation, image): material = pyrender.MetallicRoughnessMaterial( metallicFactor=0.2, alphaMode="OPAQUE", baseColorFactor=(0.8, 0.3, 0.3, 1.0) ) camera_translation[0] *= -1.0 mesh = trimesh.Trimesh(vertices, self.faces) rot = trimesh.transformations.rotation_matrix(np.radians(180), [1, 0, 0]) mesh.apply_transform(rot) mesh = pyrender.Mesh.from_trimesh(mesh, material=material) scene = pyrender.Scene(ambient_light=(0.5, 0.5, 0.5)) scene.add(mesh, "mesh") camera_pose = np.eye(4) camera_pose[:3, 3] = camera_translation camera = pyrender.IntrinsicsCamera( fx=self.focal_length, fy=self.focal_length, cx=self.camera_center[0], cy=self.camera_center[1], ) scene.add(camera, pose=camera_pose) light = pyrender.DirectionalLight(color=[1.0, 1.0, 1.0], intensity=1) light_pose = np.eye(4) light_pose[:3, 3] = np.array([0, -1, 1]) scene.add(light, pose=light_pose) light_pose[:3, 3] = np.array([0, 1, 1]) scene.add(light, pose=light_pose) light_pose[:3, 3] = np.array([1, 1, 2]) scene.add(light, pose=light_pose) color, rend_depth = self.renderer.render(scene, flags=pyrender.RenderFlags.RGBA) color = color.astype(np.float32) / 255.0 valid_mask = (rend_depth > 0)[:, :, None] output_img = color[:, :, :3] * valid_mask + (1 - valid_mask) * image return output_img def overlay_mesh( verts, faces, camera_transl, focal_length_x, focal_length_y, camera_center, H, W, img, camera_rotation=None, rotaround=None, contactlist=None, color=False, scale=1, ): material = pyrender.MetallicRoughnessMaterial( metallicFactor=0.0, alphaMode="OPAQUE", baseColorFactor=(1.0, 1.0, 0.9, 1.0) ) out_mesh = trimesh.Trimesh(verts, faces, process=False) out_mesh_col = np.array(out_mesh.visual.vertex_colors) if contactlist is not None and len(contactlist) > 0: color = [255, 0, 0, 255] out_mesh_col[contactlist] = color out_mesh.visual.vertex_colors = out_mesh_col if camera_rotation is None: camera_rotation = np.eye(3) else: camera_rotation = camera_rotation[0] # rotate mesh and stack output images if rotaround is None: out_mesh.vertices = np.matmul(verts, camera_rotation.T) + camera_transl else: base_mesh = trimesh.Trimesh(verts, faces, process=False) # rot_center = (base_mesh.vertices[5615] + base_mesh.vertices[5614] ) / 2 rot = trimesh.transformations.rotation_matrix( np.radians(rotaround), [0, 1, 0], base_mesh.vertices[4297] ) base_mesh.apply_transform(rot) out_mesh.vertices = ( np.matmul(base_mesh.vertices, camera_rotation.T) + camera_transl ) out_mesh.vertices += np.array([0, 0, 50]) # add mesh to scene mesh = pyrender.Mesh.from_trimesh( out_mesh, material=material, smooth=False, ) if img is not None: scene = pyrender.Scene( bg_color=[0.0, 0.0, 0.0, 0.0], ambient_light=(0.3, 0.3, 0.3, 1.0), ) else: scene = pyrender.Scene( bg_color=[1.0, 1.0, 1.0, 1.0], ambient_light=(0.3, 0.3, 0.3, 1.0), ) scene.add(mesh, "mesh") # create and add camera camera_pose = np.eye(4) camera_pose[1, :] = -camera_pose[1, :] camera_pose[2, :] = -camera_pose[2, :] pyrencamera = pyrender.camera.OrthographicCamera( camera_transl[2], camera_transl[2], znear=1e-6, zfar=1000000, ) scene.add(pyrencamera, pose=camera_pose) # create and add light light = pyrender.PointLight( color=[1.0, 1.0, 1.0], intensity=1, ) light_pose = np.eye(4) for lp in [[1, 1, -1], [-1, 1, -1], [1, -1, -1], [-1, -1, -1]]: light_pose[:3, 3] = out_mesh.vertices.mean(0) + np.array(lp) scene.add(light, pose=light_pose) r = pyrender.OffscreenRenderer( viewport_width=int(scale * W), viewport_height=int(scale * H), point_size=1.0, ) color, _ = r.render(scene, flags=pyrender.RenderFlags.RGBA) color = color.astype(np.float32) / 255.0 if img is not None: valid_mask = (color[:, :, -1] > 0)[:, :, np.newaxis] output_img = color[:, :, :-1] * valid_mask + (1 - valid_mask) * img else: output_img = color output_img = (output_img * 255).astype(np.uint8)[..., :3] return output_img ================================================ FILE: src/spin/__init__.py ================================================ from .constants import JOINT_NAMES from .hmr import hmr from .smpl import SMPLX from .utils import process_image __all__ = [ "hmr", "SMPLX", "process_image", "JOINT_NAMES", ] ================================================ FILE: src/spin/constants.py ================================================ FOCAL_LENGTH = 5000.0 IMG_RES = 224 # Mean and standard deviation for normalizing input image IMG_NORM_MEAN = [0.485, 0.456, 0.406] IMG_NORM_STD = [0.229, 0.224, 0.225] """ We create a superset of joints containing the OpenPose joints together with the ones that each dataset provides. We keep a superset of 24 joints such that we include all joints from every dataset. If a dataset doesn't provide annotations for a specific joint, we simply ignore it. The joints used here are the following: """ JOINT_NAMES = ( "Hips", "Left Upper Leg", "Right Upper Leg", "Spine", "Left Leg", "Right Leg", "Spine1", "Left Foot", "Right Foot", "Thorax", "Left Toe", "Right Toe", "Neck", "Left Shoulder", "Right Shoulder", "Head", "Left ForeArm", "Right ForeArm", "Left Arm", "Right Arm", "Left Hand", "Right Hand", # 25 OpenPose joints (in the order provided by OpenPose) # "OP Nose", # "OP Neck", # "OP RShoulder", # "OP RElbow", # "OP RWrist", # "OP LShoulder", # "OP LElbow", # "OP LWrist", # "OP MidHip", # "OP RHip", # "OP RKnee", # "OP RAnkle", # "OP LHip", # "OP LKnee", # "OP LAnkle", # "OP REye", # "OP LEye", # "OP REar", # "OP LEar", # "OP LBigToe", # "OP LSmallToe", # "OP LHeel", # "OP RBigToe", # "OP RSmallToe", # "OP RHeel", ## 24 Ground Truth joints (superset of joints from different datasets) # "Right Ankle", # "Right Knee", # "Right Hip", # "Left Hip", # "Left Knee", # "Left Ankle", # "Right Wrist", # "Right Elbow", # "Right Shoulder", # "Left Shoulder", # "Left Elbow", # "Left Wrist", # "Neck (LSP)", # "Top of Head (LSP)", # "Pelvis (MPII)", # "Thorax (MPII)", # "Spine (H36M)", # "Jaw (H36M)", # "Head (H36M)", # "Nose", # "Left Eye", # "Right Eye", # "Left Ear", # "Right Ear", # "OP MidHip", # "Spine1", # "Spine2", # "Spine3", # "OP Neck", # "Head", ) # Dict containing the joints in numerical order JOINT_IDS = {JOINT_NAMES[i]: i for i in range(len(JOINT_NAMES))} # Map joints to SMPL joints JOINT_MAP = { "Hips": 0, "Left Upper Leg": 1, "Right Upper Leg": 2, "Spine": 3, "Left Leg": 4, "Right Leg": 5, "Spine1": 6, "Left Foot": 7, "Right Foot": 8, "Thorax": 9, "Left Toe": 10, "Right Toe": 11, "Neck": 12, "Left Shoulder": 13, "Right Shoulder": 14, "Head": 15, "Left ForeArm": 16, "Right ForeArm": 17, "Left Arm": 18, "Right Arm": 19, "Left Hand": 20, "Right Hand": 21, # "OP Nose": 24, # "OP Neck": 12, # "OP RShoulder": 17, # "OP RElbow": 19, # "OP RWrist": 21, # "OP LShoulder": 16, # "OP LElbow": 18, # "OP LWrist": 20, # "OP MidHip": 0, # "OP RHip": 2, # "OP RKnee": 5, # "OP RAnkle": 8, # "OP LHip": 1, # "OP LKnee": 4, # "OP LAnkle": 7, # "OP REye": 25, # "OP LEye": 26, # "OP REar": 27, # "OP LEar": 28, # "OP LBigToe": 29, # "OP LSmallToe": 30, # "OP LHeel": 31, # "OP RBigToe": 32, # "OP RSmallToe": 33, # "OP RHeel": 34, # "Right Ankle": 8, # "Right Knee": 5, # "Right Hip": 45, # "Left Hip": 46, # "Left Knee": 4, # "Left Ankle": 7, # "Right Wrist": 21, # "Right Elbow": 19, # "Right Shoulder": 17, # "Left Shoulder": 16, # "Left Elbow": 18, # "Left Wrist": 20, # "Neck (LSP)": 47, # "Top of Head (LSP)": 15, # 48, # "Pelvis (MPII)": 49, # "Thorax (MPII)": 50, # "Spine (H36M)": 51, # "Jaw (H36M)": 52, # "Head (H36M)": 15, # 53, # "Nose": 24, # "Left Eye": 26, # "Right Eye": 25, # "Left Ear": 28, # "Right Ear": 27, # "Spine1": 3, # "Spine2": 6, # "Spine3": 9, # "Head": 15, } # Joint selectors # Indices to get the 14 LSP joints from the 17 H36M joints H36M_TO_J17 = [6, 5, 4, 1, 2, 3, 16, 15, 14, 11, 12, 13, 8, 10, 0, 7, 9] H36M_TO_J14 = H36M_TO_J17[:14] # Indices to get the 14 LSP joints from the ground truth joints J24_TO_J17 = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 18, 14, 16, 17] J24_TO_J14 = J24_TO_J17[:14] # Permutation of SMPL pose parameters when flipping the shape SMPL_JOINTS_FLIP_PERM = [ 0, 2, 1, 3, 5, 4, 6, 8, 7, 9, 11, 10, 12, 14, 13, 15, 17, 16, 19, 18, 21, 20, 23, 22, ] SMPL_POSE_FLIP_PERM = [] for i in SMPL_JOINTS_FLIP_PERM: SMPL_POSE_FLIP_PERM.append(3 * i) SMPL_POSE_FLIP_PERM.append(3 * i + 1) SMPL_POSE_FLIP_PERM.append(3 * i + 2) # Permutation indices for the 24 ground truth joints J24_FLIP_PERM = [ 5, 4, 3, 2, 1, 0, 11, 10, 9, 8, 7, 6, 12, 13, 14, 15, 16, 17, 18, 19, 21, 20, 23, 22, ] # Permutation indices for the full set of 49 joints J49_FLIP_PERM = [ 0, 1, 5, 6, 7, 2, 3, 4, 8, 12, 13, 14, 9, 10, 11, 16, 15, 18, 17, 22, 23, 24, 19, 20, 21, ] + [25 + i for i in J24_FLIP_PERM] ================================================ FILE: src/spin/hmr.py ================================================ import math import numpy as np import torch import torch.nn as nn import torchvision.models.resnet as resnet def rot6d_to_rotmat(x): """Convert 6D rotation representation to 3x3 rotation matrix. Based on Zhou et al., "On the Continuity of Rotation Representations in Neural Networks", CVPR 2019 Input: (B,6) Batch of 6-D rotation representations Output: (B,3,3) Batch of corresponding rotation matrices """ x = x.view(-1, 3, 2) a1 = x[:, :, 0] a2 = x[:, :, 1] b1 = nn.functional.normalize(a1) b2 = nn.functional.normalize( a2 - torch.einsum("bi,bi->b", b1, a2).unsqueeze(-1) * b1 ) b3 = torch.cross(b1, b2) return torch.stack((b1, b2, b3), dim=-1) class Bottleneck(nn.Module): """Redefinition of Bottleneck residual block Adapted from the official PyTorch implementation """ expansion = 4 def __init__(self, inplanes, planes, stride=1, downsample=None): super(Bottleneck, self).__init__() self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) self.bn1 = nn.BatchNorm2d(planes) self.conv2 = nn.Conv2d( planes, planes, kernel_size=3, stride=stride, padding=1, bias=False ) self.bn2 = nn.BatchNorm2d(planes) self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) self.bn3 = nn.BatchNorm2d(planes * 4) self.relu = nn.ReLU(inplace=True) self.downsample = downsample self.stride = stride def forward(self, x): residual = x out = self.conv1(x) out = self.bn1(out) out = self.relu(out) out = self.conv2(out) out = self.bn2(out) out = self.relu(out) out = self.conv3(out) out = self.bn3(out) if self.downsample is not None: residual = self.downsample(x) out += residual out = self.relu(out) return out class HMR(nn.Module): """SMPL Iterative Regressor with ResNet50 backbone""" def __init__(self, block, layers, smpl_mean_params): self.inplanes = 64 super(HMR, self).__init__() self.n_shape = 10 self.n_cam = 3 self.n_joints = 24 npose = self.n_joints * 6 self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False) self.bn1 = nn.BatchNorm2d(64) self.relu = nn.ReLU(inplace=True) self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) self.layer1 = self._make_layer(block, 64, layers[0]) self.layer2 = self._make_layer(block, 128, layers[1], stride=2) self.layer3 = self._make_layer(block, 256, layers[2], stride=2) self.layer4 = self._make_layer(block, 512, layers[3], stride=2) self.avgpool = nn.AvgPool2d(7, stride=1) self.fc1 = nn.Linear(512 * block.expansion + npose + self.n_shape + self.n_cam, 1024) self.drop1 = nn.Dropout() self.fc2 = nn.Linear(1024, 1024) self.drop2 = nn.Dropout() self.decpose = nn.Linear(1024, npose) self.decshape = nn.Linear(1024, self.n_shape) self.deccam = nn.Linear(1024, self.n_cam) nn.init.xavier_uniform_(self.decpose.weight, gain=0.01) nn.init.xavier_uniform_(self.decshape.weight, gain=0.01) nn.init.xavier_uniform_(self.deccam.weight, gain=0.01) for m in self.modules(): if isinstance(m, nn.Conv2d): n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels m.weight.data.normal_(0, math.sqrt(2.0 / n)) elif isinstance(m, nn.BatchNorm2d): m.weight.data.fill_(1) m.bias.data.zero_() mean_params = np.load(smpl_mean_params) init_pose = torch.from_numpy(mean_params["pose"][:]).unsqueeze(0) init_shape = torch.from_numpy( mean_params["shape"][:].astype("float32") ).unsqueeze(0) init_cam = torch.from_numpy(mean_params["cam"]).unsqueeze(0) self.register_buffer("init_pose", init_pose) self.register_buffer("init_shape", init_shape) self.register_buffer("init_cam", init_cam) def _make_layer(self, block, planes, blocks, stride=1): downsample = None if stride != 1 or self.inplanes != planes * block.expansion: downsample = nn.Sequential( nn.Conv2d( self.inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False, ), nn.BatchNorm2d(planes * block.expansion), ) layers = [] layers.append(block(self.inplanes, planes, stride, downsample)) self.inplanes = planes * block.expansion for _ in range(1, blocks): layers.append(block(self.inplanes, planes)) return nn.Sequential(*layers) def forward(self, x, init_pose=None, init_shape=None, init_cam=None, n_iter=3): batch_size = x.shape[0] if init_pose is None: init_pose = self.init_pose.expand(batch_size, -1) if init_shape is None: init_shape = self.init_shape.expand(batch_size, -1) if init_cam is None: init_cam = self.init_cam.expand(batch_size, -1) x = self.conv1(x) x = self.bn1(x) x = self.relu(x) x = self.maxpool(x) x1 = self.layer1(x) x2 = self.layer2(x1) x3 = self.layer3(x2) x4 = self.layer4(x3) xf = self.avgpool(x4) xf = xf.view(xf.size(0), -1) pred_pose = init_pose pred_shape = init_shape pred_cam = init_cam for _ in range(n_iter): xc = torch.cat([xf, pred_pose, pred_shape, pred_cam], 1) xc = self.fc1(xc) xc = self.drop1(xc) xc = self.fc2(xc) xc = self.drop2(xc) pred_pose = self.decpose(xc) + pred_pose pred_shape = self.decshape(xc) + pred_shape pred_cam = self.deccam(xc) + pred_cam pred_rotmat = rot6d_to_rotmat(pred_pose).view(batch_size, self.n_joints, 3, 3) return pred_rotmat, pred_shape, pred_cam def hmr(smpl_mean_params, pretrained=True, **kwargs): """Constructs an HMR model with ResNet50 backbone. Args: pretrained (bool): If True, returns a model pre-trained on ImageNet """ model = HMR(Bottleneck, [3, 4, 6, 3], smpl_mean_params, **kwargs) if pretrained: resnet_imagenet = resnet.resnet50(pretrained=True) model.load_state_dict(resnet_imagenet.state_dict(), strict=False) return model ================================================ FILE: src/spin/smpl.py ================================================ import numpy as np import torch from smplx import SMPL as _SMPL from smplx import SMPLX as _SMPLX from smplx.body_models import SMPLOutput, SMPLXOutput from smplx.lbs import vertices2joints from .constants import JOINT_MAP, JOINT_NAMES # Hand joints SMPLX_HAND_TO_PANOPTIC = [ 0, 13, 14, 15, 16, 1, 2, 3, 17, 4, 5, 6, 18, 10, 11, 12, 19, 7, 8, 9, 20, ] # Wrist Thumb to Pinky class SMPL(_SMPL): """Extension of the official SMPL implementation to support more joints""" JOINTS = ( "Hips", "Left Upper Leg", "Right Upper Leg", "Spine", "Left Leg", "Right Leg", "Spine1", "Left Foot", "Right Foot", "Thorax", "Left Toe", "Right Toe", "Neck", "Left Shoulder", "Right Shoulder", "Head", "Left ForeArm", "Right ForeArm", "Left Arm", "Right Arm", "Left Hand", "Right Hand", "Left Finger", "Right Finger", ) SKELETON = ( (0, 1), (0, 2), (0, 3), (1, 4), (2, 5), (3, 6), (4, 7), (5, 8), (6, 9), (7, 10), (8, 11), (9, 12), (12, 13), (12, 14), (12, 15), (13, 16), (14, 17), (16, 18), (17, 19), (18, 20), (19, 21), (20, 22), (21, 23), ) def __init__(self, *args, **kwargs): super(SMPL, self).__init__(*args, **kwargs) joints = [JOINT_MAP[i] for i in JOINT_NAMES] joint_regressor_extra = kwargs["joint_regressor_extra_path"] J_regressor_extra = np.load(joint_regressor_extra) self.register_buffer( "J_regressor_extra", torch.tensor(J_regressor_extra, dtype=torch.float32) ) self.joint_map = torch.tensor(joints, dtype=torch.long) def forward(self, *args, **kwargs): kwargs["get_skin"] = True smpl_output = super(SMPL, self).forward(*args, **kwargs) extra_joints = vertices2joints( self.J_regressor_extra, smpl_output.vertices ) # Additional 9 joints #Check doc/J_regressor_extra.png joints = torch.cat( [smpl_output.joints, extra_joints], dim=1 ) # [N, 24 + 21, 3] + [N, 9, 3] joints = joints[:, self.joint_map, :] output = SMPLOutput( vertices=smpl_output.vertices, global_orient=smpl_output.global_orient, body_pose=smpl_output.body_pose, joints=joints, betas=smpl_output.betas, full_pose=smpl_output.full_pose, ) return output class SMPLX(_SMPLX): """Extension of the official SMPL implementation to support more joints""" JOINTS = ( "Hips", "Left Upper Leg", "Right Upper Leg", "Spine", "Left Leg", "Right Leg", "Spine1", "Left Foot", "Right Foot", "Thorax", "Left Toe", "Right Toe", "Neck", "Left Shoulder", "Right Shoulder", "Head", "Left ForeArm", "Right ForeArm", "Left Arm", "Right Arm", "Left Hand", "Right Hand", ) SKELETON = ( (0, 1), (0, 2), (0, 3), (1, 4), (2, 5), (3, 6), (4, 7), (5, 8), (6, 9), (7, 10), (8, 11), (9, 12), (12, 13), (12, 14), (12, 15), (13, 16), (14, 17), (16, 18), (17, 19), (18, 20), (19, 21), ) def __init__(self, *args, **kwargs): kwargs["ext"] = "pkl" # We have pkl file super(SMPLX, self).__init__(*args, **kwargs) joints = [JOINT_MAP[i] for i in JOINT_NAMES] self.joint_map = torch.tensor(joints, dtype=torch.long) def forward(self, *args, **kwargs): kwargs["get_skin"] = True # if pose parameter is for SMPL with 21 joints (ignoring root) try: if kwargs["body_pose"].shape[1] == 69: kwargs["body_pose"] = kwargs["body_pose"][ :, : -2 * 3 ] # Ignore the last two joints (which are on the palm. Not used) if kwargs["body_pose"].shape[1] == 23: kwargs["body_pose"] = kwargs["body_pose"][ :, :-2 ] # Ignore the last two joints (which are on the palm. Not used) except: pass smpl_output = super(SMPLX, self).forward(*args, **kwargs) # SMPL-X Joint order: https://docs.google.com/spreadsheets/d/1_1dLdaX-sbMkCKr_JzJW_RZCpwBwd7rcKkWT_VgAQ_0/edit#gid=0 smplx_to_smpl = ( list(range(0, 22)) + [28, 43] + list(range(55, 76)) ) # 28 left middle finger , 43: right middle finger 1 smpl_joints = smpl_output.joints[ :, smplx_to_smpl, : ] # Convert SMPL-X to SMPL 127 ->45 joints = smpl_joints joints = joints[:, self.joint_map, :] smplx_lhand = ( [20] + list(range(25, 40)) + list(range(66, 71)) ) # 20 for left wrist. 20 finger joints lhand_joints = smpl_output.joints[:, smplx_lhand, :] # (N,21,3) lhand_joints = lhand_joints[ :, SMPLX_HAND_TO_PANOPTIC, : ] # Convert SMPL-X hand order to paonptic hand order smplx_rhand = ( [21] + list(range(40, 55)) + list(range(71, 76)) ) # 21 for right wrist. 20 finger joints rhand_joints = smpl_output.joints[:, smplx_rhand, :] # (N,21,3) rhand_joints = rhand_joints[ :, SMPLX_HAND_TO_PANOPTIC, : ] # Convert SMPL-X hand order to paonptic hand order output = SMPLXOutput( vertices=smpl_output.vertices, global_orient=smpl_output.global_orient, body_pose=smpl_output.body_pose, joints=joints, right_hand_pose=rhand_joints, # N,21,3 left_hand_pose=lhand_joints, # N,21,3 betas=smpl_output.betas, full_pose=smpl_output.full_pose, A=smpl_output.A, ) return output """ 0 pelvis', 1 left_hip', 2 right_hip', 3 spine1', 4 left_knee', 5 right_knee', 6 spine2', 7 left_ankle', 8 right_ankle', 9 spine3', 10 left_foot', 11 right_foot', 12 neck', 13 left_collar', 14 right_collar', 15 head', 16 left_shoulder', 17 right_shoulder', 18 left_elbow', 19 right_elbow', 20 left_wrist', 21 right_wrist', 22 jaw', 23 left_eye_smplhf', 24 right_eye_smplhf', 25 left_index1', 26 left_index2', 27 left_index3', 28 left_middle1', 29 left_middle2', 30 left_middle3', 31 left_pinky1', 32 left_pinky2', 33 left_pinky3', 34 left_ring1', 35 left_ring2', 36 left_ring3', 37 left_thumb1', 38 left_thumb2', 39 left_thumb3', 40 right_index1', 41 right_index2', 42 right_index3', 43 right_middle1', 44 right_middle2', 45 right_middle3', 46 right_pinky1', 47 right_pinky2', 48 right_pinky3', 49 right_ring1', 50 right_ring2', 51 right_ring3', 52 right_thumb1', 53 right_thumb2', 54 right_thumb3', 55 nose', 56 right_eye', 57 left_eye', 58 right_ear', 59 left_ear', 60 left_big_toe', 61 left_small_toe', 62 left_heel', 63 right_big_toe', 64 right_small_toe', 65 right_heel', 66 left_thumb', 67 left_index', 68 left_middle', 69 left_ring', 70 left_pinky', 71 right_thumb', 72 right_index', 73 right_middle', 74 right_ring', 75 right_pinky', 76 right_eye_brow1', 77 right_eye_brow2', 78 right_eye_brow3', 79 right_eye_brow4', 80 right_eye_brow5', 81 left_eye_brow5', 82 left_eye_brow4', 83 left_eye_brow3', 84 left_eye_brow2', 85 left_eye_brow1', 86 nose1', 87 nose2', 88 nose3', 89 nose4', 90 right_nose_2', 91 right_nose_1', 92 nose_middle', 93 left_nose_1', 94 left_nose_2', 95 right_eye1', 96 right_eye2', 97 right_eye3', 98 right_eye4', 99 right_eye5', 100 right_eye6', 101 left_eye4', 102 left_eye3', 103 left_eye2', 104 left_eye1', 105 left_eye6', 106 left_eye5', 107 right_mouth_1', 108 right_mouth_2', 109 right_mouth_3', 110 mouth_top', 111 left_mouth_3', 112 left_mouth_2', 113 left_mouth_1', 114 left_mouth_5', # 59 in OpenPose output 115 left_mouth_4', # 58 in OpenPose output 116 mouth_bottom', 117 right_mouth_4', 118 right_mouth_5', 119 right_lip_1', 120 right_lip_2', 121 lip_top', 122 left_lip_2', 123 left_lip_1', 124 left_lip_3', 125 lip_bottom', 126 right_lip_3', 127 right_contour_1', 128 right_contour_2', 129 right_contour_3', 130 right_contour_4', 131 right_contour_5', 132 right_contour_6', 133 right_contour_7', 134 right_contour_8', 135 contour_middle', 136 left_contour_8', 137 left_contour_7', 138 left_contour_6', 139 left_contour_5', 140 left_contour_4', 141 left_contour_3', 142 left_contour_2', 143 left_contour_1' """ # SMPL Joints: """ 0 pelvis', 1 left_hip', 2 right_hip', 3 spine1', 4 left_knee', 5 right_knee', 6 spine2', 7 left_ankle', 8 right_ankle', 9 spine3', 10 left_foot', 11 right_foot', 12 neck', 13 left_collar', 14 right_collar', 15 head', 16 left_shoulder', 17 right_shoulder', 18 left_elbow', 19 right_elbow', 20 left_wrist', 21 right_wrist', 22 23 24 nose', 25 right_eye', 26 left_eye', 27 right_ear', 28 left_ear', 29 left_big_toe', 30 left_small_toe', 31 left_heel', 32 right_big_toe', 33 right_small_toe', 34 right_heel', 35 left_thumb', 36 left_index', 37 left_middle', 38 left_ring', 39 left_pinky', 40 right_thumb', 41 right_index', 42 right_middle', 43 right_ring', 44 right_pinky', """ ================================================ FILE: src/spin/utils.py ================================================ import json import cv2 import numpy as np import torch from skimage.transform import resize, rotate from torchvision.transforms import Normalize from .constants import IMG_NORM_MEAN, IMG_NORM_STD, IMG_RES def get_transform(center, scale, res, rot=0): """Generate transformation matrix.""" h = 200 * scale t = np.zeros((3, 3)) t[0, 0] = float(res[1]) / h t[1, 1] = float(res[0]) / h t[0, 2] = res[1] * (-float(center[0]) / h + 0.5) t[1, 2] = res[0] * (-float(center[1]) / h + 0.5) t[2, 2] = 1 if not rot == 0: rot = -rot # To match direction of rotation from cropping rot_mat = np.zeros((3, 3)) rot_rad = rot * np.pi / 180 sn, cs = np.sin(rot_rad), np.cos(rot_rad) rot_mat[0, :2] = [cs, -sn] rot_mat[1, :2] = [sn, cs] rot_mat[2, 2] = 1 # Need to rotate around center t_mat = np.eye(3) t_mat[0, 2] = -res[1] / 2 t_mat[1, 2] = -res[0] / 2 t_inv = t_mat.copy() t_inv[:2, 2] *= -1 t = np.dot(t_inv, np.dot(rot_mat, np.dot(t_mat, t))) return t def transform(pt, center, scale, res, invert=0, rot=0): """Transform pixel location to different reference.""" t = get_transform(center, scale, res, rot=rot) if invert: t = np.linalg.inv(t) new_pt = np.array([pt[0] - 1, pt[1] - 1, 1.0]).T new_pt = np.dot(t, new_pt) return new_pt[:2].astype(int) + 1 def crop(img, center, scale, res, rot=0): """Crop image according to the supplied bounding box.""" # Upper left point ul = np.array(transform([1, 1], center, scale, res, invert=1)) - 1 # Bottom right point br = np.array(transform([res[0] + 1, res[1] + 1], center, scale, res, invert=1)) - 1 # Padding so that when rotated proper amount of context is included pad = int(np.linalg.norm(br - ul) / 2 - float(br[1] - ul[1]) / 2) if not rot == 0: ul -= pad br += pad new_shape = [br[1] - ul[1], br[0] - ul[0]] if len(img.shape) > 2: new_shape += [img.shape[2]] new_img = np.zeros(new_shape) # Range to fill new array new_x = max(0, -ul[0]), min(br[0], len(img[0])) - ul[0] new_y = max(0, -ul[1]), min(br[1], len(img)) - ul[1] # Range to sample from original image old_x = max(0, ul[0]), min(len(img[0]), br[0]) old_y = max(0, ul[1]), min(len(img), br[1]) new_img[new_y[0] : new_y[1], new_x[0] : new_x[1]] = img[ old_y[0] : old_y[1], old_x[0] : old_x[1] ] if not rot == 0: # Remove padding new_img = rotate(new_img, rot) new_img = new_img[pad:-pad, pad:-pad] new_img = resize(new_img, res) return new_img def bbox_from_openpose(openpose_file, rescale=1.2, detection_thresh=0.2): """Get center and scale for bounding box from openpose detections.""" with open(openpose_file, "r") as f: keypoints = json.load(f)["people"][0]["pose_keypoints_2d"] keypoints = np.reshape(np.array(keypoints), (-1, 3)) valid = keypoints[:, -1] > detection_thresh valid_keypoints = keypoints[valid][:, :-1] center = valid_keypoints.mean(axis=0) bbox_size = (valid_keypoints.max(axis=0) - valid_keypoints.min(axis=0)).max() # adjust bounding box tightness scale = bbox_size / 200.0 scale *= rescale return center, scale def bbox_from_json(bbox_file): """Get center and scale of bounding box from bounding box annotations. The expected format is [top_left(x), top_left(y), width, height]. """ with open(bbox_file, "r") as f: bbox = np.array(json.load(f)["bbox"]).astype(np.float32) ul_corner = bbox[:2] center = ul_corner + 0.5 * bbox[2:] width = max(bbox[2], bbox[3]) scale = width / 200.0 # make sure the bounding box is rectangular return center, scale def process_image(img_file, bbox_file=None, openpose_file=None, input_res=IMG_RES): """Read image, do preprocessing and possibly crop it according to the bounding box. If there are bounding box annotations, use them to crop the image. If no bounding box is specified but openpose detections are available, use them to get the bounding box. """ img_file = str(img_file) normalize_img = Normalize(mean=IMG_NORM_MEAN, std=IMG_NORM_STD) img = cv2.imread(img_file)[ :, :, ::-1 ].copy() # PyTorch does not support negative stride at the moment if bbox_file is None and openpose_file is None: # Assume that the person is centerered in the image height = img.shape[0] width = img.shape[1] center = np.array([width // 2, height // 2]) scale = max(height, width) / 200 else: if bbox_file is not None: center, scale = bbox_from_json(bbox_file) elif openpose_file is not None: center, scale = bbox_from_openpose(openpose_file) img = crop(img, center, scale, (input_res, input_res)) img = img.astype(np.float32) / 255.0 img = torch.from_numpy(img).permute(2, 0, 1) norm_img = normalize_img(img.clone()) return img, norm_img ================================================ FILE: src/utils.py ================================================ import colorsys import itertools import json import pickle import cv2 import plotly.graph_objects as go import trimesh import torch import numpy as np import PIL.Image as pil_img import PIL.ImageDraw as ImageDraw from PIL import Image, ImageChops from skimage import exposure import spin import renderer import pose_estimation def load_json(path): with open(path) as f: return json.load(f) def save_json(o, path): with open(path, "w") as f: json.dump(o, f) def load_pkl(path): with open(path, "rb") as f: return pickle.load(f) def save_pkl(o, path): with open(path, "wb") as f: pickle.dump(o, f) def plot_3D(joints, vertices, faces): x, y, z = joints.T x1, y1, z1 = vertices.T i, j, k = faces.T data = [ go.Mesh3d( x=x1, y=y1, z=z1, i=i, j=j, k=k, ), go.Scatter3d( x=x, y=y, z=z, mode="markers", marker_size=5, ), ] fig = go.Figure( data=data, ) return fig def draw_keypoints( input_img_kp, keypoints, skeleton, r, color, contact2dlist=None, contact2dlist_color="green", cos=None, ): if keypoints is not None: draw = ImageDraw.Draw(input_img_kp) for skidx, (i, j) in enumerate(skeleton): a = keypoints[i] b = keypoints[j] ln = np.linalg.norm(b - a) xy = [a[0], a[1], b[0], b[1]] if cos is not None: c = colorsys.hsv_to_rgb(cos[skidx] ** 8, 1, 1) c = tuple(int(c_ * 255) for c_ in c) draw.line(xy, fill=c, width=r) else: draw.line(xy, fill=color, width=r) draw_kpts = [(p[0] - r, p[1] - r, p[0] + r, p[1] + r) for p in keypoints] for _, elipse in enumerate(draw_kpts): draw.ellipse(elipse, fill="black", outline="black") if contact2dlist is not None: keypoints_torch = torch.from_numpy(keypoints) for c2d in contact2dlist: for (src_1, dst_1, t_1), (src_2, dst_2, t_2) in itertools.combinations( c2d, 2 ): a = torch.lerp( keypoints_torch[src_1], keypoints_torch[dst_1], t_1 ).tolist() b = torch.lerp( keypoints_torch[src_2], keypoints_torch[dst_2], t_2 ).tolist() xy = [a[0], a[1], b[0], b[1]] draw.line(xy, fill=contact2dlist_color, width=max(r // 3, 10)) return input_img_kp def save_results_image( camera, focal_length_x, focal_length_y, input_img, vertices, faces, filename, keypoints=None, keypoints_2=None, heatmap=None, cvt_camera=True, contactlist=None, contact2dlist=None, user_study=True, cos=None, ): if isinstance(contactlist, list) and len(contactlist) > 0: contactlist = np.concatenate(contactlist) H, W, _ = input_img.shape HW = max(H, W) camera_center = np.array([W // 2, H // 2]) if not cvt_camera: camera_transl = camera.copy() else: camera_transl = np.stack( [ camera[1], camera[2], 1 / camera[0], ], ) # draw keypoints input_img_kp = pil_img.fromarray(input_img) if keypoints is not None: draw_keypoints( input_img_kp, keypoints, pose_estimation.SKELETON, r=int(HW * 0.01), color=(255, 0, 0, 255), # contact2dlist=contact2dlist, # contact2dlist_color="orange", ) if cos is not None: input_img_kp_cos = pil_img.fromarray(input_img) if keypoints is not None: draw_keypoints( input_img_kp_cos, keypoints, pose_estimation.SKELETON, r=int(HW * 0.01), color=(255, 0, 0, 255), # contact2dlist=contact2dlist, # contact2dlist_color="orange", cos=cos, ) input_img_kp_2 = input_img_kp.copy() if keypoints_2 is not None: draw_keypoints( input_img_kp_2, keypoints_2, # spin.SMPLX.SKELETON if "eft" in str(filename) else pose_estimation.SKELETON, pose_estimation.SKELETON, r=int(HW * 0.01), color=(0, 0, 255, 255), contact2dlist=contact2dlist, contact2dlist_color="purple", ) # heatmap = ImageOps.invert(heatmap) if heatmap is not None: # input_img_kp_2 = pil_img.blend(input_img_kp_2, heatmap, 0.5) hm = np.copy(input_img) gray_img = exposure.rescale_intensity(heatmap, out_range=(0, 255)) gray_img = gray_img.astype(np.uint8) heatmap_img = cv2.applyColorMap(gray_img, cv2.COLORMAP_JET) hm = pil_img.fromarray(cv2.cvtColor(heatmap_img, cv2.COLOR_BGR2RGB)) hm.save(filename.with_stem(f"{filename.stem}_heatmap")) input_img_kp.save(filename.with_stem(f"{filename.stem}_2dkps")) heatmap = pil_img.fromarray(heatmap) if cos is not None: input_img_kp_cos.save(filename.with_stem(f"{filename.stem}_2dkpscos")) # render fitted mesh from different views overlay_fit_img = renderer.overlay_mesh( vertices, faces, camera_transl, focal_length_x, focal_length_y, camera_center, H, W, input_img.astype("float32") / 255, None, rotaround=None, ) # overlay_fit_img = pil_img.fromarray(overlay_fit_img) # draw_keypoints(overlay_fit_img, keypoints_2, r=int(HW * 0.01), color=(0, 0, 255, 255)) # camera_transl[-1] *= 1 view1_fit = renderer.overlay_mesh( vertices, faces, camera_transl.astype(np.float32), focal_length_x, focal_length_y, camera_center, H, W, None, None, rotaround=-45, contactlist=contactlist, ) view2_fit = renderer.overlay_mesh( vertices, faces, camera_transl.astype(np.float32), focal_length_x, focal_length_y, camera_center, H, W, None, None, rotaround=None, contactlist=contactlist, ) view3_fit = renderer.overlay_mesh( vertices, faces, camera_transl.astype(np.float32), focal_length_x, focal_length_y, camera_center, H, W, None, None, rotaround=90, contactlist=contactlist, scale=1, ) IMG = np.vstack( ( np.hstack( ( np.asarray(input_img_kp) if keypoints is not None else 255 * np.ones_like(np.asarray(input_img_kp)), np.asarray(input_img_kp_2), overlay_fit_img, # np.asanyarray(overlay_fit_img), ), ), np.hstack( ( view1_fit, view2_fit, view3_fit, ), ), ), ) IMG = pil_img.fromarray(IMG) IMG.thumbnail((2000, 2000)) IMG.save(filename) if user_study: w = 768 input_img_kp.thumbnail((w, w)) input_img_kp.save(filename.with_stem(f"{filename.stem}_orig")) W, H = input_img_kp.size camera_transl[-1] *= 2 view2_fit = renderer.overlay_mesh( vertices, faces, camera_transl.astype(np.float32), focal_length_x, focal_length_y, camera_center, H, W, None, None, rotaround=None, contactlist=contactlist, scale=2, ) view2_fit = pil_img.fromarray(view2_fit) w *= 2 view2_fit.thumbnail((w, w)) view2_fit.save(filename.with_stem(f"{filename.stem}_same")) view3_fit = renderer.overlay_mesh( vertices, faces, camera_transl.astype(np.float32), focal_length_x, focal_length_y, camera_center, H, W, None, None, rotaround=90, contactlist=contactlist, scale=2, ) view3_fit = pil_img.fromarray(view3_fit) view3_fit.thumbnail((w, w)) view3_fit.save(filename.with_stem(f"{filename.stem}_alt")) return IMG def save_3d_model_on_img( camera, vertices, faces, img, filename, save_path, ): img_res = max(img.shape[:2]) r = renderer.Renderer( focal_length=spin.constants.FOCAL_LENGTH, img_res=img_res, faces=faces, ) # Calculate camera parameters for rendering camera_translation = np.stack( [ camera[1], camera[2], 2 * spin.constants.FOCAL_LENGTH / (img_res * camera[0] + 1e-9), ], ) # Render parametric shape img_shape = r(vertices, camera_translation, img) img_shape = (255 * img_shape).astype("uint8") img_shape = cv2.cvtColor(img_shape, cv2.COLOR_RGB2BGR) cv2.imwrite(str(save_path / f"shape_{filename}.png"), img_shape) # Render side views aroundy = cv2.Rodrigues(np.array([0, np.radians(90.0), 0]))[0] center = vertices.mean(axis=0) rot_vertices = np.dot((vertices - center), aroundy) + center # Render non-parametric shape img_shape = r(rot_vertices, camera_translation, np.ones_like(img)) img_shape = (255 * img_shape).astype("uint8") img_shape = cv2.cvtColor(img_shape, cv2.COLOR_RGB2BGR) cv2.imwrite(str(save_path / f"shape_rot_{filename}.png"), img_shape) def save_mesh_with_colors(vertices, faces, save_path, mask=None, inds=None): if inds is not None and isinstance(inds, list): inds = np.concatenate(inds) mesh = trimesh.Trimesh( vertices=vertices, faces=faces, process=False, ) color = np.array(mesh.visual.vertex_colors) color[:] = [233, 233, 233, 255] if mask is not None and any(mask): color[~mask] = [255, 0, 0, 255] elif inds is not None and len(inds) > 0: color[inds] = [255, 0, 0, 255] mesh.visual.vertex_colors = color mesh.export(save_path) def save_pose_params(rotmat, camera, betas, vertices, smpl, contact, save_path): if contact is not None and isinstance(contact, list) and len(contact) > 0: contact = np.concatenate(contact) rotmat = rotmat.detach() camera = camera.detach() if smpl.name() == "SMPL-X": rotmat = rotmat[: -2 * 3] res = { "camera_s_t": camera.unsqueeze(0).cpu().numpy(), "global_orient": rotmat[:3].unsqueeze(0).cpu().numpy(), "betas": betas, "body_pose": rotmat[3:].unsqueeze(0).cpu().numpy(), "left_hand_pose": smpl.left_hand_pose.unsqueeze(0).detach().cpu().numpy(), "right_hand_pose": smpl.right_hand_pose.unsqueeze(0).detach().cpu().numpy(), "model": smpl.name().replace("-", ""), "gender": smpl.gender, "vertices": vertices[0].cpu().numpy(), } if contact is not None and len(contact) > 0: contact = np.array(contact) res["contact"] = contact else: res["v"] = vertices[0].cpu().numpy() save_pkl(res, save_path) np.savez(save_path, **res)