main 5f670850a013 cached
37 files
2.4 MB
621.2k tokens
168 symbols
1 requests
Download .txt
Showing preview only (2,485K chars total). Download the full file or copy to clipboard to get everything.
Repository: generalizable-neural-performer/gnr
Branch: main
Commit: 5f670850a013
Files: 37
Total size: 2.4 MB

Directory structure:
gitextract_wh9gfa65/

├── .gitignore
├── .gitmodules
├── README.md
├── apps/
│   ├── render_smpl_depth.py
│   └── run_genebody.py
├── configs/
│   ├── render.txt
│   ├── test.txt
│   └── train.txt
├── docs/
│   ├── Annotation.md
│   └── Dataset.md
├── environment.yml
├── genebody/
│   ├── download_tool.py
│   ├── gender.py
│   ├── genebody.py
│   └── mesh.py
├── lib/
│   ├── data/
│   │   ├── GeneBodyDataset.py
│   │   └── __init__.py
│   ├── geometry.py
│   ├── mesh_util.py
│   ├── metrics.py
│   ├── metrics_torch.py
│   ├── model/
│   │   ├── Embedder.py
│   │   ├── GNR.py
│   │   ├── HGFilters.py
│   │   ├── NeRF.py
│   │   ├── NeRFRenderer.py
│   │   ├── SRFilters.py
│   │   └── __init__.py
│   ├── net_ddp.py
│   ├── net_util.py
│   ├── options.py
│   └── ply_util.py
├── scripts/
│   ├── download_model.sh
│   ├── render_smpl_depth.sh
│   └── train_ddp.sh
└── smpl_t_pose/
    ├── smpl.obj
    └── smplx.obj

================================================
FILE CONTENTS
================================================

================================================
FILE: .gitignore
================================================
checkpoints/*
data/*
**.ply
results/*
sample_images/*
logs/*
**/__pycache__/*


================================================
FILE: .gitmodules
================================================
[submodule "benchmarks/ibrnet"]
	path = benchmarks/ibrnet
	url = https://github.com/generalizable-neural-performer/genebody-benchmarks/
	branch = ibrnet
[submodule "benchmarks/nv"]
	path = benchmarks/nv
	url = https://github.com/generalizable-neural-performer/genebody-benchmarks/
	branch = nv
[submodule "benchmarks/nt"]
	path = benchmarks/nt
	url = https://github.com/generalizable-neural-performer/genebody-benchmarks/
	branch = nt
[submodule "benchmarks/nb"]
	path = benchmarks/nb
	url = https://github.com/generalizable-neural-performer/genebody-benchmarks/
	branch = nb
[submodule "benchmarks/anerf"]
	path = benchmarks/anerf
	url = https://github.com/generalizable-neural-performer/genebody-benchmarks/
	branch = A-Nerf
[submodule "benchmarks/nhr"]
	path = benchmarks/nhr
	url = https://github.com/generalizable-neural-performer/genebody-benchmarks/
	branch = nhr


================================================
FILE: README.md
================================================
# Generalizable Neural Performer: Learning Robust Radiance Fields for Human Novel View Synthesis
[![report](https://img.shields.io/badge/arxiv-report-red)](http://arxiv.org/abs/2204.11798) 
<!-- [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)]() -->

![Teaser image](./docs/teaser.png)

> **Abstract:** *This work targets using a general deep learning framework to synthesize free-viewpoint images of arbitrary human performers, only requiring a sparse number of camera views as inputs and skirting per-case fine-tuning. The large variation of geometry and appearance, caused by articulated body poses, shapes, and clothing types, are the key bottlenecks of this task. To overcome these challenges, we present a simple yet powerful framework, named Generalizable Neural Performer (GNR), that learns a generalizable and robust neural body representation over various geometry and appearance. Specifically, we compress the light fields for a novel view of human rendering as conditional implicit neural radiance fields with several designs from both geometry and appearance aspects. We first introduce an Implicit Geometric Body Embedding strategy to enhance the robustness based on both parametric 3D human body model prior and multi-view source images hints. On top of this, we further propose a Screen-Space Occlusion-Aware Appearance Blending technique to preserve the high-quality appearance, through interpolating source view appearance to the radiance fields with a relaxed but approximate geometric guidance.* <br>

[Wei Cheng](mailto:wchengad@connect.ust.hk), [Su Xu](mailto:xusu@sensetime.com), [Jingtan Piao](mailto:piaojingtan@sensetime.com), [Chen Qian](https://scholar.google.com/citations?user=AerkT0YAAAAJ&hl=zh-CN), [Wayne Wu](https://wywu.github.io/), [Kwan-Yee Lin](https://kwanyeelin.github.io/), [Hongsheng Li](https://www.ee.cuhk.edu.hk/~hsli/)<br>
**[[Demo Video]](https://www.youtube.com/watch?v=2COR4u1ZIuk)** | **[[Project Page]](https://generalizable-neural-performer.github.io/)** | **[[Data]](https://generalizable-neural-performer.github.io/genebody.html)** | **[[Paper]](https://arxiv.org/pdf/2204.11798.pdf)**

## Updates
- [14/07/2023] :star2::star2::star2:Check out our sister dataset -- [DNA-Rendering](https://dna-rendering.github.io/)! :star2::star2::star2: It captures humans with up to 4K resolution, and enjoys more accurate annotations, attributes, and diversities than Genbody1.0.
- [01/09/2022] We also recommend the implementation of our work in the [OpenXRLab](https://github.com/openxrlab/xrnerf). 
- [01/09/2022] :exclamation: GeneBody has been reframed. For users who have downloaded GeneBody before `2022.09.01` please update the latest data using our more user-friendly download tool.
- [29/07/2022] GeneBody can be downloaded from [OpenDataLab](https://opendatalab.com/GeneBody).
- [11/07/2022] Code is released.
- [02/05/2022] GeneBody Train40 is released! Apply [here](./docs/Dataset.md#train40)! 
- [29/04/2022] SMPLx fitting toolbox and benchmarks are released!
- [26/04/2022] Technical report released.
- [24/04/2022] The codebase and project page are created.

## Data Download
To download and use the GeneBody dataset set, please first read the instructions in [Dataset.md](./docs/Dataset.md). We provide a download tool to download and update the GeneBody data including dataset and pretrained models (if there is any future adjustment), for example
```
python genebody/download_tool.py --genebody_root ${GENEBODY_ROOT} --subset train40 test10 pretrained_models smpl_depth
```
The tool will fetch and download the subsets you selected and put the data in `${GENEBODY_ROOT}`.

## Annotations
GeneBody provides the per-view per-frame segmentation, using [BackgroundMatting-V2](https://github.com/PeterL1n/BackgroundMattingV2), and register the fitted [SMPLx](https://github.com/PeterL1n/BackgroundMattingV2) using our enhanced multi-view smplify repo in [here](https://github.com/generalizable-neural-performer/bodyfitting).

To use annotations of GeneBody, please check the document [Annotation.md](./docs/Annotation.md), we provide a reference data fetch module in `genebody`.

## Train and Evaluate GNR

Setup the environment 
```
conda env create -f environment.yml
conda activate gnr
pip install git+https://github.com/generalizable-neural-performer/gnr.git@mesh_grid
```

To run GNR on genebody
```
python apps/run_genebody.py --config configs/[train, test, render].txt --dataroot ${GENEBODY_ROOT}
```
if you have multiple machines and multiple GPUs, you can try to train our model using distributed data parallel
```
bash scripts/train_ddp.sh
```


## Benchmarks
We also provide benchmarks of start-of-the-art methods on GeneBody Dataset, methods and requirements are listed in [Benchmarks.md](https://github.com/generalizable-neural-performer/genebody-benchmarks).

To test the performance of our released pretrained models or train by yourselves, run:
```
git clone --recurse-submodules https://github.com/generalizable-neural-performer/gnr.git
```
And `cd benchmarks/`, the released benchmarks are ready to go on Genebody and other datasets such as V-sense and ZJU-Mocap.

### Case-specific Methods on Genebody
| Model  | PSNR | SSIM |LPIPS| ckpts|
| :--- | :---------------:|:---------------:| :---------------:| :---------------:  |
| [NV](https://github.com/generalizable-neural-performer/genebody-benchmarks/tree/nv)| 19.86 |0.774 |  0.267 | [ckpts](https://hkustconnect-my.sharepoint.com/:f:/g/personal/wchengad_connect_ust_hk/EniK9r9UdbtGvYvtJITBGkIBlmxSHqaoEIiIgpYBGddCHQ?e=RbS0sG)|
| [NHR](https://github.com/generalizable-neural-performer/genebody-benchmarks/tree/nhr)| 20.05  |0.800 |  0.155 | [ckpts](https://hkustconnect-my.sharepoint.com/:f:/g/personal/wchengad_connect_ust_hk/EqQDNVch2j5DmyIDnHX0VgkBDdCksmT4Kfq2oPOMn6gfMg?e=dy6yUA)|
| [NT](https://github.com/generalizable-neural-performer/genebody-benchmarks/tree/nt)| 21.68  |0.881 |   0.152 | [ckpts](https://hkustconnect-my.sharepoint.com/:f:/g/personal/wchengad_connect_ust_hk/Etg3LW44m61OjZOgDp-f4TcB_rgm_32ve529z5EZgCmoGw?e=zGUadc)|
| [NB](https://github.com/generalizable-neural-performer/genebody-benchmarks/tree/nb)| 20.73  |0.878 |  0.231 | [ckpts](https://hkustconnect-my.sharepoint.com/personal/wchengad_connect_ust_hk/_layouts/15/onedrive.aspx?ga=1&id=%2Fpersonal%2Fwchengad%5Fconnect%5Fust%5Fhk%2FDocuments%2Fgenebody%2Dbenchmark%2Dpretrained%2Fnb%2Fgenebody)|
| [A-Nerf](https://github.com/generalizable-neural-performer/genebody-benchmarks/tree/A-Nerf)| 15.57 |0.508 |  0.242 | [ckpts](https://hkustconnect-my.sharepoint.com/:f:/g/personal/wchengad_connect_ust_hk/En56nksujH1Fn1qWiUJ-gpIBfzdHqHf66F-RvfzwTe2TBQ?e=Zz0EgX)|

(see detail why A-Nerf's performance is counterproductive in [issue](https://github.com/LemonATsu/A-NeRF/issues/8))
### Generalizable Methods on Genebody
| Model  | PSNR | SSIM |LPIPS| ckpts|
| :--- | :---------------:|:---------------:| :---------------:| :---------------:  |
| PixelNeRF | 24.15   |0.903 | 0.122 | |
| [IBRNet](https://github.com/generalizable-neural-performer/genebody-benchmarks/tree/ibrnet)| 23.61    |0.836 |  0.177 | [ckpts](https://hkustconnect-my.sharepoint.com/personal/wchengad_connect_ust_hk/_layouts/15/onedrive.aspx?ga=1&id=%2Fpersonal%2Fwchengad%5Fconnect%5Fust%5Fhk%2FDocuments%2Fgenebody%2Dbenchmark%2Dpretrained%2Fibrnet)|

### Opensource contributions
OpenXRLab/xrnerf: [https://github.com/openxrlab/xrnerf](https://github.com/openxrlab/xrnerf)


## Citation

```
@article{cheng2022generalizable,
    title={Generalizable Neural Performer: Learning Robust Radiance Fields for Human Novel View Synthesis},
    author={Cheng, Wei and Xu, Su and Piao, Jingtan and Qian, Chen and Wu, Wayne and Lin, Kwan-Yee and Li, Hongsheng},
    journal={arXiv preprint arXiv:2204.11798},
    year={2022}
}
```


================================================
FILE: apps/render_smpl_depth.py
================================================
from tqdm import tqdm
import numpy as np
import argparse
import struct
import sys
import cv2
import os
import re
from multiprocessing import Queue, Lock, Process

from trimesh import load_mesh
base_dir = os.path.dirname(os.path.abspath(__file__))
parser = argparse.ArgumentParser()
parser.add_argument('--datadir', type = str, required=True)
parser.add_argument('--outdir', type = str, default = '')
parser.add_argument('--annotdir', type = str, default = '')
parser.add_argument('--workers', type = int, default = 8)

def load_obj_mesh(mesh_file, with_normal=False, with_texture=False, with_texture_image=False):
	vertex_data = []
	norm_data = []
	uv_data = []

	face_data = []
	face_norm_data = []
	face_uv_data = []

	if isinstance(mesh_file, str):
		f = open(mesh_file, "r")
	else:
		f = mesh_file
	for line in f:
		if isinstance(line, bytes):
			line = line.decode("utf-8")
		if line.startswith('#'):
			continue
		values = line.split()
		if not values:
			continue

		if values[0] == 'v':
			v = list(map(float, values[1:4]))
			vertex_data.append(v)
		elif values[0] == 'vn':
			vn = list(map(float, values[1:4]))
			norm_data.append(vn)
		elif values[0] == 'vt':
			vt = list(map(float, values[1:3]))
			uv_data.append(vt)

		elif values[0] == 'f':
			# quad mesh
			if len(values) > 4:
				f = list(map(lambda x: int(x.split('/')[0]), values[1:4]))
				face_data.append(f)
				f = list(map(lambda x: int(x.split('/')[0]), [values[3], values[4], values[1]]))
				face_data.append(f)
			# tri mesh
			else:
				f = list(map(lambda x: int(x.split('/')[0]), values[1:4]))
				face_data.append(f)
			
			# deal with texture
			if len(values[1].split('/')) >= 2:
				# quad mesh
				if len(values) > 4:
					f = list(map(lambda x: int(x.split('/')[1]), values[1:4]))
					face_uv_data.append(f)
					f = list(map(lambda x: int(x.split('/')[1]), [values[3], values[4], values[1]]))
					face_uv_data.append(f)
				# tri mesh
				elif len(values[1].split('/')[1]) != 0:
					f = list(map(lambda x: int(x.split('/')[1]), values[1:4]))
					face_uv_data.append(f)
			# deal with normal
			if len(values[1].split('/')) == 3:
				# quad mesh
				if len(values) > 4:
					f = list(map(lambda x: int(x.split('/')[2]), values[1:4]))
					face_norm_data.append(f)
					f = list(map(lambda x: int(x.split('/')[2]), [values[3], values[4], values[1]]))
					face_norm_data.append(f)
				# tri mesh
				elif len(values[1].split('/')[2]) != 0:
					f = list(map(lambda x: int(x.split('/')[2]), values[1:4]))
					face_norm_data.append(f)
		elif 'mtllib' in line.split():
			mtlname = line.split()[-1]
			mtlfile = os.path.join(os.path.dirname(mesh_file), mtlname)
			with open(mtlfile, 'r') as fmtl:
				mtllines = fmtl.readlines()
				for mtlline in mtllines:
					# if mtlline.startswith('map_Kd'):
					if 'map_Kd' in mtlline.split():
						texname = mtlline.split()[-1]
						texfile = os.path.join(os.path.dirname(mesh_file), texname)
						texture_image = cv2.imread(texfile)
						texture_image = cv2.cvtColor(texture_image, cv2.COLOR_BGR2RGB)
						break

	vertices = np.array(vertex_data)
	faces = np.array(face_data) - 1

	if with_texture and with_normal:
		uvs = np.array(uv_data)
		face_uvs = np.array(face_uv_data) - 1
		norms = np.array(norm_data)
		if norms.shape[0] == 0:
			norms = compute_normal(vertices, faces)
			face_normals = faces
		else:
			norms = normalize_v3(norms)
			face_normals = np.array(face_norm_data) - 1
		if with_texture_image:
			return vertices, faces, norms, face_normals, uvs, face_uvs, texture_image
		else:
			return vertices, faces, norms, face_normals, uvs, face_uvs

	if with_texture:
		uvs = np.array(uv_data)
		face_uvs = np.array(face_uv_data) - 1
		return vertices, faces, uvs, face_uvs

	if with_normal:
		# norms = np.array(norm_data)
		# norms = normalize_v3(norms)
		# face_normals = np.array(face_norm_data) - 1
		norms = np.array(norm_data)
		if norms.shape[0] == 0:
			norms = compute_normal(vertices, faces)
			face_normals = faces
		else:
			norms = normalize_v3(norms)
			face_normals = np.array(face_norm_data) - 1
		return vertices, faces, norms, face_normals

	return vertices, faces

def extract_float(text):
	flts = []
	for c in re.findall('(-?[0-9]*\.?[0-9]*[eE]?[-\+]?[0-9]+)',text):
		if c != '':
			try:
				flts.append(float(c))
			except ValueError as e:
				continue
	return flts

def natural_sort(files):
	return	sorted(files, key = lambda text: \
		extract_float(os.path.basename(text)) \
		if len(extract_float(os.path.basename(text))) > 0 else \
		[float(ord(c)) for c in os.path.basename(text)])
		
def load_ply(file_name):
	v = []; tri = []
	try:
		fid = open(file_name, 'r')
		head = fid.readline().strip()
		readl= lambda f: f.readline().strip()
	except UnicodeDecodeError as e:
		fid = open(file_name, 'rb')
		readl =	(lambda f: str(f.readline().strip())[2:-1]) \
			if sys.version_info[0] == 3 else \
			(lambda f: str(f.readline().strip()))
		head = readl(fid)
	if head.lower() != 'ply':
		return	v, tri
	form = readl(fid).split(' ')[1]
	line = readl(fid)
	vshape = fshape = [0]
	while line != 'end_header':
		s = [i for i in line.split(' ') if len(i) > 0]
		if len(s) > 2 and s[0] == 'element' and s[1] == 'vertex':
			vshape = [int(s[2])]
			line = readl(fid)
			s = [i for i in line.split(' ') if len(i) > 0]
			while s[0] == 'property' or s[0][0] == '#':
				if s[0][0] != '#':
					vshape += [s[1]]
				line = readl(fid)
				s = [i for i in line.split(' ') if len(i) > 0]
		elif len(s) > 2 and s[0] == 'element' and s[1] == 'face':
			fshape = [int(s[2])]
			line = readl(fid)
			s = [i for i in line.split(' ') if len(i) > 0]
			while s[0] == 'property' or s[0][0] == '#':
				if s[0][0] != '#':
					fshape = [fshape[0],s[2],s[3]]
				line = readl(fid)
				s = [i for i in line.split(' ') if len(i) > 0]
		else:
			line = readl(fid)
	if form.lower() == 'ascii':
		for i in range(vshape[0]):
			s = [i for i in readl(fid).split(' ') if len(i) > 0]
			if len(s) > 0 and s[0][0] != '#':
				v += [[float(i) for i in s]]
		v = np.array(v, np.float32)
		for i in range(fshape[0]):
			s = [i for i in readl(fid).split(' ') if len(i) > 0]
			if len(s) > 0 and s[0][0] != '#':
				tri += [[int(s[1]),int(s[i-1]),int(s[i])] \
					for i in range(3,len(s))]
		tri = np.array(tri, np.int64)
	else:
		maps = {'float': ('f',4), 'double':('d',8), \
			'uint':  ('I',4), 'int':   ('i',4), \
			'ushort':('H',2), 'short': ('h',2), \
			'uchar': ('B',1), 'char':  ('b',1)}
		if 'little' in form.lower():
			fmt = '<' + ''.join([maps[i][0] for i in vshape[1:]]*vshape[0])
		else:
			fmt = '>' + ''.join([maps[i][0] for i in vshape[1:]]*vshape[0])
		l = sum([maps[i][1] for i in vshape[1:]]) * vshape[0]
		v = struct.unpack(fmt, fid.read(l))
		v = np.array(v).reshape(vshape[0],-1).astype(np.float32)
		v = v[:,:3]
		tri = []
		for i in range(fshape[0]):
			l = struct.unpack(fmt[0]+maps[fshape[1]][0], \
				fid.read(maps[fshape[1]][1]))[0]
			f = struct.unpack(fmt[0]+maps[fshape[2]][0]*l, \
				fid.read(l*maps[fshape[2]][1]))
			tri += [[f[0],f[i-1],f[i]] for i in range(2,len(f))]
		tri = np.array(tri).reshape(fshape[0],-1).astype(np.int64)
	fid.close()
	return	v, tri

def distortPoints(p, dist):
	dist = np.reshape(dist,-1) \
		if dist is not None else []
	k1 = dist[0] if len(dist) > 0 else 0
	k2 = dist[1] if len(dist) > 1 else 0
	p1 = dist[2] if len(dist) > 2 else 0
	p2 = dist[3] if len(dist) > 3 else 0
	k3 = dist[4] if len(dist) > 4 else 0
	k4 = dist[5] if len(dist) > 5 else 0
	k5 = dist[6] if len(dist) > 6 else 0
	k6 = dist[7] if len(dist) > 7 else 0
	x, y = p[...,0], p[...,1]
	x2 = x * x; y2 = y * y; xy = x * y
	r2 = x2 + x2
	c =	(1 + r2 * (k1 + r2 * (k2 + r2 * k3))) / \
		(1 + r2 * (k4 + r2 * (k5 + r2 * k6)))
	x_ = c*x + p1*2*xy + p2*(r2+2*x2)
	y_ = c*y + p2*2*xy + p1*(r2+2*y2)
	p[...,0] = x_
	p[...,1] = y_
	return p

def rasterize(v, tri, size, K = np.identity(3), \
		dist = None, persp = True, eps = 1e-6):
	h, w = size
	zbuf = np.ones([h, w], v.dtype) * float('inf')
	if dist is not None:
		valid = np.where(v[:,2] >= eps)[0] \
			if persp else np.arange(len(v))
		v_proj = v[valid,:2] / v[valid,2:]
		v_proj = distortPoints(v_proj, dist)
		v[valid,:2]= v_proj * v[valid,2:]
	v_proj = v.dot(K.T)[:,:2] / np.maximum(v[:,2:], eps) \
		if persp else v.dot(K.T)[:,:2]
	va = v_proj[tri[:,0],:2]
	vb = v_proj[tri[:,1],:2]
	vc = v_proj[tri[:,2],:2]
	front = np.cross(vc - va, vb - va)
	umin = np.maximum(np.ceil (np.vstack((va[:,0],vb[:,0],vc[:,0])).min(0)), 0)
	umax = np.minimum(np.floor(np.vstack((va[:,0],vb[:,0],vc[:,0])).max(0)),w-1)
	vmin = np.maximum(np.ceil (np.vstack((va[:,1],vb[:,1],vc[:,1])).min(0)), 0)
	vmax = np.minimum(np.floor(np.vstack((va[:,1],vb[:,1],vc[:,1])).max(0)),h-1)
	umin = umin.astype(np.int32)
	umax = umax.astype(np.int32)
	vmin = vmin.astype(np.int32)
	vmax = vmax.astype(np.int32)
	front = np.where(np.logical_and(np.logical_and( \
			umin <= umax, vmin <= vmax), front > 0))[0]
	for t in front:
		A = np.concatenate((vb[t:t+1]-va[t:t+1], vc[t:t+1]-va[t:t+1]),0)
		x, y = np.meshgrid(	range(umin[t],umax[t]+1), \
					range(vmin[t],vmax[t]+1))
		u = np.vstack((x.reshape(-1),y.reshape(-1))).T
		coeff = (u.astype(v.dtype) - va[t:t+1,:]).dot(np.linalg.pinv(A))
		coeff = np.concatenate((1-coeff.sum(1).reshape(-1,1),coeff),1)
		if persp:
			z = coeff.dot(v[tri[t], 2])
		else:
			z = 1 / np.maximum((coeff/v[tri[t],2:3].T).sum(1), eps)
		for i, (x, y) in enumerate(u):
			if  coeff[i,0] >= -eps \
			and coeff[i,1] >= -eps \
			and coeff[i,2] >= -eps \
			and zbuf[y,x] > z[i]:
				zbuf[y,x] = z[i]
	return	zbuf

def render_view(intri, dists, c2ws, meshes, view, i):
	K = intri[i]
	dist = dists[i]
	# c2w  = np.concatenate([c2ws[i],[[0,0,0,1]]], 0)
	w2c = np.linalg.inv(c2ws[i])
	# w2c = c2w
	out = os.path.join(args.outdir, os.path.basename(view))
	if not os.path.isdir(out):
		os.makedirs(out, exist_ok=True)
	imgs = [os.path.join(view, f) for f in os.listdir(view) \
		if f[-4:].lower() in ['.jpg','.png']]
	imgs = sorted(imgs) if len(imgs) > 1 else imgs
	for i in tqdm(range(len(imgs))):
		img = cv2.imread(imgs[i])
		try:
			if i < len(meshes) and meshes[i][-4:] == '.ply':
				v, tri = load_ply(meshes[i])
			elif i < len(meshes) and meshes[i][-4:] == '.npy':
				v = np.load(meshes[i])
			else:
				v, tri = load_obj_mesh(meshes[i])
		except:
			continue
		v_= v.dot(w2c[:3,:3].T) + w2c[:3,3:].T
		z = rasterize(v_, tri, img.shape[:2], K, dist)
		height = v_[:,1].max() - v_[:,1].min()
		z[z == float('inf')] = 0
		z = np.clip(np.round(z * 1000), 0, 65535).astype(np.uint16)
		cv2.imwrite(os.path.join(out, \
			'smpl_'+os.path.basename(imgs[i][:-4])+'.png'), z) 

class Worker(Process):

	def __init__(self, queue, lock):
		super(Worker, self).__init__()
		self.queue = queue
		self.lock = lock

	def run(self):
		while True:
			self.lock.acquire()
			if self.queue.empty():
				self.lock.release()
				break
			else:
				kwargs = self.queue.get()
				queue_len = self.queue.qsize()
				self.lock.release()
				print("started {}, {} jobs left".format(kwargs["view"], queue_len))
				render_view(**kwargs)

if __name__ == '__main__':
	args  = parser.parse_args()
	views = [os.path.join(args.datadir, 'image', f) \
			for f in os.listdir(os.path.join(args.datadir, 'image'))]
	if args.annotdir == '':
		args.annotdir = os.path.join(args.datadir, 'annots.npy')
	annot = np.load(os.path.join(args.annotdir), allow_pickle = True).item()['cams']
	intri = np.array([annot[view]['K'] for view in annot.keys()], np.float32)
	dists = np.array([annot[view]['D'] for view in annot.keys()], intri.dtype)
	c2ws  = np.array([annot[view]['c2w'] for view in annot.keys()]).astype(intri.dtype)
	if args.outdir == '': 
		args.outdir = os.path.join(args.datadir, 'smpl_depth')
	if not os.path.isdir(args.outdir):
		os.makedirs(args.outdir)
	if os.path.exists(os.path.join(args.datadir,'new_smpl')):
		meshes = [os.path.join(args.datadir,'new_smpl',f) \
			for f in os.listdir(os.path.join(args.datadir,'new_smpl')) \
			if f[-4:] == '.ply' or f[-4:] == '.obj']
		meshes = natural_sort(meshes)
	elif os.path.exists(os.path.join(args.datadir,'smpl')):
		meshes = [os.path.join(args.datadir,'smpl',f) \
			for f in os.listdir(os.path.join(args.datadir,'smpl')) \
			if f[-4:] == '.ply' or f[-4:] == '.obj']
		meshes = natural_sort(meshes)
	elif os.path.exists(os.path.join(args.datadir,'new_vertices')):
		meshes = [os.path.join(args.datadir,'new_vertices',f) \
			for f in os.listdir(os.path.join(args.datadir,'new_vertices')) \
			if f[-4:] == '.npy']
		tri = np.loadtxt(os.path.join(base_dir,'tri.txt')).astype(np.int64)
		meshes = natural_sort(meshes)
	elif os.path.exists(os.path.join(args.datadir,'vertices')):
		meshes = [os.path.join(args.datadir,'vertices',f) \
			for f in os.listdir(os.path.join(args.datadir,'vertices')) \
			if f[-4:] == '.npy']
		_, tri = load_obj_mesh('./smpl_t_pose/smplx.obj')
		tri = tri.astype(np.int64)
		meshes = natural_sort(meshes)

	queue = Queue()
	lock = Lock()

	for i, view in enumerate(natural_sort(views)):

		queue.put({
			'intri': intri, 
			'dists': dists, 
			'c2ws': c2ws, 
			'meshes': meshes, 
			'view': view, 
			'i': i,
		})

	print("num of workers", args.workers, flush=True)
	pool = [Worker(queue, lock) for _ in range(args.workers)]
	for worker in pool:  worker.start()
	for worker in pool:  worker.join()


================================================
FILE: apps/run_genebody.py
================================================
import sys
import os
ROOT_PATH = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
sys.path = sys.path[:-1]
import time
import json
import numpy as np
import torch
from torch.utils.data import DataLoader

from lib.options import BaseOptions
from lib.data.GeneBodyDataset import GeneBodyDataset as MyDataset
from lib.mesh_util import *
from lib.net_ddp import create_network, worker_init_fn, ddpSampler, ddp_init, synchronize
from lib.model import GNR
from PIL import Image
import torchvision.transforms as transforms
from tqdm import tqdm, trange
from torch.utils.tensorboard import SummaryWriter
import torch.distributed as dist
import logging
import imageio
import lib.metrics_torch as metrics_torch

def loss_string(loss_dict):
    string = ''
    for key in loss_dict.keys():
        string += '| {}: {:.2e} '.format(key, loss_dict[key].item())
    return string

def print_write(file, string):
    file.write(string)
    if string[-1] == '\n': string = string[:-1]
    print(string)

def to8b(img):
    if isinstance(img, torch.Tensor):
        img = img.detach().cpu().numpy()
    if img.shape[0] == 3 and img.shape[-1] != 3:
        img = np.transpose(img, [1,2,0])
    if img.min() < -.2:
        img = (img + 1) * 127.5
    elif img.max() <= 2.:
        img = img * 255.
    img = np.clip(img, 0, 255)
    return img.astype(np.uint8)
# get options

def prepare_data(opt, data, local_rank=0):
    # retrieve the data
    image_tensor = data['img'][0].to(device=local_rank)
    calib_tensor = data['calib'][0].to(device=local_rank)
    mask_tensor = data['mask'][0].to(device=local_rank)
    bbox = list(data['bbox'][0].numpy().astype(np.int32))
    mesh_param = {'center': data['center'][0].to(device=local_rank), 
                    'body_scale': data['body_scale'][0].cpu().numpy().item()}
    if opt.train_shape:
        mesh_param['samples'] = data['samples'][0].to(device=local_rank)
        mesh_param['labels'] = data['labels'][0].to(device=local_rank)

    if any([opt.use_smpl_sdf, opt.use_t_pose]):
        smpl = { 'rot': data['smpl_rot'].to(device=local_rank) }
        if opt.use_smpl_sdf or opt.use_t_pose:
            smpl['verts'] = data['smpl_verts'][0].to(device=local_rank)
            smpl['faces'] = data['smpl_faces'][0].to(device=local_rank)
        if opt.use_t_pose:
            smpl['t_verts'] = data['smpl_t_verts'][0].to(device=local_rank)
            smpl['t_faces'] = data['smpl_t_faces'][0].to(device=local_rank)
        if opt.use_smpl_depth:
            smpl['depth'] = data['smpl_depth'][0].to(device=local_rank)[:,None,...]
            
    else:
        smpl = None

    if 'scan_verts' in data.keys():
        scan = [data['scan_verts'][0].to(device=local_rank), data['scan_faces'][0].to(device=local_rank)]
    else:
        scan = None

    persps = data['persps'][0].to(device=local_rank) if opt.projection_mode == 'perspective' else None
    
    return {
        'images': image_tensor,
        'calibs': calib_tensor,
        'bbox': bbox,
        'masks': mask_tensor,
        'mesh_param': mesh_param,
        'smpl': smpl,
        'scan': scan,
        'persps': persps
    }

def cal_metrics(metrics, rgbs, gts):
    x = rgbs.clone().permute((0, 3, 1, 2))
    out = {}
    for m_key in metrics.keys():
        out[m_key] = []
        for pred, gt in zip(x, gts):
            metric = metrics[m_key]
            out[m_key].append(metric(pred, gt))
        out[m_key] = torch.stack(out[m_key], dim=0)
    return out

def train(opt, rank=0, local_rank = 0):
    gpu_num = torch.cuda.device_count()
    train_dataset = MyDataset(opt, phase='train')
    test_dataset = MyDataset(opt, phase='test', move_cam=0)
    render_dataset = MyDataset(opt, phase='render', move_cam=opt.move_cam)
    train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) if opt.ddp else None

    # create data loader
    shuffle = not opt.ddp
    train_data_loader = DataLoader(train_dataset, sampler=train_sampler, batch_size=opt.batch_size, 
                                    num_workers=opt.num_threads, shuffle=False, worker_init_fn=worker_init_fn)
    logging.info(f'train data size: {len(train_data_loader)}')

    test_data_loader = DataLoader(test_dataset, batch_size=1)
    test_data_iter = iter(test_data_loader)
    logging.info(f'test data size: {len(test_data_loader)}')

    render_data_loader = DataLoader(render_dataset, batch_size=1)
    render_data_iter = iter(render_data_loader)
    logging.info(f'render data size: {len(render_data_loader)}')
    # create net
    net = GNR(opt)
    
    logging.info(f'Using Network: {net.name}')
    
    set_train = net.train
    set_eval = net.eval

    os.makedirs(opt.basedir, exist_ok=True)
    os.makedirs('%s/%s' % (opt.basedir, opt.name), exist_ok=True)

    net, start_epoch = create_network(opt, net, local_rank)
    global_step = start_epoch * len(train_dataset)

    lr = opt.lrate * (0.1 ** (start_epoch / opt.lrate_decay))
    # params = net.parameters() if opt.train_encoder else net.module.nerf.parameters()
    params_list = []
    for name, param in net.named_parameters():
        if 'occ_linears' in name:
            if opt.train_occlusion:
                params_list.append(param)
        elif 'image_filter' in name:
            if opt.train_encoder:
                params_list.append(param)
        else:
            params_list.append(param)

    optimizer = torch.optim.Adam(params=params_list, lr=lr, betas=(0.9, 0.999))

    is_summary = not opt.ddp or (opt.ddp and (rank == 0))
    if is_summary:
        from tqdm import tqdm, trange
        if opt.train:
            writer = SummaryWriter(os.path.join(opt.basedir, opt.name))
            opt_log = os.path.join(opt.basedir, opt.name, 'opt.txt')
            config_file = os.path.join(opt.basedir, opt.name, 'config.txt')
            with open(opt_log, 'w') as outfile:
                outfile.write(json.dumps(vars(opt), indent=2))
            os.system(f'cp {opt.config} {config_file}')
    else:
        tqdm = lambda x: x
        trange = range

    # evaluate, not demo
    # metrics_dict = {'lpips': [], 'psnr': [], 'ssim': []}
    metrics_dict = {}
    metrics = {'lpips': metrics_torch.LPIPS().to(local_rank), 'psnr': metrics_torch.psnr, 'ssim': metrics_torch.SSIM().to(local_rank)}
    
    # training
    if opt.train:
        for epoch in trange(start_epoch, opt.num_epoch):        

            set_train()
            if opt.ddp:
                train_data_loader.sampler.set_epoch(epoch)
                synchronize()
            
            pbar = tqdm(train_data_loader)
            if is_summary:
                pbar.set_description("epoch {}/{}".format(epoch, opt.num_epoch))

            for train_idx, train_data in enumerate(pbar):
                data = prepare_data(opt, train_data, local_rank)
                train_shape = opt.train_shape and train_idx % opt.train_shape_skips == 0
                loss_dict = net(data, train_shape=train_shape)
                loss = sum(loss_dict.values())

                optimizer.zero_grad()
                try:
                    loss.backward()
                except:
                    print(train_data['name'], train_data['sid'], train_data['vid'], flush=True)
                optimizer.step()

                if global_step % opt.freq_plot == 0 and is_summary:
                    tqdm.write(
                        '[{}] | epoch: {} | step: {:d} | loss:{:.2e} | lr: {:.2e} '.format(
                            opt.name, epoch, global_step, loss.item(), lr))
                    tqdm.write(f'[{opt.name}] {loss_string(loss_dict)}')

                if is_summary:
                    writer.add_scalar('loss', loss.item(), global_step)
                    for key in loss_dict.keys():
                        writer.add_scalar(key, loss_dict[key].item(), global_step)
                    pbar.update(1)
                global_step += 1 if not opt.ddp else gpu_num

            if opt.ddp and (rank == 0):
                torch.save({'network_state_dict': net.module.state_dict(), 'epoch': epoch+1}, 
                            '%s/%s/%04d.tar' % (opt.basedir, opt.name, epoch+1))
            elif not opt.ddp:
                torch.save({'network_state_dict': net.state_dict(), 'epoch': epoch+1}, 
                            '%s/%s/%04d.tar' % (opt.basedir, opt.name, epoch+1))

            # update learning rate
            lr = opt.lrate * (0.1 ** ((epoch+1) / opt.lrate_decay))
            for param_group in optimizer.param_groups:
                param_group['lr'] = lr

    if opt.test:
        net.eval()
        idx = 0
        with torch.no_grad():
            for test_idx in trange(len(test_dataset)):
                ## test dataset
                test_data = next(test_data_iter)
                data = prepare_data(opt, test_data, local_rank)
                name = test_data['name'][0]
                subject = name.split('_')[0]
                 # fid = test_data['fid'][0]
                fid = test_data['sid'][0]
                vid = test_data['vid'][0]
                render_gt = test_data['render_gt'][0].to(local_rank)
                
                if opt.ddp:
                    # distribute query views to different GPUs if multiple GPU or multiple machines are available
                    src_calibs, tar_calibs = torch.split(data['calibs'], [opt.num_views, data['calibs'].shape[0]-opt.num_views], 0)
                    total_len = len(tar_calibs)
                    sampler = ddpSampler(tar_calibs)
                    indices = sampler.indices()
                    tar_calibs = tar_calibs[indices]
                    tar_gt = render_gt[indices]
                    data['calibs'] = torch.cat([src_calibs, tar_calibs], 0)
                    if opt.projection_mode == 'perspective':
                        src_persps, tar_persps = torch.split(data['persps'], [opt.num_views, data['persps'].shape[0]-opt.num_views], 0)
                        tar_persps = tar_persps[indices]
                        data['persps'] = torch.cat([src_persps, tar_persps], 0)

                rgbs, _ = net.module.render_path(data)
                rgbs = sampler.distributed_concat(rgbs, total_len) if opt.ddp else rgbs
                if opt.use_attention:
                    rgbs, att_rgbs = rgbs[...,:3], rgbs[...,3:6]
                else:
                    att_rgbs = rgbs[..., :3]
                m_dict = cal_metrics(metrics, att_rgbs, render_gt)
                
                if subject not in metrics_dict.keys():
                    metrics_dict[subject] = {'lpips': [], 'psnr': [], 'ssim': []}
                for key, value in m_dict.items():
                    # if opt.ddp:
                    #     value = sampler.distributed_concat(value, total_len) if opt.ddp else value
                    metrics_dict[subject][key].append(torch.mean(value).cpu().numpy())

                att_rgbs = [to8b(att_rgb) for att_rgb in att_rgbs.cpu().numpy()]
                render_gt = [to8b(gt) for gt in render_gt.permute(0,2,3,1).cpu().numpy()]
                target_dir = os.path.join(opt.basedir, opt.name, opt.eval_dir, name)
                os.makedirs(target_dir, exist_ok=True)
                if is_summary:
                    for vid, im in enumerate(att_rgbs):
                        imageio.imwrite(os.path.join(target_dir, f'{vid:02d}.png'), im)

                    fname = os.path.join(opt.basedir, opt.name, opt.eval_dir, 'eval.txt')
                    with open(fname, 'a') as file_:
                        print_write(file_, '******\n%s\n' % (name))
                        for k, v in metrics_dict[subject].items():
                            print_write(file_, '%s: %.5f\n'%(k, v[-1]))
                        file_.write('------')
                        for k, v in metrics_dict[subject].items():
                            print_write(file_, '[total] %s: %.5f\n'%(k, sum(v)/len(v)))

                if opt.output_mesh:
                    verts, faces, rgbs = net.module.reconstruct(data)
                    if is_summary:
                        save_obj_mesh_with_color(os.path.join(target_dir, "{}.obj".format(name)), verts, faces, rgbs)
                idx += 1

    if opt.render:
        net.eval()
        with torch.no_grad():
            imgs = []
            for ridx in trange(len(render_dataset)):
                # render dataset
                test_data = next(render_data_iter)
                data = prepare_data(opt, test_data, local_rank)
                name = test_data['name'][0]
                # fid = test_data['fid'][0]
                fid = test_data['sid'][0]
                vid = test_data['vid'][0]
                
                if opt.ddp:
                    # distribute query views to different GPUs if multiple GPU or multiple machines are available
                    src_calibs, tar_calibs = torch.split(data['calibs'], [opt.num_views, data['calibs'].shape[0]-opt.num_views], 0)
                    total_len = len(tar_calibs)
                    sampler = ddpSampler(tar_calibs)
                    indices = sampler.indices()
                    tar_calibs = tar_calibs[indices]
                    data['calibs'] = torch.cat([src_calibs, tar_calibs], 0)
                    if opt.projection_mode == 'perspective':
                        src_persps, tar_persps = torch.split(data['persps'], [opt.num_views, data['persps'].shape[0]-opt.num_views], 0)
                        tar_persps = tar_persps[indices]
                        data['persps'] = torch.cat([src_persps, tar_persps], 0)

                target_dir = os.path.join(opt.basedir, opt.name, opt.render_dir, name.split('_')[0])
                os.makedirs(target_dir, exist_ok=True)
                rgbs, depths = net.module.render_path(data)
                rgbs = sampler.distributed_concat(rgbs, total_len) if opt.ddp else rgbs
                depths = sampler.distributed_concat(depths, total_len) if opt.ddp else depths
                if opt.use_attention:
                    rgbs, att_rgbs = rgbs[...,:3], rgbs[...,3:6]
                else:
                    att_rgbs = rgbs[..., :3]
                att_rgbs = [to8b(att_rgb) for att_rgb in att_rgbs.cpu().numpy()]
                imgs += att_rgbs
                os.makedirs(target_dir, exist_ok=True)
                for vid, im in enumerate(att_rgbs):
                    imageio.imwrite(os.path.join(target_dir, f'{ridx:03d}_rgb.png'), im)
                    depth= np.clip(np.round(depths[vid].cpu().numpy()*1000), 0, 65535).astype(np.uint16)
                    depth[depth == 0] = 65535
                    imageio.imwrite(os.path.join(target_dir, f'{ridx:03d}_depth.png'), depth)
            if is_summary:
                imageio.mimwrite(os.path.join(target_dir, "render_{}.mp4".format(fid)), imgs, quality=8, fps=30)

            

if __name__ == '__main__':
    opt = BaseOptions().parse()
    if opt.ddp:
        rank, local_rank = ddp_init(opt)
        logging.basicConfig(level=logging.INFO if rank in [-1, 0] else logging.WARN)
        logging.info(vars(opt))
        train(opt, rank, local_rank)
    else:
        logging.basicConfig(level=logging.INFO)
        logging.info(vars(opt))
        train(opt)



================================================
FILE: configs/render.txt
================================================
name = genebody

# run phase
train = False
test = False
render = True

# Dataloader
num_threads = 5
output_mesh = True

# Geometric Body Shape Embedding
smpl_type = smplx
use_smpl_sdf = True
use_t_pose = True
use_nml = True

# SSOAB
use_attention = True
weighted_pool = True
use_sh = True
use_viewdirs = True
use_occlusion = True
use_smpl_depth = True
use_occlusion_net = True

# Ray Sampling
use_vh = True
N_rand = 1024
N_rand_infer = 4096
N_samples = 256
chunk = 524288
vh_overhead = 1

# Trianing
ddp = False
train_encoder = False
projection_mode = perspective

# Evaluation
eval_skip = 1

# Render
move_cam = 150

# Reconstruction
N_grid = 512
laplacian = 5



================================================
FILE: configs/test.txt
================================================
name = genebody

# run phase
train = False
test = True
render = False

# Dataloader
num_threads = 5
output_mesh = True

# Geometric Body Shape Embedding
smpl_type = smplx
use_smpl_sdf = True
use_t_pose = True
use_nml = True

# SSOAB
use_attention = True
weighted_pool = True
use_sh = True
use_viewdirs = True
use_occlusion = True
use_smpl_depth = True
use_occlusion_net = True

# Ray Sampling
use_vh = True
N_rand = 1024
N_rand_infer = 4096
N_samples = 256
chunk = 524288
vh_overhead = 1

# Trianing
ddp = False
train_encoder = False
projection_mode = perspective

# Evaluation
eval_skip = 15

# Render
move_cam = 1

# Reconstruction
N_grid = 512
laplacian = 5



================================================
FILE: configs/train.txt
================================================
name = genebody

# run phase
train = True
test = False
render = False

# Dataloader
num_threads = 5
output_mesh = True

# Geometric Body Shape Embedding
smpl_type = smplx
use_smpl_sdf = True
use_t_pose = True
use_nml = True

# SSOAB
use_attention = True
weighted_pool = True
use_sh = True
use_viewdirs = True
use_occlusion = True
use_smpl_depth = True
use_occlusion_net = True

# Ray Sampling
use_vh = True
N_rand = 1024
N_rand_infer = 4096	# decrease the batch size if there is an out-of-memory error
N_samples = 256
chunk = 524288
vh_overhead = 1

# Trianing
ddp = False
train_encoder = True
projection_mode = perspective

# Evaluation
eval_skip = 15

# Render
move_cam = 1

# Reconstruction
N_grid = 512		# decrease the grid size if there is an out-of-memory error
laplacian = 5



================================================
FILE: docs/Annotation.md
================================================
# GeneBody Annotations

## Data Capture
GeneBody dataset captures performer in a motion capture studio with 48 synchronized cameras. Each actor is asked to perform 10 seconds clips recorded in a 15 fps rate. The camera location and capture volume is visualized in the following video.
<!-- ![Teaser image](./genebody.gif#center) -->
<p align="center"><img src="./capture_volume.gif" width="90%"></p>
<p align="center">Left: Motion capture studio and performer, cameras are highlighted. Right: Video captured from camera 25.</p>


## Dataset Organization
The processed GeneBody dataset is organized in following structure
```
├──genebody/            # root of dataset
  ├──amanda/            # subject
    ├──image/           # multiview images
        ├──00/          # images of 00 view
        ├──...
    ├──mask/            # multiview masks
        ├──00/          # masks of 00 view
        ├──...
    ├──param/           # smpl parameters
    ├──smpl/            # smpl meshes in OBJ format
    ├──annots.npy       # camera parameters
  ├──.../
  ├──genebody_split.npy # dataset splits
```
You can download the Test10 and Train40 subset by the [instructions](./Dataset.md#download-instructions). 


## Data Interface
We provide the reference data reader `GeneBodyReader` in `genebody/genebody.py`.
### Source views
The default source view number is 4, and the source views are `[01, 13, 25, 37]`.
### Image cropping
As human performer may appear in different size across views, and the original image plane contains very small proportion of foreground, directly apply image quality metrics on raw image, eg. PSNR, SSIM and LPIPS may introduce ambiguity numerically. To tackle this  we crop the performer out and resize it to a give resolution, in GNR and GeneBody benchmarks. Check the reference interface for more details.

## Camera Calibration
GeneBody provides the camera calibration for each subjects, intrinic matrix, distortion coefficient and extrinic parameters are provided in `annots.npy` in each subject folder. Note that we use a opencv camera (xyz->right,down,front).

## Human Segmentation
### Auto annotation
Before recording, we capture the scene without any performer, and use it as reference image in [BackgroundMatting-V2](https://github.com/PeterL1n/BackgroundMattingV2) to automatically extract the performer foreground mask.
### Human labeling
We choose 8 camera views to manually check the results of auto annoataion and manually labels the bad case. The 8 camera views are `[01, 07, 13, 19, 25, 31, 37, 43]`.

## SMPLx
GeneBody provides per-frame SMPLx estimation, and store the mesh in `smpl` subfolder and SMPLx parameters in `param` subfolder. The SMPLx toolbox is also provided in this [repo](https://github.com/generalizable-neural-performer/bodyfitting).

### SMPLx parameters
GeneBody provide SMPLs parameter and 3D keypoints in `param` subfolder. More specifally, the dictionary 'smplx', can be directly feed to SMPLX [forward](https://github.com/vchoutas/smplx/blob/master/smplx/body_models.py#L1111) pass as long as all the parameters are converted to torch tensor, an example to generate SMPLX from parameters is provided in the [data interface](../genebody/genebody.py#L189).

### SMPLx scale
GeneBody has a large variation on performers' age distribution, while SMPLx model typically fails to fit well on kids and giants. We introduce 'smplx_scale' outside the SMPLx model, and jointly optimize scale and body model parameters during fitting. Thus, to recover the fitted mesh in `smpl` subfolder using parameters in `param` subfolder, you need to multiple a 'smplx_scale' to the vertices or 3d joints output by the body model.

================================================
FILE: docs/Dataset.md
================================================
# GeneBody Dataset

<!-- ![Teaser image](./genebody.gif#center) -->
<p align="center"><img src="./genebody.gif" width="90%"></p>
<p align="center">Please check the <a href="https://generalizable-neural-performer.github.io/genebody.html">dataset webpage</a> for data preview</p>

## Overview
GeneBody is a new dataset that
we collected to evaluate the generalization and robustness of human novel view synthesis. It consists of over total *2.95M* frames of *100* subjects performing *370* sequences under *48* multi-view cameras capturing, with a variety of pose actions, in different types of body shapes, clothing, accessories and hairdos, ranging the geometry and appearance varies from everyday life to professional occasions.


## Agreement
The GeneBody is available for non-commercial research purposes only. 

You agree not to reproduce, duplicate, copy, sell, trade, resell or exploit for any commercial purposes, any portion of the images and any portion of derived data.
You agree not to further copy, publish or distribute any portion of the GeneBody. Except, for internal use at a single site within the same organization it is allowed to make copies of the dataset.

The SenseTime Reasearch and Shanghai AI Lab reserves the right to terminate your access to the GeneBody at any time.

## Download Instructions
### Test10
You can download the *Test10* subset from Onedrive [link](https://hkustconnect-my.sharepoint.com/:f:/g/personal/wchengad_connect_ust_hk/Er9_MPspGvpKlrFbVnmwnbEBloYJN7z9D0rP4Ms8LUcYKA?e=lnjS0N)

### Train40
Please download the GeneBody Dataset Release Agreement from [link](./GeneBody_Dataset_Release_Agreement.pdf).
Read it carefully, complete and sign it appropriately. 

Please send the completed form to Wei Cheng (wchengad(at)connect.ust.hk) and cc to Kwan-Yee Lin (linjunyi(at)sensetime.com) using institutional email address. The email Subject Title is "GeneBody Agreement". We will verify your request and contact you with the passwords to unzip the image data.

### Extended370

Coming Soon in OpenDataLab.

================================================
FILE: environment.yml
================================================
name: gnr
channels:
  - pytorch
  - conda-forge
  - https://repo.anaconda.com/pkgs/main
  - defaults
dependencies:
  - _libgcc_mutex=0.1=main
  - blas=1.0=mkl
  - bzip2=1.0.8=h516909a_3
  - ca-certificates=2021.10.8=ha878542_0
  - certifi=2021.5.30=py36h5fab9bb_0
  - cffi=1.14.5=py36h261ae71_0
  - cudatoolkit=9.0=h13b8566_0
  - ffmpeg=4.3.1=h3215721_1
  - freetype=2.10.4=h5ab3b9f_0
  - gmp=6.2.1=h58526e2_0
  - gnutls=3.6.13=h85f3911_1
  - intel-openmp=2020.2=254
  - jpeg=9b=h024ee3a_2
  - lame=3.100=h14c3975_1001
  - lcms2=2.11=h396b838_0
  - ld_impl_linux-64=2.33.1=h53a641e_7
  - libffi=3.3=he6710b0_2
  - libgcc-ng=9.1.0=hdf63c60_0
  - libiconv=1.16=h516909a_0
  - libpng=1.6.37=hbc83047_0
  - libstdcxx-ng=9.1.0=hdf63c60_0
  - libtiff=4.1.0=h2733197_1
  - lz4-c=1.9.3=h2531618_0
  - mkl=2020.2=256
  - mkl-service=2.3.0=py36he8ac12f_0
  - mkl_fft=1.3.0=py36h54f3939_0
  - mkl_random=1.1.1=py36h0573a6f_0
  - ncurses=6.2=he6710b0_1
  - nettle=3.6=he412f7d_0
  - ninja=1.10.2=py36hff7bd54_0
  - numpy-base=1.19.2=py36hfa32c7d_0
  - olefile=0.46=py36_0
  - openh264=2.1.1=h8b12597_0
  - openssl=1.1.1m=h7f8727e_0
  - pillow=8.1.2=py36he98fc37_0
  - pip=21.0.1=py36h06a4308_0
  - pycparser=2.20=py_2
  - python=3.6.13=hdb3f193_0
  - python_abi=3.6=2_cp36m
  - pytorch=1.1.0=py3.6_cuda9.0.176_cudnn7.5.1_0
  - readline=8.1=h27cfd23_0
  - setuptools=52.0.0=py36h06a4308_0
  - sqlite=3.35.3=hdfb4753_0
  - tk=8.6.10=hbc83047_0
  - torchvision=0.3.0=py36_cu9.0.176_1
  - wheel=0.36.2=pyhd3eb1b0_0
  - x264=1!152.20180806=h14c3975_0
  - xz=5.2.5=h7b6447c_0
  - zlib=1.2.11=h7b6447c_3
  - zstd=1.4.9=haebb681_0
  - pip:
    - absl-py==0.12.0
    - cachetools==4.2.1
    - chardet==4.0.0
    - configargparse==1.4
    - cycler==0.10.0
    - dataclasses==0.8
    - decorator==4.4.1
    - future==0.18.2
    - google-auth==1.28.0
    - google-auth-oauthlib==0.4.4
    - grpcio==1.36.1
    - idna==2.10
    - imageio==2.8.0
    - imageio-ffmpeg==0.4.3
    - importlib-metadata==3.10.0
    - kiwisolver==1.1.0
    - lpips==0.1.4
    - markdown==3.3.4
    - matplotlib==3.1.3
    - networkx==2.4
    - numpy==1.18.1
    - oauthlib==3.1.0
    - opencv-python==4.2.0.32
    - opencv-python-headless==4.2.0.34
    - pathlib==1.0.1
    - protobuf==3.15.6
    - pyasn1==0.4.8
    - pyasn1-modules==0.2.8
    - pyopengl==3.1.5
    - pyparsing==2.4.6
    - python-dateutil==2.8.1
    - pywavelets==1.1.1
    - pyyaml==6.0
    - requests==2.25.1
    - requests-oauthlib==1.3.0
    - rsa==4.7.2
    - rtree==0.9.7
    - scikit-image==0.16.2
    - scipy==1.4.1
    - shapely==1.7.0
    - six==1.14.0
    - smplx==0.1.28
    - tensorboard==2.4.1
    - tensorboard-plugin-wit==1.8.0
    - tqdm==4.43.0
    - trimesh==3.5.23
    - typing-extensions==3.7.4.3
    - unknown==0.0.0
    - urllib3==1.26.4
    - video2calibration==0.0.1
    - werkzeug==1.0.1
    - xxhash==1.4.3
    - zipp==3.4.1
    - gdown


================================================
FILE: genebody/download_tool.py
================================================
from ast import arg
import json, os, sys, pip, copy
from re import sub
def import_or_install(package):
    try:
        __import__(package)
    except ImportError:
        pip.main(['install', package])
import_or_install("urllib")
import_or_install("requests")
import_or_install("quickxorhash")
import_or_install("asyncio")
import_or_install("pyppeteer")
import urllib, requests, quickxorhash, asyncio
import urllib.request
from urllib import parse
from pyppeteer import launch
from tqdm import tqdm
from requests.models import codes
from requests.adapters import HTTPAdapter, Retry
import base64
import numpy as np
import argparse
import chardet

# simulate browser
header = {
    'sec-ch-ua-mobile': '?0',
    'upgrade-insecure-requests': '1',
    'dnt': '1',
    'user-agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/90.0.4430.93 Safari/537.36 Edg/90.0.818.51',
    'accept': 'text/html,application/xhtml+xml,application/xml;q=0.9,image/webp,image/apng,*/*;q=0.8,application/signed-exchange;v=b3;q=0.9',
    'service-worker-navigation-preload': 'true',
    'sec-fetch-site': 'same-origin',
    'sec-fetch-mode': 'navigate',
    'sec-fetch-dest': 'iframe',
    'accept-language': 'zh-CN,zh;q=0.9,en;q=0.8,en-GB;q=0.7,en-US;q=0.6',
}

genebody_urls = {
    "test10": "https://hkustconnect-my.sharepoint.com/:f:/g/personal/wchengad_connect_ust_hk/Er9_MPspGvpKlrFbVnmwnbEBloYJN7z9D0rP4Ms8LUcYKA",
    "train40": "https://hkustconnect-my.sharepoint.com/:f:/g/personal/wchengad_connect_ust_hk/Et-eRSKLr09OjUPUTiJW0zAB1yWo1kx_qr7t1FsG2cuv2g",
    "smpl_depth": "https://hkustconnect-my.sharepoint.com/:f:/g/personal/wchengad_connect_ust_hk/Eju1mH_WFLtMknBzuii8mtIBVAzxUQMCWIaLl67AGySoOA?e=MsVwJG",
    "pretrained_models": "https://hkustconnect-my.sharepoint.com/:f:/g/personal/wchengad_connect_ust_hk/EpsK3TaBfGBBgtDlcWV6nxABMDT9C1qe2fE6EUxoTkQYkQ"
}

def parse_args():
    """
    Args:
        genebody_root: path to store download data.
        force: wheather force to download data, 
            we set default value to False, assume that user just want to update data.
            if you have no data in local path, pls set this argument to True.
        subset: data subset to download.
            pls choose type in ['test10', 'train40', 'smpl_depth', 'pretrained_models'].
    """
    parse = argparse.ArgumentParser(description='Download Genebody data')
    parse.add_argument('--genebody_root', type=str, required=True, help='path to store download data')
    parse.add_argument('--force', type=bool, default=False, help='wheather force to download data')
    parse.add_argument('--subset', type=str, required=True, help='data subset to download')

    args = parse.parse_args()
    
    return args

def newSession():
    s = requests.session()
    retries = Retry(total=5, backoff_factor=0.1)
    s.mount('http://', HTTPAdapter(max_retries=retries))
    return s

def save_hash(path, code):
    with open(path, 'w', encoding='utf-8') as f:
        f.write(code)

def read_hash(path):
    with open(path, 'r', encoding='utf-8') as f:
        str = f.read()
    return str

def checkHashes(localfile, cloud_hash, localroot, force):
    """
    Synchronize local file with cloud file by hash checking
    (For GeneBody, the file is on Onedrive for Bussiness, only "quickXorHash" is avalible)
    ref: https://docs.microsoft.com/en-us/onedrive/developer/rest-api/resources/hashes?view=odsp-graph-online
    Args:
        localfile: local file path
        cloud_hash: cloud hash of cloud file 
    Output:
        matched: [True, False], whether local file matches with cloud file
    """
    if not os.path.exists(os.path.dirname(localfile)):
        os.makedirs(os.path.dirname(localfile), exist_ok=True)
    if localfile.split('/')[1] == 'pretrained_models':
        local_hash_bkup = os.path.join(localroot, '.hash', os.path.join(localfile.split('/')[2], localfile.split('/')[3].split('.')[0]+'.txt'))
    else:
        local_hash_bkup = os.path.join(localroot, '.hash', os.path.join(localfile.split('/')[-1].split('.')[0]+'.txt'))
    if not os.path.exists(os.path.dirname(local_hash_bkup)):
        os.makedirs(os.path.dirname(local_hash_bkup), exist_ok=True)    

    if force:
        save_hash(local_hash_bkup, cloud_hash["quickXorHash"])
        tqdm.write(f"force to download data")
        return False

    # if local file is not deleted
    if os.path.exists(localfile):
        with open(localfile, 'rb') as lf:
            content = lf.read()
            hash = quickxorhash.quickxorhash()
            hash.update(content)
            hashoutput = base64.b64encode(hash.digest()).decode('ascii')
            # # write local hash backup
            save_hash(local_hash_bkup, cloud_hash["quickXorHash"])
            if hashoutput == cloud_hash["quickXorHash"]:
                tqdm.write(f"[{os.path.relpath(localfile, localroot)}] local file is up-to-date, skip downloading")
                return True
            else:
                tqdm.write(f"[{os.path.relpath(localfile, localroot)}] local file is out-of-date, updating")
                return False
    else:
        if os.path.isfile(local_hash_bkup):
            # read hashcode and compare
            hashoutput = read_hash(local_hash_bkup)
            if hashoutput == cloud_hash["quickXorHash"]:
                tqdm.write(f"[{os.path.relpath(localfile, localroot)}] local file is up-to-date, skip downloading")
                return True
            else:
                save_hash(local_hash_bkup, cloud_hash["quickXorHash"])
                tqdm.write(f"[{os.path.relpath(localfile, localroot)}] local file is out-of-date, updating")
                return False

        else:
            # write and record hashcode in txt file
            save_hash(local_hash_bkup, cloud_hash["quickXorHash"])
            tqdm.write(f"[{os.path.basename(localfile)}] no local file or local file is missing, downloading")
            # sys.stdout.flush()
            return False
        

def getFiles(originalUrl, download_path, force, download_root=None, req=None, layers=0, _id=0):
    """
    Get file from folder share link (with "-my" share point url)
    ref: https://docs.microsoft.com/en-us/graph/use-the-api

    Args:
        originalUrl: share link
        download_path: path to download
        req: request None
    """
    isSharepoint = False
    if "-my" not in originalUrl:
        isSharepoint = True
    if req == None:
        req = newSession()
    reqf = req.get(originalUrl, headers=header)
    if ',"FirstRow"' not in reqf.text:
        print("\t"*layers, "No file in this folder")
        return 0
    if download_root is None:
        download_root = download_path

    filesData = []
    redirectURL = reqf.url

    query = dict(urllib.parse.parse_qsl(
        urllib.parse.urlsplit(redirectURL).query))
    redirectSplitURL = redirectURL.split("/")

    relativeFolder = ""
    rootFolder = query["id"]
    for i in rootFolder.split("/"):
        if isSharepoint:
            if i != "Shared Documents":
                relativeFolder += i+"/"
            else:
                relativeFolder += i
                break
        else:
            if i != "Documents":
                relativeFolder += i+"/"
            else:
                relativeFolder += i
                break
    relativeUrl = parse.quote(relativeFolder).replace(
        "/", "%2F").replace("_", "%5F").replace("-", "%2D")
    rootFolderUrl = parse.quote(rootFolder).replace(
        "/", "%2F").replace("_", "%5F").replace("-", "%2D")

    graphqlVar = '{"query":"query (\n        $listServerRelativeUrl: String!,$renderListDataAsStreamParameters: RenderListDataAsStreamParameters!,$renderListDataAsStreamQueryString: String!\n        )\n      {\n      \n      legacy {\n      \n      renderListDataAsStream(\n      listServerRelativeUrl: $listServerRelativeUrl,\n      parameters: $renderListDataAsStreamParameters,\n      queryString: $renderListDataAsStreamQueryString\n      )\n    }\n      \n      \n  perf {\n    executionTime\n    overheadTime\n    parsingTime\n    queryCount\n    validationTime\n    resolvers {\n      name\n      queryCount\n      resolveTime\n      waitTime\n    }\n  }\n    }","variables":{"listServerRelativeUrl":"%s","renderListDataAsStreamParameters":{"renderOptions":5707527,"allowMultipleValueFilterForTaxonomyFields":true,"addRequiredFields":true,"folderServerRelativeUrl":"%s"},"renderListDataAsStreamQueryString":"@a1=\'%s\'&RootFolder=%s&TryNewExperienceSingle=TRUE"}}' % (relativeFolder, rootFolder, relativeUrl, rootFolderUrl)

    s2 = urllib.parse.urlparse(redirectURL)
    tempHeader = copy.deepcopy(header)
    tempHeader["referer"] = redirectURL
    tempHeader["cookie"] = reqf.headers["set-cookie"]
    tempHeader["authority"] = s2.netloc
    tempHeader["content-type"] = "application/json;odata=verbose"

    graphqlReq = req.post(
        "/".join(redirectSplitURL[:-3])+"/_api/v2.1/graphql", data=graphqlVar.encode('utf-8'), headers=tempHeader)
    graphqlReq = json.loads(graphqlReq.text)
    if "NextHref" in graphqlReq["data"]["legacy"]["renderListDataAsStream"]["ListData"]:
        nextHref = graphqlReq[
            "data"]["legacy"]["renderListDataAsStream"]["ListData"]["NextHref"]+"&@a1=%s&TryNewExperienceSingle=TRUE" % (
            "%27"+relativeUrl+"%27")
        filesData.extend(graphqlReq[
            "data"]["legacy"]["renderListDataAsStream"]["ListData"]["Row"])

        listViewXml = graphqlReq[
            "data"]["legacy"]["renderListDataAsStream"]["ViewMetadata"]["ListViewXml"]
        renderListDataAsStreamVar = '{"parameters":{"__metadata":{"type":"SP.RenderListDataParameters"},"RenderOptions":1216519,"ViewXml":"%s","AllowMultipleValueFilterForTaxonomyFields":true,"AddRequiredFields":true}}' % (
            listViewXml).replace('"', '\\"')

        graphqlReq = req.post(
            "/".join(redirectSplitURL[:-3])+"/_api/web/GetListUsingPath(DecodedUrl=@a1)/RenderListDataAsStream"+nextHref, data=renderListDataAsStreamVar.encode('utf-8'), headers=tempHeader)
        graphqlReq = json.loads(graphqlReq.text)

        while "NextHref" in graphqlReq["ListData"]:
            nextHref = graphqlReq["ListData"]["NextHref"]+"&@a1=%s&TryNewExperienceSingle=TRUE" % (
                "%27"+relativeUrl+"%27")
            filesData.extend(graphqlReq["ListData"]["Row"])
            graphqlReq = req.post(
                "/".join(redirectSplitURL[:-3])+"/_api/web/GetListUsingPath(DecodedUrl=@a1)/RenderListDataAsStream"+nextHref, data=renderListDataAsStreamVar.encode('utf-8'), headers=tempHeader)
            graphqlReq = json.loads(graphqlReq.text)
        filesData.extend(graphqlReq["ListData"]["Row"])
    else:
        filesData.extend(graphqlReq[
            "data"]["legacy"]["renderListDataAsStream"]["ListData"]["Row"])

    filesData = sorted(filesData, key=lambda x: x['FileLeafRef'])
    for i in filesData:
        # if is a folder, download recursively
        if i['FSObjType'] == "1":
            _query = query.copy()
            _query['id'] = os.path.join(_query['id'],  i['FileLeafRef']).replace("\\", "/")
            if not isSharepoint:
                originalPath = "/".join(redirectSplitURL[:-1]) + \
                    "/onedrive.aspx?" + urllib.parse.urlencode(_query)
            else:
                originalPath = "/".join(redirectSplitURL[:-1]) + \
                    "/AllItems.aspx?" + urllib.parse.urlencode(_query)
            getFiles(originalPath, os.path.join(download_path, i['FileLeafRef']), force, download_root, req=req, layers=layers+1)
        # if is a file, download directly
        else:
            reqf = req.get(i[".spItemUrl"], headers=header)
            filemeta = json.loads(reqf.text)

            url, name, hash = filemeta["@content.downloadUrl"], filemeta["name"], filemeta["file"]["hashes"]
            r = requests.get(url, stream = True) 
            total_length = int(r.headers.get('content-length', 0))
            local_file = os.path.join(download_path, name)
            
            # Check hash code of local file and cloud file
            if not checkHashes(local_file, hash, download_root, force):            
                with open(os.path.join(download_path, name), 'wb') as f, \
                    tqdm(desc=os.path.relpath(local_file, download_root),total=total_length,
                         unit='iB',unit_scale=True,unit_divisor=1024) as bar:
                    for chunk in r.iter_content(chunk_size = 1024): 
                        if chunk: 
                            size = f.write(chunk)
                            bar.update(size)


pheader = {}
url = ""

async def fetch_with_pwd(iurl, password):
    """
    Fetch data with data password

    Args:
        iurl: input share folder url
        password: password of the share folder
    """
    global pheader, url
    browser = await launch(options={'args': ['--no-sandbox']})
    page = await browser.newPage()
    await page.goto(iurl, {'waitUntil': 'networkidle0'})
    await page.focus("input[id='txtPassword']")
    await page.keyboard.type(password)
    verityElem = await page.querySelector("input[id='btnSubmitPassword']")
    print("Password input complete, jumping")

    await asyncio.gather(
        page.waitForNavigation(),
        verityElem.click(),
    )
    url = await page.evaluate('window.location.href', force_expr=True)
    await page.screenshot({'path': 'example.png'})
    print("Fetching cookies")
    _cookie = await page.cookies()
    pheader = ""
    for __cookie in _cookie:
        coo = "{}={};".format(__cookie.get("name"), __cookie.get("value"))
        pheader += coo
    await browser.close()

def havePwdGetFiles(iurl, password, download_path, force):
    global header
    asyncio.get_event_loop().run_until_complete(fetch_with_pwd(iurl, password))
    header['cookie'] = pheader
    getFiles(url, download_path, force)

def extractFiles(path, subset):
    """
    Extract download data

    Args:
        path: path contains data to extract
    """
    all_files = os.listdir(path)
    for file in all_files:
        if not file.endswith('.gz'):
            continue
        file_path = os.path.join(path, file)
        extract_cmd = f'tar -xvf {file_path} -C {path}'
        os.system(extract_cmd)
        rm_cmd = f'rm {file_path}'
        os.system(rm_cmd)
    if subset == 'smpl_depth':
        src_path = os.path.join(path, 'GeneBody')
        mv_cmd = f'rsync -av {src_path}/* {path}'
        os.system(mv_cmd)
        rm_cmd = f'rm -rf {src_path}'
        os.system(rm_cmd)

def moveFiles(root):
    """
    Move pretrained models to benchmark directory

    Args:
        root: path contains download pretrained models
    """
    os.makedirs('benchmarks', exist_ok=True)
    models = os.listdir(os.path.join(root, 'pretrained_models'))
    for model in models:
        src_path = os.path.join(root, 'pretrained_models', model)
        if model != 'gnr':
            cmd = f'mv {src_path} benchmarks'
            os.system(cmd)
        else:
            os.makedirs('./logs/genebody', exist_ok=True)
            os.system(f'mv {src_path}/* ./logs/genebody/')
    pretrain_path = os.path.join(root, 'pretrained_models')
    rm_cmd = f'rm -rf {pretrain_path}'
    os.system(rm_cmd)

if __name__ == "__main__":
    args = parse_args()
    genebody_root = args.genebody_root
    force = args.force
    subsets = [set_ for set_ in ['test10', 'train40', 'smpl_depth', 'pretrained_models'] if set in args.subset]

    for subset in subsets:
        if subset == 'train40':
            pwd = input("Please input Train40 password, or contact with the author for data access: \n")
            havePwdGetFiles(genebody_urls[subset], pwd, genebody_root, force)
        elif subset == 'pretrained_models':
            getFiles(genebody_urls[subset], os.path.join(genebody_root, 'pretrained_models'), force)
        else:
            getFiles(genebody_urls[subset], genebody_root, force)
        if subset == 'pretrained_models':
            moveFiles(genebody_root)
        else:
            extractFiles(genebody_root, subset)


================================================
FILE: genebody/gender.py
================================================

genebody_gender = {
    "abror": "male", 
    "ahha": "female", 
    "alejandro": "male", 
    "amanda": "female", 
    "amaris": "female", 
    "anastasia": "female", 
    "aosilan": "male", 
    "arslan": "male", 
    "barlas": "male", 
    "barry": "male", 
    "camilo": "male", 
    "dannier": "male", 
    "dilshod": "male", 
    "fenghaohan": "male", 
    "fuzhizhi": "female", 
    "fuzhizhi2": "female", 
    "gaoxing": "female", 
    "huajiangtao3": "male", 
    "huajiangtao5": "male", 
    "ivan": "male", 
    "jinyutong": "female", 
    "jinyutong2": "female", 
    "joseph_matanda": "female", 
    "kamal_ejaz": "male", 
    "kemal": "male", 
    "lihongyun": "female", 
    "mahaoran": "male", 
    "maria": "female", 
    "natacha": "female", 
    "quyuanning": "female", 
    "rivera": "male", 
    "shchyerbina_oleksandrsongyujie": "male", 
    "soufianou_boubacar_moumouni": "male", 
    "sunyuxing": "male", 
    "Tichinah_jervier": "female", 
    "wangxiang": "male", 
    "wuwenyan": "male", 
    "xujiarui": "female", 
    "yaoqibin": "male", 
    "zhanghao": "male", 
    "zhanghongwei": "female", 
    "zhangzixiao": "female", 
    "zhengxin": "female", 
    "zhonglantai": "female", 
    "zhuna": "female", 
    "zhuna2": "female", 
    "zhuxuezhi": "male", 
    "songyujie": "male", 
    "rabbi": "male",
    "zhangziyu": "female"
}

================================================
FILE: genebody/genebody.py
================================================
import os, sys
import numpy as np 
import cv2, imageio
from .mesh import load_ply, load_obj_mesh, write_obj_mesh
import torch
from .gender import genebody_gender

def image_cropping(mask, padding=0.1):
    """
    To better evaluate different metric on rendered images, we crop out the human performer and resize the cropped 
    image to the same resolution. This function provides returns the bound box of human performer given the mask.
    mask: np.ndarry of mask 
    padding: padding of the bounding box
    """
    a = np.where(mask != 0)
    h, w = list(mask.shape[:2])
    if len(a[0]) > 0:   # valid mask
        top, left, bottom, right = np.min(a[0]), np.min(a[1]), np.max(a[0]), np.max(a[1])
    else:               # mask failure
        return 0,0,mask.shape[0],mask.shape[1]
    bbox_h, bbox_w = bottom - top, right - left

    # padd bbox
    bottom = min(int(bbox_h*padding+bottom), h)
    top = max(int(top-bbox_h*padding), 0)
    right = min(int(bbox_w*padding+right), w)
    left = max(int(left-bbox_h*padding), 0)
    bbox_h, bbox_w = bottom - top, right - left
    bbox_h = min(bbox_h, h, w)
    bbox_w = min(bbox_w, h, w)

    if bbox_h >= bbox_w:
        w_c = (left+right) / 2
        size = bbox_h
        if w_c - size / 2 < 0:
            left = 0
            right = size
        elif w_c + size / 2 >= w:
            left = w - size
            right = w
        else:
            left = int(w_c - size / 2)
            right = left + size
        h_c = (top+bottom) / 2
        top = int(h_c - size / 2)
        bottom = top + size
    else:   # bbox_w >= bbox_h
        h_c = (top+bottom) / 2
        size = bbox_w
        if h_c - size / 2 < 0:
            top = 0
            bottom = size
        elif h_c + size / 2 >= h:
            top = h - size
            bottom = h
        else:
            top = int(h_c - size / 2)
            bottom = top + size
        w_c = (left+right) / 2
        left = int(w_c - size / 2)
        right = left + size

    return top, left, bottom, right

class GeneBodyReader():
    def __init__(self, rootdir, loadsize=512):
        self.rootdir = rootdir
        self.split = np.load(os.path.join(rootdir, 'genebody_split.npy'), allow_pickle=True).item()
        self.loadsize = loadsize
        # the default seting of GNR is to use these four source views of GeneBody
        self.sourceviews = ['01', '13', '25', '37']
        self.gender = genebody_gender
        self.rawsize = (2448, 2048)

    def get_views(self, subject):
        """
        Returns valid camera views of each sequence. 
        Note that there are several view missing subjects in GeneBody. More specifically, 
        "Tichinah_jervier" misses [32], 
        "wuwenyan" misses [34, 36], 
        "joseph_matanda" misses [39, 40, 42, 43, 44, 45, 46, 47]

        subject: name of subject
        all_views: all valid views of this subject
        """

        all_views = sorted(os.listdir(os.path.join(self.rootdir, subject, 'image')))
        ## alt
        # all_views = sorted(np.load(os.path.join(self.rootdir, subject, 'annots.npy'), allow_pickle=True).item()['cams'].keys())

        return all_views

    def get_frames(self, subject):
        frame_list = []
        frame_list = os.listdir(os.path.join(self.rootdir, subject, 'image', '00'))
        frame_list = sorted([frame[:-4] for frame in frame_list])
        return frame_list

    def get_cameras(self, subject):
        return np.load(os.path.join(self.rootdir, subject, 'annots.npy'), allow_pickle=True).item()['cams']

    def get_smpl(self, subject, frame):
        """
        Returns the smpl vertices and faces
        frame: all frames of the subject <- self.get_frames(subject)
        """
        smpl_path = os.path.join(self.rootdir, subject, 'smpl', frame+'.obj')
        vert, face = load_obj_mesh(smpl_path)
        return vert, face

    def get_smpl_param(self, subject, frame):
        """
        Returns the smpl parameters and smpl scale
        frame_list: all frames of the subject <- self.get_frames(subject)
        """
        param_path = os.path.join(self.rootdir, subject, 'param', frame+'.npy')
        # global_orient and pose are Rodrigues rotation vector
        param = np.load(param_path, allow_pickle=True).item()

        # the smpl_param is a dictionary of smplx parameters which can be directory passed to a SMPLX forward pass
        # via SMPLXLayer(**smpl_param) if each value of it is converted to torch.Tensor
        smpl_param = param["smplx"]
        for key in smpl_param.keys():
            if isinstance(smpl_param[key], torch.Tensor):
                smpl_param[key] = smpl_param[key].numpy()
        # For GeneBody, we fit human performer in a wide age range, and SMPLx cannot fit well on kids and giants
        # we use a smplx_scale outside SMPLX model via direct scaling.
        # You can recover the smplx mesh in 'smpl' directory via SMPLX(**smpl_param) * smpl_scale
        smpl_scale = param["smplx_scale"]

        return smpl_param, smpl_scale

    def get_data(self, subject, frame, camera_params, views):
        """
        Fetch one frame of multiview data from database with cropping
        subject: name of subject
        frame: frame of the subject <- self.get_frames(subject)[frameid]
        all_views: all views of subject <- self.get_views(subject)
        camera_params: camera parameters <- self.get_annot(subject)
        frame_id: eg. 1
        views: list of views to fetch, eg. load sourceviews through self.sourceviews,
               or all view through self.get_views(subject)
        """
        subject_dir = os.path.join(self.rootdir, subject)
        
        Ks, c2ws, Ds, images, masks = [], [], [], [], []
        for view in views:
            img = imageio.imread(os.path.join(subject_dir, 'image', view, frame+'.jpg'))
            msk = imageio.imread(os.path.join(subject_dir, 'mask', view, f'mask{frame}.png'))
            # crop the human out from raw image            
            top, left, bottom, right = image_cropping(msk)
            img = img * (msk > 128).astype(np.uint8)[...,None]
            # resize to uniform resolution
            img = cv2.resize(img[top:bottom, left:right].copy(), (self.loadsize, self.loadsize), cv2.INTER_CUBIC)
            images.append(img)
            msk = cv2.resize(msk[top:bottom, left:right].copy(), (self.loadsize, self.loadsize), cv2.INTER_NEAREST)
            masks.append(msk)

            # adjust the camera intrinsic parameter because of the cropping and resize
            # Note that there is no need to adjust extrinsic or distortation coefficents
            K, c2w, D = camera_params[view]['K'].copy(), camera_params[view]['c2w'].copy(), camera_params[view]['D'].copy()
            K[0,2] -= left
            K[1,2] -= top
            K[0,:] *= self.loadsize / float(right - left)
            K[1,:] *= self.loadsize / float(bottom - top)
            Ks.append(K)
            c2ws.append(c2w)
            Ds.append(D)
        
        return images, masks, Ks, c2ws, Ds

    def get_near_far(self, verts, c2w, pad=0.5):
        """
        Get near far plane of perspective project from SMPL estimation
        verts: SMPLx vertices
        c2w: Camera to world roation matrix
        pad: near far padding from SMPLx near far, set smaller if you want tighter bound, set larger if the accessory is huge.
        """
        w2c = np.linalg.inv(c2w)
        # Transform SMPLx to camera coordinate
        vp = verts.dot(w2c[:3,:3].T) + w2c[:3,3:].T
        vmin, vmax = vp.min(0), vp.max(0)
        # near far are minmax in z axis
        near, far = vmin[2], vmax[2]
        near, far = near-(far-near)*pad, far+(far-near)*pad
        return near, far

    def smpl_from_param(self, model_path, subject, smpl_param, smpl_scale):
        import smplx
        smpl = smplx.SMPLX(
            model_path=model_path,
            gender=self.gender[subject],
            use_pca=False,
        )
        smpl_param = smpl_param.copy()
        for key in smpl_param.keys():
            if isinstance(smpl_param[key], np.ndarray):
                smpl_param[key] = torch.from_numpy(smpl_param[key])
        
        output = smpl(**smpl_param)
        verts = output['vertices'].numpy().reshape(-1,3)

        # To align with keypoints3d saved in param, use the base keypoints only,
        # if you want to use the full keypoints3d with extra joints and landmarks,
        # please refer the the definition of joints in
        # vertex_joint_selector.py and landmarks in vertices2landmarks in lbs.py
        keypoints3d = output['joints'].numpy().reshape(-1,3)[:55]
        
        return verts*smpl_scale, smpl.faces, keypoints3d*smpl_scale


if __name__ == "__main__":
    ## Here is a example
    # python genebody/genebody.py path_to_genebody fuzhizhi
    root = sys.argv[1]
    subject = sys.argv[2]

    genebody = GeneBodyReader(root)
    print(subject, ' is a ', 'training set' if subject in genebody.split['train'] else 'test set')

    views = genebody.get_views(subject)
    frames = genebody.get_frames(subject)
    camera_params = genebody.get_cameras(subject)
    frame = frames[9]

    imgs, msks, Ks, c2ws, Ds = genebody.get_data(subject, frame, camera_params, genebody.sourceviews)
    print(f'loaded {len(imgs)} frames of images and masks in size of ', list(imgs[0].shape))

    verts, faces = genebody.get_smpl(subject, frame)
    smpl_param, smpl_scale = genebody.get_smpl_param(subject, frame)
    print('mesh with size ', verts.shape, ' body scale ', smpl_scale)

    near, far = genebody.get_near_far(verts, c2ws[0])
    print('the near far is ', near, far)

    # to test smplx parameter tor smplx mesh, please try the following command
    # python genebody/genebody.py path_to_genebody fuzhizhi path_to_smplx
    if len(sys.argv) == 4:
        import trimesh
        smplx_path = sys.argv[3]
        verts_from_param, faces_from_param, kpts_from_param = genebody.smpl_from_param(smplx_path, subject, smpl_param, smpl_scale)
        print('average error of parameter generated smplx is ', np.abs(verts_from_param-verts).mean())
        print('average error of parameter generated keypoints is ', np.abs(smpl_param['keypoints3d'].reshape(-1,3)-kpts_from_param).mean())
        print(verts.min(0), verts.max(0), kpts_from_param.min(0), kpts_from_param.max(0))
        smpl_mesh = trimesh.Trimesh(verts_from_param, faces_from_param)
        smpl_mesh.export(f'{subject}.obj')



================================================
FILE: genebody/mesh.py
================================================
import numpy as np
import struct
import sys, os, re
import cv2
if sys.version_info[0] == 3:
    from functools import reduce
# type: (regularExpression, structPack, stringFormat, numpyType, isFloat)

decode_map = {
    'bool': ('([01])', '?', '%d', 1, np.dtype('bool'), False),
    'uchar':('([0-9]{1,3})', 'B', '%d', 1, np.uint8, False, 1),
    'uint8':('([0-9]{1,3})', 'B', '%d', 1, np.uint8, False),
    'byte': ('([0-9]{1,3})', 'B', '%d', 1, np.uint8, False),
    'unsigned char':('([0-9]{1,3})', 'B', '%d', 1, np.uint8, False),
    'char': ('(-?[0-9]{1,3})', 'b', '%d', 1, np.int8, False, 1),
    'int8': ('(-?[0-9]{1,3})', 'b', '%d', 1, np.int8, False),

    'ushort': ('([0-9]{1,5})', 'H', '%d', 2, np.uint16, False, 1),
    'uint16': ('([0-9]{1,5})', 'H', '%d', 2, np.uint16, False),
    'unsigned short': ('([0-9]{1,5})', 'H', '%d', 2, np.uint16, False),
    'short': ('(-?[0-9]{1,5})', 'h', '%d', 2, np.int16, False, 1),
    'int16': ('(-?[0-9]{1,5})', 'h', '%d', 2, np.int16, False),
    'half': ('(-?[0-9]*\.?[0-9]*[eE]?[-\+]?[0-9]*)', 'e', '%f', 2, np.float16, True, 1),
    'float16':('(-?[0-9]*\.?[0-9]*[eE]?[-\+]?[0-9]*)', 'e', '%f', 2, np.float16, True),

    'uint':  ('([0-9]{1,10})', 'I', '%u', 4, np.uint32, False, 1),
    'uint32':('([0-9]{1,10})', 'I', '%u', 4, np.uint32, False),
    'ulong': ('([0-9]{1,10})', 'I', '%u', 4, np.uint32, False),
    'unsigned':('([0-9]{1,10})', 'I', '%u', 4, np.uint32, False),
    'unsigned int':('([0-9]{1,10})', 'I', '%u', 4, np.uint32, False),
    'unsigned long':('([0-9]{1,10})', 'I', '%u', 4, np.uint32, False),
    'int':  ('(-?[0-9]{1,10})', 'i', '%d', 4, np.int32, False, 1),
    'long': ('(-?[0-9]{1,10})', 'i', '%d', 4, np.int32, False),
    'int32':('(-?[0-9]{1,10})', 'i', '%d', 4, np.int32, False),
    'float':('(-?[0-9]*\.?[0-9]*[eE]?[-\+]?[0-9]*)', 'f', '%f', 4, np.float32, True, 1),
    'single':('(-?[0-9]*\.?[0-9]*[eE]?[-\+]?[0-9]*)', 'f', '%f', 4, np.float32, True),
    'float32':('(-?[0-9]*\.?[0-9]*[eE]?[-\+]?[0-9]*)', 'f', '%f', 4, np.float32, True),

    'uint64':('([0-9]{1,20})', 'Q', '%lu', 8, np.uint64, False, 1),
    'ullong':('([0-9]{1,20})', 'Q', '%lu', 8, np.uint64, False),
    'unsigned long long':('([0-9]{1,20})', 'Q', '%lu', 8, np.uint64, False),
    'int64':('(-?[0-9]{1,19})', 'q', '%ld', 8, np.int64, False, 1),
    'llong':('(-?[0-9]{1,19})', 'q', '%ld', 8, np.int64, False),
    'long long':('(-?[0-9]{1,19})', 'q', 8, np.int64, False),
    'double':('(-?[0-9]*\.?[0-9]*[eE]?[-\+]?[0-9]*)', 'd', '%f', 8, np.float64, True, 1),
    'float64':('(-?[0-9]*\.?[0-9]*[eE]?[-\+]?[0-9]*)', 'd', '%f', 8, np.float64, True),
}

def max_precision(type1, type2):
    if decode_map[type1][5]:
        if decode_map[type2][5]:
            if decode_map[type1][3] < decode_map[type2][3]:
                return type2
            else:
                return type1
        else:
            if decode_map[type1][3] < decode_map[type2][3]:
                for t, c in decode_map.items():
                    if c[5] and c[3] >= decode_map[type2][3]:
                        return t
            else:
                return type1
    elif decode_map[type2][5]:
        if decode_map[type2][3] < decode_map[type1][3]:
            for t, c in decode_map.items():
                if c[5] and c[3] >= decode_map[type1][3]:
                    return t
        else:
            return type2
    else:
        if decode_map[type2][3] < decode_map[type1][3]:
            return type1
        elif decode_map[type2][3] < decode_map[type1][3]:
            return type2
        elif decode_map[type2][4] == decode_map[type2][4]:
            return type1
        else:
            for t, c in decode_map.items():
                if not c[5] and c[3] > decode_map[type1][3]:
                    return t
            return 'unsigned long long'

def decode(content, structure, num, form):
    if form.lower() == 'ascii':
        l = 0; d = []; lines = content.split('\n')
        for i in range(num):
            s = [j for j in lines[i].split(' ') if len(j) > 0]
            l += len(lines[i]) + 1; k = []; j = 0
            while j < len(s) and len(k) < len(structure):
                t = structure[len(k)]
                if t[:4] == 'list':
                    n = int(s[j]); t = t.split(':')[-1]
                    k += [[float(s[i]) if decode_map[t][5] else int(s[i]) \
                        for i in range(j+1,j+n+1)]]
                    j += n + 1
                else:
                    k += [float(s[j]) if decode_map[t][5] else int(s[j])]
                    j += 1
            d += [k]
    else:
        if form.lower() == 'binary_little_endian':
            c = '<'
        elif form.lower() == 'binary_big_endian':
            c = '>'
        l = 0; d = []
        for i in range(num):
            k = []
            while len(k) < len(structure) and l < len(content):
                t = structure[len(k)]
                if t[:4] == 'list':
                    t = t.split(':')
                    n = struct.unpack(c+decode_map[t[1]][1], \
                        content[l:l+decode_map[t[1]][3]])[0]
                    l += decode_map[t[1]][3]
                    k += [struct.unpack(c+decode_map[t[2]][1]*n, \
                        content[l:l+decode_map[t[2]][3]*n])]
                    l += decode_map[t[2]][3]*n
                else:
                    k += [struct.unpack(c+decode_map[t][1], \
                        content[l:l+decode_map[t][3]])[0]]
                    l += decode_map[t][3]
            d += [k]
    try:
        t = reduce(max_precision, [t if t[:4] != 'list' \
            else t.split(':')[-1] for t in structure])
        d = np.array(d, dtype = decode_map[t][4])
    except ValueError:
        print('Warning: Not in Matrix')
    return d, content[l:]

def load_ply(file_name):
    try:
        with open(file_name, 'r') as f:
            head = f.readline().strip()
            if head.lower() != 'ply':
                raise('Error: Not a valid PLY file')
            content = f.read()
            i = content.find('end_header\n')
            if i < 0:
                raise('Error: Not a valid PLY file')
            info = [[l for l in line.split(' ') if len(l) > 0] \
                for line in content[:i].split('\n')]
            content = content[i+11:]
    except UnicodeDecodeError as e:
        with open(file_name, 'rb') as f:
            head = f.readline().strip()
            if sys.version_info[0] == 3:
                head = str(head)[2:-1]
            else:
                head = str(head)
            if head.lower() != 'ply':
                raise('Error: Not a valid PLY file')
            content = f.read()
            i = content.find(b'end_header\n')
            if i < 0:
                raise('Error: Not a valid PLY file')
            if sys.version_info[0] == 3:
                cnt = str(content[:i])[2:-1].replace('\\n', '\n')
            else:
                cnt = str(content[:i])
            info = [[l for l in line.split(' ') if len(l) > 0] \
                for line in cnt.split('\n')]
            content = content[i+11:]
    form = 'ascii'
    elem_names = []
    elem = {}
    for i in info:
        if len(i) >= 2 and i[0] == 'format':
            form = i[1]
        elif len(i) >= 3 and i[0] == 'element':
            if len(elem_names) > 0:
                elem[elem_names[-1]] = (structure_name, structure)
            elem_names += [(i[1], int(i[2]))]
            structure_name = []
            structure = []
        elif len(i) >= 3 and i[0] == 'property' and len(elem_names) > 0:
            structure_name += [i[-1]]
            if i[1] == 'list' and len(i) >= 5:
                structure += [i[1] + ':' + i[2] + ':' + ' '.join(i[3:-1])]
            else:
                structure += [' '.join(i[1:-1])]
    if len(elem_names) > 0:
        elem[elem_names[-1]] = (structure_name, structure)
    elem_ = {}
    for k in elem_names:
        d, content = decode(content, elem[k][1], k[1], form)
        if 'face' in k[0] and isinstance(d, np.ndarray):
            d = d.reshape((k[1], -1))
        elem_[k[0]] = d # elem[k] = (elem[k][0], d)
    return elem_

def save_ply(file_name, elems, _type = 'binary_little_endian', comments = []):
    _type = _type.lower()
    types = {}
    if isinstance(comments, str):
        comments = [comments]
    comments = [c for l in comments for c in l.split('\n')]
    with open(file_name, 'w') as f:
        f.write('ply\nformat %s 1.0\n' % _type)
        for comment in comments:
            f.write('comment %s\n' % comment)
        for key, elem in elems.items():
            f.write('element %s %d\n' % (key, len(elem)))
            if isinstance(elem, np.ndarray):
                for e, c in decode_map.items():
                    if len(c) > 6 and c[4] == elem.dtype:
                        c = (e, c[1], c[2]); break
            else:
                c = ('int', 'i', '%d')
                for e in elem:
                    if hasattr(e, '__len__'):
                        for i in range(len(e)):
                            if int(e[i]) != e[i]:
                                c = ('float','f','%f'); break
                        if i != len(e): break
                    elif int(e) != e:
                        c = ('float','f','%f'); break
            if 'face' in key:
                tag = 'vertex_index'
                max_num = max([len(e) for e in elem])
                if max_num < 256:
                    l = ('uchar', 'B', '%d')
                elif max_num < 65536:
                    l = ('ushort', 'H', '%d')
                elif max_num < 4294967296:
                    l = ('uint', 'I', '%d')
                else:
                    l = ('uint64', 'Q', '%d')
                f.write('property list %s %s %s\n' % \
                    (l[0], c[0], tag))
                types[key] = (l, c)
            else:
                if 'vert' in key:
                    if len(elem) > 0:
                        if len(elem[0]) > 4:
                            tag = ['x', 'y', 'z', 'red', 'green', 'blue', 'alpha']
                        else:
                            tag = ['x', 'y', 'z', 'w']
                elif 'norm' or 'texcoord' in key:
                    tag = ['x', 'y', 'z', 'w']
                elif 'color' in key:
                    tag = ['red', 'green', 'blue', 'alpha']
                else:
                    tag = []
                if len(elem) > 0:
                    for j in range(len(elem[0])):
                        f.write('property %s %s\n' % (c[0], tag[j] \
                            if len(tag) > j else 'k%d'%(j-len(tag))))
                types[key] = c
        f.write('end_header\n')
    with open(file_name, 'ab' if 'binary' in _type else 'a') as f:
        for key, elem in elems.items():
            c = types[key]
            if len(c) == 2:
                for e in elem:
                    l = len(e)
                    if 'ascii' in _type:
                        f.write((c[0][2]+(' '+c[1][2])*l+'\n') % \
                            tuple([l]+list(e)))
                    elif 'little' in _type:
                        f.write(struct.pack('<'+c[0][1], l))
                        f.write(struct.pack('<'+c[1][1]*l, *e))
                    elif 'big' in _type:
                        f.write(struct.pack('>'+c[0][1], l))
                        f.write(struct.pack('>'+c[1][1]*l, *e))
            else:
                seg = len(elem[0]) if len(elem) > 0 else 1
                elem = [i for l in elem for i in l] \
                    if isinstance(elem, np.ndarray) else elem.reshape(-1)
                if 'ascii' in _type:
                    for i in range(len(elem)):
                        f.write((c[2] + '\n' if (i + 1) % seg == 0 \
                            else c[2] + ' ') % elem[i])
                elif 'little' in _type:
                    f.write(struct.pack('<'+c[1]*len(elem), *elem))
                elif 'big' in _type:
                    f.write(struct.pack('>'+c[1]*len(elem), *elem))


def normalize_v3(arr):
    ''' Normalize a numpy array of 3 component vectors shape=(n,3) '''
    lens = np.sqrt(arr[:, 0] ** 2 + arr[:, 1] ** 2 + arr[:, 2] ** 2)
    eps = 0.00000001
    lens[lens < eps] = eps
    arr[:, 0] /= lens
    arr[:, 1] /= lens
    arr[:, 2] /= lens
    return arr


def compute_normal(vertices, faces):
    # Create a zeroed array with the same type and shape as our vertices i.e., per vertex normal
    norm = np.zeros(vertices.shape, dtype=vertices.dtype)
    # Create an indexed view into the vertex array using the array of three indices for triangles
    tris = vertices[faces]
    # Calculate the normal for all the triangles, by taking the cross product of the vectors v1-v0, and v2-v0 in each triangle
    n = np.cross(tris[::, 1] - tris[::, 0], tris[::, 2] - tris[::, 0])
    # n is now an array of normals per triangle. The length of each normal is dependent the vertices,
    # we need to normalize these, so that our next step weights each normal equally.
    normalize_v3(n)
    # now we have a normalized array of normals, one per triangle, i.e., per triangle normals.
    # But instead of one per triangle (i.e., flat shading), we add to each vertex in that triangle,
    # the triangles' normal. Multiple triangles would then contribute to every vertex, so we need to normalize again afterwards.
    # The cool part, we can actually add the normals through an indexed view of our (zeroed) per vertex normal array
    norm[faces[:, 0]] += n
    norm[faces[:, 1]] += n
    norm[faces[:, 2]] += n
    normalize_v3(norm)

    return norm

def load_obj_mesh(mesh_file, with_normal=False, with_texture=False, with_texture_image=False):
	vertex_data = []
	norm_data = []
	uv_data = []

	face_data = []
	face_norm_data = []
	face_uv_data = []

	if isinstance(mesh_file, str):
		f = open(mesh_file, "r")
	else:
		f = mesh_file
	for line in f:
		if isinstance(line, bytes):
			line = line.decode("utf-8")
		if line.startswith('#'):
			continue
		values = line.split()
		if not values:
			continue

		if values[0] == 'v':
			v = list(map(float, values[1:4]))
			vertex_data.append(v)
		elif values[0] == 'vn':
			vn = list(map(float, values[1:4]))
			norm_data.append(vn)
		elif values[0] == 'vt':
			vt = list(map(float, values[1:3]))
			uv_data.append(vt)

		elif values[0] == 'f':
			# quad mesh
			if len(values) > 4:
				f = list(map(lambda x: int(x.split('/')[0]), values[1:4]))
				face_data.append(f)
				f = list(map(lambda x: int(x.split('/')[0]), [values[3], values[4], values[1]]))
				face_data.append(f)
			# tri mesh
			else:
				f = list(map(lambda x: int(x.split('/')[0]), values[1:4]))
				face_data.append(f)
			
			# deal with texture
			if len(values[1].split('/')) >= 2:
				# quad mesh
				if len(values) > 4:
					f = list(map(lambda x: int(x.split('/')[1]), values[1:4]))
					face_uv_data.append(f)
					f = list(map(lambda x: int(x.split('/')[1]), [values[3], values[4], values[1]]))
					face_uv_data.append(f)
				# tri mesh
				elif len(values[1].split('/')[1]) != 0:
					f = list(map(lambda x: int(x.split('/')[1]), values[1:4]))
					face_uv_data.append(f)
			# deal with normal
			if len(values[1].split('/')) == 3:
				# quad mesh
				if len(values) > 4:
					f = list(map(lambda x: int(x.split('/')[2]), values[1:4]))
					face_norm_data.append(f)
					f = list(map(lambda x: int(x.split('/')[2]), [values[3], values[4], values[1]]))
					face_norm_data.append(f)
				# tri mesh
				elif len(values[1].split('/')[2]) != 0:
					f = list(map(lambda x: int(x.split('/')[2]), values[1:4]))
					face_norm_data.append(f)
		elif 'mtllib' in line.split():
			mtlname = line.split()[-1]
			mtlfile = os.path.join(os.path.dirname(mesh_file), mtlname)
			with open(mtlfile, 'r') as fmtl:
				mtllines = fmtl.readlines()
				for mtlline in mtllines:
					# if mtlline.startswith('map_Kd'):
					if 'map_Kd' in mtlline.split():
						texname = mtlline.split()[-1]
						texfile = os.path.join(os.path.dirname(mesh_file), texname)
						texture_image = cv2.imread(texfile)
						texture_image = cv2.cvtColor(texture_image, cv2.COLOR_BGR2RGB)
						break

	vertices = np.array(vertex_data)
	faces = np.array(face_data) - 1

	if with_texture and with_normal:
		uvs = np.array(uv_data)
		face_uvs = np.array(face_uv_data) - 1
		norms = np.array(norm_data)
		if norms.shape[0] == 0:
			norms = compute_normal(vertices, faces)
			face_normals = faces
		else:
			norms = normalize_v3(norms)
			face_normals = np.array(face_norm_data) - 1
		if with_texture_image:
			return vertices, faces, norms, face_normals, uvs, face_uvs, texture_image
		else:
			return vertices, faces, norms, face_normals, uvs, face_uvs

	if with_texture:
		uvs = np.array(uv_data)
		face_uvs = np.array(face_uv_data) - 1
		return vertices, faces, uvs, face_uvs

	if with_normal:
		# norms = np.array(norm_data)
		# norms = normalize_v3(norms)
		# face_normals = np.array(face_norm_data) - 1
		norms = np.array(norm_data)
		if norms.shape[0] == 0:
			norms = compute_normal(vertices, faces)
			face_normals = faces
		else:
			norms = normalize_v3(norms)
			face_normals = np.array(face_norm_data) - 1
		return vertices, faces, norms, face_normals

	return vertices, faces


def write_obj_mesh(filename, verts, faces):
    with open(filename, 'w') as f:
        for vert in verts:
            f.write('v %f %f %f\n' % tuple(list(vert)))
        for face in faces+1:
            f.write('f %d %d %d\n' % tuple(list(face)))

================================================
FILE: lib/data/GeneBodyDataset.py
================================================
from re import sub
import torch
from torch.utils.data import Dataset
import torchvision.transforms as transforms
from PIL import Image
import numpy as np
import cv2
import sys
import os
import imageio
from ..mesh_util import save_obj_mesh
base_dir = os.path.dirname(os.path.abspath(__file__))
sys.path.append(os.path.join(base_dir,'..'))
sys.path = sys.path[:-1]
# import trimesh
from lib.ply_util import load_ply
from genebody.mesh import load_obj_mesh
import scipy.interpolate as interpolater
import random
from genebody.gender import genebody_gender

def mask_padding(mask, border = 5):
    kernel = np.ones((border, border), np.uint8)
    msk_erode = cv2.erode(mask.copy(), kernel)
    msk_dilate = cv2.dilate(mask.copy(), kernel)
    # retain the origin hard mask and create a soft padding
    mask = mask > 0
    mask[(msk_dilate - msk_erode) > 0] = 1
    return mask

def euler2rot(euler):
    sin, cos = np.sin, np.cos
    phi, theta, psi = euler[0], euler[1], euler[2]
    R1 = np.array([[1, 0, 0],
                   [0, cos(phi), sin(phi)],
                   [0, -sin(phi), cos(phi)]])
    R2 = np.array([[cos(theta), 0, -sin(theta)],
                   [0, 1, 0],
                   [sin(theta), 0, cos(theta)]])
    R3 = np.array([[cos(psi), sin(psi), 0],
                   [-sin(psi), cos(psi), 0],
                   [0, 0, 1]])
    R = R1 @ R2 @ R3
    return R

def rot2euler(R):
    phi = np.arctan2(R[1,2], R[2,2])
    theta = -np.arcsin(R[0,2])
    psi = np.arctan2(R[0,1], R[0,0])
    return np.array([phi, theta, psi])

def gen_cam_views(transl, z_pitch, viewnum):
    def viewmatrix(z, up, translation):
        vec3 = z / np.linalg.norm(z)
        up = up / np.linalg.norm(up)
        vec1 = np.cross(up, vec3)
        vec2 = np.cross(vec3, vec1)
        view = np.stack([vec1, vec2, vec3, translation], axis=1)
        view = np.concatenate([view, np.array([[0,0,0,1]])], axis=0)
        return view

    cam_poses = []
    for i, theta in enumerate(np.linspace(-np.pi/2, 1.5*np.pi, viewnum+1)[:-1]):
        theta = -theta
        dist = 2.9
        
        z = np.array([np.cos(theta), 0, np.sin(theta)])
        t = -z * dist + transl

        z = z * np.sqrt(1-z_pitch*z_pitch)
        z[1] = z_pitch
        z = z * dist
        up = np.array([0,1,0])
        view = viewmatrix(z, up, t)
        cam_poses.append(view)
    return cam_poses


class GeneBodyDataset(Dataset):
    @staticmethod
    def modify_commandline_options(parser):
        return parser
    def __init__(self, opt, phase='eval', root=None, move_cam=0):
        super(GeneBodyDataset, self).__init__()
        self.opt = opt
        self.is_train = phase == 'train'
        self.is_render = phase == 'render'
        self.projection_mode = 'perspective'
        self.eval_skip = self.opt.eval_skip
        self.train_skip = self.opt.train_skip
        self.genebody_seq_len = 150

        self.root = root if root is not None else opt.dataroot
        self.phase = 'val'
        self.load_size = self.opt.loadSize

        self.B_MIN = np.array([-128, -28, -128])
        self.B_MAX = np.array([128, 228, 128])

        self.num_views = self.opt.num_views
        self.input_views = [1,13,25,37]
        self.test_views = sorted(list(range(48)))
        # self.sequences = self.opt.ghr_seq if phase == 'train' else self.opt.ghr_test_seq
        self.split = np.load(os.path.join(self.root, 'genebody_split.npy'), allow_pickle=True).item()
        self.sequences = self.split['train'] if self.is_train else ['natacha']#self.split['test']
        self.frames, self.cam_names, self.subjects, self.frames_id = self.get_frames()
        self.load_smpl_param = any([self.opt.use_smpl_sdf, self.opt.use_t_pose])
        self.load_smpl_mesh = any([self.opt.use_smpl_sdf, self.opt.use_t_pose])
        self.smpl_type = self.opt.smpl_type
        self.smpl_t_pose = load_obj_mesh(os.path.join(self.opt.t_pose_path, f'{self.smpl_type}.obj'))
        self.use_smpl_depth = opt.use_smpl_depth
        # PIL to tensor
        self.to_tensor_normal = transforms.Compose([
            transforms.Resize(self.load_size),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])
        self.to_tensor = transforms.Compose([
            transforms.Resize(self.load_size),
            transforms.ToTensor()
        ])
        
        self.move_cam = move_cam if not self.is_train else 0
        self.use_white_bkgd = self.opt.use_white_bkgd

    def get_frames(self, i = 0):
        sequences = self.sequences
        frames, subjects, cam_names, frames_id = [], [], [], []
        i = self.input_views[i % self.num_views]
        for seq in sequences:
            if os.path.exists(os.path.join(self.root, seq)):
                files = sorted([f for f in os.listdir(os.path.join(self.root, \
                        seq, 'smpl')) \
                        if f[-4:] == '.obj'])
                files = sorted(files)
                cam_names += ['%02d' for i in range(len(files))]
                subjects += [seq for i in range(len(files))]
                frames_id += list(range(len(files)))

                for f in files:
                    f = f[:-4]
                    frames += [f]
        return frames, cam_names, subjects, frames_id

    def get_render_poses(self, annots, move_cam=150):
        height, pitch = [], []
        for view in range(1,48,3):
            view = '%02d' % view
            if view in annots.keys():
                height.append(annots[view]['c2w'][1, 3])
                z_rodrigous = annots[view]['c2w'][:3,:3]@np.array([[0],[0],[1]])
                pitch.append(z_rodrigous[1,0])
        transl = np.array([0, np.mean(np.array(height)), 0])
        z_pitch = np.mean(np.array(pitch))

        render_poses = gen_cam_views(transl, z_pitch, move_cam)
        return render_poses

    def __len__(self):
        if self.is_train:
            return len(self.frames) * len(self.test_views) // self.train_skip
        else:
            return len(self.frames) // self.eval_skip

    def image_cropping(self, mask):
        a = np.where(mask != 0)
        h, w = list(mask.shape[:2])
        if len(a[0]) > 0:
            top, left, bottom, right = np.min(a[0]), np.min(a[1]), np.max(a[0]), np.max(a[1])
        else:
            return 0,0,mask.shape[0],mask.shape[1]
        bbox_h, bbox_w = bottom - top, right - left

        # padd bbox
        bottom = min(int(bbox_h*0.1+bottom), h)
        top = max(int(top-bbox_h*0.1), 0)
        right = min(int(bbox_w*0.1+right), w)
        left = max(int(left-bbox_h*0.1), 0)
        bbox_h, bbox_w = bottom - top, right - left

        if bbox_h >= bbox_w:
            w_c = (left+right) / 2
            size = bbox_h
            if w_c - size / 2 < 0:
                left = 0
                right = size
            elif w_c + size / 2 >= w:
                left = w - size
                right = w
            else:
                left = int(w_c - size / 2)
                right = left + size
        else:   # bbox_w >= bbox_h
            h_c = (top+bottom) / 2
            size = bbox_w
            if h_c - size / 2 < 0:
                top = 0
                bottom = size
            elif h_c + size / 2 >= h:
                top = h - size
                bottom = h
            else:
                top = int(h_c - size / 2)
                bottom = top + size
        
        return top, left, bottom, right

    def get_near_far(self, smpl_verts, w2c):
        vp = smpl_verts.dot(w2c[:3,:3].T) + w2c[:3,3:].T
        vmin, vmax = vp.min(0), vp.max(0)
        near, far = vmin[2], vmax[2]
        near, far = near-(far-near)/2, far+(far-near)/2
        return near, far

    # def get_realworld_scale(self, smpl_verts, bbox, w2c, K):
    #     smpl_min, smpl_max = smpl_verts.min(0), smpl_verts.max(0)
    #     # reprojected smpl verts
    #     vp = smpl_verts.dot(w2c[:3,:3].T) + w2c[:3,3:].T
    #     vp = vp.dot(K.T)
    #     vp = vp[:,:2] / (vp[:,2:]+1e-8)
    #     vmin, vmax = vp.min(0), vp.max(0)
        
    #     # compare with bounding box
    #     bbox_h = bbox[1][0] - bbox[0][0]
    #     bbox_w = bbox[1][1] - bbox[0][1]
    #     long_axis = bbox_h/(vmax[1]-vmin[1])*(smpl_max[1]-smpl_min[1]) if bbox_h > bbox_w else bbox_w/(vmax[0]-vmin[0])*(smpl_max[0]-smpl_min[0])
    #     spatial_freq = 180/long_axis/0.5
    #     return spatial_freq

    def get_image(self, sid, num_views, view_id=None, random_sample=False, smpl_verts=None):
        frame = self.frames[sid]
        subject = self.subjects[sid]
        # some of the sequence has some view missing
        if subject == 'wuwenyan':
            test_views = list(set(self.test_views)-set([34, 36]))
        elif (subject == 'dannier' or subject == 'Tichinah_jervier'):
            test_views = list(set(self.test_views)-set([32]))
        elif subject == 'joseph_matanda':
            test_views = list(set(self.test_views) - set([39, 40, 42, 43, 44, 45, 46, 47]))
        elif subject in ['anastasia', 'aosilan']:
            test_views = list(set(self.test_views) - set(range(16,24)))
        else:
            test_views = self.test_views
        test_views = sorted(test_views)

        # Select a random view_id from self.max_view_angle if not given
        if self.is_train:
            if view_id is None or random_sample:
                view_id = test_views[np.random.randint(len(test_views))]
            else:
                view_id = test_views[view_id % len(test_views)]
            # The ids are an even distribution of num_views around view_id
            view_ids = self.input_views + [view_id]
        else:
            if self.is_render:
                view_ids = self.input_views
            else:
                view_ids = self.input_views + test_views

        calib_list = []
        image_list = []
        mask_list = []
        extrinsic_list = []
        bbox_list = []
        smpl_depth_list = []
        spatial_freqs = []
        annot_path = os.path.join(self.root, subject, f'annots.npy')
        annots = np.load(annot_path, allow_pickle = True).item()['cams']
        for i, vid in enumerate(view_ids): 
            view = '%02d' % vid
            mask_folder = 'mask'
            mask_path = os.path.join(self.root, subject, mask_folder, self.cam_names[sid] % vid)
            mask_path = [os.path.join(mask_path,f) for f in os.listdir(mask_path) \
                        if frame in f]
            image_path = os.path.join(self.root, subject, 'image', self.cam_names[sid] % vid)
            image_path = [os.path.join(image_path,f) for f in os.listdir(image_path) \
                          if frame in f]
            image_np = imageio.imread(image_path[0])
            mask_np = imageio.imread(mask_path[0])
            size = image_np.shape
            if self.use_smpl_depth and i < self.num_views:
                smpl_depth_path = os.path.join(self.root, subject, 'smpl_depth', self.cam_names[sid] % vid)
                smpl_depth_path = [os.path.join(smpl_depth_path,f) for f in os.listdir(smpl_depth_path) \
                                    if frame in f]
                smpl_depth = imageio.imread(smpl_depth_path[0])
                smpl_depth = smpl_depth.astype(np.float32) / 1000.0
            top, left, bottom, right = self.image_cropping(mask_np)
            
            mask_np = mask_np[top:bottom, left:right]
            image_crop = image_np[top:bottom, left:right]
            
            mask_np = cv2.resize(mask_np.copy(), (self.load_size,self.load_size), \
                              interpolation = cv2.INTER_NEAREST)
            image_crop = cv2.resize(image_crop.copy(), (self.load_size,self.load_size), \
                              interpolation = cv2.INTER_CUBIC)
            image = Image.fromarray(image_crop)
            mask_np = mask_np > 128
            if self.use_smpl_depth and i < self.num_views:
                smpl_depth = smpl_depth[top:bottom, left:right]
                smpl_depth = cv2.resize(smpl_depth, (self.load_size,self.load_size), \
                                interpolation = cv2.INTER_NEAREST)
                mask_np = np.logical_or(mask_np, smpl_depth > 0)
                smpl_depth_list.append(torch.from_numpy(smpl_depth))

            a = np.where(mask_np != 0)
            try:
                bbox = [[np.min(a[0]), np.min(a[1])], [np.max(a[0]), np.max(a[1])]] if len(a[0]) > 0 else \
                    [[0, 0], [self.load_size, self.load_size]]
            except:
                print(os.path.join(self.root, subject, mask_folder, self.cam_names[sid] % vid))
                print(top, left, bottom, right)
                print(mask_np)
                exit(0)
            bbox_list.append(bbox)
            mask = torch.from_numpy(mask_np.astype(np.float32)).view(1,self.load_size,self.load_size)
            mask_list.append(mask)
            image = self.to_tensor(image) if i >= num_views else self.to_tensor_normal(image)
            if i >= self.num_views and self.use_white_bkgd:
                image = image * mask + (1. - mask)
            image = mask.type(image.dtype).expand(3,-1,-1) * image
            rgb = image.cpu().numpy().transpose([1,2,0])
            
            K = np.array(annots[view]['K'], dtype=np.float32)
            K[0,2] -= left
            K[1,2] -= top
            K[0,:] *= self.load_size / float(right - left)
            K[1,:] *= self.load_size / float(bottom - top)
            c2w = np.array(annots[view]['c2w'], dtype=np.float32)
            w2c = np.linalg.inv(c2w)
            dist = np.array(annots[view]['D'], dtype = np.float32)
            # determine near far plane from smpl estimation
            near, far = self.get_near_far(smpl_verts, w2c)

            # # determine valid body part from smpl and bounding box
            # if i < self.num_views:
            #     spatial_freq = self.get_realworld_scale(smpl_verts, bbox, w2c, K)
            #     spatial_freqs.append(spatial_freq)

            calib = torch.Tensor([K[0,0],K[1,1],K[0,2],K[1,2]]+list(dist.reshape(-1))+[near,far]).float()
            extrinsic = torch.from_numpy(w2c)
            image_list.append(image)
            calib_list.append(calib)
            extrinsic_list.append(extrinsic)
        # bbox = np.array(bbox_list).reshape(-1,4)
        if not self.is_train and self.move_cam > 0:
            bboxs = np.array(bbox_list[:self.num_views]).reshape(-1,2)
        else:
            bboxs = np.array(bbox_list[self.num_views:]).reshape(-1,2)
        centroid = np.array([mask_np.shape[0], mask_np.shape[1]]) / 2
        bbox = (np.max(np.abs(bboxs - centroid), axis=0) * \
                np.array([1, np.sqrt(2)])).astype(np.int32)
        bbox = np.array([centroid - bbox, centroid + bbox]).T
        bbox = np.clip(bbox.reshape(-1), 0, self.load_size)
        # spatial_freq = min(spatial_freqs)

        if self.is_render:
            # render free view point video on full image resolution 
            render_id = sid % (self.genebody_seq_len // self.eval_skip)
            render_c2ws = self.get_render_poses(annots, self.move_cam)
            w2c = np.linalg.inv(render_c2ws[render_id])
            K = annots['25']['K']

            render_extrinsics = torch.from_numpy(w2c.astype(np.float32))
            near, far = self.get_near_far(smpl_verts, w2c)
            render_calibs = torch.Tensor([K[0,0],K[1,1],K[0,2],K[1,2]]+list(np.zeros_like(dist))+[near,far]).float()
            bbox = np.array([0, size[0], 0, size[1]])
            extrinsic_list = extrinsic_list[:self.num_views] + [render_extrinsics]
            calib_list = calib_list[:self.num_views] + [render_calibs]
            mask_list = mask_list[:self.num_views]

        if not self.is_train and self.move_cam is 0:
            gt_list = image_list[num_views:]
            image_list = image_list[:num_views]
        return {
            'img': torch.stack(image_list, dim=0),
            'mask': torch.stack(mask_list[:num_views], dim=0),
            'persps': torch.stack(calib_list, dim=0),
            'calib': torch.stack(extrinsic_list, dim=0),
            'bbox': bbox,
            'render_gt': torch.stack(gt_list, dim = 0) if not self.is_train and self.move_cam is 0 else [],
            'smpl_depth': torch.stack(smpl_depth_list[:self.num_views], dim=0) if self.opt.use_smpl_depth else [],
            'img_i': view_ids,
            # 'body_scale': spatial_freq,
            'center': torch.from_numpy((smpl_verts.max(0)+smpl_verts.min(0))/2),
        }

    def smpl_from_param(self, model_path, subject, smpl_param, smpl_scale):
        import smplx
        smpl = smplx.SMPLX(
            model_path=model_path,
            gender='NEUTRAL',#genebody_gender[subject],
            use_pca=False,
        )
        smpl_param = smpl_param.copy()
        for key in smpl_param.keys():
            if isinstance(smpl_param[key], np.ndarray):
                smpl_param[key] = torch.from_numpy(smpl_param[key])
        
        output = smpl(**smpl_param)
        verts = output['vertices'].numpy().reshape(-1,3)

        # To align with keypoints3d saved in param, use the base keypoints only,
        # if you want to use the full keypoints3d with extra joints and landmarks,
        # please refer the the definition of joints in
        # vertex_joint_selector.py and landmarks in vertices2landmarks in lbs.py
        keypoints3d = output['joints'].numpy().reshape(-1,3)[:55]
        
        return verts*smpl_scale, smpl.faces, keypoints3d*smpl_scale

    def get_item(self, index):
        sid = index % len(self.frames)
        vid =(index // len(self.frames)) % len(self.test_views)
        frame = self.frames[sid]
        subject = self.subjects[sid]
        old = 'old' if os.path.exists(os.path.join(self.root, subject, 'oldsmpl')) else ''
        res = {
            'name': subject+'_'+frame+'_'+str(vid),
            'mesh_path': os.path.join(self.root, subject, 'smpl', '%d.ply' % sid),
            'sid': sid,
            'vid': vid,
        }

        # load smpl data
        param_dir = os.path.join(self.root, subject, f'{old}param')
        param_path = [os.path.join(param_dir,f) for f in os.listdir(param_dir) if frame in f]
        param = np.load(os.path.join(param_path[0]), allow_pickle=True).item()
        scale, param = param['smplx_scale'], param['smplx']
        res['body_scale'] = scale
        for key in param.keys():
            if isinstance(param[key], torch.Tensor):
                param[key] = param[key].numpy()
        # param['jaw_pose'] = np.zeros_like(param['jaw_pose'])

        # smplx_model_path = '/home/SENSETIME/chengwei/Projects/bodyfitting_release/data/smplx'
        # vert, face, kp3d = self.smpl_from_param(smplx_model_path, subject, param, scale)

        smpl_dir = os.path.join(self.root, subject, f'{old}smpl')
        # os.makedirs(smpl_dir, exist_ok=True)
        # save_obj_mesh(os.path.join(smpl_dir, frame+'.obj'), vert, face[...,::-1])
        smpl_path = [os.path.join(smpl_dir,f) for f in os.listdir(smpl_dir) if frame in f][0]
        if smpl_path[-4:] == '.obj':
            vert, face = load_obj_mesh(smpl_path)
        else:
            smpl = load_ply(smpl_path)
            vert, face = smpl['vertex'][:,:3], smpl['face']
        vert = vert.astype(np.float32)
        # load image data
        image_data = self.get_image(sid, num_views=self.num_views, view_id=vid,
                                      random_sample=self.opt.random_multiview, smpl_verts=vert)
        res.update(image_data)

        T = cv2.Rodrigues(param['global_orient'].reshape(-1, 3)[:1])[0]
        res['bbox'] = np.array(res['bbox'])
        res['smpl_rot'] = torch.from_numpy(T.astype(np.float32)) \
                          if self.load_smpl_mesh else []
        res['smpl_verts']= torch.from_numpy(vert.astype(np.float32)) \
                          if self.load_smpl_mesh else []
        res['smpl_faces']= torch.from_numpy(face.astype(np.int32)) \
                          if self.load_smpl_mesh else []
        res['smpl_betas']= torch.from_numpy(param['betas'].reshape(-1).astype(np.float32)) \
                          if self.load_smpl_param else []
        if self.load_smpl_param:
            t_vert, t_face = self.smpl_t_pose
            res['smpl_t_verts'] = t_vert
            res['smpl_t_faces'] = t_face
        if self.opt.use_t_pose:
            res['smpl_t_verts'] = torch.from_numpy(res['smpl_t_verts'].astype(np.float32))
            res['smpl_t_faces'] = torch.from_numpy(res['smpl_t_faces'].astype(np.int32))
        else:
            res['smpl_t_verts'] = []
            res['smpl_t_faces'] = []

        return res

    def __getitem__(self, index):
        if not self.is_train:
            index *= self.eval_skip
        return self.get_item(index)


================================================
FILE: lib/data/__init__.py
================================================


================================================
FILE: lib/geometry.py
================================================
import torch
import numpy as np

def rot2euler(R):
    phi = np.arctan2(R[1,2], R[2,2])
    theta = -np.arcsin(R[0,2])
    psi = np.arctan2(R[0,1], R[0,0])
    return np.array([phi, theta, psi])

def euler2rot(euler):
    sin, cos = np.sin, np.cos
    phi, theta, psi = euler[0], euler[1], euler[2]
    R1 = np.array([[1, 0, 0],
                [0, cos(phi), sin(phi)],
                [0, -sin(phi), cos(phi)]])
    R2 = np.array([[cos(theta), 0, -sin(theta)],
                [0, 1, 0],
                [sin(theta), 0, cos(theta)]])
    R3 = np.array([[cos(psi), sin(psi), 0],
                [-sin(psi), cos(psi), 0],
                [0, 0, 1]])
    R = R1 @ R2 @ R3
    return R

def batch_rodrigues(theta):
    """Convert axis-angle representation to rotation matrix.
    Args:
        theta: size = [B, 3]
    Returns:
        Rotation matrix corresponding to the quaternion -- size = [B, 3, 3]
    """
    l1norm = torch.norm(theta + 1e-8, p = 2, dim = 1)
    angle = torch.unsqueeze(l1norm, -1)
    normalized = torch.div(theta, angle)
    angle = angle * 0.5
    v_cos = torch.cos(angle)
    v_sin = torch.sin(angle)
    quat = torch.cat([v_cos, v_sin * normalized], dim = 1)
    return quat_to_rotmat(quat)

def quat_to_rotmat(quat):
    """Convert quaternion coefficients to rotation matrix.
    Args:
        quat: size = [B, 4] 4 <===>(w, x, y, z)
    Returns:
        Rotation matrix corresponding to the quaternion -- size = [B, 3, 3]
    """
    norm_quat = quat
    norm_quat = norm_quat/norm_quat.norm(p=2, dim=1, keepdim=True)
    w, x, y, z = norm_quat[:,0], norm_quat[:,1], norm_quat[:,2], norm_quat[:,3]

    B = quat.size(0)

    w2, x2, y2, z2 = w.pow(2), x.pow(2), y.pow(2), z.pow(2)
    wx, wy, wz = w*x, w*y, w*z
    xy, xz, yz = x*y, x*z, y*z

    rotMat = torch.stack([w2 + x2 - y2 - z2, 2*xy - 2*wz, 2*wy + 2*xz,
                          2*wz + 2*xy, w2 - x2 + y2 - z2, 2*yz - 2*wx,
                          2*xz - 2*wy, 2*wx + 2*yz, w2 - x2 - y2 + z2], dim=1).view(B, 3, 3)
    return rotMat


def index(feat, uv, mode='bilinear'):
    '''
    :param feat: [B, C, H, W] image features
    :param uv: [B, 2, N] uv coordinates in the image plane, range [-1, 1]
    :return: [B, C, N] image features at the uv coordinates
    '''
    uv = uv.transpose(1, 2)  # [B, N, 2]
    uv = uv.unsqueeze(2)  # [B, N, 1, 2]
    # NOTE: for newer PyTorch, it seems that training results are degraded due to implementation diff in F.grid_sample
    # for old versions, simply remove the aligned_corners argument.
    # if torch.__version__ >= "1.3.0":
    #     samples = torch.nn.functional.grid_sample(feat, uv, align_corners=True)  # [B, C, N, 1]
    # else:
    samples = torch.nn.functional.grid_sample(feat, uv, mode=mode)
    return samples[:, :, :, 0]  # [B, C, N]


def orthogonal(points, calibrations, transforms=None):
    '''
    Compute the orthogonal projections of 3D points into the image plane by given projection matrix
    :param points: [B, 3, N] Tensor of 3D points
    :param calibrations: [B, 4, 4] Tensor of projection matrix
    :param transforms: [B, 2, 3] Tensor of image transform matrix
    :return: xyz: [B, 3, N] Tensor of xyz coordinates in the image plane
    '''
    rot = calibrations[:, :3, :3]
    trans = calibrations[:, :3, 3:4]
    pts = torch.baddbmm(trans, rot, points)  # [B, 3, N]
    if transforms is not None:
        scale = transforms[:2, :2]
        shift = transforms[:2, 2:3]
        pts[:, :2, :] = torch.baddbmm(shift, scale, pts[:, :2, :])
    return pts



def perspective(points, w2c, camera):
    '''
    Compute the perspective projections of 3D points into the image plane by given projection matrix
    :param points: [Bx3xN] Tensor of 3D points
    :param calibrations: [Bx4/9] Tensor of projection matrix
    :param transforms: [Bx4x4] Tensor of image transform matrix
    :return: xy: [Bx2xN] Tensor of xy coordinates in the image plane
    '''
    rot = w2c[:, :3, :3]
    trans = w2c[:, :3, 3:4]
    points = torch.baddbmm(trans, rot, points)  # [B, 3, N]
    xy  = points[:,:2, :] / torch.clamp(points[:,2:3,:], 1e-9)
    if camera.shape[1] > 6:
        x2 = xy[:,0,:]*xy[:,0,:]
        y2 = xy[:,1,:]*xy[:,1,:]
        xy_= xy[:,0,:]*xy[:,1,:]
        r2 = x2 + y2
        c = (1 + r2*(camera[:,4:5]+r2*(camera[:,5:6]+r2*camera[:,8:9])))
        xy = c.unsqueeze(1)*xy + torch.cat([ \
                  (camera[:,6:7]*2*xy_+ camera[:,7:8]*(r2+2*x2)).unsqueeze(1),\
                  (camera[:,7:8]*2*xy_+ camera[:,6:7]*(r2+2*y2)).unsqueeze(1)],1)
    xy = camera[:,0:2,None]*xy + camera[:,2:4,None]
    points[:,:2, :] = xy
    return points


================================================
FILE: lib/mesh_util.py
================================================
from skimage import measure
import numpy as np
import torch
from skimage import measure


def save_obj_mesh(mesh_path, verts, faces):
    file = open(mesh_path, 'w')

    for v in verts:
        file.write('v %.4f %.4f %.4f\n' % (v[0], v[1], v[2]))
    for f in faces:
        f_plus = f + 1
        file.write('f %d %d %d\n' % (f_plus[0], f_plus[2], f_plus[1]))
    file.close()


def save_obj_mesh_with_color(mesh_path, verts, faces, colors):
    file = open(mesh_path, 'w')

    for idx, v in enumerate(verts):
        c = colors[idx]
        file.write('v %.4f %.4f %.4f %.4f %.4f %.4f\n' % (v[0], v[1], v[2], c[0], c[1], c[2]))
    for f in faces:
        f_plus = f + 1
        file.write('f %d %d %d\n' % (f_plus[0], f_plus[2], f_plus[1]))
    file.close()


def save_obj_mesh_with_uv(mesh_path, verts, faces, uvs):
    file = open(mesh_path, 'w')

    for idx, v in enumerate(verts):
        vt = uvs[idx]
        file.write('v %.4f %.4f %.4f\n' % (v[0], v[1], v[2]))
        file.write('vt %.4f %.4f\n' % (vt[0], vt[1]))

    for f in faces:
        f_plus = f + 1
        file.write('f %d/%d %d/%d %d/%d\n' % (f_plus[0], f_plus[0],
                                              f_plus[2], f_plus[2],
                                              f_plus[1], f_plus[1]))
    file.close()


================================================
FILE: lib/metrics.py
================================================
import sys
import os

import torch

sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
ROOT_PATH = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))

# import torch
# torch.autograd.set_detect_anomaly(True)
# import torch.nn as nn
# import torch.nn.functional as F

import numpy as np
import trimesh
import cv2
from scipy.spatial import KDTree
import lpips

lpips_net = lpips.LPIPS(net='alex')
def chamfer(x_verts, gt_verts, x_normals=None, gt_normals=None):

    searcher = KDTree(gt_verts)
    dists, inds = searcher.query(x_verts)

    if x_normals is None or gt_normals is None:
        return dists
    elif x_normals is not None and gt_normals is not None:
        cosine_dists = 1 - np.sum(x_normals * gt_normals[inds], axis=1)
        return dists, cosine_dists
    else:
        raise Exception("provide normals for both point sets")



def fscore(dist1, dist2, threshold=1.0):
    """
    Calculates the F-score between two point clouds with the corresponding threshold value.
    :param dist1: N-Points
    :param dist2: N-Points
    :param th: float
    :return: fscore, precision, recall
    """
    # NB : In this depo, dist1 and dist2 are squared pointcloud euclidean distances, so you should adapt the threshold accordingly.
    precision_1 = np.mean((dist1 < threshold).astype(np.float32))
    precision_2 = np.mean((dist2 < threshold).astype(np.float32))
    fscore = 2 * precision_1 * precision_2 / (precision_1 + precision_2)
    return fscore, precision_1, precision_2


def psnr(x, gt):
    """
    x: np.uint8, HxWxC, 0 - 255
    gt: np.uint8, HxWxC, 0 - 255
    """
    x = (x / 255.).astype(np.float32)
    gt = (gt / 255.).astype(np.float32)
    mse = ((x - gt) ** 2).mean()
    psnr = 10. * np.log10(1. / mse)

    # return mse, psnr
    return psnr

def ssim_channel(x, gt):

    C1 = (0.01 * 255) ** 2
    C2 = (0.03 * 255) ** 2
    x = x.astype(np.float32)
    gt = gt.astype(np.float32)
    kernel = cv2.getGaussianKernel(11, 1.5)
    window = np.outer(kernel, kernel.transpose())
    mu1 = cv2.filter2D(x, -1, window)[5:-5, 5:-5]  # valid
    mu2 = cv2.filter2D(gt, -1, window)[5:-5, 5:-5]
    mu1_sq = mu1 ** 2
    mu2_sq = mu2 ** 2
    mu1_mu2 = mu1 * mu2
    sigma1_sq = cv2.filter2D(x ** 2, -1, window)[5:-5, 5:-5] - mu1_sq
    sigma2_sq = cv2.filter2D(gt ** 2, -1, window)[5:-5, 5:-5] - mu2_sq
    sigma12 = cv2.filter2D(x * gt, -1, window)[5:-5, 5:-5] - mu1_mu2
    ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) *
                                                            (sigma1_sq + sigma2_sq + C2))
    return ssim_map.mean()


def ssim(x, gt):
    '''calculate SSIM
    the same outputs as MATLAB's
    x: np.uint8, HxWxC, 0 - 255
    gt: np.uint8, HxWxC, 0 - 255
    '''
    if not x.shape == gt.shape:
        raise ValueError('Input images must have the same dimensions.')
    if x.ndim == 2:
        return ssim_channel(x, gt)
    elif x.ndim == 3:
        if x.shape[2] == 3:
            ssims = []
            for i in range(3):
                ssims.append(ssim_channel(x, gt))
            return np.array(ssims).mean()
        elif x.shape[2] == 1:
            return ssim_channel(np.squeeze(x), np.squeeze(gt))
    else:
        raise ValueError('input image dimension mismatch.')


def lpips(x, gt, net=lpips_net):
    x = torch.from_numpy(x).float() / 255. * 2 - 1.
    gt = torch.from_numpy(gt).float() / 255. * 2 - 1.
    x = x.permute([2, 0, 1]).unsqueeze(0)
    gt = gt.permute([2, 0, 1]).unsqueeze(0)
    with torch.no_grad():
        loss = net.forward(x, gt)
    return loss.item()




if __name__ == "__main__":
    import trimesh
    x = trimesh.load("./data/human/SMPL/10315_m_John/smpl.obj", process=False)
    y = trimesh.load("./data/human/SMPL/10316_m_John/smpl.obj", process=False)

    print(chamfer(x.vertices, y.vertices, x.vertex_normals, y.vertex_normals))


================================================
FILE: lib/metrics_torch.py
================================================
import torch
import torch.nn.functional as F
from math import exp
import numpy as np
import lpips


class LPIPS(torch.nn.Module):
    def __init__(self):
        super(LPIPS, self).__init__()
        self.net = lpips.LPIPS(net='alex', verbose=False)

    def forward(self, x, gt):
        if torch.max(gt) > 128:
            # [0, 255]
            x = x / 255. * 2 - 1
            gt = gt / 255. * 2 - 1
        elif torch.min(gt) >= 0 and torch.max(gt) <= 1:
            # [0, 1]
            x = x * 2 - 1
            gt = gt * 2 - 1
        with torch.no_grad():
            loss = self.net.forward(x, gt)
        # return loss.item()
        return loss

def psnr(x, gt):
    """
    x: np.uint8, HxWxC, 0 - 255
    gt: np.uint8, HxWxC, 0 - 255
    """
    if torch.max(gt) > 128:
        # [0, 255]
        x = x / 255
        gt = gt / 255
    elif torch.min(gt) < -0.5:
        # [0, 1]
        x = (x+1)/2
        gt = (gt+1)/2

    mse = torch.mean((x - gt) ** 2)
    psnr = -10. * torch.log10(mse)
    return psnr


def gaussian(window_size, sigma):
    gauss = torch.Tensor([exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)])
    return gauss/gauss.sum()


def create_window(window_size, channel=1):
    _1D_window = gaussian(window_size, 1.5).unsqueeze(1)
    _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
    window = _2D_window.expand(channel, 1, window_size, window_size).contiguous()
    return window


def ssim_(img1, img2, window_size=11, window=None, size_average=True, full=False, val_range=None):
    # Value range can be different from 255. Other common ranges are 1 (sigmoid) and 2 (tanh).\

    if val_range is None:
        if torch.max(img1) > 128:
            max_val = 255
        else:
            max_val = 1

        if torch.min(img1) < -0.5:
            min_val = -1
        else:
            min_val = 0
        L = max_val - min_val
    else:
        L = val_range

    padd = 0
    (_, channel, height, width) = img1.size()
    if window is None:
        real_size = min(window_size, height, width)
        window = create_window(real_size, channel=channel).to(img1.device)

    mu1 = F.conv2d(img1, window, padding=padd, groups=channel)
    mu2 = F.conv2d(img2, window, padding=padd, groups=channel)

    mu1_sq = mu1.pow(2)
    mu2_sq = mu2.pow(2)
    mu1_mu2 = mu1 * mu2

    sigma1_sq = F.conv2d(img1 * img1, window, padding=padd, groups=channel) - mu1_sq
    sigma2_sq = F.conv2d(img2 * img2, window, padding=padd, groups=channel) - mu2_sq
    sigma12 = F.conv2d(img1 * img2, window, padding=padd, groups=channel) - mu1_mu2

    C1 = (0.01 * L) ** 2
    C2 = (0.03 * L) ** 2

    v1 = 2.0 * sigma12 + C2
    v2 = sigma1_sq + sigma2_sq + C2
    cs = v1 / v2  # contrast sensitivity

    ssim_map = ((2 * mu1_mu2 + C1) * v1) / ((mu1_sq + mu2_sq + C1) * v2)

    if size_average:
        cs = cs.mean()
        ret = ssim_map.mean()
    else:
        cs = cs.mean(1).mean(1).mean(1)
        ret = ssim_map.mean(1).mean(1).mean(1)

    if full:
        return ret, cs
    return ret


# Classes to re-use window
class SSIM(torch.nn.Module):
    def __init__(self, window_size=11, size_average=True, val_range=None):
        super(SSIM, self).__init__()
        self.window_size = window_size
        self.size_average = size_average
        self.val_range = val_range

        # Assume 1 channel for SSIM
        self.channel = 1
        self.window = create_window(window_size)

    def forward(self, img1, img2):
        if len(list(img1.shape)) < 4:
            img1 = img1.unsqueeze(0)
            img2 = img2.unsqueeze(0)
        (_, channel, _, _) = img1.size()

        if channel == self.channel and self.window.dtype == img1.dtype:
            window = self.window
        else:
            window = create_window(self.window_size, channel).to(img1.device).type(img1.dtype)
            self.window = window
            self.channel = channel

        return ssim_(img1, img2, window=window, window_size=self.window_size, size_average=self.size_average)


================================================
FILE: lib/model/Embedder.py
================================================

import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import numpy as np

class PositionalEncoding:
    """
    GNR uses positional encoding in NeRF for coordinate embedding 
    """
    def __init__(self, d, num_freqs=10, min_freq=None, max_freq=None, freq_type='linear'):
        self.num_freqs = num_freqs
        self.min_freq = min_freq
        self.max_freq = max_freq
        self.freq_type = freq_type
        self.create_embedding_fn(d)
        
    def create_embedding_fn(self, d):
        embed_fns = []
        out_dim = 0
        embed_fns.append(lambda x : x)
        out_dim += d
            
        N_freqs = self.num_freqs
        
        if self.freq_type == 'linear':
            min_freq = 0 if self.min_freq is None else self.min_freq
            max_freq = 2 ** (self.num_freqs-1) if self.max_freq is None else self.max_freq
            freq_bands = torch.linspace(min_freq*math.pi*2, max_freq*math.pi*2, steps=N_freqs) # linear freq band, Fourier expansion
        else:
            min_freq = 0 if self.min_freq is None else math.log2(self.min_freq)
            max_freq = self.num_freqs-1 if self.max_freq is None else math.log2(self.max_freq)
            freq_bands = 2.**torch.linspace(min_freq*math.pi*2, max_freq*math.pi*2, steps=N_freqs)    # log expansion

        for freq in freq_bands:
            for p_fn in [torch.sin, torch.cos]:
                embed_fns.append(lambda x, p_fn=p_fn, freq=freq : p_fn(x * freq))
                out_dim += d
                    
        self.embed_fns = embed_fns
        self.out_dim = out_dim
        
    def embed(self, inputs):
        return torch.cat([fn(inputs) for fn in self.embed_fns], -1)

class SphericalHarmonics:
    """
    GNR uses Sepherical Harmonics for view direction embedding 
    """
    def __init__(self, d = 3, rank = 3):
        assert d % 3 == 0
        self.rank = max([int(rank),0])
        self.out_dim= self.rank*self.rank * (d // 3)

    def Lengdre_polynormial(self, x, omx = None):
        if omx is None: omx = 1 - x * x
        Fml = [[]] *((self.rank+1)*self.rank//2)
        Fml[0] = torch.ones_like(x)
        for l in range(1, self.rank):
            b = (l * l + l) // 2
            Fml[b+l]  =-Fml[b-1]*(2*l-1)
            Fml[b+l-1]= Fml[b-1]*(2*l-1)*x
            for m in range(l,1,-1):
                Fml[b+m-2] = -(omx * Fml[b+m] + \
                         2*(m-1)*x * Fml[b+m-1]) / ((l-m+2)*(l+m-1))
        return Fml

    def SH(self, xyz):
        cs  = xyz[...,0:1]
        sn  = xyz[...,1:2]
        Fml = self.Lengdre_polynormial(xyz[...,2:3], cs*cs + sn*sn)
        H = [[]] *(self.rank*self.rank)
        for l in range(self.rank):
            b = l * l + l
            attr = np.sqrt((2*l+1)/math.pi/4)
            H[b] = attr * Fml[b//2]
            attr = attr * np.sqrt(2)
            snM = sn; csM = cs
            for m in range(1, l+1):
                attr = -attr / np.sqrt((l+m)*(l+1-m))
                H[b-m] = attr * Fml[b//2+m] * snM
                H[b+m] = attr * Fml[b//2-m] * csM
                snM, csM = snM*cs+csM*sn, csM*cs-snM*sn
        if len(H) > 0:
            return torch.cat(H, -1)
        else:
            return torch.Tensor([])

    def embed(self, inputs):
        return self.SH(inputs)


================================================
FILE: lib/model/GNR.py
================================================
import torch
import torch.nn as nn
import torch.nn.functional as F
from .HGFilters import *
from ..net_util import init_weights
from .NeRF import NeRF
from .NeRFRenderer import NeRFRenderer
from .SRFilters import SRFilters
from ..geometry import index


class GNR(nn.Module):

    def __init__(self, opt):

        super(GNR, self).__init__()
        self.name = 'gnr'

        self.opt = opt
        self.num_views = self.opt.num_views
        self.use_feat_sr = self.opt.use_feat_sr
        self.ddp = self.opt.ddp
        self.feat_dim = 64 if self.use_feat_sr else 256
        self.index = index
        self.error_term=nn.MSELoss()

        self.image_filter = HGFilter(opt)
        if self.use_feat_sr:
            self.sr_filter = SRFilters(order=2, out_ch=self.feat_dim)

        if not opt.train_encoder:
            for param in self.image_filter.parameters():
                param.requires_grad = False

        self.nerf = NeRF(opt, input_ch_feat=self.feat_dim)
        self.nerf_renderer = NeRFRenderer(opt, self.nerf)

        init_weights(self)
    
    def image_rescale(self, images, masks):
        if images.min() < -0.2:
            images = (images + 1) / 2
            images = images * (masks > 0).float()
        return images

    def get_image_feature(self, data):
        if 'feats' not in data.keys():
            images = data['images']
            im_feat = self.image_filter(images[:self.num_views])
            if self.use_feat_sr:
                im_feat = self.sr_filter(im_feat, images[:self.num_views])
            data['images'] = torch.cat([self.image_rescale(images[:self.num_views], data['masks'][:self.num_views]), \
                                        images[self.num_views:]], 0)
            data['feats'] = im_feat
        return data

    def forward(self, data, train_shape=False):
        data = self.get_image_feature(data)
        if train_shape:
            error = self.nerf_renderer.train_shape(**data)
        else:
            error = self.nerf_renderer.render(**data)

        return error

    def render_path(self, data):
        with torch.no_grad():
            rgbs = None
            data = self.get_image_feature(data)
            rgbs, depths = self.nerf_renderer.render_path(**data)

        return rgbs, depths

    def reconstruct(self, data):
        with torch.no_grad():
            data = self.get_image_feature(data)
            verts, faces, rgbs = self.nerf_renderer.reconstruct(**data)

        return verts, faces, rgbs

================================================
FILE: lib/model/HGFilters.py
================================================
"""
    This file is directly borrowed from PIFu
    GNR uses PIFu's Stacked-Hour-Glass for image encoding
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
from ..net_util import *


class HourGlass(nn.Module):
    def __init__(self, num_modules, depth, num_features, norm='batch'):
        super(HourGlass, self).__init__()
        self.num_modules = num_modules
        self.depth = depth
        self.features = num_features
        self.norm = norm

        self._generate_network(self.depth)

    def _generate_network(self, level):
        self.add_module('b1_' + str(level), ConvBlock(self.features, self.features, norm=self.norm))

        self.add_module('b2_' + str(level), ConvBlock(self.features, self.features, norm=self.norm))

        if level > 1:
            self._generate_network(level - 1)
        else:
            self.add_module('b2_plus_' + str(level), ConvBlock(self.features, self.features, norm=self.norm))

        self.add_module('b3_' + str(level), ConvBlock(self.features, self.features, norm=self.norm))

    def _forward(self, level, inp):
        # Upper branch
        up1 = inp
        up1 = self._modules['b1_' + str(level)](up1)

        # Lower branch
        low1 = F.avg_pool2d(inp, 2, stride=2)
        low1 = self._modules['b2_' + str(level)](low1)

        if level > 1:
            low2 = self._forward(level - 1, low1)
        else:
            low2 = low1
            low2 = self._modules['b2_plus_' + str(level)](low2)

        low3 = low2
        low3 = self._modules['b3_' + str(level)](low3)

        # NOTE: for newer PyTorch (1.3~), it seems that training results are degraded due to implementation diff in F.grid_sample
        # if the pretrained model behaves weirdly, switch with the commented line.
        # NOTE: I also found that "bicubic" works better.
        up2 = F.interpolate(low3, scale_factor=2, mode='bicubic', align_corners=True)
        # up2 = F.interpolate(low3, scale_factor=2, mode='nearest)

        return up1 + up2

    def forward(self, x):
        return self._forward(self.depth, x)


class HGFilter(nn.Module):
    def __init__(self, opt):
        super(HGFilter, self).__init__()
        self.num_modules = opt.num_stack

        self.opt = opt

        # Base part
        self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3)

        if self.opt.norm == 'batch':
            self.bn1 = nn.BatchNorm2d(64)
        elif self.opt.norm == 'group':
            self.bn1 = nn.GroupNorm(32, 64)

        if self.opt.hg_down == 'conv64':
            self.conv2 = ConvBlock(64, 64, self.opt.norm)
            self.down_conv2 = nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1)
        elif self.opt.hg_down == 'conv128':
            self.conv2 = ConvBlock(64, 128, self.opt.norm)
            self.down_conv2 = nn.Conv2d(128, 128, kernel_size=3, stride=2, padding=1)
        elif self.opt.hg_down == 'ave_pool':
            self.conv2 = ConvBlock(64, 128, self.opt.norm)
        else:
            raise NameError('Unknown Fan Filter setting!')

        self.conv3 = ConvBlock(128, 128, self.opt.norm)
        self.conv4 = ConvBlock(128, 256, self.opt.norm)

        # Stacking part
        for hg_module in range(self.num_modules):
            self.add_module('m' + str(hg_module), HourGlass(1, opt.num_hourglass, 256, self.opt.norm))

            self.add_module('top_m_' + str(hg_module), ConvBlock(256, 256, self.opt.norm))
            self.add_module('conv_last' + str(hg_module),
                            nn.Conv2d(256, 256, kernel_size=1, stride=1, padding=0))
            if self.opt.norm == 'batch':
                self.add_module('bn_end' + str(hg_module), nn.BatchNorm2d(256))
            elif self.opt.norm == 'group':
                self.add_module('bn_end' + str(hg_module), nn.GroupNorm(32, 256))
                
            self.add_module('l' + str(hg_module), nn.Conv2d(256,
                                                            opt.hourglass_dim, kernel_size=1, stride=1, padding=0))

            if hg_module < self.num_modules - 1:
                self.add_module(
                    'bl' + str(hg_module), nn.Conv2d(256, 256, kernel_size=1, stride=1, padding=0))
                self.add_module('al' + str(hg_module), nn.Conv2d(opt.hourglass_dim,
                                                                 256, kernel_size=1, stride=1, padding=0))

    def forward(self, x):
        x = F.relu(self.bn1(self.conv1(x)), True)
        # tmpx = x
        if self.opt.hg_down == 'ave_pool':
            x = F.avg_pool2d(self.conv2(x), 2, stride=2)
        elif self.opt.hg_down in ['conv64', 'conv128']:
            x = self.conv2(x)
            x = self.down_conv2(x)
        else:
            raise NameError('Unknown Fan Filter setting!')

        # normx = x

        x = self.conv3(x)
        x = self.conv4(x)

        previous = x

        # outputs = []
        for i in range(self.num_modules):
            hg = self._modules['m' + str(i)](previous)

            ll = hg
            ll = self._modules['top_m_' + str(i)](ll)

            ll = F.relu(self._modules['bn_end' + str(i)]
                        (self._modules['conv_last' + str(i)](ll)), True)

            # Predict heatmaps
            tmp_out = self._modules['l' + str(i)](ll)
            # outputs.append(tmp_out)

            if i < self.num_modules - 1:
                ll = self._modules['bl' + str(i)](ll)
                tmp_out_ = self._modules['al' + str(i)](tmp_out)
                previous = previous + ll + tmp_out_

        # return outputs, tmpx.detach(), normx
        return tmp_out


================================================
FILE: lib/model/NeRF.py
================================================
import torch
torch.autograd.set_detect_anomaly(True)
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import math
from .Embedder import SphericalHarmonics, PositionalEncoding

class NeRF(nn.Module):
    def __init__(self, opt, D=8, W=256, input_ch=3, input_ch_atts=3, input_ch_feat=128, output_ch=4,
                 activation="relu", pose_freqs=10, att_freqs=6, spatial_freq=1/256):

        """ 
        """
        super(NeRF, self).__init__()
        self.D = D
        self.W = W

        self.use_smpl_sdf = opt.use_smpl_sdf
        self.use_t_pose = opt.use_t_pose
        self.angle_diff = opt.angle_diff
        self.use_occ_net = opt.use_occlusion_net
        self.train_occ = opt.train_occlusion

        self.input_ch_pos_enc = input_ch
        self.input_ch_smpl = 0
        if self.use_smpl_sdf: self.input_ch_smpl += 4
        if self.use_t_pose: self.input_ch_smpl += 3
        self.use_smpl = self.input_ch_smpl != 0

        self.input_ch_feat = input_ch_feat + 3

        self.skips = opt.skips
        self.use_viewdirs = opt.use_viewdirs and opt.use_attention
        self.num_views = opt.num_views
        # self.input_ch_atts = input_ch_atts if opt.use_attention else 0
        if not opt.use_attention:
            self.input_ch_atts = 0
        elif self.angle_diff:
            self.input_ch_atts = 1
        else:
            self.input_ch_atts = 3
        self.use_sh = opt.use_sh if not self.angle_diff else False
        
        self.use_attention = opt.use_attention
        self.use_bn = opt.use_bn
        self.spatial_freq = spatial_freq
        self.pose_embeder = PositionalEncoding(self.input_ch_pos_enc, num_freqs=pose_freqs, min_freq=spatial_freq*0.1, max_freq=spatial_freq*10)
        self.att_embeder = SphericalHarmonics(d = self.input_ch_atts) if self.use_sh else PositionalEncoding(self.input_ch_atts, num_freqs=att_freqs)

        self.pose_embed_fn = self.pose_embeder.embed
        self.att_embed_fn = self.att_embeder.embed
        self.weighted_pool = opt.weighted_pool and self.use_attention

        self.alpha_linears = nn.ModuleList(
            [nn.Linear(self.pose_embeder.out_dim + self.input_ch_smpl + self.input_ch_feat, W)] +
            [nn.Linear(W + self.pose_embeder.out_dim + self.input_ch_smpl, W) if i in self.skips else nn.Linear(W, W) for i in range(0, D-1)])
        self.alpha_out_linear = nn.Linear(W, 1)
        
        self.rgb_linears = nn.ModuleList(
            [nn.Linear(W + self.pose_embeder.out_dim + self.input_ch_smpl, W//4)] + 
            [nn.Linear(W//4 + self.att_embeder.out_dim, W//8) if self.use_viewdirs else nn.Linear(W//4, W//8)] + 
            [nn.Linear(W//8, W//16)] + [nn.Linear(W//16, 3)]
        )
        if self.use_bn:
            self.bn_layer_1 = nn.BatchNorm1d(W)
            self.bn_layer_2 = nn.BatchNorm1d(W)
            self.bn_layer_3 = nn.BatchNorm1d(W//16)

        if self.weighted_pool:
            self.s = nn.Parameter(torch.ones(1))
        # ### Implementation according to the official code release (https://github.com/bmild/nerf/blob/master/run_nerf_helpers.py#L104-L105)

        if activation == "relu":
            self.activation_fn = F.relu
        elif activation == "swish":
            if torch.__version__ < "1.7.0":
                swish = lambda x: x * torch.sigmoid(x)
                self.activation_fn = swish
            else:
                self.activation_fn = torch.nn.SiLU

        if self.use_attention:
            self.value_linears = nn.ModuleList([nn.Linear(self.pose_embeder.out_dim+self.att_embeder.out_dim+W, W//4), 
                                                nn.Linear(W//4+self.att_embeder.out_dim,W//8), nn.Linear(W//8+self.att_embeder.out_dim,W//16)])
            self.key_linears = nn.ModuleList([nn.Linear(self.pose_embeder.out_dim+self.att_embeder.out_dim+W, W//4), 
                                                nn.Linear(W//4+self.att_embeder.out_dim,W//8), nn.Linear(W//8+self.att_embeder.out_dim,W//16)])

        if self.use_occ_net:
            self.occ_linears = nn.ModuleList([nn.Linear(self.input_ch_smpl + 6 + self.input_ch_feat, W//4), 
                                              nn.Linear(W//4, W//16),
                                              nn.Linear(W//16+self.input_ch_smpl + 6, 1)])
            if not self.train_occ:
                for linear in self.occ_linears:
                    linear.weight.requires_grad_(False)
                    linear.bias.requires_grad_(False)

    def forward(self, x, attdirs=None, alpha_only=False, smpl_vis=None):
        # prepare inputs
        input_pts, input_smpl, input_feats = torch.split(x, [self.input_ch_pos_enc, self.input_ch_smpl, self.input_ch_feat], dim=-1)
        unqiue_pts = input_pts[:, 0]
        unqiue_smpl = input_smpl[:, 0] if self.use_smpl else torch.zeros([input_pts.shape[0], 0], dtype=torch.float32, device=input_pts.device)
        input_pts = input_pts.view([-1, self.input_ch_pos_enc])
        input_smpl = input_smpl.view([-1, self.input_ch_smpl]) if self.use_smpl else torch.zeros([input_pts.shape[0], 0], dtype=torch.float32, device=input_pts.device)
        input_feats = input_feats.view([-1,self.input_ch_feat])
        if self.use_attention and attdirs is not None:
            qrydirs, srcdirs = torch.split(attdirs, [1, self.num_views], dim=-2)

        if self.use_occ_net and attdirs is not None:
            # compute plucker coord
            d = srcdirs.reshape([-1, 3])
            m = torch.cross(input_pts, d, dim=-1)
            occ_h = torch.cat([input_smpl, d, m, input_feats], dim=-1)
            for i, l in enumerate(self.occ_linears):
                occ_h = self.occ_linears[i](occ_h)
                if i < len(self.occ_linears) - 1:
                    occ_h = self.activation_fn(occ_h)
                if i == 1:
                    occ_h = torch.cat([input_smpl, d, m, occ_h], dim=-1)
            occ_out = torch.sigmoid(occ_h).view([-1, self.num_views, 1])
            # occ = F.softmax(occ_out, dim=1)

        # alpha mlp
        tmp_h = None
        h = torch.cat([self.pose_embed_fn(input_pts), input_smpl, input_feats], dim=-1)
        for i, l in enumerate(self.alpha_linears):
            h = self.alpha_linears[i](h)
            h = self.activation_fn(h)
            if i in self.skips:
                if i == self.skips[0]:
                    tmp_h = h.clone()
                    h = torch.mean(h.view(-1, self.num_views, self.W), dim=1)
                h = torch.cat([self.pose_embed_fn(unqiue_pts), unqiue_smpl, h], dim=-1)
        alpha = self.alpha_out_linear(h)
        if alpha_only: return alpha

        # rgb mpl
        if self.use_attention and self.weighted_pool:
            weights = torch.exp(self.s * (torch.sum(srcdirs*qrydirs, dim=-1) -1))
            weights = weights / (torch.sum(weights, dim=-1, keepdim=True) + 1e-8)  # [N_rand*N_sample, 4]
            h = torch.sum(tmp_h.view(-1, self.num_views, self.W) * weights[...,None], dim=1)
            h0 = h.clone()
        else:
            h = torch.mean(tmp_h.view(-1, self.num_views, self.W), dim=1)

        h = torch.cat([self.pose_embed_fn(unqiue_pts), unqiue_smpl, h], -1)
        for i, l in enumerate(self.rgb_linears): 
            h = self.rgb_linears[i](h)
            if i < len(self.rgb_linears) - 1:
                h = self.activation_fn(h)
            if i == 0 and self.use_viewdirs:
                h = torch.cat([self.att_embed_fn(-qrydirs.squeeze(1)), h], dim=-1)
        outputs = torch.cat([h, alpha], dim=-1)

        # calculate attention
        if self.use_attention and attdirs is not None:
            attdirs = attdirs.reshape([-1, self.input_ch_atts])
            input_pts_ = torch.cat([unqiue_pts, input_pts], dim=0)
            input_h = torch.cat([h0, tmp_h], dim=0)
            val = torch.cat([self.pose_embed_fn(input_pts_), self.att_embed_fn(attdirs), input_h], dim=-1)
            for i, l in enumerate(self.value_linears):
                val = self.value_linears[i](val)
                if i < len(self.value_linears) - 1:
                    val = self.activation_fn(val)
                    val = torch.cat([self.att_embed_fn(attdirs), val], dim=-1)
            key = torch.cat([self.pose_embed_fn(unqiue_pts), self.att_embed_fn(qrydirs.squeeze(1)), h0], dim=-1)
            for i, l in enumerate(self.key_linears):
                key = self.key_linears[i](key)
                if i < len(self.key_linears) - 1:
                    key = self.activation_fn(key)
                    key = torch.cat([self.att_embed_fn(qrydirs.squeeze(1)), key], dim=-1)
            # attention key (query direction) and val (source view direction)
            key = key.unsqueeze(1)
            val = val.view(unqiue_pts.shape[0], self.num_views+1, -1)
            attention = torch.matmul(val, key.permute(0,2,1)).squeeze(-1)
            if self.use_occ_net:
                attention = self.weighted_softmax(attention, occ_out.squeeze(-1))
            elif smpl_vis is not None:
                attention = self.weighted_softmax(attention, smpl_vis.float())
            else:
                attention = F.softmax(attention, dim=-1)

        if self.use_attention and attdirs is not None:
            outputs = torch.cat([outputs, attention], dim=-1)

        if self.use_occ_net:
            outputs = torch.cat([outputs, occ_out.squeeze(-1)], dim=-1)

        return outputs

    def weighted_softmax(self, attention, weight):
        exp_att = torch.exp(attention - torch.max(attention, 1, keepdim=True)[0])
        # exp_att_src = exp_att[:, 1:].clone() * weight
        exp_att = torch.cat([exp_att[:,:1], exp_att[:, 1:] * weight], dim=1)
        exp_att_sum = torch.sum(exp_att, dim=-1, keepdim=True)
        attention = exp_att / (exp_att_sum + 1e-8)
        return attention

================================================
FILE: lib/model/NeRFRenderer.py
================================================
import logging
from turtle import width

import torch
torch.autograd.set_detect_anomaly(True)
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from ..geometry import index, orthogonal, perspective
import trimesh
import imageio
from tqdm import tqdm
from mesh_grid import MeshGridSearcher
from skimage import measure

mse = lambda x, y : torch.mean((x - y) ** 2)
bmse = lambda x, y: torch.sum((x*y - y) ** 2) / torch.sum(y)
l1 = lambda x, y : torch.mean(torch.abs(x-y))
to8b = lambda x: (np.clip(x.detach().cpu().numpy(), 0, 1) * 255).astype(np.uint8)
eikonal = lambda x: torch.mean(x**2)


class NeRFRenderer:
    def __init__(self, opt, nerf, nerf_fine=None, projection='perspective', vgg_loss=None, threshold=0.5):
        self.opt = opt
        self.nerf = nerf
        self.nerf_fine = nerf_fine
        self.use_fine = self.nerf_fine is not None
        self.width = self.opt.loadSize
        self.height = self.opt.loadSize
        self.N_samples = opt.N_samples
        self.num_views = opt.num_views
        self.projection_mode = projection
        self.projection = orthogonal if projection == 'orthogonal' else perspective
        self.N_rand = opt.N_rand
        self.N_grid = opt.N_grid+1
        self.chunk = opt.chunk
        self.N_rand_infer = opt.N_rand_infer
        self.mse_loss = nn.MSELoss()
        self.alpha_loss = nn.BCELoss()
        self.alpha_loss_bmse = bmse
        self.alpha_grad_loss = eikonal

        self.mesh_searcher = MeshGridSearcher()
        self.use_nml = opt.use_nml
        self.use_attention = opt.use_attention
        self.threshold = threshold
        self.debug= opt.debug

        self.rgb_ch = 6 if self.use_attention else 3
        if self.debug: self.rgb_ch += self.num_views * 3
        self.debug_idx = 0

        self.use_vgg = opt.use_vgg
        self.vgg_loss = vgg_loss
        self.use_smpl_sdf = opt.use_smpl_sdf
        self.use_t_pose = opt.use_t_pose
        self.use_smpl_depth = opt.use_smpl_depth
        self.sel_cords = None
        self.regularization = opt.regularization
        self.angle_diff = opt.angle_diff
        self.use_occlusion = opt.use_occlusion and self.use_smpl_depth
        self.use_occlusion_net = opt.use_occlusion_net

        self.gamma = 1
        self.pts_nml = None
        self.alpha_grad = None
        self.alpha_gt = None
        self.alpha_smpl = None
        self.alpha = None
        self.omega_reg = 0.01

        self.nerf_out_ch = 8 if self.use_attention else 4
        self.use_vh = opt.use_vh
        self.vh_overhead = opt.vh_overhead if self.use_vh else 1
        self.use_vh_free = opt.use_vh_free
        self.use_white_bkgd = opt.use_white_bkgd
        self.default_rgb = torch.zeros if not self.use_white_bkgd else torch.ones
        self.occ = None
        self.occ_gt = None

    def cal_loss(self, rgb, rgb_gt):
        loss = {'nerf': self.mse_loss(rgb[:,:3], rgb_gt)}
        if self.use_attention:
            loss.update({'att': self.mse_loss(rgb[:,3:6], rgb_gt)})
            # loss.update({'cmb': self.mse_loss(rgb[:,6:9], rgb_gt)})
        if self.alpha_gt is not None and self.alpha is not None:
            # alpha loss has three options, self.alpha_loss (binary cross entropy), self.mse_loss (mean square error)
            # and self.alpha_loss_bmse (one sided mean square error), we find the last performs the best
            loss.update({'alpha': self.mse_loss(self.alpha, self.alpha_gt)})
        if self.regularization:
            loss.update({'alpha_reg': self.alpha_grad_loss(self.alpha_grad / self.nerf.spatial_freq) * self.omega_reg})
        if self.angle_diff and self.angle_diff_grad is not None:
            loss.update({'angle_diff': torch.mean(self.angle_diff_grad**2)})
        if self.use_occlusion_net and self.occ is not None and self.occ_gt is not None:
            loss.update({'occ': self.mse_loss(self.occ, self.occ_gt)*0.1})
        return loss

    def get_rays_orthogonal(self, bbox, calib):
        top, bottom, left, right = bbox
        cy, cx, focal = self.height/2, self.width/2, self.height/2
        radian = ((right - left) / 2 + 1) / focal
        i, j = torch.meshgrid(torch.linspace(top, bottom-1, int(bottom-top), device=calib.device), 
                                torch.linspace(left, right-1, int(right-left), device=calib.device))  # pytorch's meshgrid has indexing='ij'

        x = (j - cx) / focal
        y = (i - cy) / focal
        z = torch.sqrt(radian**2 - x**2)
        # z = torch.ones_like(x)
        starts = torch.stack([x, y, z], -1)
        ends = torch.stack([x, y, -z], -1)
        calib = torch.inverse(calib)
        R, t = calib[:3,:3], calib[:3, 3]

        rays_s = torch.sum(starts[..., None, :] * R, -1) + t
        rays_e = torch.sum(ends[..., None, :] * R, -1) + t

        return rays_s, rays_e
        
    def get_rays_perspective(self, bbox, w2c, cam):
        """
        bbox: bounding box [top, bottom, left, right]
        w2c: 4x4 rotation matrix
        cam: perspective camera parameters [fx, fy, cx, cy, (if distortion), near, far]
        """
        top, bottom, left, right = bbox
        near, far = cam[-2], cam[-1]
        top, bottom, left, right = int(top),int(bottom),int(left),int(right)
        i, j = torch.meshgrid(torch.linspace(top, bottom-1, int(bottom-top), device=w2c.device), 
                              torch.linspace(left, right-1, int(right-left), device=w2c.device))
        x = (j - cam[2]) / cam[0]
        y = (i - cam[3]) / cam[1]
        if len(cam) > 6:
            xp, yp = x, y
            for _ in range(3): # iter to undistort
                x2 = x*x
                y2 = y*y
                xy = x*y
                r2 = x2 + y2
                c = (1 + r2*(cam[4]+r2*(cam[5]+r2*cam[8])))
                x = (xp - cam[6]*2*xy - cam[7]*(r2+2*x2)) / (c+1e-9)
                y = (yp - cam[7]*2*xy - cam[6]*(r2+2*y2)) / (c+1e-9)
        z = torch.ones_like(x)
        starts = torch.stack([x*near,y*near,z*near],-1)
        ends   = torch.stack([x*far, y*far, z*far], -1)
        c2w = torch.inverse(w2c)
        R, t = c2w[:3,:3], c2w[:3, 3]

        rays_s = torch.sum(starts[..., None, :]* R, -1) + t
        rays_e = torch.sum(ends[..., None, :]  * R, -1) + t
        # rs = rays_s.cpu().numpy().reshape(-1,3)
        # re = rays_e.cpu().numpy().reshape(-1,3)
        return rays_s, rays_e

    def make_att_input(self, pts, viewdirs, calibs, smpl):
        """
        Prepare input for multiview attention based SSOAB
        """
        if self.projection_mode == 'perspective':
            cam_c = torch.inverse(calibs)[:,:3,3]
            attdirs = cam_c[None,:,:].expand(pts.shape[0], -1, -1) - pts[:,None,:].expand(-1,self.num_views,-1)
            if smpl is not None:
                viewdirs = viewdirs @ smpl['rot'][0]
                attdirs = (attdirs.view(-1,3) @ smpl['rot'][0]).view(attdirs.shape)
            attdirs = torch.cat([viewdirs[:,None,:], attdirs], dim=1)
            attdirs = attdirs / torch.clamp(torch.norm(attdirs, dim=-1, keepdim=True),min=1e-9)
            if self.angle_diff:
                viewdirs = viewdirs / torch.clamp(torch.norm(viewdirs, dim=-1, keepdim=True),min=1e-9)
                attdirs = torch.sum(attdirs * viewdirs.unsqueeze(1), dim=-1, keepdim=True)
        else:
            ## c2w @ [0,0,1] is equvilant to c2w[:3, 2] back tracing attention direction
            attdirs = torch.inverse(calibs)[:, :3, 2]                                   # [num_views, 3]
            attdirs = attdirs[None,...].repeat([pts.shape[0], 1, 1])
            if smpl is not None:
                viewdirs = viewdirs @ smpl['rot'][0]
                attdirs = attdirs @ smpl['rot'][0]
            attdirs = torch.cat([viewdirs, attdirs], dim=0)                            # [(num_views+1), 3]
            attdirs = attdirs / torch.norm(attdirs, dim=-1, keepdim=True)

        return attdirs

    def make_nerf_input(self, pts, feats, images, smpl, calibs, mesh_param, persps=None, is_train=True):
        """
        Aggregate Geometric Body Shape Embedding for NeRF input
        """
        nerf_input, source_rgb = [], None

        # Convert query point to normalized body coordinate (normalized scale and body orientation)
        center, body_scale = mesh_param['center'], mesh_param['body_scale']
        if self.use_nml:
            # points normalized to volume [-1,1]^3
            self.pts_nml = ((pts - center) / body_scale).requires_grad_()
            if self.use_smpl_sdf: self.pts_nml = self.pts_nml @ smpl['rot'][0]  # rotate to smpl volume, with smpl root node facing front
            nerf_input.append(self.pts_nml)
        else:
            nerf_input.append(pts)

        # Body shape embedding
        if self.use_smpl_sdf or self.use_t_pose:
            self.mesh_searcher.set_mesh(smpl['verts'], smpl['faces'])
            closest_pts, closest_idx = self.mesh_searcher.nearest_points(pts)

            if self.use_t_pose:
                closest_faces = smpl['faces'][closest_idx.long()]
                t_pose_verts = smpl['t_verts'][closest_faces.long()]
                t_pose_coords = t_pose_verts.mean(dim=1)
                # T-pose correspondance
                nerf_input.append(t_pose_coords)

            if self.use_smpl_sdf:
                reg_vecs = pts - closest_pts
                if self.use_nml:
                    reg_vecs = reg_vecs / body_scale     # normalized to volume [-1,1]^3
                    reg_vecs = reg_vecs @ smpl['rot'][0]                       # rotate to smpl volume, with smpl root node facing front
                signs = self.mesh_searcher.inside_mesh(pts)
                self.alpha_smpl = (signs + 1) / 2
                norm = torch.norm(reg_vecs, dim=1, keepdim=True) + 1e-8
                sdf = norm * signs[...,None]
                # Normalized SDF Gradient
                nerf_input.append(reg_vecs / norm)
                # SDF (scale for a constant for faster convergence)
                nerf_input.append(torch.tanh(sdf*20))

        # Multiview image feature
        if feats is not None:
            xyz = self.projection(pts.permute((1,0))[None,...].expand([calibs.shape[0],-1,-1]), calibs, persps)
            xy = xyz[:, :2, :]                  # [self.num_views, 2, self.N_samples]
            if persps is not None:
                xy = xy / torch.tensor([[[self.width],[self.height]]], \
                    dtype = xyz.dtype, device = xyz.device) * 2 - 1
            latent = index(feats, xy)           # [self.num_views, C, self.N_samples(*2)]
            latent = latent.permute((2,0,1))    # [self.N_samples(*2), self.num_views, C]
            source_rgb = index(images[:self.num_views], xy)
            source_rgb = source_rgb.permute((2,0,1))
            latent = torch.cat([latent, source_rgb], -1)

            nerf_input = [inp[:,None,:].expand([-1,self.num_views,-1]) for inp in nerf_input] # expand each feature to num_views
            nerf_input += [latent]

        nerf_input = torch.cat(nerf_input, dim=-1) # [self.N_samples, self.num_views, C]
        return nerf_input, source_rgb

    def make_nerf_output(self, nerf_output, t_vals, norm, source_rgb, is_train=True):
        """
        Renders ray by integrating sample points
        """
        dists = t_vals[...,1:] - t_vals[...,:-1]
        dists = torch.cat([dists, torch.tensor([1e10], device=dists.device).expand(dists[...,:1].shape)], -1)  # [N_rays, N_samples]
        dists = dists * norm
        N_samples = t_vals.shape[-1]

        rgb = torch.sigmoid(nerf_output[...,:3])  # [N_rays, N_samples, 3]
        noise = torch.randn(nerf_output[...,3].shape, device=nerf_output.device) if is_train else 0
        alpha = 1.-torch.exp(-F.relu((nerf_output[...,3]+noise)))  # [N_rays, N_samples]
        weights = alpha * torch.cumprod(torch.cat([torch.ones((alpha.shape[0], 1), device=nerf_output.device), 1.-alpha + 1e-10], -1), -1)[:, :-1]
        rgb_map = torch.sum(weights[...,None] * rgb, -2)  # [N_rays, 3]
        if self.use_attention:
            att = nerf_output[...,4:]
            source_rgb = source_rgb.reshape([-1, N_samples, self.num_views, 3])
            source_rgb = torch.cat([rgb.unsqueeze(-2), source_rgb], dim=-2)
            source_rgb_att = torch.sum(source_rgb * att[...,None], dim=-2)
            att_rgb_map = torch.sum(weights[...,None] * source_rgb_att, -2)  # [N_rays, 3]
            rgb_map = torch.cat([rgb_map, att_rgb_map], -1)
            if self.debug:
                for i in range(self.num_views):
                    source_rgb_i = source_rgb[:,:,i,:] * att[:, :, i, None]
                    rgb_map_i = torch.sum(weights[...,None] * source_rgb_i, -2)
                    rgb_map = torch.cat([rgb_map, rgb_map_i], -1)
            
        acc_map = torch.sum(weights, -1)
        if self.use_white_bkgd:
            rgb_map = rgb_map + (1.-acc_map[...,None])

        return rgb_map, weights

    def render_rays(self, ray_batch, feats, images, masks, calibs, smpl, mesh_param, 
                    scan=None, persps=None, q_persps=None, is_train=True):
        """Volumetric rendering.
        Args:
        ray_batch: array of shape [batch_size, ...]. All information necessary
            for sampling along a ray, including: ray origin, ray direction, min
            dist, max dist, and unit-magnitude viewing direction.
        """
        self.alpha, self.angle_diff_grad = None, None
        eps = 1e-9
        N_rays = ray_batch.shape[0]
        rays_s, rays_e = ray_batch[:,0:3], ray_batch[:,3:6] # [N_rays, 3] each

        t_vals = torch.linspace(0., 1., steps=self.N_samples, device=ray_batch.device)
        t_vals = t_vals.repeat([N_rays, 1])
        # perturb during training
        if is_train:
            t_rand = (torch.rand(t_vals.shape, device=ray_batch.device) - 0.5) / (self.N_samples-1)
            t_vals = t_vals + t_rand

        pts = rays_e[:,None,:] * t_vals[...,None] + (1-t_vals[...,None]) * rays_s[:,None,:]
        pts = pts.reshape(-1, 3)

        # Use visual hull to skip sample points outside the body
        inside, smpl_vis, scan_vis = None, None, None
        if self.use_vh:
            inside, smpl_vis, scan_vis = self.inside_pts_vh(pts, masks, smpl, calibs, persps)
            pts = pts[inside]
            if len(pts) == 0:
                return self.default_rgb([N_rays, self.rgb_ch], dtype=torch.float32, device=ray_batch.device), \
                       torch.zeros([N_rays], dtype=torch.float32, device=ray_batch.device)

        # When train RenderPeople with scan ground truth, prepare 3D supervision
        if is_train and scan is not None:
            scan_verts, scan_faces = scan
            self.mesh_searcher.set_mesh(scan_verts, scan_faces)
            self.alpha_gt = (self.mesh_searcher.inside_mesh(pts) + 1) / 2

        # Prepare attention based appereance blending input
        viewdirs = (rays_s - rays_e)[:,None,:].expand(-1,self.N_samples,-1)
        viewdirs = viewdirs.reshape(-1,3)[inside].requires_grad_()
        attdirs = self.make_att_input(pts, viewdirs, calibs, smpl) if self.use_attention else []

        # Prepare geometry body shape embedding input for NeRF
        nerf_input, source_rgb = self.make_nerf_input(pts, feats, images, smpl, calibs, mesh_param, persps)
        
        # Feed to the network
        nerf_output = torch.cat([self.nerf(nerf_input[i:i+self.chunk], attdirs[i:i+self.chunk], smpl_vis=smpl_vis) \
                                 for i in range(0, nerf_input.shape[0], self.chunk)], 0)
        self.alpha = torch.sigmoid(nerf_output[...,3]*self.gamma)

        # If RenderPeople available, supervise the occlusion
        if self.use_occlusion_net:
            if is_train and scan is not None:
                self.occ_gt = scan_vis.float()
                self.occ = nerf_output[:, -self.num_views:]
            nerf_output = nerf_output[:, :-self.num_views]

        # Regularize the alpha distribution
        if self.regularization and is_train:
            self.alpha_grad = torch.autograd.grad(self.alpha, self.pts_nml, grad_outputs=torch.ones_like(self.alpha), retain_graph=True)[0]

        # use sparse multiplication to aggregate points inside and outside the visual hull for NeRF integration
        if self.use_vh:
            inside_idx = torch.nonzero(inside)
            row_cols = torch.cat([inside_idx.view(1,-1), torch.arange(len(inside_idx), device=pts.device).view(1,-1)], 0)
            I = torch.sparse_coo_tensor(row_cols, torch.ones(len(inside_idx),dtype=pts.dtype, device=pts.device), size=(N_rays*self.N_samples, len(inside_idx)))
            nerf_output = torch.sparse.mm(I, nerf_output)
            nerf_output[~inside, :4] = -1e4
            full_source_rgb = torch.zeros([N_rays*self.N_samples, self.num_views, 3], device=pts.device)
            full_source_rgb[inside] = source_rgb
            source_rgb = full_source_rgb
        nerf_output = nerf_output.view(N_rays, self.N_samples, -1)
        
        norm = torch.norm(rays_e - rays_s, dim=-1, keepdim=True)
        if self.use_nml:
           center, body_scale = mesh_param['center'], mesh_param['body_scale']
           norm = norm / body_scale
        rgb_map, weights = self.make_nerf_output(nerf_output, t_vals, norm, source_rgb, is_train=is_train)
        z_vals = t_vals * q_persps[-2]  + (1-t_vals) * q_persps[-1] if persps is not None and q_persps is not None else 2*t_vals - 1
        depth = torch.sum(weights * z_vals, -1)

        # Regularize the angle difference of apperance
        if self.angle_diff and is_train:
            self.angle_diff_grad = torch.autograd.grad(rgb_map, viewdirs, grad_outputs=torch.ones_like(rgb_map), retain_graph=True)[0]

        return rgb_map, depth

    def inside_pts_vh(self, pts, masks, smpl, calibs, persps=None):
        """
        Valid sample point selection via visual hull
        """
        xyz = self.projection(pts.permute((1,0))[None,...].expand([calibs.shape[0],-1,-1]), calibs, persps)
        xy = xyz[:, :2, :]
        if persps is not None:
            xy = xy / torch.tensor([[[self.width],[self.height]]], \
                dtype = xyz.dtype, device = xyz.device) * 2 - 1
        inside = index(masks, xy, 'nearest')
        inside = torch.prod(inside, dim=0).squeeze(0) > 0
        if (inside.sum() < self.chunk * 0.7) and self.use_vh_free:
            n_samples = inside.sum() * 0.3
            idx = torch.randperm(len(inside))[:n_samples]
            inside[idx] = True
        smpl_vis, scan_vis = None, None
        if self.use_occlusion:
            smpl_depth = index(smpl['depth'], xy, 'nearest').squeeze(1).permute((1,0))
            smpl_depth = smpl_depth[inside]
            depth = xyz[:,2,:].permute((1,0))
            depth = depth[inside]
            smpl_vis = (((depth - smpl_depth) <= 0) * (smpl_depth > 0) + (smpl_depth == 0)) > 0
        if self.use_occlusion_net and 'scan_depth' in smpl.keys():
            scan_depth = index(smpl['scan_depth'], xy, 'nearest').squeeze(1).permute((1,0))[inside]
            depth = xyz[:,2,:].permute((1,0))[inside]
            scan_vis = (((depth - scan_depth) <= 0) * (scan_depth > 0) + (scan_depth == 0)) > 0
        return inside, smpl_vis, scan_vis


    def render(self, feats, images, masks, calibs, bbox, mesh_param, smpl=None, scan=None, persps=None):
        """
        Render a image from give camera pose
        """
        self.debug_idx += 1
        if persps is None:
            rays_s, rays_e = self.get_rays(bbox, calibs[-1])  # (H, W, 3), (H, W, 3)
        else:
            rays_s, rays_e = self.get_rays_perspective(bbox, calibs[-1], persps[-1])

        top, bottom, left, right = bbox
        gt = images[-1].permute((1,2,0))[top:bottom, left:right]
        coords = torch.stack(torch.meshgrid(torch.linspace(0, bottom-top-1, int(bottom-top), device=calibs.device), 
                                            torch.linspace(0, right-left-1, int(right-left), device=calibs.device)), -1)  # (H, W, 2)

        coords = coords.view(-1,2)  # (H * W, 2)
        select_inds = np.random.choice(coords.shape[0], size=[self.N_rand*self.vh_overhead], replace=False)  # (N_rand,)
        select_coords = coords[select_inds].long()  # (N_rand, 2)
        rays_s = rays_s[select_coords[:, 0], select_coords[:, 1]]  # (N_rand, 3)
        rays_e = rays_e[select_coords[:, 0], select_coords[:, 1]]  # (N_rand, 3)
        batch_rays = torch.cat([rays_s, rays_e], 1)
        target = gt[select_coords[:, 0], select_coords[:, 1]]  # (N_rand, 3)
        
        persps = persps[:self.num_views] if persps is not None else None
        rgb, _ = self.render_rays(batch_rays, feats, images[:self.num_views], masks[:self.num_views], 
                               calibs[:self.num_views], smpl, mesh_param, scan, persps)
        loss = self.cal_loss(rgb, target)

        return loss

    def render_path(self, feats, images, masks, calibs, bbox, mesh_param, smpl=None, scan=None, persps=None):
        """
        Render a path given trajectory
        """
        top, bottom, left, right = bbox
        height, width = max(self.height, bottom-top), max(self.width, right-left)
        calibs_source, calibs_query = calibs[:self.num_views], calibs[self.num_views:]
        persps_source = persps[:self.num_views] if persps is not None else None
        persps_query = persps[self.num_views:] if persps is not None else None

        rgbs, depths = [], []
        # inference
        for idx, calib in enumerate(tqdm(calibs_query)):
            if persps is None:
                rays_s, rays_e = self.get_rays(bbox, calib)  # (H, W, 3), (H, W, 3)
            else:
                persp = persps[self.num_views+idx]
                rays_s, rays_e = self.get_rays_perspective(bbox, calib, persp)
            batch_rays = torch.cat([rays_s.view(-1,3), rays_e.view(-1,3)], 1)
            # self.idx = 0
            rgb, depth = [], []
            for i in range(0, batch_rays.shape[0], self.N_rand_infer):
                c, d = self.render_rays(batch_rays[i:i+self.N_rand_infer].detach(), feats, images[:self.num_views], \
                             masks[:self.num_views], calibs_source, smpl, mesh_param, persps=persps_source, q_persps=persps_query[idx], is_train=False)
                rgb.append(c[:, :self.rgb_ch])
                depth.append(d)
            rgb = torch.cat(rgb, 0).clone()
            depth = torch.cat(depth, 0).clone()
            img = self.default_rgb((height, width, self.rgb_ch), dtype=torch.float32, device=rgb.device)
            dimg = torch.zeros((height, width), dtype=torch.float32, device=rgb.device)
            img[top:bottom, left:right] = rgb.view(int(bottom-top), int(right-left), self.rgb_ch)
            dimg[top:bottom, left:right] = depth.view(int(bottom-top), int(right-left))
            rgbs.append(img)
            depths.append(dimg)
        rgbs = torch.stack(rgbs, dim=0)
        depths = torch.stack(depths, dim=0)

        return rgbs, depths

    def train_shape(self, feats, images, masks, calibs, bbox, mesh_param, smpl=None, scan=None, persps=None):
        """
        Unused
        """
        pts, alpha_gt = mesh_param['samples'].transpose(1,0), mesh_param['labels']
        persps = persps[:self.num_views] if persps is not None else None
        nerf_input, _ = self.make_nerf_input(pts, feats, images[:self.num_views], smpl, calibs[:self.num_views], mesh_param, persps=persps)
        nerf_output = torch.cat([self.nerf(nerf_input[i:i+self.chunk], alpha_only=True) \
                                    for i in range(0, pts.shape[0], self.chunk)], 0)
        alpha = torch.sigmoid(nerf_output*self.gamma).reshape(alpha_gt.shape)
        
        # loss = self.alpha_loss(alpha, alpha_gt[...,None].detach())
        loss = {'alpha': self.rgb_loss(alpha, alpha_gt)}
        if self.regularization:
            alpha_grad = torch.autograd.grad(alpha, self.pts_nml, grad_outputs=torch.ones_like(alpha), retain_graph=True)[0]
            loss.update({'alpha_reg': self.alpha_grad_loss(alpha_grad) * self.omega_reg})
        return loss

    def reconstruct(self, feats, images, masks, calibs, bbox, mesh_param, smpl=None, scan=None, persps=None):
        """
        Mesh Reconstruction borrowed form PIFu
        """
        # Deterimine 3D bounding box
        center, body_scale = mesh_param['center'].cpu().numpy(), mesh_param['body_scale']
        top, bottom, left, right = bbox
        left, right = 0, 512
        bb_min = [left-self.width/2, top-self.height/2, left-self.width/2]
        bb_max = [right-self.width/2, bottom-self.height/2, right-self.width/2]

        # Make mesh grid in normalized body cordinate
        linspaces = [np.linspace(bb_min[i], bb_max[i], self.N_grid) for i in range(len(bb_min))]
        grids = np.stack(np.meshgrid(linspaces[0], linspaces[1], linspaces[2], indexing='ij'), -1)
        sh = grids.shape
        pts = grids / (self.width/2) * body_scale + center
        recon_kwargs = {
            'feats': feats, 'images': images, 'smpl': smpl, 'calibs': calibs[:self.num_views], 
            'mesh_param': mesh_param, 'persps': persps[:self.num_views] if persps is not None else None
        }

        # Reconstruct use progressive octree reconstrution
        sdf = self.octree_reconstruct(pts, masks, **recon_kwargs)
        verts, faces, normals, _ = measure.marching_cubes_lewiner(sdf, self.threshold)

        # Convert marching cubes coordinate back to world coordinate
        verts = (verts - self.N_grid/2) / self.N_grid * np.array([[right-left, bottom-top, right-left]])
        verts = verts / (self.width/2) * body_scale + center
        
        # use laplacian smooth if the mesh is noisy
        if self.opt.laplacian > 0:
            mesh = trimesh.Trimesh(verts, faces, process=False)
            trimesh.smoothing.filter_laplacian(mesh, iterations=self.opt.laplacian)
            verts, faces = mesh.vertices, mesh.faces
        pts = torch.tensor(verts, dtype=torch.float32, device=calibs.device)

        viewdirs = torch.from_numpy(normals.astype(np.float32)).to(calibs.device)
        attdirs = self.make_att_input(pts, viewdirs, calibs[:self.num_views], smpl) if self.use_attention else []

        rgbs = []
        for i in range(0, pts.shape[0], self.chunk):
            nerf_input, source_rgb = self.make_nerf_input(pts[i:i+self.chunk], **recon_kwargs)
            nerf_output = self.nerf(nerf_input, attdirs[i:i+self.chunk])
            rgb = torch.sigmoid(nerf_output[...,:3])
            if self.use_attention:
                att = nerf_output[...,4:4+self.num_views+1]
                source_rgb = source_rgb.view(-1, self.num_views, 3)
                source_rgb = torch.cat([rgb[:,None], source_rgb], dim=-2)
                rgb = torch.sum(source_rgb * att[...,None], dim=-2)
            rgbs.append(rgb)
        rgbs = torch.cat(rgbs, 0).cpu().numpy()

        return verts, faces, rgbs
            
    def octree_reconstruct(self, coords, masks, **kwargs):
        """
        We use Octree recontruction for higher resolution reconstruction borrowed form PIFu
        """
        device = kwargs['calibs'].device
        calibs = kwargs['calibs']
        persps = kwargs['persps']
        resolution = [self.N_grid, self.N_grid, self.N_grid]
        sdf = np.zeros(resolution)
        notprocessed = np.zeros(resolution, dtype=np.bool)
        notprocessed[:-1,:-1,:-1] = True
        # only voxel grids lies in the visual hull are to processed
        if self.use_vh:
            dilation_kernel = torch.ones((1,1,5,5), device=device, dtype=torch.float32)
            masks = torch.clamp(torch.nn.functional.conv2d(masks, dilation_kernel, padding=(2, 2)), 0, 1)
            masks_np = masks.permute([0,2,3,1]).cpu().numpy()
            pts = coords.reshape(-1, 3)
            notprocessed = notprocessed.reshape(-1)
            for i in range(0, pts.shape[0], self.chunk):
                inside, _, _ = self.inside_pts_vh(torch.tensor(pts[i:i+self.chunk], dtype=torch.float32, device=device), 
                                            masks, kwargs['smpl'], calibs, persps)
                inside = inside.cpu().numpy()
                outside = np.logical_not(inside.astype(np.bool))
                notprocessed_chunk = notprocessed[i:i+self.chunk].copy()
                notprocessed_chunk[outside] = False
                notprocessed[i:i+self.chunk] = notprocessed_chunk
            notprocessed = notprocessed.reshape(resolution)

        grid_mask = np.zeros(resolution, dtype=np.bool)
        reso = self.N_grid // 64

        center = kwargs['mesh_param']['center'].cpu().numpy()
        while reso > 0:
            grid_mask[0:self.N_grid:reso, 0:self.N_grid:reso, 0:self.N_grid:reso] = True
            test_mask = np.logical_and(grid_mask, notprocessed)
            pts = coords[test_mask, :]
            if pts.shape[0] == 0: 
                print("break")
                break

            pts_tensor = torch.tensor(pts, dtype=torch.float32, device=device)
            nerf_output = []
            for i in range(0, pts_tensor.shape[0], self.chunk):
                nerf_input, _ = self.make_nerf_input(pts_tensor[i:i+self.chunk], **kwargs)
                nerf_output.append(self.nerf(nerf_input, alpha_only=True))
            nerf_output = torch.cat(nerf_output, dim=0)
            sdf[test_mask] = torch.sigmoid(nerf_output*self.gamma).detach().cpu().numpy().reshape(-1)

            notprocessed[test_mask] = False

            # do interpolation
            if reso <= 1:
                break
            grid = np.arange(0, self.N_grid, reso)
            v = sdf[tuple(np.meshgrid(grid, grid, grid, indexing='ij'))]
            vs = [v[:-1,:-1,:-1], v[:-1,:-1,1:], v[:-1,1:,:-1], v[:-1,1:,1:], 
                    v[1:,:-1,:-1], v[1:,:-1,1:], v[1:,1:,:-1], v[1:,1:,1:]]
            grid = grid[:-1] + reso//2
            nonprocessed_grid = notprocessed[tuple(np.meshgrid(grid, grid, grid, indexing='ij'))]

            v = np.stack(vs, 0)
            v_min = v.min(0)
            v_max = v.max(0)
            v = 0.5*(v_min+v_max)
            skip_grid = np.logical_and(((v_max - v_min) < 0.01), nonprocessed_grid)
            xs, ys, zs = np.where(skip_grid)
            for x, y, z in zip(xs*reso, ys*reso, zs*reso):
                sdf[x:(x+reso+1), y:(y+reso+1), z:(z+reso+1)] = v[x//reso,y//reso,z//reso]
                notprocessed[x:(x+reso+1), y:(y+reso+1), z:(z+reso+1)] = False
            reso //= 2

        return sdf.reshape(resolution)


    def get_plucker_line(self, countour, w2c, cam, mesh_param):
        """
        Unused
        """
        x = (countour[:,0].float() - cam[2]) / cam[0]
        y = (countour[:,1].float() - cam[3]) / cam[1]
        near, far = cam[-2], cam[-1]
        center, body_scale = mesh_param['center'], mesh_param['body_scale']
        if len(cam) > 6:
            xp, yp = x, y
            for _ in range(3): # iter to undistort
                x2 = x*x
                y2 = y*y
                xy = x*y
                r2 = x2 + y2
                c = (1 + r2*(cam[4]+r2*(cam[5]+r2*cam[8])))
                x = (xp - cam[6]*2*xy - cam[7]*(r2+2*x2)) / (c+1e-9)
                y = (yp - cam[7]*2*xy - cam[6]*(r2+2*y2)) / (c+1e-9)
        z = torch.ones_like(x)
        starts = torch.stack([x*near,y*near,z*near],-1)
        ends   = torch.stack([x*far, y*far, z*far], -1)
        c2w = torch.inverse(w2c)
        R, t = c2w[:3,:3], c2w[:3, 3]

        rays_s = torch.sum(starts[..., None, :]* R, -1) + t
        rays_e = torch.sum(ends[..., None, :]  * R, -1) + t
        rays_d = rays_e - rays_s
        rays_d = rays_d / (torch.norm(rays_d, dim=-1, keepdim=True) + 1e-8)
        if self.use_nml:
            rays_s = (rays_s - center) / body_scale
        return rays_d, torch.cross(rays_s, rays_d, dim=-1)

    def warp_mat(self, face_verts):
        """
        Unused
        """
        # [[x0, x1, x2, nx/sqrt(size)]
        #  [y0, y1, y2, ny/sqrt(size)]
        #  [z0, z1, z2, nz/sqrt(size)]
        #  [1,  1,  1,  0            ]]
        a = face_verts[:,1,:] - face_verts[:,0,:]
        b = face_verts[:,2,:] - face_verts[:,0,:]
        normal = torch.cross(a, b, dim=-1)
        # cross product's norm equal 2*size of the triangluar
        size = torch.norm(normal/2, dim=-1, keepdim=True)
        normal = normal / torch.sqrt(size + 1e-10)
        mat = torch.cat([face_verts.permute([0,2,1]), normal[...,None]], dim=-1)       # [N_samples, 3, 4]
        column = torch.tensor([[[1,1,1,0]]], dtype=torch.float32, device=face_verts.device).repeat([mat.shape[0],1,1])
        mat = torch.cat([mat, column], dim=1)

        return mat

    def canonical_warpping(self, verts, verts_can, pts):
        """
        Unused
        """
        warp = self.warp_mat(verts)
        warp_can = self.warp_mat(verts_can)

        pts_homo = torch.cat([pts, torch.ones([pts.shape[0], 1], dtype=torch.float32, device=pts.device)], dim=-1)
        pts_can = warp_can @ torch.inverse(warp) @ pts_homo[...,None]
        
        return pts_can[:,:3,0]

    def multiview_consistency(self, pts, feats, calibs, depth, persps=None):
        """
        Unused
        """
        xyz = self.projection(pts.permute((1,0))[None,...].expand([calibs.shape[0],-1,-1]), calibs, persps)
        xy = xyz[:, :2, :]                  # [self.num_views, 2, self.N_samples]
        z = xyz[:, 2:, :].permute((2,0,1))
        if persps is not None:
            xy = xy / torch.tensor([[[self.width],[self.height]]], \
                dtype = xyz.dtype, device = xyz.device) * 2 - 1
        latent = index(feats, xy)           # [self.num_views, C, self.N_samples(*2)]
        latent = latent.permute((2,0,1))    # [self.N_samples(*2), self.num_views, C]

        depth_z = index(depth, xy, mode='nearest')
        depth_z = depth_z.permute((2,0,1))
        
        return depth_z, z


================================================
FILE: lib/model/SRFilters.py
================================================
import torch
import torch.nn as nn
import torch.nn.functional as F

class SRFilters(nn.Module):
    """
    Upsample the pixel-aligned feature 
    """
    def __init__(self, order=2, in_ch=256, out_ch=128):
        super(SRFilters, self).__init__()
        self.in_ch = in_ch
        self.out_ch = out_ch
        self.image_factor = [0.5**(order-i) for i in range(0, order+1)]
        self.convs = nn.ModuleList([nn.Conv2d(in_ch+3, out_ch, kernel_size=3, padding=1)] +
                    [nn.Conv2d(out_ch+3, out_ch, kernel_size=3, padding=1) for i in range(order)])

    def forward(self, feat, images):
        for i, conv in enumerate(self.convs):
            im = F.interpolate(images, scale_factor=self.image_factor[i], mode='bicubic', align_corners=True) if self.image_factor[i] is not 1 else images
            feat = F.interpolate(feat, scale_factor=2, mode='bicubic', align_corners=True) if i is not 0 else feat
            feat = torch.cat([feat, im], dim=1)
            feat = self.convs[i](feat)
        return feat

================================================
FILE: lib/model/__init__.py
================================================
from .GNR import GNR
from .NeRF import NeRF
from .NeRFRenderer import NeRFRenderer
from .Embedder import SphericalHarmonics, PositionalEncoding


================================================
FILE: lib/net_ddp.py
================================================
import torch
import torch.multiprocessing as mp
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
import logging
import os
import random
import numpy as np
import math

def worker_init_fn(worker_id):
    random.seed(worker_id+100)
    np.random.seed(worker_id+100)
    torch.manual_seed(worker_id+100)

def ddp_init(args):
    local_rank = args.local_rank

    # torch.set_default_tensor_type('torch.cuda.FloatTensor') # RuntimeError: Expected a 'N2at13CUDAGeneratorE' but found 'PN2at9GeneratorE'
    dist.init_process_group(backend = 'nccl')  # 'nccl' for GPU, 'gloo/mpi' for CPU

    torch.cuda.set_device(local_rank)

    rank = torch.distributed.get_rank()
    random.seed(rank)
    np.random.seed(rank)
    torch.manual_seed(rank)
    print(f"local_rank {local_rank} rank {rank} launched...")

    return rank, local_rank

def create_network(opt, net, local_rank=0):
    def load_network(opt, net, load=True):
        # init network from ckpts
        start_epoch, global_step = 0, 0
        ckpts = [os.path.join(opt.basedir, opt.name, f) for f in sorted(os.listdir(os.path.join(opt.basedir, opt.name))) if 'tar' in f]
        if len(ckpts) == 0:
            ckpts = [os.path.join(opt.basedir, opt.name, '..', f) for f in sorted(os.listdir(os.path.join(opt.basedir, opt.name, '..'))) if 'tar' in f]
        if len(ckpts) > 0:
            ckpt_path = ckpts[-1]
            logging.info(f'Reloading from {ckpt_path}')
            ckpt = torch.load(ckpt_path, map_location=torch.device('cpu'))  # load to cpu, otherwise this will occupy rank=0 's gpu memory
            if load:
                try:    # dp or ddp load itself
                    net.load_state_dict(ckpt['network_state_dict'])
                except:
                    try:    # dp_load_ddp
                        net.load_state_dict({'module.'+k: v for k, v in ckpt['network_state_dict'].items()})
                    except: # ddp_load_dp
                        net.load_state_dict({k[7:]: v for k, v in ckpt['network_state_dict'].items()})
            start_epoch = ckpt['epoch']
        # if no ckpts found, only init the encode from PIFu
        elif opt.load_netG_checkpoint_path is not None:
            logging.info(f'loading for net G ... {opt.load_netG_checkpoint_path}')
            pretrained_net = torch.load(opt.load_netG_checkpoint_path, map_location=torch.device('cpu'))
            if opt.ddp:
                pretrained_image_filter = {k: v for k, v in pretrained_net.items() if k.startswith('image_filter')}
            else:
                pretrained_image_filter = {'module.'+k: v for k, v in pretrained_net.items() if k.startswith('image_filter')}
            if load:
                net.load_state_dict(pretrained_image_filter, strict=False)
        return start_epoch
    
    # DDP: load parameters first (only on master node), then make ddp model
    if opt.ddp:
        logging.info("use Distributed Data Parallel...")
        net = net.to(local_rank) 
        start_epoch = load_network(opt, net, load=(dist.get_rank() == 0))
        net = DDP(net, device_ids=[local_rank], output_device=local_rank, find_unused_parameters=True)
    # DP: make dp model, then load parameters to all devices 
    else:
        if torch.cuda.is_available():
            logging.info("use Data Parallel...")
            gpu_ids = [i for i in range(torch.cuda.device_count())]
            net = net.to(gpu_ids[0])
            net = torch.nn.DataParallel(net, device_ids=gpu_ids)
            start_epoch = load_network(opt, net)
            
    return net, start_epoch

def synchronize():
    if dist.get_world_size() > 1:
        dist.barrier()
    return

class ddpSampler:
    """
    ddp sampler for inference
    """
    def __init__(self, dataset, rank=None, num_replicas=None):
        if num_replicas is None:
            if not torch.distributed.is_available():
                raise RuntimeError("Requires distributed package to be available")
            num_replicas = torch.distributed.get_world_size()
        if rank is None:
            if not torch.distributed.is_available():
                raise RuntimeError("Requires distributed package to be available")
            rank = torch.distributed.get_rank()
        self.dataset = dataset
        self.num_replicas = num_replicas
        self.rank = rank
        self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas))
        self.total_size = self.num_samples * self.num_replicas

    def indices(self):
        indices = list(range(len(self.dataset)))
        # add extra samples to make it evenly divisible
        indices += [indices[-1]] * (self.total_size - len(indices))
        # subsample
        indices = indices[self.rank * self.num_samples : (self.rank + 1) * self.num_samples]
        return indices

    def len(self):
        return self.num_samples

    def distributed_concat(self, tensor, num_total_examples):
        output_tensors = [tensor.clone() for _ in range(torch.distributed.get_world_size())]
        torch.distributed.all_gather(output_tensors, tensor)
        concat = torch.cat(output_tensors, dim=0)
        # truncate the dummy elements added by SequentialDistributedSampler
        return concat[:num_total_examples]

================================================
FILE: lib/net_util.py
================================================
import torch
from torch.nn import init
import torch.nn as nn
import torch.nn.functional as F
import functools

import numpy as np
from .mesh_util import *
from .geometry import index
import cv2
from PIL import Image
from tqdm import tqdm

def conv3x3(in_planes, out_planes, strd=1, padding=1, bias=False):
    "3x3 convolution with padding"
    return nn.Conv2d(in_planes, out_planes, kernel_size=3,
                     stride=strd, padding=padding, bias=bias)

def init_weights(net, init_type='normal', init_gain=0.02):
    """Initialize network weights.

    Parameters:
        net (network)   -- network to be initialized
        init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal
        init_gain (float)    -- scaling factor for normal, xavier and orthogonal.

    We use 'normal' in the original pix2pix and CycleGAN paper. But xavier and kaiming might
    work better for some applications. Feel free to try yourself.
    """

    def init_func(m):  # define the initialization function
        classname = m.__class__.__name__
        if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
            if init_type == 'normal':
                init.normal_(m.weight.data, 0.0, init_gain)
            elif init_type == 'xavier':
                init.xavier_normal_(m.weight.data, gain=init_gain)
            elif init_type == 'kaiming':
                init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
            elif init_type == 'orthogonal':
                init.orthogonal_(m.weight.data, gain=init_gain)
            else:
                raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
            if hasattr(m, 'bias') and m.bias is not None:
                init.constant_(m.bias.data, 0.0)
        elif classname.find(
                'BatchNorm2d') != -1:  # BatchNorm Layer's weight is not a matrix; only normal distribution applies.
            init.normal_(m.weight.data, 1.0, init_gain)
            init.constant_(m.bias.data, 0.0)

    # print('initialize network with %s' % init_type)
    net.apply(init_func)  # apply the initialization function <init_func>


def init_net(net, init_type='normal', init_gain=0.02, gpu_ids=[]):
    """Initialize a network: 1. register CPU/GPU device (with multi-GPU support); 2. initialize the network weights
    Parameters:
        net (network)      -- the network to be initialized
        init_type (str)    -- the name of an initialization method: normal | xavier | kaiming | orthogonal
        gain (float)       -- scaling factor for normal, xavier and orthogonal.
        gpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2

    Return an initialized network.
    """
    if len(gpu_ids) > 0:
        assert (torch.cuda.is_available())
        net.to(gpu_ids[0])
        net = torch.nn.DataParallel(net, gpu_ids)  # multi-GPUs
    init_weights(net, init_type, init_gain=init_gain)
    return net


def imageSpaceRotation(xy, rot):
    '''
    args:
        xy: (B, 2, N) input
        rot: (B, 2) x,y axis rotation angles

    rotation center will be always image center (other rotation center can be represented by additional z translation)
    '''
    disp = rot.unsqueeze(2).sin().expand_as(xy)
    return (disp * xy).sum(dim=1)


def cal_gradient_penalty(netD, real_data, fake_data, device, type='mixed', constant=1.0, lambda_gp=10.0):
    """Calculate the gradient penalty loss, used in WGAN-GP paper https://arxiv.org/abs/1704.00028

    Arguments:
        netD (network)              -- discriminator network
        real_data (tensor array)    -- real images
        fake_data (tensor array)    -- generated images from the generator
        device (str)                -- GPU / CPU: from torch.device('cuda:{}'.format(self.gpu_ids[0])) if self.gpu_ids else torch.device('cpu')
        type (str)                  -- if we mix real and fake data or not [real | fake | mixed].
        constant (float)            -- the constant used in formula ( | |gradient||_2 - constant)^2
        lambda_gp (float)           -- weight for this loss

    Returns the gradient penalty loss
    """
    if lambda_gp > 0.0:
        if type == 'real':  # either use real images, fake images, or a linear interpolation of two.
            interpolatesv = real_data
        elif type == 'fake':
            interpolatesv = fake_data
        elif type == 'mixed':
            alpha = torch.rand(real_data.shape[0], 1)
            alpha = alpha.expand(real_data.shape[0], real_data.nelement() // real_data.shape[0]).contiguous().view(
                *real_data.shape)
            alpha = alpha.to(device)
            interpolatesv = alpha * real_data + ((1 - alpha) * fake_data)
        else:
            raise NotImplementedError('{} not implemented'.format(type))
        interpolatesv.requires_grad_(True)
        disc_interpolates = netD(interpolatesv)
        gradients = torch.autograd.grad(outputs=disc_interpolates, inputs=interpolatesv,
                                        grad_outputs=torch.ones(disc_interpolates.size()).to(device),
                                        create_graph=True, retain_graph=True, only_inputs=True)
        gradients = gradients[0].view(real_data.size(0), -1)  # flat the data
        gradient_penalty = (((gradients + 1e-16).norm(2, dim=1) - constant) ** 2).mean() * lambda_gp  # added eps
        return gradient_penalty, gradients
    else:
        return 0.0, None

def get_norm_layer(norm_type='instance'):
    """Return a normalization layer
    Parameters:
        norm_type (str) -- the name of the normalization layer: batch | instance | none
    For BatchNorm, we use learnable affine parameters and track running statistics (mean/stddev).
    For InstanceNorm, we do not use learnable affine parameters. We do not track running statistics.
    """
    if norm_type == 'batch':
        norm_layer = functools.partial(nn.BatchNorm2d, affine=True, track_running_stats=True)
    elif norm_type == 'instance':
        norm_layer = functools.partial(nn.InstanceNorm2d, affine=False, track_running_stats=False)
    elif norm_type == 'group':
        norm_layer = functools.partial(nn.GroupNorm, 32)
    elif norm_type == 'none':
        norm_layer = None
    else:
        raise NotImplementedError('normalization layer [%s] is not found' % norm_type)
    return norm_layer

class Flatten(nn.Module):
    def forward(self, input):
        return input.view(inp
Download .txt
gitextract_wh9gfa65/

├── .gitignore
├── .gitmodules
├── README.md
├── apps/
│   ├── render_smpl_depth.py
│   └── run_genebody.py
├── configs/
│   ├── render.txt
│   ├── test.txt
│   └── train.txt
├── docs/
│   ├── Annotation.md
│   └── Dataset.md
├── environment.yml
├── genebody/
│   ├── download_tool.py
│   ├── gender.py
│   ├── genebody.py
│   └── mesh.py
├── lib/
│   ├── data/
│   │   ├── GeneBodyDataset.py
│   │   └── __init__.py
│   ├── geometry.py
│   ├── mesh_util.py
│   ├── metrics.py
│   ├── metrics_torch.py
│   ├── model/
│   │   ├── Embedder.py
│   │   ├── GNR.py
│   │   ├── HGFilters.py
│   │   ├── NeRF.py
│   │   ├── NeRFRenderer.py
│   │   ├── SRFilters.py
│   │   └── __init__.py
│   ├── net_ddp.py
│   ├── net_util.py
│   ├── options.py
│   └── ply_util.py
├── scripts/
│   ├── download_model.sh
│   ├── render_smpl_depth.sh
│   └── train_ddp.sh
└── smpl_t_pose/
    ├── smpl.obj
    └── smplx.obj
Download .txt
SYMBOL INDEX (168 symbols across 20 files)

FILE: apps/render_smpl_depth.py
  function load_obj_mesh (line 19) | def load_obj_mesh(mesh_file, with_normal=False, with_texture=False, with...
  function extract_float (line 139) | def extract_float(text):
  function natural_sort (line 149) | def natural_sort(files):
  function load_ply (line 155) | def load_ply(file_name):
  function distortPoints (line 230) | def distortPoints(p, dist):
  function rasterize (line 252) | def rasterize(v, tri, size, K = np.identity(3), \
  function render_view (line 297) | def render_view(intri, dists, c2ws, meshes, view, i):
  class Worker (line 328) | class Worker(Process):
    method __init__ (line 330) | def __init__(self, queue, lock):
    method run (line 335) | def run(self):

FILE: apps/run_genebody.py
  function loss_string (line 26) | def loss_string(loss_dict):
  function print_write (line 32) | def print_write(file, string):
  function to8b (line 37) | def to8b(img):
  function prepare_data (line 50) | def prepare_data(opt, data, local_rank=0):
  function cal_metrics (line 94) | def cal_metrics(metrics, rgbs, gts):
  function train (line 105) | def train(opt, rank=0, local_rank = 0):

FILE: genebody/download_tool.py
  function import_or_install (line 4) | def import_or_install(package):
  function parse_args (line 47) | def parse_args():
  function newSession (line 66) | def newSession():
  function save_hash (line 72) | def save_hash(path, code):
  function read_hash (line 76) | def read_hash(path):
  function checkHashes (line 81) | def checkHashes(localfile, cloud_hash, localroot, force):
  function getFiles (line 141) | def getFiles(originalUrl, download_path, force, download_root=None, req=...
  function fetch_with_pwd (line 267) | async def fetch_with_pwd(iurl, password):
  function havePwdGetFiles (line 298) | def havePwdGetFiles(iurl, password, download_path, force):
  function extractFiles (line 304) | def extractFiles(path, subset):
  function moveFiles (line 327) | def moveFiles(root):

FILE: genebody/genebody.py
  function image_cropping (line 8) | def image_cropping(mask, padding=0.1):
  class GeneBodyReader (line 65) | class GeneBodyReader():
    method __init__ (line 66) | def __init__(self, rootdir, loadsize=512):
    method get_views (line 75) | def get_views(self, subject):
    method get_frames (line 93) | def get_frames(self, subject):
    method get_cameras (line 99) | def get_cameras(self, subject):
    method get_smpl (line 102) | def get_smpl(self, subject, frame):
    method get_smpl_param (line 111) | def get_smpl_param(self, subject, frame):
    method get_data (line 133) | def get_data(self, subject, frame, camera_params, views):
    method get_near_far (line 172) | def get_near_far(self, verts, c2w, pad=0.5):
    method smpl_from_param (line 188) | def smpl_from_param(self, model_path, subject, smpl_param, smpl_scale):

FILE: genebody/mesh.py
  function max_precision (line 49) | def max_precision(type1, type2):
  function decode (line 83) | def decode(content, structure, num, form):
  function load_ply (line 131) | def load_ply(file_name):
  function save_ply (line 192) | def save_ply(file_name, elems, _type = 'binary_little_endian', comments ...
  function normalize_v3 (line 280) | def normalize_v3(arr):
  function compute_normal (line 291) | def compute_normal(vertices, faces):
  function load_obj_mesh (line 312) | def load_obj_mesh(mesh_file, with_normal=False, with_texture=False, with...
  function write_obj_mesh (line 433) | def write_obj_mesh(filename, verts, faces):

FILE: lib/data/GeneBodyDataset.py
  function mask_padding (line 22) | def mask_padding(mask, border = 5):
  function euler2rot (line 31) | def euler2rot(euler):
  function rot2euler (line 46) | def rot2euler(R):
  function gen_cam_views (line 52) | def gen_cam_views(transl, z_pitch, viewnum):
  class GeneBodyDataset (line 79) | class GeneBodyDataset(Dataset):
    method modify_commandline_options (line 81) | def modify_commandline_options(parser):
    method __init__ (line 83) | def __init__(self, opt, phase='eval', root=None, move_cam=0):
    method get_frames (line 126) | def get_frames(self, i = 0):
    method get_render_poses (line 145) | def get_render_poses(self, annots, move_cam=150):
    method __len__ (line 159) | def __len__(self):
    method image_cropping (line 165) | def image_cropping(self, mask):
    method get_near_far (line 208) | def get_near_far(self, smpl_verts, w2c):
    method get_image (line 230) | def get_image(self, sid, num_views, view_id=None, random_sample=False,...
    method smpl_from_param (line 387) | def smpl_from_param(self, model_path, subject, smpl_param, smpl_scale):
    method get_item (line 410) | def get_item(self, index):
    method __getitem__ (line 475) | def __getitem__(self, index):

FILE: lib/geometry.py
  function rot2euler (line 4) | def rot2euler(R):
  function euler2rot (line 10) | def euler2rot(euler):
  function batch_rodrigues (line 25) | def batch_rodrigues(theta):
  function quat_to_rotmat (line 41) | def quat_to_rotmat(quat):
  function index (line 64) | def index(feat, uv, mode='bilinear'):
  function orthogonal (line 81) | def orthogonal(points, calibrations, transforms=None):
  function perspective (line 100) | def perspective(points, w2c, camera):

FILE: lib/mesh_util.py
  function save_obj_mesh (line 7) | def save_obj_mesh(mesh_path, verts, faces):
  function save_obj_mesh_with_color (line 18) | def save_obj_mesh_with_color(mesh_path, verts, faces, colors):
  function save_obj_mesh_with_uv (line 30) | def save_obj_mesh_with_uv(mesh_path, verts, faces, uvs):

FILE: lib/metrics.py
  function chamfer (line 21) | def chamfer(x_verts, gt_verts, x_normals=None, gt_normals=None):
  function fscore (line 36) | def fscore(dist1, dist2, threshold=1.0):
  function psnr (line 51) | def psnr(x, gt):
  function ssim_channel (line 64) | def ssim_channel(x, gt):
  function ssim (line 85) | def ssim(x, gt):
  function lpips (line 107) | def lpips(x, gt, net=lpips_net):

FILE: lib/metrics_torch.py
  class LPIPS (line 8) | class LPIPS(torch.nn.Module):
    method __init__ (line 9) | def __init__(self):
    method forward (line 13) | def forward(self, x, gt):
  function psnr (line 27) | def psnr(x, gt):
  function gaussian (line 46) | def gaussian(window_size, sigma):
  function create_window (line 51) | def create_window(window_size, channel=1):
  function ssim_ (line 58) | def ssim_(img1, img2, window_size=11, window=None, size_average=True, fu...
  class SSIM (line 114) | class SSIM(torch.nn.Module):
    method __init__ (line 115) | def __init__(self, window_size=11, size_average=True, val_range=None):
    method forward (line 125) | def forward(self, img1, img2):

FILE: lib/model/Embedder.py
  class PositionalEncoding (line 8) | class PositionalEncoding:
    method __init__ (line 12) | def __init__(self, d, num_freqs=10, min_freq=None, max_freq=None, freq...
    method create_embedding_fn (line 19) | def create_embedding_fn(self, d):
    method embed (line 44) | def embed(self, inputs):
  class SphericalHarmonics (line 47) | class SphericalHarmonics:
    method __init__ (line 51) | def __init__(self, d = 3, rank = 3):
    method Lengdre_polynormial (line 56) | def Lengdre_polynormial(self, x, omx = None):
    method SH (line 69) | def SH(self, xyz):
    method embed (line 90) | def embed(self, inputs):

FILE: lib/model/GNR.py
  class GNR (line 12) | class GNR(nn.Module):
    method __init__ (line 14) | def __init__(self, opt):
    method image_rescale (line 40) | def image_rescale(self, images, masks):
    method get_image_feature (line 46) | def get_image_feature(self, data):
    method forward (line 57) | def forward(self, data, train_shape=False):
    method render_path (line 66) | def render_path(self, data):
    method reconstruct (line 74) | def reconstruct(self, data):

FILE: lib/model/HGFilters.py
  class HourGlass (line 12) | class HourGlass(nn.Module):
    method __init__ (line 13) | def __init__(self, num_modules, depth, num_features, norm='batch'):
    method _generate_network (line 22) | def _generate_network(self, level):
    method _forward (line 34) | def _forward(self, level, inp):
    method forward (line 60) | def forward(self, x):
  class HGFilter (line 64) | class HGFilter(nn.Module):
    method __init__ (line 65) | def __init__(self, opt):
    method forward (line 114) | def forward(self, x):

FILE: lib/model/NeRF.py
  class NeRF (line 9) | class NeRF(nn.Module):
    method __init__ (line 10) | def __init__(self, opt, D=8, W=256, input_ch=3, input_ch_atts=3, input...
    method forward (line 98) | def forward(self, x, attdirs=None, alpha_only=False, smpl_vis=None):
    method weighted_softmax (line 191) | def weighted_softmax(self, attention, weight):

FILE: lib/model/NeRFRenderer.py
  class NeRFRenderer (line 23) | class NeRFRenderer:
    method __init__ (line 24) | def __init__(self, opt, nerf, nerf_fine=None, projection='perspective'...
    method cal_loss (line 82) | def cal_loss(self, rgb, rgb_gt):
    method get_rays_orthogonal (line 99) | def get_rays_orthogonal(self, bbox, calib):
    method get_rays_perspective (line 120) | def get_rays_perspective(self, bbox, w2c, cam):
    method make_att_input (line 155) | def make_att_input(self, pts, viewdirs, calibs, smpl):
    method make_nerf_input (line 182) | def make_nerf_input(self, pts, feats, images, smpl, calibs, mesh_param...
    method make_nerf_output (line 243) | def make_nerf_output(self, nerf_output, t_vals, norm, source_rgb, is_t...
    method render_rays (line 276) | def render_rays(self, ray_batch, feats, images, masks, calibs, smpl, m...
    method inside_pts_vh (line 364) | def inside_pts_vh(self, pts, masks, smpl, calibs, persps=None):
    method render (line 393) | def render(self, feats, images, masks, calibs, bbox, mesh_param, smpl=...
    method render_path (line 423) | def render_path(self, feats, images, masks, calibs, bbox, mesh_param, ...
    method train_shape (line 462) | def train_shape(self, feats, images, masks, calibs, bbox, mesh_param, ...
    method reconstruct (line 480) | def reconstruct(self, feats, images, masks, calibs, bbox, mesh_param, ...
    method octree_reconstruct (line 534) | def octree_reconstruct(self, coords, masks, **kwargs):
    method get_plucker_line (line 608) | def get_plucker_line(self, countour, w2c, cam, mesh_param):
    method warp_mat (line 640) | def warp_mat(self, face_verts):
    method canonical_warpping (line 660) | def canonical_warpping(self, verts, verts_can, pts):
    method multiview_consistency (line 672) | def multiview_consistency(self, pts, feats, calibs, depth, persps=None):

FILE: lib/model/SRFilters.py
  class SRFilters (line 5) | class SRFilters(nn.Module):
    method __init__ (line 9) | def __init__(self, order=2, in_ch=256, out_ch=128):
    method forward (line 17) | def forward(self, feat, images):

FILE: lib/net_ddp.py
  function worker_init_fn (line 11) | def worker_init_fn(worker_id):
  function ddp_init (line 16) | def ddp_init(args):
  function create_network (line 32) | def create_network(opt, net, local_rank=0):
  function synchronize (line 81) | def synchronize():
  class ddpSampler (line 86) | class ddpSampler:
    method __init__ (line 90) | def __init__(self, dataset, rank=None, num_replicas=None):
    method indices (line 105) | def indices(self):
    method len (line 113) | def len(self):
    method distributed_concat (line 116) | def distributed_concat(self, tensor, num_total_examples):

FILE: lib/net_util.py
  function conv3x3 (line 14) | def conv3x3(in_planes, out_planes, strd=1, padding=1, bias=False):
  function init_weights (line 19) | def init_weights(net, init_type='normal', init_gain=0.02):
  function init_net (line 55) | def init_net(net, init_type='normal', init_gain=0.02, gpu_ids=[]):
  function imageSpaceRotation (line 73) | def imageSpaceRotation(xy, rot):
  function cal_gradient_penalty (line 85) | def cal_gradient_penalty(netD, real_data, fake_data, device, type='mixed...
  function get_norm_layer (line 123) | def get_norm_layer(norm_type='instance'):
  class Flatten (line 142) | class Flatten(nn.Module):
    method forward (line 143) | def forward(self, input):
  class ConvBlock (line 146) | class ConvBlock(nn.Module):
    method __init__ (line 147) | def __init__(self, in_planes, out_planes, norm='batch'):
    method forward (line 174) | def forward(self, x):

FILE: lib/options.py
  class BaseOptions (line 6) | class BaseOptions():
    method __init__ (line 7) | def __init__(self):
    method initialize (line 10) | def initialize(self, parser):
    method gather_options (line 136) | def gather_options(self):
    method print_options (line 148) | def print_options(self, opt):
    method parse (line 160) | def parse(self):

FILE: lib/ply_util.py
  function max_precision (line 47) | def max_precision(type1, type2):
  function decode (line 80) | def decode(content, structure, num, form):
  function load_ply (line 127) | def load_ply(file_name):
  function save_ply (line 187) | def save_ply(file_name, elems, _type = 'binary_little_endian', comments ...
Condensed preview — 37 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (2,588K chars).
[
  {
    "path": ".gitignore",
    "chars": 78,
    "preview": "checkpoints/*\ndata/*\n**.ply\nresults/*\nsample_images/*\nlogs/*\n**/__pycache__/*\n"
  },
  {
    "path": ".gitmodules",
    "chars": 895,
    "preview": "[submodule \"benchmarks/ibrnet\"]\r\n\tpath = benchmarks/ibrnet\r\n\turl = https://github.com/generalizable-neural-performer/gen"
  },
  {
    "path": "README.md",
    "chars": 7827,
    "preview": "# Generalizable Neural Performer: Learning Robust Radiance Fields for Human Novel View Synthesis\n[![report](https://img."
  },
  {
    "path": "apps/render_smpl_depth.py",
    "chars": 13267,
    "preview": "from tqdm import tqdm\nimport numpy as np\nimport argparse\nimport struct\nimport sys\nimport cv2\nimport os\nimport re\nfrom mu"
  },
  {
    "path": "apps/run_genebody.py",
    "chars": 15316,
    "preview": "import sys\nimport os\nROOT_PATH = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))\nsys.path.insert(0, os.path."
  },
  {
    "path": "configs/render.txt",
    "chars": 663,
    "preview": "name = genebody\n\n# run phase\ntrain = False\ntest = False\nrender = True\n\n# Dataloader\nnum_threads = 5\noutput_mesh = True\n\n"
  },
  {
    "path": "configs/test.txt",
    "chars": 662,
    "preview": "name = genebody\n\n# run phase\ntrain = False\ntest = True\nrender = False\n\n# Dataloader\nnum_threads = 5\noutput_mesh = True\n\n"
  },
  {
    "path": "configs/train.txt",
    "chars": 783,
    "preview": "name = genebody\n\n# run phase\ntrain = True\ntest = False\nrender = False\n\n# Dataloader\nnum_threads = 5\noutput_mesh = True\n\n"
  },
  {
    "path": "docs/Annotation.md",
    "chars": 3670,
    "preview": "# GeneBody Annotations\n\n## Data Capture\nGeneBody dataset captures performer in a motion capture studio with 48 synchroni"
  },
  {
    "path": "docs/Dataset.md",
    "chars": 2047,
    "preview": "# GeneBody Dataset\n\n<!-- ![Teaser image](./genebody.gif#center) -->\n<p align=\"center\"><img src=\"./genebody.gif\" width=\"9"
  },
  {
    "path": "environment.yml",
    "chars": 2885,
    "preview": "name: gnr\nchannels:\n  - pytorch\n  - conda-forge\n  - https://repo.anaconda.com/pkgs/main\n  - defaults\ndependencies:\n  - _"
  },
  {
    "path": "genebody/download_tool.py",
    "chars": 16131,
    "preview": "from ast import arg\nimport json, os, sys, pip, copy\nfrom re import sub\ndef import_or_install(package):\n    try:\n        "
  },
  {
    "path": "genebody/gender.py",
    "chars": 1361,
    "preview": "\ngenebody_gender = {\n    \"abror\": \"male\", \n    \"ahha\": \"female\", \n    \"alejandro\": \"male\", \n    \"amanda\": \"female\", \n   "
  },
  {
    "path": "genebody/genebody.py",
    "chars": 10467,
    "preview": "import os, sys\nimport numpy as np \nimport cv2, imageio\nfrom .mesh import load_ply, load_obj_mesh, write_obj_mesh\nimport "
  },
  {
    "path": "genebody/mesh.py",
    "chars": 17480,
    "preview": "import numpy as np\nimport struct\nimport sys, os, re\nimport cv2\nif sys.version_info[0] == 3:\n    from functools import re"
  },
  {
    "path": "lib/data/GeneBodyDataset.py",
    "chars": 20809,
    "preview": "from re import sub\nimport torch\nfrom torch.utils.data import Dataset\nimport torchvision.transforms as transforms\nfrom PI"
  },
  {
    "path": "lib/data/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "lib/geometry.py",
    "chars": 4625,
    "preview": "import torch\nimport numpy as np\n\ndef rot2euler(R):\n    phi = np.arctan2(R[1,2], R[2,2])\n    theta = -np.arcsin(R[0,2])\n "
  },
  {
    "path": "lib/mesh_util.py",
    "chars": 1296,
    "preview": "from skimage import measure\nimport numpy as np\nimport torch\nfrom skimage import measure\n\n\ndef save_obj_mesh(mesh_path, v"
  },
  {
    "path": "lib/metrics.py",
    "chars": 3907,
    "preview": "import sys\nimport os\n\nimport torch\n\nsys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))\nR"
  },
  {
    "path": "lib/metrics_torch.py",
    "chars": 4057,
    "preview": "import torch\nimport torch.nn.functional as F\nfrom math import exp\nimport numpy as np\nimport lpips\n\n\nclass LPIPS(torch.nn"
  },
  {
    "path": "lib/model/Embedder.py",
    "chars": 3285,
    "preview": "\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport math\nimport numpy as np\n\nclass PositionalEnco"
  },
  {
    "path": "lib/model/GNR.py",
    "chars": 2499,
    "preview": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom .HGFilters import *\nfrom ..net_util import init_"
  },
  {
    "path": "lib/model/HGFilters.py",
    "chars": 5626,
    "preview": "\"\"\"\n    This file is directly borrowed from PIFu\n    GNR uses PIFu's Stacked-Hour-Glass for image encoding\n\"\"\"\n\nimport t"
  },
  {
    "path": "lib/model/NeRF.py",
    "chars": 9803,
    "preview": "import torch\ntorch.autograd.set_detect_anomaly(True)\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport numpy "
  },
  {
    "path": "lib/model/NeRFRenderer.py",
    "chars": 33862,
    "preview": "import logging\nfrom turtle import width\n\nimport torch\ntorch.autograd.set_detect_anomaly(True)\nimport torch.nn as nn\nimpo"
  },
  {
    "path": "lib/model/SRFilters.py",
    "chars": 1029,
    "preview": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nclass SRFilters(nn.Module):\n    \"\"\"\n    Upsample the"
  },
  {
    "path": "lib/model/__init__.py",
    "chars": 144,
    "preview": "from .GNR import GNR\nfrom .NeRF import NeRF\nfrom .NeRFRenderer import NeRFRenderer\nfrom .Embedder import SphericalHarmon"
  },
  {
    "path": "lib/net_ddp.py",
    "chars": 5271,
    "preview": "import torch\nimport torch.multiprocessing as mp\nimport torch.distributed as dist\nfrom torch.nn.parallel import Distribut"
  },
  {
    "path": "lib/net_util.py",
    "chars": 8206,
    "preview": "import torch\nfrom torch.nn import init\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport functools\n\nimport nu"
  },
  {
    "path": "lib/options.py",
    "chars": 9988,
    "preview": "# import argparse\nimport configargparse\nimport os\n\n\nclass BaseOptions():\n    def __init__(self):\n        self.initialize"
  },
  {
    "path": "lib/ply_util.py",
    "chars": 9929,
    "preview": "import numpy as np\nimport struct\nimport sys\nimport re\nif sys.version_info[0] == 3:\n\tfrom functools import reduce\n# type:"
  },
  {
    "path": "scripts/download_model.sh",
    "chars": 151,
    "preview": "mkdir -p logs/genebody && cd logs/genebody\ngdown 17kVOpH4Hud-ZxKlvj0vbKB5dkIxwoHHm\nunzip genebody-pretrained.zip\nrm -f g"
  },
  {
    "path": "scripts/render_smpl_depth.sh",
    "chars": 269,
    "preview": "genebody_path=$1\n\nsubjects=($(ls ${genebody_path}))\nfor ((i=0; i<${#genebody_path[@]}; ++i)) do\n    subject=${subjects[i"
  },
  {
    "path": "scripts/train_ddp.sh",
    "chars": 470,
    "preview": "NODES=2\nGPU_PER_NODES=8\nGENEBODY_ROOT=put-your-root-here\nMASTER=your-address-of-master-machine\nCOMMAND=${1:-apps/train_g"
  },
  {
    "path": "smpl_t_pose/smpl.obj",
    "chars": 1038453,
    "preview": "#\r\n# SMPL UV coordinate template model\r\n# http://smpl.is.tue.mpg.de/\r\n# Version: 20200910\r\n#\r\n# Ownership / Licensees:\r\n"
  },
  {
    "path": "smpl_t_pose/smplx.obj",
    "chars": 1221122,
    "preview": "# Blender v2.90.1 OBJ File: ''\n# www.blender.org\no smplx_uv\nv 0.062714 0.288500 -0.009561\nv 0.066796 0.287508 -0.008525\n"
  }
]

About this extraction

This page contains the full source code of the generalizable-neural-performer/gnr GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 37 files (2.4 MB), approximately 621.2k tokens, and a symbol index with 168 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.

Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.

Copied to clipboard!