Showing preview only (2,485K chars total). Download the full file or copy to clipboard to get everything.
Repository: generalizable-neural-performer/gnr
Branch: main
Commit: 5f670850a013
Files: 37
Total size: 2.4 MB
Directory structure:
gitextract_wh9gfa65/
├── .gitignore
├── .gitmodules
├── README.md
├── apps/
│ ├── render_smpl_depth.py
│ └── run_genebody.py
├── configs/
│ ├── render.txt
│ ├── test.txt
│ └── train.txt
├── docs/
│ ├── Annotation.md
│ └── Dataset.md
├── environment.yml
├── genebody/
│ ├── download_tool.py
│ ├── gender.py
│ ├── genebody.py
│ └── mesh.py
├── lib/
│ ├── data/
│ │ ├── GeneBodyDataset.py
│ │ └── __init__.py
│ ├── geometry.py
│ ├── mesh_util.py
│ ├── metrics.py
│ ├── metrics_torch.py
│ ├── model/
│ │ ├── Embedder.py
│ │ ├── GNR.py
│ │ ├── HGFilters.py
│ │ ├── NeRF.py
│ │ ├── NeRFRenderer.py
│ │ ├── SRFilters.py
│ │ └── __init__.py
│ ├── net_ddp.py
│ ├── net_util.py
│ ├── options.py
│ └── ply_util.py
├── scripts/
│ ├── download_model.sh
│ ├── render_smpl_depth.sh
│ └── train_ddp.sh
└── smpl_t_pose/
├── smpl.obj
└── smplx.obj
================================================
FILE CONTENTS
================================================
================================================
FILE: .gitignore
================================================
checkpoints/*
data/*
**.ply
results/*
sample_images/*
logs/*
**/__pycache__/*
================================================
FILE: .gitmodules
================================================
[submodule "benchmarks/ibrnet"]
path = benchmarks/ibrnet
url = https://github.com/generalizable-neural-performer/genebody-benchmarks/
branch = ibrnet
[submodule "benchmarks/nv"]
path = benchmarks/nv
url = https://github.com/generalizable-neural-performer/genebody-benchmarks/
branch = nv
[submodule "benchmarks/nt"]
path = benchmarks/nt
url = https://github.com/generalizable-neural-performer/genebody-benchmarks/
branch = nt
[submodule "benchmarks/nb"]
path = benchmarks/nb
url = https://github.com/generalizable-neural-performer/genebody-benchmarks/
branch = nb
[submodule "benchmarks/anerf"]
path = benchmarks/anerf
url = https://github.com/generalizable-neural-performer/genebody-benchmarks/
branch = A-Nerf
[submodule "benchmarks/nhr"]
path = benchmarks/nhr
url = https://github.com/generalizable-neural-performer/genebody-benchmarks/
branch = nhr
================================================
FILE: README.md
================================================
# Generalizable Neural Performer: Learning Robust Radiance Fields for Human Novel View Synthesis
[](http://arxiv.org/abs/2204.11798)
<!-- []() -->

> **Abstract:** *This work targets using a general deep learning framework to synthesize free-viewpoint images of arbitrary human performers, only requiring a sparse number of camera views as inputs and skirting per-case fine-tuning. The large variation of geometry and appearance, caused by articulated body poses, shapes, and clothing types, are the key bottlenecks of this task. To overcome these challenges, we present a simple yet powerful framework, named Generalizable Neural Performer (GNR), that learns a generalizable and robust neural body representation over various geometry and appearance. Specifically, we compress the light fields for a novel view of human rendering as conditional implicit neural radiance fields with several designs from both geometry and appearance aspects. We first introduce an Implicit Geometric Body Embedding strategy to enhance the robustness based on both parametric 3D human body model prior and multi-view source images hints. On top of this, we further propose a Screen-Space Occlusion-Aware Appearance Blending technique to preserve the high-quality appearance, through interpolating source view appearance to the radiance fields with a relaxed but approximate geometric guidance.* <br>
[Wei Cheng](mailto:wchengad@connect.ust.hk), [Su Xu](mailto:xusu@sensetime.com), [Jingtan Piao](mailto:piaojingtan@sensetime.com), [Chen Qian](https://scholar.google.com/citations?user=AerkT0YAAAAJ&hl=zh-CN), [Wayne Wu](https://wywu.github.io/), [Kwan-Yee Lin](https://kwanyeelin.github.io/), [Hongsheng Li](https://www.ee.cuhk.edu.hk/~hsli/)<br>
**[[Demo Video]](https://www.youtube.com/watch?v=2COR4u1ZIuk)** | **[[Project Page]](https://generalizable-neural-performer.github.io/)** | **[[Data]](https://generalizable-neural-performer.github.io/genebody.html)** | **[[Paper]](https://arxiv.org/pdf/2204.11798.pdf)**
## Updates
- [14/07/2023] :star2::star2::star2:Check out our sister dataset -- [DNA-Rendering](https://dna-rendering.github.io/)! :star2::star2::star2: It captures humans with up to 4K resolution, and enjoys more accurate annotations, attributes, and diversities than Genbody1.0.
- [01/09/2022] We also recommend the implementation of our work in the [OpenXRLab](https://github.com/openxrlab/xrnerf).
- [01/09/2022] :exclamation: GeneBody has been reframed. For users who have downloaded GeneBody before `2022.09.01` please update the latest data using our more user-friendly download tool.
- [29/07/2022] GeneBody can be downloaded from [OpenDataLab](https://opendatalab.com/GeneBody).
- [11/07/2022] Code is released.
- [02/05/2022] GeneBody Train40 is released! Apply [here](./docs/Dataset.md#train40)!
- [29/04/2022] SMPLx fitting toolbox and benchmarks are released!
- [26/04/2022] Technical report released.
- [24/04/2022] The codebase and project page are created.
## Data Download
To download and use the GeneBody dataset set, please first read the instructions in [Dataset.md](./docs/Dataset.md). We provide a download tool to download and update the GeneBody data including dataset and pretrained models (if there is any future adjustment), for example
```
python genebody/download_tool.py --genebody_root ${GENEBODY_ROOT} --subset train40 test10 pretrained_models smpl_depth
```
The tool will fetch and download the subsets you selected and put the data in `${GENEBODY_ROOT}`.
## Annotations
GeneBody provides the per-view per-frame segmentation, using [BackgroundMatting-V2](https://github.com/PeterL1n/BackgroundMattingV2), and register the fitted [SMPLx](https://github.com/PeterL1n/BackgroundMattingV2) using our enhanced multi-view smplify repo in [here](https://github.com/generalizable-neural-performer/bodyfitting).
To use annotations of GeneBody, please check the document [Annotation.md](./docs/Annotation.md), we provide a reference data fetch module in `genebody`.
## Train and Evaluate GNR
Setup the environment
```
conda env create -f environment.yml
conda activate gnr
pip install git+https://github.com/generalizable-neural-performer/gnr.git@mesh_grid
```
To run GNR on genebody
```
python apps/run_genebody.py --config configs/[train, test, render].txt --dataroot ${GENEBODY_ROOT}
```
if you have multiple machines and multiple GPUs, you can try to train our model using distributed data parallel
```
bash scripts/train_ddp.sh
```
## Benchmarks
We also provide benchmarks of start-of-the-art methods on GeneBody Dataset, methods and requirements are listed in [Benchmarks.md](https://github.com/generalizable-neural-performer/genebody-benchmarks).
To test the performance of our released pretrained models or train by yourselves, run:
```
git clone --recurse-submodules https://github.com/generalizable-neural-performer/gnr.git
```
And `cd benchmarks/`, the released benchmarks are ready to go on Genebody and other datasets such as V-sense and ZJU-Mocap.
### Case-specific Methods on Genebody
| Model | PSNR | SSIM |LPIPS| ckpts|
| :--- | :---------------:|:---------------:| :---------------:| :---------------: |
| [NV](https://github.com/generalizable-neural-performer/genebody-benchmarks/tree/nv)| 19.86 |0.774 | 0.267 | [ckpts](https://hkustconnect-my.sharepoint.com/:f:/g/personal/wchengad_connect_ust_hk/EniK9r9UdbtGvYvtJITBGkIBlmxSHqaoEIiIgpYBGddCHQ?e=RbS0sG)|
| [NHR](https://github.com/generalizable-neural-performer/genebody-benchmarks/tree/nhr)| 20.05 |0.800 | 0.155 | [ckpts](https://hkustconnect-my.sharepoint.com/:f:/g/personal/wchengad_connect_ust_hk/EqQDNVch2j5DmyIDnHX0VgkBDdCksmT4Kfq2oPOMn6gfMg?e=dy6yUA)|
| [NT](https://github.com/generalizable-neural-performer/genebody-benchmarks/tree/nt)| 21.68 |0.881 | 0.152 | [ckpts](https://hkustconnect-my.sharepoint.com/:f:/g/personal/wchengad_connect_ust_hk/Etg3LW44m61OjZOgDp-f4TcB_rgm_32ve529z5EZgCmoGw?e=zGUadc)|
| [NB](https://github.com/generalizable-neural-performer/genebody-benchmarks/tree/nb)| 20.73 |0.878 | 0.231 | [ckpts](https://hkustconnect-my.sharepoint.com/personal/wchengad_connect_ust_hk/_layouts/15/onedrive.aspx?ga=1&id=%2Fpersonal%2Fwchengad%5Fconnect%5Fust%5Fhk%2FDocuments%2Fgenebody%2Dbenchmark%2Dpretrained%2Fnb%2Fgenebody)|
| [A-Nerf](https://github.com/generalizable-neural-performer/genebody-benchmarks/tree/A-Nerf)| 15.57 |0.508 | 0.242 | [ckpts](https://hkustconnect-my.sharepoint.com/:f:/g/personal/wchengad_connect_ust_hk/En56nksujH1Fn1qWiUJ-gpIBfzdHqHf66F-RvfzwTe2TBQ?e=Zz0EgX)|
(see detail why A-Nerf's performance is counterproductive in [issue](https://github.com/LemonATsu/A-NeRF/issues/8))
### Generalizable Methods on Genebody
| Model | PSNR | SSIM |LPIPS| ckpts|
| :--- | :---------------:|:---------------:| :---------------:| :---------------: |
| PixelNeRF | 24.15 |0.903 | 0.122 | |
| [IBRNet](https://github.com/generalizable-neural-performer/genebody-benchmarks/tree/ibrnet)| 23.61 |0.836 | 0.177 | [ckpts](https://hkustconnect-my.sharepoint.com/personal/wchengad_connect_ust_hk/_layouts/15/onedrive.aspx?ga=1&id=%2Fpersonal%2Fwchengad%5Fconnect%5Fust%5Fhk%2FDocuments%2Fgenebody%2Dbenchmark%2Dpretrained%2Fibrnet)|
### Opensource contributions
OpenXRLab/xrnerf: [https://github.com/openxrlab/xrnerf](https://github.com/openxrlab/xrnerf)
## Citation
```
@article{cheng2022generalizable,
title={Generalizable Neural Performer: Learning Robust Radiance Fields for Human Novel View Synthesis},
author={Cheng, Wei and Xu, Su and Piao, Jingtan and Qian, Chen and Wu, Wayne and Lin, Kwan-Yee and Li, Hongsheng},
journal={arXiv preprint arXiv:2204.11798},
year={2022}
}
```
================================================
FILE: apps/render_smpl_depth.py
================================================
from tqdm import tqdm
import numpy as np
import argparse
import struct
import sys
import cv2
import os
import re
from multiprocessing import Queue, Lock, Process
from trimesh import load_mesh
base_dir = os.path.dirname(os.path.abspath(__file__))
parser = argparse.ArgumentParser()
parser.add_argument('--datadir', type = str, required=True)
parser.add_argument('--outdir', type = str, default = '')
parser.add_argument('--annotdir', type = str, default = '')
parser.add_argument('--workers', type = int, default = 8)
def load_obj_mesh(mesh_file, with_normal=False, with_texture=False, with_texture_image=False):
vertex_data = []
norm_data = []
uv_data = []
face_data = []
face_norm_data = []
face_uv_data = []
if isinstance(mesh_file, str):
f = open(mesh_file, "r")
else:
f = mesh_file
for line in f:
if isinstance(line, bytes):
line = line.decode("utf-8")
if line.startswith('#'):
continue
values = line.split()
if not values:
continue
if values[0] == 'v':
v = list(map(float, values[1:4]))
vertex_data.append(v)
elif values[0] == 'vn':
vn = list(map(float, values[1:4]))
norm_data.append(vn)
elif values[0] == 'vt':
vt = list(map(float, values[1:3]))
uv_data.append(vt)
elif values[0] == 'f':
# quad mesh
if len(values) > 4:
f = list(map(lambda x: int(x.split('/')[0]), values[1:4]))
face_data.append(f)
f = list(map(lambda x: int(x.split('/')[0]), [values[3], values[4], values[1]]))
face_data.append(f)
# tri mesh
else:
f = list(map(lambda x: int(x.split('/')[0]), values[1:4]))
face_data.append(f)
# deal with texture
if len(values[1].split('/')) >= 2:
# quad mesh
if len(values) > 4:
f = list(map(lambda x: int(x.split('/')[1]), values[1:4]))
face_uv_data.append(f)
f = list(map(lambda x: int(x.split('/')[1]), [values[3], values[4], values[1]]))
face_uv_data.append(f)
# tri mesh
elif len(values[1].split('/')[1]) != 0:
f = list(map(lambda x: int(x.split('/')[1]), values[1:4]))
face_uv_data.append(f)
# deal with normal
if len(values[1].split('/')) == 3:
# quad mesh
if len(values) > 4:
f = list(map(lambda x: int(x.split('/')[2]), values[1:4]))
face_norm_data.append(f)
f = list(map(lambda x: int(x.split('/')[2]), [values[3], values[4], values[1]]))
face_norm_data.append(f)
# tri mesh
elif len(values[1].split('/')[2]) != 0:
f = list(map(lambda x: int(x.split('/')[2]), values[1:4]))
face_norm_data.append(f)
elif 'mtllib' in line.split():
mtlname = line.split()[-1]
mtlfile = os.path.join(os.path.dirname(mesh_file), mtlname)
with open(mtlfile, 'r') as fmtl:
mtllines = fmtl.readlines()
for mtlline in mtllines:
# if mtlline.startswith('map_Kd'):
if 'map_Kd' in mtlline.split():
texname = mtlline.split()[-1]
texfile = os.path.join(os.path.dirname(mesh_file), texname)
texture_image = cv2.imread(texfile)
texture_image = cv2.cvtColor(texture_image, cv2.COLOR_BGR2RGB)
break
vertices = np.array(vertex_data)
faces = np.array(face_data) - 1
if with_texture and with_normal:
uvs = np.array(uv_data)
face_uvs = np.array(face_uv_data) - 1
norms = np.array(norm_data)
if norms.shape[0] == 0:
norms = compute_normal(vertices, faces)
face_normals = faces
else:
norms = normalize_v3(norms)
face_normals = np.array(face_norm_data) - 1
if with_texture_image:
return vertices, faces, norms, face_normals, uvs, face_uvs, texture_image
else:
return vertices, faces, norms, face_normals, uvs, face_uvs
if with_texture:
uvs = np.array(uv_data)
face_uvs = np.array(face_uv_data) - 1
return vertices, faces, uvs, face_uvs
if with_normal:
# norms = np.array(norm_data)
# norms = normalize_v3(norms)
# face_normals = np.array(face_norm_data) - 1
norms = np.array(norm_data)
if norms.shape[0] == 0:
norms = compute_normal(vertices, faces)
face_normals = faces
else:
norms = normalize_v3(norms)
face_normals = np.array(face_norm_data) - 1
return vertices, faces, norms, face_normals
return vertices, faces
def extract_float(text):
flts = []
for c in re.findall('(-?[0-9]*\.?[0-9]*[eE]?[-\+]?[0-9]+)',text):
if c != '':
try:
flts.append(float(c))
except ValueError as e:
continue
return flts
def natural_sort(files):
return sorted(files, key = lambda text: \
extract_float(os.path.basename(text)) \
if len(extract_float(os.path.basename(text))) > 0 else \
[float(ord(c)) for c in os.path.basename(text)])
def load_ply(file_name):
v = []; tri = []
try:
fid = open(file_name, 'r')
head = fid.readline().strip()
readl= lambda f: f.readline().strip()
except UnicodeDecodeError as e:
fid = open(file_name, 'rb')
readl = (lambda f: str(f.readline().strip())[2:-1]) \
if sys.version_info[0] == 3 else \
(lambda f: str(f.readline().strip()))
head = readl(fid)
if head.lower() != 'ply':
return v, tri
form = readl(fid).split(' ')[1]
line = readl(fid)
vshape = fshape = [0]
while line != 'end_header':
s = [i for i in line.split(' ') if len(i) > 0]
if len(s) > 2 and s[0] == 'element' and s[1] == 'vertex':
vshape = [int(s[2])]
line = readl(fid)
s = [i for i in line.split(' ') if len(i) > 0]
while s[0] == 'property' or s[0][0] == '#':
if s[0][0] != '#':
vshape += [s[1]]
line = readl(fid)
s = [i for i in line.split(' ') if len(i) > 0]
elif len(s) > 2 and s[0] == 'element' and s[1] == 'face':
fshape = [int(s[2])]
line = readl(fid)
s = [i for i in line.split(' ') if len(i) > 0]
while s[0] == 'property' or s[0][0] == '#':
if s[0][0] != '#':
fshape = [fshape[0],s[2],s[3]]
line = readl(fid)
s = [i for i in line.split(' ') if len(i) > 0]
else:
line = readl(fid)
if form.lower() == 'ascii':
for i in range(vshape[0]):
s = [i for i in readl(fid).split(' ') if len(i) > 0]
if len(s) > 0 and s[0][0] != '#':
v += [[float(i) for i in s]]
v = np.array(v, np.float32)
for i in range(fshape[0]):
s = [i for i in readl(fid).split(' ') if len(i) > 0]
if len(s) > 0 and s[0][0] != '#':
tri += [[int(s[1]),int(s[i-1]),int(s[i])] \
for i in range(3,len(s))]
tri = np.array(tri, np.int64)
else:
maps = {'float': ('f',4), 'double':('d',8), \
'uint': ('I',4), 'int': ('i',4), \
'ushort':('H',2), 'short': ('h',2), \
'uchar': ('B',1), 'char': ('b',1)}
if 'little' in form.lower():
fmt = '<' + ''.join([maps[i][0] for i in vshape[1:]]*vshape[0])
else:
fmt = '>' + ''.join([maps[i][0] for i in vshape[1:]]*vshape[0])
l = sum([maps[i][1] for i in vshape[1:]]) * vshape[0]
v = struct.unpack(fmt, fid.read(l))
v = np.array(v).reshape(vshape[0],-1).astype(np.float32)
v = v[:,:3]
tri = []
for i in range(fshape[0]):
l = struct.unpack(fmt[0]+maps[fshape[1]][0], \
fid.read(maps[fshape[1]][1]))[0]
f = struct.unpack(fmt[0]+maps[fshape[2]][0]*l, \
fid.read(l*maps[fshape[2]][1]))
tri += [[f[0],f[i-1],f[i]] for i in range(2,len(f))]
tri = np.array(tri).reshape(fshape[0],-1).astype(np.int64)
fid.close()
return v, tri
def distortPoints(p, dist):
dist = np.reshape(dist,-1) \
if dist is not None else []
k1 = dist[0] if len(dist) > 0 else 0
k2 = dist[1] if len(dist) > 1 else 0
p1 = dist[2] if len(dist) > 2 else 0
p2 = dist[3] if len(dist) > 3 else 0
k3 = dist[4] if len(dist) > 4 else 0
k4 = dist[5] if len(dist) > 5 else 0
k5 = dist[6] if len(dist) > 6 else 0
k6 = dist[7] if len(dist) > 7 else 0
x, y = p[...,0], p[...,1]
x2 = x * x; y2 = y * y; xy = x * y
r2 = x2 + x2
c = (1 + r2 * (k1 + r2 * (k2 + r2 * k3))) / \
(1 + r2 * (k4 + r2 * (k5 + r2 * k6)))
x_ = c*x + p1*2*xy + p2*(r2+2*x2)
y_ = c*y + p2*2*xy + p1*(r2+2*y2)
p[...,0] = x_
p[...,1] = y_
return p
def rasterize(v, tri, size, K = np.identity(3), \
dist = None, persp = True, eps = 1e-6):
h, w = size
zbuf = np.ones([h, w], v.dtype) * float('inf')
if dist is not None:
valid = np.where(v[:,2] >= eps)[0] \
if persp else np.arange(len(v))
v_proj = v[valid,:2] / v[valid,2:]
v_proj = distortPoints(v_proj, dist)
v[valid,:2]= v_proj * v[valid,2:]
v_proj = v.dot(K.T)[:,:2] / np.maximum(v[:,2:], eps) \
if persp else v.dot(K.T)[:,:2]
va = v_proj[tri[:,0],:2]
vb = v_proj[tri[:,1],:2]
vc = v_proj[tri[:,2],:2]
front = np.cross(vc - va, vb - va)
umin = np.maximum(np.ceil (np.vstack((va[:,0],vb[:,0],vc[:,0])).min(0)), 0)
umax = np.minimum(np.floor(np.vstack((va[:,0],vb[:,0],vc[:,0])).max(0)),w-1)
vmin = np.maximum(np.ceil (np.vstack((va[:,1],vb[:,1],vc[:,1])).min(0)), 0)
vmax = np.minimum(np.floor(np.vstack((va[:,1],vb[:,1],vc[:,1])).max(0)),h-1)
umin = umin.astype(np.int32)
umax = umax.astype(np.int32)
vmin = vmin.astype(np.int32)
vmax = vmax.astype(np.int32)
front = np.where(np.logical_and(np.logical_and( \
umin <= umax, vmin <= vmax), front > 0))[0]
for t in front:
A = np.concatenate((vb[t:t+1]-va[t:t+1], vc[t:t+1]-va[t:t+1]),0)
x, y = np.meshgrid( range(umin[t],umax[t]+1), \
range(vmin[t],vmax[t]+1))
u = np.vstack((x.reshape(-1),y.reshape(-1))).T
coeff = (u.astype(v.dtype) - va[t:t+1,:]).dot(np.linalg.pinv(A))
coeff = np.concatenate((1-coeff.sum(1).reshape(-1,1),coeff),1)
if persp:
z = coeff.dot(v[tri[t], 2])
else:
z = 1 / np.maximum((coeff/v[tri[t],2:3].T).sum(1), eps)
for i, (x, y) in enumerate(u):
if coeff[i,0] >= -eps \
and coeff[i,1] >= -eps \
and coeff[i,2] >= -eps \
and zbuf[y,x] > z[i]:
zbuf[y,x] = z[i]
return zbuf
def render_view(intri, dists, c2ws, meshes, view, i):
K = intri[i]
dist = dists[i]
# c2w = np.concatenate([c2ws[i],[[0,0,0,1]]], 0)
w2c = np.linalg.inv(c2ws[i])
# w2c = c2w
out = os.path.join(args.outdir, os.path.basename(view))
if not os.path.isdir(out):
os.makedirs(out, exist_ok=True)
imgs = [os.path.join(view, f) for f in os.listdir(view) \
if f[-4:].lower() in ['.jpg','.png']]
imgs = sorted(imgs) if len(imgs) > 1 else imgs
for i in tqdm(range(len(imgs))):
img = cv2.imread(imgs[i])
try:
if i < len(meshes) and meshes[i][-4:] == '.ply':
v, tri = load_ply(meshes[i])
elif i < len(meshes) and meshes[i][-4:] == '.npy':
v = np.load(meshes[i])
else:
v, tri = load_obj_mesh(meshes[i])
except:
continue
v_= v.dot(w2c[:3,:3].T) + w2c[:3,3:].T
z = rasterize(v_, tri, img.shape[:2], K, dist)
height = v_[:,1].max() - v_[:,1].min()
z[z == float('inf')] = 0
z = np.clip(np.round(z * 1000), 0, 65535).astype(np.uint16)
cv2.imwrite(os.path.join(out, \
'smpl_'+os.path.basename(imgs[i][:-4])+'.png'), z)
class Worker(Process):
def __init__(self, queue, lock):
super(Worker, self).__init__()
self.queue = queue
self.lock = lock
def run(self):
while True:
self.lock.acquire()
if self.queue.empty():
self.lock.release()
break
else:
kwargs = self.queue.get()
queue_len = self.queue.qsize()
self.lock.release()
print("started {}, {} jobs left".format(kwargs["view"], queue_len))
render_view(**kwargs)
if __name__ == '__main__':
args = parser.parse_args()
views = [os.path.join(args.datadir, 'image', f) \
for f in os.listdir(os.path.join(args.datadir, 'image'))]
if args.annotdir == '':
args.annotdir = os.path.join(args.datadir, 'annots.npy')
annot = np.load(os.path.join(args.annotdir), allow_pickle = True).item()['cams']
intri = np.array([annot[view]['K'] for view in annot.keys()], np.float32)
dists = np.array([annot[view]['D'] for view in annot.keys()], intri.dtype)
c2ws = np.array([annot[view]['c2w'] for view in annot.keys()]).astype(intri.dtype)
if args.outdir == '':
args.outdir = os.path.join(args.datadir, 'smpl_depth')
if not os.path.isdir(args.outdir):
os.makedirs(args.outdir)
if os.path.exists(os.path.join(args.datadir,'new_smpl')):
meshes = [os.path.join(args.datadir,'new_smpl',f) \
for f in os.listdir(os.path.join(args.datadir,'new_smpl')) \
if f[-4:] == '.ply' or f[-4:] == '.obj']
meshes = natural_sort(meshes)
elif os.path.exists(os.path.join(args.datadir,'smpl')):
meshes = [os.path.join(args.datadir,'smpl',f) \
for f in os.listdir(os.path.join(args.datadir,'smpl')) \
if f[-4:] == '.ply' or f[-4:] == '.obj']
meshes = natural_sort(meshes)
elif os.path.exists(os.path.join(args.datadir,'new_vertices')):
meshes = [os.path.join(args.datadir,'new_vertices',f) \
for f in os.listdir(os.path.join(args.datadir,'new_vertices')) \
if f[-4:] == '.npy']
tri = np.loadtxt(os.path.join(base_dir,'tri.txt')).astype(np.int64)
meshes = natural_sort(meshes)
elif os.path.exists(os.path.join(args.datadir,'vertices')):
meshes = [os.path.join(args.datadir,'vertices',f) \
for f in os.listdir(os.path.join(args.datadir,'vertices')) \
if f[-4:] == '.npy']
_, tri = load_obj_mesh('./smpl_t_pose/smplx.obj')
tri = tri.astype(np.int64)
meshes = natural_sort(meshes)
queue = Queue()
lock = Lock()
for i, view in enumerate(natural_sort(views)):
queue.put({
'intri': intri,
'dists': dists,
'c2ws': c2ws,
'meshes': meshes,
'view': view,
'i': i,
})
print("num of workers", args.workers, flush=True)
pool = [Worker(queue, lock) for _ in range(args.workers)]
for worker in pool: worker.start()
for worker in pool: worker.join()
================================================
FILE: apps/run_genebody.py
================================================
import sys
import os
ROOT_PATH = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
sys.path = sys.path[:-1]
import time
import json
import numpy as np
import torch
from torch.utils.data import DataLoader
from lib.options import BaseOptions
from lib.data.GeneBodyDataset import GeneBodyDataset as MyDataset
from lib.mesh_util import *
from lib.net_ddp import create_network, worker_init_fn, ddpSampler, ddp_init, synchronize
from lib.model import GNR
from PIL import Image
import torchvision.transforms as transforms
from tqdm import tqdm, trange
from torch.utils.tensorboard import SummaryWriter
import torch.distributed as dist
import logging
import imageio
import lib.metrics_torch as metrics_torch
def loss_string(loss_dict):
string = ''
for key in loss_dict.keys():
string += '| {}: {:.2e} '.format(key, loss_dict[key].item())
return string
def print_write(file, string):
file.write(string)
if string[-1] == '\n': string = string[:-1]
print(string)
def to8b(img):
if isinstance(img, torch.Tensor):
img = img.detach().cpu().numpy()
if img.shape[0] == 3 and img.shape[-1] != 3:
img = np.transpose(img, [1,2,0])
if img.min() < -.2:
img = (img + 1) * 127.5
elif img.max() <= 2.:
img = img * 255.
img = np.clip(img, 0, 255)
return img.astype(np.uint8)
# get options
def prepare_data(opt, data, local_rank=0):
# retrieve the data
image_tensor = data['img'][0].to(device=local_rank)
calib_tensor = data['calib'][0].to(device=local_rank)
mask_tensor = data['mask'][0].to(device=local_rank)
bbox = list(data['bbox'][0].numpy().astype(np.int32))
mesh_param = {'center': data['center'][0].to(device=local_rank),
'body_scale': data['body_scale'][0].cpu().numpy().item()}
if opt.train_shape:
mesh_param['samples'] = data['samples'][0].to(device=local_rank)
mesh_param['labels'] = data['labels'][0].to(device=local_rank)
if any([opt.use_smpl_sdf, opt.use_t_pose]):
smpl = { 'rot': data['smpl_rot'].to(device=local_rank) }
if opt.use_smpl_sdf or opt.use_t_pose:
smpl['verts'] = data['smpl_verts'][0].to(device=local_rank)
smpl['faces'] = data['smpl_faces'][0].to(device=local_rank)
if opt.use_t_pose:
smpl['t_verts'] = data['smpl_t_verts'][0].to(device=local_rank)
smpl['t_faces'] = data['smpl_t_faces'][0].to(device=local_rank)
if opt.use_smpl_depth:
smpl['depth'] = data['smpl_depth'][0].to(device=local_rank)[:,None,...]
else:
smpl = None
if 'scan_verts' in data.keys():
scan = [data['scan_verts'][0].to(device=local_rank), data['scan_faces'][0].to(device=local_rank)]
else:
scan = None
persps = data['persps'][0].to(device=local_rank) if opt.projection_mode == 'perspective' else None
return {
'images': image_tensor,
'calibs': calib_tensor,
'bbox': bbox,
'masks': mask_tensor,
'mesh_param': mesh_param,
'smpl': smpl,
'scan': scan,
'persps': persps
}
def cal_metrics(metrics, rgbs, gts):
x = rgbs.clone().permute((0, 3, 1, 2))
out = {}
for m_key in metrics.keys():
out[m_key] = []
for pred, gt in zip(x, gts):
metric = metrics[m_key]
out[m_key].append(metric(pred, gt))
out[m_key] = torch.stack(out[m_key], dim=0)
return out
def train(opt, rank=0, local_rank = 0):
gpu_num = torch.cuda.device_count()
train_dataset = MyDataset(opt, phase='train')
test_dataset = MyDataset(opt, phase='test', move_cam=0)
render_dataset = MyDataset(opt, phase='render', move_cam=opt.move_cam)
train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) if opt.ddp else None
# create data loader
shuffle = not opt.ddp
train_data_loader = DataLoader(train_dataset, sampler=train_sampler, batch_size=opt.batch_size,
num_workers=opt.num_threads, shuffle=False, worker_init_fn=worker_init_fn)
logging.info(f'train data size: {len(train_data_loader)}')
test_data_loader = DataLoader(test_dataset, batch_size=1)
test_data_iter = iter(test_data_loader)
logging.info(f'test data size: {len(test_data_loader)}')
render_data_loader = DataLoader(render_dataset, batch_size=1)
render_data_iter = iter(render_data_loader)
logging.info(f'render data size: {len(render_data_loader)}')
# create net
net = GNR(opt)
logging.info(f'Using Network: {net.name}')
set_train = net.train
set_eval = net.eval
os.makedirs(opt.basedir, exist_ok=True)
os.makedirs('%s/%s' % (opt.basedir, opt.name), exist_ok=True)
net, start_epoch = create_network(opt, net, local_rank)
global_step = start_epoch * len(train_dataset)
lr = opt.lrate * (0.1 ** (start_epoch / opt.lrate_decay))
# params = net.parameters() if opt.train_encoder else net.module.nerf.parameters()
params_list = []
for name, param in net.named_parameters():
if 'occ_linears' in name:
if opt.train_occlusion:
params_list.append(param)
elif 'image_filter' in name:
if opt.train_encoder:
params_list.append(param)
else:
params_list.append(param)
optimizer = torch.optim.Adam(params=params_list, lr=lr, betas=(0.9, 0.999))
is_summary = not opt.ddp or (opt.ddp and (rank == 0))
if is_summary:
from tqdm import tqdm, trange
if opt.train:
writer = SummaryWriter(os.path.join(opt.basedir, opt.name))
opt_log = os.path.join(opt.basedir, opt.name, 'opt.txt')
config_file = os.path.join(opt.basedir, opt.name, 'config.txt')
with open(opt_log, 'w') as outfile:
outfile.write(json.dumps(vars(opt), indent=2))
os.system(f'cp {opt.config} {config_file}')
else:
tqdm = lambda x: x
trange = range
# evaluate, not demo
# metrics_dict = {'lpips': [], 'psnr': [], 'ssim': []}
metrics_dict = {}
metrics = {'lpips': metrics_torch.LPIPS().to(local_rank), 'psnr': metrics_torch.psnr, 'ssim': metrics_torch.SSIM().to(local_rank)}
# training
if opt.train:
for epoch in trange(start_epoch, opt.num_epoch):
set_train()
if opt.ddp:
train_data_loader.sampler.set_epoch(epoch)
synchronize()
pbar = tqdm(train_data_loader)
if is_summary:
pbar.set_description("epoch {}/{}".format(epoch, opt.num_epoch))
for train_idx, train_data in enumerate(pbar):
data = prepare_data(opt, train_data, local_rank)
train_shape = opt.train_shape and train_idx % opt.train_shape_skips == 0
loss_dict = net(data, train_shape=train_shape)
loss = sum(loss_dict.values())
optimizer.zero_grad()
try:
loss.backward()
except:
print(train_data['name'], train_data['sid'], train_data['vid'], flush=True)
optimizer.step()
if global_step % opt.freq_plot == 0 and is_summary:
tqdm.write(
'[{}] | epoch: {} | step: {:d} | loss:{:.2e} | lr: {:.2e} '.format(
opt.name, epoch, global_step, loss.item(), lr))
tqdm.write(f'[{opt.name}] {loss_string(loss_dict)}')
if is_summary:
writer.add_scalar('loss', loss.item(), global_step)
for key in loss_dict.keys():
writer.add_scalar(key, loss_dict[key].item(), global_step)
pbar.update(1)
global_step += 1 if not opt.ddp else gpu_num
if opt.ddp and (rank == 0):
torch.save({'network_state_dict': net.module.state_dict(), 'epoch': epoch+1},
'%s/%s/%04d.tar' % (opt.basedir, opt.name, epoch+1))
elif not opt.ddp:
torch.save({'network_state_dict': net.state_dict(), 'epoch': epoch+1},
'%s/%s/%04d.tar' % (opt.basedir, opt.name, epoch+1))
# update learning rate
lr = opt.lrate * (0.1 ** ((epoch+1) / opt.lrate_decay))
for param_group in optimizer.param_groups:
param_group['lr'] = lr
if opt.test:
net.eval()
idx = 0
with torch.no_grad():
for test_idx in trange(len(test_dataset)):
## test dataset
test_data = next(test_data_iter)
data = prepare_data(opt, test_data, local_rank)
name = test_data['name'][0]
subject = name.split('_')[0]
# fid = test_data['fid'][0]
fid = test_data['sid'][0]
vid = test_data['vid'][0]
render_gt = test_data['render_gt'][0].to(local_rank)
if opt.ddp:
# distribute query views to different GPUs if multiple GPU or multiple machines are available
src_calibs, tar_calibs = torch.split(data['calibs'], [opt.num_views, data['calibs'].shape[0]-opt.num_views], 0)
total_len = len(tar_calibs)
sampler = ddpSampler(tar_calibs)
indices = sampler.indices()
tar_calibs = tar_calibs[indices]
tar_gt = render_gt[indices]
data['calibs'] = torch.cat([src_calibs, tar_calibs], 0)
if opt.projection_mode == 'perspective':
src_persps, tar_persps = torch.split(data['persps'], [opt.num_views, data['persps'].shape[0]-opt.num_views], 0)
tar_persps = tar_persps[indices]
data['persps'] = torch.cat([src_persps, tar_persps], 0)
rgbs, _ = net.module.render_path(data)
rgbs = sampler.distributed_concat(rgbs, total_len) if opt.ddp else rgbs
if opt.use_attention:
rgbs, att_rgbs = rgbs[...,:3], rgbs[...,3:6]
else:
att_rgbs = rgbs[..., :3]
m_dict = cal_metrics(metrics, att_rgbs, render_gt)
if subject not in metrics_dict.keys():
metrics_dict[subject] = {'lpips': [], 'psnr': [], 'ssim': []}
for key, value in m_dict.items():
# if opt.ddp:
# value = sampler.distributed_concat(value, total_len) if opt.ddp else value
metrics_dict[subject][key].append(torch.mean(value).cpu().numpy())
att_rgbs = [to8b(att_rgb) for att_rgb in att_rgbs.cpu().numpy()]
render_gt = [to8b(gt) for gt in render_gt.permute(0,2,3,1).cpu().numpy()]
target_dir = os.path.join(opt.basedir, opt.name, opt.eval_dir, name)
os.makedirs(target_dir, exist_ok=True)
if is_summary:
for vid, im in enumerate(att_rgbs):
imageio.imwrite(os.path.join(target_dir, f'{vid:02d}.png'), im)
fname = os.path.join(opt.basedir, opt.name, opt.eval_dir, 'eval.txt')
with open(fname, 'a') as file_:
print_write(file_, '******\n%s\n' % (name))
for k, v in metrics_dict[subject].items():
print_write(file_, '%s: %.5f\n'%(k, v[-1]))
file_.write('------')
for k, v in metrics_dict[subject].items():
print_write(file_, '[total] %s: %.5f\n'%(k, sum(v)/len(v)))
if opt.output_mesh:
verts, faces, rgbs = net.module.reconstruct(data)
if is_summary:
save_obj_mesh_with_color(os.path.join(target_dir, "{}.obj".format(name)), verts, faces, rgbs)
idx += 1
if opt.render:
net.eval()
with torch.no_grad():
imgs = []
for ridx in trange(len(render_dataset)):
# render dataset
test_data = next(render_data_iter)
data = prepare_data(opt, test_data, local_rank)
name = test_data['name'][0]
# fid = test_data['fid'][0]
fid = test_data['sid'][0]
vid = test_data['vid'][0]
if opt.ddp:
# distribute query views to different GPUs if multiple GPU or multiple machines are available
src_calibs, tar_calibs = torch.split(data['calibs'], [opt.num_views, data['calibs'].shape[0]-opt.num_views], 0)
total_len = len(tar_calibs)
sampler = ddpSampler(tar_calibs)
indices = sampler.indices()
tar_calibs = tar_calibs[indices]
data['calibs'] = torch.cat([src_calibs, tar_calibs], 0)
if opt.projection_mode == 'perspective':
src_persps, tar_persps = torch.split(data['persps'], [opt.num_views, data['persps'].shape[0]-opt.num_views], 0)
tar_persps = tar_persps[indices]
data['persps'] = torch.cat([src_persps, tar_persps], 0)
target_dir = os.path.join(opt.basedir, opt.name, opt.render_dir, name.split('_')[0])
os.makedirs(target_dir, exist_ok=True)
rgbs, depths = net.module.render_path(data)
rgbs = sampler.distributed_concat(rgbs, total_len) if opt.ddp else rgbs
depths = sampler.distributed_concat(depths, total_len) if opt.ddp else depths
if opt.use_attention:
rgbs, att_rgbs = rgbs[...,:3], rgbs[...,3:6]
else:
att_rgbs = rgbs[..., :3]
att_rgbs = [to8b(att_rgb) for att_rgb in att_rgbs.cpu().numpy()]
imgs += att_rgbs
os.makedirs(target_dir, exist_ok=True)
for vid, im in enumerate(att_rgbs):
imageio.imwrite(os.path.join(target_dir, f'{ridx:03d}_rgb.png'), im)
depth= np.clip(np.round(depths[vid].cpu().numpy()*1000), 0, 65535).astype(np.uint16)
depth[depth == 0] = 65535
imageio.imwrite(os.path.join(target_dir, f'{ridx:03d}_depth.png'), depth)
if is_summary:
imageio.mimwrite(os.path.join(target_dir, "render_{}.mp4".format(fid)), imgs, quality=8, fps=30)
if __name__ == '__main__':
opt = BaseOptions().parse()
if opt.ddp:
rank, local_rank = ddp_init(opt)
logging.basicConfig(level=logging.INFO if rank in [-1, 0] else logging.WARN)
logging.info(vars(opt))
train(opt, rank, local_rank)
else:
logging.basicConfig(level=logging.INFO)
logging.info(vars(opt))
train(opt)
================================================
FILE: configs/render.txt
================================================
name = genebody
# run phase
train = False
test = False
render = True
# Dataloader
num_threads = 5
output_mesh = True
# Geometric Body Shape Embedding
smpl_type = smplx
use_smpl_sdf = True
use_t_pose = True
use_nml = True
# SSOAB
use_attention = True
weighted_pool = True
use_sh = True
use_viewdirs = True
use_occlusion = True
use_smpl_depth = True
use_occlusion_net = True
# Ray Sampling
use_vh = True
N_rand = 1024
N_rand_infer = 4096
N_samples = 256
chunk = 524288
vh_overhead = 1
# Trianing
ddp = False
train_encoder = False
projection_mode = perspective
# Evaluation
eval_skip = 1
# Render
move_cam = 150
# Reconstruction
N_grid = 512
laplacian = 5
================================================
FILE: configs/test.txt
================================================
name = genebody
# run phase
train = False
test = True
render = False
# Dataloader
num_threads = 5
output_mesh = True
# Geometric Body Shape Embedding
smpl_type = smplx
use_smpl_sdf = True
use_t_pose = True
use_nml = True
# SSOAB
use_attention = True
weighted_pool = True
use_sh = True
use_viewdirs = True
use_occlusion = True
use_smpl_depth = True
use_occlusion_net = True
# Ray Sampling
use_vh = True
N_rand = 1024
N_rand_infer = 4096
N_samples = 256
chunk = 524288
vh_overhead = 1
# Trianing
ddp = False
train_encoder = False
projection_mode = perspective
# Evaluation
eval_skip = 15
# Render
move_cam = 1
# Reconstruction
N_grid = 512
laplacian = 5
================================================
FILE: configs/train.txt
================================================
name = genebody
# run phase
train = True
test = False
render = False
# Dataloader
num_threads = 5
output_mesh = True
# Geometric Body Shape Embedding
smpl_type = smplx
use_smpl_sdf = True
use_t_pose = True
use_nml = True
# SSOAB
use_attention = True
weighted_pool = True
use_sh = True
use_viewdirs = True
use_occlusion = True
use_smpl_depth = True
use_occlusion_net = True
# Ray Sampling
use_vh = True
N_rand = 1024
N_rand_infer = 4096 # decrease the batch size if there is an out-of-memory error
N_samples = 256
chunk = 524288
vh_overhead = 1
# Trianing
ddp = False
train_encoder = True
projection_mode = perspective
# Evaluation
eval_skip = 15
# Render
move_cam = 1
# Reconstruction
N_grid = 512 # decrease the grid size if there is an out-of-memory error
laplacian = 5
================================================
FILE: docs/Annotation.md
================================================
# GeneBody Annotations
## Data Capture
GeneBody dataset captures performer in a motion capture studio with 48 synchronized cameras. Each actor is asked to perform 10 seconds clips recorded in a 15 fps rate. The camera location and capture volume is visualized in the following video.
<!--  -->
<p align="center"><img src="./capture_volume.gif" width="90%"></p>
<p align="center">Left: Motion capture studio and performer, cameras are highlighted. Right: Video captured from camera 25.</p>
## Dataset Organization
The processed GeneBody dataset is organized in following structure
```
├──genebody/ # root of dataset
├──amanda/ # subject
├──image/ # multiview images
├──00/ # images of 00 view
├──...
├──mask/ # multiview masks
├──00/ # masks of 00 view
├──...
├──param/ # smpl parameters
├──smpl/ # smpl meshes in OBJ format
├──annots.npy # camera parameters
├──.../
├──genebody_split.npy # dataset splits
```
You can download the Test10 and Train40 subset by the [instructions](./Dataset.md#download-instructions).
## Data Interface
We provide the reference data reader `GeneBodyReader` in `genebody/genebody.py`.
### Source views
The default source view number is 4, and the source views are `[01, 13, 25, 37]`.
### Image cropping
As human performer may appear in different size across views, and the original image plane contains very small proportion of foreground, directly apply image quality metrics on raw image, eg. PSNR, SSIM and LPIPS may introduce ambiguity numerically. To tackle this we crop the performer out and resize it to a give resolution, in GNR and GeneBody benchmarks. Check the reference interface for more details.
## Camera Calibration
GeneBody provides the camera calibration for each subjects, intrinic matrix, distortion coefficient and extrinic parameters are provided in `annots.npy` in each subject folder. Note that we use a opencv camera (xyz->right,down,front).
## Human Segmentation
### Auto annotation
Before recording, we capture the scene without any performer, and use it as reference image in [BackgroundMatting-V2](https://github.com/PeterL1n/BackgroundMattingV2) to automatically extract the performer foreground mask.
### Human labeling
We choose 8 camera views to manually check the results of auto annoataion and manually labels the bad case. The 8 camera views are `[01, 07, 13, 19, 25, 31, 37, 43]`.
## SMPLx
GeneBody provides per-frame SMPLx estimation, and store the mesh in `smpl` subfolder and SMPLx parameters in `param` subfolder. The SMPLx toolbox is also provided in this [repo](https://github.com/generalizable-neural-performer/bodyfitting).
### SMPLx parameters
GeneBody provide SMPLs parameter and 3D keypoints in `param` subfolder. More specifally, the dictionary 'smplx', can be directly feed to SMPLX [forward](https://github.com/vchoutas/smplx/blob/master/smplx/body_models.py#L1111) pass as long as all the parameters are converted to torch tensor, an example to generate SMPLX from parameters is provided in the [data interface](../genebody/genebody.py#L189).
### SMPLx scale
GeneBody has a large variation on performers' age distribution, while SMPLx model typically fails to fit well on kids and giants. We introduce 'smplx_scale' outside the SMPLx model, and jointly optimize scale and body model parameters during fitting. Thus, to recover the fitted mesh in `smpl` subfolder using parameters in `param` subfolder, you need to multiple a 'smplx_scale' to the vertices or 3d joints output by the body model.
================================================
FILE: docs/Dataset.md
================================================
# GeneBody Dataset
<!--  -->
<p align="center"><img src="./genebody.gif" width="90%"></p>
<p align="center">Please check the <a href="https://generalizable-neural-performer.github.io/genebody.html">dataset webpage</a> for data preview</p>
## Overview
GeneBody is a new dataset that
we collected to evaluate the generalization and robustness of human novel view synthesis. It consists of over total *2.95M* frames of *100* subjects performing *370* sequences under *48* multi-view cameras capturing, with a variety of pose actions, in different types of body shapes, clothing, accessories and hairdos, ranging the geometry and appearance varies from everyday life to professional occasions.
## Agreement
The GeneBody is available for non-commercial research purposes only.
You agree not to reproduce, duplicate, copy, sell, trade, resell or exploit for any commercial purposes, any portion of the images and any portion of derived data.
You agree not to further copy, publish or distribute any portion of the GeneBody. Except, for internal use at a single site within the same organization it is allowed to make copies of the dataset.
The SenseTime Reasearch and Shanghai AI Lab reserves the right to terminate your access to the GeneBody at any time.
## Download Instructions
### Test10
You can download the *Test10* subset from Onedrive [link](https://hkustconnect-my.sharepoint.com/:f:/g/personal/wchengad_connect_ust_hk/Er9_MPspGvpKlrFbVnmwnbEBloYJN7z9D0rP4Ms8LUcYKA?e=lnjS0N)
### Train40
Please download the GeneBody Dataset Release Agreement from [link](./GeneBody_Dataset_Release_Agreement.pdf).
Read it carefully, complete and sign it appropriately.
Please send the completed form to Wei Cheng (wchengad(at)connect.ust.hk) and cc to Kwan-Yee Lin (linjunyi(at)sensetime.com) using institutional email address. The email Subject Title is "GeneBody Agreement". We will verify your request and contact you with the passwords to unzip the image data.
### Extended370
Coming Soon in OpenDataLab.
================================================
FILE: environment.yml
================================================
name: gnr
channels:
- pytorch
- conda-forge
- https://repo.anaconda.com/pkgs/main
- defaults
dependencies:
- _libgcc_mutex=0.1=main
- blas=1.0=mkl
- bzip2=1.0.8=h516909a_3
- ca-certificates=2021.10.8=ha878542_0
- certifi=2021.5.30=py36h5fab9bb_0
- cffi=1.14.5=py36h261ae71_0
- cudatoolkit=9.0=h13b8566_0
- ffmpeg=4.3.1=h3215721_1
- freetype=2.10.4=h5ab3b9f_0
- gmp=6.2.1=h58526e2_0
- gnutls=3.6.13=h85f3911_1
- intel-openmp=2020.2=254
- jpeg=9b=h024ee3a_2
- lame=3.100=h14c3975_1001
- lcms2=2.11=h396b838_0
- ld_impl_linux-64=2.33.1=h53a641e_7
- libffi=3.3=he6710b0_2
- libgcc-ng=9.1.0=hdf63c60_0
- libiconv=1.16=h516909a_0
- libpng=1.6.37=hbc83047_0
- libstdcxx-ng=9.1.0=hdf63c60_0
- libtiff=4.1.0=h2733197_1
- lz4-c=1.9.3=h2531618_0
- mkl=2020.2=256
- mkl-service=2.3.0=py36he8ac12f_0
- mkl_fft=1.3.0=py36h54f3939_0
- mkl_random=1.1.1=py36h0573a6f_0
- ncurses=6.2=he6710b0_1
- nettle=3.6=he412f7d_0
- ninja=1.10.2=py36hff7bd54_0
- numpy-base=1.19.2=py36hfa32c7d_0
- olefile=0.46=py36_0
- openh264=2.1.1=h8b12597_0
- openssl=1.1.1m=h7f8727e_0
- pillow=8.1.2=py36he98fc37_0
- pip=21.0.1=py36h06a4308_0
- pycparser=2.20=py_2
- python=3.6.13=hdb3f193_0
- python_abi=3.6=2_cp36m
- pytorch=1.1.0=py3.6_cuda9.0.176_cudnn7.5.1_0
- readline=8.1=h27cfd23_0
- setuptools=52.0.0=py36h06a4308_0
- sqlite=3.35.3=hdfb4753_0
- tk=8.6.10=hbc83047_0
- torchvision=0.3.0=py36_cu9.0.176_1
- wheel=0.36.2=pyhd3eb1b0_0
- x264=1!152.20180806=h14c3975_0
- xz=5.2.5=h7b6447c_0
- zlib=1.2.11=h7b6447c_3
- zstd=1.4.9=haebb681_0
- pip:
- absl-py==0.12.0
- cachetools==4.2.1
- chardet==4.0.0
- configargparse==1.4
- cycler==0.10.0
- dataclasses==0.8
- decorator==4.4.1
- future==0.18.2
- google-auth==1.28.0
- google-auth-oauthlib==0.4.4
- grpcio==1.36.1
- idna==2.10
- imageio==2.8.0
- imageio-ffmpeg==0.4.3
- importlib-metadata==3.10.0
- kiwisolver==1.1.0
- lpips==0.1.4
- markdown==3.3.4
- matplotlib==3.1.3
- networkx==2.4
- numpy==1.18.1
- oauthlib==3.1.0
- opencv-python==4.2.0.32
- opencv-python-headless==4.2.0.34
- pathlib==1.0.1
- protobuf==3.15.6
- pyasn1==0.4.8
- pyasn1-modules==0.2.8
- pyopengl==3.1.5
- pyparsing==2.4.6
- python-dateutil==2.8.1
- pywavelets==1.1.1
- pyyaml==6.0
- requests==2.25.1
- requests-oauthlib==1.3.0
- rsa==4.7.2
- rtree==0.9.7
- scikit-image==0.16.2
- scipy==1.4.1
- shapely==1.7.0
- six==1.14.0
- smplx==0.1.28
- tensorboard==2.4.1
- tensorboard-plugin-wit==1.8.0
- tqdm==4.43.0
- trimesh==3.5.23
- typing-extensions==3.7.4.3
- unknown==0.0.0
- urllib3==1.26.4
- video2calibration==0.0.1
- werkzeug==1.0.1
- xxhash==1.4.3
- zipp==3.4.1
- gdown
================================================
FILE: genebody/download_tool.py
================================================
from ast import arg
import json, os, sys, pip, copy
from re import sub
def import_or_install(package):
try:
__import__(package)
except ImportError:
pip.main(['install', package])
import_or_install("urllib")
import_or_install("requests")
import_or_install("quickxorhash")
import_or_install("asyncio")
import_or_install("pyppeteer")
import urllib, requests, quickxorhash, asyncio
import urllib.request
from urllib import parse
from pyppeteer import launch
from tqdm import tqdm
from requests.models import codes
from requests.adapters import HTTPAdapter, Retry
import base64
import numpy as np
import argparse
import chardet
# simulate browser
header = {
'sec-ch-ua-mobile': '?0',
'upgrade-insecure-requests': '1',
'dnt': '1',
'user-agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/90.0.4430.93 Safari/537.36 Edg/90.0.818.51',
'accept': 'text/html,application/xhtml+xml,application/xml;q=0.9,image/webp,image/apng,*/*;q=0.8,application/signed-exchange;v=b3;q=0.9',
'service-worker-navigation-preload': 'true',
'sec-fetch-site': 'same-origin',
'sec-fetch-mode': 'navigate',
'sec-fetch-dest': 'iframe',
'accept-language': 'zh-CN,zh;q=0.9,en;q=0.8,en-GB;q=0.7,en-US;q=0.6',
}
genebody_urls = {
"test10": "https://hkustconnect-my.sharepoint.com/:f:/g/personal/wchengad_connect_ust_hk/Er9_MPspGvpKlrFbVnmwnbEBloYJN7z9D0rP4Ms8LUcYKA",
"train40": "https://hkustconnect-my.sharepoint.com/:f:/g/personal/wchengad_connect_ust_hk/Et-eRSKLr09OjUPUTiJW0zAB1yWo1kx_qr7t1FsG2cuv2g",
"smpl_depth": "https://hkustconnect-my.sharepoint.com/:f:/g/personal/wchengad_connect_ust_hk/Eju1mH_WFLtMknBzuii8mtIBVAzxUQMCWIaLl67AGySoOA?e=MsVwJG",
"pretrained_models": "https://hkustconnect-my.sharepoint.com/:f:/g/personal/wchengad_connect_ust_hk/EpsK3TaBfGBBgtDlcWV6nxABMDT9C1qe2fE6EUxoTkQYkQ"
}
def parse_args():
"""
Args:
genebody_root: path to store download data.
force: wheather force to download data,
we set default value to False, assume that user just want to update data.
if you have no data in local path, pls set this argument to True.
subset: data subset to download.
pls choose type in ['test10', 'train40', 'smpl_depth', 'pretrained_models'].
"""
parse = argparse.ArgumentParser(description='Download Genebody data')
parse.add_argument('--genebody_root', type=str, required=True, help='path to store download data')
parse.add_argument('--force', type=bool, default=False, help='wheather force to download data')
parse.add_argument('--subset', type=str, required=True, help='data subset to download')
args = parse.parse_args()
return args
def newSession():
s = requests.session()
retries = Retry(total=5, backoff_factor=0.1)
s.mount('http://', HTTPAdapter(max_retries=retries))
return s
def save_hash(path, code):
with open(path, 'w', encoding='utf-8') as f:
f.write(code)
def read_hash(path):
with open(path, 'r', encoding='utf-8') as f:
str = f.read()
return str
def checkHashes(localfile, cloud_hash, localroot, force):
"""
Synchronize local file with cloud file by hash checking
(For GeneBody, the file is on Onedrive for Bussiness, only "quickXorHash" is avalible)
ref: https://docs.microsoft.com/en-us/onedrive/developer/rest-api/resources/hashes?view=odsp-graph-online
Args:
localfile: local file path
cloud_hash: cloud hash of cloud file
Output:
matched: [True, False], whether local file matches with cloud file
"""
if not os.path.exists(os.path.dirname(localfile)):
os.makedirs(os.path.dirname(localfile), exist_ok=True)
if localfile.split('/')[1] == 'pretrained_models':
local_hash_bkup = os.path.join(localroot, '.hash', os.path.join(localfile.split('/')[2], localfile.split('/')[3].split('.')[0]+'.txt'))
else:
local_hash_bkup = os.path.join(localroot, '.hash', os.path.join(localfile.split('/')[-1].split('.')[0]+'.txt'))
if not os.path.exists(os.path.dirname(local_hash_bkup)):
os.makedirs(os.path.dirname(local_hash_bkup), exist_ok=True)
if force:
save_hash(local_hash_bkup, cloud_hash["quickXorHash"])
tqdm.write(f"force to download data")
return False
# if local file is not deleted
if os.path.exists(localfile):
with open(localfile, 'rb') as lf:
content = lf.read()
hash = quickxorhash.quickxorhash()
hash.update(content)
hashoutput = base64.b64encode(hash.digest()).decode('ascii')
# # write local hash backup
save_hash(local_hash_bkup, cloud_hash["quickXorHash"])
if hashoutput == cloud_hash["quickXorHash"]:
tqdm.write(f"[{os.path.relpath(localfile, localroot)}] local file is up-to-date, skip downloading")
return True
else:
tqdm.write(f"[{os.path.relpath(localfile, localroot)}] local file is out-of-date, updating")
return False
else:
if os.path.isfile(local_hash_bkup):
# read hashcode and compare
hashoutput = read_hash(local_hash_bkup)
if hashoutput == cloud_hash["quickXorHash"]:
tqdm.write(f"[{os.path.relpath(localfile, localroot)}] local file is up-to-date, skip downloading")
return True
else:
save_hash(local_hash_bkup, cloud_hash["quickXorHash"])
tqdm.write(f"[{os.path.relpath(localfile, localroot)}] local file is out-of-date, updating")
return False
else:
# write and record hashcode in txt file
save_hash(local_hash_bkup, cloud_hash["quickXorHash"])
tqdm.write(f"[{os.path.basename(localfile)}] no local file or local file is missing, downloading")
# sys.stdout.flush()
return False
def getFiles(originalUrl, download_path, force, download_root=None, req=None, layers=0, _id=0):
"""
Get file from folder share link (with "-my" share point url)
ref: https://docs.microsoft.com/en-us/graph/use-the-api
Args:
originalUrl: share link
download_path: path to download
req: request None
"""
isSharepoint = False
if "-my" not in originalUrl:
isSharepoint = True
if req == None:
req = newSession()
reqf = req.get(originalUrl, headers=header)
if ',"FirstRow"' not in reqf.text:
print("\t"*layers, "No file in this folder")
return 0
if download_root is None:
download_root = download_path
filesData = []
redirectURL = reqf.url
query = dict(urllib.parse.parse_qsl(
urllib.parse.urlsplit(redirectURL).query))
redirectSplitURL = redirectURL.split("/")
relativeFolder = ""
rootFolder = query["id"]
for i in rootFolder.split("/"):
if isSharepoint:
if i != "Shared Documents":
relativeFolder += i+"/"
else:
relativeFolder += i
break
else:
if i != "Documents":
relativeFolder += i+"/"
else:
relativeFolder += i
break
relativeUrl = parse.quote(relativeFolder).replace(
"/", "%2F").replace("_", "%5F").replace("-", "%2D")
rootFolderUrl = parse.quote(rootFolder).replace(
"/", "%2F").replace("_", "%5F").replace("-", "%2D")
graphqlVar = '{"query":"query (\n $listServerRelativeUrl: String!,$renderListDataAsStreamParameters: RenderListDataAsStreamParameters!,$renderListDataAsStreamQueryString: String!\n )\n {\n \n legacy {\n \n renderListDataAsStream(\n listServerRelativeUrl: $listServerRelativeUrl,\n parameters: $renderListDataAsStreamParameters,\n queryString: $renderListDataAsStreamQueryString\n )\n }\n \n \n perf {\n executionTime\n overheadTime\n parsingTime\n queryCount\n validationTime\n resolvers {\n name\n queryCount\n resolveTime\n waitTime\n }\n }\n }","variables":{"listServerRelativeUrl":"%s","renderListDataAsStreamParameters":{"renderOptions":5707527,"allowMultipleValueFilterForTaxonomyFields":true,"addRequiredFields":true,"folderServerRelativeUrl":"%s"},"renderListDataAsStreamQueryString":"@a1=\'%s\'&RootFolder=%s&TryNewExperienceSingle=TRUE"}}' % (relativeFolder, rootFolder, relativeUrl, rootFolderUrl)
s2 = urllib.parse.urlparse(redirectURL)
tempHeader = copy.deepcopy(header)
tempHeader["referer"] = redirectURL
tempHeader["cookie"] = reqf.headers["set-cookie"]
tempHeader["authority"] = s2.netloc
tempHeader["content-type"] = "application/json;odata=verbose"
graphqlReq = req.post(
"/".join(redirectSplitURL[:-3])+"/_api/v2.1/graphql", data=graphqlVar.encode('utf-8'), headers=tempHeader)
graphqlReq = json.loads(graphqlReq.text)
if "NextHref" in graphqlReq["data"]["legacy"]["renderListDataAsStream"]["ListData"]:
nextHref = graphqlReq[
"data"]["legacy"]["renderListDataAsStream"]["ListData"]["NextHref"]+"&@a1=%s&TryNewExperienceSingle=TRUE" % (
"%27"+relativeUrl+"%27")
filesData.extend(graphqlReq[
"data"]["legacy"]["renderListDataAsStream"]["ListData"]["Row"])
listViewXml = graphqlReq[
"data"]["legacy"]["renderListDataAsStream"]["ViewMetadata"]["ListViewXml"]
renderListDataAsStreamVar = '{"parameters":{"__metadata":{"type":"SP.RenderListDataParameters"},"RenderOptions":1216519,"ViewXml":"%s","AllowMultipleValueFilterForTaxonomyFields":true,"AddRequiredFields":true}}' % (
listViewXml).replace('"', '\\"')
graphqlReq = req.post(
"/".join(redirectSplitURL[:-3])+"/_api/web/GetListUsingPath(DecodedUrl=@a1)/RenderListDataAsStream"+nextHref, data=renderListDataAsStreamVar.encode('utf-8'), headers=tempHeader)
graphqlReq = json.loads(graphqlReq.text)
while "NextHref" in graphqlReq["ListData"]:
nextHref = graphqlReq["ListData"]["NextHref"]+"&@a1=%s&TryNewExperienceSingle=TRUE" % (
"%27"+relativeUrl+"%27")
filesData.extend(graphqlReq["ListData"]["Row"])
graphqlReq = req.post(
"/".join(redirectSplitURL[:-3])+"/_api/web/GetListUsingPath(DecodedUrl=@a1)/RenderListDataAsStream"+nextHref, data=renderListDataAsStreamVar.encode('utf-8'), headers=tempHeader)
graphqlReq = json.loads(graphqlReq.text)
filesData.extend(graphqlReq["ListData"]["Row"])
else:
filesData.extend(graphqlReq[
"data"]["legacy"]["renderListDataAsStream"]["ListData"]["Row"])
filesData = sorted(filesData, key=lambda x: x['FileLeafRef'])
for i in filesData:
# if is a folder, download recursively
if i['FSObjType'] == "1":
_query = query.copy()
_query['id'] = os.path.join(_query['id'], i['FileLeafRef']).replace("\\", "/")
if not isSharepoint:
originalPath = "/".join(redirectSplitURL[:-1]) + \
"/onedrive.aspx?" + urllib.parse.urlencode(_query)
else:
originalPath = "/".join(redirectSplitURL[:-1]) + \
"/AllItems.aspx?" + urllib.parse.urlencode(_query)
getFiles(originalPath, os.path.join(download_path, i['FileLeafRef']), force, download_root, req=req, layers=layers+1)
# if is a file, download directly
else:
reqf = req.get(i[".spItemUrl"], headers=header)
filemeta = json.loads(reqf.text)
url, name, hash = filemeta["@content.downloadUrl"], filemeta["name"], filemeta["file"]["hashes"]
r = requests.get(url, stream = True)
total_length = int(r.headers.get('content-length', 0))
local_file = os.path.join(download_path, name)
# Check hash code of local file and cloud file
if not checkHashes(local_file, hash, download_root, force):
with open(os.path.join(download_path, name), 'wb') as f, \
tqdm(desc=os.path.relpath(local_file, download_root),total=total_length,
unit='iB',unit_scale=True,unit_divisor=1024) as bar:
for chunk in r.iter_content(chunk_size = 1024):
if chunk:
size = f.write(chunk)
bar.update(size)
pheader = {}
url = ""
async def fetch_with_pwd(iurl, password):
"""
Fetch data with data password
Args:
iurl: input share folder url
password: password of the share folder
"""
global pheader, url
browser = await launch(options={'args': ['--no-sandbox']})
page = await browser.newPage()
await page.goto(iurl, {'waitUntil': 'networkidle0'})
await page.focus("input[id='txtPassword']")
await page.keyboard.type(password)
verityElem = await page.querySelector("input[id='btnSubmitPassword']")
print("Password input complete, jumping")
await asyncio.gather(
page.waitForNavigation(),
verityElem.click(),
)
url = await page.evaluate('window.location.href', force_expr=True)
await page.screenshot({'path': 'example.png'})
print("Fetching cookies")
_cookie = await page.cookies()
pheader = ""
for __cookie in _cookie:
coo = "{}={};".format(__cookie.get("name"), __cookie.get("value"))
pheader += coo
await browser.close()
def havePwdGetFiles(iurl, password, download_path, force):
global header
asyncio.get_event_loop().run_until_complete(fetch_with_pwd(iurl, password))
header['cookie'] = pheader
getFiles(url, download_path, force)
def extractFiles(path, subset):
"""
Extract download data
Args:
path: path contains data to extract
"""
all_files = os.listdir(path)
for file in all_files:
if not file.endswith('.gz'):
continue
file_path = os.path.join(path, file)
extract_cmd = f'tar -xvf {file_path} -C {path}'
os.system(extract_cmd)
rm_cmd = f'rm {file_path}'
os.system(rm_cmd)
if subset == 'smpl_depth':
src_path = os.path.join(path, 'GeneBody')
mv_cmd = f'rsync -av {src_path}/* {path}'
os.system(mv_cmd)
rm_cmd = f'rm -rf {src_path}'
os.system(rm_cmd)
def moveFiles(root):
"""
Move pretrained models to benchmark directory
Args:
root: path contains download pretrained models
"""
os.makedirs('benchmarks', exist_ok=True)
models = os.listdir(os.path.join(root, 'pretrained_models'))
for model in models:
src_path = os.path.join(root, 'pretrained_models', model)
if model != 'gnr':
cmd = f'mv {src_path} benchmarks'
os.system(cmd)
else:
os.makedirs('./logs/genebody', exist_ok=True)
os.system(f'mv {src_path}/* ./logs/genebody/')
pretrain_path = os.path.join(root, 'pretrained_models')
rm_cmd = f'rm -rf {pretrain_path}'
os.system(rm_cmd)
if __name__ == "__main__":
args = parse_args()
genebody_root = args.genebody_root
force = args.force
subsets = [set_ for set_ in ['test10', 'train40', 'smpl_depth', 'pretrained_models'] if set in args.subset]
for subset in subsets:
if subset == 'train40':
pwd = input("Please input Train40 password, or contact with the author for data access: \n")
havePwdGetFiles(genebody_urls[subset], pwd, genebody_root, force)
elif subset == 'pretrained_models':
getFiles(genebody_urls[subset], os.path.join(genebody_root, 'pretrained_models'), force)
else:
getFiles(genebody_urls[subset], genebody_root, force)
if subset == 'pretrained_models':
moveFiles(genebody_root)
else:
extractFiles(genebody_root, subset)
================================================
FILE: genebody/gender.py
================================================
genebody_gender = {
"abror": "male",
"ahha": "female",
"alejandro": "male",
"amanda": "female",
"amaris": "female",
"anastasia": "female",
"aosilan": "male",
"arslan": "male",
"barlas": "male",
"barry": "male",
"camilo": "male",
"dannier": "male",
"dilshod": "male",
"fenghaohan": "male",
"fuzhizhi": "female",
"fuzhizhi2": "female",
"gaoxing": "female",
"huajiangtao3": "male",
"huajiangtao5": "male",
"ivan": "male",
"jinyutong": "female",
"jinyutong2": "female",
"joseph_matanda": "female",
"kamal_ejaz": "male",
"kemal": "male",
"lihongyun": "female",
"mahaoran": "male",
"maria": "female",
"natacha": "female",
"quyuanning": "female",
"rivera": "male",
"shchyerbina_oleksandrsongyujie": "male",
"soufianou_boubacar_moumouni": "male",
"sunyuxing": "male",
"Tichinah_jervier": "female",
"wangxiang": "male",
"wuwenyan": "male",
"xujiarui": "female",
"yaoqibin": "male",
"zhanghao": "male",
"zhanghongwei": "female",
"zhangzixiao": "female",
"zhengxin": "female",
"zhonglantai": "female",
"zhuna": "female",
"zhuna2": "female",
"zhuxuezhi": "male",
"songyujie": "male",
"rabbi": "male",
"zhangziyu": "female"
}
================================================
FILE: genebody/genebody.py
================================================
import os, sys
import numpy as np
import cv2, imageio
from .mesh import load_ply, load_obj_mesh, write_obj_mesh
import torch
from .gender import genebody_gender
def image_cropping(mask, padding=0.1):
"""
To better evaluate different metric on rendered images, we crop out the human performer and resize the cropped
image to the same resolution. This function provides returns the bound box of human performer given the mask.
mask: np.ndarry of mask
padding: padding of the bounding box
"""
a = np.where(mask != 0)
h, w = list(mask.shape[:2])
if len(a[0]) > 0: # valid mask
top, left, bottom, right = np.min(a[0]), np.min(a[1]), np.max(a[0]), np.max(a[1])
else: # mask failure
return 0,0,mask.shape[0],mask.shape[1]
bbox_h, bbox_w = bottom - top, right - left
# padd bbox
bottom = min(int(bbox_h*padding+bottom), h)
top = max(int(top-bbox_h*padding), 0)
right = min(int(bbox_w*padding+right), w)
left = max(int(left-bbox_h*padding), 0)
bbox_h, bbox_w = bottom - top, right - left
bbox_h = min(bbox_h, h, w)
bbox_w = min(bbox_w, h, w)
if bbox_h >= bbox_w:
w_c = (left+right) / 2
size = bbox_h
if w_c - size / 2 < 0:
left = 0
right = size
elif w_c + size / 2 >= w:
left = w - size
right = w
else:
left = int(w_c - size / 2)
right = left + size
h_c = (top+bottom) / 2
top = int(h_c - size / 2)
bottom = top + size
else: # bbox_w >= bbox_h
h_c = (top+bottom) / 2
size = bbox_w
if h_c - size / 2 < 0:
top = 0
bottom = size
elif h_c + size / 2 >= h:
top = h - size
bottom = h
else:
top = int(h_c - size / 2)
bottom = top + size
w_c = (left+right) / 2
left = int(w_c - size / 2)
right = left + size
return top, left, bottom, right
class GeneBodyReader():
def __init__(self, rootdir, loadsize=512):
self.rootdir = rootdir
self.split = np.load(os.path.join(rootdir, 'genebody_split.npy'), allow_pickle=True).item()
self.loadsize = loadsize
# the default seting of GNR is to use these four source views of GeneBody
self.sourceviews = ['01', '13', '25', '37']
self.gender = genebody_gender
self.rawsize = (2448, 2048)
def get_views(self, subject):
"""
Returns valid camera views of each sequence.
Note that there are several view missing subjects in GeneBody. More specifically,
"Tichinah_jervier" misses [32],
"wuwenyan" misses [34, 36],
"joseph_matanda" misses [39, 40, 42, 43, 44, 45, 46, 47]
subject: name of subject
all_views: all valid views of this subject
"""
all_views = sorted(os.listdir(os.path.join(self.rootdir, subject, 'image')))
## alt
# all_views = sorted(np.load(os.path.join(self.rootdir, subject, 'annots.npy'), allow_pickle=True).item()['cams'].keys())
return all_views
def get_frames(self, subject):
frame_list = []
frame_list = os.listdir(os.path.join(self.rootdir, subject, 'image', '00'))
frame_list = sorted([frame[:-4] for frame in frame_list])
return frame_list
def get_cameras(self, subject):
return np.load(os.path.join(self.rootdir, subject, 'annots.npy'), allow_pickle=True).item()['cams']
def get_smpl(self, subject, frame):
"""
Returns the smpl vertices and faces
frame: all frames of the subject <- self.get_frames(subject)
"""
smpl_path = os.path.join(self.rootdir, subject, 'smpl', frame+'.obj')
vert, face = load_obj_mesh(smpl_path)
return vert, face
def get_smpl_param(self, subject, frame):
"""
Returns the smpl parameters and smpl scale
frame_list: all frames of the subject <- self.get_frames(subject)
"""
param_path = os.path.join(self.rootdir, subject, 'param', frame+'.npy')
# global_orient and pose are Rodrigues rotation vector
param = np.load(param_path, allow_pickle=True).item()
# the smpl_param is a dictionary of smplx parameters which can be directory passed to a SMPLX forward pass
# via SMPLXLayer(**smpl_param) if each value of it is converted to torch.Tensor
smpl_param = param["smplx"]
for key in smpl_param.keys():
if isinstance(smpl_param[key], torch.Tensor):
smpl_param[key] = smpl_param[key].numpy()
# For GeneBody, we fit human performer in a wide age range, and SMPLx cannot fit well on kids and giants
# we use a smplx_scale outside SMPLX model via direct scaling.
# You can recover the smplx mesh in 'smpl' directory via SMPLX(**smpl_param) * smpl_scale
smpl_scale = param["smplx_scale"]
return smpl_param, smpl_scale
def get_data(self, subject, frame, camera_params, views):
"""
Fetch one frame of multiview data from database with cropping
subject: name of subject
frame: frame of the subject <- self.get_frames(subject)[frameid]
all_views: all views of subject <- self.get_views(subject)
camera_params: camera parameters <- self.get_annot(subject)
frame_id: eg. 1
views: list of views to fetch, eg. load sourceviews through self.sourceviews,
or all view through self.get_views(subject)
"""
subject_dir = os.path.join(self.rootdir, subject)
Ks, c2ws, Ds, images, masks = [], [], [], [], []
for view in views:
img = imageio.imread(os.path.join(subject_dir, 'image', view, frame+'.jpg'))
msk = imageio.imread(os.path.join(subject_dir, 'mask', view, f'mask{frame}.png'))
# crop the human out from raw image
top, left, bottom, right = image_cropping(msk)
img = img * (msk > 128).astype(np.uint8)[...,None]
# resize to uniform resolution
img = cv2.resize(img[top:bottom, left:right].copy(), (self.loadsize, self.loadsize), cv2.INTER_CUBIC)
images.append(img)
msk = cv2.resize(msk[top:bottom, left:right].copy(), (self.loadsize, self.loadsize), cv2.INTER_NEAREST)
masks.append(msk)
# adjust the camera intrinsic parameter because of the cropping and resize
# Note that there is no need to adjust extrinsic or distortation coefficents
K, c2w, D = camera_params[view]['K'].copy(), camera_params[view]['c2w'].copy(), camera_params[view]['D'].copy()
K[0,2] -= left
K[1,2] -= top
K[0,:] *= self.loadsize / float(right - left)
K[1,:] *= self.loadsize / float(bottom - top)
Ks.append(K)
c2ws.append(c2w)
Ds.append(D)
return images, masks, Ks, c2ws, Ds
def get_near_far(self, verts, c2w, pad=0.5):
"""
Get near far plane of perspective project from SMPL estimation
verts: SMPLx vertices
c2w: Camera to world roation matrix
pad: near far padding from SMPLx near far, set smaller if you want tighter bound, set larger if the accessory is huge.
"""
w2c = np.linalg.inv(c2w)
# Transform SMPLx to camera coordinate
vp = verts.dot(w2c[:3,:3].T) + w2c[:3,3:].T
vmin, vmax = vp.min(0), vp.max(0)
# near far are minmax in z axis
near, far = vmin[2], vmax[2]
near, far = near-(far-near)*pad, far+(far-near)*pad
return near, far
def smpl_from_param(self, model_path, subject, smpl_param, smpl_scale):
import smplx
smpl = smplx.SMPLX(
model_path=model_path,
gender=self.gender[subject],
use_pca=False,
)
smpl_param = smpl_param.copy()
for key in smpl_param.keys():
if isinstance(smpl_param[key], np.ndarray):
smpl_param[key] = torch.from_numpy(smpl_param[key])
output = smpl(**smpl_param)
verts = output['vertices'].numpy().reshape(-1,3)
# To align with keypoints3d saved in param, use the base keypoints only,
# if you want to use the full keypoints3d with extra joints and landmarks,
# please refer the the definition of joints in
# vertex_joint_selector.py and landmarks in vertices2landmarks in lbs.py
keypoints3d = output['joints'].numpy().reshape(-1,3)[:55]
return verts*smpl_scale, smpl.faces, keypoints3d*smpl_scale
if __name__ == "__main__":
## Here is a example
# python genebody/genebody.py path_to_genebody fuzhizhi
root = sys.argv[1]
subject = sys.argv[2]
genebody = GeneBodyReader(root)
print(subject, ' is a ', 'training set' if subject in genebody.split['train'] else 'test set')
views = genebody.get_views(subject)
frames = genebody.get_frames(subject)
camera_params = genebody.get_cameras(subject)
frame = frames[9]
imgs, msks, Ks, c2ws, Ds = genebody.get_data(subject, frame, camera_params, genebody.sourceviews)
print(f'loaded {len(imgs)} frames of images and masks in size of ', list(imgs[0].shape))
verts, faces = genebody.get_smpl(subject, frame)
smpl_param, smpl_scale = genebody.get_smpl_param(subject, frame)
print('mesh with size ', verts.shape, ' body scale ', smpl_scale)
near, far = genebody.get_near_far(verts, c2ws[0])
print('the near far is ', near, far)
# to test smplx parameter tor smplx mesh, please try the following command
# python genebody/genebody.py path_to_genebody fuzhizhi path_to_smplx
if len(sys.argv) == 4:
import trimesh
smplx_path = sys.argv[3]
verts_from_param, faces_from_param, kpts_from_param = genebody.smpl_from_param(smplx_path, subject, smpl_param, smpl_scale)
print('average error of parameter generated smplx is ', np.abs(verts_from_param-verts).mean())
print('average error of parameter generated keypoints is ', np.abs(smpl_param['keypoints3d'].reshape(-1,3)-kpts_from_param).mean())
print(verts.min(0), verts.max(0), kpts_from_param.min(0), kpts_from_param.max(0))
smpl_mesh = trimesh.Trimesh(verts_from_param, faces_from_param)
smpl_mesh.export(f'{subject}.obj')
================================================
FILE: genebody/mesh.py
================================================
import numpy as np
import struct
import sys, os, re
import cv2
if sys.version_info[0] == 3:
from functools import reduce
# type: (regularExpression, structPack, stringFormat, numpyType, isFloat)
decode_map = {
'bool': ('([01])', '?', '%d', 1, np.dtype('bool'), False),
'uchar':('([0-9]{1,3})', 'B', '%d', 1, np.uint8, False, 1),
'uint8':('([0-9]{1,3})', 'B', '%d', 1, np.uint8, False),
'byte': ('([0-9]{1,3})', 'B', '%d', 1, np.uint8, False),
'unsigned char':('([0-9]{1,3})', 'B', '%d', 1, np.uint8, False),
'char': ('(-?[0-9]{1,3})', 'b', '%d', 1, np.int8, False, 1),
'int8': ('(-?[0-9]{1,3})', 'b', '%d', 1, np.int8, False),
'ushort': ('([0-9]{1,5})', 'H', '%d', 2, np.uint16, False, 1),
'uint16': ('([0-9]{1,5})', 'H', '%d', 2, np.uint16, False),
'unsigned short': ('([0-9]{1,5})', 'H', '%d', 2, np.uint16, False),
'short': ('(-?[0-9]{1,5})', 'h', '%d', 2, np.int16, False, 1),
'int16': ('(-?[0-9]{1,5})', 'h', '%d', 2, np.int16, False),
'half': ('(-?[0-9]*\.?[0-9]*[eE]?[-\+]?[0-9]*)', 'e', '%f', 2, np.float16, True, 1),
'float16':('(-?[0-9]*\.?[0-9]*[eE]?[-\+]?[0-9]*)', 'e', '%f', 2, np.float16, True),
'uint': ('([0-9]{1,10})', 'I', '%u', 4, np.uint32, False, 1),
'uint32':('([0-9]{1,10})', 'I', '%u', 4, np.uint32, False),
'ulong': ('([0-9]{1,10})', 'I', '%u', 4, np.uint32, False),
'unsigned':('([0-9]{1,10})', 'I', '%u', 4, np.uint32, False),
'unsigned int':('([0-9]{1,10})', 'I', '%u', 4, np.uint32, False),
'unsigned long':('([0-9]{1,10})', 'I', '%u', 4, np.uint32, False),
'int': ('(-?[0-9]{1,10})', 'i', '%d', 4, np.int32, False, 1),
'long': ('(-?[0-9]{1,10})', 'i', '%d', 4, np.int32, False),
'int32':('(-?[0-9]{1,10})', 'i', '%d', 4, np.int32, False),
'float':('(-?[0-9]*\.?[0-9]*[eE]?[-\+]?[0-9]*)', 'f', '%f', 4, np.float32, True, 1),
'single':('(-?[0-9]*\.?[0-9]*[eE]?[-\+]?[0-9]*)', 'f', '%f', 4, np.float32, True),
'float32':('(-?[0-9]*\.?[0-9]*[eE]?[-\+]?[0-9]*)', 'f', '%f', 4, np.float32, True),
'uint64':('([0-9]{1,20})', 'Q', '%lu', 8, np.uint64, False, 1),
'ullong':('([0-9]{1,20})', 'Q', '%lu', 8, np.uint64, False),
'unsigned long long':('([0-9]{1,20})', 'Q', '%lu', 8, np.uint64, False),
'int64':('(-?[0-9]{1,19})', 'q', '%ld', 8, np.int64, False, 1),
'llong':('(-?[0-9]{1,19})', 'q', '%ld', 8, np.int64, False),
'long long':('(-?[0-9]{1,19})', 'q', 8, np.int64, False),
'double':('(-?[0-9]*\.?[0-9]*[eE]?[-\+]?[0-9]*)', 'd', '%f', 8, np.float64, True, 1),
'float64':('(-?[0-9]*\.?[0-9]*[eE]?[-\+]?[0-9]*)', 'd', '%f', 8, np.float64, True),
}
def max_precision(type1, type2):
if decode_map[type1][5]:
if decode_map[type2][5]:
if decode_map[type1][3] < decode_map[type2][3]:
return type2
else:
return type1
else:
if decode_map[type1][3] < decode_map[type2][3]:
for t, c in decode_map.items():
if c[5] and c[3] >= decode_map[type2][3]:
return t
else:
return type1
elif decode_map[type2][5]:
if decode_map[type2][3] < decode_map[type1][3]:
for t, c in decode_map.items():
if c[5] and c[3] >= decode_map[type1][3]:
return t
else:
return type2
else:
if decode_map[type2][3] < decode_map[type1][3]:
return type1
elif decode_map[type2][3] < decode_map[type1][3]:
return type2
elif decode_map[type2][4] == decode_map[type2][4]:
return type1
else:
for t, c in decode_map.items():
if not c[5] and c[3] > decode_map[type1][3]:
return t
return 'unsigned long long'
def decode(content, structure, num, form):
if form.lower() == 'ascii':
l = 0; d = []; lines = content.split('\n')
for i in range(num):
s = [j for j in lines[i].split(' ') if len(j) > 0]
l += len(lines[i]) + 1; k = []; j = 0
while j < len(s) and len(k) < len(structure):
t = structure[len(k)]
if t[:4] == 'list':
n = int(s[j]); t = t.split(':')[-1]
k += [[float(s[i]) if decode_map[t][5] else int(s[i]) \
for i in range(j+1,j+n+1)]]
j += n + 1
else:
k += [float(s[j]) if decode_map[t][5] else int(s[j])]
j += 1
d += [k]
else:
if form.lower() == 'binary_little_endian':
c = '<'
elif form.lower() == 'binary_big_endian':
c = '>'
l = 0; d = []
for i in range(num):
k = []
while len(k) < len(structure) and l < len(content):
t = structure[len(k)]
if t[:4] == 'list':
t = t.split(':')
n = struct.unpack(c+decode_map[t[1]][1], \
content[l:l+decode_map[t[1]][3]])[0]
l += decode_map[t[1]][3]
k += [struct.unpack(c+decode_map[t[2]][1]*n, \
content[l:l+decode_map[t[2]][3]*n])]
l += decode_map[t[2]][3]*n
else:
k += [struct.unpack(c+decode_map[t][1], \
content[l:l+decode_map[t][3]])[0]]
l += decode_map[t][3]
d += [k]
try:
t = reduce(max_precision, [t if t[:4] != 'list' \
else t.split(':')[-1] for t in structure])
d = np.array(d, dtype = decode_map[t][4])
except ValueError:
print('Warning: Not in Matrix')
return d, content[l:]
def load_ply(file_name):
try:
with open(file_name, 'r') as f:
head = f.readline().strip()
if head.lower() != 'ply':
raise('Error: Not a valid PLY file')
content = f.read()
i = content.find('end_header\n')
if i < 0:
raise('Error: Not a valid PLY file')
info = [[l for l in line.split(' ') if len(l) > 0] \
for line in content[:i].split('\n')]
content = content[i+11:]
except UnicodeDecodeError as e:
with open(file_name, 'rb') as f:
head = f.readline().strip()
if sys.version_info[0] == 3:
head = str(head)[2:-1]
else:
head = str(head)
if head.lower() != 'ply':
raise('Error: Not a valid PLY file')
content = f.read()
i = content.find(b'end_header\n')
if i < 0:
raise('Error: Not a valid PLY file')
if sys.version_info[0] == 3:
cnt = str(content[:i])[2:-1].replace('\\n', '\n')
else:
cnt = str(content[:i])
info = [[l for l in line.split(' ') if len(l) > 0] \
for line in cnt.split('\n')]
content = content[i+11:]
form = 'ascii'
elem_names = []
elem = {}
for i in info:
if len(i) >= 2 and i[0] == 'format':
form = i[1]
elif len(i) >= 3 and i[0] == 'element':
if len(elem_names) > 0:
elem[elem_names[-1]] = (structure_name, structure)
elem_names += [(i[1], int(i[2]))]
structure_name = []
structure = []
elif len(i) >= 3 and i[0] == 'property' and len(elem_names) > 0:
structure_name += [i[-1]]
if i[1] == 'list' and len(i) >= 5:
structure += [i[1] + ':' + i[2] + ':' + ' '.join(i[3:-1])]
else:
structure += [' '.join(i[1:-1])]
if len(elem_names) > 0:
elem[elem_names[-1]] = (structure_name, structure)
elem_ = {}
for k in elem_names:
d, content = decode(content, elem[k][1], k[1], form)
if 'face' in k[0] and isinstance(d, np.ndarray):
d = d.reshape((k[1], -1))
elem_[k[0]] = d # elem[k] = (elem[k][0], d)
return elem_
def save_ply(file_name, elems, _type = 'binary_little_endian', comments = []):
_type = _type.lower()
types = {}
if isinstance(comments, str):
comments = [comments]
comments = [c for l in comments for c in l.split('\n')]
with open(file_name, 'w') as f:
f.write('ply\nformat %s 1.0\n' % _type)
for comment in comments:
f.write('comment %s\n' % comment)
for key, elem in elems.items():
f.write('element %s %d\n' % (key, len(elem)))
if isinstance(elem, np.ndarray):
for e, c in decode_map.items():
if len(c) > 6 and c[4] == elem.dtype:
c = (e, c[1], c[2]); break
else:
c = ('int', 'i', '%d')
for e in elem:
if hasattr(e, '__len__'):
for i in range(len(e)):
if int(e[i]) != e[i]:
c = ('float','f','%f'); break
if i != len(e): break
elif int(e) != e:
c = ('float','f','%f'); break
if 'face' in key:
tag = 'vertex_index'
max_num = max([len(e) for e in elem])
if max_num < 256:
l = ('uchar', 'B', '%d')
elif max_num < 65536:
l = ('ushort', 'H', '%d')
elif max_num < 4294967296:
l = ('uint', 'I', '%d')
else:
l = ('uint64', 'Q', '%d')
f.write('property list %s %s %s\n' % \
(l[0], c[0], tag))
types[key] = (l, c)
else:
if 'vert' in key:
if len(elem) > 0:
if len(elem[0]) > 4:
tag = ['x', 'y', 'z', 'red', 'green', 'blue', 'alpha']
else:
tag = ['x', 'y', 'z', 'w']
elif 'norm' or 'texcoord' in key:
tag = ['x', 'y', 'z', 'w']
elif 'color' in key:
tag = ['red', 'green', 'blue', 'alpha']
else:
tag = []
if len(elem) > 0:
for j in range(len(elem[0])):
f.write('property %s %s\n' % (c[0], tag[j] \
if len(tag) > j else 'k%d'%(j-len(tag))))
types[key] = c
f.write('end_header\n')
with open(file_name, 'ab' if 'binary' in _type else 'a') as f:
for key, elem in elems.items():
c = types[key]
if len(c) == 2:
for e in elem:
l = len(e)
if 'ascii' in _type:
f.write((c[0][2]+(' '+c[1][2])*l+'\n') % \
tuple([l]+list(e)))
elif 'little' in _type:
f.write(struct.pack('<'+c[0][1], l))
f.write(struct.pack('<'+c[1][1]*l, *e))
elif 'big' in _type:
f.write(struct.pack('>'+c[0][1], l))
f.write(struct.pack('>'+c[1][1]*l, *e))
else:
seg = len(elem[0]) if len(elem) > 0 else 1
elem = [i for l in elem for i in l] \
if isinstance(elem, np.ndarray) else elem.reshape(-1)
if 'ascii' in _type:
for i in range(len(elem)):
f.write((c[2] + '\n' if (i + 1) % seg == 0 \
else c[2] + ' ') % elem[i])
elif 'little' in _type:
f.write(struct.pack('<'+c[1]*len(elem), *elem))
elif 'big' in _type:
f.write(struct.pack('>'+c[1]*len(elem), *elem))
def normalize_v3(arr):
''' Normalize a numpy array of 3 component vectors shape=(n,3) '''
lens = np.sqrt(arr[:, 0] ** 2 + arr[:, 1] ** 2 + arr[:, 2] ** 2)
eps = 0.00000001
lens[lens < eps] = eps
arr[:, 0] /= lens
arr[:, 1] /= lens
arr[:, 2] /= lens
return arr
def compute_normal(vertices, faces):
# Create a zeroed array with the same type and shape as our vertices i.e., per vertex normal
norm = np.zeros(vertices.shape, dtype=vertices.dtype)
# Create an indexed view into the vertex array using the array of three indices for triangles
tris = vertices[faces]
# Calculate the normal for all the triangles, by taking the cross product of the vectors v1-v0, and v2-v0 in each triangle
n = np.cross(tris[::, 1] - tris[::, 0], tris[::, 2] - tris[::, 0])
# n is now an array of normals per triangle. The length of each normal is dependent the vertices,
# we need to normalize these, so that our next step weights each normal equally.
normalize_v3(n)
# now we have a normalized array of normals, one per triangle, i.e., per triangle normals.
# But instead of one per triangle (i.e., flat shading), we add to each vertex in that triangle,
# the triangles' normal. Multiple triangles would then contribute to every vertex, so we need to normalize again afterwards.
# The cool part, we can actually add the normals through an indexed view of our (zeroed) per vertex normal array
norm[faces[:, 0]] += n
norm[faces[:, 1]] += n
norm[faces[:, 2]] += n
normalize_v3(norm)
return norm
def load_obj_mesh(mesh_file, with_normal=False, with_texture=False, with_texture_image=False):
vertex_data = []
norm_data = []
uv_data = []
face_data = []
face_norm_data = []
face_uv_data = []
if isinstance(mesh_file, str):
f = open(mesh_file, "r")
else:
f = mesh_file
for line in f:
if isinstance(line, bytes):
line = line.decode("utf-8")
if line.startswith('#'):
continue
values = line.split()
if not values:
continue
if values[0] == 'v':
v = list(map(float, values[1:4]))
vertex_data.append(v)
elif values[0] == 'vn':
vn = list(map(float, values[1:4]))
norm_data.append(vn)
elif values[0] == 'vt':
vt = list(map(float, values[1:3]))
uv_data.append(vt)
elif values[0] == 'f':
# quad mesh
if len(values) > 4:
f = list(map(lambda x: int(x.split('/')[0]), values[1:4]))
face_data.append(f)
f = list(map(lambda x: int(x.split('/')[0]), [values[3], values[4], values[1]]))
face_data.append(f)
# tri mesh
else:
f = list(map(lambda x: int(x.split('/')[0]), values[1:4]))
face_data.append(f)
# deal with texture
if len(values[1].split('/')) >= 2:
# quad mesh
if len(values) > 4:
f = list(map(lambda x: int(x.split('/')[1]), values[1:4]))
face_uv_data.append(f)
f = list(map(lambda x: int(x.split('/')[1]), [values[3], values[4], values[1]]))
face_uv_data.append(f)
# tri mesh
elif len(values[1].split('/')[1]) != 0:
f = list(map(lambda x: int(x.split('/')[1]), values[1:4]))
face_uv_data.append(f)
# deal with normal
if len(values[1].split('/')) == 3:
# quad mesh
if len(values) > 4:
f = list(map(lambda x: int(x.split('/')[2]), values[1:4]))
face_norm_data.append(f)
f = list(map(lambda x: int(x.split('/')[2]), [values[3], values[4], values[1]]))
face_norm_data.append(f)
# tri mesh
elif len(values[1].split('/')[2]) != 0:
f = list(map(lambda x: int(x.split('/')[2]), values[1:4]))
face_norm_data.append(f)
elif 'mtllib' in line.split():
mtlname = line.split()[-1]
mtlfile = os.path.join(os.path.dirname(mesh_file), mtlname)
with open(mtlfile, 'r') as fmtl:
mtllines = fmtl.readlines()
for mtlline in mtllines:
# if mtlline.startswith('map_Kd'):
if 'map_Kd' in mtlline.split():
texname = mtlline.split()[-1]
texfile = os.path.join(os.path.dirname(mesh_file), texname)
texture_image = cv2.imread(texfile)
texture_image = cv2.cvtColor(texture_image, cv2.COLOR_BGR2RGB)
break
vertices = np.array(vertex_data)
faces = np.array(face_data) - 1
if with_texture and with_normal:
uvs = np.array(uv_data)
face_uvs = np.array(face_uv_data) - 1
norms = np.array(norm_data)
if norms.shape[0] == 0:
norms = compute_normal(vertices, faces)
face_normals = faces
else:
norms = normalize_v3(norms)
face_normals = np.array(face_norm_data) - 1
if with_texture_image:
return vertices, faces, norms, face_normals, uvs, face_uvs, texture_image
else:
return vertices, faces, norms, face_normals, uvs, face_uvs
if with_texture:
uvs = np.array(uv_data)
face_uvs = np.array(face_uv_data) - 1
return vertices, faces, uvs, face_uvs
if with_normal:
# norms = np.array(norm_data)
# norms = normalize_v3(norms)
# face_normals = np.array(face_norm_data) - 1
norms = np.array(norm_data)
if norms.shape[0] == 0:
norms = compute_normal(vertices, faces)
face_normals = faces
else:
norms = normalize_v3(norms)
face_normals = np.array(face_norm_data) - 1
return vertices, faces, norms, face_normals
return vertices, faces
def write_obj_mesh(filename, verts, faces):
with open(filename, 'w') as f:
for vert in verts:
f.write('v %f %f %f\n' % tuple(list(vert)))
for face in faces+1:
f.write('f %d %d %d\n' % tuple(list(face)))
================================================
FILE: lib/data/GeneBodyDataset.py
================================================
from re import sub
import torch
from torch.utils.data import Dataset
import torchvision.transforms as transforms
from PIL import Image
import numpy as np
import cv2
import sys
import os
import imageio
from ..mesh_util import save_obj_mesh
base_dir = os.path.dirname(os.path.abspath(__file__))
sys.path.append(os.path.join(base_dir,'..'))
sys.path = sys.path[:-1]
# import trimesh
from lib.ply_util import load_ply
from genebody.mesh import load_obj_mesh
import scipy.interpolate as interpolater
import random
from genebody.gender import genebody_gender
def mask_padding(mask, border = 5):
kernel = np.ones((border, border), np.uint8)
msk_erode = cv2.erode(mask.copy(), kernel)
msk_dilate = cv2.dilate(mask.copy(), kernel)
# retain the origin hard mask and create a soft padding
mask = mask > 0
mask[(msk_dilate - msk_erode) > 0] = 1
return mask
def euler2rot(euler):
sin, cos = np.sin, np.cos
phi, theta, psi = euler[0], euler[1], euler[2]
R1 = np.array([[1, 0, 0],
[0, cos(phi), sin(phi)],
[0, -sin(phi), cos(phi)]])
R2 = np.array([[cos(theta), 0, -sin(theta)],
[0, 1, 0],
[sin(theta), 0, cos(theta)]])
R3 = np.array([[cos(psi), sin(psi), 0],
[-sin(psi), cos(psi), 0],
[0, 0, 1]])
R = R1 @ R2 @ R3
return R
def rot2euler(R):
phi = np.arctan2(R[1,2], R[2,2])
theta = -np.arcsin(R[0,2])
psi = np.arctan2(R[0,1], R[0,0])
return np.array([phi, theta, psi])
def gen_cam_views(transl, z_pitch, viewnum):
def viewmatrix(z, up, translation):
vec3 = z / np.linalg.norm(z)
up = up / np.linalg.norm(up)
vec1 = np.cross(up, vec3)
vec2 = np.cross(vec3, vec1)
view = np.stack([vec1, vec2, vec3, translation], axis=1)
view = np.concatenate([view, np.array([[0,0,0,1]])], axis=0)
return view
cam_poses = []
for i, theta in enumerate(np.linspace(-np.pi/2, 1.5*np.pi, viewnum+1)[:-1]):
theta = -theta
dist = 2.9
z = np.array([np.cos(theta), 0, np.sin(theta)])
t = -z * dist + transl
z = z * np.sqrt(1-z_pitch*z_pitch)
z[1] = z_pitch
z = z * dist
up = np.array([0,1,0])
view = viewmatrix(z, up, t)
cam_poses.append(view)
return cam_poses
class GeneBodyDataset(Dataset):
@staticmethod
def modify_commandline_options(parser):
return parser
def __init__(self, opt, phase='eval', root=None, move_cam=0):
super(GeneBodyDataset, self).__init__()
self.opt = opt
self.is_train = phase == 'train'
self.is_render = phase == 'render'
self.projection_mode = 'perspective'
self.eval_skip = self.opt.eval_skip
self.train_skip = self.opt.train_skip
self.genebody_seq_len = 150
self.root = root if root is not None else opt.dataroot
self.phase = 'val'
self.load_size = self.opt.loadSize
self.B_MIN = np.array([-128, -28, -128])
self.B_MAX = np.array([128, 228, 128])
self.num_views = self.opt.num_views
self.input_views = [1,13,25,37]
self.test_views = sorted(list(range(48)))
# self.sequences = self.opt.ghr_seq if phase == 'train' else self.opt.ghr_test_seq
self.split = np.load(os.path.join(self.root, 'genebody_split.npy'), allow_pickle=True).item()
self.sequences = self.split['train'] if self.is_train else ['natacha']#self.split['test']
self.frames, self.cam_names, self.subjects, self.frames_id = self.get_frames()
self.load_smpl_param = any([self.opt.use_smpl_sdf, self.opt.use_t_pose])
self.load_smpl_mesh = any([self.opt.use_smpl_sdf, self.opt.use_t_pose])
self.smpl_type = self.opt.smpl_type
self.smpl_t_pose = load_obj_mesh(os.path.join(self.opt.t_pose_path, f'{self.smpl_type}.obj'))
self.use_smpl_depth = opt.use_smpl_depth
# PIL to tensor
self.to_tensor_normal = transforms.Compose([
transforms.Resize(self.load_size),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
self.to_tensor = transforms.Compose([
transforms.Resize(self.load_size),
transforms.ToTensor()
])
self.move_cam = move_cam if not self.is_train else 0
self.use_white_bkgd = self.opt.use_white_bkgd
def get_frames(self, i = 0):
sequences = self.sequences
frames, subjects, cam_names, frames_id = [], [], [], []
i = self.input_views[i % self.num_views]
for seq in sequences:
if os.path.exists(os.path.join(self.root, seq)):
files = sorted([f for f in os.listdir(os.path.join(self.root, \
seq, 'smpl')) \
if f[-4:] == '.obj'])
files = sorted(files)
cam_names += ['%02d' for i in range(len(files))]
subjects += [seq for i in range(len(files))]
frames_id += list(range(len(files)))
for f in files:
f = f[:-4]
frames += [f]
return frames, cam_names, subjects, frames_id
def get_render_poses(self, annots, move_cam=150):
height, pitch = [], []
for view in range(1,48,3):
view = '%02d' % view
if view in annots.keys():
height.append(annots[view]['c2w'][1, 3])
z_rodrigous = annots[view]['c2w'][:3,:3]@np.array([[0],[0],[1]])
pitch.append(z_rodrigous[1,0])
transl = np.array([0, np.mean(np.array(height)), 0])
z_pitch = np.mean(np.array(pitch))
render_poses = gen_cam_views(transl, z_pitch, move_cam)
return render_poses
def __len__(self):
if self.is_train:
return len(self.frames) * len(self.test_views) // self.train_skip
else:
return len(self.frames) // self.eval_skip
def image_cropping(self, mask):
a = np.where(mask != 0)
h, w = list(mask.shape[:2])
if len(a[0]) > 0:
top, left, bottom, right = np.min(a[0]), np.min(a[1]), np.max(a[0]), np.max(a[1])
else:
return 0,0,mask.shape[0],mask.shape[1]
bbox_h, bbox_w = bottom - top, right - left
# padd bbox
bottom = min(int(bbox_h*0.1+bottom), h)
top = max(int(top-bbox_h*0.1), 0)
right = min(int(bbox_w*0.1+right), w)
left = max(int(left-bbox_h*0.1), 0)
bbox_h, bbox_w = bottom - top, right - left
if bbox_h >= bbox_w:
w_c = (left+right) / 2
size = bbox_h
if w_c - size / 2 < 0:
left = 0
right = size
elif w_c + size / 2 >= w:
left = w - size
right = w
else:
left = int(w_c - size / 2)
right = left + size
else: # bbox_w >= bbox_h
h_c = (top+bottom) / 2
size = bbox_w
if h_c - size / 2 < 0:
top = 0
bottom = size
elif h_c + size / 2 >= h:
top = h - size
bottom = h
else:
top = int(h_c - size / 2)
bottom = top + size
return top, left, bottom, right
def get_near_far(self, smpl_verts, w2c):
vp = smpl_verts.dot(w2c[:3,:3].T) + w2c[:3,3:].T
vmin, vmax = vp.min(0), vp.max(0)
near, far = vmin[2], vmax[2]
near, far = near-(far-near)/2, far+(far-near)/2
return near, far
# def get_realworld_scale(self, smpl_verts, bbox, w2c, K):
# smpl_min, smpl_max = smpl_verts.min(0), smpl_verts.max(0)
# # reprojected smpl verts
# vp = smpl_verts.dot(w2c[:3,:3].T) + w2c[:3,3:].T
# vp = vp.dot(K.T)
# vp = vp[:,:2] / (vp[:,2:]+1e-8)
# vmin, vmax = vp.min(0), vp.max(0)
# # compare with bounding box
# bbox_h = bbox[1][0] - bbox[0][0]
# bbox_w = bbox[1][1] - bbox[0][1]
# long_axis = bbox_h/(vmax[1]-vmin[1])*(smpl_max[1]-smpl_min[1]) if bbox_h > bbox_w else bbox_w/(vmax[0]-vmin[0])*(smpl_max[0]-smpl_min[0])
# spatial_freq = 180/long_axis/0.5
# return spatial_freq
def get_image(self, sid, num_views, view_id=None, random_sample=False, smpl_verts=None):
frame = self.frames[sid]
subject = self.subjects[sid]
# some of the sequence has some view missing
if subject == 'wuwenyan':
test_views = list(set(self.test_views)-set([34, 36]))
elif (subject == 'dannier' or subject == 'Tichinah_jervier'):
test_views = list(set(self.test_views)-set([32]))
elif subject == 'joseph_matanda':
test_views = list(set(self.test_views) - set([39, 40, 42, 43, 44, 45, 46, 47]))
elif subject in ['anastasia', 'aosilan']:
test_views = list(set(self.test_views) - set(range(16,24)))
else:
test_views = self.test_views
test_views = sorted(test_views)
# Select a random view_id from self.max_view_angle if not given
if self.is_train:
if view_id is None or random_sample:
view_id = test_views[np.random.randint(len(test_views))]
else:
view_id = test_views[view_id % len(test_views)]
# The ids are an even distribution of num_views around view_id
view_ids = self.input_views + [view_id]
else:
if self.is_render:
view_ids = self.input_views
else:
view_ids = self.input_views + test_views
calib_list = []
image_list = []
mask_list = []
extrinsic_list = []
bbox_list = []
smpl_depth_list = []
spatial_freqs = []
annot_path = os.path.join(self.root, subject, f'annots.npy')
annots = np.load(annot_path, allow_pickle = True).item()['cams']
for i, vid in enumerate(view_ids):
view = '%02d' % vid
mask_folder = 'mask'
mask_path = os.path.join(self.root, subject, mask_folder, self.cam_names[sid] % vid)
mask_path = [os.path.join(mask_path,f) for f in os.listdir(mask_path) \
if frame in f]
image_path = os.path.join(self.root, subject, 'image', self.cam_names[sid] % vid)
image_path = [os.path.join(image_path,f) for f in os.listdir(image_path) \
if frame in f]
image_np = imageio.imread(image_path[0])
mask_np = imageio.imread(mask_path[0])
size = image_np.shape
if self.use_smpl_depth and i < self.num_views:
smpl_depth_path = os.path.join(self.root, subject, 'smpl_depth', self.cam_names[sid] % vid)
smpl_depth_path = [os.path.join(smpl_depth_path,f) for f in os.listdir(smpl_depth_path) \
if frame in f]
smpl_depth = imageio.imread(smpl_depth_path[0])
smpl_depth = smpl_depth.astype(np.float32) / 1000.0
top, left, bottom, right = self.image_cropping(mask_np)
mask_np = mask_np[top:bottom, left:right]
image_crop = image_np[top:bottom, left:right]
mask_np = cv2.resize(mask_np.copy(), (self.load_size,self.load_size), \
interpolation = cv2.INTER_NEAREST)
image_crop = cv2.resize(image_crop.copy(), (self.load_size,self.load_size), \
interpolation = cv2.INTER_CUBIC)
image = Image.fromarray(image_crop)
mask_np = mask_np > 128
if self.use_smpl_depth and i < self.num_views:
smpl_depth = smpl_depth[top:bottom, left:right]
smpl_depth = cv2.resize(smpl_depth, (self.load_size,self.load_size), \
interpolation = cv2.INTER_NEAREST)
mask_np = np.logical_or(mask_np, smpl_depth > 0)
smpl_depth_list.append(torch.from_numpy(smpl_depth))
a = np.where(mask_np != 0)
try:
bbox = [[np.min(a[0]), np.min(a[1])], [np.max(a[0]), np.max(a[1])]] if len(a[0]) > 0 else \
[[0, 0], [self.load_size, self.load_size]]
except:
print(os.path.join(self.root, subject, mask_folder, self.cam_names[sid] % vid))
print(top, left, bottom, right)
print(mask_np)
exit(0)
bbox_list.append(bbox)
mask = torch.from_numpy(mask_np.astype(np.float32)).view(1,self.load_size,self.load_size)
mask_list.append(mask)
image = self.to_tensor(image) if i >= num_views else self.to_tensor_normal(image)
if i >= self.num_views and self.use_white_bkgd:
image = image * mask + (1. - mask)
image = mask.type(image.dtype).expand(3,-1,-1) * image
rgb = image.cpu().numpy().transpose([1,2,0])
K = np.array(annots[view]['K'], dtype=np.float32)
K[0,2] -= left
K[1,2] -= top
K[0,:] *= self.load_size / float(right - left)
K[1,:] *= self.load_size / float(bottom - top)
c2w = np.array(annots[view]['c2w'], dtype=np.float32)
w2c = np.linalg.inv(c2w)
dist = np.array(annots[view]['D'], dtype = np.float32)
# determine near far plane from smpl estimation
near, far = self.get_near_far(smpl_verts, w2c)
# # determine valid body part from smpl and bounding box
# if i < self.num_views:
# spatial_freq = self.get_realworld_scale(smpl_verts, bbox, w2c, K)
# spatial_freqs.append(spatial_freq)
calib = torch.Tensor([K[0,0],K[1,1],K[0,2],K[1,2]]+list(dist.reshape(-1))+[near,far]).float()
extrinsic = torch.from_numpy(w2c)
image_list.append(image)
calib_list.append(calib)
extrinsic_list.append(extrinsic)
# bbox = np.array(bbox_list).reshape(-1,4)
if not self.is_train and self.move_cam > 0:
bboxs = np.array(bbox_list[:self.num_views]).reshape(-1,2)
else:
bboxs = np.array(bbox_list[self.num_views:]).reshape(-1,2)
centroid = np.array([mask_np.shape[0], mask_np.shape[1]]) / 2
bbox = (np.max(np.abs(bboxs - centroid), axis=0) * \
np.array([1, np.sqrt(2)])).astype(np.int32)
bbox = np.array([centroid - bbox, centroid + bbox]).T
bbox = np.clip(bbox.reshape(-1), 0, self.load_size)
# spatial_freq = min(spatial_freqs)
if self.is_render:
# render free view point video on full image resolution
render_id = sid % (self.genebody_seq_len // self.eval_skip)
render_c2ws = self.get_render_poses(annots, self.move_cam)
w2c = np.linalg.inv(render_c2ws[render_id])
K = annots['25']['K']
render_extrinsics = torch.from_numpy(w2c.astype(np.float32))
near, far = self.get_near_far(smpl_verts, w2c)
render_calibs = torch.Tensor([K[0,0],K[1,1],K[0,2],K[1,2]]+list(np.zeros_like(dist))+[near,far]).float()
bbox = np.array([0, size[0], 0, size[1]])
extrinsic_list = extrinsic_list[:self.num_views] + [render_extrinsics]
calib_list = calib_list[:self.num_views] + [render_calibs]
mask_list = mask_list[:self.num_views]
if not self.is_train and self.move_cam is 0:
gt_list = image_list[num_views:]
image_list = image_list[:num_views]
return {
'img': torch.stack(image_list, dim=0),
'mask': torch.stack(mask_list[:num_views], dim=0),
'persps': torch.stack(calib_list, dim=0),
'calib': torch.stack(extrinsic_list, dim=0),
'bbox': bbox,
'render_gt': torch.stack(gt_list, dim = 0) if not self.is_train and self.move_cam is 0 else [],
'smpl_depth': torch.stack(smpl_depth_list[:self.num_views], dim=0) if self.opt.use_smpl_depth else [],
'img_i': view_ids,
# 'body_scale': spatial_freq,
'center': torch.from_numpy((smpl_verts.max(0)+smpl_verts.min(0))/2),
}
def smpl_from_param(self, model_path, subject, smpl_param, smpl_scale):
import smplx
smpl = smplx.SMPLX(
model_path=model_path,
gender='NEUTRAL',#genebody_gender[subject],
use_pca=False,
)
smpl_param = smpl_param.copy()
for key in smpl_param.keys():
if isinstance(smpl_param[key], np.ndarray):
smpl_param[key] = torch.from_numpy(smpl_param[key])
output = smpl(**smpl_param)
verts = output['vertices'].numpy().reshape(-1,3)
# To align with keypoints3d saved in param, use the base keypoints only,
# if you want to use the full keypoints3d with extra joints and landmarks,
# please refer the the definition of joints in
# vertex_joint_selector.py and landmarks in vertices2landmarks in lbs.py
keypoints3d = output['joints'].numpy().reshape(-1,3)[:55]
return verts*smpl_scale, smpl.faces, keypoints3d*smpl_scale
def get_item(self, index):
sid = index % len(self.frames)
vid =(index // len(self.frames)) % len(self.test_views)
frame = self.frames[sid]
subject = self.subjects[sid]
old = 'old' if os.path.exists(os.path.join(self.root, subject, 'oldsmpl')) else ''
res = {
'name': subject+'_'+frame+'_'+str(vid),
'mesh_path': os.path.join(self.root, subject, 'smpl', '%d.ply' % sid),
'sid': sid,
'vid': vid,
}
# load smpl data
param_dir = os.path.join(self.root, subject, f'{old}param')
param_path = [os.path.join(param_dir,f) for f in os.listdir(param_dir) if frame in f]
param = np.load(os.path.join(param_path[0]), allow_pickle=True).item()
scale, param = param['smplx_scale'], param['smplx']
res['body_scale'] = scale
for key in param.keys():
if isinstance(param[key], torch.Tensor):
param[key] = param[key].numpy()
# param['jaw_pose'] = np.zeros_like(param['jaw_pose'])
# smplx_model_path = '/home/SENSETIME/chengwei/Projects/bodyfitting_release/data/smplx'
# vert, face, kp3d = self.smpl_from_param(smplx_model_path, subject, param, scale)
smpl_dir = os.path.join(self.root, subject, f'{old}smpl')
# os.makedirs(smpl_dir, exist_ok=True)
# save_obj_mesh(os.path.join(smpl_dir, frame+'.obj'), vert, face[...,::-1])
smpl_path = [os.path.join(smpl_dir,f) for f in os.listdir(smpl_dir) if frame in f][0]
if smpl_path[-4:] == '.obj':
vert, face = load_obj_mesh(smpl_path)
else:
smpl = load_ply(smpl_path)
vert, face = smpl['vertex'][:,:3], smpl['face']
vert = vert.astype(np.float32)
# load image data
image_data = self.get_image(sid, num_views=self.num_views, view_id=vid,
random_sample=self.opt.random_multiview, smpl_verts=vert)
res.update(image_data)
T = cv2.Rodrigues(param['global_orient'].reshape(-1, 3)[:1])[0]
res['bbox'] = np.array(res['bbox'])
res['smpl_rot'] = torch.from_numpy(T.astype(np.float32)) \
if self.load_smpl_mesh else []
res['smpl_verts']= torch.from_numpy(vert.astype(np.float32)) \
if self.load_smpl_mesh else []
res['smpl_faces']= torch.from_numpy(face.astype(np.int32)) \
if self.load_smpl_mesh else []
res['smpl_betas']= torch.from_numpy(param['betas'].reshape(-1).astype(np.float32)) \
if self.load_smpl_param else []
if self.load_smpl_param:
t_vert, t_face = self.smpl_t_pose
res['smpl_t_verts'] = t_vert
res['smpl_t_faces'] = t_face
if self.opt.use_t_pose:
res['smpl_t_verts'] = torch.from_numpy(res['smpl_t_verts'].astype(np.float32))
res['smpl_t_faces'] = torch.from_numpy(res['smpl_t_faces'].astype(np.int32))
else:
res['smpl_t_verts'] = []
res['smpl_t_faces'] = []
return res
def __getitem__(self, index):
if not self.is_train:
index *= self.eval_skip
return self.get_item(index)
================================================
FILE: lib/data/__init__.py
================================================
================================================
FILE: lib/geometry.py
================================================
import torch
import numpy as np
def rot2euler(R):
phi = np.arctan2(R[1,2], R[2,2])
theta = -np.arcsin(R[0,2])
psi = np.arctan2(R[0,1], R[0,0])
return np.array([phi, theta, psi])
def euler2rot(euler):
sin, cos = np.sin, np.cos
phi, theta, psi = euler[0], euler[1], euler[2]
R1 = np.array([[1, 0, 0],
[0, cos(phi), sin(phi)],
[0, -sin(phi), cos(phi)]])
R2 = np.array([[cos(theta), 0, -sin(theta)],
[0, 1, 0],
[sin(theta), 0, cos(theta)]])
R3 = np.array([[cos(psi), sin(psi), 0],
[-sin(psi), cos(psi), 0],
[0, 0, 1]])
R = R1 @ R2 @ R3
return R
def batch_rodrigues(theta):
"""Convert axis-angle representation to rotation matrix.
Args:
theta: size = [B, 3]
Returns:
Rotation matrix corresponding to the quaternion -- size = [B, 3, 3]
"""
l1norm = torch.norm(theta + 1e-8, p = 2, dim = 1)
angle = torch.unsqueeze(l1norm, -1)
normalized = torch.div(theta, angle)
angle = angle * 0.5
v_cos = torch.cos(angle)
v_sin = torch.sin(angle)
quat = torch.cat([v_cos, v_sin * normalized], dim = 1)
return quat_to_rotmat(quat)
def quat_to_rotmat(quat):
"""Convert quaternion coefficients to rotation matrix.
Args:
quat: size = [B, 4] 4 <===>(w, x, y, z)
Returns:
Rotation matrix corresponding to the quaternion -- size = [B, 3, 3]
"""
norm_quat = quat
norm_quat = norm_quat/norm_quat.norm(p=2, dim=1, keepdim=True)
w, x, y, z = norm_quat[:,0], norm_quat[:,1], norm_quat[:,2], norm_quat[:,3]
B = quat.size(0)
w2, x2, y2, z2 = w.pow(2), x.pow(2), y.pow(2), z.pow(2)
wx, wy, wz = w*x, w*y, w*z
xy, xz, yz = x*y, x*z, y*z
rotMat = torch.stack([w2 + x2 - y2 - z2, 2*xy - 2*wz, 2*wy + 2*xz,
2*wz + 2*xy, w2 - x2 + y2 - z2, 2*yz - 2*wx,
2*xz - 2*wy, 2*wx + 2*yz, w2 - x2 - y2 + z2], dim=1).view(B, 3, 3)
return rotMat
def index(feat, uv, mode='bilinear'):
'''
:param feat: [B, C, H, W] image features
:param uv: [B, 2, N] uv coordinates in the image plane, range [-1, 1]
:return: [B, C, N] image features at the uv coordinates
'''
uv = uv.transpose(1, 2) # [B, N, 2]
uv = uv.unsqueeze(2) # [B, N, 1, 2]
# NOTE: for newer PyTorch, it seems that training results are degraded due to implementation diff in F.grid_sample
# for old versions, simply remove the aligned_corners argument.
# if torch.__version__ >= "1.3.0":
# samples = torch.nn.functional.grid_sample(feat, uv, align_corners=True) # [B, C, N, 1]
# else:
samples = torch.nn.functional.grid_sample(feat, uv, mode=mode)
return samples[:, :, :, 0] # [B, C, N]
def orthogonal(points, calibrations, transforms=None):
'''
Compute the orthogonal projections of 3D points into the image plane by given projection matrix
:param points: [B, 3, N] Tensor of 3D points
:param calibrations: [B, 4, 4] Tensor of projection matrix
:param transforms: [B, 2, 3] Tensor of image transform matrix
:return: xyz: [B, 3, N] Tensor of xyz coordinates in the image plane
'''
rot = calibrations[:, :3, :3]
trans = calibrations[:, :3, 3:4]
pts = torch.baddbmm(trans, rot, points) # [B, 3, N]
if transforms is not None:
scale = transforms[:2, :2]
shift = transforms[:2, 2:3]
pts[:, :2, :] = torch.baddbmm(shift, scale, pts[:, :2, :])
return pts
def perspective(points, w2c, camera):
'''
Compute the perspective projections of 3D points into the image plane by given projection matrix
:param points: [Bx3xN] Tensor of 3D points
:param calibrations: [Bx4/9] Tensor of projection matrix
:param transforms: [Bx4x4] Tensor of image transform matrix
:return: xy: [Bx2xN] Tensor of xy coordinates in the image plane
'''
rot = w2c[:, :3, :3]
trans = w2c[:, :3, 3:4]
points = torch.baddbmm(trans, rot, points) # [B, 3, N]
xy = points[:,:2, :] / torch.clamp(points[:,2:3,:], 1e-9)
if camera.shape[1] > 6:
x2 = xy[:,0,:]*xy[:,0,:]
y2 = xy[:,1,:]*xy[:,1,:]
xy_= xy[:,0,:]*xy[:,1,:]
r2 = x2 + y2
c = (1 + r2*(camera[:,4:5]+r2*(camera[:,5:6]+r2*camera[:,8:9])))
xy = c.unsqueeze(1)*xy + torch.cat([ \
(camera[:,6:7]*2*xy_+ camera[:,7:8]*(r2+2*x2)).unsqueeze(1),\
(camera[:,7:8]*2*xy_+ camera[:,6:7]*(r2+2*y2)).unsqueeze(1)],1)
xy = camera[:,0:2,None]*xy + camera[:,2:4,None]
points[:,:2, :] = xy
return points
================================================
FILE: lib/mesh_util.py
================================================
from skimage import measure
import numpy as np
import torch
from skimage import measure
def save_obj_mesh(mesh_path, verts, faces):
file = open(mesh_path, 'w')
for v in verts:
file.write('v %.4f %.4f %.4f\n' % (v[0], v[1], v[2]))
for f in faces:
f_plus = f + 1
file.write('f %d %d %d\n' % (f_plus[0], f_plus[2], f_plus[1]))
file.close()
def save_obj_mesh_with_color(mesh_path, verts, faces, colors):
file = open(mesh_path, 'w')
for idx, v in enumerate(verts):
c = colors[idx]
file.write('v %.4f %.4f %.4f %.4f %.4f %.4f\n' % (v[0], v[1], v[2], c[0], c[1], c[2]))
for f in faces:
f_plus = f + 1
file.write('f %d %d %d\n' % (f_plus[0], f_plus[2], f_plus[1]))
file.close()
def save_obj_mesh_with_uv(mesh_path, verts, faces, uvs):
file = open(mesh_path, 'w')
for idx, v in enumerate(verts):
vt = uvs[idx]
file.write('v %.4f %.4f %.4f\n' % (v[0], v[1], v[2]))
file.write('vt %.4f %.4f\n' % (vt[0], vt[1]))
for f in faces:
f_plus = f + 1
file.write('f %d/%d %d/%d %d/%d\n' % (f_plus[0], f_plus[0],
f_plus[2], f_plus[2],
f_plus[1], f_plus[1]))
file.close()
================================================
FILE: lib/metrics.py
================================================
import sys
import os
import torch
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
ROOT_PATH = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
# import torch
# torch.autograd.set_detect_anomaly(True)
# import torch.nn as nn
# import torch.nn.functional as F
import numpy as np
import trimesh
import cv2
from scipy.spatial import KDTree
import lpips
lpips_net = lpips.LPIPS(net='alex')
def chamfer(x_verts, gt_verts, x_normals=None, gt_normals=None):
searcher = KDTree(gt_verts)
dists, inds = searcher.query(x_verts)
if x_normals is None or gt_normals is None:
return dists
elif x_normals is not None and gt_normals is not None:
cosine_dists = 1 - np.sum(x_normals * gt_normals[inds], axis=1)
return dists, cosine_dists
else:
raise Exception("provide normals for both point sets")
def fscore(dist1, dist2, threshold=1.0):
"""
Calculates the F-score between two point clouds with the corresponding threshold value.
:param dist1: N-Points
:param dist2: N-Points
:param th: float
:return: fscore, precision, recall
"""
# NB : In this depo, dist1 and dist2 are squared pointcloud euclidean distances, so you should adapt the threshold accordingly.
precision_1 = np.mean((dist1 < threshold).astype(np.float32))
precision_2 = np.mean((dist2 < threshold).astype(np.float32))
fscore = 2 * precision_1 * precision_2 / (precision_1 + precision_2)
return fscore, precision_1, precision_2
def psnr(x, gt):
"""
x: np.uint8, HxWxC, 0 - 255
gt: np.uint8, HxWxC, 0 - 255
"""
x = (x / 255.).astype(np.float32)
gt = (gt / 255.).astype(np.float32)
mse = ((x - gt) ** 2).mean()
psnr = 10. * np.log10(1. / mse)
# return mse, psnr
return psnr
def ssim_channel(x, gt):
C1 = (0.01 * 255) ** 2
C2 = (0.03 * 255) ** 2
x = x.astype(np.float32)
gt = gt.astype(np.float32)
kernel = cv2.getGaussianKernel(11, 1.5)
window = np.outer(kernel, kernel.transpose())
mu1 = cv2.filter2D(x, -1, window)[5:-5, 5:-5] # valid
mu2 = cv2.filter2D(gt, -1, window)[5:-5, 5:-5]
mu1_sq = mu1 ** 2
mu2_sq = mu2 ** 2
mu1_mu2 = mu1 * mu2
sigma1_sq = cv2.filter2D(x ** 2, -1, window)[5:-5, 5:-5] - mu1_sq
sigma2_sq = cv2.filter2D(gt ** 2, -1, window)[5:-5, 5:-5] - mu2_sq
sigma12 = cv2.filter2D(x * gt, -1, window)[5:-5, 5:-5] - mu1_mu2
ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) *
(sigma1_sq + sigma2_sq + C2))
return ssim_map.mean()
def ssim(x, gt):
'''calculate SSIM
the same outputs as MATLAB's
x: np.uint8, HxWxC, 0 - 255
gt: np.uint8, HxWxC, 0 - 255
'''
if not x.shape == gt.shape:
raise ValueError('Input images must have the same dimensions.')
if x.ndim == 2:
return ssim_channel(x, gt)
elif x.ndim == 3:
if x.shape[2] == 3:
ssims = []
for i in range(3):
ssims.append(ssim_channel(x, gt))
return np.array(ssims).mean()
elif x.shape[2] == 1:
return ssim_channel(np.squeeze(x), np.squeeze(gt))
else:
raise ValueError('input image dimension mismatch.')
def lpips(x, gt, net=lpips_net):
x = torch.from_numpy(x).float() / 255. * 2 - 1.
gt = torch.from_numpy(gt).float() / 255. * 2 - 1.
x = x.permute([2, 0, 1]).unsqueeze(0)
gt = gt.permute([2, 0, 1]).unsqueeze(0)
with torch.no_grad():
loss = net.forward(x, gt)
return loss.item()
if __name__ == "__main__":
import trimesh
x = trimesh.load("./data/human/SMPL/10315_m_John/smpl.obj", process=False)
y = trimesh.load("./data/human/SMPL/10316_m_John/smpl.obj", process=False)
print(chamfer(x.vertices, y.vertices, x.vertex_normals, y.vertex_normals))
================================================
FILE: lib/metrics_torch.py
================================================
import torch
import torch.nn.functional as F
from math import exp
import numpy as np
import lpips
class LPIPS(torch.nn.Module):
def __init__(self):
super(LPIPS, self).__init__()
self.net = lpips.LPIPS(net='alex', verbose=False)
def forward(self, x, gt):
if torch.max(gt) > 128:
# [0, 255]
x = x / 255. * 2 - 1
gt = gt / 255. * 2 - 1
elif torch.min(gt) >= 0 and torch.max(gt) <= 1:
# [0, 1]
x = x * 2 - 1
gt = gt * 2 - 1
with torch.no_grad():
loss = self.net.forward(x, gt)
# return loss.item()
return loss
def psnr(x, gt):
"""
x: np.uint8, HxWxC, 0 - 255
gt: np.uint8, HxWxC, 0 - 255
"""
if torch.max(gt) > 128:
# [0, 255]
x = x / 255
gt = gt / 255
elif torch.min(gt) < -0.5:
# [0, 1]
x = (x+1)/2
gt = (gt+1)/2
mse = torch.mean((x - gt) ** 2)
psnr = -10. * torch.log10(mse)
return psnr
def gaussian(window_size, sigma):
gauss = torch.Tensor([exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)])
return gauss/gauss.sum()
def create_window(window_size, channel=1):
_1D_window = gaussian(window_size, 1.5).unsqueeze(1)
_2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
window = _2D_window.expand(channel, 1, window_size, window_size).contiguous()
return window
def ssim_(img1, img2, window_size=11, window=None, size_average=True, full=False, val_range=None):
# Value range can be different from 255. Other common ranges are 1 (sigmoid) and 2 (tanh).\
if val_range is None:
if torch.max(img1) > 128:
max_val = 255
else:
max_val = 1
if torch.min(img1) < -0.5:
min_val = -1
else:
min_val = 0
L = max_val - min_val
else:
L = val_range
padd = 0
(_, channel, height, width) = img1.size()
if window is None:
real_size = min(window_size, height, width)
window = create_window(real_size, channel=channel).to(img1.device)
mu1 = F.conv2d(img1, window, padding=padd, groups=channel)
mu2 = F.conv2d(img2, window, padding=padd, groups=channel)
mu1_sq = mu1.pow(2)
mu2_sq = mu2.pow(2)
mu1_mu2 = mu1 * mu2
sigma1_sq = F.conv2d(img1 * img1, window, padding=padd, groups=channel) - mu1_sq
sigma2_sq = F.conv2d(img2 * img2, window, padding=padd, groups=channel) - mu2_sq
sigma12 = F.conv2d(img1 * img2, window, padding=padd, groups=channel) - mu1_mu2
C1 = (0.01 * L) ** 2
C2 = (0.03 * L) ** 2
v1 = 2.0 * sigma12 + C2
v2 = sigma1_sq + sigma2_sq + C2
cs = v1 / v2 # contrast sensitivity
ssim_map = ((2 * mu1_mu2 + C1) * v1) / ((mu1_sq + mu2_sq + C1) * v2)
if size_average:
cs = cs.mean()
ret = ssim_map.mean()
else:
cs = cs.mean(1).mean(1).mean(1)
ret = ssim_map.mean(1).mean(1).mean(1)
if full:
return ret, cs
return ret
# Classes to re-use window
class SSIM(torch.nn.Module):
def __init__(self, window_size=11, size_average=True, val_range=None):
super(SSIM, self).__init__()
self.window_size = window_size
self.size_average = size_average
self.val_range = val_range
# Assume 1 channel for SSIM
self.channel = 1
self.window = create_window(window_size)
def forward(self, img1, img2):
if len(list(img1.shape)) < 4:
img1 = img1.unsqueeze(0)
img2 = img2.unsqueeze(0)
(_, channel, _, _) = img1.size()
if channel == self.channel and self.window.dtype == img1.dtype:
window = self.window
else:
window = create_window(self.window_size, channel).to(img1.device).type(img1.dtype)
self.window = window
self.channel = channel
return ssim_(img1, img2, window=window, window_size=self.window_size, size_average=self.size_average)
================================================
FILE: lib/model/Embedder.py
================================================
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import numpy as np
class PositionalEncoding:
"""
GNR uses positional encoding in NeRF for coordinate embedding
"""
def __init__(self, d, num_freqs=10, min_freq=None, max_freq=None, freq_type='linear'):
self.num_freqs = num_freqs
self.min_freq = min_freq
self.max_freq = max_freq
self.freq_type = freq_type
self.create_embedding_fn(d)
def create_embedding_fn(self, d):
embed_fns = []
out_dim = 0
embed_fns.append(lambda x : x)
out_dim += d
N_freqs = self.num_freqs
if self.freq_type == 'linear':
min_freq = 0 if self.min_freq is None else self.min_freq
max_freq = 2 ** (self.num_freqs-1) if self.max_freq is None else self.max_freq
freq_bands = torch.linspace(min_freq*math.pi*2, max_freq*math.pi*2, steps=N_freqs) # linear freq band, Fourier expansion
else:
min_freq = 0 if self.min_freq is None else math.log2(self.min_freq)
max_freq = self.num_freqs-1 if self.max_freq is None else math.log2(self.max_freq)
freq_bands = 2.**torch.linspace(min_freq*math.pi*2, max_freq*math.pi*2, steps=N_freqs) # log expansion
for freq in freq_bands:
for p_fn in [torch.sin, torch.cos]:
embed_fns.append(lambda x, p_fn=p_fn, freq=freq : p_fn(x * freq))
out_dim += d
self.embed_fns = embed_fns
self.out_dim = out_dim
def embed(self, inputs):
return torch.cat([fn(inputs) for fn in self.embed_fns], -1)
class SphericalHarmonics:
"""
GNR uses Sepherical Harmonics for view direction embedding
"""
def __init__(self, d = 3, rank = 3):
assert d % 3 == 0
self.rank = max([int(rank),0])
self.out_dim= self.rank*self.rank * (d // 3)
def Lengdre_polynormial(self, x, omx = None):
if omx is None: omx = 1 - x * x
Fml = [[]] *((self.rank+1)*self.rank//2)
Fml[0] = torch.ones_like(x)
for l in range(1, self.rank):
b = (l * l + l) // 2
Fml[b+l] =-Fml[b-1]*(2*l-1)
Fml[b+l-1]= Fml[b-1]*(2*l-1)*x
for m in range(l,1,-1):
Fml[b+m-2] = -(omx * Fml[b+m] + \
2*(m-1)*x * Fml[b+m-1]) / ((l-m+2)*(l+m-1))
return Fml
def SH(self, xyz):
cs = xyz[...,0:1]
sn = xyz[...,1:2]
Fml = self.Lengdre_polynormial(xyz[...,2:3], cs*cs + sn*sn)
H = [[]] *(self.rank*self.rank)
for l in range(self.rank):
b = l * l + l
attr = np.sqrt((2*l+1)/math.pi/4)
H[b] = attr * Fml[b//2]
attr = attr * np.sqrt(2)
snM = sn; csM = cs
for m in range(1, l+1):
attr = -attr / np.sqrt((l+m)*(l+1-m))
H[b-m] = attr * Fml[b//2+m] * snM
H[b+m] = attr * Fml[b//2-m] * csM
snM, csM = snM*cs+csM*sn, csM*cs-snM*sn
if len(H) > 0:
return torch.cat(H, -1)
else:
return torch.Tensor([])
def embed(self, inputs):
return self.SH(inputs)
================================================
FILE: lib/model/GNR.py
================================================
import torch
import torch.nn as nn
import torch.nn.functional as F
from .HGFilters import *
from ..net_util import init_weights
from .NeRF import NeRF
from .NeRFRenderer import NeRFRenderer
from .SRFilters import SRFilters
from ..geometry import index
class GNR(nn.Module):
def __init__(self, opt):
super(GNR, self).__init__()
self.name = 'gnr'
self.opt = opt
self.num_views = self.opt.num_views
self.use_feat_sr = self.opt.use_feat_sr
self.ddp = self.opt.ddp
self.feat_dim = 64 if self.use_feat_sr else 256
self.index = index
self.error_term=nn.MSELoss()
self.image_filter = HGFilter(opt)
if self.use_feat_sr:
self.sr_filter = SRFilters(order=2, out_ch=self.feat_dim)
if not opt.train_encoder:
for param in self.image_filter.parameters():
param.requires_grad = False
self.nerf = NeRF(opt, input_ch_feat=self.feat_dim)
self.nerf_renderer = NeRFRenderer(opt, self.nerf)
init_weights(self)
def image_rescale(self, images, masks):
if images.min() < -0.2:
images = (images + 1) / 2
images = images * (masks > 0).float()
return images
def get_image_feature(self, data):
if 'feats' not in data.keys():
images = data['images']
im_feat = self.image_filter(images[:self.num_views])
if self.use_feat_sr:
im_feat = self.sr_filter(im_feat, images[:self.num_views])
data['images'] = torch.cat([self.image_rescale(images[:self.num_views], data['masks'][:self.num_views]), \
images[self.num_views:]], 0)
data['feats'] = im_feat
return data
def forward(self, data, train_shape=False):
data = self.get_image_feature(data)
if train_shape:
error = self.nerf_renderer.train_shape(**data)
else:
error = self.nerf_renderer.render(**data)
return error
def render_path(self, data):
with torch.no_grad():
rgbs = None
data = self.get_image_feature(data)
rgbs, depths = self.nerf_renderer.render_path(**data)
return rgbs, depths
def reconstruct(self, data):
with torch.no_grad():
data = self.get_image_feature(data)
verts, faces, rgbs = self.nerf_renderer.reconstruct(**data)
return verts, faces, rgbs
================================================
FILE: lib/model/HGFilters.py
================================================
"""
This file is directly borrowed from PIFu
GNR uses PIFu's Stacked-Hour-Glass for image encoding
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from ..net_util import *
class HourGlass(nn.Module):
def __init__(self, num_modules, depth, num_features, norm='batch'):
super(HourGlass, self).__init__()
self.num_modules = num_modules
self.depth = depth
self.features = num_features
self.norm = norm
self._generate_network(self.depth)
def _generate_network(self, level):
self.add_module('b1_' + str(level), ConvBlock(self.features, self.features, norm=self.norm))
self.add_module('b2_' + str(level), ConvBlock(self.features, self.features, norm=self.norm))
if level > 1:
self._generate_network(level - 1)
else:
self.add_module('b2_plus_' + str(level), ConvBlock(self.features, self.features, norm=self.norm))
self.add_module('b3_' + str(level), ConvBlock(self.features, self.features, norm=self.norm))
def _forward(self, level, inp):
# Upper branch
up1 = inp
up1 = self._modules['b1_' + str(level)](up1)
# Lower branch
low1 = F.avg_pool2d(inp, 2, stride=2)
low1 = self._modules['b2_' + str(level)](low1)
if level > 1:
low2 = self._forward(level - 1, low1)
else:
low2 = low1
low2 = self._modules['b2_plus_' + str(level)](low2)
low3 = low2
low3 = self._modules['b3_' + str(level)](low3)
# NOTE: for newer PyTorch (1.3~), it seems that training results are degraded due to implementation diff in F.grid_sample
# if the pretrained model behaves weirdly, switch with the commented line.
# NOTE: I also found that "bicubic" works better.
up2 = F.interpolate(low3, scale_factor=2, mode='bicubic', align_corners=True)
# up2 = F.interpolate(low3, scale_factor=2, mode='nearest)
return up1 + up2
def forward(self, x):
return self._forward(self.depth, x)
class HGFilter(nn.Module):
def __init__(self, opt):
super(HGFilter, self).__init__()
self.num_modules = opt.num_stack
self.opt = opt
# Base part
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3)
if self.opt.norm == 'batch':
self.bn1 = nn.BatchNorm2d(64)
elif self.opt.norm == 'group':
self.bn1 = nn.GroupNorm(32, 64)
if self.opt.hg_down == 'conv64':
self.conv2 = ConvBlock(64, 64, self.opt.norm)
self.down_conv2 = nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1)
elif self.opt.hg_down == 'conv128':
self.conv2 = ConvBlock(64, 128, self.opt.norm)
self.down_conv2 = nn.Conv2d(128, 128, kernel_size=3, stride=2, padding=1)
elif self.opt.hg_down == 'ave_pool':
self.conv2 = ConvBlock(64, 128, self.opt.norm)
else:
raise NameError('Unknown Fan Filter setting!')
self.conv3 = ConvBlock(128, 128, self.opt.norm)
self.conv4 = ConvBlock(128, 256, self.opt.norm)
# Stacking part
for hg_module in range(self.num_modules):
self.add_module('m' + str(hg_module), HourGlass(1, opt.num_hourglass, 256, self.opt.norm))
self.add_module('top_m_' + str(hg_module), ConvBlock(256, 256, self.opt.norm))
self.add_module('conv_last' + str(hg_module),
nn.Conv2d(256, 256, kernel_size=1, stride=1, padding=0))
if self.opt.norm == 'batch':
self.add_module('bn_end' + str(hg_module), nn.BatchNorm2d(256))
elif self.opt.norm == 'group':
self.add_module('bn_end' + str(hg_module), nn.GroupNorm(32, 256))
self.add_module('l' + str(hg_module), nn.Conv2d(256,
opt.hourglass_dim, kernel_size=1, stride=1, padding=0))
if hg_module < self.num_modules - 1:
self.add_module(
'bl' + str(hg_module), nn.Conv2d(256, 256, kernel_size=1, stride=1, padding=0))
self.add_module('al' + str(hg_module), nn.Conv2d(opt.hourglass_dim,
256, kernel_size=1, stride=1, padding=0))
def forward(self, x):
x = F.relu(self.bn1(self.conv1(x)), True)
# tmpx = x
if self.opt.hg_down == 'ave_pool':
x = F.avg_pool2d(self.conv2(x), 2, stride=2)
elif self.opt.hg_down in ['conv64', 'conv128']:
x = self.conv2(x)
x = self.down_conv2(x)
else:
raise NameError('Unknown Fan Filter setting!')
# normx = x
x = self.conv3(x)
x = self.conv4(x)
previous = x
# outputs = []
for i in range(self.num_modules):
hg = self._modules['m' + str(i)](previous)
ll = hg
ll = self._modules['top_m_' + str(i)](ll)
ll = F.relu(self._modules['bn_end' + str(i)]
(self._modules['conv_last' + str(i)](ll)), True)
# Predict heatmaps
tmp_out = self._modules['l' + str(i)](ll)
# outputs.append(tmp_out)
if i < self.num_modules - 1:
ll = self._modules['bl' + str(i)](ll)
tmp_out_ = self._modules['al' + str(i)](tmp_out)
previous = previous + ll + tmp_out_
# return outputs, tmpx.detach(), normx
return tmp_out
================================================
FILE: lib/model/NeRF.py
================================================
import torch
torch.autograd.set_detect_anomaly(True)
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import math
from .Embedder import SphericalHarmonics, PositionalEncoding
class NeRF(nn.Module):
def __init__(self, opt, D=8, W=256, input_ch=3, input_ch_atts=3, input_ch_feat=128, output_ch=4,
activation="relu", pose_freqs=10, att_freqs=6, spatial_freq=1/256):
"""
"""
super(NeRF, self).__init__()
self.D = D
self.W = W
self.use_smpl_sdf = opt.use_smpl_sdf
self.use_t_pose = opt.use_t_pose
self.angle_diff = opt.angle_diff
self.use_occ_net = opt.use_occlusion_net
self.train_occ = opt.train_occlusion
self.input_ch_pos_enc = input_ch
self.input_ch_smpl = 0
if self.use_smpl_sdf: self.input_ch_smpl += 4
if self.use_t_pose: self.input_ch_smpl += 3
self.use_smpl = self.input_ch_smpl != 0
self.input_ch_feat = input_ch_feat + 3
self.skips = opt.skips
self.use_viewdirs = opt.use_viewdirs and opt.use_attention
self.num_views = opt.num_views
# self.input_ch_atts = input_ch_atts if opt.use_attention else 0
if not opt.use_attention:
self.input_ch_atts = 0
elif self.angle_diff:
self.input_ch_atts = 1
else:
self.input_ch_atts = 3
self.use_sh = opt.use_sh if not self.angle_diff else False
self.use_attention = opt.use_attention
self.use_bn = opt.use_bn
self.spatial_freq = spatial_freq
self.pose_embeder = PositionalEncoding(self.input_ch_pos_enc, num_freqs=pose_freqs, min_freq=spatial_freq*0.1, max_freq=spatial_freq*10)
self.att_embeder = SphericalHarmonics(d = self.input_ch_atts) if self.use_sh else PositionalEncoding(self.input_ch_atts, num_freqs=att_freqs)
self.pose_embed_fn = self.pose_embeder.embed
self.att_embed_fn = self.att_embeder.embed
self.weighted_pool = opt.weighted_pool and self.use_attention
self.alpha_linears = nn.ModuleList(
[nn.Linear(self.pose_embeder.out_dim + self.input_ch_smpl + self.input_ch_feat, W)] +
[nn.Linear(W + self.pose_embeder.out_dim + self.input_ch_smpl, W) if i in self.skips else nn.Linear(W, W) for i in range(0, D-1)])
self.alpha_out_linear = nn.Linear(W, 1)
self.rgb_linears = nn.ModuleList(
[nn.Linear(W + self.pose_embeder.out_dim + self.input_ch_smpl, W//4)] +
[nn.Linear(W//4 + self.att_embeder.out_dim, W//8) if self.use_viewdirs else nn.Linear(W//4, W//8)] +
[nn.Linear(W//8, W//16)] + [nn.Linear(W//16, 3)]
)
if self.use_bn:
self.bn_layer_1 = nn.BatchNorm1d(W)
self.bn_layer_2 = nn.BatchNorm1d(W)
self.bn_layer_3 = nn.BatchNorm1d(W//16)
if self.weighted_pool:
self.s = nn.Parameter(torch.ones(1))
# ### Implementation according to the official code release (https://github.com/bmild/nerf/blob/master/run_nerf_helpers.py#L104-L105)
if activation == "relu":
self.activation_fn = F.relu
elif activation == "swish":
if torch.__version__ < "1.7.0":
swish = lambda x: x * torch.sigmoid(x)
self.activation_fn = swish
else:
self.activation_fn = torch.nn.SiLU
if self.use_attention:
self.value_linears = nn.ModuleList([nn.Linear(self.pose_embeder.out_dim+self.att_embeder.out_dim+W, W//4),
nn.Linear(W//4+self.att_embeder.out_dim,W//8), nn.Linear(W//8+self.att_embeder.out_dim,W//16)])
self.key_linears = nn.ModuleList([nn.Linear(self.pose_embeder.out_dim+self.att_embeder.out_dim+W, W//4),
nn.Linear(W//4+self.att_embeder.out_dim,W//8), nn.Linear(W//8+self.att_embeder.out_dim,W//16)])
if self.use_occ_net:
self.occ_linears = nn.ModuleList([nn.Linear(self.input_ch_smpl + 6 + self.input_ch_feat, W//4),
nn.Linear(W//4, W//16),
nn.Linear(W//16+self.input_ch_smpl + 6, 1)])
if not self.train_occ:
for linear in self.occ_linears:
linear.weight.requires_grad_(False)
linear.bias.requires_grad_(False)
def forward(self, x, attdirs=None, alpha_only=False, smpl_vis=None):
# prepare inputs
input_pts, input_smpl, input_feats = torch.split(x, [self.input_ch_pos_enc, self.input_ch_smpl, self.input_ch_feat], dim=-1)
unqiue_pts = input_pts[:, 0]
unqiue_smpl = input_smpl[:, 0] if self.use_smpl else torch.zeros([input_pts.shape[0], 0], dtype=torch.float32, device=input_pts.device)
input_pts = input_pts.view([-1, self.input_ch_pos_enc])
input_smpl = input_smpl.view([-1, self.input_ch_smpl]) if self.use_smpl else torch.zeros([input_pts.shape[0], 0], dtype=torch.float32, device=input_pts.device)
input_feats = input_feats.view([-1,self.input_ch_feat])
if self.use_attention and attdirs is not None:
qrydirs, srcdirs = torch.split(attdirs, [1, self.num_views], dim=-2)
if self.use_occ_net and attdirs is not None:
# compute plucker coord
d = srcdirs.reshape([-1, 3])
m = torch.cross(input_pts, d, dim=-1)
occ_h = torch.cat([input_smpl, d, m, input_feats], dim=-1)
for i, l in enumerate(self.occ_linears):
occ_h = self.occ_linears[i](occ_h)
if i < len(self.occ_linears) - 1:
occ_h = self.activation_fn(occ_h)
if i == 1:
occ_h = torch.cat([input_smpl, d, m, occ_h], dim=-1)
occ_out = torch.sigmoid(occ_h).view([-1, self.num_views, 1])
# occ = F.softmax(occ_out, dim=1)
# alpha mlp
tmp_h = None
h = torch.cat([self.pose_embed_fn(input_pts), input_smpl, input_feats], dim=-1)
for i, l in enumerate(self.alpha_linears):
h = self.alpha_linears[i](h)
h = self.activation_fn(h)
if i in self.skips:
if i == self.skips[0]:
tmp_h = h.clone()
h = torch.mean(h.view(-1, self.num_views, self.W), dim=1)
h = torch.cat([self.pose_embed_fn(unqiue_pts), unqiue_smpl, h], dim=-1)
alpha = self.alpha_out_linear(h)
if alpha_only: return alpha
# rgb mpl
if self.use_attention and self.weighted_pool:
weights = torch.exp(self.s * (torch.sum(srcdirs*qrydirs, dim=-1) -1))
weights = weights / (torch.sum(weights, dim=-1, keepdim=True) + 1e-8) # [N_rand*N_sample, 4]
h = torch.sum(tmp_h.view(-1, self.num_views, self.W) * weights[...,None], dim=1)
h0 = h.clone()
else:
h = torch.mean(tmp_h.view(-1, self.num_views, self.W), dim=1)
h = torch.cat([self.pose_embed_fn(unqiue_pts), unqiue_smpl, h], -1)
for i, l in enumerate(self.rgb_linears):
h = self.rgb_linears[i](h)
if i < len(self.rgb_linears) - 1:
h = self.activation_fn(h)
if i == 0 and self.use_viewdirs:
h = torch.cat([self.att_embed_fn(-qrydirs.squeeze(1)), h], dim=-1)
outputs = torch.cat([h, alpha], dim=-1)
# calculate attention
if self.use_attention and attdirs is not None:
attdirs = attdirs.reshape([-1, self.input_ch_atts])
input_pts_ = torch.cat([unqiue_pts, input_pts], dim=0)
input_h = torch.cat([h0, tmp_h], dim=0)
val = torch.cat([self.pose_embed_fn(input_pts_), self.att_embed_fn(attdirs), input_h], dim=-1)
for i, l in enumerate(self.value_linears):
val = self.value_linears[i](val)
if i < len(self.value_linears) - 1:
val = self.activation_fn(val)
val = torch.cat([self.att_embed_fn(attdirs), val], dim=-1)
key = torch.cat([self.pose_embed_fn(unqiue_pts), self.att_embed_fn(qrydirs.squeeze(1)), h0], dim=-1)
for i, l in enumerate(self.key_linears):
key = self.key_linears[i](key)
if i < len(self.key_linears) - 1:
key = self.activation_fn(key)
key = torch.cat([self.att_embed_fn(qrydirs.squeeze(1)), key], dim=-1)
# attention key (query direction) and val (source view direction)
key = key.unsqueeze(1)
val = val.view(unqiue_pts.shape[0], self.num_views+1, -1)
attention = torch.matmul(val, key.permute(0,2,1)).squeeze(-1)
if self.use_occ_net:
attention = self.weighted_softmax(attention, occ_out.squeeze(-1))
elif smpl_vis is not None:
attention = self.weighted_softmax(attention, smpl_vis.float())
else:
attention = F.softmax(attention, dim=-1)
if self.use_attention and attdirs is not None:
outputs = torch.cat([outputs, attention], dim=-1)
if self.use_occ_net:
outputs = torch.cat([outputs, occ_out.squeeze(-1)], dim=-1)
return outputs
def weighted_softmax(self, attention, weight):
exp_att = torch.exp(attention - torch.max(attention, 1, keepdim=True)[0])
# exp_att_src = exp_att[:, 1:].clone() * weight
exp_att = torch.cat([exp_att[:,:1], exp_att[:, 1:] * weight], dim=1)
exp_att_sum = torch.sum(exp_att, dim=-1, keepdim=True)
attention = exp_att / (exp_att_sum + 1e-8)
return attention
================================================
FILE: lib/model/NeRFRenderer.py
================================================
import logging
from turtle import width
import torch
torch.autograd.set_detect_anomaly(True)
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from ..geometry import index, orthogonal, perspective
import trimesh
import imageio
from tqdm import tqdm
from mesh_grid import MeshGridSearcher
from skimage import measure
mse = lambda x, y : torch.mean((x - y) ** 2)
bmse = lambda x, y: torch.sum((x*y - y) ** 2) / torch.sum(y)
l1 = lambda x, y : torch.mean(torch.abs(x-y))
to8b = lambda x: (np.clip(x.detach().cpu().numpy(), 0, 1) * 255).astype(np.uint8)
eikonal = lambda x: torch.mean(x**2)
class NeRFRenderer:
def __init__(self, opt, nerf, nerf_fine=None, projection='perspective', vgg_loss=None, threshold=0.5):
self.opt = opt
self.nerf = nerf
self.nerf_fine = nerf_fine
self.use_fine = self.nerf_fine is not None
self.width = self.opt.loadSize
self.height = self.opt.loadSize
self.N_samples = opt.N_samples
self.num_views = opt.num_views
self.projection_mode = projection
self.projection = orthogonal if projection == 'orthogonal' else perspective
self.N_rand = opt.N_rand
self.N_grid = opt.N_grid+1
self.chunk = opt.chunk
self.N_rand_infer = opt.N_rand_infer
self.mse_loss = nn.MSELoss()
self.alpha_loss = nn.BCELoss()
self.alpha_loss_bmse = bmse
self.alpha_grad_loss = eikonal
self.mesh_searcher = MeshGridSearcher()
self.use_nml = opt.use_nml
self.use_attention = opt.use_attention
self.threshold = threshold
self.debug= opt.debug
self.rgb_ch = 6 if self.use_attention else 3
if self.debug: self.rgb_ch += self.num_views * 3
self.debug_idx = 0
self.use_vgg = opt.use_vgg
self.vgg_loss = vgg_loss
self.use_smpl_sdf = opt.use_smpl_sdf
self.use_t_pose = opt.use_t_pose
self.use_smpl_depth = opt.use_smpl_depth
self.sel_cords = None
self.regularization = opt.regularization
self.angle_diff = opt.angle_diff
self.use_occlusion = opt.use_occlusion and self.use_smpl_depth
self.use_occlusion_net = opt.use_occlusion_net
self.gamma = 1
self.pts_nml = None
self.alpha_grad = None
self.alpha_gt = None
self.alpha_smpl = None
self.alpha = None
self.omega_reg = 0.01
self.nerf_out_ch = 8 if self.use_attention else 4
self.use_vh = opt.use_vh
self.vh_overhead = opt.vh_overhead if self.use_vh else 1
self.use_vh_free = opt.use_vh_free
self.use_white_bkgd = opt.use_white_bkgd
self.default_rgb = torch.zeros if not self.use_white_bkgd else torch.ones
self.occ = None
self.occ_gt = None
def cal_loss(self, rgb, rgb_gt):
loss = {'nerf': self.mse_loss(rgb[:,:3], rgb_gt)}
if self.use_attention:
loss.update({'att': self.mse_loss(rgb[:,3:6], rgb_gt)})
# loss.update({'cmb': self.mse_loss(rgb[:,6:9], rgb_gt)})
if self.alpha_gt is not None and self.alpha is not None:
# alpha loss has three options, self.alpha_loss (binary cross entropy), self.mse_loss (mean square error)
# and self.alpha_loss_bmse (one sided mean square error), we find the last performs the best
loss.update({'alpha': self.mse_loss(self.alpha, self.alpha_gt)})
if self.regularization:
loss.update({'alpha_reg': self.alpha_grad_loss(self.alpha_grad / self.nerf.spatial_freq) * self.omega_reg})
if self.angle_diff and self.angle_diff_grad is not None:
loss.update({'angle_diff': torch.mean(self.angle_diff_grad**2)})
if self.use_occlusion_net and self.occ is not None and self.occ_gt is not None:
loss.update({'occ': self.mse_loss(self.occ, self.occ_gt)*0.1})
return loss
def get_rays_orthogonal(self, bbox, calib):
top, bottom, left, right = bbox
cy, cx, focal = self.height/2, self.width/2, self.height/2
radian = ((right - left) / 2 + 1) / focal
i, j = torch.meshgrid(torch.linspace(top, bottom-1, int(bottom-top), device=calib.device),
torch.linspace(left, right-1, int(right-left), device=calib.device)) # pytorch's meshgrid has indexing='ij'
x = (j - cx) / focal
y = (i - cy) / focal
z = torch.sqrt(radian**2 - x**2)
# z = torch.ones_like(x)
starts = torch.stack([x, y, z], -1)
ends = torch.stack([x, y, -z], -1)
calib = torch.inverse(calib)
R, t = calib[:3,:3], calib[:3, 3]
rays_s = torch.sum(starts[..., None, :] * R, -1) + t
rays_e = torch.sum(ends[..., None, :] * R, -1) + t
return rays_s, rays_e
def get_rays_perspective(self, bbox, w2c, cam):
"""
bbox: bounding box [top, bottom, left, right]
w2c: 4x4 rotation matrix
cam: perspective camera parameters [fx, fy, cx, cy, (if distortion), near, far]
"""
top, bottom, left, right = bbox
near, far = cam[-2], cam[-1]
top, bottom, left, right = int(top),int(bottom),int(left),int(right)
i, j = torch.meshgrid(torch.linspace(top, bottom-1, int(bottom-top), device=w2c.device),
torch.linspace(left, right-1, int(right-left), device=w2c.device))
x = (j - cam[2]) / cam[0]
y = (i - cam[3]) / cam[1]
if len(cam) > 6:
xp, yp = x, y
for _ in range(3): # iter to undistort
x2 = x*x
y2 = y*y
xy = x*y
r2 = x2 + y2
c = (1 + r2*(cam[4]+r2*(cam[5]+r2*cam[8])))
x = (xp - cam[6]*2*xy - cam[7]*(r2+2*x2)) / (c+1e-9)
y = (yp - cam[7]*2*xy - cam[6]*(r2+2*y2)) / (c+1e-9)
z = torch.ones_like(x)
starts = torch.stack([x*near,y*near,z*near],-1)
ends = torch.stack([x*far, y*far, z*far], -1)
c2w = torch.inverse(w2c)
R, t = c2w[:3,:3], c2w[:3, 3]
rays_s = torch.sum(starts[..., None, :]* R, -1) + t
rays_e = torch.sum(ends[..., None, :] * R, -1) + t
# rs = rays_s.cpu().numpy().reshape(-1,3)
# re = rays_e.cpu().numpy().reshape(-1,3)
return rays_s, rays_e
def make_att_input(self, pts, viewdirs, calibs, smpl):
"""
Prepare input for multiview attention based SSOAB
"""
if self.projection_mode == 'perspective':
cam_c = torch.inverse(calibs)[:,:3,3]
attdirs = cam_c[None,:,:].expand(pts.shape[0], -1, -1) - pts[:,None,:].expand(-1,self.num_views,-1)
if smpl is not None:
viewdirs = viewdirs @ smpl['rot'][0]
attdirs = (attdirs.view(-1,3) @ smpl['rot'][0]).view(attdirs.shape)
attdirs = torch.cat([viewdirs[:,None,:], attdirs], dim=1)
attdirs = attdirs / torch.clamp(torch.norm(attdirs, dim=-1, keepdim=True),min=1e-9)
if self.angle_diff:
viewdirs = viewdirs / torch.clamp(torch.norm(viewdirs, dim=-1, keepdim=True),min=1e-9)
attdirs = torch.sum(attdirs * viewdirs.unsqueeze(1), dim=-1, keepdim=True)
else:
## c2w @ [0,0,1] is equvilant to c2w[:3, 2] back tracing attention direction
attdirs = torch.inverse(calibs)[:, :3, 2] # [num_views, 3]
attdirs = attdirs[None,...].repeat([pts.shape[0], 1, 1])
if smpl is not None:
viewdirs = viewdirs @ smpl['rot'][0]
attdirs = attdirs @ smpl['rot'][0]
attdirs = torch.cat([viewdirs, attdirs], dim=0) # [(num_views+1), 3]
attdirs = attdirs / torch.norm(attdirs, dim=-1, keepdim=True)
return attdirs
def make_nerf_input(self, pts, feats, images, smpl, calibs, mesh_param, persps=None, is_train=True):
"""
Aggregate Geometric Body Shape Embedding for NeRF input
"""
nerf_input, source_rgb = [], None
# Convert query point to normalized body coordinate (normalized scale and body orientation)
center, body_scale = mesh_param['center'], mesh_param['body_scale']
if self.use_nml:
# points normalized to volume [-1,1]^3
self.pts_nml = ((pts - center) / body_scale).requires_grad_()
if self.use_smpl_sdf: self.pts_nml = self.pts_nml @ smpl['rot'][0] # rotate to smpl volume, with smpl root node facing front
nerf_input.append(self.pts_nml)
else:
nerf_input.append(pts)
# Body shape embedding
if self.use_smpl_sdf or self.use_t_pose:
self.mesh_searcher.set_mesh(smpl['verts'], smpl['faces'])
closest_pts, closest_idx = self.mesh_searcher.nearest_points(pts)
if self.use_t_pose:
closest_faces = smpl['faces'][closest_idx.long()]
t_pose_verts = smpl['t_verts'][closest_faces.long()]
t_pose_coords = t_pose_verts.mean(dim=1)
# T-pose correspondance
nerf_input.append(t_pose_coords)
if self.use_smpl_sdf:
reg_vecs = pts - closest_pts
if self.use_nml:
reg_vecs = reg_vecs / body_scale # normalized to volume [-1,1]^3
reg_vecs = reg_vecs @ smpl['rot'][0] # rotate to smpl volume, with smpl root node facing front
signs = self.mesh_searcher.inside_mesh(pts)
self.alpha_smpl = (signs + 1) / 2
norm = torch.norm(reg_vecs, dim=1, keepdim=True) + 1e-8
sdf = norm * signs[...,None]
# Normalized SDF Gradient
nerf_input.append(reg_vecs / norm)
# SDF (scale for a constant for faster convergence)
nerf_input.append(torch.tanh(sdf*20))
# Multiview image feature
if feats is not None:
xyz = self.projection(pts.permute((1,0))[None,...].expand([calibs.shape[0],-1,-1]), calibs, persps)
xy = xyz[:, :2, :] # [self.num_views, 2, self.N_samples]
if persps is not None:
xy = xy / torch.tensor([[[self.width],[self.height]]], \
dtype = xyz.dtype, device = xyz.device) * 2 - 1
latent = index(feats, xy) # [self.num_views, C, self.N_samples(*2)]
latent = latent.permute((2,0,1)) # [self.N_samples(*2), self.num_views, C]
source_rgb = index(images[:self.num_views], xy)
source_rgb = source_rgb.permute((2,0,1))
latent = torch.cat([latent, source_rgb], -1)
nerf_input = [inp[:,None,:].expand([-1,self.num_views,-1]) for inp in nerf_input] # expand each feature to num_views
nerf_input += [latent]
nerf_input = torch.cat(nerf_input, dim=-1) # [self.N_samples, self.num_views, C]
return nerf_input, source_rgb
def make_nerf_output(self, nerf_output, t_vals, norm, source_rgb, is_train=True):
"""
Renders ray by integrating sample points
"""
dists = t_vals[...,1:] - t_vals[...,:-1]
dists = torch.cat([dists, torch.tensor([1e10], device=dists.device).expand(dists[...,:1].shape)], -1) # [N_rays, N_samples]
dists = dists * norm
N_samples = t_vals.shape[-1]
rgb = torch.sigmoid(nerf_output[...,:3]) # [N_rays, N_samples, 3]
noise = torch.randn(nerf_output[...,3].shape, device=nerf_output.device) if is_train else 0
alpha = 1.-torch.exp(-F.relu((nerf_output[...,3]+noise))) # [N_rays, N_samples]
weights = alpha * torch.cumprod(torch.cat([torch.ones((alpha.shape[0], 1), device=nerf_output.device), 1.-alpha + 1e-10], -1), -1)[:, :-1]
rgb_map = torch.sum(weights[...,None] * rgb, -2) # [N_rays, 3]
if self.use_attention:
att = nerf_output[...,4:]
source_rgb = source_rgb.reshape([-1, N_samples, self.num_views, 3])
source_rgb = torch.cat([rgb.unsqueeze(-2), source_rgb], dim=-2)
source_rgb_att = torch.sum(source_rgb * att[...,None], dim=-2)
att_rgb_map = torch.sum(weights[...,None] * source_rgb_att, -2) # [N_rays, 3]
rgb_map = torch.cat([rgb_map, att_rgb_map], -1)
if self.debug:
for i in range(self.num_views):
source_rgb_i = source_rgb[:,:,i,:] * att[:, :, i, None]
rgb_map_i = torch.sum(weights[...,None] * source_rgb_i, -2)
rgb_map = torch.cat([rgb_map, rgb_map_i], -1)
acc_map = torch.sum(weights, -1)
if self.use_white_bkgd:
rgb_map = rgb_map + (1.-acc_map[...,None])
return rgb_map, weights
def render_rays(self, ray_batch, feats, images, masks, calibs, smpl, mesh_param,
scan=None, persps=None, q_persps=None, is_train=True):
"""Volumetric rendering.
Args:
ray_batch: array of shape [batch_size, ...]. All information necessary
for sampling along a ray, including: ray origin, ray direction, min
dist, max dist, and unit-magnitude viewing direction.
"""
self.alpha, self.angle_diff_grad = None, None
eps = 1e-9
N_rays = ray_batch.shape[0]
rays_s, rays_e = ray_batch[:,0:3], ray_batch[:,3:6] # [N_rays, 3] each
t_vals = torch.linspace(0., 1., steps=self.N_samples, device=ray_batch.device)
t_vals = t_vals.repeat([N_rays, 1])
# perturb during training
if is_train:
t_rand = (torch.rand(t_vals.shape, device=ray_batch.device) - 0.5) / (self.N_samples-1)
t_vals = t_vals + t_rand
pts = rays_e[:,None,:] * t_vals[...,None] + (1-t_vals[...,None]) * rays_s[:,None,:]
pts = pts.reshape(-1, 3)
# Use visual hull to skip sample points outside the body
inside, smpl_vis, scan_vis = None, None, None
if self.use_vh:
inside, smpl_vis, scan_vis = self.inside_pts_vh(pts, masks, smpl, calibs, persps)
pts = pts[inside]
if len(pts) == 0:
return self.default_rgb([N_rays, self.rgb_ch], dtype=torch.float32, device=ray_batch.device), \
torch.zeros([N_rays], dtype=torch.float32, device=ray_batch.device)
# When train RenderPeople with scan ground truth, prepare 3D supervision
if is_train and scan is not None:
scan_verts, scan_faces = scan
self.mesh_searcher.set_mesh(scan_verts, scan_faces)
self.alpha_gt = (self.mesh_searcher.inside_mesh(pts) + 1) / 2
# Prepare attention based appereance blending input
viewdirs = (rays_s - rays_e)[:,None,:].expand(-1,self.N_samples,-1)
viewdirs = viewdirs.reshape(-1,3)[inside].requires_grad_()
attdirs = self.make_att_input(pts, viewdirs, calibs, smpl) if self.use_attention else []
# Prepare geometry body shape embedding input for NeRF
nerf_input, source_rgb = self.make_nerf_input(pts, feats, images, smpl, calibs, mesh_param, persps)
# Feed to the network
nerf_output = torch.cat([self.nerf(nerf_input[i:i+self.chunk], attdirs[i:i+self.chunk], smpl_vis=smpl_vis) \
for i in range(0, nerf_input.shape[0], self.chunk)], 0)
self.alpha = torch.sigmoid(nerf_output[...,3]*self.gamma)
# If RenderPeople available, supervise the occlusion
if self.use_occlusion_net:
if is_train and scan is not None:
self.occ_gt = scan_vis.float()
self.occ = nerf_output[:, -self.num_views:]
nerf_output = nerf_output[:, :-self.num_views]
# Regularize the alpha distribution
if self.regularization and is_train:
self.alpha_grad = torch.autograd.grad(self.alpha, self.pts_nml, grad_outputs=torch.ones_like(self.alpha), retain_graph=True)[0]
# use sparse multiplication to aggregate points inside and outside the visual hull for NeRF integration
if self.use_vh:
inside_idx = torch.nonzero(inside)
row_cols = torch.cat([inside_idx.view(1,-1), torch.arange(len(inside_idx), device=pts.device).view(1,-1)], 0)
I = torch.sparse_coo_tensor(row_cols, torch.ones(len(inside_idx),dtype=pts.dtype, device=pts.device), size=(N_rays*self.N_samples, len(inside_idx)))
nerf_output = torch.sparse.mm(I, nerf_output)
nerf_output[~inside, :4] = -1e4
full_source_rgb = torch.zeros([N_rays*self.N_samples, self.num_views, 3], device=pts.device)
full_source_rgb[inside] = source_rgb
source_rgb = full_source_rgb
nerf_output = nerf_output.view(N_rays, self.N_samples, -1)
norm = torch.norm(rays_e - rays_s, dim=-1, keepdim=True)
if self.use_nml:
center, body_scale = mesh_param['center'], mesh_param['body_scale']
norm = norm / body_scale
rgb_map, weights = self.make_nerf_output(nerf_output, t_vals, norm, source_rgb, is_train=is_train)
z_vals = t_vals * q_persps[-2] + (1-t_vals) * q_persps[-1] if persps is not None and q_persps is not None else 2*t_vals - 1
depth = torch.sum(weights * z_vals, -1)
# Regularize the angle difference of apperance
if self.angle_diff and is_train:
self.angle_diff_grad = torch.autograd.grad(rgb_map, viewdirs, grad_outputs=torch.ones_like(rgb_map), retain_graph=True)[0]
return rgb_map, depth
def inside_pts_vh(self, pts, masks, smpl, calibs, persps=None):
"""
Valid sample point selection via visual hull
"""
xyz = self.projection(pts.permute((1,0))[None,...].expand([calibs.shape[0],-1,-1]), calibs, persps)
xy = xyz[:, :2, :]
if persps is not None:
xy = xy / torch.tensor([[[self.width],[self.height]]], \
dtype = xyz.dtype, device = xyz.device) * 2 - 1
inside = index(masks, xy, 'nearest')
inside = torch.prod(inside, dim=0).squeeze(0) > 0
if (inside.sum() < self.chunk * 0.7) and self.use_vh_free:
n_samples = inside.sum() * 0.3
idx = torch.randperm(len(inside))[:n_samples]
inside[idx] = True
smpl_vis, scan_vis = None, None
if self.use_occlusion:
smpl_depth = index(smpl['depth'], xy, 'nearest').squeeze(1).permute((1,0))
smpl_depth = smpl_depth[inside]
depth = xyz[:,2,:].permute((1,0))
depth = depth[inside]
smpl_vis = (((depth - smpl_depth) <= 0) * (smpl_depth > 0) + (smpl_depth == 0)) > 0
if self.use_occlusion_net and 'scan_depth' in smpl.keys():
scan_depth = index(smpl['scan_depth'], xy, 'nearest').squeeze(1).permute((1,0))[inside]
depth = xyz[:,2,:].permute((1,0))[inside]
scan_vis = (((depth - scan_depth) <= 0) * (scan_depth > 0) + (scan_depth == 0)) > 0
return inside, smpl_vis, scan_vis
def render(self, feats, images, masks, calibs, bbox, mesh_param, smpl=None, scan=None, persps=None):
"""
Render a image from give camera pose
"""
self.debug_idx += 1
if persps is None:
rays_s, rays_e = self.get_rays(bbox, calibs[-1]) # (H, W, 3), (H, W, 3)
else:
rays_s, rays_e = self.get_rays_perspective(bbox, calibs[-1], persps[-1])
top, bottom, left, right = bbox
gt = images[-1].permute((1,2,0))[top:bottom, left:right]
coords = torch.stack(torch.meshgrid(torch.linspace(0, bottom-top-1, int(bottom-top), device=calibs.device),
torch.linspace(0, right-left-1, int(right-left), device=calibs.device)), -1) # (H, W, 2)
coords = coords.view(-1,2) # (H * W, 2)
select_inds = np.random.choice(coords.shape[0], size=[self.N_rand*self.vh_overhead], replace=False) # (N_rand,)
select_coords = coords[select_inds].long() # (N_rand, 2)
rays_s = rays_s[select_coords[:, 0], select_coords[:, 1]] # (N_rand, 3)
rays_e = rays_e[select_coords[:, 0], select_coords[:, 1]] # (N_rand, 3)
batch_rays = torch.cat([rays_s, rays_e], 1)
target = gt[select_coords[:, 0], select_coords[:, 1]] # (N_rand, 3)
persps = persps[:self.num_views] if persps is not None else None
rgb, _ = self.render_rays(batch_rays, feats, images[:self.num_views], masks[:self.num_views],
calibs[:self.num_views], smpl, mesh_param, scan, persps)
loss = self.cal_loss(rgb, target)
return loss
def render_path(self, feats, images, masks, calibs, bbox, mesh_param, smpl=None, scan=None, persps=None):
"""
Render a path given trajectory
"""
top, bottom, left, right = bbox
height, width = max(self.height, bottom-top), max(self.width, right-left)
calibs_source, calibs_query = calibs[:self.num_views], calibs[self.num_views:]
persps_source = persps[:self.num_views] if persps is not None else None
persps_query = persps[self.num_views:] if persps is not None else None
rgbs, depths = [], []
# inference
for idx, calib in enumerate(tqdm(calibs_query)):
if persps is None:
rays_s, rays_e = self.get_rays(bbox, calib) # (H, W, 3), (H, W, 3)
else:
persp = persps[self.num_views+idx]
rays_s, rays_e = self.get_rays_perspective(bbox, calib, persp)
batch_rays = torch.cat([rays_s.view(-1,3), rays_e.view(-1,3)], 1)
# self.idx = 0
rgb, depth = [], []
for i in range(0, batch_rays.shape[0], self.N_rand_infer):
c, d = self.render_rays(batch_rays[i:i+self.N_rand_infer].detach(), feats, images[:self.num_views], \
masks[:self.num_views], calibs_source, smpl, mesh_param, persps=persps_source, q_persps=persps_query[idx], is_train=False)
rgb.append(c[:, :self.rgb_ch])
depth.append(d)
rgb = torch.cat(rgb, 0).clone()
depth = torch.cat(depth, 0).clone()
img = self.default_rgb((height, width, self.rgb_ch), dtype=torch.float32, device=rgb.device)
dimg = torch.zeros((height, width), dtype=torch.float32, device=rgb.device)
img[top:bottom, left:right] = rgb.view(int(bottom-top), int(right-left), self.rgb_ch)
dimg[top:bottom, left:right] = depth.view(int(bottom-top), int(right-left))
rgbs.append(img)
depths.append(dimg)
rgbs = torch.stack(rgbs, dim=0)
depths = torch.stack(depths, dim=0)
return rgbs, depths
def train_shape(self, feats, images, masks, calibs, bbox, mesh_param, smpl=None, scan=None, persps=None):
"""
Unused
"""
pts, alpha_gt = mesh_param['samples'].transpose(1,0), mesh_param['labels']
persps = persps[:self.num_views] if persps is not None else None
nerf_input, _ = self.make_nerf_input(pts, feats, images[:self.num_views], smpl, calibs[:self.num_views], mesh_param, persps=persps)
nerf_output = torch.cat([self.nerf(nerf_input[i:i+self.chunk], alpha_only=True) \
for i in range(0, pts.shape[0], self.chunk)], 0)
alpha = torch.sigmoid(nerf_output*self.gamma).reshape(alpha_gt.shape)
# loss = self.alpha_loss(alpha, alpha_gt[...,None].detach())
loss = {'alpha': self.rgb_loss(alpha, alpha_gt)}
if self.regularization:
alpha_grad = torch.autograd.grad(alpha, self.pts_nml, grad_outputs=torch.ones_like(alpha), retain_graph=True)[0]
loss.update({'alpha_reg': self.alpha_grad_loss(alpha_grad) * self.omega_reg})
return loss
def reconstruct(self, feats, images, masks, calibs, bbox, mesh_param, smpl=None, scan=None, persps=None):
"""
Mesh Reconstruction borrowed form PIFu
"""
# Deterimine 3D bounding box
center, body_scale = mesh_param['center'].cpu().numpy(), mesh_param['body_scale']
top, bottom, left, right = bbox
left, right = 0, 512
bb_min = [left-self.width/2, top-self.height/2, left-self.width/2]
bb_max = [right-self.width/2, bottom-self.height/2, right-self.width/2]
# Make mesh grid in normalized body cordinate
linspaces = [np.linspace(bb_min[i], bb_max[i], self.N_grid) for i in range(len(bb_min))]
grids = np.stack(np.meshgrid(linspaces[0], linspaces[1], linspaces[2], indexing='ij'), -1)
sh = grids.shape
pts = grids / (self.width/2) * body_scale + center
recon_kwargs = {
'feats': feats, 'images': images, 'smpl': smpl, 'calibs': calibs[:self.num_views],
'mesh_param': mesh_param, 'persps': persps[:self.num_views] if persps is not None else None
}
# Reconstruct use progressive octree reconstrution
sdf = self.octree_reconstruct(pts, masks, **recon_kwargs)
verts, faces, normals, _ = measure.marching_cubes_lewiner(sdf, self.threshold)
# Convert marching cubes coordinate back to world coordinate
verts = (verts - self.N_grid/2) / self.N_grid * np.array([[right-left, bottom-top, right-left]])
verts = verts / (self.width/2) * body_scale + center
# use laplacian smooth if the mesh is noisy
if self.opt.laplacian > 0:
mesh = trimesh.Trimesh(verts, faces, process=False)
trimesh.smoothing.filter_laplacian(mesh, iterations=self.opt.laplacian)
verts, faces = mesh.vertices, mesh.faces
pts = torch.tensor(verts, dtype=torch.float32, device=calibs.device)
viewdirs = torch.from_numpy(normals.astype(np.float32)).to(calibs.device)
attdirs = self.make_att_input(pts, viewdirs, calibs[:self.num_views], smpl) if self.use_attention else []
rgbs = []
for i in range(0, pts.shape[0], self.chunk):
nerf_input, source_rgb = self.make_nerf_input(pts[i:i+self.chunk], **recon_kwargs)
nerf_output = self.nerf(nerf_input, attdirs[i:i+self.chunk])
rgb = torch.sigmoid(nerf_output[...,:3])
if self.use_attention:
att = nerf_output[...,4:4+self.num_views+1]
source_rgb = source_rgb.view(-1, self.num_views, 3)
source_rgb = torch.cat([rgb[:,None], source_rgb], dim=-2)
rgb = torch.sum(source_rgb * att[...,None], dim=-2)
rgbs.append(rgb)
rgbs = torch.cat(rgbs, 0).cpu().numpy()
return verts, faces, rgbs
def octree_reconstruct(self, coords, masks, **kwargs):
"""
We use Octree recontruction for higher resolution reconstruction borrowed form PIFu
"""
device = kwargs['calibs'].device
calibs = kwargs['calibs']
persps = kwargs['persps']
resolution = [self.N_grid, self.N_grid, self.N_grid]
sdf = np.zeros(resolution)
notprocessed = np.zeros(resolution, dtype=np.bool)
notprocessed[:-1,:-1,:-1] = True
# only voxel grids lies in the visual hull are to processed
if self.use_vh:
dilation_kernel = torch.ones((1,1,5,5), device=device, dtype=torch.float32)
masks = torch.clamp(torch.nn.functional.conv2d(masks, dilation_kernel, padding=(2, 2)), 0, 1)
masks_np = masks.permute([0,2,3,1]).cpu().numpy()
pts = coords.reshape(-1, 3)
notprocessed = notprocessed.reshape(-1)
for i in range(0, pts.shape[0], self.chunk):
inside, _, _ = self.inside_pts_vh(torch.tensor(pts[i:i+self.chunk], dtype=torch.float32, device=device),
masks, kwargs['smpl'], calibs, persps)
inside = inside.cpu().numpy()
outside = np.logical_not(inside.astype(np.bool))
notprocessed_chunk = notprocessed[i:i+self.chunk].copy()
notprocessed_chunk[outside] = False
notprocessed[i:i+self.chunk] = notprocessed_chunk
notprocessed = notprocessed.reshape(resolution)
grid_mask = np.zeros(resolution, dtype=np.bool)
reso = self.N_grid // 64
center = kwargs['mesh_param']['center'].cpu().numpy()
while reso > 0:
grid_mask[0:self.N_grid:reso, 0:self.N_grid:reso, 0:self.N_grid:reso] = True
test_mask = np.logical_and(grid_mask, notprocessed)
pts = coords[test_mask, :]
if pts.shape[0] == 0:
print("break")
break
pts_tensor = torch.tensor(pts, dtype=torch.float32, device=device)
nerf_output = []
for i in range(0, pts_tensor.shape[0], self.chunk):
nerf_input, _ = self.make_nerf_input(pts_tensor[i:i+self.chunk], **kwargs)
nerf_output.append(self.nerf(nerf_input, alpha_only=True))
nerf_output = torch.cat(nerf_output, dim=0)
sdf[test_mask] = torch.sigmoid(nerf_output*self.gamma).detach().cpu().numpy().reshape(-1)
notprocessed[test_mask] = False
# do interpolation
if reso <= 1:
break
grid = np.arange(0, self.N_grid, reso)
v = sdf[tuple(np.meshgrid(grid, grid, grid, indexing='ij'))]
vs = [v[:-1,:-1,:-1], v[:-1,:-1,1:], v[:-1,1:,:-1], v[:-1,1:,1:],
v[1:,:-1,:-1], v[1:,:-1,1:], v[1:,1:,:-1], v[1:,1:,1:]]
grid = grid[:-1] + reso//2
nonprocessed_grid = notprocessed[tuple(np.meshgrid(grid, grid, grid, indexing='ij'))]
v = np.stack(vs, 0)
v_min = v.min(0)
v_max = v.max(0)
v = 0.5*(v_min+v_max)
skip_grid = np.logical_and(((v_max - v_min) < 0.01), nonprocessed_grid)
xs, ys, zs = np.where(skip_grid)
for x, y, z in zip(xs*reso, ys*reso, zs*reso):
sdf[x:(x+reso+1), y:(y+reso+1), z:(z+reso+1)] = v[x//reso,y//reso,z//reso]
notprocessed[x:(x+reso+1), y:(y+reso+1), z:(z+reso+1)] = False
reso //= 2
return sdf.reshape(resolution)
def get_plucker_line(self, countour, w2c, cam, mesh_param):
"""
Unused
"""
x = (countour[:,0].float() - cam[2]) / cam[0]
y = (countour[:,1].float() - cam[3]) / cam[1]
near, far = cam[-2], cam[-1]
center, body_scale = mesh_param['center'], mesh_param['body_scale']
if len(cam) > 6:
xp, yp = x, y
for _ in range(3): # iter to undistort
x2 = x*x
y2 = y*y
xy = x*y
r2 = x2 + y2
c = (1 + r2*(cam[4]+r2*(cam[5]+r2*cam[8])))
x = (xp - cam[6]*2*xy - cam[7]*(r2+2*x2)) / (c+1e-9)
y = (yp - cam[7]*2*xy - cam[6]*(r2+2*y2)) / (c+1e-9)
z = torch.ones_like(x)
starts = torch.stack([x*near,y*near,z*near],-1)
ends = torch.stack([x*far, y*far, z*far], -1)
c2w = torch.inverse(w2c)
R, t = c2w[:3,:3], c2w[:3, 3]
rays_s = torch.sum(starts[..., None, :]* R, -1) + t
rays_e = torch.sum(ends[..., None, :] * R, -1) + t
rays_d = rays_e - rays_s
rays_d = rays_d / (torch.norm(rays_d, dim=-1, keepdim=True) + 1e-8)
if self.use_nml:
rays_s = (rays_s - center) / body_scale
return rays_d, torch.cross(rays_s, rays_d, dim=-1)
def warp_mat(self, face_verts):
"""
Unused
"""
# [[x0, x1, x2, nx/sqrt(size)]
# [y0, y1, y2, ny/sqrt(size)]
# [z0, z1, z2, nz/sqrt(size)]
# [1, 1, 1, 0 ]]
a = face_verts[:,1,:] - face_verts[:,0,:]
b = face_verts[:,2,:] - face_verts[:,0,:]
normal = torch.cross(a, b, dim=-1)
# cross product's norm equal 2*size of the triangluar
size = torch.norm(normal/2, dim=-1, keepdim=True)
normal = normal / torch.sqrt(size + 1e-10)
mat = torch.cat([face_verts.permute([0,2,1]), normal[...,None]], dim=-1) # [N_samples, 3, 4]
column = torch.tensor([[[1,1,1,0]]], dtype=torch.float32, device=face_verts.device).repeat([mat.shape[0],1,1])
mat = torch.cat([mat, column], dim=1)
return mat
def canonical_warpping(self, verts, verts_can, pts):
"""
Unused
"""
warp = self.warp_mat(verts)
warp_can = self.warp_mat(verts_can)
pts_homo = torch.cat([pts, torch.ones([pts.shape[0], 1], dtype=torch.float32, device=pts.device)], dim=-1)
pts_can = warp_can @ torch.inverse(warp) @ pts_homo[...,None]
return pts_can[:,:3,0]
def multiview_consistency(self, pts, feats, calibs, depth, persps=None):
"""
Unused
"""
xyz = self.projection(pts.permute((1,0))[None,...].expand([calibs.shape[0],-1,-1]), calibs, persps)
xy = xyz[:, :2, :] # [self.num_views, 2, self.N_samples]
z = xyz[:, 2:, :].permute((2,0,1))
if persps is not None:
xy = xy / torch.tensor([[[self.width],[self.height]]], \
dtype = xyz.dtype, device = xyz.device) * 2 - 1
latent = index(feats, xy) # [self.num_views, C, self.N_samples(*2)]
latent = latent.permute((2,0,1)) # [self.N_samples(*2), self.num_views, C]
depth_z = index(depth, xy, mode='nearest')
depth_z = depth_z.permute((2,0,1))
return depth_z, z
================================================
FILE: lib/model/SRFilters.py
================================================
import torch
import torch.nn as nn
import torch.nn.functional as F
class SRFilters(nn.Module):
"""
Upsample the pixel-aligned feature
"""
def __init__(self, order=2, in_ch=256, out_ch=128):
super(SRFilters, self).__init__()
self.in_ch = in_ch
self.out_ch = out_ch
self.image_factor = [0.5**(order-i) for i in range(0, order+1)]
self.convs = nn.ModuleList([nn.Conv2d(in_ch+3, out_ch, kernel_size=3, padding=1)] +
[nn.Conv2d(out_ch+3, out_ch, kernel_size=3, padding=1) for i in range(order)])
def forward(self, feat, images):
for i, conv in enumerate(self.convs):
im = F.interpolate(images, scale_factor=self.image_factor[i], mode='bicubic', align_corners=True) if self.image_factor[i] is not 1 else images
feat = F.interpolate(feat, scale_factor=2, mode='bicubic', align_corners=True) if i is not 0 else feat
feat = torch.cat([feat, im], dim=1)
feat = self.convs[i](feat)
return feat
================================================
FILE: lib/model/__init__.py
================================================
from .GNR import GNR
from .NeRF import NeRF
from .NeRFRenderer import NeRFRenderer
from .Embedder import SphericalHarmonics, PositionalEncoding
================================================
FILE: lib/net_ddp.py
================================================
import torch
import torch.multiprocessing as mp
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
import logging
import os
import random
import numpy as np
import math
def worker_init_fn(worker_id):
random.seed(worker_id+100)
np.random.seed(worker_id+100)
torch.manual_seed(worker_id+100)
def ddp_init(args):
local_rank = args.local_rank
# torch.set_default_tensor_type('torch.cuda.FloatTensor') # RuntimeError: Expected a 'N2at13CUDAGeneratorE' but found 'PN2at9GeneratorE'
dist.init_process_group(backend = 'nccl') # 'nccl' for GPU, 'gloo/mpi' for CPU
torch.cuda.set_device(local_rank)
rank = torch.distributed.get_rank()
random.seed(rank)
np.random.seed(rank)
torch.manual_seed(rank)
print(f"local_rank {local_rank} rank {rank} launched...")
return rank, local_rank
def create_network(opt, net, local_rank=0):
def load_network(opt, net, load=True):
# init network from ckpts
start_epoch, global_step = 0, 0
ckpts = [os.path.join(opt.basedir, opt.name, f) for f in sorted(os.listdir(os.path.join(opt.basedir, opt.name))) if 'tar' in f]
if len(ckpts) == 0:
ckpts = [os.path.join(opt.basedir, opt.name, '..', f) for f in sorted(os.listdir(os.path.join(opt.basedir, opt.name, '..'))) if 'tar' in f]
if len(ckpts) > 0:
ckpt_path = ckpts[-1]
logging.info(f'Reloading from {ckpt_path}')
ckpt = torch.load(ckpt_path, map_location=torch.device('cpu')) # load to cpu, otherwise this will occupy rank=0 's gpu memory
if load:
try: # dp or ddp load itself
net.load_state_dict(ckpt['network_state_dict'])
except:
try: # dp_load_ddp
net.load_state_dict({'module.'+k: v for k, v in ckpt['network_state_dict'].items()})
except: # ddp_load_dp
net.load_state_dict({k[7:]: v for k, v in ckpt['network_state_dict'].items()})
start_epoch = ckpt['epoch']
# if no ckpts found, only init the encode from PIFu
elif opt.load_netG_checkpoint_path is not None:
logging.info(f'loading for net G ... {opt.load_netG_checkpoint_path}')
pretrained_net = torch.load(opt.load_netG_checkpoint_path, map_location=torch.device('cpu'))
if opt.ddp:
pretrained_image_filter = {k: v for k, v in pretrained_net.items() if k.startswith('image_filter')}
else:
pretrained_image_filter = {'module.'+k: v for k, v in pretrained_net.items() if k.startswith('image_filter')}
if load:
net.load_state_dict(pretrained_image_filter, strict=False)
return start_epoch
# DDP: load parameters first (only on master node), then make ddp model
if opt.ddp:
logging.info("use Distributed Data Parallel...")
net = net.to(local_rank)
start_epoch = load_network(opt, net, load=(dist.get_rank() == 0))
net = DDP(net, device_ids=[local_rank], output_device=local_rank, find_unused_parameters=True)
# DP: make dp model, then load parameters to all devices
else:
if torch.cuda.is_available():
logging.info("use Data Parallel...")
gpu_ids = [i for i in range(torch.cuda.device_count())]
net = net.to(gpu_ids[0])
net = torch.nn.DataParallel(net, device_ids=gpu_ids)
start_epoch = load_network(opt, net)
return net, start_epoch
def synchronize():
if dist.get_world_size() > 1:
dist.barrier()
return
class ddpSampler:
"""
ddp sampler for inference
"""
def __init__(self, dataset, rank=None, num_replicas=None):
if num_replicas is None:
if not torch.distributed.is_available():
raise RuntimeError("Requires distributed package to be available")
num_replicas = torch.distributed.get_world_size()
if rank is None:
if not torch.distributed.is_available():
raise RuntimeError("Requires distributed package to be available")
rank = torch.distributed.get_rank()
self.dataset = dataset
self.num_replicas = num_replicas
self.rank = rank
self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas))
self.total_size = self.num_samples * self.num_replicas
def indices(self):
indices = list(range(len(self.dataset)))
# add extra samples to make it evenly divisible
indices += [indices[-1]] * (self.total_size - len(indices))
# subsample
indices = indices[self.rank * self.num_samples : (self.rank + 1) * self.num_samples]
return indices
def len(self):
return self.num_samples
def distributed_concat(self, tensor, num_total_examples):
output_tensors = [tensor.clone() for _ in range(torch.distributed.get_world_size())]
torch.distributed.all_gather(output_tensors, tensor)
concat = torch.cat(output_tensors, dim=0)
# truncate the dummy elements added by SequentialDistributedSampler
return concat[:num_total_examples]
================================================
FILE: lib/net_util.py
================================================
import torch
from torch.nn import init
import torch.nn as nn
import torch.nn.functional as F
import functools
import numpy as np
from .mesh_util import *
from .geometry import index
import cv2
from PIL import Image
from tqdm import tqdm
def conv3x3(in_planes, out_planes, strd=1, padding=1, bias=False):
"3x3 convolution with padding"
return nn.Conv2d(in_planes, out_planes, kernel_size=3,
stride=strd, padding=padding, bias=bias)
def init_weights(net, init_type='normal', init_gain=0.02):
"""Initialize network weights.
Parameters:
net (network) -- network to be initialized
init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal
init_gain (float) -- scaling factor for normal, xavier and orthogonal.
We use 'normal' in the original pix2pix and CycleGAN paper. But xavier and kaiming might
work better for some applications. Feel free to try yourself.
"""
def init_func(m): # define the initialization function
classname = m.__class__.__name__
if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
if init_type == 'normal':
init.normal_(m.weight.data, 0.0, init_gain)
elif init_type == 'xavier':
init.xavier_normal_(m.weight.data, gain=init_gain)
elif init_type == 'kaiming':
init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
elif init_type == 'orthogonal':
init.orthogonal_(m.weight.data, gain=init_gain)
else:
raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
if hasattr(m, 'bias') and m.bias is not None:
init.constant_(m.bias.data, 0.0)
elif classname.find(
'BatchNorm2d') != -1: # BatchNorm Layer's weight is not a matrix; only normal distribution applies.
init.normal_(m.weight.data, 1.0, init_gain)
init.constant_(m.bias.data, 0.0)
# print('initialize network with %s' % init_type)
net.apply(init_func) # apply the initialization function <init_func>
def init_net(net, init_type='normal', init_gain=0.02, gpu_ids=[]):
"""Initialize a network: 1. register CPU/GPU device (with multi-GPU support); 2. initialize the network weights
Parameters:
net (network) -- the network to be initialized
init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal
gain (float) -- scaling factor for normal, xavier and orthogonal.
gpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2
Return an initialized network.
"""
if len(gpu_ids) > 0:
assert (torch.cuda.is_available())
net.to(gpu_ids[0])
net = torch.nn.DataParallel(net, gpu_ids) # multi-GPUs
init_weights(net, init_type, init_gain=init_gain)
return net
def imageSpaceRotation(xy, rot):
'''
args:
xy: (B, 2, N) input
rot: (B, 2) x,y axis rotation angles
rotation center will be always image center (other rotation center can be represented by additional z translation)
'''
disp = rot.unsqueeze(2).sin().expand_as(xy)
return (disp * xy).sum(dim=1)
def cal_gradient_penalty(netD, real_data, fake_data, device, type='mixed', constant=1.0, lambda_gp=10.0):
"""Calculate the gradient penalty loss, used in WGAN-GP paper https://arxiv.org/abs/1704.00028
Arguments:
netD (network) -- discriminator network
real_data (tensor array) -- real images
fake_data (tensor array) -- generated images from the generator
device (str) -- GPU / CPU: from torch.device('cuda:{}'.format(self.gpu_ids[0])) if self.gpu_ids else torch.device('cpu')
type (str) -- if we mix real and fake data or not [real | fake | mixed].
constant (float) -- the constant used in formula ( | |gradient||_2 - constant)^2
lambda_gp (float) -- weight for this loss
Returns the gradient penalty loss
"""
if lambda_gp > 0.0:
if type == 'real': # either use real images, fake images, or a linear interpolation of two.
interpolatesv = real_data
elif type == 'fake':
interpolatesv = fake_data
elif type == 'mixed':
alpha = torch.rand(real_data.shape[0], 1)
alpha = alpha.expand(real_data.shape[0], real_data.nelement() // real_data.shape[0]).contiguous().view(
*real_data.shape)
alpha = alpha.to(device)
interpolatesv = alpha * real_data + ((1 - alpha) * fake_data)
else:
raise NotImplementedError('{} not implemented'.format(type))
interpolatesv.requires_grad_(True)
disc_interpolates = netD(interpolatesv)
gradients = torch.autograd.grad(outputs=disc_interpolates, inputs=interpolatesv,
grad_outputs=torch.ones(disc_interpolates.size()).to(device),
create_graph=True, retain_graph=True, only_inputs=True)
gradients = gradients[0].view(real_data.size(0), -1) # flat the data
gradient_penalty = (((gradients + 1e-16).norm(2, dim=1) - constant) ** 2).mean() * lambda_gp # added eps
return gradient_penalty, gradients
else:
return 0.0, None
def get_norm_layer(norm_type='instance'):
"""Return a normalization layer
Parameters:
norm_type (str) -- the name of the normalization layer: batch | instance | none
For BatchNorm, we use learnable affine parameters and track running statistics (mean/stddev).
For InstanceNorm, we do not use learnable affine parameters. We do not track running statistics.
"""
if norm_type == 'batch':
norm_layer = functools.partial(nn.BatchNorm2d, affine=True, track_running_stats=True)
elif norm_type == 'instance':
norm_layer = functools.partial(nn.InstanceNorm2d, affine=False, track_running_stats=False)
elif norm_type == 'group':
norm_layer = functools.partial(nn.GroupNorm, 32)
elif norm_type == 'none':
norm_layer = None
else:
raise NotImplementedError('normalization layer [%s] is not found' % norm_type)
return norm_layer
class Flatten(nn.Module):
def forward(self, input):
return input.view(inp
gitextract_wh9gfa65/
├── .gitignore
├── .gitmodules
├── README.md
├── apps/
│ ├── render_smpl_depth.py
│ └── run_genebody.py
├── configs/
│ ├── render.txt
│ ├── test.txt
│ └── train.txt
├── docs/
│ ├── Annotation.md
│ └── Dataset.md
├── environment.yml
├── genebody/
│ ├── download_tool.py
│ ├── gender.py
│ ├── genebody.py
│ └── mesh.py
├── lib/
│ ├── data/
│ │ ├── GeneBodyDataset.py
│ │ └── __init__.py
│ ├── geometry.py
│ ├── mesh_util.py
│ ├── metrics.py
│ ├── metrics_torch.py
│ ├── model/
│ │ ├── Embedder.py
│ │ ├── GNR.py
│ │ ├── HGFilters.py
│ │ ├── NeRF.py
│ │ ├── NeRFRenderer.py
│ │ ├── SRFilters.py
│ │ └── __init__.py
│ ├── net_ddp.py
│ ├── net_util.py
│ ├── options.py
│ └── ply_util.py
├── scripts/
│ ├── download_model.sh
│ ├── render_smpl_depth.sh
│ └── train_ddp.sh
└── smpl_t_pose/
├── smpl.obj
└── smplx.obj
SYMBOL INDEX (168 symbols across 20 files)
FILE: apps/render_smpl_depth.py
function load_obj_mesh (line 19) | def load_obj_mesh(mesh_file, with_normal=False, with_texture=False, with...
function extract_float (line 139) | def extract_float(text):
function natural_sort (line 149) | def natural_sort(files):
function load_ply (line 155) | def load_ply(file_name):
function distortPoints (line 230) | def distortPoints(p, dist):
function rasterize (line 252) | def rasterize(v, tri, size, K = np.identity(3), \
function render_view (line 297) | def render_view(intri, dists, c2ws, meshes, view, i):
class Worker (line 328) | class Worker(Process):
method __init__ (line 330) | def __init__(self, queue, lock):
method run (line 335) | def run(self):
FILE: apps/run_genebody.py
function loss_string (line 26) | def loss_string(loss_dict):
function print_write (line 32) | def print_write(file, string):
function to8b (line 37) | def to8b(img):
function prepare_data (line 50) | def prepare_data(opt, data, local_rank=0):
function cal_metrics (line 94) | def cal_metrics(metrics, rgbs, gts):
function train (line 105) | def train(opt, rank=0, local_rank = 0):
FILE: genebody/download_tool.py
function import_or_install (line 4) | def import_or_install(package):
function parse_args (line 47) | def parse_args():
function newSession (line 66) | def newSession():
function save_hash (line 72) | def save_hash(path, code):
function read_hash (line 76) | def read_hash(path):
function checkHashes (line 81) | def checkHashes(localfile, cloud_hash, localroot, force):
function getFiles (line 141) | def getFiles(originalUrl, download_path, force, download_root=None, req=...
function fetch_with_pwd (line 267) | async def fetch_with_pwd(iurl, password):
function havePwdGetFiles (line 298) | def havePwdGetFiles(iurl, password, download_path, force):
function extractFiles (line 304) | def extractFiles(path, subset):
function moveFiles (line 327) | def moveFiles(root):
FILE: genebody/genebody.py
function image_cropping (line 8) | def image_cropping(mask, padding=0.1):
class GeneBodyReader (line 65) | class GeneBodyReader():
method __init__ (line 66) | def __init__(self, rootdir, loadsize=512):
method get_views (line 75) | def get_views(self, subject):
method get_frames (line 93) | def get_frames(self, subject):
method get_cameras (line 99) | def get_cameras(self, subject):
method get_smpl (line 102) | def get_smpl(self, subject, frame):
method get_smpl_param (line 111) | def get_smpl_param(self, subject, frame):
method get_data (line 133) | def get_data(self, subject, frame, camera_params, views):
method get_near_far (line 172) | def get_near_far(self, verts, c2w, pad=0.5):
method smpl_from_param (line 188) | def smpl_from_param(self, model_path, subject, smpl_param, smpl_scale):
FILE: genebody/mesh.py
function max_precision (line 49) | def max_precision(type1, type2):
function decode (line 83) | def decode(content, structure, num, form):
function load_ply (line 131) | def load_ply(file_name):
function save_ply (line 192) | def save_ply(file_name, elems, _type = 'binary_little_endian', comments ...
function normalize_v3 (line 280) | def normalize_v3(arr):
function compute_normal (line 291) | def compute_normal(vertices, faces):
function load_obj_mesh (line 312) | def load_obj_mesh(mesh_file, with_normal=False, with_texture=False, with...
function write_obj_mesh (line 433) | def write_obj_mesh(filename, verts, faces):
FILE: lib/data/GeneBodyDataset.py
function mask_padding (line 22) | def mask_padding(mask, border = 5):
function euler2rot (line 31) | def euler2rot(euler):
function rot2euler (line 46) | def rot2euler(R):
function gen_cam_views (line 52) | def gen_cam_views(transl, z_pitch, viewnum):
class GeneBodyDataset (line 79) | class GeneBodyDataset(Dataset):
method modify_commandline_options (line 81) | def modify_commandline_options(parser):
method __init__ (line 83) | def __init__(self, opt, phase='eval', root=None, move_cam=0):
method get_frames (line 126) | def get_frames(self, i = 0):
method get_render_poses (line 145) | def get_render_poses(self, annots, move_cam=150):
method __len__ (line 159) | def __len__(self):
method image_cropping (line 165) | def image_cropping(self, mask):
method get_near_far (line 208) | def get_near_far(self, smpl_verts, w2c):
method get_image (line 230) | def get_image(self, sid, num_views, view_id=None, random_sample=False,...
method smpl_from_param (line 387) | def smpl_from_param(self, model_path, subject, smpl_param, smpl_scale):
method get_item (line 410) | def get_item(self, index):
method __getitem__ (line 475) | def __getitem__(self, index):
FILE: lib/geometry.py
function rot2euler (line 4) | def rot2euler(R):
function euler2rot (line 10) | def euler2rot(euler):
function batch_rodrigues (line 25) | def batch_rodrigues(theta):
function quat_to_rotmat (line 41) | def quat_to_rotmat(quat):
function index (line 64) | def index(feat, uv, mode='bilinear'):
function orthogonal (line 81) | def orthogonal(points, calibrations, transforms=None):
function perspective (line 100) | def perspective(points, w2c, camera):
FILE: lib/mesh_util.py
function save_obj_mesh (line 7) | def save_obj_mesh(mesh_path, verts, faces):
function save_obj_mesh_with_color (line 18) | def save_obj_mesh_with_color(mesh_path, verts, faces, colors):
function save_obj_mesh_with_uv (line 30) | def save_obj_mesh_with_uv(mesh_path, verts, faces, uvs):
FILE: lib/metrics.py
function chamfer (line 21) | def chamfer(x_verts, gt_verts, x_normals=None, gt_normals=None):
function fscore (line 36) | def fscore(dist1, dist2, threshold=1.0):
function psnr (line 51) | def psnr(x, gt):
function ssim_channel (line 64) | def ssim_channel(x, gt):
function ssim (line 85) | def ssim(x, gt):
function lpips (line 107) | def lpips(x, gt, net=lpips_net):
FILE: lib/metrics_torch.py
class LPIPS (line 8) | class LPIPS(torch.nn.Module):
method __init__ (line 9) | def __init__(self):
method forward (line 13) | def forward(self, x, gt):
function psnr (line 27) | def psnr(x, gt):
function gaussian (line 46) | def gaussian(window_size, sigma):
function create_window (line 51) | def create_window(window_size, channel=1):
function ssim_ (line 58) | def ssim_(img1, img2, window_size=11, window=None, size_average=True, fu...
class SSIM (line 114) | class SSIM(torch.nn.Module):
method __init__ (line 115) | def __init__(self, window_size=11, size_average=True, val_range=None):
method forward (line 125) | def forward(self, img1, img2):
FILE: lib/model/Embedder.py
class PositionalEncoding (line 8) | class PositionalEncoding:
method __init__ (line 12) | def __init__(self, d, num_freqs=10, min_freq=None, max_freq=None, freq...
method create_embedding_fn (line 19) | def create_embedding_fn(self, d):
method embed (line 44) | def embed(self, inputs):
class SphericalHarmonics (line 47) | class SphericalHarmonics:
method __init__ (line 51) | def __init__(self, d = 3, rank = 3):
method Lengdre_polynormial (line 56) | def Lengdre_polynormial(self, x, omx = None):
method SH (line 69) | def SH(self, xyz):
method embed (line 90) | def embed(self, inputs):
FILE: lib/model/GNR.py
class GNR (line 12) | class GNR(nn.Module):
method __init__ (line 14) | def __init__(self, opt):
method image_rescale (line 40) | def image_rescale(self, images, masks):
method get_image_feature (line 46) | def get_image_feature(self, data):
method forward (line 57) | def forward(self, data, train_shape=False):
method render_path (line 66) | def render_path(self, data):
method reconstruct (line 74) | def reconstruct(self, data):
FILE: lib/model/HGFilters.py
class HourGlass (line 12) | class HourGlass(nn.Module):
method __init__ (line 13) | def __init__(self, num_modules, depth, num_features, norm='batch'):
method _generate_network (line 22) | def _generate_network(self, level):
method _forward (line 34) | def _forward(self, level, inp):
method forward (line 60) | def forward(self, x):
class HGFilter (line 64) | class HGFilter(nn.Module):
method __init__ (line 65) | def __init__(self, opt):
method forward (line 114) | def forward(self, x):
FILE: lib/model/NeRF.py
class NeRF (line 9) | class NeRF(nn.Module):
method __init__ (line 10) | def __init__(self, opt, D=8, W=256, input_ch=3, input_ch_atts=3, input...
method forward (line 98) | def forward(self, x, attdirs=None, alpha_only=False, smpl_vis=None):
method weighted_softmax (line 191) | def weighted_softmax(self, attention, weight):
FILE: lib/model/NeRFRenderer.py
class NeRFRenderer (line 23) | class NeRFRenderer:
method __init__ (line 24) | def __init__(self, opt, nerf, nerf_fine=None, projection='perspective'...
method cal_loss (line 82) | def cal_loss(self, rgb, rgb_gt):
method get_rays_orthogonal (line 99) | def get_rays_orthogonal(self, bbox, calib):
method get_rays_perspective (line 120) | def get_rays_perspective(self, bbox, w2c, cam):
method make_att_input (line 155) | def make_att_input(self, pts, viewdirs, calibs, smpl):
method make_nerf_input (line 182) | def make_nerf_input(self, pts, feats, images, smpl, calibs, mesh_param...
method make_nerf_output (line 243) | def make_nerf_output(self, nerf_output, t_vals, norm, source_rgb, is_t...
method render_rays (line 276) | def render_rays(self, ray_batch, feats, images, masks, calibs, smpl, m...
method inside_pts_vh (line 364) | def inside_pts_vh(self, pts, masks, smpl, calibs, persps=None):
method render (line 393) | def render(self, feats, images, masks, calibs, bbox, mesh_param, smpl=...
method render_path (line 423) | def render_path(self, feats, images, masks, calibs, bbox, mesh_param, ...
method train_shape (line 462) | def train_shape(self, feats, images, masks, calibs, bbox, mesh_param, ...
method reconstruct (line 480) | def reconstruct(self, feats, images, masks, calibs, bbox, mesh_param, ...
method octree_reconstruct (line 534) | def octree_reconstruct(self, coords, masks, **kwargs):
method get_plucker_line (line 608) | def get_plucker_line(self, countour, w2c, cam, mesh_param):
method warp_mat (line 640) | def warp_mat(self, face_verts):
method canonical_warpping (line 660) | def canonical_warpping(self, verts, verts_can, pts):
method multiview_consistency (line 672) | def multiview_consistency(self, pts, feats, calibs, depth, persps=None):
FILE: lib/model/SRFilters.py
class SRFilters (line 5) | class SRFilters(nn.Module):
method __init__ (line 9) | def __init__(self, order=2, in_ch=256, out_ch=128):
method forward (line 17) | def forward(self, feat, images):
FILE: lib/net_ddp.py
function worker_init_fn (line 11) | def worker_init_fn(worker_id):
function ddp_init (line 16) | def ddp_init(args):
function create_network (line 32) | def create_network(opt, net, local_rank=0):
function synchronize (line 81) | def synchronize():
class ddpSampler (line 86) | class ddpSampler:
method __init__ (line 90) | def __init__(self, dataset, rank=None, num_replicas=None):
method indices (line 105) | def indices(self):
method len (line 113) | def len(self):
method distributed_concat (line 116) | def distributed_concat(self, tensor, num_total_examples):
FILE: lib/net_util.py
function conv3x3 (line 14) | def conv3x3(in_planes, out_planes, strd=1, padding=1, bias=False):
function init_weights (line 19) | def init_weights(net, init_type='normal', init_gain=0.02):
function init_net (line 55) | def init_net(net, init_type='normal', init_gain=0.02, gpu_ids=[]):
function imageSpaceRotation (line 73) | def imageSpaceRotation(xy, rot):
function cal_gradient_penalty (line 85) | def cal_gradient_penalty(netD, real_data, fake_data, device, type='mixed...
function get_norm_layer (line 123) | def get_norm_layer(norm_type='instance'):
class Flatten (line 142) | class Flatten(nn.Module):
method forward (line 143) | def forward(self, input):
class ConvBlock (line 146) | class ConvBlock(nn.Module):
method __init__ (line 147) | def __init__(self, in_planes, out_planes, norm='batch'):
method forward (line 174) | def forward(self, x):
FILE: lib/options.py
class BaseOptions (line 6) | class BaseOptions():
method __init__ (line 7) | def __init__(self):
method initialize (line 10) | def initialize(self, parser):
method gather_options (line 136) | def gather_options(self):
method print_options (line 148) | def print_options(self, opt):
method parse (line 160) | def parse(self):
FILE: lib/ply_util.py
function max_precision (line 47) | def max_precision(type1, type2):
function decode (line 80) | def decode(content, structure, num, form):
function load_ply (line 127) | def load_ply(file_name):
function save_ply (line 187) | def save_ply(file_name, elems, _type = 'binary_little_endian', comments ...
Condensed preview — 37 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (2,588K chars).
[
{
"path": ".gitignore",
"chars": 78,
"preview": "checkpoints/*\ndata/*\n**.ply\nresults/*\nsample_images/*\nlogs/*\n**/__pycache__/*\n"
},
{
"path": ".gitmodules",
"chars": 895,
"preview": "[submodule \"benchmarks/ibrnet\"]\r\n\tpath = benchmarks/ibrnet\r\n\turl = https://github.com/generalizable-neural-performer/gen"
},
{
"path": "README.md",
"chars": 7827,
"preview": "# Generalizable Neural Performer: Learning Robust Radiance Fields for Human Novel View Synthesis\n[))\nsys.path.insert(0, os.path."
},
{
"path": "configs/render.txt",
"chars": 663,
"preview": "name = genebody\n\n# run phase\ntrain = False\ntest = False\nrender = True\n\n# Dataloader\nnum_threads = 5\noutput_mesh = True\n\n"
},
{
"path": "configs/test.txt",
"chars": 662,
"preview": "name = genebody\n\n# run phase\ntrain = False\ntest = True\nrender = False\n\n# Dataloader\nnum_threads = 5\noutput_mesh = True\n\n"
},
{
"path": "configs/train.txt",
"chars": 783,
"preview": "name = genebody\n\n# run phase\ntrain = True\ntest = False\nrender = False\n\n# Dataloader\nnum_threads = 5\noutput_mesh = True\n\n"
},
{
"path": "docs/Annotation.md",
"chars": 3670,
"preview": "# GeneBody Annotations\n\n## Data Capture\nGeneBody dataset captures performer in a motion capture studio with 48 synchroni"
},
{
"path": "docs/Dataset.md",
"chars": 2047,
"preview": "# GeneBody Dataset\n\n<!--  -->\n<p align=\"center\"><img src=\"./genebody.gif\" width=\"9"
},
{
"path": "environment.yml",
"chars": 2885,
"preview": "name: gnr\nchannels:\n - pytorch\n - conda-forge\n - https://repo.anaconda.com/pkgs/main\n - defaults\ndependencies:\n - _"
},
{
"path": "genebody/download_tool.py",
"chars": 16131,
"preview": "from ast import arg\nimport json, os, sys, pip, copy\nfrom re import sub\ndef import_or_install(package):\n try:\n "
},
{
"path": "genebody/gender.py",
"chars": 1361,
"preview": "\ngenebody_gender = {\n \"abror\": \"male\", \n \"ahha\": \"female\", \n \"alejandro\": \"male\", \n \"amanda\": \"female\", \n "
},
{
"path": "genebody/genebody.py",
"chars": 10467,
"preview": "import os, sys\nimport numpy as np \nimport cv2, imageio\nfrom .mesh import load_ply, load_obj_mesh, write_obj_mesh\nimport "
},
{
"path": "genebody/mesh.py",
"chars": 17480,
"preview": "import numpy as np\nimport struct\nimport sys, os, re\nimport cv2\nif sys.version_info[0] == 3:\n from functools import re"
},
{
"path": "lib/data/GeneBodyDataset.py",
"chars": 20809,
"preview": "from re import sub\nimport torch\nfrom torch.utils.data import Dataset\nimport torchvision.transforms as transforms\nfrom PI"
},
{
"path": "lib/data/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "lib/geometry.py",
"chars": 4625,
"preview": "import torch\nimport numpy as np\n\ndef rot2euler(R):\n phi = np.arctan2(R[1,2], R[2,2])\n theta = -np.arcsin(R[0,2])\n "
},
{
"path": "lib/mesh_util.py",
"chars": 1296,
"preview": "from skimage import measure\nimport numpy as np\nimport torch\nfrom skimage import measure\n\n\ndef save_obj_mesh(mesh_path, v"
},
{
"path": "lib/metrics.py",
"chars": 3907,
"preview": "import sys\nimport os\n\nimport torch\n\nsys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))\nR"
},
{
"path": "lib/metrics_torch.py",
"chars": 4057,
"preview": "import torch\nimport torch.nn.functional as F\nfrom math import exp\nimport numpy as np\nimport lpips\n\n\nclass LPIPS(torch.nn"
},
{
"path": "lib/model/Embedder.py",
"chars": 3285,
"preview": "\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport math\nimport numpy as np\n\nclass PositionalEnco"
},
{
"path": "lib/model/GNR.py",
"chars": 2499,
"preview": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom .HGFilters import *\nfrom ..net_util import init_"
},
{
"path": "lib/model/HGFilters.py",
"chars": 5626,
"preview": "\"\"\"\n This file is directly borrowed from PIFu\n GNR uses PIFu's Stacked-Hour-Glass for image encoding\n\"\"\"\n\nimport t"
},
{
"path": "lib/model/NeRF.py",
"chars": 9803,
"preview": "import torch\ntorch.autograd.set_detect_anomaly(True)\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport numpy "
},
{
"path": "lib/model/NeRFRenderer.py",
"chars": 33862,
"preview": "import logging\nfrom turtle import width\n\nimport torch\ntorch.autograd.set_detect_anomaly(True)\nimport torch.nn as nn\nimpo"
},
{
"path": "lib/model/SRFilters.py",
"chars": 1029,
"preview": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nclass SRFilters(nn.Module):\n \"\"\"\n Upsample the"
},
{
"path": "lib/model/__init__.py",
"chars": 144,
"preview": "from .GNR import GNR\nfrom .NeRF import NeRF\nfrom .NeRFRenderer import NeRFRenderer\nfrom .Embedder import SphericalHarmon"
},
{
"path": "lib/net_ddp.py",
"chars": 5271,
"preview": "import torch\nimport torch.multiprocessing as mp\nimport torch.distributed as dist\nfrom torch.nn.parallel import Distribut"
},
{
"path": "lib/net_util.py",
"chars": 8206,
"preview": "import torch\nfrom torch.nn import init\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport functools\n\nimport nu"
},
{
"path": "lib/options.py",
"chars": 9988,
"preview": "# import argparse\nimport configargparse\nimport os\n\n\nclass BaseOptions():\n def __init__(self):\n self.initialize"
},
{
"path": "lib/ply_util.py",
"chars": 9929,
"preview": "import numpy as np\nimport struct\nimport sys\nimport re\nif sys.version_info[0] == 3:\n\tfrom functools import reduce\n# type:"
},
{
"path": "scripts/download_model.sh",
"chars": 151,
"preview": "mkdir -p logs/genebody && cd logs/genebody\ngdown 17kVOpH4Hud-ZxKlvj0vbKB5dkIxwoHHm\nunzip genebody-pretrained.zip\nrm -f g"
},
{
"path": "scripts/render_smpl_depth.sh",
"chars": 269,
"preview": "genebody_path=$1\n\nsubjects=($(ls ${genebody_path}))\nfor ((i=0; i<${#genebody_path[@]}; ++i)) do\n subject=${subjects[i"
},
{
"path": "scripts/train_ddp.sh",
"chars": 470,
"preview": "NODES=2\nGPU_PER_NODES=8\nGENEBODY_ROOT=put-your-root-here\nMASTER=your-address-of-master-machine\nCOMMAND=${1:-apps/train_g"
},
{
"path": "smpl_t_pose/smpl.obj",
"chars": 1038453,
"preview": "#\r\n# SMPL UV coordinate template model\r\n# http://smpl.is.tue.mpg.de/\r\n# Version: 20200910\r\n#\r\n# Ownership / Licensees:\r\n"
},
{
"path": "smpl_t_pose/smplx.obj",
"chars": 1221122,
"preview": "# Blender v2.90.1 OBJ File: ''\n# www.blender.org\no smplx_uv\nv 0.062714 0.288500 -0.009561\nv 0.066796 0.287508 -0.008525\n"
}
]
About this extraction
This page contains the full source code of the generalizable-neural-performer/gnr GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 37 files (2.4 MB), approximately 621.2k tokens, and a symbol index with 168 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.
Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.