Showing preview only (418K chars total). Download the full file or copy to clipboard to get everything.
Repository: yiyuzhuang/IDOL
Branch: main
Commit: 9fd9296c28e8
Files: 58
Total size: 396.2 KB
Directory structure:
gitextract_sj4g9tfa/
├── .gitignore
├── README.md
├── configs/
│ ├── idol_debug.yaml
│ ├── idol_v0.yaml
│ └── test_dataset.yaml
├── data_processing/
│ ├── prepare_cache.py
│ ├── process_datasets.sh
│ └── visualize_samples.py
├── dataset/
│ ├── README.md
│ ├── sample/
│ │ └── param/
│ │ └── Kenya_female_fit_streetwear_50~60 years old_1501.npy
│ └── visualize_samples.py
├── env/
│ └── README.md
├── lib/
│ ├── __init__.py
│ ├── datasets/
│ │ ├── __init__.py
│ │ ├── avatar_dataset.py
│ │ └── dataloader.py
│ ├── humanlrm_wrapper_sa_v1.py
│ ├── mmutils/
│ │ ├── __init__.py
│ │ └── initialize.py
│ ├── models/
│ │ ├── __init__.py
│ │ ├── decoders/
│ │ │ ├── __init__.py
│ │ │ ├── uvmaps_decoder_gender.py
│ │ │ └── vit_head.py
│ │ ├── deformers/
│ │ │ ├── __init__.py
│ │ │ ├── fast_snarf/
│ │ │ │ ├── cuda/
│ │ │ │ │ ├── filter/
│ │ │ │ │ │ ├── filter.cpp
│ │ │ │ │ │ └── filter_kernel.cu
│ │ │ │ │ ├── fuse_kernel/
│ │ │ │ │ │ ├── fuse_cuda.cpp
│ │ │ │ │ │ └── fuse_cuda_kernel.cu
│ │ │ │ │ └── precompute/
│ │ │ │ │ ├── precompute.cpp
│ │ │ │ │ └── precompute_kernel.cu
│ │ │ │ └── lib/
│ │ │ │ └── model/
│ │ │ │ ├── deformer_smpl.py
│ │ │ │ └── deformer_smplx.py
│ │ │ ├── smplx/
│ │ │ │ ├── __init__.py
│ │ │ │ ├── body_models.py
│ │ │ │ ├── joint_names.py
│ │ │ │ ├── lbs.py
│ │ │ │ ├── utils.py
│ │ │ │ ├── vertex_ids.py
│ │ │ │ └── vertex_joint_selector.py
│ │ │ └── smplx_deformer_gender.py
│ │ ├── renderers/
│ │ │ ├── __init__.py
│ │ │ └── gau_renderer.py
│ │ ├── sapiens/
│ │ │ ├── __init__.py
│ │ │ └── sapiens_wrapper_torchscipt.py
│ │ └── transformer_sa/
│ │ ├── __init__.py
│ │ └── mae_decoder_v3_skip.py
│ ├── ops/
│ │ ├── __init__.py
│ │ └── activation.py
│ └── utils/
│ ├── infer_util.py
│ ├── mesh.py
│ ├── mesh_utils.py
│ └── train_util.py
├── run_demo.py
├── scripts/
│ ├── download_files.sh
│ ├── fetch_template.sh
│ └── pip_install.sh
├── setup.py
└── train.py
================================================
FILE CONTENTS
================================================
================================================
FILE: .gitignore
================================================
# Python cache files
__pycache__/
*.py[cod]
*$py.class
*.so
# Project specific
work_dirs/
test/
lib/models/deformers/smplx/SMPLX/*
# IDE
.idea/
.vscode/
*.swp
*.swo
# Distribution / packaging
dist/
build/
*.egg-info/
# Jupyter Notebook
.ipynb_checkpoints
# Git related
*.orig
*.rej
*.patch
================================================
FILE: README.md
================================================
# **IDOL: Instant Photorealistic 3D Human Creation from a Single Image**
[](https://yiyuzhuang.github.io/IDOL/)
[](https://arxiv.org/pdf/2412.14963)
[](https://yiyuzhuang.github.io/IDOL/)
[](https://opensource.org/licenses/MIT)
---
<p align="center">
<img src="./asset/images/Teaser_v2.png" alt="Teaser Image for IDOL" width="85%">
</p>
---
## **Abstract**
This work introduces **IDOL**, a feed-forward, single-image human reconstruction framework that is fast, high-fidelity, and generalizable.
Leveraging a large-scale dataset of 100K multi-view subjects, our method demonstrates exceptional generalizability and robustness in handling diverse human shapes, cross-domain data, severe viewpoints, and occlusions.
With a uniform structured representation, the reconstructed avatars are directly animatable and easily editable, providing a significant step forward for various applications in graphics, vision, and beyond.
In summary, this project introduces:
- **IDOL**: A scalable pipeline for instant photorealistic 3D human reconstruction using a simple yet efficient feed-forward model.
- **HuGe100K Dataset**: We develop a data generation pipeline and present \datasetname, a large-scale multi-view human dataset featuring diverse attributes, high-fidelity, high-resolution appearances, and a well-aligned SMPL-X model.
- **Application Support**: Enabling 3D human reconstruction and downstream tasks such as editing and animation.
---
## 📰 **News**
- **2024-12-18**: Paper is now available on arXiv.
- **2025-01-02**: The demo dataset containing 100 samples is now available for access. The remaining dataset is currently undergoing further cleaning and review.
- **2025-03-01**: 🎉 Paper accepted by CVPR 2025.
- **2025-03-01**: 🎉 We have released the inference code! Check out the [Code Release](#code-release) section for details.
- **2025-04-01**: 🔥 Full HuGe100K dataset is now available! See the [Dataset Access](#dataset-demo-access) section.
- **2025-04-05**: 🔥 Training code is now available! Check out the [Training Code](#training-code) section for details.
## 🚧 **Project Status**
We are actively working on releasing the following resources:
| Resource | Status | Expected Release Date |
|-----------------------------|---------------------|----------------------------|
| **Dataset Demo** | ✅ Available | **Now Live! (2025.01.02)** |
| **Inference Code** | ✅ Available | **Now Live! (2025.03.01)** |
| **Full Dataset Access** | ✅ Available | **Now Live! (2025.04.01)** |
| **Online Demo** | 🚧 In Progress | **Before April 2025** |
| **Training Code** | ✅ Available | **Now Live! (2025.04.05)** |
Stay tuned as we update this section with new releases! 🚀
## 💻 **Code Release**
### Installation & Environment Setup
Please refer to [env/README.md](env/README.md) for detailed environment setup instructions.
### Quick Start
Run demo with different modes:
```bash
# Reconstruct the input image
python run_demo.py --render_mode reconstruct
# Generate novel poses (animation)
python run_demo.py --render_mode novel_pose
# Generate 360-degree view
python run_demo.py --render_mode novel_pose_A
```
### Training
#### Data Preparation
1. **Dataset Structure**: First, prepare your dataset with the following structure:
```
dataset_root/
├── deepfashion/
│ ├── image1/
│ │ ├── videos/
│ │ │ ├── xxx.mp4
│ │ │ └── xxx.jpg
│ │ └── param/
│ │ └── xxx.npy
│ └── image2/
│ ├── videos/
│ └── param/
└── flux_batch1_5000/
├── image1/
│ ├── videos/
│ └── param/
└── image2/
├── videos/
└── param/
```
2. **Process Dataset**: Run the data processing script to generate cache files:
```bash
# Process the dataset and generate cache files
# Please modify the dataset path and the sample number in the script
bash data_processing/process_datasets.sh
```
This will generate cache files in the `processed_data` directory:
- `deepfashion_train_140.npy`
- `deepfashion_val_10.npy`
- `deepfashion_test_50.npy`
- `flux_batch1_5000_train_140.npy`
- `flux_batch1_5000_val_10.npy`
- `flux_batch1_5000_test_50.npy`
3. **Configure Cache Path**: Update the cache path in your config file (e.g., `configs/idol_v0.yaml`):
```yaml
params:
cache_path: [
./processed_data/deepfashion_train_140.npy,
./processed_data/flux_batch1_5000_train_140.npy
]
```
#### Training
1. **Single-Node Training**: For single-node multi-GPU training:
```bash
python train.py \
--base configs/idol_v0.yaml \
--num_nodes 1 \
--gpus 0,1,2,3,4,5,6,7
```
2. **Multi-Node Training**: For multi-node training, specify additional parameters:
```bash
python train.py \
--base configs/idol_v0.yaml \
--num_nodes <total_nodes> \
--node_rank <current_node_rank> \
--master_addr <master_node_ip> \
--master_port <port_number> \
--gpus 0,1,2,3,4,5,6,7
```
Example for a 2-node setup:
```bash
# On master node (node 0):
python train.py --base configs/idol_v0.yaml --num_nodes 2 --node_rank 0 --master_addr 192.168.1.100 --master_port 29500 --gpus 0,1,2,3,4,5,6,7
# On worker node (node 1):
python train.py --base configs/idol_v0.yaml --num_nodes 2 --node_rank 1 --master_addr 192.168.1.100 --master_port 29500 --gpus 0,1,2,3,4,5,6,7
```
3. **Resume Training**: To resume training from a checkpoint:
```bash
python train.py \
--base configs/idol_v0.yaml \
--resume PATH/TO/MODEL.ckpt \
--num_nodes 1 \
--gpus 0,1,2,3,4,5,6,7
```
4. **Test and Evaluate Metrics**:
```bash
python train.py \
--base configs/idol_v0.yaml \ # Main config file (model)
--num_nodes 1 \
--gpus 0,1,2,3,4,5,6,7 \
--test_sd /path/to/model_checkpoint.ckpt \ # Path to the .ckpt model you want to test
--test_dataset ./configs/test_dataset.yaml # (Optional) Dataset config used specifically for testing
```
## Notes
- Make sure all GPUs have enough memory for the selected batch size
- For multi-node training, ensure network connectivity between nodes
- Monitor training progress using the logging system
- Adjust learning rate and other hyperparameters in the config file as needed
## 🌐 **Key Links**
- 📄 [**Paper on arXiv**](https://arxiv.org/pdf/2412.02684)
- 🌐 [**Project Website**](https://yiyuzhuang.github.io/IDOL/)
- 🚀 [**Live Demo**](https://your-live-demo-link.com) (Coming Soon!)
---
## 📊 **Dataset Demo Access**
We introduce **HuGe100K**, a large-scale multi-view human dataset, supporting 3D human reconstruction and animation research.
### ▶ **Watch the Demo Video**
<p align="center">
<img src="./asset/videos/dataset.gif" alt="Dataset GIF" width="85%">
</p>
### 📋 **Dataset Documentation**
For detailed information about the dataset format, structure, and usage guidelines, please refer to our [Dataset Documentation](dataset/README.md).
### 🚀 **Access the Dataset**
<div align="center">
<p><strong>🔥 HuGe100K - The largest multi-view human dataset with 100,000+ subjects! 🔥</strong></p>
<p>High-resolution • Multi-view • Diverse poses • SMPL-X aligned</p>
<a href="https://docs.google.com/forms/d/e/1FAIpQLSeVqrA9Mc_ODdcTZsB3GgrxgSNZk5deOzK4f64N72xlQFhvzQ/viewform?usp=dialog">
<img src="https://img.shields.io/badge/Apply_for_Access-HuGe100K_Dataset-FF6B6B?style=for-the-badge&logo=googleforms&logoColor=white" alt="Apply for Access" width="300px">
</a>
<p><i>Complete the form to get access credentials and download links!</i></p>
</div>
### ⚖️ **License and Attribution**
This dataset includes images derived from the **DeepFashion** dataset, originally provided by MMLAB at The Chinese University of Hong Kong. The use of DeepFashion images in this dataset has been explicitly authorized by the original authors solely for the purpose of creating and distributing this dataset. **Users must not further reproduce, distribute, sell, or commercially exploit any images or derived data originating from DeepFashion.** For any subsequent or separate use of the DeepFashion data, users must directly obtain authorization from MMLAB and comply with the original [DeepFashion License](https://mmlab.ie.cuhk.edu.hk/projects/DeepFashion.html).
---
## 📝 **Citation**
If you find our work helpful, please cite us using the following BibTeX:
```bibtex
@article{zhuang2024idolinstant,
title={IDOL: Instant Photorealistic 3D Human Creation from a Single Image},
author={Yiyu Zhuang and Jiaxi Lv and Hao Wen and Qing Shuai and Ailing Zeng and Hao Zhu and Shifeng Chen and Yujiu Yang and Xun Cao and Wei Liu},
journal={arXiv preprint arXiv:2412.14963},
year={2024},
url={https://arxiv.org/abs/2412.14963},
}
```
## **License**
This project is licensed under the **MIT License**.
- **Permissions**: This license grants permission to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the software.
- **Condition**: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
- **Disclaimer**: The software is provided "as is", without warranty of any kind.
For more information, see the full license [here](https://opensource.org/licenses/MIT).
## **Support Our Work** ⭐
If you find our work useful for your research or applications:
- Please ⭐ **star our repository** to help us reach more people
- Consider **citing our paper** in your publications (see [Citation](#citation) section)
- Share our project with others who might benefit from it
Your support helps us continue developing open-source research projects like this one!
## 📚 **Acknowledgments**
This project is majorly built upon several excellent open-source projects:
- [E3Gen](https://github.com/olivia23333/E3Gen): Efficient, Expressive and Editable Avatars Generation
- [SAPIENS](https://github.com/facebookresearch/sapiens): High-resolution visual models for human-centric tasks
- [GeoLRM](https://github.com/alibaba-yuanjing-aigclab/GeoLRM): Large Reconstruction Model for High-Quality 3D Generation
- [3D Gaussian Splatting](https://github.com/graphdeco-inria/gaussian-splatting): Real-Time 3DGS Rendering
We thank all the authors for their contributions to the open-source community.
================================================
FILE: configs/idol_debug.yaml
================================================
debug: True
# code_size: [32, 256, 256]
code_size: [32, 1024, 1024]
model:
# base_learning_rate: 2.0e-04 # yy Need to check
target: lib.SapiensGS_SA_v1
params:
# optimizer add
# use_bf16: true
max_steps: 100_000
warmup_steps: 10_000 #12_000
use_checkpoint: true
lambda_depth_tv: 0.05
lambda_lpips: 0 #2.0
lambda_mse: 20 #1.0
lambda_offset: 1 #offset_weight: 50 mse 20, lpips 0.1
neck_learning_rate: 5e-4
decoder_learning_rate: 5e-4
output_hidden_states: true # if True, will output the hidden states from sapiens shallow layer, for the neck decoder
loss_coef: 0.5
init_iter: 500
scale_weight: 0.01
smplx_path: 'work_dirs/demo_data/Ways_to_Catch_360_clip1.json'
code_reshape: [32, 96, 96]
patch_size: 1
code_activation:
type: tanh
mean: 0.0
std: 0.5
clip_range: 2
grid_size: 64
encoder:
target: lib.models.sapiens.SapiensWrapper_ts
params:
model_path: work_dirs/ckpt/sapiens_1b_epoch_173_torchscript.pt2
# model_path: /apdcephfs_cq8/share_1367250/harriswen/projects/sapiens_convert/checkpoints//sapiens_1b_epoch_173_torchscript.pt2
layer_num: 40
img_size: [1024, 736]
freeze: True
neck:
target: lib.models.transformer_sa.neck_SA_v3_skip # TODO!! add a self attention version
params:
patch_size: 4 #4,
in_chans: 32 #32, # the uv code dims
num_patches: 9216 #4096 #num_patches #,#4096, # 16*16
embed_dim: 1536 # sapiens' latent dims # 1920 # 1920 for sapiens encoder2 #1024 # the feature extrators outputs
decoder_embed_dim: 128 # 1024
decoder_depth: 2 # 8
decoder_num_heads: 4 #16,
total_num_hidden_states: 12
mlp_ratio: 4.
decoder:
target: lib.models.decoders.UVNDecoder_gender
params:
interp_mode: bilinear
base_layers: [16, 64]
density_layers: [64, 1]
color_layers: [16, 128, 9]
offset_layers: [64, 3]
use_dir_enc: false
dir_layers: [16, 64]
activation: silu
bg_color: 1
sigma_activation: sigmoid
sigmoid_saturation: 0.001
gender: neutral
is_sub2: true ## update, make it into 10w gs points
multires: 0
image_size: [640, 896]
superres: false
focal: 1120
up_cnn_in_channels: 128 # be the same as decoder_embed_dim
reshape_type: VitHead
vithead_param:
in_channels: 128 # be the same as decoder_embed_dim
out_channels: 32
deconv_out_channels: [128, 64]
deconv_kernel_sizes: [4, 4]
conv_out_channels: [128, 128]
conv_kernel_sizes: [3, 3]
fix_sigma: true
dataset:
target: lib.datasets.dataloader.DataModuleFromConfig
params:
batch_size: 1 #16 # 6 for lpips
num_workers: 1 #2
# working when in debug mode
debug_cache_path:./processed_data/flux_batch1_5000_test_50_local.npy
train:
target: lib.datasets.AvatarDataset
params:
data_prefix: None
cache_path: [
./processed_data/deepfashion_train_140_local.npy,
./processed_data/flux_batch1_5000_train_140_local.npy
]
specific_observation_num: 5
better_range: true
first_is_front: true
if_include_video_ref_img: true
prob_include_video_ref_img: 0.5
img_res: [640, 896]
validation:
target: lib.datasets.AvatarDataset
params:
data_prefix: None
load_imgs: true
specific_observation_num: 3
better_range: true
first_is_front: true
img_res: [640, 896]
cache_path: [
./processed_data/flux_batch1_5000_test_50_local.npy,
#./processed_data/flux_batch1_5000_val_10_local.npy
]
lightning:
modelcheckpoint:
params:
every_n_train_steps: 4000 #2000
save_top_k: -1
save_last: true
monitor: 'train/loss_mse' # ADD this logging in the wrapper_sa
mode: "min"
filename: 'sample-synData-epoch{epoch:02d}-val_loss{val/loss:.2f}'
callbacks: {}
trainer:
num_sanity_val_steps: 1
accumulate_grad_batches: 1
gradient_clip_val: 10.0
max_steps: 80000
check_val_every_n_epoch: 1 ## check validation set every 1 training batches in the current epoch
benchmark: true
val_check_interval: 1.0
================================================
FILE: configs/idol_v0.yaml
================================================
debug: True
# code_size: [32, 256, 256]
code_size: [32, 1024, 1024]
model:
# base_learning_rate: 2.0e-04 # yy Need to check
target: lib.SapiensGS_SA_v1
params:
# optimizer add
# use_bf16: true
max_steps: 100_000
warmup_steps: 10_000 #12_000
use_checkpoint: true
lambda_depth_tv: 0.05
lambda_lpips: 10 #2.0
lambda_mse: 20 #1.0
lambda_offset: 1 #offset_weight: 50 mse 20, lpips 0.1
neck_learning_rate: 5e-4
decoder_learning_rate: 5e-4
output_hidden_states: true # if True, will output the hidden states from sapiens shallow layer, for the neck decoder
loss_coef: 0.5
init_iter: 500
scale_weight: 0.01
smplx_path: 'work_dirs/demo_data/Ways_to_Catch_360_clip1.json'
code_reshape: [32, 96, 96]
patch_size: 1
code_activation:
type: tanh
mean: 0.0
std: 0.5
clip_range: 2
grid_size: 64
encoder:
target: lib.models.sapiens.SapiensWrapper_ts
params:
model_path: work_dirs/ckpt/sapiens_1b_epoch_173_torchscript.pt2
# model_path: /apdcephfs_cq8/share_1367250/harriswen/projects/sapiens_convert/checkpoints//sapiens_1b_epoch_173_torchscript.pt2
layer_num: 40
img_size: [1024, 736]
freeze: True
neck:
target: lib.models.transformer_sa.neck_SA_v3_skip # TODO!! add a self attention version
params:
patch_size: 4 #4,
in_chans: 32 #32, # the uv code dims
num_patches: 9216 #4096 #num_patches #,#4096, # 16*16
embed_dim: 1536 # 1920 # 1920 for sapiens encoder2 #1024 # the feature extrators outputs
decoder_embed_dim: 1536 # 1024
decoder_depth: 16 # 8
decoder_num_heads: 16 #16,
total_num_hidden_states: 40
mlp_ratio: 4.
decoder:
target: lib.models.decoders.UVNDecoder_gender
params:
interp_mode: bilinear
base_layers: [16, 64]
density_layers: [64, 1]
color_layers: [16, 128, 9]
offset_layers: [64, 3]
use_dir_enc: false
dir_layers: [16, 64]
activation: silu
bg_color: 1
sigma_activation: sigmoid
sigmoid_saturation: 0.001
gender: neutral
is_sub2: true ## update, make it into 10w gs points
multires: 0
image_size: [640, 896]
superres: false
focal: 1120
up_cnn_in_channels: 1536 # be the same as decoder_embed_dim
reshape_type: VitHead
vithead_param:
in_channels: 1536 # be the same as decoder_embed_dim
out_channels: 32
deconv_out_channels: [512, 512, 512, 256]
deconv_kernel_sizes: [4, 4, 4, 4]
conv_out_channels: [128, 128]
conv_kernel_sizes: [3, 3]
fix_sigma: true
dataset:
target: lib.datasets.dataloader.DataModuleFromConfig
params:
batch_size: 1 #16 # 6 for lpips
num_workers: 2 #2
# working when in debug mode
debug_cache_path: ./processed_data/flux_batch1_5000_test_50_local.npy
train:
target: lib.datasets.AvatarDataset
params:
data_prefix: None
cache_path: [
./processed_data/deepfashion_train_140_local.npy,
./processed_data/flux_batch1_5000_train_140_local.npy
]
specific_observation_num: 5
better_range: true
first_is_front: true
if_include_video_ref_img: true
prob_include_video_ref_img: 0.5
img_res: [640, 896]
validation:
target: lib.datasets.AvatarDataset
params:
data_prefix: None
load_imgs: true
specific_observation_num: 5
better_range: true
first_is_front: true
img_res: [640, 896]
cache_path: [
./processed_data/deepfashion_val_10_local.npy,
./processed_data/flux_batch1_5000_val_10_local.npy
]
lightning:
modelcheckpoint:
params:
every_n_train_steps: 4000 #2000
save_top_k: -1
save_last: true
monitor: 'train/loss_mse' # ADD this logging in the wrapper_sa
mode: "min"
filename: 'sample-synData-epoch{epoch:02d}-val_loss{val/loss:.2f}'
callbacks: {}
trainer:
num_sanity_val_steps: 0
accumulate_grad_batches: 1
gradient_clip_val: 10.0
max_steps: 80000
check_val_every_n_epoch: 1 ## check validation set every 1 training batches in the current epoch
benchmark: true
================================================
FILE: configs/test_dataset.yaml
================================================
dataset:
target: lib.datasets.dataloader.DataModuleFromConfig
params:
batch_size: 1
num_workers: 2
# working when in debug mode
debug_cache_path: ./processed_data/flux_batch1_5000_test_50_local.npy
train:
target: lib.datasets.AvatarDataset
params:
data_prefix: None
cache_path: [
./processed_data/deepfashion_train_140_local.npy,
./processed_data/flux_batch1_5000_train_140_local.npy
]
specific_observation_num: 5
better_range: true
first_is_front: true
if_include_video_ref_img: true
prob_include_video_ref_img: 0.5
img_res: [640, 896]
validation:
target: lib.datasets.AvatarDataset
params:
data_prefix: None
load_imgs: true
specific_observation_num: 5
better_range: true
first_is_front: true
img_res: [640, 896]
cache_path: [
./processed_data/deepfashion_val_10_local.npy,
./processed_data/flux_batch1_5000_val_10_local.npy
]
test:
target: lib.datasets.AvatarDataset
params:
data_prefix: None
load_imgs: true
specific_observation_num: 5
better_range: true
first_is_front: true
img_res: [640, 896]
cache_path: [
./processed_data/deepfashion_test_50_local.npy,
./processed_data/flux_batch1_5000_test_50_local.npy
]
================================================
FILE: data_processing/prepare_cache.py
================================================
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
Data preparation script for DeepFashion video dataset.
This script processes video files and their corresponding parameters,
and splits the dataset into train/val/test sets.
"""
import os
import numpy as np
import argparse
def parse_args():
"""Parse command line arguments."""
parser = argparse.ArgumentParser(description="Prepare DeepFashion video dataset")
parser.add_argument(
"--video_dir",
type=str,
default="/apdcephfs/private_harriswen/data/deepfashion/",
help="Base directory containing imageX folders"
)
parser.add_argument(
"--output_dir",
type=str,
default="./",
help="Directory to save the processed data"
)
parser.add_argument(
"--prefix",
type=str,
default="DeepFashion",
help="Prefix for the output file names"
)
parser.add_argument(
"--max_videos",
type=int,
default=20000,
help="Maximum number of videos to process (for creating smaller datasets)"
)
return parser.parse_args()
def prepare_dataset(video_dir, output_dir, prefix, max_total_videos=20000):
"""
Prepare the DeepFashion dataset by processing videos and parameters.
Args:
video_dir: Base directory containing imageX folders
output_dir: Directory to save processed data
prefix: Prefix for output filenames
max_total_videos: Maximum number of videos to process (default: 20000)
"""
# Find all imageX subdirectories
image_dirs = []
for item in os.listdir(video_dir):
if item.startswith("image") and os.path.isdir(os.path.join(video_dir, item)):
image_dirs.append(item)
image_dirs.sort()
print(f"Found {len(image_dirs)} image directories: {image_dirs}")
# Collect all video files
all_video_files = []
all_param_files = []
all_dir_names = []
# import ipdb; ipdb.set_trace()
for image_dir in image_dirs:
videos_path = os.path.join(video_dir, image_dir, "videos")
params_path = os.path.join(video_dir, image_dir, "param")
if not os.path.exists(videos_path):
print(f"Warning: Videos directory not found in {image_dir}, skipping.")
continue
if not os.path.exists(params_path):
print(f"Warning: Parameters directory not found in {image_dir}, skipping.")
continue
# Get list of video names in current directory
param_names = os.listdir(params_path)
# filter the files with .npy extension
param_names = [name for name in param_names if name.endswith(".npy")]
for name in param_names:
video_path = os.path.join(videos_path, name.replace(".npy", ".mp4"))
param_path = os.path.join(params_path, name)
# Check if both video and parameter files exist
if not os.path.exists(video_path):
print(f"Warning: Video file not found: {video_path}, skipping.")
continue
if not os.path.exists(param_path):
print(f"Warning: Parameter file not found: {param_path}, skipping.")
continue
# Add to collection only if both files exist
all_video_files.append(video_path)
all_param_files.append(param_path)
all_dir_names.append(image_dir)
total_videos = len(all_video_files)
print(f"Total valid videos found: {total_videos}")
if total_videos == 0:
print("Error: No valid video-parameter pairs found. Please check your data paths.")
return
# Limit number of videos to process
if max_total_videos < total_videos:
# Randomly shuffle and select first max_total_videos
indices = list(range(total_videos))
np.random.shuffle(indices)
indices = indices[:max_total_videos]
all_video_files = [all_video_files[i] for i in indices]
all_param_files = [all_param_files[i] for i in indices]
all_dir_names = [all_dir_names[i] for i in indices]
print(f"Limiting to {max_total_videos} videos")
# Process videos and parameters
scenes = []
processed_count = 0
skipped_count = 0
for video_path, param_path, dir_name in zip(all_video_files, all_param_files, all_dir_names):
processed_count += 1
case_name = os.path.basename(video_path)
print(f"Processing {processed_count}/{len(all_video_files)}: {dir_name}/{case_name}")
try:
# Create scene dictionary
scenes.append(dict(
video_path=video_path,
image_paths=None, # only fill it for the data in a images sequence instead of a video
param_path=param_path,
image_ref=video_path.replace(".mp4", ".jpg")
))
except Exception as e:
print(f"Error processing {video_path}: {e}")
skipped_count += 1
print(f"Total scenes collected: {len(scenes)}")
print(f"Total scenes skipped: {skipped_count}")
if len(scenes) == 0:
print("Error: No scenes could be processed. Please check your data.")
return
# Split dataset
total_scenes = len(scenes)
test_scenes = scenes[-50:] if total_scenes > 50 else []
val_scenes = scenes[-60:-50] if total_scenes > 60 else []
train_scenes = scenes[:-60] if total_scenes > 60 else scenes
# Save each split
splits = {
"train": train_scenes,
"val": val_scenes,
"test": test_scenes,
"all": scenes
}
# Create output directory
os.makedirs(output_dir, exist_ok=True)
# Save each split to separate file
for split_name, split_data in splits.items():
if not split_data:
continue
cache_path = os.path.join(
output_dir,
f"{prefix}_{split_name}_{len(split_data)}.npy"
)
np.save(cache_path, split_data)
print(f"Saved {split_name} split with {len(split_data)} samples to {cache_path}")
if __name__ == "__main__":
# Parse command line arguments
args = parse_args()
# Prepare and save the dataset
prepare_dataset(args.video_dir, args.output_dir, args.prefix, args.max_videos)
print(f"Done processing {args.video_dir} dataset")
================================================
FILE: data_processing/process_datasets.sh
================================================
#!/bin/bash
# Data processing script for multiple datasets
# This script processes all specified datasets and saves the results to output directories
# Define the list of dataset paths
DATASET_PATHS=(
"/PATH/TO/deepfashion"
"/PATH/TO/flux_batch1_5000"
"/PATH/TO/flux_batch2"
# Add more dataset paths here as needed
)
# Output base directory for processed cache files
OUTPUT_BASE_DIR="./processed_data"
# Maximum videos to process per dataset (set to a smaller number for testing)
# if you want to process all videos, set MAX_VIDEOS to a very large number
MAX_VIDEOS=200
# Process each dataset
for DATASET_PATH in "${DATASET_PATHS[@]}"; do
# Extract dataset name from path (use the last directory name as prefix)
DATASET_NAME=$(basename "$DATASET_PATH")
# Create output directory for this dataset
OUTPUT_DIR="${OUTPUT_BASE_DIR}"
mkdir -p "$OUTPUT_DIR"
echo "===== Processing ${DATASET_NAME} Dataset ====="
echo "Source: ${DATASET_PATH}"
echo "Destination: ${OUTPUT_DIR}"
# Run the processing script
python data_processing/prepare_cache.py \
--video_dir "${DATASET_PATH}" \
--output_dir "${OUTPUT_DIR}" \
--prefix "${DATASET_NAME}" \
--max_videos "${MAX_VIDEOS}"
# Check if processing was successful
if [ $? -ne 0 ]; then
echo "Error processing ${DATASET_NAME} dataset"
echo "Continuing with next dataset..."
else
echo "Successfully processed ${DATASET_NAME} dataset"
fi
echo "----------------------------------------"
done
echo "===== All datasets processing completed ====="
echo "Results saved to: ${OUTPUT_BASE_DIR}"
# List all processed datasets
echo "Processed datasets:"
for DATASET_PATH in "${DATASET_PATHS[@]}"; do
DATASET_NAME=$(basename "$DATASET_PATH")
echo "- ${DATASET_NAME}: ${OUTPUT_BASE_DIR}/${DATASET_NAME}"
done
================================================
FILE: data_processing/visualize_samples.py
================================================
import torch
import numpy as np
import os
os.environ["PYOPENGL_PLATFORM"] = "osmesa"
import smplx
import trimesh
import pyrender
import imageio
def init_smplx_model():
"""Initialize the SMPL-X model with predefined settings."""
body_model = smplx.SMPLX('PATH_TO_YOUR_SMPLX_FOLDER',
gender="neutral",
create_body_pose=False,
create_betas=False,
create_global_orient=False,
create_transl=False,
create_expression=True,
create_jaw_pose=True,
create_leye_pose=True,
create_reye_pose=True,
create_right_hand_pose=False,
create_left_hand_pose=False,
use_pca=False,
num_pca_comps=12,
num_betas=10,
flat_hand_mean=False)
return body_model
# Load SMPL-X parameters
param_path = "./100samples/Apose/param/Argentina_male_buff_thermal wear_20~30 years old_1573.npy"
param = np.load(param_path, allow_pickle=True).item()
# Extract SMPL-X parameters
smpl_params = param['smpl_params'].reshape(1, -1)
scale, transl, global_orient, pose, betas, left_hand_pose, right_hand_pose, jaw_pose, leye_pose, reye_pose, expression = torch.split(smpl_params, [1, 3, 3, 63, 10, 45, 45, 3, 3, 3, 10], dim=1)
# Initialize SMPL-X model and generate vertices
device = torch.device("cpu")
model = init_smplx_model().to(device)
output = model(global_orient=global_orient, body_pose=pose, betas=betas, left_hand_pose=left_hand_pose,
right_hand_pose=right_hand_pose, jaw_pose=jaw_pose, leye_pose=leye_pose, reye_pose=reye_pose,
expression=expression)
vertices = output.vertices[0].detach().cpu().numpy()
faces = model.faces
# Create a Trimesh and Pyrender mesh
mesh = trimesh.Trimesh(vertices, faces)
mesh_pyrender = pyrender.Mesh.from_trimesh(mesh)
rendered_images_list = []
# Loop through multiple camera views
for idx in range(24):
scene = pyrender.Scene()
scene.add(mesh_pyrender)
# Load and process camera parameters
camera_params = param['poses']
intrinsic_params = camera_params[idx][1] # fx, fy, cx, cy
extrinsic_params = camera_params[idx][0] # R|T
# Set up Pyrender camera
camera = pyrender.IntrinsicsCamera(fx=intrinsic_params[0], fy=intrinsic_params[1],
cx=intrinsic_params[2], cy=intrinsic_params[3])
# Convert COLMAP coordinates to Pyrender-compatible transformation
extrinsic_params_inv = torch.inverse(extrinsic_params.clone())
scale_factor = extrinsic_params_inv[:3, :3].norm(dim=1)
extrinsic_params_inv[:3, 1:3] = -extrinsic_params_inv[:3, 1:3]
extrinsic_params_inv[3, :3] = 0
# Add camera and lighting
scene.add(camera, pose=extrinsic_params_inv)
light = pyrender.DirectionalLight(color=[1.0, 1.0, 1.0], intensity=10.0)
scene.add(light, pose=extrinsic_params_inv)
# Render the scene
renderer = pyrender.OffscreenRenderer(640, 896)
color, depth = renderer.render(scene)
rendered_images_list.append(color)
renderer.delete()
# Save rendered images as a video
rendered_images = np.stack(rendered_images_list)
imageio.mimwrite('rendered_results.mp4', rendered_images, fps=15)
print("Rendered results saved as rendered_results.mp4")
# Load an existing video and test alignment
video_path = param_path.replace("param", "videos").replace("npy", "mp4")
input_video = imageio.get_reader(video_path)
input_frames = [frame for frame in input_video]
blended_frames = [(0.5 * frame + 0.5 * render_frame).astype(np.uint8) for render_frame, frame in zip(rendered_images, input_frames)]
imageio.mimwrite('aligned_results.mp4', blended_frames, fps=15)
print("Blended video saved as aligned_results.mp4")
================================================
FILE: dataset/README.md
================================================
# 🌟 HuGe100K Dataset Documentation
## 📊 Dataset Overview
HuGe100K is a large-scale multi-view human dataset featuring diverse attributes, high-fidelity appearances, and well-aligned SMPL-X models.
## 📁 File Format and Structure
The dataset is organized with the following structure:
```
HuGe100K/
├── flux_batch1/
│ ├── images[0...9]/ # different batch of images
│ │ ├── videos/ # Folder for .mp4 and .jpg files
│ │ │ ├── Algeria_female_average_high fashion_50~60 years old_844.jpg
│ │ │ ├── Algeria_female_average_high fashion_50~60 years old_844.mp4
│ │ │ └── ... (more .jpg and .mp4)
│ │ └── param/ # Folder for parameter files (.npy)
│ │ ├── Algeria_female_average_high fashion_50~60 years old_844.npy
│ │ └── ... (more .npy files)
├── flux_batch2/
│ └── ... (similar structure with images[0...9])
├── flux_batch3/
│ └── ... (similar structure with images[0...9])
├── flux_batch4/
│ └── ... (similar structure with images[0...9])
├── flux_batch5/
│ └── ... (similar structure with images[0...9])
├── flux_batch6/
│ └── ... (similar structure with images[0...9])
├── flux_batch7/
│ └── ... (similar structure with images[0...9])
├── flux_batch8/
│ └── ... (similar structure with images[0...9])
├── flux_batch9/
│ └── ... (similar structure with images[0...9])
└── deepfashion/
└── ... (similar structure with images[0...9])
```
Where:
- Each `images[X]` folder contains:
- `videos/`: Reference images and generatedvideo files
- `param/`: Camera and body pose parameters
- **flux_batch1 through flux_batch7**: Contains subjects in A-pose
- **flux_batch8 and flux_batch9**: Contains subjects in diverse poses
- **deepfashion**: Contains subjects in A-pose (derived from the DeepFashion dataset)
### File Naming Convention
Files follow the naming pattern: `Area_Gender_BodyType_Clothing_Age_ID.extension`
For example:
- `Algeria_female_average_high fashion_50~60 years old_844.jpg`: Reference image of an Algerian female with average build in high fashion clothing
- `Algeria_female_average_high fashion_50~60 years old_844.npy`: Parameter file for the same subject
### 📸 Sample Visualization
<div style="display: flex; align-items: center; justify-content: center; gap: 10px; flex-wrap: nowrap; width: 100%;">
<img src="sample/videos/Kenya_female_fit_streetwear_50~60 years old_1501.jpg" alt="Kenya Female Fit Streetwear Image" style="max-width: 45%; width: 45%; height: auto;">
<span style="font-weight: bold;"> =MVChamp=> </span>
<!-- <video autoplay loop muted playsinline style="max-width: 45%; width: 45%; height: auto;">
<source src="sample/videos/Kenya_female_fit_streetwear_50~60 years old_1501.gif" type="video/mp4">
Your browser does not support the video tag.
</video> -->
<img src="sample/videos/Kenya_female_fit_streetwear_50~60 years old_1501.gif" alt="Kenya Female Fit Streetwear Image" style="max-width: 45%; width: 45%; height: auto;">
</div>
## 📈 Dataset Statistics
- **Total Subjects**: 100,000+
- **Views per Subject**: Multiple viewpoints covering 360° in 24 views
- **Pose Types**: A-pose and diverse poses
## 🔍 Visualizing the Dataset
For visualization and data parsing examples, please refer to our provided script:
`visualize_samples.py`. This script demonstrates how to:
- Load the SMPL-X parameters from `.npy` files
- Render the 3D human model from multiple camera views
- Compare rendered results with the original video frames
Requirements for visualization:
- SMPL-X model (download from [official website](https://smpl-x.is.tue.mpg.de/))
- Python packages: `pyrender`, `trimesh`, `smplx`, `numpy`, `torch`
Example usage:
```bash
python visualize_samples.py
```
The script will generate:
- `rendered_results.mp4`: Rendered views of the 3D model
- `aligned_results.mp4`: Blended visualization of rendered model with original frames
## 📋 Usage Guidelines
1. **Research Purposes Only**: This dataset is intended for academic and research purposes.
2. **Citation Required**: If you use this dataset in your research, please cite our paper.
3. **No Commercial Use**: Commercial use is permitted only with explicit permission from us at yiyu.zhuang@smail.nju.edu.cn.
4. **DeepFashion Derivatives**: See License and Attribution section below for special requirements.
## ⚖️ License and Attribution (DeepFashion)
This dataset includes images derived from the **DeepFashion** dataset, originally provided by MMLAB at The Chinese University of Hong Kong. The use of DeepFashion images in this dataset has been explicitly authorized by the original authors solely for the purpose of creating and distributing this dataset. **Users must not further reproduce, distribute, sell, or commercially exploit any images or derived data originating from DeepFashion.** For any subsequent or separate use of the DeepFashion data, users must directly obtain authorization from MMLAB and comply with the original [DeepFashion License](https://mmlab.ie.cuhk.edu.hk/projects/DeepFashion.html).
================================================
FILE: dataset/visualize_samples.py
================================================
import torch
import numpy as np
import os
os.environ["PYOPENGL_PLATFORM"] = "osmesa"
import smplx
import trimesh
import pyrender
import imageio
def init_smplx_model():
"""Initialize the SMPL-X model with predefined settings."""
body_model = smplx.SMPLX('PATH_TO_YOUR_SMPLX_FOLDER',
gender="neutral",
create_body_pose=False,
create_betas=False,
create_global_orient=False,
create_transl=False,
create_expression=True,
create_jaw_pose=True,
create_leye_pose=True,
create_reye_pose=True,
create_right_hand_pose=False,
create_left_hand_pose=False,
use_pca=False,
num_pca_comps=12,
num_betas=10,
flat_hand_mean=False)
return body_model
# Load SMPL-X parameters
param_path = "./samples/param/Kenya_female_fit_streetwear_50~60 years old_1501.npy"
param = np.load(param_path, allow_pickle=True).item()
# Extract SMPL-X parameters
smpl_params = param['smpl_params'].reshape(1, -1)
scale, transl, global_orient, pose, betas, left_hand_pose, right_hand_pose, jaw_pose, leye_pose, reye_pose, expression = torch.split(smpl_params, [1, 3, 3, 63, 10, 45, 45, 3, 3, 3, 10], dim=1)
# Initialize SMPL-X model and generate vertices
device = torch.device("cpu")
model = init_smplx_model().to(device)
output = model(global_orient=global_orient, body_pose=pose, betas=betas, left_hand_pose=left_hand_pose,
right_hand_pose=right_hand_pose, jaw_pose=jaw_pose, leye_pose=leye_pose, reye_pose=reye_pose,
expression=expression)
vertices = output.vertices[0].detach().cpu().numpy()
faces = model.faces
# Create a Trimesh and Pyrender mesh
mesh = trimesh.Trimesh(vertices, faces)
mesh_pyrender = pyrender.Mesh.from_trimesh(mesh)
rendered_images_list = []
# Loop through multiple camera views
for idx in range(24):
scene = pyrender.Scene()
scene.add(mesh_pyrender)
# Load and process camera parameters
camera_params = param['poses']
intrinsic_params = camera_params[idx][1] # fx, fy, cx, cy
extrinsic_params = camera_params[idx][0] # R|T
# Set up Pyrender camera
camera = pyrender.IntrinsicsCamera(fx=intrinsic_params[0], fy=intrinsic_params[1],
cx=intrinsic_params[2], cy=intrinsic_params[3])
# Convert COLMAP coordinates to Pyrender-compatible transformation
extrinsic_params_inv = torch.inverse(extrinsic_params.clone())
scale_factor = extrinsic_params_inv[:3, :3].norm(dim=1)
extrinsic_params_inv[:3, 1:3] = -extrinsic_params_inv[:3, 1:3]
extrinsic_params_inv[3, :3] = 0
# Add camera and lighting
scene.add(camera, pose=extrinsic_params_inv)
light = pyrender.DirectionalLight(color=[1.0, 1.0, 1.0], intensity=10.0)
scene.add(light, pose=extrinsic_params_inv)
# Render the scene
renderer = pyrender.OffscreenRenderer(640, 896)
color, depth = renderer.render(scene)
rendered_images_list.append(color)
renderer.delete()
# Save rendered images as a video
rendered_images = np.stack(rendered_images_list)
imageio.mimwrite('rendered_results.mp4', rendered_images, fps=15)
print("Rendered results saved as rendered_results.mp4")
# Load an existing video and test alignment
video_path = param_path.replace("param", "videos").replace("npy", "mp4")
input_video = imageio.get_reader(video_path)
input_frames = [frame for frame in input_video]
blended_frames = [(0.5 * frame + 0.5 * render_frame).astype(np.uint8) for render_frame, frame in zip(rendered_images, input_frames)]
imageio.mimwrite('aligned_results.mp4', blended_frames, fps=15)
print("Blended video saved as aligned_results.mp4")
================================================
FILE: env/README.md
================================================
# Environment Setup Guide
## Prerequisites
- Python 3.10
- CUDA 11.8
- PyTorch 2.3.1
## Installation Steps
### 1. Environment Preparation
First, create and activate a conda environment:
```bash
conda create -n idol python=3.10
conda activate idol
```
Install all dependencies:
```bash
bash scripts/pip_install.sh
```
### 2. Download Required Models
Before proceeding, please register on:
- [SMPL-X website](https://smpl-x.is.tue.mpg.de/)
- [FLAME website](https://flame.is.tue.mpg.de/)
Then download the template files:
```bash
bash scripts/fetch_template.sh
```
### 3. Download Pretrained Models and caches with:
```bash
bash scripts/download_files.sh # download pretrained models
```
Or mannually download the following models from HuggingFace:
- [IDOL Model Checkpoint](https://huggingface.co/yiyuzhuang/IDOL/blob/main/model.ckpt)
- [Sapiens Pretrained Model](https://huggingface.co/yiyuzhuang/IDOL/blob/main/sapiens_1b_epoch_173_torchscript.pt2)
## System Requirements
- **GPU**: NVIDIA GPU with CUDA 11.8 support
- **GPU RAM**: Recommended 24GB+
- **Storage**: At least 15GB free space
## Common Issues & Solutions
**Issue**:
```
ImportError: libGL.so.1: cannot open shared object file: No such file or directory
```
when importing OpenCV (`import cv2`)
**Solution**:
```bash
# For Ubuntu/Debian
sudo apt-get install libgl1-mesa-glx
```
### 2. Gaussian Splatting Antialiasing Issue
**Issue**: Error related to `antialiasing=True` setting in `GaussianRasterizationSettings`
**Solution**:
This issue arises due to updates in the Gaussian Splatting repository. Reinstalling the module from the GitHub repository with the latest version should resolve the problem.
================================================
FILE: lib/__init__.py
================================================
from .models import *
from .mmutils import *
from .humanlrm_wrapper_sa_v1 import SapiensGS_SA_v1
================================================
FILE: lib/datasets/__init__.py
================================================
from .avatar_dataset import AvatarDataset
from .dataloader import DataModuleFromConfig
================================================
FILE: lib/datasets/avatar_dataset.py
================================================
import os
import random
import numpy as np
import torch
import json
import pickle
from torch.utils.data import Dataset
import torchvision.transforms.functional as F
import pickle
from torch.utils.data import Dataset
import webdataset as wds
# from lib.utils.train_util import print
import cv2
import av
from omegaconf import OmegaConf, ListConfig
def load_pose(path):
with open(path, 'rb') as f:
pose_param = json.load(f)
c2w = np.array(pose_param['cam_param'], dtype=np.float32).reshape(4,4)
cam_center = c2w[:3, 3]
w2c = np.linalg.inv(c2w)
# pose[:,:2] *= -1
# pose = np.loadtxt(path, dtype=np.float32, delimiter=' ').reshape(9, 4)
return [torch.from_numpy(w2c), torch.from_numpy(cam_center)]
def load_npy(file_path):
return np.load(file_path, allow_pickle=True)
def load_smpl(path, smpl_type='smpl'):
filetype = path.split('.')[-1]
with open(path, 'rb') as f:
if filetype=='pkl':
smpl_param_data = pickle.load(f)
elif filetype == 'json':
smpl_param_data = json.load(f)
else:
assert False
if smpl_type=='smpl':
with open(os.path.join(os.path.split(path)[0][:-5], 'pose', '000_000.json'), 'rb') as f:
tf_param = json.load(f)
smpl_param = np.concatenate([np.array(tf_param['scale']).reshape(1, -1), np.array(tf_param['center'])[None],
smpl_param_data['global_orient'], smpl_param_data['body_pose'].reshape(1, -1), smpl_param_data['betas']], axis=1)
elif smpl_type == 'smplx':
tf_param = np.load(os.path.join(os.path.dirname(os.path.dirname(path)), 'scale_offset.npy'), allow_pickle=True).item()
# smpl_param = np.concatenate([np.array([tf_param['scale']]).reshape(1, -1), tf_param['offset'].reshape(1, -1),
smpl_param = np.concatenate([np.array([[1]]), np.array([[0,0,0]]),
np.array(smpl_param_data['global_orient']).reshape(1, -1), np.array(smpl_param_data['body_pose']).reshape(1, -1),
np.array(smpl_param_data['betas']).reshape(1, -1), np.array(smpl_param_data['left_hand_pose']).reshape(1, -1), np.array(smpl_param_data['right_hand_pose']).reshape(1, -1),
np.array(smpl_param_data['jaw_pose']).reshape(1, -1), np.array(smpl_param_data['leye_pose']).reshape(1, -1), np.array(smpl_param_data['reye_pose']).reshape(1, -1),
np.array(smpl_param_data['expression']).reshape(1, -1)], axis=1)
else:
assert False
return torch.from_numpy(smpl_param.astype(np.float32)).reshape(-1)
class AvatarDataset(Dataset):
def __init__(self,
data_prefix,
code_dir=None,
# code_only=False,
load_imgs=True,
load_norm=False,
specific_observation_idcs=None,
specific_observation_num=None,
first_is_front=False, # yy add # If True, it will returns a random sampled batch with the front view in the first place
better_range=False, # yy add # If True, the views will not be fully random, but will be selected by a better skip
if_include_video_ref_img= False,# yy add Define a variable to indicate whether to include reference images from the video
prob_include_video_ref_img= 0.2, # yy add Define a variable to specify the probability
allow_k_angles_near_the_front = 0, # yy add, if value > 0, the front view will be allowed to be selected from the range of [front_view - allow_k_angles_near_the_front, front_view + allow_k_angles_near_the_front]
# num_test_imgs=0,
if_use_swap_face_v1=False, # yy add, if True, use the swap face v1
random_test_imgs=False,
scene_id_as_name=False,
cache_path=None,
cache_repeat=None, # be the same length with the cache_path
test_pose_override=None,
num_train_imgs=-1,
load_cond_data=True,
load_test_data=True,
max_num_scenes=-1, # for debug or testing
# radius=0.5,
radius=1.0,
img_res=[640, 896],
test_mode=False,
step=1, # only for debug & visualization purpose
crop=False # randomly crop the image with upper body inputs
):
super(AvatarDataset, self).__init__()
self.data_prefix = data_prefix
self.code_dir = code_dir
# self.code_only = code_only
self.load_imgs = load_imgs
self.load_norm = load_norm
self.specific_observation_idcs = specific_observation_idcs
self.specific_observation_num = specific_observation_num
self.first_is_front = first_is_front
self.if_include_video_ref_img= if_include_video_ref_img
self.prob_include_video_ref_img = prob_include_video_ref_img
self.allow_k_angles_near_the_front = allow_k_angles_near_the_front
self.better_range = better_range
# self.num_test_imgs = num_test_imgs
self.random_test_imgs = random_test_imgs
self.scene_id_as_name = scene_id_as_name
self.cache_path = cache_path
self.cache_repeat = cache_repeat
self.test_pose_override = test_pose_override
self.num_train_imgs = num_train_imgs
self.load_cond_data = load_cond_data
self.load_test_data = load_test_data
self.max_num_scenes = max_num_scenes
self.step = step
self.if_use_swap_face_v1 = if_use_swap_face_v1
# import ipdb; ipdb.set_trace()
self.img_res = [int(i) for i in img_res]
self.radius = torch.tensor([radius], dtype=torch.float32).expand(3)
self.center = torch.zeros_like(self.radius)
self.load_scenes()
self.crop = crop
self.test_poses = self.test_intrinsics = None
self.defalut_focal = 1120 #40 * (self.img_res[0]/32) # focal 80mm, sensor 32mm
self.default_fxy_cxy = torch.tensor([self.defalut_focal, self.defalut_focal, self.img_res[1]//2, self.img_res[0]//2]).reshape(1, 4)
self.test_mode = test_mode
if self.test_mode:
self.parse_scene = self.parse_scene_test
def load_scenes(self):
if isinstance(self.cache_path, ListConfig):
cache_list = []
case_num_per_dataset = 1000000000
for ii, path in enumerate(self.cache_path):
cache = np.load(path, allow_pickle=True)
if self.cache_repeat is not None:
cache = np.repeat(cache, self.cache_repeat[ii], axis=0)
print("done loading ", path)
cache_list.extend(cache[:case_num_per_dataset])
scenes = cache_list
print(f"=========intialized totally {len(scenes)} scenes===========")
else:
if self.cache_path is not None and os.path.exists(self.cache_path):
scenes = np.load(self.cache_path, allow_pickle=True)
print("load ", self.cache_path)
else:
print(f"{self.cache_path} is not exist")
raise FileNotFoundError(f"maybe {self.cache_path} is not exist")
end = len(scenes)
if self.max_num_scenes >= 0:
end = min(end, self.max_num_scenes * self.step)
self.scenes = scenes[:end:self.step]
self.num_scenes = len(self.scenes)
def parse_scene(self, scene_id):
scene = self.scenes[scene_id]
input_is_video = False # flag of if the input is video, some operations should be different
# print(scene)
# scene['video_path'] = "/data/jxlv/transformers/src/A_pose_MEN-Denim-id_00000089-01_7_additional/result.mp4"
# scene['image_paths'] = None #"/data/jxlv/transformers/src/A_pose_MEN-Denim-id_00000089-01_7_additional/source_seg.png"
# import pdb
# pdb.set_trace()
# =========== loading the params ===========
param = np.load(scene['param_path'], allow_pickle=True).item()
scene.update(param)
print(scene.keys())
# =========== loading the multi-view images ===========
if scene['image_paths'] is None:
input_is_video = True
video_path = scene['video_path']
try:
if self.if_use_swap_face_v1:
image_paths_or_video = read_frames(scene['video_path'].replace('result.mp4', 'output.mp4'))
else:
image_paths_or_video = read_frames(scene['video_path'])
except Exception as e:
print(f"Error: {e}")
print(f"Error in reading the video : {scene['video_path'].replace('result.mp4', 'output.mp4')}")
image_paths_or_video = read_frames(scene['video_path'])
# if 'pose_animate_service_0727' in video_path or 'flux' in video_path:
# # move the first to the last
# # TODO fixed this bug with a better cameras parameters
# image_paths_or_video = image_paths_or_video[1:] + image_paths_or_video[0:1]
if not input_is_video:
image_paths_or_video = scene['image_paths']
scene_name = f"{scene_id:0>4d}" # image_paths[0].split('/')[-3]
results = dict(
scene_id=[scene_id],
scene_name=
'{:04d}'.format(scene_id) if self.scene_id_as_name else scene_name,
# cpu_only=True
)
# import pdb; pdb.set_trace()
# if not self.code_only:
poses = scene['poses']
smpl_params = scene['smpl_params']
# if input_is_video:
# num_imgs = len(video)
# else:
num_imgs = len(image_paths_or_video)
# front_view = num_imgs // 4
# randonly / specificically select the views of output
smplx_cam_rotate = smpl_params[4: 7] #get global orient # 1, 3, 63, 10
# smpl_params[70:80] = torch.rand_like(smpl_params[70:80]); print("error !! need to delete this rand betas in avatarnet:287") # get betas
front_view = find_front_camera_by_rotation(poses, smplx_cam_rotate) # inputs camera poses and smplx poses
if self.allow_k_angles_near_the_front > 0:
allow_n_views_near_the_front = round(self.allow_k_angles_near_the_front / 360 * num_imgs)
new_front_view = random.randint(-allow_n_views_near_the_front, allow_n_views_near_the_front) + front_view
if new_front_view >= num_imgs:
new_front_view = new_front_view - num_imgs
elif new_front_view < 0:
new_front_view = new_front_view + num_imgs
front_view = new_front_view
print("changes the front views ranges", front_view, "+-", allow_n_views_near_the_front)
if self.specific_observation_idcs is None: ######### if not specify views ########
# if self.num_train_imgs >= 0:
# num_train_imgs = self.num_train_imgs
# else:
num_train_imgs = num_imgs
if self.random_test_imgs: ###### randomly selected images with self.num_train_imgs ######
cond_inds = random.sample(range(num_imgs), self.num_train_imgs)
elif self.specific_observation_num: ###### randomly selected "specific_observation_num" images ######
if self.first_is_front and self.specific_observation_num < 2:
# self.specific_observation_num = 2
cond_inds =torch.tensor([front_view, front_view]) # first for input, second for supervised
elif self.better_range: # select views by a uniform distribution range
if self.first_is_front: # must include the front view
num_train_imgs = self.specific_observation_num - 2
else:
num_train_imgs = self.specific_observation_num
skip_range = num_imgs//num_train_imgs
# select random views from each range of [skip_range] seperate from [0, skip_range, 2*skip_range, ...],
cond_inds = torch.randperm(num_train_imgs) * skip_range \
+ torch.randint(low=0, high=skip_range, size=[num_train_imgs])
if self.first_is_front: # concat [the first view * 2] to the front of cond_inds
cond_inds = torch.cat([torch.tensor([front_view, front_view]), cond_inds])
else: # previous version, random views are sampled
cond_inds = torch.randperm(num_imgs)[:self.specific_observation_num]
else:
cond_inds = np.round(np.linspace(0, num_imgs - 1, num_train_imgs)).astype(np.int64)
else: ######### selected target views ########
cond_inds = self.specific_observation_idcs
test_inds = list(range(num_imgs))
if self.specific_observation_num: # yy note: if specific_observation_num is not None, then remove the test_inds
test_inds = []
else:
for cond_ind in cond_inds:
test_inds.remove(cond_ind)
cond_smpl_param_ref = torch.zeros([189])
if_use_smpl_param_ref = torch.Tensor([1]) # 默认使用ref smpl,
if self.load_cond_data and len(cond_inds) > 0:
# cond_imgs, cond_poses, cond_intrinsics, cond_img_paths, cond_smpl_param, cond_norm = gather_imgs(cond_inds)
cond_imgs, cond_poses, cond_intrinsics, cond_img_paths, cond_smpl_param, cond_norm = \
gather_imgs(cond_inds, poses, image_paths_or_video, smpl_params, load_imgs=self.load_imgs, load_norm=self.load_norm, center=self.center, radius=self.radius,
input_is_video=input_is_video)
cond_smpl_param_ref = cond_smpl_param.clone() # the smpl_param_ref for the reference images
if cond_intrinsics.shape[-1] == 3: # the old data format, which contains the value of camera center instead of fxfycxcy
cond_intrinsics = self.default_fxy_cxy.clone().repeat(cond_intrinsics.shape[0], 1)
# import pdb; pdb.set_trace()
# print("video_path", video_path)
if input_is_video:
cond_img_paths = [f"{video_path[:-4]}_{i:0>4d}.png" for i in range(self.specific_observation_num)] # Replace the .mp4 into index.png
if self.if_include_video_ref_img and input_is_video:
# 设置一个随机数,如果小于某个概率,那么替换第一张图为另一个图片
if np.random.rand() < self.prob_include_video_ref_img:
if 'image_ref' in scene:
ref_image_path = scene['image_ref']
print("ref_image_path",ref_image_path)
else:
ref_image_path = video_path.replace(".mp4", ".jpg")
# if "flux" in ref_image_path: # temperaturelly supports the inputs from the flux
try:
# replacement_img_path = ref_image_path
# replacement_img = load_image(ref_image_path) # 假设有一个函数可以加载图片
# 使用cv2.IMREAD_UNCHANGED标志读取图片,以保留alpha通道
img = cv2.imread(ref_image_path, cv2.IMREAD_UNCHANGED)
assert img is not None, f"img is None, {ref_image_path}"
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img = torch.from_numpy(img.astype(np.float32) / 255) # (h, w, 3)
# print("img.shape", img.shape)
# print("cond_imgs.shape", cond_imgs.shape)
# test_img_paths[0] = ref_image_path
# results.update(test_imgs=test_imgs, test_img_paths=test_img_paths)
# ======== loading the reference smplx for images ==========
load_ref_smplx = False
if load_ref_smplx:
# if "flux" in ref_image_path:
smplx_smplify_path = from_video_to_get_ref_smplx(video_path)
# load json and get values
with open(smplx_smplify_path) as f:
data = json.load(f)
RT = torch.concatenate([ torch.Tensor(data['camera']['R']), torch.Tensor(data['camera']['t']).reshape(3,1)], dim=1)
RT = torch.cat([RT, torch.Tensor([[0,0,0,1]])], dim=0)
intri = torch.Tensor(data['camera']['focal'] + data['camera']['princpt'])
intri[[3,2]] = intri[[2,3]]
intri = intri * self.default_fxy_cxy[0,-1] / intri[-1]
# 假设 smpl_param_data 是已经加载好的数据
# (['root_pose', 'body_pose', 'jaw_pose', 'leye_pose', 'reye_pose', 'lhand_pose', 'rhand_pose', 'expr', 'trans', 'betas_save', 'nlf_smplx_betas', 'camera', 'img_path'])
smpl_param_data = data
# 从字典中提取所需的数据
global_orient = np.array(smpl_param_data['root_pose']).reshape(1, -1)
body_pose = np.array(smpl_param_data['body_pose']).reshape(1, -1)
shape = np.array(smpl_param_data['betas_save']).reshape(1, -1)[:, :10]
left_hand_pose = np.array(smpl_param_data['lhand_pose']).reshape(1, -1)
right_hand_pose = np.array(smpl_param_data['rhand_pose']).reshape(1, -1)
# smpl_param_ref = np.concatenate([np.array([[1.]]), np.array([[0, 0, 0]]),
smpl_param_ref = np.concatenate([np.array([[1.]]), np.array(smpl_param_data['trans']).reshape(1,3),
global_orient,body_pose,
shape, left_hand_pose, right_hand_pose,
np.array(smpl_param_data['jaw_pose']).reshape(1, -1), np.array(smpl_param_data['leye_pose']).reshape(1, -1), np.array(smpl_param_data['reye_pose']).reshape(1, -1),
np.array(smpl_param_data['expr']).reshape(1, -1)[:,:10]], axis=1)
cond_poses[0] = RT # RT
cond_intrinsics[0] = intri # fxfycxcy
cond_smpl_param_ref = torch.Tensor(smpl_param_ref).reshape(-1) # 189, combines
if_use_smpl_param_ref = torch.Tensor([1]) # use the smpl_param_ref
# overwrite some datas
cond_imgs[0] = img
cond_img_paths[0] = ref_image_path
except (FileNotFoundError, json.JSONDecodeError, KeyError, Exception) as e:
# 记录错误信息到日志文件
# with open(self.log_file_path, 'a') as log_file:
# log_file.write(f"{datetime.datetime.now()} {video_path} \n - An error occurred: {str(e)}\n")
print(f"An error occurred: {e}")
''' randomly crop the first image for augmentation'''
if self.crop:
# print("crop", cond_imgs[0].shape)
if random.random() < 0.5:
# cond_imgs[0] = F.crop(cond_imgs[0], 0, 0, 512, 512)
crop_imgs = cond_imgs[0]
# 图像尺寸
h, w, _ = crop_imgs.shape
# 随机偏移量
random_offset_head = np.random.randint(-h//7, -h//8)
random_offset_body = np.random.randint(-h // 8, h // 8)
# head_joint, upper_body_joint
head_joint = [ w//2, h//7,]
upper_body_joint = [w//2, h//2, ]
# 计算裁剪区域
head_y = int(head_joint[1]) + random_offset_head
body_y = int(upper_body_joint[1]) + random_offset_body
# 确保裁剪区域在图像范围内
head_y = max(0, min(h, head_y))
body_y = max(0, min(h, body_y))
# 裁剪区域的高度和宽度
crop_height = body_y - head_y
crop_width =int(crop_height * 640 / 896)
# 确保裁剪区域在图像范围内
start_x = max(0, min(w - crop_width, int(w // 2 - crop_width // 2)))
end_x = start_x + crop_width
start_y = max(0, head_y)
end_y = min(h, body_y)
# 裁剪图像
cropped_img = crop_imgs[start_y:end_y, start_x:end_x]
padded_img = F.resize(cropped_img.permute(2, 0, 1), [h, w]).permute(1, 2, 0)
# save this img for debug
# Image.fromarray((padded_img.numpy() * 255).astype(np.uint8)).save(f"debug_crop.png")
# rescale the image for augmentation
cond_imgs[0] = random_scale_and_crop(padded_img, (0.8,1.2))
else:
cond_imgs[0] = random_scale_and_crop(cond_imgs[0], (0.8,1.1))
results.update(
cond_poses=cond_poses,
cond_intrinsics=cond_intrinsics.to(torch.float32),
cond_img_paths=cond_img_paths,
cond_smpl_param=cond_smpl_param,
cond_smpl_param_ref=cond_smpl_param_ref,
if_use_smpl_param_ref=if_use_smpl_param_ref)
if cond_imgs is not None:
results.update(cond_imgs=cond_imgs)
if cond_norm is not None:
results.update(cond_norm=cond_norm)
if self.load_test_data and len(test_inds) > 0:
test_imgs, test_poses, test_intrinsics, test_img_paths, test_smpl_param, test_norm = \
gather_imgs(test_inds, poses, image_paths_or_video, smpl_params, load_imgs=self.load_imgs, load_norm=self.load_norm, center=self.center, radius=self.radius)
if test_intrinsics.shape[-1] == 3: # the old data format, which contains the value of camera center instead of fxfycxcy
test_intrinsics = self.default_fxy_cxy.clone().repeat(test_intrinsics.shape[0], 1)
results.update(
test_poses=test_poses,
test_intrinsics=test_intrinsics,
test_img_paths=test_img_paths,
test_smpl_param=test_smpl_param)
if test_imgs is not None:
results.update(test_imgs=test_imgs)
if test_norm is not None:
results.update(test_norm=test_norm)
if self.test_pose_override is not None:
results.update(test_poses=self.test_poses, test_intrinsics=self.test_intrinsics)
return results
def __len__(self):
return self.num_scenes
def __getitem__(self, scene_id):
try:
scene = self.parse_scene(scene_id)
except:
print("ERROR in parsing ", scene_id)
scene = self.parse_scene(0)
return scene
def parse_scene_test(self, scene_id):
scene = self.scenes[scene_id]
input_is_video = False # flag of if the input is video, some operations should be different
# =========== loading the params ===========
param = np.load(scene['param_path'], allow_pickle=True).item()
scene.update(param)
print(scene.keys())
# import ipdb; ipdb.set_trace()
if scene['image_paths'] is None:
input_is_video = True
video_path = scene['video_path']
try:
image_paths_or_video = read_frames(scene['video_path'])
except Exception as e:
print(f"Error: {e}")
print(f"Error in reading the video : {scene['video_path'].replace('result.mp4', 'output.mp4')}")
image_paths_or_video = read_frames(scene['video_path'])
if not input_is_video:
image_paths_or_video = scene['image_paths']
scene_name = f"{scene_id:0>4d}" # image_paths[0].split('/')[-3]
results = dict(
scene_id=[scene_id],
scene_name=
'{:04d}'.format(scene_id) if self.scene_id_as_name else scene_name,
# cpu_only=True
)
# if not self.code_only:
poses = scene['poses']
smpl_params = scene['smpl_params']
# if input_is_video:
# num_imgs = len(video)
# else:
num_imgs = len(image_paths_or_video)
# front_view = num_imgs // 4
# randonly / specificically select the views of output
smplx_cam_rotate = smpl_params[4: 7] #get global orient # 1, 3, 63, 10
# smpl_params[70:80] = torch.rand_like(smpl_params[70:80]); print("error !! need to delete this rand betas in avatarnet:287") # get betas
front_view = find_front_camera_by_rotation(poses, smplx_cam_rotate) # inputs camera poses and smplx poses
if self.allow_k_angles_near_the_front > 0:
allow_n_views_near_the_front = round(self.allow_k_angles_near_the_front / 360 * num_imgs)
new_front_view = random.randint(-allow_n_views_near_the_front, allow_n_views_near_the_front) + front_view
if new_front_view >= num_imgs:
new_front_view = new_front_view - num_imgs
elif new_front_view < 0:
new_front_view = new_front_view + num_imgs
front_view = new_front_view
print("changes the front views ranges", front_view, "+-", allow_n_views_near_the_front)
num_train_imgs = num_imgs
test_inds = torch.Tensor( list(range(num_imgs)))
cond_inds = np.concatenate([np.array([front_view]),test_inds]).astype(np.int64) # first for input, second for supervised
test_inds = cond_inds.tolist()
# if self.specific_observation_num: # yy note: if specific_observation_num is not None, then remove the test_inds
# test_inds = []
# else:
# for cond_ind in cond_inds:
# test_inds.remove(cond_ind)
# cond_inds = cond_inds
cond_smpl_param_ref = torch.zeros([189])
if_use_smpl_param_ref = torch.Tensor([1]) # 默认使用ref smpl,
if self.load_cond_data and len(cond_inds) > 0:
# cond_imgs, cond_poses, cond_intrinsics, cond_img_paths, cond_smpl_param, cond_norm = gather_imgs(cond_inds)
cond_imgs, cond_poses, cond_intrinsics, cond_img_paths, cond_smpl_param, cond_norm = \
gather_imgs(cond_inds, poses, image_paths_or_video, smpl_params, load_imgs=self.load_imgs, load_norm=self.load_norm, center=self.center, radius=self.radius, \
input_is_video=input_is_video)
cond_smpl_param_ref = cond_smpl_param.clone() # the smpl_param_ref for the reference images
if cond_intrinsics.shape[-1] == 3: # the old data format, which contains the value of camera center instead of fxfycxcy
cond_intrinsics = self.default_fxy_cxy.clone().repeat(cond_intrinsics.shape[0], 1)
if input_is_video:
cond_img_paths = [f"{video_path[:-4]}_{i:0>4d}.png" for i in range(self.specific_observation_num)] # Replace the .mp4 into index.png
results.update(
cond_poses=cond_poses,
cond_intrinsics=cond_intrinsics.to(torch.float32),
cond_img_paths=cond_img_paths,
cond_smpl_param=cond_smpl_param,
cond_smpl_param_ref=cond_smpl_param_ref,
if_use_smpl_param_ref=if_use_smpl_param_ref)
if cond_imgs is not None:
results.update(cond_imgs=cond_imgs)
if cond_norm is not None:
results.update(cond_norm=cond_norm)
if self.load_test_data and len(test_inds) > 0:
print("input_is_video", input_is_video)
test_imgs, test_poses, test_intrinsics, test_img_paths, test_smpl_param, test_norm = \
gather_imgs(test_inds, poses, image_paths_or_video, smpl_params, load_imgs=self.load_imgs, load_norm=self.load_norm, center=self.center, radius=self.radius, \
input_is_video=input_is_video)
if test_intrinsics.shape[-1] == 3: # the old data format, which contains the value of camera center instead of fxfycxcy
test_intrinsics = self.default_fxy_cxy.clone().repeat(test_intrinsics.shape[0], 1)
results.update(
test_poses=test_poses,
test_intrinsics=test_intrinsics,
test_img_paths=test_img_paths,
test_smpl_param=test_smpl_param)
if test_imgs is not None:
results.update(test_imgs=test_imgs)
if test_norm is not None:
results.update(test_norm=test_norm)
if self.test_pose_override is not None:
results.update(test_poses=self.test_poses, test_intrinsics=self.test_intrinsics)
return results
def gather_imgs(img_ids, poses, image_paths_or_video, smpl_params, load_imgs=True, load_norm=False, center=None, radius=None, input_is_video=False):
imgs_list = [] if load_imgs else None
norm_list = [] if load_norm else None
poses_list = []
cam_centers_list = []
img_paths_list = []
for img_id in img_ids:
pose = poses[img_id][0]
cam_centers_list.append((poses[img_id][1]).to(torch.float)) # (C)
c2w = pose.to(torch.float)#torch.FloatTensor(pose) # 虽然是c2w但其实存的值应该是w2c (R|T)
cam_to_ndc = torch.cat(
[c2w[:3, :3], (c2w[:3, 3:] - center[:, None]) / radius[:, None]], dim=-1)
poses_list.append(
torch.cat([
cam_to_ndc,
cam_to_ndc.new_tensor([[0.0, 0.0, 0.0, 1.0]])
], dim=-2))
if input_is_video:
# img_paths_list.append(video[img_id])
img = image_paths_or_video[img_id]
# for img in imgs: # add the ajustment to make the color > [250,250,250] to be white
mask_white = np.all(img[:,:,:3] > 250, axis=-1, keepdims=False)
# Image.fromarray(img).save(f"debug.png")
# Image.fromarray(mask_white).save(f"debug_mask.png")
img[mask_white] = [255, 255, 255]
# Image.fromarray(img).save(f"debug_afmask.png")
img = torch.from_numpy(img.astype(np.float32) / 255) # (h, w, 3)
imgs_list.append(img)
else:
img_paths_list.append(image_paths_or_video[img_id])
if load_imgs:
# img = mmcv.imread(image_paths[img_id], channel_order='rgb')
# 使用cv2.IMREAD_UNCHANGED标志读取图片,以保留alpha通道
print("Loading, .......", image_paths_or_video[img_id])
print("Loading, .......", image_paths_or_video[img_id])
print("Loading, .......", image_paths_or_video[img_id])
img = cv2.imread(image_paths_or_video[img_id], cv2.IMREAD_UNCHANGED)
print("img.shape", img.shape)
# 将透明像素的RGB值设置为白色(255, 255, 255)
img[img[..., 3] == 0] = [255, 255, 255, 255]
img = img[..., :3]
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img = torch.from_numpy(img.astype(np.float32) / 255) # (h, w, 3)
imgs_list.append(img)
if load_norm: # have not support the input type is video
norm = cv2.imread(image_paths_or_video[img_id].replace('rgb', 'norm'), cv2.IMREAD_UNCHANGED)
norm = cv2.cvtColor(norm, cv2.COLOR_BGR2RGB)
norm = torch.from_numpy(norm.astype(np.float32) / 255)
norm_list.append(norm)
poses_list = torch.stack(poses_list, dim=0) # (n, 4, 4)
cam_centers_list = torch.stack(cam_centers_list, dim=0)
if load_imgs:
imgs_list = torch.stack(imgs_list, dim=0) # (n, h, w, 3)
if load_norm:
norm_list = torch.stack(norm_list, dim=0)
return imgs_list, poses_list, cam_centers_list, img_paths_list, smpl_params, norm_list
from scipy.spatial.transform import Rotation as R
def calculate_angle(vector1, vector2):
unit_vector1 = vector1 / torch.linalg.norm(vector1)
unit_vector2 = vector2 / torch.linalg.norm(vector2)
dot_product = torch.dot(unit_vector1, unit_vector2)
angle = torch.arccos(dot_product)
return angle
def axis_angle_to_rotation_matrix(axis_angle):
if_input_is_torch = torch.is_tensor(axis_angle)
if if_input_is_torch:
dtype_torch = axis_angle.dtype
axis_angle = axis_angle.numpy()
r = R.from_rotvec(axis_angle)
rotation_matrix = r.as_matrix()
if if_input_is_torch:
rotation_matrix = torch.from_numpy(rotation_matrix).to(dtype_torch)
return rotation_matrix
# def find_front_camera_by_global_orient(global_orient, camera_direction):
# front_direction = np.array([0, 0, -1]) # 人体正面方向
# min_angle = float('inf')
# front_camera_idx = -1
# for idx, global_orient in enumerate(global_orient_list):
# angle = calculate_angle(body_direction, front_direction)
# if angle < min_angle:
# min_angle = angle
# front_camera_idx = idx
# return front_camera_idx
def find_front_camera_by_rotation(poses, global_orient_human):
# front_direction = global_orient_human # 人体正面方向
rotation_matrix = axis_angle_to_rotation_matrix(global_orient_human)
front_direction = rotation_matrix @ torch.Tensor([0, 0, -1]) # 人体的朝向
min_angle = float('inf')
front_camera_idx = -1
for idx, pose in enumerate(poses):
rotation_matrix = pose[0][:3, :3]
camera_direction = rotation_matrix @ torch.Tensor([0, 0, 1]) # 相机的朝向
angle = calculate_angle(camera_direction, front_direction).to(camera_direction.dtype)
if angle < min_angle:
min_angle = angle
front_camera_idx = idx
return front_camera_idx
def read_frames(video_path):
container = av.open(video_path)
video_stream = next(s for s in container.streams if s.type == "video")
frames = []
for packet in container.demux(video_stream):
for frame in packet.decode():
# image = Image.frombytes(
# "RGB",
# (frame.width, frame.height),
# frame.to_rgb().to_ndarray(),
# )
image = frame.to_rgb().to_ndarray()
frames.append(image)
return frames
def prepare_camera( resolution_x, resolution_y, num_views=24, stides=1):
# resolution_x = 640
# resolution_y = 896
import math
focal_length = 40 #80
sensor_width = 32
# # 创建 Pyrender 相机
# camera = pyrender.PerspectiveCamera(yfov=fov, aspectRatio=aspect_ratio)
focal_length = focal_length * (resolution_y/sensor_width)
K = np.array(
[[focal_length, 0, resolution_x//2],
[0, focal_length, resolution_y//2],
[0, 0, 1]]
)
# print("update!! the camera intrisic is error 0819")
def look_at(camera_position, target_position, up_vector): # colmap +z forward, +y down
forward = -(camera_position - target_position) / np.linalg.norm(camera_position - target_position)
right = np.cross(up_vector, forward)
up = np.cross(forward, right)
return np.column_stack((right, up, forward))
camera_pose_list = []
for frame_idx in range(0, num_views, stides):
# 设置相机的位置和方向
camera_dist = 1.5 #3 #1.2 * 2
phi = math.radians(90)
theta = (frame_idx / num_views) * math.pi * 2
camera_location = np.array(
[camera_dist * math.sin(phi) * math.cos(theta),
camera_dist * math.cos(phi),
-camera_dist * math.sin(phi) * math.sin(theta),]
)
# print(camera_location)
camera_pose = np.eye(4)
camera_pose[:3, 3] = camera_location
# print("camera_location", camera_location)
# from smplx import look_at
# 设置相机位置和目标位置
camera_position = camera_location
target_position = np.array([0.0, 0.0, 0.0])
# 计算相机的旋转矩阵,使其朝向目标
# up_vector = np.array([0.0, 1.0, 0.0])
up_vector = np.array([0.0, -1.0, 0.0]) # colmap
rotation_matrix = look_at(camera_position, target_position, up_vector)
# 更新相机的位置和旋转
camera_pose[:3, :3] = rotation_matrix
camera_pose[:3, 3] = camera_position
camera_pose_list.append(camera_pose)
return K, camera_pose_list
def from_video_to_get_ref_smplx(video_path):
# 分解路径
video_dir = os.path.dirname(video_path)
video_name = video_dir.split("/")[-1] # 视频文件夹名称
# 替换视频目录为 smplify 目录
if "flux" in video_dir:
smplify_dir = video_dir.replace('/videos/', '/smplx_smplify/')
elif "DeepFashion" in video_dir:
smplify_dir = video_dir.replace('/video/', '/smplx_smplify/').replace("A_pose_", "")
# # 获取视频文件名(不包括扩展名)
# video_name = os.path.splitext(video_file)[0]
# 构建 JSON 文件路径
if smplify_dir[-1] == '/': smplify_dir = smplify_dir[:-1]
json_path = smplify_dir+".json" #os.path.join(smplify_dir, f"{video_name}.json")
return json_path
def random_scale_and_crop(image: torch.Tensor, scale_range=(0.8, 1.2)) -> torch.Tensor:
"""
Randomly scale the input image and crop/pad to maintain original size.
Args:
image: Input image tensor of shape [H, W, 3]
scale_range: Range for scaling factor, default (0.8, 1.2)
Returns:
Scaled and cropped/padded image tensor of shape [H, W, 3]
"""
is_numpy = False
if not torch.is_tensor(image):
image = torch.from_numpy(image)
is_numpy = True
# 获取图像的高度和宽度
h, w = image.shape[:2]
# 生成随机缩放因子
scale_factor = random.uniform(*scale_range)
# 计算新的高度和宽度
new_h = int(h * scale_factor)
new_w = int(w * scale_factor)
# 使用 torchvision.transforms.functional.resize 进行缩放
scaled_image = F.resize(image.permute(2, 0, 1), [new_h, new_w]).permute(1, 2, 0)
# 如果缩放后的图像比原图大,进行居中裁剪
if new_h > h or new_w > w:
top = (new_h - h) // 2
left = (new_w - w) // 2
scaled_image = scaled_image[top:top + h, left:left + w]
else:
# 如果缩放后的图像比原图小,进行居中填充
padded_image = torch.ones((h, w, 3), dtype=image.dtype)
top = h-new_h #(h - new_h) // 2 # H不应该居中
left = (w - new_w) // 2
padded_image[top:top + new_h, left:left + new_w] = scaled_image
scaled_image = padded_image
if is_numpy:
scaled_image = scaled_image.numpy()
return scaled_image
if __name__ == "__main__":
import os
params = {
"data_prefix": None,
"cache_path": ListConfig([
"./processed_data/deepfashion_train_145_local.npy",
"./processed_data/flux_batch1_5000_train_145_local.npy"
]),
"specific_observation_num": 5,
"better_range": True,
"first_is_front": True,
"if_include_video_ref_img": True,
"prob_include_video_ref_img": 0.5,
"img_res": [640, 896],
'test_mode': True
}
data = AvatarDataset(**params)
sample = data[0]
print(sample.keys())
import os
import torch.distributed as dist
os.environ['RANK'] = '0'
os.environ['WORLD_SIZE'] = '1'
os.environ['MASTER_ADDR'] = '127.0.0.1'
os.environ['MASTER_PORT'] = '29500'
dist.init_process_group(backend='nccl', rank=0, world_size=1)
# test the batch loader
from torch.utils.data import DataLoader
dataloader = DataLoader(data, batch_size=10, shuffle=True, collate_fn=custom_collate_fn)
from torch.utils.data.distributed import DistributedSampler
import webdataset as wds
# sampler = DistributedSampler(data) # training is true!~
sampler = None
dataloader = wds.WebLoader(data, batch_size=10, num_workers=1, shuffle=False, sampler=sampler, )
try:
for i, batch in enumerate(dataloader):
print(batch.keys())
# break
except Exception as e:
import traceback
print("Caught an exception during dataloader iteration:")
traceback.print_exc()
================================================
FILE: lib/datasets/dataloader.py
================================================
import os, sys
import json
import numpy as np
import webdataset as wds
import pytorch_lightning as pl
import torch
from torch.utils.data import Dataset
from torch.utils.data.distributed import DistributedSampler
file_dir = os.path.abspath(os.path.dirname(__file__))
project_root = os.path.join(file_dir, '..', '..')
sys.path.append(project_root)
from lib.utils.train_util import instantiate_from_config
from torch.utils.data import DataLoader
class DataModuleFromConfig(pl.LightningDataModule):
def __init__(
self,
batch_size=8,
num_workers=4,
train=None,
validation=None,
test=None,
**kwargs,
):
super().__init__()
self.batch_size = batch_size
self.num_workers = num_workers
self.dataset_configs = dict()
if train is not None:
self.dataset_configs['train'] = train
if validation is not None:
self.dataset_configs['validation'] = validation
if test is not None:
self.dataset_configs['test'] = test
def setup(self, stage):
self.datasets = dict((k, instantiate_from_config(self.dataset_configs[k])) for k in self.dataset_configs)
def train_dataloader(self):
sampler = DistributedSampler(self.datasets['train']) if torch.distributed.is_initialized() else None
return DataLoader(
self.datasets['train'],
batch_size=self.batch_size,
num_workers=self.num_workers,
shuffle=(sampler is None),
sampler=sampler,
pin_memory=True,
drop_last=True,
)
def val_dataloader(self):
sampler = DistributedSampler(self.datasets['validation']) if torch.distributed.is_initialized() else None
return DataLoader(
self.datasets['validation'],
batch_size=1,
num_workers=self.num_workers,
shuffle=False,
sampler=sampler,
pin_memory=True
)
def test_dataloader(self):
return DataLoader(
self.datasets['test'],
batch_size=self.batch_size,
num_workers=self.num_workers,
shuffle=False,
pin_memory=True
)
================================================
FILE: lib/humanlrm_wrapper_sa_v1.py
================================================
import os
import math
import json
from torch.optim import Adam
from torch.nn.parallel.distributed import DistributedDataParallel
import torch
import torch.nn.functional as F
from torchvision.transforms import InterpolationMode
from torchvision.utils import make_grid, save_image
from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity
from torchmetrics.image.ssim import StructuralSimilarityIndexMeasure
import pytorch_lightning as pl
from pytorch_lightning.utilities.grads import grad_norm
from einops import rearrange, repeat
from lib.utils.train_util import instantiate_from_config
from lib.ops.activation import TruncExp
import time
import matplotlib.pyplot as plt
from PIL import Image
import numpy as np
from lib.utils.train_util import main_print
from typing import List, Optional, Tuple, Union
def get_1d_rotary_pos_embed(
dim: int,
pos: Union[torch.Tensor, int],
theta: float = 10000.0,
use_real=False,
linear_factor=1.0,
ntk_factor=1.0,
repeat_interleave_real=True,
):
"""
Precompute the frequency tensor for complex exponentials (cis) with given dimensions.
This function calculates a frequency tensor with complex exponentials using the given dimension 'dim' and the end
index 'end'. The 'theta' parameter scales the frequencies. The returned tensor contains complex values in complex64
data type.
Args:
dim (`int`): Dimension of the frequency tensor.
pos (`torch.Tensor` or `int`): Position indices for the frequency tensor. [S] or scalar
theta (`float`, *optional*, defaults to 10000.0):
Scaling factor for frequency computation. Defaults to 10000.0.
use_real (`bool`, *optional*):
If True, return real part and imaginary part separately. Otherwise, return complex numbers.
linear_factor (`float`, *optional*, defaults to 1.0):
Scaling factor for the context extrapolation. Defaults to 1.0.
ntk_factor (`float`, *optional*, defaults to 1.0):
Scaling factor for the NTK-Aware RoPE. Defaults to 1.0.
repeat_interleave_real (`bool`, *optional*, defaults to `True`):
If `True` and `use_real`, real part and imaginary part are each interleaved with themselves to reach `dim`.
Otherwise, they are concateanted with themselves.
Returns:
`torch.Tensor`: Precomputed frequency tensor with complex exponentials. [S, D/2]
"""
assert dim % 2 == 0
if isinstance(pos, int):
pos = torch.arange(pos)
theta = theta * ntk_factor
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) / linear_factor # [D/2]
t = pos # torch.from_numpy(pos).to(freqs.device) # type: ignore # [S]
freqs = freqs.to(device=t.device, dtype=t.dtype) # type: ignore
freqs = torch.outer(t, freqs).float() # type: ignore # [S, D/2]
if use_real and repeat_interleave_real:
freqs_cos = freqs.cos().repeat_interleave(2, dim=1) # [S, D]
freqs_sin = freqs.sin().repeat_interleave(2, dim=1) # [S, D]
return freqs_cos, freqs_sin
elif use_real:
freqs_cos = torch.cat([freqs.cos(), freqs.cos()], dim=-1) # [S, D]
freqs_sin = torch.cat([freqs.sin(), freqs.sin()], dim=-1) # [S, D]
return freqs_cos, freqs_sin
else:
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 # [S, D/2]
return freqs_cis
class FluxPosEmbed(torch.nn.Module):
# modified from https://github.com/black-forest-labs/flux/blob/c00d7c60b085fce8058b9df845e036090873f2ce/src/flux/modules/layers.py#L11
def __init__(self, theta: int, axes_dim: [int]):
super().__init__()
self.theta = theta
self.axes_dim = axes_dim
def forward(self, ids: torch.Tensor) -> torch.Tensor:
n_axes = ids.shape[-1]
cos_out = []
sin_out = []
pos = ids.float()
is_mps = ids.device.type == "mps"
freqs_dtype = torch.float32 if is_mps else torch.float64
for i in range(n_axes):
cos, sin = get_1d_rotary_pos_embed(
self.axes_dim[i], pos[:, i], repeat_interleave_real=True, use_real=True#, freqs_dtype=freqs_dtype
)
cos_out.append(cos)
sin_out.append(sin)
freqs_cos = torch.cat(cos_out, dim=-1).to(ids.device)
freqs_sin = torch.cat(sin_out, dim=-1).to(ids.device)
return freqs_cos, freqs_sin
class SapiensGS_SA_v1(pl.LightningModule):
def __init__(
self,
encoder=dict(type='mmpretrain.VisionTransformer'),
neck=dict(type='mmpretrain.VisionTransformer'),
decoder=dict(),
diffusion_use_ema=True,
freeze_decoder=False,
image_cond=False,
code_permute=None,
code_reshape=None,
autocast_dtype=None,
ortho=True,
return_norm=False,
# reshape_type='reshape', # 'cnn'
code_size=None,
decoder_use_ema=None,
bg_color=1,
training_mode=None, # stage2's flag, default None for stage1
patch_size: int = 4,
warmup_steps: int = 12_000,
use_checkpoint: bool = True,
lambda_depth_tv: float = 0.05,
lambda_lpips: float = 2.0,
lambda_mse: float = 1.0,
lambda_l1: float=0,
lambda_ssim: float=0,
neck_learning_rate: float = 5e-4,
decoder_learning_rate: float = 1e-3,
encoder_learning_rate: float=0,
max_steps: int = 100_000,
loss_coef: float = 0.5,
init_iter: int = 500,
lambda_offset: int = 50, # offset_weight: 50
scale_weight: float = 0.01,
is_debug: bool = False, # if debug, then it will not returns lpips
code_activation: dict=None,
output_hidden_states: bool=False, # if True, will output the hidden states from sapiens shallow layer, for the neck decoder
loss_weights_views: List = [], # the loss weights for the views, if empty, will use the same weights for all the views
**kwargs
):
super(SapiensGS_SA_v1, self).__init__()
## ========== part -- Add the code to save this parameters for optimizers ========
self.warmup_steps = warmup_steps
self.use_checkpoint = use_checkpoint
self.lambda_depth_tv = lambda_depth_tv
self.lambda_lpips = lambda_lpips
self.lambda_mse = lambda_mse
self.lambda_l1 = lambda_l1
self.lambda_ssim = lambda_ssim
self.neck_learning_rate = neck_learning_rate
self.decoder_learning_rate = decoder_learning_rate
self.encoder_learning_rate = encoder_learning_rate
self.max_steps = max_steps
self.loss_coef = loss_coef
self.init_iter = init_iter
self.lambda_offset = lambda_offset
self.scale_weight = scale_weight
self.is_debug = is_debug
## ========== end part ========
self.code_size = code_size
if code_activation['type'] == 'tanh':
self.code_activation = torch.nn.Tanh()
else:
self.code_activation = TruncExp() #build_module(code_activation)
# self.grid_size = grid_size
self.decoder = instantiate_from_config(decoder)
self.decoder_use_ema = decoder_use_ema
if decoder_use_ema:
raise NotImplementedError("decoder_use_ema has not been implemented")
if self.decoder_use_ema:
self.decoder_ema = deepcopy(self.decoder)
self.encoder = instantiate_from_config(encoder)
# get_obj_from_str(config["target"])
self.code_size = code_reshape
self.code_clip_range = [-1,1]
# ============= begin config =============
# transformer from class MAEPretrainDecoder(BaseModule):
# compress the token number of the uv code
self.patch_size = patch_size
self.code_patch_size = self.patch_size
self.num_patches_axis = code_reshape[-1]//self.patch_size # reshape it for the upsampling
self.num_patches = self.num_patches_axis ** 2
self.code_feat_dims = code_reshape[0] # only used for the upsampling of 'reshape' type
self.code_resolution = code_reshape[-1] # only used for the upsampling of 'reshape' type
self.reshape_type = self.decoder.reshape_type
self.inputs_front_only = True
self.render_loss_all_view = True
self.if_include_video_ref_img = True
self.training_mode = training_mode
self.loss_weights_views = torch.Tensor(loss_weights_views).reshape(-1) / sum(loss_weights_views) # normalize the weights
# ========== config meaning ===========
self.neck = instantiate_from_config(neck)
self.ids_restore = torch.arange(0, self.num_patches).unsqueeze(0)
self.freeze_decoder = freeze_decoder
if self.freeze_decoder:
self.decoder.requires_grad_(False)
if self.decoder_use_ema:
self.decoder_ema.requires_grad_(False)
self.image_cond = image_cond
self.code_permute = code_permute
self.code_reshape = code_reshape
self.code_reshape_inv = [self.code_size[axis] for axis in self.code_permute] if code_permute is not None \
else self.code_size
self.code_permute_inv = [self.code_permute.index(axis) for axis in range(len(self.code_permute))] \
if code_permute is not None else None
self.autocast_dtype = autocast_dtype
self.ortho = ortho
self.return_norm = return_norm
'''add a flag for the skip connection from sapiens shallow layer'''
self.output_hidden_states = output_hidden_states
''' add the in-the-wild images visualization'''
if self.lambda_lpips > 0:
self.lpips = LearnedPerceptualImagePatchSimilarity(net_type='vgg')
else:
self.lpips = None
self.ssim = StructuralSimilarityIndexMeasure()
self.validation_step_outputs = []
self.validation_step_code_outputs = [] # saving the code
self.validation_step_nvPose_outputs = []
self.validation_metrics = []
# loading the smplx for the nv pose
import json
import numpy as np
# evaluate the animation
smplx_path = './work_dirs/demo_data/Ways_to_Catch_360_clip1.json'
with open(smplx_path, 'r') as f:
smplx_pose_param = json.load(f)
smplx_param_list = []
for par in smplx_pose_param['annotations']:
k = par['smplx_params']
for i in k.keys():
k[i] = np.array(k[i])
left_hands = np.array([1.4624, -0.1615, 0.1361, 1.3851, -0.2597, 0.0247, -0.0683, -0.4478,
-0.6652, -0.7290, 0.0084, -0.4818])
betas = torch.zeros((10))
smplx_param = \
np.concatenate([np.array([1]), np.array([0,0.,0]), np.array([0, -1, 0])*k['root_orient'], \
k['pose_body'],betas, \
k['pose_hand'], k['pose_jaw'], np.zeros(6), k['face_expr'][:10]], axis=0).reshape(1,-1)
# print(smplx_param.shape)
smplx_param_list.append(smplx_param)
smplx_params = np.concatenate(smplx_param_list, 0)
self.smplx_params = torch.Tensor(smplx_params).cuda()
def get_default_smplx_params(self):
A_pose = torch.Tensor([[ 1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.1047, 0.0000, 0.0000, -0.1047, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, -0.7854, 0.0000,
0.0000, 0.7854, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, -0.7470, 1.0966,
0.0169, -0.0534, -0.0212, 0.0782, -0.0348, 0.0260, 0.0060, 0.0118,
-0.1117, -0.0429, 0.4164, -0.1088, 0.0660, 0.7562, 0.0964, 0.0909,
0.1885, 0.1181, -0.0509, 0.5296, 0.1437, -0.0552, 0.7049, 0.0192,
0.0923, 0.3379, 0.4570, 0.1963, 0.6255, 0.2147, 0.0660, 0.5069,
0.3697, 0.0603, 0.0795, 0.1419, 0.0859, 0.6355, 0.3033, 0.0579,
0.6314, 0.1761, 0.1321, 0.3734, -0.8510, -0.2769, 0.0915, 0.4998,
-0.0266, -0.0529, -0.5356, -0.0460, 0.2774, -0.1117, 0.0429, -0.4164,
-0.1088, -0.0660, -0.7562, 0.0964, -0.0909, -0.1885, 0.1181, 0.0509,
-0.5296, 0.1437, 0.0552, -0.7049, 0.0192, -0.0923, -0.3379, 0.4570,
-0.1963, -0.6255, 0.2147, -0.0660, -0.5069, 0.3697, -0.0603, -0.0795,
0.1419, -0.0859, -0.6355, 0.3033, -0.0579, -0.6314, 0.1761, -0.1321,
-0.3734, -0.8510, 0.2769, -0.0915, 0.4998, 0.0266, 0.0529, -0.5356,
0.0460, -0.2774, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
0.0000, 0.0000, 0.0000, 0.0000, 0.0000]])
return A_pose
def forward_decoder(self, decoder, code, target_rgbs, cameras,
smpl_params=None, return_decoder_loss=False, init=False):
decoder = self.decoder_ema if self.freeze_decoder and self.decoder_use_ema else self.decoder
num_imgs = target_rgbs.shape[1]
outputs = decoder(
code, smpl_params, cameras,
num_imgs, return_loss=return_decoder_loss, init=init, return_norm=False)
return outputs
def on_fit_start(self):
if self.global_rank == 0:
os.makedirs(os.path.join(self.logdir, 'images'), exist_ok=True)
os.makedirs(os.path.join(self.logdir, 'images_val'), exist_ok=True)
os.makedirs(os.path.join(self.logdir, 'images_val_code'), exist_ok=True)
def forward(self, data):
# print("iter")
num_scenes = len(data['scene_id']) # 8
if 'cond_imgs' in data:
cond_imgs = data['cond_imgs'] # (num_scenes, num_imgs, h, w, 3)
cond_intrinsics = data['cond_intrinsics'] # (num_scenes, num_imgs, 4), in [fx, fy, cx, cy]
cond_poses = data['cond_poses'] # (num_scenes, num_imgs, 4, 4)
smpl_params = data['cond_smpl_param'] # (num_scenes, c)
# if 'cond_norm' in data:cond_norm = data['cond_norm'] else: cond_norm = None
num_scenes, num_imgs, h, w, _ = cond_imgs.size()
cameras = torch.cat([cond_intrinsics, cond_poses.reshape(num_scenes, num_imgs, -1)], dim=-1)
if self.if_include_video_ref_img: # new! we render to all view, including the input image; # don't compute this loss
cameras = cameras[:,1:]
num_imgs = num_imgs - 1
if self.inputs_front_only: # default setting to use the first image as input
inputs_img_idx = [0]
else:
raise NotImplementedError("inputs_front_only is False")
inputs_img = cond_imgs[:,inputs_img_idx[0],...].permute([0,3,1,2]) #
target_imgs = cond_imgs[:, 1:]
assert cameras.shape[1] == target_imgs.shape[1]
if self.is_debug:
try:
code = self.forward_image_to_uv(inputs_img, is_training=self.training) #TODO check where the validation
except Exception as e: # OOM
main_print(e)
code = torch.zeros([num_scenes, 32, 256, 256]).to(inputs_img.dtype).to(inputs_img.device)
else:
code = self.forward_image_to_uv(inputs_img, is_training=self.training) #TODO check where the validation
decoder = self.decoder_ema if self.freeze_decoder and self.decoder_use_ema else self.decoder
# uvmaps_decoder_gender's forward
output = decoder(
code, smpl_params, cameras,
num_imgs, return_loss=False, init=(self.global_step < self.init_iter), return_norm=False) #(['scales', 'norm', 'image', 'offset'])
output['code'] = code
output['target_imgs'] = target_imgs
output['inputs_img'] = cond_imgs[:,[0],...]
# for visualization
if self.global_rank == 0 and self.global_step % 200 == 0 and self.is_debug:
overlay_imgs = 0.5 * target_imgs + 0.5 * output['image']
overlay_imgs = rearrange(overlay_imgs, 'b n h w c -> b h n w c')
overlay_imgs = rearrange(overlay_imgs, ' b h n w c -> (b h) (n w) c')
overlay_imgs = overlay_imgs.to(torch.float32).detach().cpu().numpy()
overlay_imgs = (overlay_imgs * 255).astype(np.uint8)
Image.fromarray(overlay_imgs).save(f'debug_{self.global_step}.jpg')
return output
def forward_image_to_uv(self, inputs_img, is_training=True):
'''
inputs_img: torch.Tensor, bs, 3, H, W
return
code : bs, 256, 256, 32
'''
if self.decoder_learning_rate <= 0:
with torch.no_grad():
features_flatten = self.encoder(inputs_img, use_my_proces=True, output_hidden_states=self.output_hidden_states)
else:
features_flatten = self.encoder(inputs_img, use_my_proces=True, output_hidden_states=self.output_hidden_states)
if self.ids_restore.device !=features_flatten.device:
self.ids_restore = self.ids_restore.to(features_flatten.device)
ids_restore = self.ids_restore.expand([features_flatten.shape[0], -1])
uv_code = self.neck(features_flatten, ids_restore)
batch_size, token_num, dims_feature = uv_code.shape
if self.reshape_type=='reshape':
feature_map = uv_code.reshape(batch_size, self.num_patches_axis, self.num_patches_axis,\
self.code_feat_dims, self.code_patch_size, self.code_patch_size) # torch.Size([1, 64, 64, 32, 4, 4, ])
feature_map = feature_map.permute(0, 3, 1, 4, 2, 5) # ([1, 32, 64, 4, 64, 4])
feature_map = feature_map.reshape(batch_size, self.code_feat_dims, self.code_resolution, self.code_resolution) # torch.Size([1, 32, 256, 256])
code = feature_map # [1, 32, 256, 256]
else:
feature_map = uv_code.reshape(batch_size, self.num_patches_axis, self.num_patches_axis,dims_feature) # torch.Size([1, 64, 64, 512, ])
if isinstance(self.decoder, DistributedDataParallel):
code = self.decoder.module.upsample_conv(feature_map.permute([0,3,1,2])) # torch.Size([1, 32, 256, 256])
else:
code = self.decoder.upsample_conv(feature_map.permute([0,3,1,2])) # torch.Size([1, 32, 256, 256])
code = self.code_activation(code)
return code
def compute_loss(self, render_out):
render_images = render_out['image'] # .Size([1, 5, 896, 640, 3]), range [0, 1]
target_images = render_out['target_imgs']
target_images =target_images.to(render_images)
if self.is_debug:
render_images_tmp= rearrange(render_images, 'b n h w c -> (b n) c h w')
target_images_tmp = rearrange(target_images, 'b n h w c -> (b n) c h w')
all_images = torch.cat([render_images_tmp, target_images_tmp], dim=2)
all_images = render_images_tmp*0.5 + target_images_tmp*0.5
grid = make_grid(all_images, nrow=4, normalize=True, value_range=(0, 1))
save_image(grid, "./debug.png")
main_print("saving into ./debug.png")
render_images = rearrange(render_images, 'b n h w c -> (b n) c h w') * 2.0 - 1.0
target_images = rearrange(target_images, 'b n h w c -> (b n) c h w') * 2.0 - 1.0
if self.lambda_mse<=0:
loss_mse = 0
else:
if self.loss_weights_views.numel() != 0:
b, n, _, _, _ = render_out['image'].shape
loss_weights_views = self.loss_weights_views.unsqueeze(0).to(render_images.device)
loss_weights_views = loss_weights_views.repeat(b,1).reshape(-1).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
loss_mse = weighted_mse_loss(render_images, target_images, loss_weights_views)
main_print("weighted sum mse")
else:
loss_mse = F.mse_loss(render_images, target_images)
if self.lambda_l1<=0:
loss_l1 = 0
else:
loss_l1 = F.l1_loss(render_images, target_images)
if self.lambda_ssim <= 0:
loss_ssim = 0
else:
loss_ssim = 1 - self.ssim(render_images, target_images)
if not self.is_debug:
if self.lambda_lpips<=0:
loss_lpips = 0
else:
if self.loss_weights_views.numel() != 0:
with torch.cuda.amp.autocast():
loss_lpips = self.lpips(render_images.clamp(-1, 1), target_images)
else:
loss_lpips = 0
with torch.cuda.amp.autocast():
for img_idx in range(render_images.shape[0]):
loss_lpips += self.lpips(render_images[[img_idx]].clamp(-1, 1), target_images[[img_idx]])
loss_lpips /= render_images.shape[0]
else:
loss_lpips = 0
loss_gs_offset = render_out['offset']
loss = loss_mse * self.lambda_mse \
+ loss_l1 * self.lambda_l1 \
+ loss_ssim * self.lambda_ssim \
+ loss_lpips * self.lambda_lpips \
+ loss_gs_offset * self.lambda_offset
prefix = 'train'
loss_dict = {}
loss_dict.update({f'{prefix}/loss_mse': loss_mse})
loss_dict.update({f'{prefix}/loss_lpips': loss_lpips})
loss_dict.update({f'{prefix}/loss_gs_offset': loss_gs_offset})
loss_dict.update({f'{prefix}/loss_ssim': loss_ssim})
loss_dict.update({f'{prefix}/loss_l1': loss_l1})
loss_dict.update({f'{prefix}/loss': loss})
return loss, loss_dict
def compute_metrics(self, render_out):
# NOTE: all the rgb value range is [0, 1]
# render_out.keys = (['scales', 'norm', 'image', 'offset', 'code', 'target_imgs'])
render_images = render_out['image'].clamp(0, 1) # .Size([1, 5, 896, 640, 3]), range [0, 1]
target_images = render_out['target_imgs']
if target_images.dtype!=render_images.dtype:
target_images = target_images.to(render_images.dtype)
render_images = rearrange(render_images, 'b n h w c -> (b n) c h w')
target_images = rearrange(target_images, 'b n h w c -> (b n) c h w').to(render_images)
mse = F.mse_loss(render_images, target_images).mean()
psnr = 10 * torch.log10(1.0 / mse)
ssim = self.ssim(render_images, target_images)
render_images = render_images * 2.0 - 1.0
target_images = target_images * 2.0 - 1.0
if self.lambda_lpips<=0:
lpips = torch.Tensor([0]).to(render_images.device).to(render_images.dtype)
else:
with torch.cuda.amp.autocast():
lpips = self.lpips(render_images, target_images)
metrics = {
'val/mse': mse,
'val/pnsr': psnr,
'val/ssim': ssim,
'val/lpips': lpips,
}
return metrics
def new_on_before_optimizer_step(self):
norms = grad_norm(self.neck, norm_type=2)
if 'grad_2.0_norm_total' in norms:
self.log_dict({'grad_norm/lrm_generator': norms['grad_2.0_norm_total']})
@torch.no_grad()
def validation_step(self, batch, batch_idx):
render_out = self.forward(batch)
metrics = self.compute_metrics(render_out)
self.validation_metrics.append(metrics)
render_images = render_out['image']
render_images = rearrange(render_images, 'b n h w c -> b c h (n w)')
gt_images = render_out['target_imgs']
gt_images = rearrange(gt_images, 'b n h w c-> b c h (n w)')
log_images = torch.cat([render_images, gt_images], dim=-2)
self.validation_step_outputs.append(log_images)
self.validation_step_code_outputs.append( render_out['code'])
render_out_comb = self.forward_nvPose(batch, smplx_given=None)
self.validation_step_nvPose_outputs.append(render_out_comb)
def forward_nvPose(self, batch, smplx_given):
'''
smplx_given: torch.Tensor, bs, 189
it will returns images with cameras_num * poses_num
'''
_, num_img, _,_ = batch['cond_poses'].shape
# write a code to seperately input the smplx_params
if smplx_given == None:
step_pose = self.smplx_params.shape[0] // num_img
smplx_given = self.smplx_params
else:
step_pose = 1
render_out_list = []
for i in range(num_img):
target_pose = smplx_given[[i*step_pose]]
bk = batch['cond_smpl_param'].clone()
batch['cond_smpl_param'][:, 7:70] = target_pose[:, 7:70] # copy body_pose
batch['cond_smpl_param'][:, 80:80+93] = target_pose[:, 80:80+93]# copy pose_hand + pose_jaw
batch['cond_smpl_param'][:, 179:189] = target_pose[:, 179:189]# copy face expression
render_out_new = self.forward(batch)
render_out_list.append(render_out_new['image'])
render_out_comb = torch.cat(render_out_list, dim=2) # stack in the H axis
render_out_comb = rearrange(render_out_comb, 'b n h w c -> b c h (n w)')
return render_out_comb
def on_validation_epoch_end(self): #
images = torch.cat(self.validation_step_outputs, dim=-1)
all_images = self.all_gather(images).cpu()
all_images = rearrange(all_images, 'r b c h w -> (r b) c h w')
# nv pose
images_pose = torch.cat(self.validation_step_nvPose_outputs, dim=-1)
all_images_pose = self.all_gather(images_pose).cpu()
all_images_pose = rearrange(all_images_pose, 'r b c h w -> (r b) c h w')
if self.global_rank == 0:
image_path = os.path.join(self.logdir, 'images_val', f'val_{self.global_step:07d}.png')
grid = make_grid(all_images, nrow=1, normalize=True, value_range=(0, 1))
save_image(grid, image_path)
main_print(f"Saved image to {image_path}")
metrics = {}
for key in self.validation_metrics[0].keys():
metrics[key] = torch.stack([m[key] for m in self.validation_metrics]).mean()
self.log_dict(metrics, prog_bar=True, logger=True, on_step=False, on_epoch=True)
# code for saving the nvPose images
image_path_nvPose = os.path.join(self.logdir, 'images_val', f'val_{self.global_step:07d}_nvPose.png')
grid_nvPose = make_grid(all_images_pose, nrow=1, normalize=True, value_range=(0, 1))
save_image(grid_nvPose, image_path_nvPose)
main_print(f"Saved image to {image_path_nvPose}")
# code for saving the code images
for i, code in enumerate(self.validation_step_code_outputs):
image_path = os.path.join(self.logdir, 'images_val_code')
num_scenes, num_chn, h, w = code.size()
code_viz = code.reshape(num_scenes, 4, 8, h, w).to(torch.float32).cpu().numpy()
code_viz = code_viz.transpose(0, 1, 3, 2, 4).reshape(num_scenes, 4 * h, 8 * w)
for j, code_viz_single in enumerate(code_viz):
plt.imsave(os.path.join(image_path, f'val_{self.global_step:07d}_{i*num_scenes+j:04d}' + '.png'), code_viz_single,
vmin=self.code_clip_range[0], vmax=self.code_clip_range[1])
self.validation_step_outputs.clear()
self.validation_step_nvPose_outputs.clear()
self.validation_metrics.clear()
self.validation_step_code_outputs.clear()
def on_test_start(self):
if self.global_rank == 0:
os.makedirs(os.path.join(self.logdir, 'images_test'), exist_ok=True)
def on_test_epoch_end(self):
metrics = {}
metrics_mean = {}
metrics_var = {}
for key in self.validation_metrics[0].keys():
tmp = torch.stack([m[key] for m in self.validation_metrics]).cpu().numpy()
metrics_mean[key] = tmp.mean()
metrics_var[key] = tmp.var()
formatted_metrics = {}
for key in metrics_mean.keys():
formatted_metrics[key] = f"{metrics_mean[key]:.4f}±{metrics_var[key]:.4f}"
for key in self.validation_metrics[0].keys():
metrics[key] = torch.stack([m[key] for m in self.validation_metrics]).cpu().numpy().tolist()
final_dict = {"average": formatted_metrics,
'details': metrics}
metric_path = os.path.join(self.logdir, f'metrics.json')
with open(metric_path, 'w') as f:
json.dump(final_dict, f, indent=4)
main_print(f"Saved metrics to {metric_path}")
for key in metrics.keys():
metrics[key] = torch.stack([m[key] for m in self.validation_metrics]).mean()
main_print(metrics)
self.validation_metrics.clear()
def configure_optimizers(self):
# define the optimizer and the scheduler for neck and decoder
main_print("WARNING currently, we only support the single optimizer for both neck and decoder")
learning_rate = self.neck_learning_rate
params= [
{'params': self.neck.parameters(), 'lr': self.neck_learning_rate, },
{'params': self.decoder.parameters(), 'lr': self.decoder_learning_rate},
]
if hasattr(self, "encoder_learning_rate") and self.encoder_learning_rate>0:
params.append({'params': self.encoder.parameters(), 'lr': self.encoder_learning_rate})
main_print("============add the encoder into the optimizer============")
optimizer = torch.optim.Adam(
params
)
T_warmup, T_max, eta_min = self.warmup_steps, self.max_steps, 0.001
lr_lambda = lambda step: \
eta_min + (1 - math.cos(math.pi * step / T_warmup)) * (1 - eta_min) * 0.5 if step < T_warmup else \
eta_min + (1 + math.cos(math.pi * (step - T_warmup) / (T_max - T_warmup))) * (1 - eta_min) * 0.5
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
return {'optimizer': optimizer, 'lr_scheduler': scheduler}
def training_step(self, batch, batch_idx):
scheduler = self.lr_schedulers()
scheduler.step()
render_gt = None #?
render_out = self.forward(batch)
loss, loss_dict = self.compute_loss(render_out)
self.log_dict(loss_dict, prog_bar=True, logger=True, on_step=True, on_epoch=True)
if self.global_step % 200 == 0 and self.global_rank == 0:
self.new_on_before_optimizer_step() # log the norm
if self.global_step % 200 == 0 and self.global_rank == 0:
if self.if_include_video_ref_img and self.training:
render_images = torch.cat([ torch.ones_like(render_out['image'][:,0:1]), render_out['image']], dim=1)
target_images = torch.cat([ render_out['inputs_img'], render_out['target_imgs']], dim=1)
target_images = rearrange(
target_images, 'b n h w c -> b c h (n w)')
render_images = rearrange(
render_images, 'b n h w c-> b c h (n w)')
grid = torch.cat([
target_images, render_images, 0.5*render_images + 0.5*target_images,
], dim=-2)
grid = make_grid(grid, nrow=target_images.shape[0], normalize=True, value_range=(0, 1))
image_path = os.path.join(self.logdir, 'images', f'train_{self.global_step:07d}.jpg')
save_image(grid, image_path)
main_print(f"Saved image to {image_path}")
return loss
@torch.no_grad()
def test_step(self, batch, batch_idx):
# input_dict, render_gt = self.prepare_validation_batch_data(batch)
render_out = self.forward(batch)
render_gt = render_out['target_imgs']
render_img = render_out['image']
# Compute metrics
metrics = self.compute_metrics(render_out)
self.validation_metrics.append(metrics)
# Save images
target_images = rearrange(
render_gt, 'b n h w c -> b c h (n w)')
render_images = rearrange(
render_img, 'b n h w c -> b c h (n w)')
grid = torch.cat([
target_images, render_images,
], dim=-2)
grid = make_grid(grid, nrow=target_images.shape[0], normalize=True, value_range=(0, 1))
# self.logger.log_image('train/render', [grid], step=self.global_step)
image_path = os.path.join(self.logdir, 'images_test', f'{batch_idx:07d}.png')
save_image(grid, image_path)
# code visualize
code = render_out['code']
self.decoder.visualize(code, batch['scene_name'],
os.path.dirname(image_path), code_range=self.code_clip_range)
print(f"Saved image to {image_path}")
def on_test_start(self):
if self.global_rank == 0:
os.makedirs(os.path.join(self.logdir, 'images_test'), exist_ok=True)
def on_test_epoch_end(self):
metrics = {}
metrics_mean = {}
metrics_var = {}
for key in self.validation_metrics[0].keys():
tmp = torch.stack([m[key] for m in self.validation_metrics]).cpu().numpy()
metrics_mean[key] = tmp.mean()
metrics_var[key] = tmp.var()
# trans format into "mean±var"
formatted_metrics = {}
for key in metrics_mean.keys():
formatted_metrics[key] = f"{metrics_mean[key]:.4f}±{metrics_var[key]:.4f}"
for key in self.validation_metrics[0].keys():
metrics[key] = torch.stack([m[key] for m in self.validation_metrics]).cpu().numpy().tolist()
# saving into a dictionary
final_dict = {"average": formatted_metrics,
'details': metrics}
metric_path = os.path.join(self.logdir, f'metrics.json')
with open(metric_path, 'w') as f:
json.dump(final_dict, f, indent=4)
print(f"Saved metrics to {metric_path}")
for key in metrics.keys():
metrics[key] = torch.stack([m[key] for m in self.validation_metrics]).mean()
print(metrics)
self.validation_metrics.clear()
def weighted_mse_loss(render_images, target_images, weights):
squared_diff = (render_images - target_images) ** 2
main_print(squared_diff.shape, weights.shape)
weighted_squared_diff = squared_diff * weights
loss_mse_weighted = weighted_squared_diff.mean()
return loss_mse_weighted
================================================
FILE: lib/mmutils/__init__.py
================================================
from .initialize import xavier_init, constant_init
================================================
FILE: lib/mmutils/initialize.py
================================================
import torch.nn as nn
def constant_init(module: nn.Module, val: float, bias: float = 0) -> None:
if hasattr(module, 'weight') and module.weight is not None:
nn.init.constant_(module.weight, val)
if hasattr(module, 'bias') and module.bias is not None:
nn.init.constant_(module.bias, bias)
def xavier_init(module: nn.Module,
gain: float = 1,
bias: float = 0,
distribution: str = 'normal') -> None:
assert distribution in ['uniform', 'normal']
if hasattr(module, 'weight') and module.weight is not None:
if distribution == 'uniform':
nn.init.xavier_uniform_(module.weight, gain=gain)
else:
nn.init.xavier_normal_(module.weight, gain=gain)
if hasattr(module, 'bias') and module.bias is not None:
nn.init.constant_(module.bias, bias)
================================================
FILE: lib/models/__init__.py
================================================
from .decoders import *
================================================
FILE: lib/models/decoders/__init__.py
================================================
from .uvmaps_decoder_gender import UVNDecoder_gender
__all__ = [ 'UVNDecoder_gender']
================================================
FILE: lib/models/decoders/uvmaps_decoder_gender.py
================================================
import os
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import einsum
from pytorch3d import ops
from lib.mmutils import xavier_init, constant_init
import numpy as np
import time
import cv2
import math
from simple_knn._C import distCUDA2
from pytorch3d.transforms import quaternion_to_matrix
from ..deformers import SMPLXDeformer_gender
from ..renderers import GRenderer, get_covariance, batch_rodrigues
from lib.ops import TruncExp
import torchvision
from lib.utils.train_util import main_print
def ensure_dtype(input_tensor, target_dtype=torch.float32):
"""
Ensure tensor dtype matches target dtype.
If not, convert it.
"""
if input_tensor.dtype != target_dtype:
input_tensor = input_tensor.to(dtype=target_dtype)
return input_tensor
class UVNDecoder_gender(nn.Module):
activation_dict = {
'relu': nn.ReLU,
'silu': nn.SiLU,
'softplus': nn.Softplus,
'trunc_exp': TruncExp,
'sigmoid': nn.Sigmoid}
def __init__(self,
*args,
interp_mode='bilinear',
base_layers=[3 * 32, 128],
density_layers=[128, 1],
color_layers=[128, 128, 3],
offset_layers=[128, 3],
scale_layers=[128, 3],
radius_layers=[128, 3],
use_dir_enc=True,
dir_layers=None,
scene_base_size=None,
scene_rand_dims=(0, 1),
activation='silu',
sigma_activation='sigmoid',
sigmoid_saturation=0.001,
code_dropout=0.0,
flip_z=False,
extend_z=False,
gender='neutral',
multires=0,
bg_color=0,
image_size=1024,
superres=False,
focal = 1280, # the default focal defination
reshape_type=None, # if true, it will create a cnn layers to upsample the uv features
fix_sigma=False, # if true, the density of GS will be fixed
up_cnn_in_channels = None, # the channel number of the upsample cnn
vithead_param=None, # the vit head for decode to uv features
is_sub2=False, # if true, will use the sub2 uv map
**kwargs):
super().__init__()
self.interp_mode = interp_mode
self.in_chn = base_layers[0]
self.use_dir_enc = use_dir_enc
if scene_base_size is None:
self.scene_base = None
else:
rand_size = [1 for _ in scene_base_size]
for dim in scene_rand_dims:
rand_size[dim] = scene_base_size[dim]
init_base = torch.randn(rand_size).expand(scene_base_size).clone()
self.scene_base = nn.Parameter(init_base)
self.dir_encoder = None
self.sigmoid_saturation = sigmoid_saturation
self.deformer = SMPLXDeformer_gender(gender, is_sub2=is_sub2)
self.renderer = GRenderer(image_size=image_size, bg_color=bg_color, f=focal)
if superres:
self.superres = None
else:
self.superres = None
self.gender= gender
self.reshape_type = reshape_type
if reshape_type=='cnn':
self.upsample_conv = torch.nn.ConvTranspose2d(512, 32, kernel_size=4, stride=4,).cuda()
elif reshape_type == 'VitHead': # changes the up block's layernorm into the feature channel norm instead of the full image norm
from lib.models.decoders.vit_head import VitHead
self.upsample_conv = VitHead(**vithead_param)
# 256, 128, 128 -> 128, 256, 256 -> 64, 512, 512, ->32, 1024, 1024
base_cache_dir = 'work_dirs/cache'
if is_sub2:
base_cache_dir = 'work_dirs/cache_sub2'
# main_print("!!!!!!!!!!!!!!!!!!! using the sub2 uv map !!!!!!!!!!!!!!!!!!!")
if gender == 'neutral':
select_uv = torch.as_tensor(np.load(base_cache_dir+'/init_uv_smplx_newNeutral.npy'))
self.register_buffer('select_coord', select_uv.unsqueeze(0)*2.-1.)
init_pcd = torch.as_tensor(np.load(base_cache_dir+'/init_pcd_smplx_newNeutral.npy'))
self.register_buffer('init_pcd', init_pcd.unsqueeze(0), persistent=False) # 0.9-- -1
elif gender == 'male':
assert NotImplementedError("Haven't create the init_uv_smplx_thu in v_template")
select_uv = torch.as_tensor(np.load(base_cache_dir+'/init_uv_smplx_thu.npy'))
self.register_buffer('select_coord', select_uv.unsqueeze(0)*2.-1.)
init_pcd = torch.as_tensor(np.load(base_cache_dir+'/init_pcd_smplx_thu.npy'))
self.register_buffer('init_pcd', init_pcd.unsqueeze(0), persistent=False) # 0.9-- -1
self.num_init = self.init_pcd.shape[1]
main_print(f"!!!!!!!!!!!!!!!!!!! cur points number are {self.num_init} !!!!!!!!!!!!!!!!!!!")
self.init_pcd = self.init_pcd
self.multires = multires # 0 Haven't
if multires > 0:
uv_map = torch.as_tensor(np.load(base_cache_dir+'/init_uvmap_smplx_thu.npy'))
pcd_map = torch.as_tensor(np.load(base_cache_dir+'/init_posmap_smplx_thu.npy'))
input_coord = torch.cat([pcd_map, uv_map], dim=1)
self.register_buffer('input_freq', input_coord, persistent=False)
base_layers[0] += 5
color_layers[0] += 5
else:
self.init_uv = None
activation_layer = self.activation_dict[activation.lower()]
base_net = [] # linear (in=18, out=64, bias=True)
for i in range(len(base_layers) - 1):
base_net.append(nn.Conv2d(base_layers[i], base_layers[i + 1], 3, padding=1))
if i != len(base_layers) - 2:
base_net.append(nn.BatchNorm2d(base_layers[i+1]))
base_net.append(activation_layer())
self.base_net = nn.Sequential(*base_net)
self.base_bn = nn.BatchNorm2d(base_layers[-1])
self.base_activation = activation_layer()
density_net = [] # linear(in=64, out=1, bias=True), sigmoid
for i in range(len(density_layers) - 1):
density_net.append(nn.Conv2d(density_layers[i], density_layers[i + 1], 1))
if i != len(density_layers) - 2:
density_net.append(nn.BatchNorm2d(density_layers[i+1]))
density_net.append(activation_layer())
density_net.append(self.activation_dict[sigma_activation.lower()]())
self.density_net = nn.Sequential(*density_net)
offset_net = [] # linear(in=64, out=1, bias=True), sigmoid
for i in range(len(offset_layers) - 1):
offset_net.append(nn.Conv2d(offset_layers[i], offset_layers[i + 1], 1))
if i != len(offset_layers) - 2:
offset_net.append(nn.BatchNorm2d(offset_layers[i+1]))
offset_net.append(activation_layer())
self.offset_net = nn.Sequential(*offset_net)
self.dir_net = None
color_net = [] # linear(in=64, out=3, bias=True), sigmoid
for i in range(len(color_layers) - 2):
color_net.append(nn.Conv2d(color_layers[i], color_layers[i + 1], kernel_size=3, padding=1))
color_net.append(nn.BatchNorm2d(color_layers[i+1]))
color_net.append(activation_layer())
color_net.append(nn.Conv2d(color_layers[-2], color_layers[-1], kernel_size=1))
color_net.append(nn.Sigmoid())
self.color_net = nn.Sequential(*color_net)
self.code_dropout = nn.Dropout2d(code_dropout) if code_dropout > 0 else None
self.flip_z = flip_z
self.extend_z = extend_z
if self.gender == 'neutral':
init_rot = torch.as_tensor(np.load(base_cache_dir+'/init_rot_smplx_newNeutral.npy'))
self.register_buffer('init_rot', init_rot, persistent=False)
face_mask = torch.as_tensor(np.load(base_cache_dir+'/face_mask_thu_newNeutral.npy'))
self.register_buffer('face_mask', face_mask.unsqueeze(0), persistent=False)
hands_mask = torch.as_tensor(np.load(base_cache_dir+'/hands_mask_thu_newNeutral.npy'))
self.register_buffer('hands_mask', hands_mask.unsqueeze(0), persistent=False)
outside_mask = torch.as_tensor(np.load(base_cache_dir+'/outside_mask_thu_newNeutral.npy'))
self.register_buffer('outside_mask', outside_mask.unsqueeze(0), persistent=False)
else:
assert NotImplementedError("Haven't create the init_rot in v_template")
init_rot = torch.as_tensor(np.load(base_cache_dir+'/init_rot_smplx_thu.npy'))
self.register_buffer('init_rot', init_rot, persistent=False)
face_mask = torch.as_tensor(np.load(base_cache_dir+'/face_mask_thu.npy'))
self.register_buffer('face_mask', face_mask.unsqueeze(0), persistent=False)
hands_mask = torch.as_tensor(np.load(base_cache_dir+'/hands_mask_thu.npy'))
self.register_buffer('hands_mask', hands_mask.unsqueeze(0), persistent=False)
outside_mask = torch.as_tensor(np.load(base_cache_dir+'/outside_mask_thu.npy'))
self.register_buffer('outside_mask', outside_mask.unsqueeze(0), persistent=False)
self.iter = 0
self.init_weights()
self.if_rotate_gaussian = False
self.fix_sigma = fix_sigma
def init_weights(self):
for m in self.modules():
if isinstance(m, nn.Linear):
xavier_init(m, distribution='uniform')
if self.dir_net is not None:
constant_init(self.dir_net[-1], 0)
if self.offset_net is not None:
self.offset_net[-1].weight.data.uniform_(-1e-5, 1e-5)
self.offset_net[-1].bias.data.zero_()
def extract_pcd(self, code, smpl_params, init=False, zeros_hands_off=False):
'''
Args:
B == num_scenes
code (tensor): latent code. shape: [B, C, H, W]
smpl_params (tensor): SMPL parameters. shape: [B_pose, 189]
init (bool): Not used
Returns:
defm_pcd (tensor): deformed point cloud. shape: [B, N, B_pose, 3]
sigmas, rgbs, offset, radius, rot(tensor): GS attributes. shape: [B, N, C]
tfs(tensor): deformation matrics. shape: [B, N, C]
'''
if isinstance(code, list):
num_scenes, _, h, w = code[0].size()
else:
num_scenes, n_channels, h, w = code.size()
init_pcd = self.init_pcd.repeat(num_scenes, 1, 1) # T-posed space points, for computing the skinning weights
sigmas, rgbs, radius, rot, offset = self._decode(code, init=init) # the person-specify attributes of GS
if self.fix_sigma:
sigmas = torch.ones_like(sigmas)
if zeros_hands_off:
offset[self.hands_mask[...,None].expand(num_scenes, -1, 3)] = 0
canon_pcd = init_pcd + offset
self.deformer.prepare_deformer(smpl_params, num_scenes, device=canon_pcd.device)
defm_pcd, tfs = self.deformer(canon_pcd, rot, mask=(self.face_mask+self.hands_mask+self.outside_mask), cano=False, if_rotate_gaussian=self.if_rotate_gaussian)
return defm_pcd, sigmas, rgbs, offset, radius, tfs, rot
def deform_pcd(self, code, smpl_params, init=False, zeros_hands_off=False, value=0.1):
'''
Args:
B == num_scenes
code (List): list of data
smpl_params (tensor): SMPL parameters. shape: [B_pose, 189]
init (bool): Not used
Returns:
defm_pcd (tensor): deformed point cloud. shape: [B, N, B_pose, 3]
sigmas, rgbs, offset, radius, rot(tensor): GS attributes. shape: [B, N, C]
tfs(tensor): deformation matrics. shape: [B, N, C]
'''
sigmas, rgbs, radius, rot, offset = code
num_scenes = sigmas.shape[0]
init_pcd = self.init_pcd.repeat(num_scenes, 1, 1) #T-posed space points, for computing the skinning weights
if self.fix_sigma:
sigmas = torch.ones_like(sigmas)
if zeros_hands_off:
offset[self.hands_mask[...,None].expand(num_scenes, -1, 3)] = torch.clamp(offset[self.hands_mask[...,None].expand(num_scenes, -1, 3)], -value, value)
canon_pcd = init_pcd + offset
self.deformer.prepare_deformer(smpl_params, num_scenes, device=canon_pcd.device)
defm_pcd, tfs = self.deformer(canon_pcd, rot, mask=(self.face_mask+self.hands_mask+self.outside_mask), cano=False, if_rotate_gaussian=self.if_rotate_gaussian)
return defm_pcd, sigmas, rgbs, offset, radius, tfs, rot
def _sample_feature(self,results,):
# outputs, sigma_uv, offset_uv, rgbs_uv, radius_uv, rot_uv = results['output'], results['sigma'], results['offset'], results['rgbs'], results['radius'], results['rot']
sigma = results['sigma']
outputs = results['output']
if isinstance(sigma, list):
num_scenes, _, h, w = sigma[0].shape
select_coord = self.select_coord.unsqueeze(1).repeat(num_scenes, 1, 1, 1)
elif sigma.dim() == 4:
num_scenes, n_channels, h, w = sigma.shape
select_coord = self.select_coord.unsqueeze(1).repeat(num_scenes, 1, 1, 1)
else:
assert False
output_attr = F.grid_sample(outputs, select_coord, mode=self.interp_mode, padding_mode='border', align_corners=False).reshape(num_scenes, 13, -1).permute(0, 2, 1)
sigma, offset, rgbs, radius, rot = output_attr.split([1, 3, 3, 3, 3], dim=2)
if self.sigmoid_saturation > 0:
rgbs = rgbs * (1 + self.sigmoid_saturation * 2) - self.sigmoid_saturation
radius = (radius - 0.5) * 2
rot = (rot - 0.5) * np.pi
return sigma, rgbs, radius, rot, offset
def _decode_feature(self, point_code, init=False):
if isinstance(point_code, list):
num_scenes, _, h, w = point_code[0].shape
geo_code, tex_code = point_code
# select_coord = self.select_coord.unsqueeze(1).repeat(num_scenes, 1, 1, 1)
if self.multires != 0:
input_freq = self.input_freq.repeat(num_scenes, 1, 1, 1)
elif point_code.dim() == 4:
num_scenes, n_channels, h, w = point_code.shape
# select_coord = self.select_coord.unsqueeze(1).repeat(num_scenes, 1, 1, 1)
if self.multires != 0:
input_freq = self.input_freq.repeat(num_scenes, 1, 1, 1)
geo_code, tex_code = point_code.split(16, dim=1)
else:
assert False
base_in = geo_code if self.multires == 0 else torch.cat([geo_code, input_freq], dim=1)
base_x = self.base_net(base_in)
base_x_act = self.base_activation(self.base_bn(base_x))
sigma = self.density_net(base_x_act)
offset = self.offset_net(base_x_act)
color_in = tex_code if self.multires == 0 else torch.cat([tex_code, input_freq], dim=1)
rgbs_radius_rot = self.color_net(color_in)
outputs = torch.cat([sigma, offset, rgbs_radius_rot], dim=1)
main_print(outputs.shape)
sigma, offset, rgbs, radius, rot = outputs.split([1, 3, 3, 3, 3], dim=1)
results = {'output':outputs, 'sigma': sigma, 'offset': offset, 'rgbs': rgbs, 'radius': radius, 'rot': rot}
return results
def _decode(self, point_code, init=False):
if isinstance(point_code, list):
num_scenes, _, h, w = point_code[0].shape
geo_code, tex_code = point_code
select_coord = self.select_coord.unsqueeze(1).repeat(num_scenes, 1, 1, 1)
if self.multires != 0:
input_freq = self.input_freq.repeat(num_scenes, 1, 1, 1)
elif point_code.dim() == 4:
num_scenes, n_channels, h, w = point_code.shape
select_coord = self.select_coord.unsqueeze(1).repeat(num_scenes, 1, 1, 1)
if self.multires != 0:
input_freq = self.input_freq.repeat(num_scenes, 1, 1, 1)
geo_code, tex_code = point_code.split(16, dim=1)
else:
assert False
base_in = geo_code if self.multires == 0 else torch.cat([geo_code, input_freq], dim=1)
base_x = self.base_net(base_in)
base_x_act = self.base_activation(self.base_bn(base_x))
sigma = self.density_net(base_x_act)
offset = self.offset_net(base_x_act)
color_in = tex_code if self.multires == 0 else torch.cat([tex_code, input_freq], dim=1)
rgbs_radius_rot = self.color_net(color_in)
outputs = torch.cat([sigma, offset, rgbs_radius_rot], dim=1)
output_attr = F.grid_sample(outputs, select_coord, mode=self.interp_mode, padding_mode='border', align_corners=False).reshape(num_scenes, 13, -1).permute(0, 2, 1)
sigma, offset, rgbs, radius, rot = output_attr.split([1, 3, 3, 3, 3], dim=2)
if self.sigmoid_saturation > 0:
rgbs = rgbs * (1 + self.sigmoid_saturation * 2) - self.sigmoid_saturation
radius = (radius - 0.5) * 2
rot = (rot - 0.5) * np.pi
return sigma, rgbs, radius, rot, offset
def gaussian_render(self, pcd, sigmas, rgbs, normals, rot, num_scenes, num_imgs, cameras, use_scale=False, radius=None, \
return_norm=False, return_viz=False, mask=None):
# add mask or visible points to images or select ind to images
'''
render the gaussian to images
return_norm: return the normals of the gaussian (haven't been used)
return_viz: return the mask of the gaussian
mask: the mask of the gaussian
'''
assert num_scenes == 1
pcd = pcd.reshape(-1, 3)
if use_scale:
dist2 = distCUDA2(pcd)
dist2 = torch.clamp_min((dist2), 0.0000001)
scales = torch.sqrt(dist2)[...,None].repeat(1, 3).detach() # distence between different points
scale = (radius+1)*scales # scaling_modifier # radius[-1--1], scale of GS
cov3D = get_covariance(scale, rot).reshape(-1, 6) # inputs rot is the rotations
images_all = []
viz_masks = [] if return_viz else None
norm_all = [] if return_norm else None
if mask != None:
pcd = pcd[mask]
rgbs = rgbs[mask]
sigmas = sigmas[mask]
cov3D = cov3D[mask]
normals = normals[mask]
if 1:
for i in range(num_imgs):
self.renderer.prepare(cameras[i])
image = self.renderer.render_gaussian(means3D=pcd, colors_precomp=rgbs,
rotations=None, opacities=sigmas, scales=None, cov3D_precomp=cov3D)
images_all.append(image)
if return_viz:
viz_mask = self.renderer.render_gaussian(means3D=pcd, colors_precomp=pcd.clone(),
rotations=None, opacities=sigmas*0+1, scales=None, cov3D_precomp=cov3D)
viz_masks.append(viz_mask)
images_all = torch.stack(images_all, dim=0).unsqueeze(0).permute(0, 1, 3, 4, 2)
if return_viz:
viz_masks = torch.stack(viz_masks, dim=0).unsqueeze(0).permute(0, 1, 3, 4, 2).reshape(1, -1, 3)
dist_sq, idx, neighbors = ops.knn_points(pcd.unsqueeze(0), viz_masks[:, ::10], K=1, return_nn=True)
viz_masks = (dist_sq < 0.0001)[0]
# ===== END the original code for batch size = 1 =====
if use_scale:
return images_all, norm_all, viz_masks, scale
else:
return images_all, norm_all, viz_masks, None
def visualize(self, code, scene_name, viz_dir, code_range=[-1, 1]):
num_scenes, num_chn, h, w = code.size()
code_viz = code.reshape(num_scenes, 4, 8, h, w).to(torch.float32).cpu().numpy()
if not self.flip_z:
code_viz = code_viz[..., ::-1, :]
code_viz = code_viz.transpose(0, 1, 3, 2, 4).reshape(num_scenes, 4 * h, 8 * w)
for code_single, code_viz_single, scene_name_single in zip(code, code_viz, scene_name):
plt.imsave(os.path.join(viz_dir, 'a_scene_' + scene_name_single + '.png'), code_viz_single,
vmin=code_range[0], vmax=code_range[1])
def forward(self, code, smpl_params, cameras, num_imgs,
return_loss=False, return_norm=False, init=False, mask=None, zeros_hands_off=False):
"""
Args:
density_bitfield: Shape (num_scenes, griz_size**3 // 8)
YY:
grid_size, dt_gamma, perturb, T_thresh are deleted
code: Shape (num_scenes, *code_size)
cameras: Shape (num_scenes, num_imgs, 19(3+16))
smpl_params: Shape (num_scenes, 189)
"""
# import ipdb; ipdb.set_trace()
if isinstance(code, list):
num_scenes = len(code[0])
else:
num_scenes = len(code)
assert num_scenes > 0
self.iter+=1
image = []
scales = []
norm = [] if return_norm else None
viz_masks = [] if not self.training else None
xyzs, sigmas, rgbs, offsets, radius, tfs, rot = self.extract_pcd(code, smpl_params, init=init, zeros_hands_off=zeros_hands_off)
if zeros_hands_off:
main_print('zeros_hands_off is on!')
main_print('zeros_hands_off is on!')
offsets[self.hands_mask[...,None].repeat(num_scenes, 1, 3)] = 0
rgbs[self.hands_mask[...,None].repeat(num_scenes, 1, 3)] = 0
rgbs[self.hands_mask[...,None].repeat(num_scenes, 1, 3)] = 0
R_delta = batch_rodrigues(rot.reshape(-1, 3))
R = torch.bmm(self.init_rot.repeat(num_scenes, 1, 1), R_delta)
R_def = torch.bmm(tfs.flatten(0, 1)[:, :3, :3], R)
normals = (R_def[:, :, -1]).reshape(num_scenes, -1, 3)
R_def_batch = R_def.reshape(num_scenes, -1, 3, 3)
return_to_bfloat16 = True if xyzs.dtype==torch.bfloat16 else False ####### ============ translate the output to BF16 =================
# return_to_bfloat16 = False # I don't want to trans it back to bf16
if return_to_bfloat16:
main_print("changes the return_to_bfloat16")
cameras, R_def_batch, xyzs, rgbs, sigmas, normals, radius = [ensure_dtype(item, torch.float32) for item in (cameras, R_def_batch, xyzs, rgbs, sigmas, normals, radius)]
# with torch.amp.autocast(enabled=False, device_type='cuda'):
if 1:
for camera_single, R_def_single, pcd_single, rgbs_single, sigmas_single, normal_single, radius_single in zip(cameras, R_def_batch, xyzs, rgbs, sigmas, normals, radius):
image_single, norm_single, viz_mask, scale = self.gaussian_render(pcd_single, sigmas_single, rgbs_single, normal_single, R_def_single, 1, num_imgs, camera_single, use_scale=True, \
radius=radius_single, return_norm=return_norm, return_viz=not self.training)
image.append(image_single)
scales.append(scale.unsqueeze(0))
if return_norm:
norm.append(norm_single)
if not self.training:
viz_masks.append(viz_mask)
image = torch.cat(image, dim=0)
scales = torch.cat(scales, dim=0)
norm = torch.cat(norm, dim=0) if return_norm else None
viz_masks = torch.cat(viz_masks, dim=0) if (not self.training) and viz_masks else None
main_print("not trans the rendered results to float16")
if False:
image = image.to(torch.bfloat16)
scales = scales.to(torch.bfloat16)
if return_norm:
norm = norm.to(torch.bfloat16)
if viz_masks is not None:
viz_masks = viz_masks.to(torch.bfloat16)
offsets = offsets.to(torch.bfloat16)
if self.training:
offset_dist = offsets ** 2
weighted_offset = torch.mean(offset_dist) + torch.mean(offset_dist[self.hands_mask.repeat(num_scenes, 1)]) #+ torch.mean(offset_dist[self.face_mask.repeat(num_scenes, 1)])
else:
weighted_offset = offsets
results = dict(
viz_masks=viz_masks,
scales=scales,
norm=norm,
image=image,
offset=weighted_offset)
if return_loss:
results.update(decoder_reg_loss=self.loss())
return results
def forward_render(self, code, cameras, num_imgs,
return_loss=False, return_norm=False, init=False, mask=None, zeros_hands_off=False):
"""
Args:
density_bitfield: Shape (num_scenes, griz_size**3 // 8)
YY:
grid_size, dt_gamma, perturb, T_thresh are deleted
code: Shape (num_scenes, *code_size)
cameras: Shape (num_scenes, num_imgs, 19(3+16))
smpl_params: Shape (num_scenes, 189)
"""
image = []
scales = []
norm = [] if return_norm else None
viz_masks = [] if not self.training else None
xyzs, sigmas, rgbs, offsets, radius, tfs, rot = code
num_scenes = xyzs.shape[0]
if zeros_hands_off:
main_print('zeros_hands_off is on!')
main_print('zeros_hands_off is on!')
offsets[self.hands_mask[...,None].repeat(num_scenes, 1, 3)] = 0
rgbs[self.hands_mask[...,None].repeat(num_scenes, 1, 3)] = 0
rgbs[self.hands_mask[...,None].repeat(num_scenes, 1, 3)] = 0
R_delta = batch_rodrigues(rot.reshape(-1, 3))
R = torch.bmm(self.init_rot.repeat(num_scenes, 1, 1), R_delta)
R_def = torch.bmm(tfs.flatten(0, 1)[:, :3, :3], R)
normals = (R_def[:, :, -1]).reshape(num_scenes, -1, 3)
R_def_batch = R_def.reshape(num_scenes, -1, 3, 3)
# import ipdb; ipdb.set_trace()
return_to_bfloat16 = True if xyzs.dtype==torch.bfloat16 else False ####### ============ translate the output to BF16 =================
# return_to_bfloat16 = False # I don't want to trans it back to bf16
if return_to_bfloat16:
main_print("changes the return_to_bfloat16")
cameras, R_def_batch, xyzs, rgbs, sigmas, normals, radius = [ensure_dtype(item, torch.float32) for item in (cameras, R_def_batch, xyzs, rgbs, sigmas, normals, radius)]
if 1:
for camera_single, R_def_single, pcd_single, rgbs_single, sigmas_single, normal_single, radius_single in zip(cameras, R_def_batch, xyzs, rgbs, sigmas, normals, radius):
image_single, norm_single, viz_mask, scale = self.gaussian_render(pcd_single, sigmas_single, rgbs_single, normal_single, R_def_single, 1, num_imgs, camera_single, use_scale=True, \
radius=radius_single, return_norm=return_norm, return_viz=not self.training)
image.append(image_single)
scales.append(scale.unsqueeze(0))
if return_norm:
norm.append(norm_single)
if not self.training:
viz_masks.append(viz_mask)
image = torch.cat(image, dim=0)
scales = torch.cat(scales, dim=0)
norm = torch.cat(norm, dim=0) if return_norm else None
viz_masks = torch.cat(viz_masks, dim=0) if (not self.training) and viz_masks else None
main_print("not trans the rendered results to float16")
if False:
image = image.to(torch.bfloat16)
scales = scales.to(torch.bfloat16)
if return_norm:
norm = norm.to(torch.bfloat16)
if viz_masks is not None:
viz_masks = viz_masks.to(torch.bfloat16)
offsets = offsets.to(torch.bfloat16)
if self.training:
offset_dist = offsets ** 2
weighted_offset = torch.mean(offset_dist) + torch.mean(offset_dist[self.hands_mask.repeat(num_scenes, 1)]) #+ torch.mean(offset_dist[self.face_mask.repeat(num_scenes, 1)])
else:
weighted_offset = offsets
results = dict(
viz_masks=viz_masks,
scales=scales,
norm=norm,
image=image,
offset=weighted_offset)
if return_loss:
results.update(decoder_reg_loss=self.loss())
return results
def forward_testing_time(self, code, smpl_params, cameras, num_imgs,
return_loss=False, return_norm=False, init=False, mask=None, zeros_hands_off=False):
"""
Args:
density_bitfield: Shape (num_scenes, griz_size**3 // 8)
YY:
grid_size, dt_gamma, perturb, T_thresh are deleted
code: Shape (num_scenes, *code_size)
cameras: Shape (num_scenes, num_imgs, 19(3+16))
smpl_params: Shape (num_scenes, 189)
"""
if isinstance(code, list):
num_scenes = len(code[0])
else:
num_scenes = len(code)
assert num_scenes > 0
self.iter+=1
image = []
scales = []
norm = [] if return_norm else None
viz_masks = [] if not self.training else None
start_time = time.time()
xyzs, sigmas, rgbs, offsets, radius, tfs, rot = self.extract_pcd(code, smpl_params, init=init, zeros_hands_off=zeros_hands_off)
end_time_to_3D = time.time()
time_code_to_3d = end_time_to_3D- start_time
R_delta = batch_rodrigues(rot.reshape(-1, 3))
R = torch.bmm(self.init_rot.repeat(num_scenes, 1, 1), R_delta)
R_def = torch.bmm(tfs.flatten(0, 1)[:, :3, :3], R)
normals = (R_def[:, :, -1]).reshape(num_scenes, -1, 3)
R_def_batch = R_def.reshape(num_scenes, -1, 3, 3)
if 1:
for camera_single, R_def_single, pcd_single, rgbs_single, sigmas_single, normal_single, radius_single in zip(cameras, R_def_batch, xyzs, rgbs, sigmas, normals, radius):
image_single, norm_single, viz_mask, scale = self.gaussian_render(pcd_single, sigmas_single, rgbs_single, normal_single, R_def_single, 1, num_imgs, camera_single, use_scale=True, \
radius=radius_single, return_norm=False, return_viz=not self.training)
image.append(image_single)
scales.append(scale.unsqueeze(0))
if return_norm:
norm.append(norm_single)
if not self.training:
viz_masks.append(viz_mask)
image = torch.cat(image, dim=0)
scales = torch.cat(scales, dim=0)
norm = torch.cat(norm, dim=0) if return_norm else None
viz_masks = torch.cat(viz_masks, dim=0) if (not self.training) and viz_masks else None
time_3D_to_img = time.time() - end_time_to_3D
if False:
image = image.to(torch.bfloat16)
scales = scales.to(torch.bfloat16)
if return_norm:
norm = norm.to(torch.bfloat16)
if viz_masks is not None:
viz_masks = viz_masks.to(torch.bfloat16)
offsets = offsets.to(torch.bfloat16)
results = dict(
image=image)
return results, time_code_to_3d, time_3D_to_img
================================================
FILE: lib/models/decoders/vit_head.py
================================================
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Sequence, Tuple, Optional
class VitHead(nn.Module):
def __init__(self,
in_channels: int,
out_channels: int,
deconv_out_channels: Optional[Sequence[int]] = (256, 256, 256),
deconv_kernel_sizes: Optional[Sequence[int]] = (4, 4, 4),
conv_out_channels: Optional[Sequence[int]] = None,
conv_kernel_sizes: Optional[Sequence[int]] = None,
):
super(VitHead, self).__init__()
if deconv_out_channels:
if deconv_kernel_sizes is None or len(deconv_out_channels) != len(deconv_kernel_sizes):
raise ValueError(
'"deconv_out_channels" and "deconv_kernel_sizes" should '
'be integer sequences with the same length. Got '
f'mismatched lengths {deconv_out_channels} and '
f'{deconv_kernel_sizes}')
self.deconv_layers = self._make_deconv_layers(
in_channels=in_channels,
layer_out_channels=deconv_out_channels,
layer_kernel_sizes=deconv_kernel_sizes,
)
in_channels = deconv_out_channels[-1]
else:
self.deconv_layers = nn.Identity()
if conv_out_channels:
if conv_kernel_sizes is None or len(conv_out_channels) != len(conv_kernel_sizes):
raise ValueError(
'"conv_out_channels" and "conv_kernel_sizes" should '
'be integer sequences with the same length. Got '
f'mismatched lengths {conv_out_channels} and '
f'{conv_kernel_sizes}')
self.conv_layers = self._make_conv_layers(
in_channels=in_channels,
layer_out_channels=conv_out_channels,
layer_kernel_sizes=conv_kernel_sizes)
in_channels = conv_out_channels[-1]
else:
self.conv_layers = nn.Identity()
self.cls_seg = nn.Conv2d(in_channels, out_channels, kernel_size=1)
def _make_conv_layers(self, in_channels: int,
layer_out_channels: Sequence[int],
layer_kernel_sizes: Sequence[int]) -> nn.Module:
"""Create convolutional layers by given parameters."""
layers = []
for out_channels, kernel_size in zip(layer_out_channels, layer_kernel_sizes):
padding = (kernel_size - 1) // 2
layers.append(nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=1, padding=padding))
layers.append(nn.InstanceNorm2d(out_channels))
layers.append(nn.SiLU(inplace=True))
in_channels = out_channels
return nn.Sequential(*layers)
def _make_deconv_layers(self, in_channels: int,
layer_out_channels: Sequence[int],
layer_kernel_sizes: Sequence[int]) -> nn.Module:
"""Create deconvolutional layers by given parameters."""
layers = []
for out_channels, kernel_size in zip(layer_out_channels, layer_kernel_sizes):
if kernel_size == 4:
padding = 1
output_padding = 0
elif kernel_size == 3:
padding = 1
output_padding = 1
elif kernel_size == 2:
padding = 0
output_padding = 0
else:
raise ValueError(f'Unsupported kernel size {kernel_size} for deconvolutional layers')
layers.append(nn.ConvTranspose2d(in_channels, out_channels, kernel_size=kernel_size, stride=2, padding=padding, output_padding=output_padding, bias=False))
layers.append(nn.InstanceNorm2d(out_channels))
layers.append(nn.SiLU(inplace=True))
in_channels = out_channels
return nn.Sequential(*layers)
def forward(self, inputs):
x = self.deconv_layers(inputs)
x = self.conv_layers(x)
out = self.cls_seg(x)
return out
if __name__ == "__main__":
# Example usage:
model = VitHead(in_channels=1536, out_channels=21, deconv_out_channels=(768, 768, 512, 512),
deconv_kernel_sizes=(4, 4, 4, 4),
conv_out_channels=(512, 256, 128), conv_kernel_sizes=(1, 1, 1),
)
inputs = torch.randn(1, 1536, 64, 64)
outputs = model(inputs)
print(outputs.shape)
================================================
FILE: lib/models/deformers/__init__.py
================================================
from .smplx_deformer_gender import SMPLXDeformer_gender
__all__ = ['SMPLXDeformer_gender']
================================================
FILE: lib/models/deformers/fast_snarf/cuda/filter/filter.cpp
================================================
#include <torch/extension.h>
#include <ATen/ATen.h>
#include <vector>
#include <c10/cuda/CUDAGuard.h>
void launch_filter(
torch::Tensor &output, const torch::Tensor &x, const torch::Tensor &mask);
void filter(const torch::Tensor &x, const torch::Tensor &mask, torch::Tensor &output) {
const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
launch_filter(output, x, mask);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("filter", &filter);
}
================================================
FILE: lib/models/deformers/fast_snarf/cuda/filter/filter_kernel.cu
================================================
#include <vector>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/detail/TensorInfo.cuh>
#include <ATen/cuda/detail/IndexUtils.cuh>
#include <ATen/cuda/detail/KernelUtils.h>
#include <ATen/core/TensorBase.h>
#include <ATen/Dispatch.h>
#include <c10/macros/Macros.h>
#include "ATen/Functions.h"
#include "ATen/core/TensorAccessor.h"
#include "c10/cuda/CUDAException.h"
#include "c10/cuda/CUDAStream.h"
#include <chrono>
#define TensorAcc4R PackedTensorAccessor32<scalar_t,4,RestrictPtrTraits>
#define TensorAcc5R PackedTensorAccessor32<scalar_t,5,RestrictPtrTraits>
using namespace at;
using namespace at::cuda::detail;
template <typename scalar_t, typename index_t>
C10_LAUNCH_BOUNDS_1(512)
__global__ void filter(
const index_t nthreads,
PackedTensorAccessor32<scalar_t, 4, RestrictPtrTraits> x,
PackedTensorAccessor32<bool, 3, RestrictPtrTraits> mask,
PackedTensorAccessor32<bool, 3, RestrictPtrTraits> output) {
index_t n_batch = mask.size(0);
index_t n_point = mask.size(1);
index_t n_init = mask.size(2);
CUDA_KERNEL_LOOP_TYPE(index, nthreads, index_t) {
const index_t i_batch = index / (n_batch*n_point);
const index_t i_point = index % (n_batch*n_point);
for(index_t i = 0; i < n_init; i++) {
if(!mask[i_batch][i_point][i]){
output[i_batch][i_point][i] = false;
continue;
}
scalar_t xi0 = x[i_batch][i_point][i][0];
scalar_t xi1 = x[i_batch][i_point][i][1];
scalar_t xi2 = x[i_batch][i_point][i][2];
bool flag = true;
for(index_t j = i+1; j < n_init; j++){
if(!mask[i_batch][i_point][j]){
continue;
}
scalar_t d0 = xi0 - x[i_batch][i_point][j][0];
scalar_t d1 = xi1 - x[i_batch][i_point][j][1];
scalar_t d2 = xi2 - x[i_batch][i_point][j][2];
scalar_t dist = d0*d0 + d1*d1 + d2*d2;
if(dist<0.0001*0.0001){
flag=false;
break;
}
}
output[i_batch][i_point][i] = flag;
}
}
}
void launch_filter(
Tensor &output,
const Tensor &x,
const Tensor &mask) {
// calculate #threads required
int64_t B = output.size(0);
int64_t N = output.size(1);
int64_t count = B*N;
if (count > 0) {
AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "filter", [&] {
filter<scalar_t>
<<<GET_BLOCKS(count, 512), 512, 0, at::cuda::getCurrentCUDAStream()>>>(
static_cast<int>(count),
x.packed_accessor32<scalar_t,4,RestrictPtrTraits>(),
mask.packed_accessor32<bool,3,RestrictPtrTraits>(),
output.packed_accessor32<bool,3,RestrictPtrTraits>());
C10_CUDA_KERNEL_LAUNCH_CHECK();
});
}
}
================================================
FILE: lib/models/deformers/fast_snarf/cuda/fuse_kernel/fuse_cuda.cpp
================================================
#include "ATen/Functions.h"
#include "ATen/core/TensorBody.h"
#include <torch/extension.h>
#include <ATen/ATen.h>
#include <vector>
#include <c10/cuda/CUDAGuard.h>
void launch_broyden_kernel(torch::Tensor &x,
const torch::Tensor &xd_tgt,
const torch::Tensor &grid,
const torch::Tensor &grid_J_inv,
const torch::Tensor &tfs,
const torch::Tensor &bone_ids,
bool align_corners,
// torch::Tensor &J_inv,
torch::Tensor &is_valid,
const torch::Tensor& offset,
const torch::Tensor& scale,
float cvg_threshold,
float dvg_threshold);
void fuse_broyden(torch::Tensor &x,
const torch::Tensor &xd_tgt,
const torch::Tensor &grid,
const torch::Tensor &grid_J_inv,
const torch::Tensor &tfs,
const torch::Tensor &bone_ids,
bool align_corners,
// torch::Tensor& J_inv,
torch::Tensor &is_valid,
torch::Tensor& offset,
torch::Tensor& scale,
float cvg_threshold,
float dvg_threshold) {
const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
launch_broyden_kernel(x, xd_tgt, grid, grid_J_inv, tfs, bone_ids, align_corners, is_valid, offset, scale, cvg_threshold, dvg_threshold);
return;
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("fuse_broyden", &fuse_broyden);
}
================================================
FILE: lib/models/deformers/fast_snarf/cuda/fuse_kernel/fuse_cuda_kernel.cu
================================================
#include "ATen/Functions.h"
#include "ATen/core/TensorAccessor.h"
#include "c10/cuda/CUDAException.h"
#include "c10/cuda/CUDAStream.h"
#include <ratio>
#include <vector>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/detail/TensorInfo.cuh>
#include <ATen/cuda/detail/IndexUtils.cuh>
#include <ATen/cuda/detail/KernelUtils.h>
#include <ATen/core/TensorBase.h>
#include <ATen/Dispatch.h>
#include <c10/macros/Macros.h>
#include <chrono>
using namespace std::chrono;
using namespace at;
using namespace at::cuda::detail;
template <typename scalar_t, typename index_t>
__device__ void fuse_J_inv_update(const index_t index,
scalar_t* J_inv,
scalar_t x0,
scalar_t x1,
scalar_t x2,
scalar_t g0,
scalar_t g1,
scalar_t g2) {
// index_t s_J = J_inv.strides[0];
// index_t s_x = delta_x.strides[0];
// index_t s_g = delta_gx.strides[0];
index_t s_J = 9;
index_t s_x = 3;
index_t s_g = 3;
scalar_t J00 = J_inv[3*0 + 0];
scalar_t J01 = J_inv[3*0 + 1];
scalar_t J02 = J_inv[3*0 + 2];
scalar_t J10 = J_inv[3*1 + 0];
scalar_t J11 = J_inv[3*1 + 1];
scalar_t J12 = J_inv[3*1 + 2];
scalar_t J20 = J_inv[3*2 + 0];
scalar_t J21 = J_inv[3*2 + 1];
scalar_t J22 = J_inv[3*2 + 2];
auto c0 = J00 * x0 + J10 * x1 + J20 * x2;
auto c1 = J01 * x0 + J11 * x1 + J21 * x2;
auto c2 = J02 * x0 + J12 * x1 + J22 * x2;
auto s = c0 * g0 + c1 * g1 + c2 * g2;
auto r0 = -J00 * g0 - J01 * g1 - J02 * g2;
auto r1 = -J10 * g0 - J11 * g1 - J12 * g2;
auto r2 = -J20 * g0 - J21 * g1 - J22 * g2;
J_inv[3*0 + 0] += c0 * (r0 + x0) / s;
J_inv[3*0 + 1] += c1 * (r0 + x0) / s;
J_inv[3*0 + 2] += c2 * (r0 + x0) / s;
J_inv[3*1 + 0] += c0 * (r1 + x1) / s;
J_inv[3*1 + 1] += c1 * (r1 + x1) / s;
J_inv[3*1 + 2] += c2 * (r1 + x1) / s;
J_inv[3*2 + 0] += c0 * (r2 + x2) / s;
J_inv[3*2 + 1] += c1 * (r2 + x2) / s;
J_inv[3*2 + 2] += c2 * (r2 + x2) / s;
}
static __forceinline__ __device__
bool within_bounds_3d(int d, int h, int w, int D, int H, int W) {
return d >= 0 && d < D && h >= 0 && h < H && w >= 0 && w < W;
}
template <typename scalar_t>
static __forceinline__ __device__
scalar_t grid_sampler_unnormalize(scalar_t coord, int size, bool align_corners) {
if (align_corners) {
// unnormalize coord from [-1, 1] to [0, size - 1]
return ((coord + 1.f) / 2) * (size - 1);
} else {
// unnormalize coord from [-1, 1] to [-0.5, size - 0.5]
return ((coord + 1.f) * size - 1) / 2;
}
}
// Clips coordinates to between 0 and clip_limit - 1
template <typename scalar_t>
static __forceinline__ __device__
scalar_t clip_coordinates(scalar_t in, int clip_limit) {
return ::min(static_cast<scalar_t>(clip_limit - 1), ::max(in, static_cast<scalar_t>(0)));
}
template<typename scalar_t>
static __forceinline__ __device__
scalar_t safe_downgrade_to_int_range(scalar_t x){
// -100.0 does not have special meaning. This is just to make sure
// it's not within_bounds_2d or within_bounds_3d, and does not cause
// undefined behavior. See #35506.
if (x > INT_MAX-1 || x < INT_MIN || !::isfinite(static_cast<double>(x)))
return static_cast<scalar_t>(-100.0);
return x;
}
template<typename scalar_t>
static __forceinline__ __device__
scalar_t compute_coordinates(scalar_t coord, int size, bool align_corners) {
// clip coordinates to image borders
// coord = clip_coordinates(coord, size);
coord = safe_downgrade_to_int_range(coord);
return coord;
}
template <typename scalar_t>
static __forceinline__ __device__
scalar_t grid_sampler_compute_source_index(scalar_t coord, int size, bool align_corners) {
coord = grid_sampler_unnormalize(coord, size, align_corners);
coord = compute_coordinates(coord, size, align_corners);
return coord;
}
template <typename scalar_t, typename index_t>
__device__ void grid_sampler_3d(
index_t i_batch,
TensorInfo<scalar_t, index_t> input,
scalar_t grid_x,
scalar_t grid_y,
scalar_t grid_z,
// TensorInfo<scalar_t, index_t> output,
PackedTensorAccessor32<scalar_t, 5> input_p, // [1, 3, 8, 32, 32]
scalar_t* output,
// PackedTensorAccessor32<scalar_t, 3> output_p, // [1800000, 3, 1]
bool align_corners,
bool nearest) {
index_t C = input.sizes[1];
index_t inp_D = input.sizes[2];
index_t inp_H = input.sizes[3];
index_t inp_W = input.sizes[4];
// broyden x.sizes=[1800000,3,1]
index_t inp_sN = input.strides[0];
index_t inp_sC = input.strides[1];
index_t inp_sD = input.strides[2];
index_t inp_sH = input.strides[3];
index_t inp_sW = input.strides[4];
index_t out_sC = 1; //output size is same as grid size...
// return;
// get the corresponding input x, y, z co-ordinates from grid
scalar_t ix = grid_x;
scalar_t iy = grid_y;
scalar_t iz = grid_z;
// c0 ix,iy,iz=-0.848051,0.592726,0.259927
// c1 ix,iy,iz=2.355216,24.687256,4.409743
ix = grid_sampler_compute_source_index(ix, inp_W, align_corners);
iy = grid_sampler_compute_source_index(iy, inp_H, align_corners);
iz = grid_sampler_compute_source_index(iz, inp_D, align_corners);
if(!nearest){
// get corner pixel values from (x, y, z)
// for 4d, we used north-east-south-west
// for 5d, we add top-bottom
index_t ix_tnw = static_cast<index_t>(::floor(ix));
index_t iy_tnw = static_cast<index_t>(::floor(iy));
index_t iz_tnw = static_cast<index_t>(::floor(iz));
index_t ix_tne = ix_tnw + 1;
index_t iy_tne = iy_tnw;
index_t iz_tne = iz_tnw;
index_t ix_tsw = ix_tnw;
index_t iy_tsw = iy_tnw + 1;
index_t iz_tsw = iz_tnw;
index_t ix_tse = ix_tnw + 1;
index_t iy_tse = iy_tnw + 1;
index_t iz_tse = iz_tnw;
index_t ix_bnw = ix_tnw;
index_t iy_bnw = iy_tnw;
index_t iz_bnw = iz_tnw + 1;
index_t ix_bne = ix_tnw + 1;
index_t iy_bne = iy_tnw;
index_t iz_bne = iz_tnw + 1;
index_t ix_bsw = ix_tnw;
index_t iy_bsw = iy_tnw + 1;
index_t iz_bsw = iz_tnw + 1;
index_t ix_bse = ix_tnw + 1;
index_t iy_bse = iy_tnw + 1;
index_t iz_bse = iz_tnw + 1;
// get surfaces to each neighbor:
scalar_t tnw = (ix_bse - ix) * (iy_bse - iy) * (iz_bse - iz);
scalar_t tne = (ix - ix_bsw) * (iy_bsw - iy) * (iz_bsw - iz);
scalar_t tsw = (ix_bne - ix) * (iy - iy_bne) * (iz_bne - iz);
scalar_t tse = (ix - ix_bnw) * (iy - iy_bnw) * (iz_bnw - iz);
scalar_t bnw = (ix_tse - ix) * (iy_tse - iy) * (iz - iz_tse);
scalar_t bne = (ix - ix_tsw) * (iy_tsw - iy) * (iz - iz_tsw);
scalar_t bsw = (ix_tne - ix) * (iy - iy_tne) * (iz - iz_tne);
scalar_t bse = (ix - ix_tnw) * (iy - iy_tnw) * (iz - iz_tnw);
for (index_t xyz = 0; xyz < C; xyz++) {
output[xyz] = 0;
if (within_bounds_3d(iz_tnw, iy_tnw, ix_tnw, inp_D, inp_H, inp_W)) {
output[xyz] += input_p[i_batch][xyz][iz_tnw][iy_tnw][ix_tnw] * tnw;
// *out_ptr_NCDHW += inp_ptr_NC[iz_tnw * inp_sD + iy_tnw * inp_sH + ix_tnw * inp_sW] * tnw;
}
if (within_bounds_3d(iz_tne, iy_tne, ix_tne, inp_D, inp_H, inp_W)) {
output[xyz] += input_p[i_batch][xyz][iz_tne][iy_tne][ix_tne] * tne;
// *out_ptr_NCDHW += inp_ptr_NC[iz_tne * inp_sD + iy_tne * inp_sH + ix_tne * inp_sW] * tne;
}
if (within_bounds_3d(iz_tsw, iy_tsw, ix_tsw, inp_D, inp_H, inp_W)) {
output[xyz] += input_p[i_batch][xyz][iz_tsw][iy_tsw][ix_tsw] * tsw;
// *out_ptr_NCDHW += inp_ptr_NC[iz_tsw * inp_sD + iy_tsw * inp_sH + ix_tsw * inp_sW] * tsw;
}
if (within_bounds_3d(iz_tse, iy_tse, ix_tse, inp_D, inp_H, inp_W)) {
output[xyz] += input_p[i_batch][xyz][iz_tse][iy_tse][ix_tse] * tse;
// *out_ptr_NCDHW += inp_ptr_NC[iz_tse * inp_sD + iy_tse * inp_sH + ix_tse * inp_sW] * tse;
}
if (within_bounds_3d(iz_bnw, iy_bnw, ix_bnw, inp_D, inp_H, inp_W)) {
output[xyz] += input_p[i_batch][xyz][iz_bnw][iy_bnw][ix_bnw] * bnw;
// *out_ptr_NCDHW += inp_ptr_NC[iz_bnw * inp_sD + iy_bnw * inp_sH + ix_bnw * inp_sW] * bnw;
}
if (within_bounds_3d(iz_bne, iy_bne, ix_bne, inp_D, inp_H, inp_W)) {
output[xyz] += input_p[i_batch][xyz][iz_bne][iy_bne][ix_bne] * bne;
// *out_ptr_NCDHW += inp_ptr_NC[iz_bne * inp_sD + iy_bne * inp_sH + ix_bne * inp_sW] * bne;
}
if (within_bounds_3d(iz_bsw, iy_bsw, ix_bsw, inp_D, inp_H, inp_W)) {
output[xyz] += input_p[i_batch][xyz][iz_bsw][iy_bsw][ix_bsw] * bsw;
// *out_ptr_NCDHW += inp_ptr_NC[iz_bsw * inp_sD + iy_bsw * inp_sH + ix_bsw * inp_sW] * bsw;
}
if (within_bounds_3d(iz_bse, iy_bse, ix_bse, inp_D, inp_H, inp_W)) {
output[xyz] += input_p[i_batch][xyz][iz_bse][iy_bse][ix_bse] * bse;
// *out_ptr_NCDHW += inp_ptr_NC[iz_bse * inp_sD + iy_bse * inp_sH + ix_bse * inp_sW] * bse;
}
}
}
else{
index_t ix_nearest = static_cast<index_t>(::round(ix));
index_t iy_nearest = static_cast<index_t>(::round(iy));
index_t iz_nearest = static_cast<index_t>(::round(iz));
for (index_t xyz = 0; xyz < C; xyz++) {
if (within_bounds_3d(iz_nearest, iy_nearest, ix_nearest, inp_D, inp_H, inp_W)) {
output[xyz] = input_p[i_batch][xyz][iz_nearest][iy_nearest][ix_nearest];
} else {
output[xyz] = static_cast<scalar_t>(0);
}
}
}
}
template <typename scalar_t, typename index_t>
C10_LAUNCH_BOUNDS_1(512)
__global__ void broyden_kernel(
const index_t npoints,
const index_t n_batch,
const index_t n_point,
const index_t n_init,
TensorInfo<scalar_t, index_t> voxel_ti,
TensorInfo<scalar_t, index_t> voxel_J_ti,
PackedTensorAccessor32<scalar_t, 4> x, // shape=(N,200000, 9, 3)
PackedTensorAccessor32<scalar_t, 3> xd_tgt, // shape=(N,200000, 3)
PackedTensorAccessor32<scalar_t, 5> voxel, // shape=(N,3,8,32,32)
PackedTensorAccessor32<scalar_t, 5> grid_J_inv, // shape=(N,9,8,32,32)
PackedTensorAccessor32<scalar_t, 4> tfs, // shape=(N,24,4,4)
PackedTensorAccessor32<int, 1> bone_ids, // shape=(9)
// PackedTensorAccessor32<scalar_t, 5> J_inv,// shape=(N,200000, 9, 9)
PackedTensorAccessor32<bool, 3> is_valid,// shape=(N,200000, 9)
PackedTensorAccessor32<scalar_t, 3> offset, // shape=(N, 1, 3)
PackedTensorAccessor32<scalar_t, 3> scale, // shape=(N, 1, 3)
float cvg_threshold,
float dvg_threshold,
int N
)
{
index_t index = blockIdx.x * blockDim.x + threadIdx.x;
if(index >= npoints) return;
const index_t i_batch = index / (n_point*n_init);
const index_t i_point = (index % (n_point*n_init)) / n_init;
const index_t i_init = (index % (n_point*n_init)) % n_init;
if(!is_valid[i_batch][i_point][i_init]){
return;
}
scalar_t gx[3];
scalar_t gx_new[3];
scalar_t xd_tgt_index[3];
xd_tgt_index[0] = xd_tgt[i_batch][i_point][0];
xd_tgt_index[1] = xd_tgt[i_batch][i_point][1];
xd_tgt_index[2] = xd_tgt[i_batch][i_point][2];
scalar_t x_l[3];
int i_bone = bone_ids[i_init];
scalar_t ixd = xd_tgt_index[0] - tfs[i_batch][i_bone][0][3];
scalar_t iyd = xd_tgt_index[1] - tfs[i_batch][i_bone][1][3];
scalar_t izd = xd_tgt_index[2] - tfs[i_batch][i_bone][2][3];
x_l[0] = ixd * tfs[i_batch][i_bone][0][0]
+ iyd * tfs[i_batch][i_bone][1][0]
+ izd * tfs[i_batch][i_bone][2][0];
x_l[1] = ixd * tfs[i_batch][i_bone][0][1]
+ iyd * tfs[i_batch][i_bone][1][1]
+ izd * tfs[i_batch][i_bone][2][1];
x_l[2] = ixd * tfs[i_batch][i_bone][0][2]
+ iyd * tfs[i_batch][i_bone][1][2]
+ izd * tfs[i_batch][i_bone][2][2];
// scalar_t ix = scale[0][0][0] * (x_l[0] + offset[0][0][0]);
// scalar_t iy = scale[0][0][1] * (x_l[1] + offset[0][0][1]);
// scalar_t iz = scale[0][0][2] * (x_l[2] + offset[0][0][2]);
scalar_t J_local[12];
grid_sampler_3d( i_batch, voxel_J_ti,
scale[0][0][0] * (x_l[0] + offset[0][0][0]),
scale[0][0][1] * (x_l[1] + offset[0][0][1]),
scale[0][0][2] * (x_l[2] + offset[0][0][2]),
grid_J_inv,
J_local,
true,
false);
scalar_t J_inv_local[9];
J_inv_local[3*0 + 0] = J_local[4*0 + 0];
J_inv_local[3*1 + 0] = J_local[4*0 + 1];
J_inv_local[3*2 + 0] = J_local[4*0 + 2];
J_inv_local[3*0 + 1] = J_local[4*1 + 0];
J_inv_local[3*1 + 1] = J_local[4*1 + 1];
J_inv_local[3*2 + 1] = J_local[4*1 + 2];
J_inv_local[3*0 + 2] = J_local[4*2 + 0];
J_inv_local[3*1 + 2] = J_local[4*2 + 1];
J_inv_local[3*2 + 2] = J_local[4*2 + 2];
for(int i=0; i<10; i++) {
scalar_t J00 = J_inv_local[3*0 + 0];
scalar_t J01 = J_inv_local[3*0 + 1];
scalar_t J02 = J_inv_local[3*0 + 2];
scalar_t J10 = J_inv_local[3*1 + 0];
scalar_t J11 = J_inv_local[3*1 + 1];
scalar_t J12 = J_inv_local[3*1 + 2];
scalar_t J20 = J_inv_local[3*2 + 0];
scalar_t J21 = J_inv_local[3*2 + 1];
scalar_t J22 = J_inv_local[3*2 + 2];
// gx = g(x)
if (i==0){
// grid_sampler_3d( i_batch, voxel_ti,
// scale[0][0][0] * (x_l[0] + offset[0][0][0]),
// scale[0][0][1] * (x_l[1] + offset[0][0][1]),
// scale[0][0][2] * (x_l[2] + offset[0][0][2]),
// voxel,
// gx,
// true,
// false);
gx[0] = J_local[4*0+0] * x_l[0] + J_local[4*0+1] * x_l[1] + J_local[4*0+2] * x_l[2] + J_local[4*0+3];
gx[1] = J_local[4*1+0] * x_l[0] + J_local[4*1+1] * x_l[1] + J_local[4*1+2] * x_l[2] + J_local[4*1+3];
gx[2] = J_local[4*2+0] * x_l[0] + J_local[4*2+1] * x_l[1] + J_local[4*2+2] * x_l[2] + J_local[4*2+3];
gx[0] = gx[0] - xd_tgt_index[0];
gx[1] = gx[1] - xd_tgt_index[1];
gx[2] = gx[2] - xd_tgt_index[2];
}
else{
gx[0] = gx_new[0];
gx[1] = gx_new[1];
gx[2] = gx_new[2];
}
// update = -J_inv @ gx
scalar_t u0 = -J00*gx[0] + -J01*gx[1] + -J02*gx[2];
scalar_t u1 = -J10*gx[0] + -J11*gx[1] + -J12*gx[2];
scalar_t u2 = -J20*gx[0] + -J21*gx[1] + -J22*gx[2];
// x += update
x_l[0] += u0;
x_l[1] += u1;
x_l[2] += u2;
scalar_t ix = scale[0][0][0] * (x_l[0] + offset[0][0][0]);
scalar_t iy = scale[0][0][1] * (x_l[1] + offset[0][0][1]);
scalar_t iz = scale[0][0][2] * (x_l[2] + offset[0][0][2]);
// gx_new = g(x)
grid_sampler_3d( i_batch, voxel_J_ti,
ix,
iy,
iz,
grid_J_inv,
J_local,
true,
false);
gx_new[0] = J_local[4*0+0] * x_l[0] + J_local[4*0+1] * x_l[1] + J_local[4*0+2] * x_l[2] + J_local[4*0+3] - xd_tgt_index[0];
gx_new[1] = J_local[4*1+0] * x_l[0] + J_local[4*1+1] * x_l[1] + J_local[4*1+2] * x_l[2] + J_local[4*1+3] - xd_tgt_index[1];
gx_new[2] = J_local[4*2+0] * x_l[0] + J_local[4*2+1] * x_l[1] + J_local[4*2+2] * x_l[2] + J_local[4*2+3] - xd_tgt_index[2];
// grid_sampler_3d( i_batch, voxel_ti,
// scale[0][0][0] * (x_l[0] + offset[0][0][0]),
// scale[0][0][1] * (x_l[1] + offset[0][0][1]),
// scale[0][0][2] * (x_l[2] + offset[0][0][2]),
// voxel,
// gx_new,
// true,
// false);
// gx_new[0] = gx_new[0] - xd_tgt_index[0];
// gx_new[1] = gx_new[1] - xd_tgt_index[1];
// gx_new[2] = gx_new[2] - xd_tgt_index[2];
// convergence checking
scalar_t norm_gx = gx_new[0]*gx_new[0] + gx_new[1]*gx_new[1] + gx_new[2]*gx_new[2];
// convergence/divergence criterion
if(norm_gx < cvg_threshold*cvg_threshold) {
bool is_valid_ = ix >= -1 && ix <= 1 && iy >= -1 && iy <= 1 && iz >= -1 && iz <= 1;
is_valid[i_batch][i_point][i_init] = is_valid_;
if (is_valid_){
x[i_batch][i_point][i_init][0] = x_l[0];
x[i_batch][i_point][i_init][1] = x_l[1];
x[i_batch][i_point][i_init][2] = x_l[2];
// J_inv[i_batch][i_point][i_init][0][0] = J00;
// J_inv[i_batch][i_point][i_init][0][1] = J01;
// J_inv[i_batch][i_point][i_init][0][2] = J02;
// J_inv[i_batch][i_point][i_init][1][0] = J10;
// J_inv[i_batch][i_point][i_init][1][1] = J11;
// J_inv[i_batch][i_point][i_init][1][2] = J12;
// J_inv[i_batch][i_point][i_init][2][0] = J20;
// J_inv[i_batch][i_point][i_init][2][1] = J21;
// J_inv[i_batch][i_point][i_init][2][2] = J22;
}
return;
}
else if(norm_gx > dvg_threshold*dvg_threshold) {
is_valid[i_batch][i_point][i_init] = false;
return;
}
// delta_x = update
scalar_t delta_x_0 = u0;
scalar_t delta_x_1 = u1;
scalar_t delta_x_2 = u2;
// delta_gx = gx_new - gx
scalar_t delta_gx_0 = gx_new[0] - gx[0];
scalar_t delta_gx_1 = gx_new[1] - gx[1];
scalar_t delta_gx_2 = gx_new[2] - gx[2];
fuse_J_inv_update(index, J_inv_local, delta_x_0, delta_x_1, delta_x_2, delta_gx_0, delta_gx_1, delta_gx_2);
}
is_valid[i_batch][i_point][i_init] = false;
return;
}
void launch_broyden_kernel(
Tensor &x,
const Tensor &xd_tgt,
const Tensor &voxel,
const Tensor &grid_J_inv,
const Tensor &tfs,
const Tensor &bone_ids,
bool align_corners,
// Tensor &J_inv,
Tensor &is_valid,
const Tensor& offset,
const Tensor& scale,
float cvg_threshold,
float dvg_threshold) {
// calculate #threads required
int64_t n_batch = xd_tgt.size(0);
int64_t n_point = xd_tgt.size(1);
int64_t n_init = bone_ids.size(0);
int64_t count = n_batch * n_point * n_init;
if (count > 0) {
AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "fuse_kernel_cuda", [&] {
broyden_kernel
<<<GET_BLOCKS(count, 512), 512, 0,
at::cuda::getCurrentCUDAStream()>>>(static_cast<int>(count),
static_cast<int>(n_batch),
static_cast<int>(n_point),
static_cast<int>(n_init),
getTensorInfo<scalar_t, int>(voxel),
getTensorInfo<scalar_t, int>(grid_J_inv),
x.packed_accessor32<scalar_t, 4>(),
xd_tgt.packed_accessor32<scalar_t, 3>(),
voxel.packed_accessor32<scalar_t, 5>(),
grid_J_inv.packed_accessor32<scalar_t, 5>(),
tfs.packed_accessor32<scalar_t, 4>(),
bone_ids.packed_accessor32<int, 1>(),
// J_inv.packed_accessor32<scalar_t, 5>(),
is_valid.packed_accessor32<bool, 3>(),
offset.packed_accessor32<scalar_t, 3>(),
scale.packed_accessor32<scalar_t, 3>(),
cvg_threshold,
dvg_threshold,
0);
C10_CUDA_KERNEL_LAUNCH_CHECK();
});
}
cudaDeviceSynchronize();
}
================================================
FILE: lib/models/deformers/fast_snarf/cuda/precompute/precompute.cpp
================================================
#include <torch/extension.h>
#include <ATen/ATen.h>
#include <vector>
#include <c10/cuda/CUDAGuard.h>
void launch_precompute(const torch::Tensor &voxel_w, const torch::Tensor &tfs, torch::Tensor &voxel_d, torch::Tensor &voxel_J, const torch::Tensor &offset, const torch::Tensor &scale);
void precompute(const torch::Tensor &voxel_w, const torch::Tensor &tfs, torch::Tensor &voxel_d, torch::Tensor &voxel_J, const torch::Tensor &offset, const torch::Tensor &scale) {
const at::cuda::OptionalCUDAGuard device_guard(device_of(voxel_w));
launch_precompute(voxel_w, tfs, voxel_d, voxel_J, offset, scale);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("precompute", &precompute);
}
================================================
FILE: lib/models/deformers/fast_snarf/cuda/precompute/precompute_kernel.cu
================================================
#include "ATen/Functions.h"
#include "ATen/core/TensorAccessor.h"
#include "c10/cuda/CUDAException.h"
#include "c10/cuda/CUDAStream.h"
#include <ratio>
#include <vector>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/detail/TensorInfo.cuh>
#include <ATen/cuda/detail/IndexUtils.cuh>
#include <ATen/cuda/detail/KernelUtils.h>
#include <ATen/core/TensorBase.h>
#include <ATen/Dispatch.h>
#include <c10/macros/Macros.h>
#include <chrono>
using namespace std::chrono;
using namespace at;
using namespace at::cuda::detail;
template <typename scalar_t, typename index_t>
C10_LAUNCH_BOUNDS_1(512)
__global__ void precompute_kernel(
const index_t npoints,
const index_t d,
const index_t h,
const index_t w,
PackedTensorAccessor32<scalar_t, 5> voxel_w, // shape=(N,200000, 9, 3)
PackedTensorAccessor32<scalar_t, 4> tfs, // shape=(N,200000, 3)
PackedTensorAccessor32<scalar_t, 5> voxel_d, // shape=(N,3,8,32,32)
PackedTensorAccessor32<scalar_t, 5> voxel_J, // shape=(N,9,8,32,32)
PackedTensorAccessor32<scalar_t, 3> offset, // shape=(N, 1, 3)
PackedTensorAccessor32<scalar_t, 3> scale // shape=(N, 1, 3)
)
{
index_t index = blockIdx.x * blockDim.x + threadIdx.x;
if(index >= npoints) return;
index_t idx_b = index / (d*h*w);
index_t idx_d = index % (d*h*w) / (h*w);
index_t idx_h = index % (d*h*w) % (h*w) / w;
index_t idx_w = index % (d*h*w) % (h*w) % w;
scalar_t coord_x = ( ((scalar_t)idx_w) / (w-1) * 2 -1) / scale[0][0][0] - offset[0][0][0];
scalar_t coord_y = ( ((scalar_t)idx_h) / (h-1) * 2 -1) / scale[0][0][1] - offset[0][0][1];
scalar_t coord_z = ( ((scalar_t)idx_d) / (d-1) * 2 -1) / scale[0][0][2] - offset[0][0][2];
scalar_t J[12];
for(index_t i0 = 0; i0 < 3; i0++){
for(index_t i1 = 0; i1 < 4; i1++){
J[i0*4 + i1] = 0;
for(index_t j = 0; j < 24; j++){
J[i0*4 + i1] += voxel_w[0][j][idx_d][idx_h][idx_w]*tfs[idx_b][j][i0][i1];
}
}
}
for(index_t i0 = 0; i0 < 3; i0++){
for(index_t i1 = 0; i1 < 4; i1++){
voxel_J[idx_b][i0*4 + i1][idx_d][idx_h][idx_w] = J[i0*4 + i1];
}
}
for(index_t i0 = 0; i0 < 3; i0++){
scalar_t xi = J[i0*4 + 0]*coord_x + J[i0*4 + 1]*coord_y + J[i0*4 + 2]*coord_z + J[i0*4 + 3];
voxel_d[idx_b][i0][idx_d][idx_h][idx_w] = xi;
}
}
void launch_precompute(
const Tensor &voxel_w,
const Tensor &tfs,
Tensor &voxel_d,
Tensor &voxel_J,
const Tensor &offset,
const Tensor &scale
) {
// calculate #threads required
int64_t n_batch = voxel_d.size(0);
int64_t d = voxel_d.size(2);
int64_t h = voxel_d.size(3);
int64_t w = voxel_d.size(4);
int64_t count = n_batch*d*h*w;
if (count > 0) {
AT_DISPATCH_FLOATING_TYPES_AND_HALF(voxel_w.scalar_type(), "precompute", [&] {
precompute_kernel
<<<GET_BLOCKS(count, 512), 512, 0,
at::cuda::getCurrentCUDAStream()>>>(static_cast<int>(count),
static_cast<int>(d),
static_cast<int>(h),
static_cast<int>(w),
voxel_w.packed_accessor32<scalar_t, 5>(),
tfs.packed_accessor32<scalar_t, 4>(),
voxel_d.packed_accessor32<scalar_t, 5>(),
voxel_J.packed_accessor32<scalar_t, 5>(),
offset.packed_accessor32<scalar_t, 3>(),
scale.packed_accessor32<scalar_t, 3>()
);
C10_CUDA_KERNEL_LAUNCH_CHECK();
});
}
cudaDeviceSynchronize();
}
================================================
FILE: lib/models/deformers/fast_snarf/lib/model/deformer_smpl.py
================================================
import torch
from torch import einsum
import torch.nn.functional as F
import os
from torch.utils.cpp_extension import load
import fuse_cuda
import filter_cuda
import precompute_cuda
import numpy as np
class ForwardDeformer(torch.nn.Module):
"""
Tensor shape abbreviation:
B: batch size
N: number of points
J: number of bones
I: number of init
D: space dimension
"""
def __init__(self, **kwargs):
super().__init__()
self.soft_blend = 20
self.init_bones = [0, 1, 2, 4, 5, 12, 15, 16, 17, 18, 19]
self.init_bones_cuda = torch.tensor(self.init_bones).int()
self.global_scale = 1.2
def forward(self, xd, cond, mask, tfs, eval_mode=False):
"""Given deformed point return its caonical correspondence
Args:
xd (tensor): deformed points in batch. shape: [B, N, D]
cond (dict): conditional input.
tfs (tensor): bone transformation matrices. shape: [B, J, D+1, D+1]
Returns:
xc (tensor): canonical correspondences. shape: [B, N, I, D]
others (dict): other useful outputs.
"""
xc_opt, others = self.search(xd, cond, mask, tfs, eval_mode=True)
if eval_mode:
return xc_opt, others
def precompute(self, tfs):
b, c, d, h, w = tfs.shape[0], 3, self.resolution//4, self.resolution, self.resolution
voxel_d = torch.zeros((b,3,d,h,w), device=tfs.device)
voxel_J = torch.zeros((b,12,d,h,w), device=tfs.device)
precompute_cuda.precompute(self.lbs_voxel_final, tfs, voxel_d, voxel_J, self.offset_kernel, self.scale_kernel)
self.voxel_d = voxel_d
self.voxel_J = voxel_J
def search(self, xd, cond, mask, tfs, eval_mode=False):
"""Search correspondences.
Args:
xd (tensor): deformed points in batch. shape: [B, N, D]
xc_init (tensor): deformed points in batch. shape: [B, N, I, D]
cond (dict): conditional input.
tfs (tensor): bone transformation matrices. shape: [B, J, D+1, D+1]
Returns:
xc_opt (tensor): canonoical correspondences of xd. shape: [B, N, I, D]
valid_ids (tensor): identifiers of converged points. [B, N, I]
"""
# reshape to [B,?,D] for other functions
# run broyden without grad
with torch.no_grad():
result = self.broyden_cuda(xd, self.voxel_d, self.voxel_J, tfs, mask)
return result['result'], result
def broyden_cuda(self,
xd_tgt,
voxel,
voxel_J_inv,
tfs,
mask,
cvg_thresh=2e-4,
dvg_thresh=1):
"""
Args:
g: f: (N, 3, 1) -> (N, 3, 1)
x: (N, 3, 1)
J_inv: (N, 3, 3)
"""
b,n,_ = xd_tgt.shape
n_init = self.init_bones_cuda.shape[0]
xc_init_IN = torch.zeros((b,n,n_init,3),device=xd_tgt.device,dtype=torch.float)
is_valid = mask.expand(b,n,n_init).clone()
if self.init_bones_cuda.device != xd_tgt.device:
self.init_bones_cuda = self.init_bones_cuda.to(xd_tgt.device)
fuse_cuda.fuse_broyden(xc_init_IN, xd_tgt, voxel, voxel_J_inv, tfs, self.init_bones_cuda, True, is_valid, self.offset_kernel, self.scale_kernel, cvg_thresh, dvg_thresh)
is_valid_new = torch.zeros_like(is_valid)
filter_cuda.filter(xc_init_IN, is_valid, is_valid_new)
return {"result": xc_init_IN, 'valid_ids': is_valid_new} #, 'J_inv': J_inv_init_IN}
def forward_skinning(self, xc, cond, tfs, mask=None):
"""Canonical point -> deformed point
Args:
xc (tensor): canonoical points in batch. shape: [B, N, D]
cond (dict): conditional input.
tfs (tensor): bone transformation matrices. shape: [B, J, D+1, D+1]
Returns:
gitextract_sj4g9tfa/ ├── .gitignore ├── README.md ├── configs/ │ ├── idol_debug.yaml │ ├── idol_v0.yaml │ └── test_dataset.yaml ├── data_processing/ │ ├── prepare_cache.py │ ├── process_datasets.sh │ └── visualize_samples.py ├── dataset/ │ ├── README.md │ ├── sample/ │ │ └── param/ │ │ └── Kenya_female_fit_streetwear_50~60 years old_1501.npy │ └── visualize_samples.py ├── env/ │ └── README.md ├── lib/ │ ├── __init__.py │ ├── datasets/ │ │ ├── __init__.py │ │ ├── avatar_dataset.py │ │ └── dataloader.py │ ├── humanlrm_wrapper_sa_v1.py │ ├── mmutils/ │ │ ├── __init__.py │ │ └── initialize.py │ ├── models/ │ │ ├── __init__.py │ │ ├── decoders/ │ │ │ ├── __init__.py │ │ │ ├── uvmaps_decoder_gender.py │ │ │ └── vit_head.py │ │ ├── deformers/ │ │ │ ├── __init__.py │ │ │ ├── fast_snarf/ │ │ │ │ ├── cuda/ │ │ │ │ │ ├── filter/ │ │ │ │ │ │ ├── filter.cpp │ │ │ │ │ │ └── filter_kernel.cu │ │ │ │ │ ├── fuse_kernel/ │ │ │ │ │ │ ├── fuse_cuda.cpp │ │ │ │ │ │ └── fuse_cuda_kernel.cu │ │ │ │ │ └── precompute/ │ │ │ │ │ ├── precompute.cpp │ │ │ │ │ └── precompute_kernel.cu │ │ │ │ └── lib/ │ │ │ │ └── model/ │ │ │ │ ├── deformer_smpl.py │ │ │ │ └── deformer_smplx.py │ │ │ ├── smplx/ │ │ │ │ ├── __init__.py │ │ │ │ ├── body_models.py │ │ │ │ ├── joint_names.py │ │ │ │ ├── lbs.py │ │ │ │ ├── utils.py │ │ │ │ ├── vertex_ids.py │ │ │ │ └── vertex_joint_selector.py │ │ │ └── smplx_deformer_gender.py │ │ ├── renderers/ │ │ │ ├── __init__.py │ │ │ └── gau_renderer.py │ │ ├── sapiens/ │ │ │ ├── __init__.py │ │ │ └── sapiens_wrapper_torchscipt.py │ │ └── transformer_sa/ │ │ ├── __init__.py │ │ └── mae_decoder_v3_skip.py │ ├── ops/ │ │ ├── __init__.py │ │ └── activation.py │ └── utils/ │ ├── infer_util.py │ ├── mesh.py │ ├── mesh_utils.py │ └── train_util.py ├── run_demo.py ├── scripts/ │ ├── download_files.sh │ ├── fetch_template.sh │ └── pip_install.sh ├── setup.py └── train.py
SYMBOL INDEX (260 symbols across 29 files)
FILE: data_processing/prepare_cache.py
function parse_args (line 15) | def parse_args():
function prepare_dataset (line 45) | def prepare_dataset(video_dir, output_dir, prefix, max_total_videos=20000):
FILE: data_processing/visualize_samples.py
function init_smplx_model (line 11) | def init_smplx_model():
FILE: dataset/visualize_samples.py
function init_smplx_model (line 11) | def init_smplx_model():
FILE: lib/datasets/avatar_dataset.py
function load_pose (line 20) | def load_pose(path):
function load_npy (line 30) | def load_npy(file_path):
function load_smpl (line 33) | def load_smpl(path, smpl_type='smpl'):
class AvatarDataset (line 63) | class AvatarDataset(Dataset):
method __init__ (line 64) | def __init__(self,
method load_scenes (line 149) | def load_scenes(self):
method parse_scene (line 180) | def parse_scene(self, scene_id):
method __len__ (line 454) | def __len__(self):
method __getitem__ (line 457) | def __getitem__(self, scene_id):
method parse_scene_test (line 466) | def parse_scene_test(self, scene_id):
function gather_imgs (line 583) | def gather_imgs(img_ids, poses, image_paths_or_video, smpl_params, load_...
function calculate_angle (line 644) | def calculate_angle(vector1, vector2):
function axis_angle_to_rotation_matrix (line 650) | def axis_angle_to_rotation_matrix(axis_angle):
function find_front_camera_by_rotation (line 676) | def find_front_camera_by_rotation(poses, global_orient_human):
function read_frames (line 693) | def read_frames(video_path):
function prepare_camera (line 711) | def prepare_camera( resolution_x, resolution_y, num_views=24, stides=1):
function from_video_to_get_ref_smplx (line 769) | def from_video_to_get_ref_smplx(video_path):
function random_scale_and_crop (line 789) | def random_scale_and_crop(image: torch.Tensor, scale_range=(0.8, 1.2)) -...
FILE: lib/datasets/dataloader.py
class DataModuleFromConfig (line 19) | class DataModuleFromConfig(pl.LightningDataModule):
method __init__ (line 20) | def __init__(
method setup (line 42) | def setup(self, stage):
method train_dataloader (line 45) | def train_dataloader(self):
method val_dataloader (line 57) | def val_dataloader(self):
method test_dataloader (line 68) | def test_dataloader(self):
FILE: lib/humanlrm_wrapper_sa_v1.py
function get_1d_rotary_pos_embed (line 28) | def get_1d_rotary_pos_embed(
class FluxPosEmbed (line 81) | class FluxPosEmbed(torch.nn.Module):
method __init__ (line 83) | def __init__(self, theta: int, axes_dim: [int]):
method forward (line 88) | def forward(self, ids: torch.Tensor) -> torch.Tensor:
class SapiensGS_SA_v1 (line 105) | class SapiensGS_SA_v1(pl.LightningModule):
method __init__ (line 107) | def __init__(
method get_default_smplx_params (line 276) | def get_default_smplx_params(self):
method forward_decoder (line 302) | def forward_decoder(self, decoder, code, target_rgbs, cameras,
method on_fit_start (line 311) | def on_fit_start(self):
method forward (line 319) | def forward(self, data):
method forward_image_to_uv (line 376) | def forward_image_to_uv(self, inputs_img, is_training=True):
method compute_loss (line 410) | def compute_loss(self, render_out):
method compute_metrics (line 481) | def compute_metrics(self, render_out):
method new_on_before_optimizer_step (line 513) | def new_on_before_optimizer_step(self):
method validation_step (line 519) | def validation_step(self, batch, batch_idx):
method forward_nvPose (line 537) | def forward_nvPose(self, batch, smplx_given):
method on_validation_epoch_end (line 563) | def on_validation_epoch_end(self): #
method on_test_start (line 608) | def on_test_start(self):
method on_test_epoch_end (line 612) | def on_test_epoch_end(self):
method configure_optimizers (line 643) | def configure_optimizers(self):
method training_step (line 665) | def training_step(self, batch, batch_idx):
method test_step (line 701) | def test_step(self, batch, batch_idx):
method on_test_start (line 732) | def on_test_start(self):
method on_test_epoch_end (line 736) | def on_test_epoch_end(self):
function weighted_mse_loss (line 769) | def weighted_mse_loss(render_images, target_images, weights):
FILE: lib/mmutils/initialize.py
function constant_init (line 3) | def constant_init(module: nn.Module, val: float, bias: float = 0) -> None:
function xavier_init (line 10) | def xavier_init(module: nn.Module,
FILE: lib/models/decoders/uvmaps_decoder_gender.py
function ensure_dtype (line 25) | def ensure_dtype(input_tensor, target_dtype=torch.float32):
class UVNDecoder_gender (line 35) | class UVNDecoder_gender(nn.Module):
method __init__ (line 44) | def __init__(self,
method init_weights (line 213) | def init_weights(self):
method extract_pcd (line 224) | def extract_pcd(self, code, smpl_params, init=False, zeros_hands_off=F...
method deform_pcd (line 253) | def deform_pcd(self, code, smpl_params, init=False, zeros_hands_off=Fa...
method _sample_feature (line 280) | def _sample_feature(self,results,):
method _decode_feature (line 303) | def _decode_feature(self, point_code, init=False):
method _decode (line 334) | def _decode(self, point_code, init=False):
method gaussian_render (line 371) | def gaussian_render(self, pcd, sigmas, rgbs, normals, rot, num_scenes,...
method visualize (line 424) | def visualize(self, code, scene_name, viz_dir, code_range=[-1, 1]):
method forward (line 434) | def forward(self, code, smpl_params, cameras, num_imgs,
method forward_render (line 527) | def forward_render(self, code, cameras, num_imgs,
method forward_testing_time (line 613) | def forward_testing_time(self, code, smpl_params, cameras, num_imgs,
FILE: lib/models/decoders/vit_head.py
class VitHead (line 6) | class VitHead(nn.Module):
method __init__ (line 7) | def __init__(self,
method _make_conv_layers (line 52) | def _make_conv_layers(self, in_channels: int,
method _make_deconv_layers (line 66) | def _make_deconv_layers(self, in_channels: int,
method forward (line 91) | def forward(self, inputs):
FILE: lib/models/deformers/fast_snarf/cuda/filter/filter.cpp
function filter (line 13) | void filter(const torch::Tensor &x, const torch::Tensor &mask, torch::Te...
function PYBIND11_MODULE (line 18) | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
FILE: lib/models/deformers/fast_snarf/cuda/fuse_kernel/fuse_cuda.cpp
function fuse_broyden (line 25) | void fuse_broyden(torch::Tensor &x,
function PYBIND11_MODULE (line 47) | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
FILE: lib/models/deformers/fast_snarf/cuda/precompute/precompute.cpp
function precompute (line 12) | void precompute(const torch::Tensor &voxel_w, const torch::Tensor &tfs, ...
function PYBIND11_MODULE (line 18) | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
FILE: lib/models/deformers/fast_snarf/lib/model/deformer_smpl.py
class ForwardDeformer (line 14) | class ForwardDeformer(torch.nn.Module):
method __init__ (line 24) | def __init__(self, **kwargs):
method forward (line 36) | def forward(self, xd, cond, mask, tfs, eval_mode=False):
method precompute (line 54) | def precompute(self, tfs):
method search (line 63) | def search(self, xd, cond, mask, tfs, eval_mode=False):
method broyden_cuda (line 85) | def broyden_cuda(self,
method forward_skinning (line 116) | def forward_skinning(self, xc, cond, tfs, mask=None):
method switch_to_explicit (line 135) | def switch_to_explicit(self,resolution=32,smpl_verts=None, smpl_faces=...
method update_lbs_voxel (line 189) | def update_lbs_voxel(self):
method query_sdf_smpl (line 199) | def query_sdf_smpl(self, x, smpl_verts, smpl_faces, smpl_weights):
method skinning_normal (line 221) | def skinning_normal(self, xc, normal, tfs, cond=None, mask=None, inver...
function skinning_mask (line 242) | def skinning_mask(x, w, tfs, inverse=False):
function skinning (line 271) | def skinning(x, w, tfs, inverse=False):
function fast_inverse (line 300) | def fast_inverse(T):
function bmv (line 317) | def bmv(m, v):
function create_voxel_grid (line 320) | def create_voxel_grid(d, h, w, device='cuda'):
function query_weights_smpl (line 329) | def query_weights_smpl(x, smpl_verts, smpl_weights):
FILE: lib/models/deformers/fast_snarf/lib/model/deformer_smplx.py
class ForwardDeformer (line 14) | class ForwardDeformer(torch.nn.Module):
method __init__ (line 24) | def __init__(self, **kwargs):
method forward_skinning (line 35) | def forward_skinning(self, xc, shape_offset, pose_offset, cond, tfs, t...
method switch_to_explicit (line 78) | def switch_to_explicit(self,resolution=32,smpl_verts=None, smpl_faces=...
method update_lbs_voxel (line 132) | def update_lbs_voxel(self):
method query_sdf_smpl (line 142) | def query_sdf_smpl(self, x, smpl_verts, smpl_faces, smpl_weights):
method skinning_normal (line 164) | def skinning_normal(self, xc, normal, tfs, cond=None, mask=None, inver...
function skinning_mask (line 185) | def skinning_mask(x, w, tfs, inverse=False):
function skinning (line 214) | def skinning(x, w, tfs, inverse=False):
function fast_inverse (line 243) | def fast_inverse(T):
function bmv (line 260) | def bmv(m, v):
function create_voxel_grid (line 264) | def create_voxel_grid(d, h, w, device='cuda'):
function query_weights_smpl (line 273) | def query_weights_smpl(x, smpl_verts, smpl_weights):
FILE: lib/models/deformers/smplx/body_models.py
class SMPL (line 39) | class SMPL(nn.Module):
method __init__ (line 44) | def __init__(
method num_betas (line 260) | def num_betas(self):
method num_expression_coeffs (line 264) | def num_expression_coeffs(self):
method create_mean_pose (line 267) | def create_mean_pose(self, data_struct) -> Tensor:
method name (line 270) | def name(self) -> str:
method reset_params (line 274) | def reset_params(self, **params_dict) -> None:
method get_num_verts (line 281) | def get_num_verts(self) -> int:
method get_num_faces (line 284) | def get_num_faces(self) -> int:
method extra_repr (line 287) | def extra_repr(self) -> str:
method forward_shape (line 295) | def forward_shape(
method forward (line 303) | def forward(
class SMPLH (line 405) | class SMPLH(SMPL):
method __init__ (line 412) | def __init__(
method create_mean_pose (line 570) | def create_mean_pose(self, data_struct, flat_hand_mean=False):
method name (line 582) | def name(self) -> str:
method extra_repr (line 585) | def extra_repr(self):
method forward (line 593) | def forward(
class SMPLX (line 661) | class SMPLX(SMPLH):
method __init__ (line 677) | def __init__(
method name (line 859) | def name(self) -> str:
method num_expression_coeffs (line 863) | def num_expression_coeffs(self):
method create_mean_pose (line 866) | def create_mean_pose(self, data_struct, flat_hand_mean=False):
method extra_repr (line 884) | def extra_repr(self):
method forward (line 892) | def forward(
FILE: lib/models/deformers/smplx/lbs.py
function find_dynamic_lmk_idx_and_bcoords (line 30) | def find_dynamic_lmk_idx_and_bcoords(
function vertices2landmarks (line 108) | def vertices2landmarks(
function lbs (line 152) | def lbs(
function vertices2joints (line 251) | def vertices2joints(J_regressor: Tensor, vertices: Tensor) -> Tensor:
function blend_shapes (line 271) | def blend_shapes(betas: Tensor, shape_disps: Tensor) -> Tensor:
function batch_rodrigues (line 295) | def batch_rodrigues(
function transform_mat (line 332) | def transform_mat(R: Tensor, t: Tensor) -> Tensor:
function batch_rigid_transform (line 345) | def batch_rigid_transform(
FILE: lib/models/deformers/smplx/utils.py
class ModelOutput (line 28) | class ModelOutput:
method __getitem__ (line 36) | def __getitem__(self, key):
method get (line 39) | def get(self, key, default=None):
method __iter__ (line 42) | def __iter__(self):
method keys (line 45) | def keys(self):
method values (line 49) | def values(self):
method items (line 53) | def items(self):
class SMPLOutput (line 59) | class SMPLOutput(ModelOutput):
class SMPLHOutput (line 69) | class SMPLHOutput(SMPLOutput):
class SMPLXOutput (line 76) | class SMPLXOutput(SMPLHOutput):
function find_joint_kin_chain (line 81) | def find_joint_kin_chain(joint_id, kinematic_tree):
function to_tensor (line 90) | def to_tensor(
class Struct (line 99) | class Struct(object):
method __init__ (line 100) | def __init__(self, **kwargs):
function to_np (line 105) | def to_np(array, dtype=np.float32):
function rot_mat_to_euler (line 111) | def rot_mat_to_euler(rot_mats):
FILE: lib/models/deformers/smplx/vertex_joint_selector.py
class VertexJointSelector (line 29) | class VertexJointSelector(nn.Module):
method __init__ (line 31) | def __init__(self, vertex_ids=None,
method forward (line 73) | def forward(self, vertices, joints):
FILE: lib/models/deformers/smplx_deformer_gender.py
class SMPLXDeformer_gender (line 12) | class SMPLXDeformer_gender(torch.nn.Module):
method __init__ (line 14) | def __init__(self, gender, is_sub2=False) -> None:
method initialize (line 68) | def initialize(self):
method forword_body_model (line 125) | def forword_body_model(self, smpl_params, point_pool=4):
method prepare_deformer (line 150) | def prepare_deformer(self, smpl_params=None, num_scenes=1, device=None):
method __call__ (line 215) | def __call__(self, pts_in, rot_in, mask=None, cano=True, offset_gs=Non...
FILE: lib/models/renderers/gau_renderer.py
function batch_rodrigues (line 9) | def batch_rodrigues(rot_vecs, epsilon = 1e-8):
function build_scaling_rotation (line 42) | def build_scaling_rotation(s, r, tfs):
function strip_symmetric (line 54) | def strip_symmetric(sym):
function strip_lowerdiag (line 57) | def strip_lowerdiag(L):
function build_covariance_from_scaling_rotation (line 68) | def build_covariance_from_scaling_rotation(scaling, scaling_modifier, ro...
function build_rotation (line 74) | def build_rotation(r):
function get_covariance (line 97) | def get_covariance(scaling, rotation, scaling_modifier = 1):
class GRenderer (line 105) | class GRenderer(nn.Module):
method __init__ (line 106) | def __init__(self, image_size=256, anti_alias=False, f=5000, near=0.01...
method prepare (line 126) | def prepare(self, cameras):
method render_gaussian (line 165) | def render_gaussian(self, means3D, colors_precomp, rotations, opacitie...
function get_view_matrix (line 195) | def get_view_matrix(R, t):
function get_proj_yy (line 200) | def get_proj_yy(f, image_size, far, near):
function get_proj_matrix (line 206) | def get_proj_matrix(fovY,fovX, z_near, z_far, z_sign):
function get_fov (line 226) | def get_fov(focal, princpt, img_shape):
FILE: lib/models/sapiens/sapiens_wrapper_torchscipt.py
function pretrain_forward (line 23) | def pretrain_forward(sp_lite, inputs: Tensor, layer_num: int, return_hid...
class SapiensWrapper_ts (line 63) | class SapiensWrapper_ts(nn.Module):
method __init__ (line 68) | def __init__(self,
method forward (line 94) | def forward(self, image, use_my_proces=False, requires_grad=False, out...
method _freeze (line 139) | def _freeze(self):
FILE: lib/models/transformer_sa/mae_decoder_v3_skip.py
function get_2d_sincos_pos_embed (line 9) | def get_2d_sincos_pos_embed(embed_dim, grid_size, add_cls_token=False):
function get_2d_sincos_pos_embed_from_grid (line 37) | def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
function get_1d_sincos_pos_embed_from_grid (line 49) | def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
class neck_SA_v3_skip (line 71) | class neck_SA_v3_skip(nn.Module):
method __init__ (line 72) | def __init__(self, patch_size=4, in_chans=32, num_patches=196, embed_d...
method initialize_weights (line 106) | def initialize_weights(self):
method _init_weights (line 116) | def _init_weights(self, m):
method forward_decoder (line 126) | def forward_decoder(self, in_features, ids_restore):
method forward (line 160) | def forward(self, encoded_latent, ids_restore):
FILE: lib/ops/activation.py
class _trunc_exp (line 8) | class _trunc_exp(Function):
method forward (line 11) | def forward(ctx, x):
method backward (line 18) | def backward(ctx, g):
class TruncExp (line 26) | class TruncExp(nn.Module):
method forward (line 29) | def forward(x):
FILE: lib/utils/infer_util.py
function reset_first_frame_rotation (line 26) | def reset_first_frame_rotation(root_orient, trans):
function rotation_matrix_to_rodrigues (line 73) | def rotation_matrix_to_rodrigues(rotation_matrices):
function get_hand_pose_mean (line 82) | def get_hand_pose_mean():
function load_smplify_json (line 105) | def load_smplify_json(smplx_smplify_path):
function load_image (line 144) | def load_image(input_path, output_folder, image_frame_ratio=None):
function prepare_camera (line 191) | def prepare_camera( resolution_x = 640, resolution_y = 640, focal_length...
function construct_camera (line 239) | def construct_camera(K, cam_list, device='cuda'):
function get_name_str (line 266) | def get_name_str(name):
function load_smplx_from_npy (line 272) | def load_smplx_from_npy(smplx_path, device='cuda'):
function add_root_rotate_to_smplx (line 303) | def add_root_rotate_to_smplx(smpl_tmp, frames_num=180, device='cuda'):
function load_smplx_from_json (line 328) | def load_smplx_from_json(smplx_path, device='cuda'):
function get_image_dimensions (line 354) | def get_image_dimensions(input_path):
function construct_camera_from_motionx (line 358) | def construct_camera_from_motionx(smplx_path, device='cuda'):
function remove_background (line 385) | def remove_background(image: PIL.Image.Image,
function resize_foreground (line 399) | def resize_foreground(
function images_to_video (line 440) | def images_to_video(
function save_video (line 461) | def save_video(
FILE: lib/utils/mesh.py
function dot (line 7) | def dot(x, y):
function length (line 11) | def length(x, eps=1e-20):
function safe_normalize (line 15) | def safe_normalize(x, eps=1e-20):
class Mesh (line 18) | class Mesh:
method __init__ (line 19) | def __init__(
method load (line 47) | def load(cls, path=None, resize=True, renormal=True, retex=False, fron...
method load_obj (line 100) | def load_obj(cls, path, albedo_path=None, device=None):
method load_trimesh (line 238) | def load_trimesh(cls, path, device=None):
method aabb (line 325) | def aabb(self):
method auto_size (line 330) | def auto_size(self):
method auto_normal (line 336) | def auto_normal(self):
method auto_uv (line 359) | def auto_uv(self, cache_path=None, vmap=True):
method align_v_to_vt (line 392) | def align_v_to_vt(self, vmapping=None):
method to (line 407) | def to(self, device):
method write (line 415) | def write(self, path):
method write_ply (line 426) | def write_ply(self, path):
method write_glb (line 435) | def write_glb(self, path):
method write_obj (line 568) | def write_obj(self, path):
FILE: lib/utils/mesh_utils.py
function gaussian_3d_coeff (line 6) | def gaussian_3d_coeff(xyzs, covs):
function poisson_mesh_reconstruction (line 27) | def poisson_mesh_reconstruction(points, normals=None):
function decimate_mesh (line 67) | def decimate_mesh(
function clean_mesh (line 111) | def clean_mesh(
FILE: lib/utils/train_util.py
function main_print (line 7) | def main_print(*args):
function count_params (line 12) | def count_params(model, verbose=False):
function instantiate_from_config (line 19) | def instantiate_from_config(config):
function get_obj_from_str (line 29) | def get_obj_from_str(string, reload=False):
FILE: run_demo.py
function parse_args (line 20) | def parse_args():
function process_data_on_gpu (line 44) | def process_data_on_gpu(args, model, gpu_id, img_paths_list, smplx_ref_p...
function setup_directories (line 237) | def setup_directories(base_path, config_name):
function main (line 248) | def main():
FILE: train.py
function get_parser (line 34) | def get_parser(**parser_kwargs):
class SetupCallback (line 134) | class SetupCallback(Callback):
method __init__ (line 135) | def __init__(self, resume, logdir, ckptdir, cfgdir, config):
method on_fit_start (line 143) | def on_fit_start(self, trainer, pl_module):
class CodeSnapshot (line 156) | class CodeSnapshot(Callback):
method __init__ (line 160) | def __init__(self, savedir, exclude_patterns=None):
method get_file_list (line 167) | def get_file_list(self):
method _match_pattern (line 194) | def _match_pattern(self, file_path, pattern):
method save_code_snapshot (line 210) | def save_code_snapshot(self):
method on_fit_start (line 218) | def on_fit_start(self, trainer, pl_module):
function load_folder_ckpt (line 428) | def load_folder_ckpt(checkpoint_dir):
Condensed preview — 58 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (423K chars).
[
{
"path": ".gitignore",
"chars": 325,
"preview": "# Python cache files\r\n__pycache__/\r\n*.py[cod]\r\n*$py.class\r\n*.so\r\n\r\n# Project specific\r\nwork_dirs/\r\ntest/\r\nlib/models/def"
},
{
"path": "README.md",
"chars": 11054,
"preview": "# **IDOL: Instant Photorealistic 3D Human Creation from a Single Image** \r\n\r\n[ -> None:\n if hasattr(module,"
},
{
"path": "lib/models/__init__.py",
"chars": 25,
"preview": "\nfrom .decoders import *\n"
},
{
"path": "lib/models/decoders/__init__.py",
"chars": 88,
"preview": "\nfrom .uvmaps_decoder_gender import UVNDecoder_gender\n\n__all__ = [ 'UVNDecoder_gender']\n"
},
{
"path": "lib/models/decoders/uvmaps_decoder_gender.py",
"chars": 31646,
"preview": "import os\nimport matplotlib.pyplot as plt\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom torch"
},
{
"path": "lib/models/decoders/vit_head.py",
"chars": 4542,
"preview": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom typing import Sequence, Tuple, Optional\n\nclass V"
},
{
"path": "lib/models/deformers/__init__.py",
"chars": 92,
"preview": "\nfrom .smplx_deformer_gender import SMPLXDeformer_gender\n\n__all__ = ['SMPLXDeformer_gender']"
},
{
"path": "lib/models/deformers/fast_snarf/cuda/filter/filter.cpp",
"chars": 467,
"preview": "#include <torch/extension.h>\n#include <ATen/ATen.h>\n#include <vector>\n#include <c10/cuda/CUDAGuard.h>\n\n\n\n\n\nvoid launch_f"
},
{
"path": "lib/models/deformers/fast_snarf/cuda/filter/filter_kernel.cu",
"chars": 2906,
"preview": "#include <vector>\n#include <ATen/cuda/CUDAContext.h>\n#include <ATen/cuda/detail/TensorInfo.cuh>\n#include <ATen/cuda/deta"
},
{
"path": "lib/models/deformers/fast_snarf/cuda/fuse_kernel/fuse_cuda.cpp",
"chars": 1715,
"preview": "#include \"ATen/Functions.h\"\n#include \"ATen/core/TensorBody.h\"\n#include <torch/extension.h>\n#include <ATen/ATen.h>\n#inclu"
},
{
"path": "lib/models/deformers/fast_snarf/cuda/fuse_kernel/fuse_cuda_kernel.cu",
"chars": 20733,
"preview": "#include \"ATen/Functions.h\"\n#include \"ATen/core/TensorAccessor.h\"\n#include \"c10/cuda/CUDAException.h\"\n#include \"c10/cuda"
},
{
"path": "lib/models/deformers/fast_snarf/cuda/precompute/precompute.cpp",
"chars": 695,
"preview": "#include <torch/extension.h>\n#include <ATen/ATen.h>\n#include <vector>\n#include <c10/cuda/CUDAGuard.h>\n\n\n\n\n\nvoid launch_p"
},
{
"path": "lib/models/deformers/fast_snarf/cuda/precompute/precompute_kernel.cu",
"chars": 4204,
"preview": "#include \"ATen/Functions.h\"\n#include \"ATen/core/TensorAccessor.h\"\n#include \"c10/cuda/CUDAException.h\"\n#include \"c10/cuda"
},
{
"path": "lib/models/deformers/fast_snarf/lib/model/deformer_smpl.py",
"chars": 11948,
"preview": "import torch\nfrom torch import einsum\nimport torch.nn.functional as F\nimport os\n\nfrom torch.utils.cpp_extension import l"
},
{
"path": "lib/models/deformers/fast_snarf/lib/model/deformer_smplx.py",
"chars": 9971,
"preview": "import torch\nfrom torch import einsum\nimport torch.nn.functional as F\nimport os\n\nfrom torch.utils.cpp_extension import l"
},
{
"path": "lib/models/deformers/smplx/__init__.py",
"chars": 728,
"preview": "# -*- coding: utf-8 -*-\n\n# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is\n# holder of all propri"
},
{
"path": "lib/models/deformers/smplx/body_models.py",
"chars": 45117,
"preview": "# -*- coding: utf-8 -*-\n\n# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is\n# holder of all propr"
},
{
"path": "lib/models/deformers/smplx/joint_names.py",
"chars": 4896,
"preview": "# -*- coding: utf-8 -*-\n\n# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is\n# holder of all propri"
},
{
"path": "lib/models/deformers/smplx/lbs.py",
"chars": 13913,
"preview": "# -*- coding: utf-8 -*-\n\n# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is\n# holder of all propri"
},
{
"path": "lib/models/deformers/smplx/utils.py",
"chars": 3326,
"preview": "# -*- coding: utf-8 -*-\n\n# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is\n# holder of all propri"
},
{
"path": "lib/models/deformers/smplx/vertex_ids.py",
"chars": 2204,
"preview": "# -*- coding: utf-8 -*-\n\n# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is\n# holder of all propri"
},
{
"path": "lib/models/deformers/smplx/vertex_joint_selector.py",
"chars": 2702,
"preview": "# -*- coding: utf-8 -*-\n\n# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is\n# holder of all propri"
},
{
"path": "lib/models/deformers/smplx_deformer_gender.py",
"chars": 14596,
"preview": "# Modified from Deformer of AG3D\n\nfrom .fast_snarf.lib.model.deformer_smplx import ForwardDeformer, skinning\nfrom .smplx"
},
{
"path": "lib/models/renderers/__init__.py",
"chars": 93,
"preview": "\nfrom .gau_renderer import GRenderer, get_covariance, batch_rodrigues\n__all__ = ['GRenderer']"
},
{
"path": "lib/models/renderers/gau_renderer.py",
"chars": 8192,
"preview": "from diff_gaussian_rasterization import (\n GaussianRasterizationSettings,\n GaussianRasterizer,\n)\nimport torch\nimpo"
},
{
"path": "lib/models/sapiens/__init__.py",
"chars": 57,
"preview": "from .sapiens_wrapper_torchscipt import SapiensWrapper_ts"
},
{
"path": "lib/models/sapiens/sapiens_wrapper_torchscipt.py",
"chars": 5789,
"preview": "# Copyright (c) 2023, Zexin He\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use thi"
},
{
"path": "lib/models/transformer_sa/__init__.py",
"chars": 48,
"preview": "from .mae_decoder_v3_skip import neck_SA_v3_skip"
},
{
"path": "lib/models/transformer_sa/mae_decoder_v3_skip.py",
"chars": 6383,
"preview": "import torch\nimport torch.nn as nn\nimport numpy as np\nfrom timm.models.vision_transformer import PatchEmbed, Block, chec"
},
{
"path": "lib/ops/__init__.py",
"chars": 33,
"preview": "from .activation import TruncExp\n"
},
{
"path": "lib/ops/activation.py",
"chars": 637,
"preview": "import math\nimport torch\nimport torch.nn as nn\nfrom torch.autograd import Function\nfrom torch.cuda.amp import custom_bwd"
},
{
"path": "lib/utils/infer_util.py",
"chars": 18522,
"preview": "import os\nimport imageio\nimport rembg\nimport torch\nimport numpy as np\nimport PIL.Image\nfrom PIL import Image\nfrom typing"
},
{
"path": "lib/utils/mesh.py",
"chars": 23321,
"preview": "import os\nimport cv2\nimport torch\nimport trimesh\nimport numpy as np\n\ndef dot(x, y):\n return torch.sum(x * y, -1, keep"
},
{
"path": "lib/utils/mesh_utils.py",
"chars": 5142,
"preview": "import numpy as np\nimport pymeshlab as pml\nimport torch\n\n\ndef gaussian_3d_coeff(xyzs, covs):\n # xyzs: [N, 3]\n # co"
},
{
"path": "lib/utils/train_util.py",
"chars": 1009,
"preview": "import importlib\n\nimport os\n\nfrom pytorch_lightning.utilities import rank_zero_only\n@rank_zero_only\ndef main_print(*args"
},
{
"path": "run_demo.py",
"chars": 12797,
"preview": "import os\nos.environ[\"CUDA_VISIBLE_DEVICES\"] = \"0\"\n# os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"0,1,2,3,4,5,6,7\"\nimport argpa"
},
{
"path": "scripts/download_files.sh",
"chars": 713,
"preview": "#!/bin/bash\n\n# Create necessary directories\nmkdir -p work_dirs/\nmkdir -p work_dirs/ckpt\n\n# Download files from HuggingFa"
},
{
"path": "scripts/fetch_template.sh",
"chars": 2969,
"preview": "#!/bin/bash\nurle () { [[ \"${1}\" ]] || return 1; local LANG=C i x; for (( i = 0; i < ${#1}; i++ )); do x=\"${1:i:1}\"; [[ \""
},
{
"path": "scripts/pip_install.sh",
"chars": 1894,
"preview": "#!/bin/bash\n\n# Complete environment setup process\n\n# Step 0: Ensure you create a Conda environment \n# and Activate the e"
},
{
"path": "setup.py",
"chars": 709,
"preview": "from setuptools import setup\r\nfrom torch.utils.cpp_extension import BuildExtension, CUDAExtension\r\nimport os\r\ncuda_dir ="
},
{
"path": "train.py",
"chars": 16200,
"preview": "import os, sys\n# os.environ[\"WANDB_MODE\"] = \"dryrun\" # default setting to save locally\nfrom lib.utils.train_util import "
}
]
// ... and 1 more files (download for full content)
About this extraction
This page contains the full source code of the yiyuzhuang/IDOL GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 58 files (396.2 KB), approximately 107.8k tokens, and a symbol index with 260 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.
Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.