Repository: yiyuzhuang/IDOL Branch: main Commit: 9fd9296c28e8 Files: 58 Total size: 396.2 KB Directory structure: gitextract_sj4g9tfa/ ├── .gitignore ├── README.md ├── configs/ │ ├── idol_debug.yaml │ ├── idol_v0.yaml │ └── test_dataset.yaml ├── data_processing/ │ ├── prepare_cache.py │ ├── process_datasets.sh │ └── visualize_samples.py ├── dataset/ │ ├── README.md │ ├── sample/ │ │ └── param/ │ │ └── Kenya_female_fit_streetwear_50~60 years old_1501.npy │ └── visualize_samples.py ├── env/ │ └── README.md ├── lib/ │ ├── __init__.py │ ├── datasets/ │ │ ├── __init__.py │ │ ├── avatar_dataset.py │ │ └── dataloader.py │ ├── humanlrm_wrapper_sa_v1.py │ ├── mmutils/ │ │ ├── __init__.py │ │ └── initialize.py │ ├── models/ │ │ ├── __init__.py │ │ ├── decoders/ │ │ │ ├── __init__.py │ │ │ ├── uvmaps_decoder_gender.py │ │ │ └── vit_head.py │ │ ├── deformers/ │ │ │ ├── __init__.py │ │ │ ├── fast_snarf/ │ │ │ │ ├── cuda/ │ │ │ │ │ ├── filter/ │ │ │ │ │ │ ├── filter.cpp │ │ │ │ │ │ └── filter_kernel.cu │ │ │ │ │ ├── fuse_kernel/ │ │ │ │ │ │ ├── fuse_cuda.cpp │ │ │ │ │ │ └── fuse_cuda_kernel.cu │ │ │ │ │ └── precompute/ │ │ │ │ │ ├── precompute.cpp │ │ │ │ │ └── precompute_kernel.cu │ │ │ │ └── lib/ │ │ │ │ └── model/ │ │ │ │ ├── deformer_smpl.py │ │ │ │ └── deformer_smplx.py │ │ │ ├── smplx/ │ │ │ │ ├── __init__.py │ │ │ │ ├── body_models.py │ │ │ │ ├── joint_names.py │ │ │ │ ├── lbs.py │ │ │ │ ├── utils.py │ │ │ │ ├── vertex_ids.py │ │ │ │ └── vertex_joint_selector.py │ │ │ └── smplx_deformer_gender.py │ │ ├── renderers/ │ │ │ ├── __init__.py │ │ │ └── gau_renderer.py │ │ ├── sapiens/ │ │ │ ├── __init__.py │ │ │ └── sapiens_wrapper_torchscipt.py │ │ └── transformer_sa/ │ │ ├── __init__.py │ │ └── mae_decoder_v3_skip.py │ ├── ops/ │ │ ├── __init__.py │ │ └── activation.py │ └── utils/ │ ├── infer_util.py │ ├── mesh.py │ ├── mesh_utils.py │ └── train_util.py ├── run_demo.py ├── scripts/ │ ├── download_files.sh │ ├── fetch_template.sh │ └── pip_install.sh ├── setup.py └── train.py ================================================ FILE CONTENTS ================================================ ================================================ FILE: .gitignore ================================================ # Python cache files __pycache__/ *.py[cod] *$py.class *.so # Project specific work_dirs/ test/ lib/models/deformers/smplx/SMPLX/* # IDE .idea/ .vscode/ *.swp *.swo # Distribution / packaging dist/ build/ *.egg-info/ # Jupyter Notebook .ipynb_checkpoints # Git related *.orig *.rej *.patch ================================================ FILE: README.md ================================================ # **IDOL: Instant Photorealistic 3D Human Creation from a Single Image** [![Website](https://img.shields.io/badge/Project-Website-0073e6)](https://yiyuzhuang.github.io/IDOL/) [![Paper](https://img.shields.io/badge/arXiv-PDF-b31b1b)](https://arxiv.org/pdf/2412.14963) [![Live Demo](https://img.shields.io/badge/Live-Demo-34C759)](https://yiyuzhuang.github.io/IDOL/) [![License](https://img.shields.io/badge/License-MIT-blue.svg)](https://opensource.org/licenses/MIT) ---

Teaser Image for IDOL

--- ## **Abstract** This work introduces **IDOL**, a feed-forward, single-image human reconstruction framework that is fast, high-fidelity, and generalizable. Leveraging a large-scale dataset of 100K multi-view subjects, our method demonstrates exceptional generalizability and robustness in handling diverse human shapes, cross-domain data, severe viewpoints, and occlusions. With a uniform structured representation, the reconstructed avatars are directly animatable and easily editable, providing a significant step forward for various applications in graphics, vision, and beyond. In summary, this project introduces: - **IDOL**: A scalable pipeline for instant photorealistic 3D human reconstruction using a simple yet efficient feed-forward model. - **HuGe100K Dataset**: We develop a data generation pipeline and present \datasetname, a large-scale multi-view human dataset featuring diverse attributes, high-fidelity, high-resolution appearances, and a well-aligned SMPL-X model. - **Application Support**: Enabling 3D human reconstruction and downstream tasks such as editing and animation. --- ## 📰 **News** - **2024-12-18**: Paper is now available on arXiv. - **2025-01-02**: The demo dataset containing 100 samples is now available for access. The remaining dataset is currently undergoing further cleaning and review. - **2025-03-01**: 🎉 Paper accepted by CVPR 2025. - **2025-03-01**: 🎉 We have released the inference code! Check out the [Code Release](#code-release) section for details. - **2025-04-01**: 🔥 Full HuGe100K dataset is now available! See the [Dataset Access](#dataset-demo-access) section. - **2025-04-05**: 🔥 Training code is now available! Check out the [Training Code](#training-code) section for details. ## 🚧 **Project Status** We are actively working on releasing the following resources: | Resource | Status | Expected Release Date | |-----------------------------|---------------------|----------------------------| | **Dataset Demo** | ✅ Available | **Now Live! (2025.01.02)** | | **Inference Code** | ✅ Available | **Now Live! (2025.03.01)** | | **Full Dataset Access** | ✅ Available | **Now Live! (2025.04.01)** | | **Online Demo** | 🚧 In Progress | **Before April 2025** | | **Training Code** | ✅ Available | **Now Live! (2025.04.05)** | Stay tuned as we update this section with new releases! 🚀 ## 💻 **Code Release** ### Installation & Environment Setup Please refer to [env/README.md](env/README.md) for detailed environment setup instructions. ### Quick Start Run demo with different modes: ```bash # Reconstruct the input image python run_demo.py --render_mode reconstruct # Generate novel poses (animation) python run_demo.py --render_mode novel_pose # Generate 360-degree view python run_demo.py --render_mode novel_pose_A ``` ### Training #### Data Preparation 1. **Dataset Structure**: First, prepare your dataset with the following structure: ``` dataset_root/ ├── deepfashion/ │ ├── image1/ │ │ ├── videos/ │ │ │ ├── xxx.mp4 │ │ │ └── xxx.jpg │ │ └── param/ │ │ └── xxx.npy │ └── image2/ │ ├── videos/ │ └── param/ └── flux_batch1_5000/ ├── image1/ │ ├── videos/ │ └── param/ └── image2/ ├── videos/ └── param/ ``` 2. **Process Dataset**: Run the data processing script to generate cache files: ```bash # Process the dataset and generate cache files # Please modify the dataset path and the sample number in the script bash data_processing/process_datasets.sh ``` This will generate cache files in the `processed_data` directory: - `deepfashion_train_140.npy` - `deepfashion_val_10.npy` - `deepfashion_test_50.npy` - `flux_batch1_5000_train_140.npy` - `flux_batch1_5000_val_10.npy` - `flux_batch1_5000_test_50.npy` 3. **Configure Cache Path**: Update the cache path in your config file (e.g., `configs/idol_v0.yaml`): ```yaml params: cache_path: [ ./processed_data/deepfashion_train_140.npy, ./processed_data/flux_batch1_5000_train_140.npy ] ``` #### Training 1. **Single-Node Training**: For single-node multi-GPU training: ```bash python train.py \ --base configs/idol_v0.yaml \ --num_nodes 1 \ --gpus 0,1,2,3,4,5,6,7 ``` 2. **Multi-Node Training**: For multi-node training, specify additional parameters: ```bash python train.py \ --base configs/idol_v0.yaml \ --num_nodes \ --node_rank \ --master_addr \ --master_port \ --gpus 0,1,2,3,4,5,6,7 ``` Example for a 2-node setup: ```bash # On master node (node 0): python train.py --base configs/idol_v0.yaml --num_nodes 2 --node_rank 0 --master_addr 192.168.1.100 --master_port 29500 --gpus 0,1,2,3,4,5,6,7 # On worker node (node 1): python train.py --base configs/idol_v0.yaml --num_nodes 2 --node_rank 1 --master_addr 192.168.1.100 --master_port 29500 --gpus 0,1,2,3,4,5,6,7 ``` 3. **Resume Training**: To resume training from a checkpoint: ```bash python train.py \ --base configs/idol_v0.yaml \ --resume PATH/TO/MODEL.ckpt \ --num_nodes 1 \ --gpus 0,1,2,3,4,5,6,7 ``` 4. **Test and Evaluate Metrics**: ```bash python train.py \ --base configs/idol_v0.yaml \ # Main config file (model) --num_nodes 1 \ --gpus 0,1,2,3,4,5,6,7 \ --test_sd /path/to/model_checkpoint.ckpt \ # Path to the .ckpt model you want to test --test_dataset ./configs/test_dataset.yaml # (Optional) Dataset config used specifically for testing ``` ## Notes - Make sure all GPUs have enough memory for the selected batch size - For multi-node training, ensure network connectivity between nodes - Monitor training progress using the logging system - Adjust learning rate and other hyperparameters in the config file as needed ## 🌐 **Key Links** - 📄 [**Paper on arXiv**](https://arxiv.org/pdf/2412.02684) - 🌐 [**Project Website**](https://yiyuzhuang.github.io/IDOL/) - 🚀 [**Live Demo**](https://your-live-demo-link.com) (Coming Soon!) --- ## 📊 **Dataset Demo Access** We introduce **HuGe100K**, a large-scale multi-view human dataset, supporting 3D human reconstruction and animation research. ### ▶ **Watch the Demo Video**

Dataset GIF

### 📋 **Dataset Documentation** For detailed information about the dataset format, structure, and usage guidelines, please refer to our [Dataset Documentation](dataset/README.md). ### 🚀 **Access the Dataset**

🔥 HuGe100K - The largest multi-view human dataset with 100,000+ subjects! 🔥

High-resolution • Multi-view • Diverse poses • SMPL-X aligned

Apply for Access

Complete the form to get access credentials and download links!

### ⚖️ **License and Attribution** This dataset includes images derived from the **DeepFashion** dataset, originally provided by MMLAB at The Chinese University of Hong Kong. The use of DeepFashion images in this dataset has been explicitly authorized by the original authors solely for the purpose of creating and distributing this dataset. **Users must not further reproduce, distribute, sell, or commercially exploit any images or derived data originating from DeepFashion.** For any subsequent or separate use of the DeepFashion data, users must directly obtain authorization from MMLAB and comply with the original [DeepFashion License](https://mmlab.ie.cuhk.edu.hk/projects/DeepFashion.html). --- ## 📝 **Citation** If you find our work helpful, please cite us using the following BibTeX: ```bibtex @article{zhuang2024idolinstant, title={IDOL: Instant Photorealistic 3D Human Creation from a Single Image}, author={Yiyu Zhuang and Jiaxi Lv and Hao Wen and Qing Shuai and Ailing Zeng and Hao Zhu and Shifeng Chen and Yujiu Yang and Xun Cao and Wei Liu}, journal={arXiv preprint arXiv:2412.14963}, year={2024}, url={https://arxiv.org/abs/2412.14963}, } ``` ## **License** This project is licensed under the **MIT License**. - **Permissions**: This license grants permission to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the software. - **Condition**: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. - **Disclaimer**: The software is provided "as is", without warranty of any kind. For more information, see the full license [here](https://opensource.org/licenses/MIT). ## **Support Our Work** ⭐ If you find our work useful for your research or applications: - Please ⭐ **star our repository** to help us reach more people - Consider **citing our paper** in your publications (see [Citation](#citation) section) - Share our project with others who might benefit from it Your support helps us continue developing open-source research projects like this one! ## 📚 **Acknowledgments** This project is majorly built upon several excellent open-source projects: - [E3Gen](https://github.com/olivia23333/E3Gen): Efficient, Expressive and Editable Avatars Generation - [SAPIENS](https://github.com/facebookresearch/sapiens): High-resolution visual models for human-centric tasks - [GeoLRM](https://github.com/alibaba-yuanjing-aigclab/GeoLRM): Large Reconstruction Model for High-Quality 3D Generation - [3D Gaussian Splatting](https://github.com/graphdeco-inria/gaussian-splatting): Real-Time 3DGS Rendering We thank all the authors for their contributions to the open-source community. ================================================ FILE: configs/idol_debug.yaml ================================================ debug: True # code_size: [32, 256, 256] code_size: [32, 1024, 1024] model: # base_learning_rate: 2.0e-04 # yy Need to check target: lib.SapiensGS_SA_v1 params: # optimizer add # use_bf16: true max_steps: 100_000 warmup_steps: 10_000 #12_000 use_checkpoint: true lambda_depth_tv: 0.05 lambda_lpips: 0 #2.0 lambda_mse: 20 #1.0 lambda_offset: 1 #offset_weight: 50 mse 20, lpips 0.1 neck_learning_rate: 5e-4 decoder_learning_rate: 5e-4 output_hidden_states: true # if True, will output the hidden states from sapiens shallow layer, for the neck decoder loss_coef: 0.5 init_iter: 500 scale_weight: 0.01 smplx_path: 'work_dirs/demo_data/Ways_to_Catch_360_clip1.json' code_reshape: [32, 96, 96] patch_size: 1 code_activation: type: tanh mean: 0.0 std: 0.5 clip_range: 2 grid_size: 64 encoder: target: lib.models.sapiens.SapiensWrapper_ts params: model_path: work_dirs/ckpt/sapiens_1b_epoch_173_torchscript.pt2 # model_path: /apdcephfs_cq8/share_1367250/harriswen/projects/sapiens_convert/checkpoints//sapiens_1b_epoch_173_torchscript.pt2 layer_num: 40 img_size: [1024, 736] freeze: True neck: target: lib.models.transformer_sa.neck_SA_v3_skip # TODO!! add a self attention version params: patch_size: 4 #4, in_chans: 32 #32, # the uv code dims num_patches: 9216 #4096 #num_patches #,#4096, # 16*16 embed_dim: 1536 # sapiens' latent dims # 1920 # 1920 for sapiens encoder2 #1024 # the feature extrators outputs decoder_embed_dim: 128 # 1024 decoder_depth: 2 # 8 decoder_num_heads: 4 #16, total_num_hidden_states: 12 mlp_ratio: 4. decoder: target: lib.models.decoders.UVNDecoder_gender params: interp_mode: bilinear base_layers: [16, 64] density_layers: [64, 1] color_layers: [16, 128, 9] offset_layers: [64, 3] use_dir_enc: false dir_layers: [16, 64] activation: silu bg_color: 1 sigma_activation: sigmoid sigmoid_saturation: 0.001 gender: neutral is_sub2: true ## update, make it into 10w gs points multires: 0 image_size: [640, 896] superres: false focal: 1120 up_cnn_in_channels: 128 # be the same as decoder_embed_dim reshape_type: VitHead vithead_param: in_channels: 128 # be the same as decoder_embed_dim out_channels: 32 deconv_out_channels: [128, 64] deconv_kernel_sizes: [4, 4] conv_out_channels: [128, 128] conv_kernel_sizes: [3, 3] fix_sigma: true dataset: target: lib.datasets.dataloader.DataModuleFromConfig params: batch_size: 1 #16 # 6 for lpips num_workers: 1 #2 # working when in debug mode debug_cache_path:./processed_data/flux_batch1_5000_test_50_local.npy train: target: lib.datasets.AvatarDataset params: data_prefix: None cache_path: [ ./processed_data/deepfashion_train_140_local.npy, ./processed_data/flux_batch1_5000_train_140_local.npy ] specific_observation_num: 5 better_range: true first_is_front: true if_include_video_ref_img: true prob_include_video_ref_img: 0.5 img_res: [640, 896] validation: target: lib.datasets.AvatarDataset params: data_prefix: None load_imgs: true specific_observation_num: 3 better_range: true first_is_front: true img_res: [640, 896] cache_path: [ ./processed_data/flux_batch1_5000_test_50_local.npy, #./processed_data/flux_batch1_5000_val_10_local.npy ] lightning: modelcheckpoint: params: every_n_train_steps: 4000 #2000 save_top_k: -1 save_last: true monitor: 'train/loss_mse' # ADD this logging in the wrapper_sa mode: "min" filename: 'sample-synData-epoch{epoch:02d}-val_loss{val/loss:.2f}' callbacks: {} trainer: num_sanity_val_steps: 1 accumulate_grad_batches: 1 gradient_clip_val: 10.0 max_steps: 80000 check_val_every_n_epoch: 1 ## check validation set every 1 training batches in the current epoch benchmark: true val_check_interval: 1.0 ================================================ FILE: configs/idol_v0.yaml ================================================ debug: True # code_size: [32, 256, 256] code_size: [32, 1024, 1024] model: # base_learning_rate: 2.0e-04 # yy Need to check target: lib.SapiensGS_SA_v1 params: # optimizer add # use_bf16: true max_steps: 100_000 warmup_steps: 10_000 #12_000 use_checkpoint: true lambda_depth_tv: 0.05 lambda_lpips: 10 #2.0 lambda_mse: 20 #1.0 lambda_offset: 1 #offset_weight: 50 mse 20, lpips 0.1 neck_learning_rate: 5e-4 decoder_learning_rate: 5e-4 output_hidden_states: true # if True, will output the hidden states from sapiens shallow layer, for the neck decoder loss_coef: 0.5 init_iter: 500 scale_weight: 0.01 smplx_path: 'work_dirs/demo_data/Ways_to_Catch_360_clip1.json' code_reshape: [32, 96, 96] patch_size: 1 code_activation: type: tanh mean: 0.0 std: 0.5 clip_range: 2 grid_size: 64 encoder: target: lib.models.sapiens.SapiensWrapper_ts params: model_path: work_dirs/ckpt/sapiens_1b_epoch_173_torchscript.pt2 # model_path: /apdcephfs_cq8/share_1367250/harriswen/projects/sapiens_convert/checkpoints//sapiens_1b_epoch_173_torchscript.pt2 layer_num: 40 img_size: [1024, 736] freeze: True neck: target: lib.models.transformer_sa.neck_SA_v3_skip # TODO!! add a self attention version params: patch_size: 4 #4, in_chans: 32 #32, # the uv code dims num_patches: 9216 #4096 #num_patches #,#4096, # 16*16 embed_dim: 1536 # 1920 # 1920 for sapiens encoder2 #1024 # the feature extrators outputs decoder_embed_dim: 1536 # 1024 decoder_depth: 16 # 8 decoder_num_heads: 16 #16, total_num_hidden_states: 40 mlp_ratio: 4. decoder: target: lib.models.decoders.UVNDecoder_gender params: interp_mode: bilinear base_layers: [16, 64] density_layers: [64, 1] color_layers: [16, 128, 9] offset_layers: [64, 3] use_dir_enc: false dir_layers: [16, 64] activation: silu bg_color: 1 sigma_activation: sigmoid sigmoid_saturation: 0.001 gender: neutral is_sub2: true ## update, make it into 10w gs points multires: 0 image_size: [640, 896] superres: false focal: 1120 up_cnn_in_channels: 1536 # be the same as decoder_embed_dim reshape_type: VitHead vithead_param: in_channels: 1536 # be the same as decoder_embed_dim out_channels: 32 deconv_out_channels: [512, 512, 512, 256] deconv_kernel_sizes: [4, 4, 4, 4] conv_out_channels: [128, 128] conv_kernel_sizes: [3, 3] fix_sigma: true dataset: target: lib.datasets.dataloader.DataModuleFromConfig params: batch_size: 1 #16 # 6 for lpips num_workers: 2 #2 # working when in debug mode debug_cache_path: ./processed_data/flux_batch1_5000_test_50_local.npy train: target: lib.datasets.AvatarDataset params: data_prefix: None cache_path: [ ./processed_data/deepfashion_train_140_local.npy, ./processed_data/flux_batch1_5000_train_140_local.npy ] specific_observation_num: 5 better_range: true first_is_front: true if_include_video_ref_img: true prob_include_video_ref_img: 0.5 img_res: [640, 896] validation: target: lib.datasets.AvatarDataset params: data_prefix: None load_imgs: true specific_observation_num: 5 better_range: true first_is_front: true img_res: [640, 896] cache_path: [ ./processed_data/deepfashion_val_10_local.npy, ./processed_data/flux_batch1_5000_val_10_local.npy ] lightning: modelcheckpoint: params: every_n_train_steps: 4000 #2000 save_top_k: -1 save_last: true monitor: 'train/loss_mse' # ADD this logging in the wrapper_sa mode: "min" filename: 'sample-synData-epoch{epoch:02d}-val_loss{val/loss:.2f}' callbacks: {} trainer: num_sanity_val_steps: 0 accumulate_grad_batches: 1 gradient_clip_val: 10.0 max_steps: 80000 check_val_every_n_epoch: 1 ## check validation set every 1 training batches in the current epoch benchmark: true ================================================ FILE: configs/test_dataset.yaml ================================================ dataset: target: lib.datasets.dataloader.DataModuleFromConfig params: batch_size: 1 num_workers: 2 # working when in debug mode debug_cache_path: ./processed_data/flux_batch1_5000_test_50_local.npy train: target: lib.datasets.AvatarDataset params: data_prefix: None cache_path: [ ./processed_data/deepfashion_train_140_local.npy, ./processed_data/flux_batch1_5000_train_140_local.npy ] specific_observation_num: 5 better_range: true first_is_front: true if_include_video_ref_img: true prob_include_video_ref_img: 0.5 img_res: [640, 896] validation: target: lib.datasets.AvatarDataset params: data_prefix: None load_imgs: true specific_observation_num: 5 better_range: true first_is_front: true img_res: [640, 896] cache_path: [ ./processed_data/deepfashion_val_10_local.npy, ./processed_data/flux_batch1_5000_val_10_local.npy ] test: target: lib.datasets.AvatarDataset params: data_prefix: None load_imgs: true specific_observation_num: 5 better_range: true first_is_front: true img_res: [640, 896] cache_path: [ ./processed_data/deepfashion_test_50_local.npy, ./processed_data/flux_batch1_5000_test_50_local.npy ] ================================================ FILE: data_processing/prepare_cache.py ================================================ #!/usr/bin/env python # -*- coding: utf-8 -*- """ Data preparation script for DeepFashion video dataset. This script processes video files and their corresponding parameters, and splits the dataset into train/val/test sets. """ import os import numpy as np import argparse def parse_args(): """Parse command line arguments.""" parser = argparse.ArgumentParser(description="Prepare DeepFashion video dataset") parser.add_argument( "--video_dir", type=str, default="/apdcephfs/private_harriswen/data/deepfashion/", help="Base directory containing imageX folders" ) parser.add_argument( "--output_dir", type=str, default="./", help="Directory to save the processed data" ) parser.add_argument( "--prefix", type=str, default="DeepFashion", help="Prefix for the output file names" ) parser.add_argument( "--max_videos", type=int, default=20000, help="Maximum number of videos to process (for creating smaller datasets)" ) return parser.parse_args() def prepare_dataset(video_dir, output_dir, prefix, max_total_videos=20000): """ Prepare the DeepFashion dataset by processing videos and parameters. Args: video_dir: Base directory containing imageX folders output_dir: Directory to save processed data prefix: Prefix for output filenames max_total_videos: Maximum number of videos to process (default: 20000) """ # Find all imageX subdirectories image_dirs = [] for item in os.listdir(video_dir): if item.startswith("image") and os.path.isdir(os.path.join(video_dir, item)): image_dirs.append(item) image_dirs.sort() print(f"Found {len(image_dirs)} image directories: {image_dirs}") # Collect all video files all_video_files = [] all_param_files = [] all_dir_names = [] # import ipdb; ipdb.set_trace() for image_dir in image_dirs: videos_path = os.path.join(video_dir, image_dir, "videos") params_path = os.path.join(video_dir, image_dir, "param") if not os.path.exists(videos_path): print(f"Warning: Videos directory not found in {image_dir}, skipping.") continue if not os.path.exists(params_path): print(f"Warning: Parameters directory not found in {image_dir}, skipping.") continue # Get list of video names in current directory param_names = os.listdir(params_path) # filter the files with .npy extension param_names = [name for name in param_names if name.endswith(".npy")] for name in param_names: video_path = os.path.join(videos_path, name.replace(".npy", ".mp4")) param_path = os.path.join(params_path, name) # Check if both video and parameter files exist if not os.path.exists(video_path): print(f"Warning: Video file not found: {video_path}, skipping.") continue if not os.path.exists(param_path): print(f"Warning: Parameter file not found: {param_path}, skipping.") continue # Add to collection only if both files exist all_video_files.append(video_path) all_param_files.append(param_path) all_dir_names.append(image_dir) total_videos = len(all_video_files) print(f"Total valid videos found: {total_videos}") if total_videos == 0: print("Error: No valid video-parameter pairs found. Please check your data paths.") return # Limit number of videos to process if max_total_videos < total_videos: # Randomly shuffle and select first max_total_videos indices = list(range(total_videos)) np.random.shuffle(indices) indices = indices[:max_total_videos] all_video_files = [all_video_files[i] for i in indices] all_param_files = [all_param_files[i] for i in indices] all_dir_names = [all_dir_names[i] for i in indices] print(f"Limiting to {max_total_videos} videos") # Process videos and parameters scenes = [] processed_count = 0 skipped_count = 0 for video_path, param_path, dir_name in zip(all_video_files, all_param_files, all_dir_names): processed_count += 1 case_name = os.path.basename(video_path) print(f"Processing {processed_count}/{len(all_video_files)}: {dir_name}/{case_name}") try: # Create scene dictionary scenes.append(dict( video_path=video_path, image_paths=None, # only fill it for the data in a images sequence instead of a video param_path=param_path, image_ref=video_path.replace(".mp4", ".jpg") )) except Exception as e: print(f"Error processing {video_path}: {e}") skipped_count += 1 print(f"Total scenes collected: {len(scenes)}") print(f"Total scenes skipped: {skipped_count}") if len(scenes) == 0: print("Error: No scenes could be processed. Please check your data.") return # Split dataset total_scenes = len(scenes) test_scenes = scenes[-50:] if total_scenes > 50 else [] val_scenes = scenes[-60:-50] if total_scenes > 60 else [] train_scenes = scenes[:-60] if total_scenes > 60 else scenes # Save each split splits = { "train": train_scenes, "val": val_scenes, "test": test_scenes, "all": scenes } # Create output directory os.makedirs(output_dir, exist_ok=True) # Save each split to separate file for split_name, split_data in splits.items(): if not split_data: continue cache_path = os.path.join( output_dir, f"{prefix}_{split_name}_{len(split_data)}.npy" ) np.save(cache_path, split_data) print(f"Saved {split_name} split with {len(split_data)} samples to {cache_path}") if __name__ == "__main__": # Parse command line arguments args = parse_args() # Prepare and save the dataset prepare_dataset(args.video_dir, args.output_dir, args.prefix, args.max_videos) print(f"Done processing {args.video_dir} dataset") ================================================ FILE: data_processing/process_datasets.sh ================================================ #!/bin/bash # Data processing script for multiple datasets # This script processes all specified datasets and saves the results to output directories # Define the list of dataset paths DATASET_PATHS=( "/PATH/TO/deepfashion" "/PATH/TO/flux_batch1_5000" "/PATH/TO/flux_batch2" # Add more dataset paths here as needed ) # Output base directory for processed cache files OUTPUT_BASE_DIR="./processed_data" # Maximum videos to process per dataset (set to a smaller number for testing) # if you want to process all videos, set MAX_VIDEOS to a very large number MAX_VIDEOS=200 # Process each dataset for DATASET_PATH in "${DATASET_PATHS[@]}"; do # Extract dataset name from path (use the last directory name as prefix) DATASET_NAME=$(basename "$DATASET_PATH") # Create output directory for this dataset OUTPUT_DIR="${OUTPUT_BASE_DIR}" mkdir -p "$OUTPUT_DIR" echo "===== Processing ${DATASET_NAME} Dataset =====" echo "Source: ${DATASET_PATH}" echo "Destination: ${OUTPUT_DIR}" # Run the processing script python data_processing/prepare_cache.py \ --video_dir "${DATASET_PATH}" \ --output_dir "${OUTPUT_DIR}" \ --prefix "${DATASET_NAME}" \ --max_videos "${MAX_VIDEOS}" # Check if processing was successful if [ $? -ne 0 ]; then echo "Error processing ${DATASET_NAME} dataset" echo "Continuing with next dataset..." else echo "Successfully processed ${DATASET_NAME} dataset" fi echo "----------------------------------------" done echo "===== All datasets processing completed =====" echo "Results saved to: ${OUTPUT_BASE_DIR}" # List all processed datasets echo "Processed datasets:" for DATASET_PATH in "${DATASET_PATHS[@]}"; do DATASET_NAME=$(basename "$DATASET_PATH") echo "- ${DATASET_NAME}: ${OUTPUT_BASE_DIR}/${DATASET_NAME}" done ================================================ FILE: data_processing/visualize_samples.py ================================================ import torch import numpy as np import os os.environ["PYOPENGL_PLATFORM"] = "osmesa" import smplx import trimesh import pyrender import imageio def init_smplx_model(): """Initialize the SMPL-X model with predefined settings.""" body_model = smplx.SMPLX('PATH_TO_YOUR_SMPLX_FOLDER', gender="neutral", create_body_pose=False, create_betas=False, create_global_orient=False, create_transl=False, create_expression=True, create_jaw_pose=True, create_leye_pose=True, create_reye_pose=True, create_right_hand_pose=False, create_left_hand_pose=False, use_pca=False, num_pca_comps=12, num_betas=10, flat_hand_mean=False) return body_model # Load SMPL-X parameters param_path = "./100samples/Apose/param/Argentina_male_buff_thermal wear_20~30 years old_1573.npy" param = np.load(param_path, allow_pickle=True).item() # Extract SMPL-X parameters smpl_params = param['smpl_params'].reshape(1, -1) scale, transl, global_orient, pose, betas, left_hand_pose, right_hand_pose, jaw_pose, leye_pose, reye_pose, expression = torch.split(smpl_params, [1, 3, 3, 63, 10, 45, 45, 3, 3, 3, 10], dim=1) # Initialize SMPL-X model and generate vertices device = torch.device("cpu") model = init_smplx_model().to(device) output = model(global_orient=global_orient, body_pose=pose, betas=betas, left_hand_pose=left_hand_pose, right_hand_pose=right_hand_pose, jaw_pose=jaw_pose, leye_pose=leye_pose, reye_pose=reye_pose, expression=expression) vertices = output.vertices[0].detach().cpu().numpy() faces = model.faces # Create a Trimesh and Pyrender mesh mesh = trimesh.Trimesh(vertices, faces) mesh_pyrender = pyrender.Mesh.from_trimesh(mesh) rendered_images_list = [] # Loop through multiple camera views for idx in range(24): scene = pyrender.Scene() scene.add(mesh_pyrender) # Load and process camera parameters camera_params = param['poses'] intrinsic_params = camera_params[idx][1] # fx, fy, cx, cy extrinsic_params = camera_params[idx][0] # R|T # Set up Pyrender camera camera = pyrender.IntrinsicsCamera(fx=intrinsic_params[0], fy=intrinsic_params[1], cx=intrinsic_params[2], cy=intrinsic_params[3]) # Convert COLMAP coordinates to Pyrender-compatible transformation extrinsic_params_inv = torch.inverse(extrinsic_params.clone()) scale_factor = extrinsic_params_inv[:3, :3].norm(dim=1) extrinsic_params_inv[:3, 1:3] = -extrinsic_params_inv[:3, 1:3] extrinsic_params_inv[3, :3] = 0 # Add camera and lighting scene.add(camera, pose=extrinsic_params_inv) light = pyrender.DirectionalLight(color=[1.0, 1.0, 1.0], intensity=10.0) scene.add(light, pose=extrinsic_params_inv) # Render the scene renderer = pyrender.OffscreenRenderer(640, 896) color, depth = renderer.render(scene) rendered_images_list.append(color) renderer.delete() # Save rendered images as a video rendered_images = np.stack(rendered_images_list) imageio.mimwrite('rendered_results.mp4', rendered_images, fps=15) print("Rendered results saved as rendered_results.mp4") # Load an existing video and test alignment video_path = param_path.replace("param", "videos").replace("npy", "mp4") input_video = imageio.get_reader(video_path) input_frames = [frame for frame in input_video] blended_frames = [(0.5 * frame + 0.5 * render_frame).astype(np.uint8) for render_frame, frame in zip(rendered_images, input_frames)] imageio.mimwrite('aligned_results.mp4', blended_frames, fps=15) print("Blended video saved as aligned_results.mp4") ================================================ FILE: dataset/README.md ================================================ # 🌟 HuGe100K Dataset Documentation ## 📊 Dataset Overview HuGe100K is a large-scale multi-view human dataset featuring diverse attributes, high-fidelity appearances, and well-aligned SMPL-X models. ## 📁 File Format and Structure The dataset is organized with the following structure: ``` HuGe100K/ ├── flux_batch1/ │ ├── images[0...9]/ # different batch of images │ │ ├── videos/ # Folder for .mp4 and .jpg files │ │ │ ├── Algeria_female_average_high fashion_50~60 years old_844.jpg │ │ │ ├── Algeria_female_average_high fashion_50~60 years old_844.mp4 │ │ │ └── ... (more .jpg and .mp4) │ │ └── param/ # Folder for parameter files (.npy) │ │ ├── Algeria_female_average_high fashion_50~60 years old_844.npy │ │ └── ... (more .npy files) ├── flux_batch2/ │ └── ... (similar structure with images[0...9]) ├── flux_batch3/ │ └── ... (similar structure with images[0...9]) ├── flux_batch4/ │ └── ... (similar structure with images[0...9]) ├── flux_batch5/ │ └── ... (similar structure with images[0...9]) ├── flux_batch6/ │ └── ... (similar structure with images[0...9]) ├── flux_batch7/ │ └── ... (similar structure with images[0...9]) ├── flux_batch8/ │ └── ... (similar structure with images[0...9]) ├── flux_batch9/ │ └── ... (similar structure with images[0...9]) └── deepfashion/ └── ... (similar structure with images[0...9]) ``` Where: - Each `images[X]` folder contains: - `videos/`: Reference images and generatedvideo files - `param/`: Camera and body pose parameters - **flux_batch1 through flux_batch7**: Contains subjects in A-pose - **flux_batch8 and flux_batch9**: Contains subjects in diverse poses - **deepfashion**: Contains subjects in A-pose (derived from the DeepFashion dataset) ### File Naming Convention Files follow the naming pattern: `Area_Gender_BodyType_Clothing_Age_ID.extension` For example: - `Algeria_female_average_high fashion_50~60 years old_844.jpg`: Reference image of an Algerian female with average build in high fashion clothing - `Algeria_female_average_high fashion_50~60 years old_844.npy`: Parameter file for the same subject ### 📸 Sample Visualization
Kenya Female Fit Streetwear Image =MVChamp=> Kenya Female Fit Streetwear Image
## 📈 Dataset Statistics - **Total Subjects**: 100,000+ - **Views per Subject**: Multiple viewpoints covering 360° in 24 views - **Pose Types**: A-pose and diverse poses ## 🔍 Visualizing the Dataset For visualization and data parsing examples, please refer to our provided script: `visualize_samples.py`. This script demonstrates how to: - Load the SMPL-X parameters from `.npy` files - Render the 3D human model from multiple camera views - Compare rendered results with the original video frames Requirements for visualization: - SMPL-X model (download from [official website](https://smpl-x.is.tue.mpg.de/)) - Python packages: `pyrender`, `trimesh`, `smplx`, `numpy`, `torch` Example usage: ```bash python visualize_samples.py ``` The script will generate: - `rendered_results.mp4`: Rendered views of the 3D model - `aligned_results.mp4`: Blended visualization of rendered model with original frames ## 📋 Usage Guidelines 1. **Research Purposes Only**: This dataset is intended for academic and research purposes. 2. **Citation Required**: If you use this dataset in your research, please cite our paper. 3. **No Commercial Use**: Commercial use is permitted only with explicit permission from us at yiyu.zhuang@smail.nju.edu.cn. 4. **DeepFashion Derivatives**: See License and Attribution section below for special requirements. ## ⚖️ License and Attribution (DeepFashion) This dataset includes images derived from the **DeepFashion** dataset, originally provided by MMLAB at The Chinese University of Hong Kong. The use of DeepFashion images in this dataset has been explicitly authorized by the original authors solely for the purpose of creating and distributing this dataset. **Users must not further reproduce, distribute, sell, or commercially exploit any images or derived data originating from DeepFashion.** For any subsequent or separate use of the DeepFashion data, users must directly obtain authorization from MMLAB and comply with the original [DeepFashion License](https://mmlab.ie.cuhk.edu.hk/projects/DeepFashion.html). ================================================ FILE: dataset/visualize_samples.py ================================================ import torch import numpy as np import os os.environ["PYOPENGL_PLATFORM"] = "osmesa" import smplx import trimesh import pyrender import imageio def init_smplx_model(): """Initialize the SMPL-X model with predefined settings.""" body_model = smplx.SMPLX('PATH_TO_YOUR_SMPLX_FOLDER', gender="neutral", create_body_pose=False, create_betas=False, create_global_orient=False, create_transl=False, create_expression=True, create_jaw_pose=True, create_leye_pose=True, create_reye_pose=True, create_right_hand_pose=False, create_left_hand_pose=False, use_pca=False, num_pca_comps=12, num_betas=10, flat_hand_mean=False) return body_model # Load SMPL-X parameters param_path = "./samples/param/Kenya_female_fit_streetwear_50~60 years old_1501.npy" param = np.load(param_path, allow_pickle=True).item() # Extract SMPL-X parameters smpl_params = param['smpl_params'].reshape(1, -1) scale, transl, global_orient, pose, betas, left_hand_pose, right_hand_pose, jaw_pose, leye_pose, reye_pose, expression = torch.split(smpl_params, [1, 3, 3, 63, 10, 45, 45, 3, 3, 3, 10], dim=1) # Initialize SMPL-X model and generate vertices device = torch.device("cpu") model = init_smplx_model().to(device) output = model(global_orient=global_orient, body_pose=pose, betas=betas, left_hand_pose=left_hand_pose, right_hand_pose=right_hand_pose, jaw_pose=jaw_pose, leye_pose=leye_pose, reye_pose=reye_pose, expression=expression) vertices = output.vertices[0].detach().cpu().numpy() faces = model.faces # Create a Trimesh and Pyrender mesh mesh = trimesh.Trimesh(vertices, faces) mesh_pyrender = pyrender.Mesh.from_trimesh(mesh) rendered_images_list = [] # Loop through multiple camera views for idx in range(24): scene = pyrender.Scene() scene.add(mesh_pyrender) # Load and process camera parameters camera_params = param['poses'] intrinsic_params = camera_params[idx][1] # fx, fy, cx, cy extrinsic_params = camera_params[idx][0] # R|T # Set up Pyrender camera camera = pyrender.IntrinsicsCamera(fx=intrinsic_params[0], fy=intrinsic_params[1], cx=intrinsic_params[2], cy=intrinsic_params[3]) # Convert COLMAP coordinates to Pyrender-compatible transformation extrinsic_params_inv = torch.inverse(extrinsic_params.clone()) scale_factor = extrinsic_params_inv[:3, :3].norm(dim=1) extrinsic_params_inv[:3, 1:3] = -extrinsic_params_inv[:3, 1:3] extrinsic_params_inv[3, :3] = 0 # Add camera and lighting scene.add(camera, pose=extrinsic_params_inv) light = pyrender.DirectionalLight(color=[1.0, 1.0, 1.0], intensity=10.0) scene.add(light, pose=extrinsic_params_inv) # Render the scene renderer = pyrender.OffscreenRenderer(640, 896) color, depth = renderer.render(scene) rendered_images_list.append(color) renderer.delete() # Save rendered images as a video rendered_images = np.stack(rendered_images_list) imageio.mimwrite('rendered_results.mp4', rendered_images, fps=15) print("Rendered results saved as rendered_results.mp4") # Load an existing video and test alignment video_path = param_path.replace("param", "videos").replace("npy", "mp4") input_video = imageio.get_reader(video_path) input_frames = [frame for frame in input_video] blended_frames = [(0.5 * frame + 0.5 * render_frame).astype(np.uint8) for render_frame, frame in zip(rendered_images, input_frames)] imageio.mimwrite('aligned_results.mp4', blended_frames, fps=15) print("Blended video saved as aligned_results.mp4") ================================================ FILE: env/README.md ================================================ # Environment Setup Guide ## Prerequisites - Python 3.10 - CUDA 11.8 - PyTorch 2.3.1 ## Installation Steps ### 1. Environment Preparation First, create and activate a conda environment: ```bash conda create -n idol python=3.10 conda activate idol ``` Install all dependencies: ```bash bash scripts/pip_install.sh ``` ### 2. Download Required Models Before proceeding, please register on: - [SMPL-X website](https://smpl-x.is.tue.mpg.de/) - [FLAME website](https://flame.is.tue.mpg.de/) Then download the template files: ```bash bash scripts/fetch_template.sh ``` ### 3. Download Pretrained Models and caches with: ```bash bash scripts/download_files.sh # download pretrained models ``` Or mannually download the following models from HuggingFace: - [IDOL Model Checkpoint](https://huggingface.co/yiyuzhuang/IDOL/blob/main/model.ckpt) - [Sapiens Pretrained Model](https://huggingface.co/yiyuzhuang/IDOL/blob/main/sapiens_1b_epoch_173_torchscript.pt2) ## System Requirements - **GPU**: NVIDIA GPU with CUDA 11.8 support - **GPU RAM**: Recommended 24GB+ - **Storage**: At least 15GB free space ## Common Issues & Solutions **Issue**: ``` ImportError: libGL.so.1: cannot open shared object file: No such file or directory ``` when importing OpenCV (`import cv2`) **Solution**: ```bash # For Ubuntu/Debian sudo apt-get install libgl1-mesa-glx ``` ### 2. Gaussian Splatting Antialiasing Issue **Issue**: Error related to `antialiasing=True` setting in `GaussianRasterizationSettings` **Solution**: This issue arises due to updates in the Gaussian Splatting repository. Reinstalling the module from the GitHub repository with the latest version should resolve the problem. ================================================ FILE: lib/__init__.py ================================================ from .models import * from .mmutils import * from .humanlrm_wrapper_sa_v1 import SapiensGS_SA_v1 ================================================ FILE: lib/datasets/__init__.py ================================================ from .avatar_dataset import AvatarDataset from .dataloader import DataModuleFromConfig ================================================ FILE: lib/datasets/avatar_dataset.py ================================================ import os import random import numpy as np import torch import json import pickle from torch.utils.data import Dataset import torchvision.transforms.functional as F import pickle from torch.utils.data import Dataset import webdataset as wds # from lib.utils.train_util import print import cv2 import av from omegaconf import OmegaConf, ListConfig def load_pose(path): with open(path, 'rb') as f: pose_param = json.load(f) c2w = np.array(pose_param['cam_param'], dtype=np.float32).reshape(4,4) cam_center = c2w[:3, 3] w2c = np.linalg.inv(c2w) # pose[:,:2] *= -1 # pose = np.loadtxt(path, dtype=np.float32, delimiter=' ').reshape(9, 4) return [torch.from_numpy(w2c), torch.from_numpy(cam_center)] def load_npy(file_path): return np.load(file_path, allow_pickle=True) def load_smpl(path, smpl_type='smpl'): filetype = path.split('.')[-1] with open(path, 'rb') as f: if filetype=='pkl': smpl_param_data = pickle.load(f) elif filetype == 'json': smpl_param_data = json.load(f) else: assert False if smpl_type=='smpl': with open(os.path.join(os.path.split(path)[0][:-5], 'pose', '000_000.json'), 'rb') as f: tf_param = json.load(f) smpl_param = np.concatenate([np.array(tf_param['scale']).reshape(1, -1), np.array(tf_param['center'])[None], smpl_param_data['global_orient'], smpl_param_data['body_pose'].reshape(1, -1), smpl_param_data['betas']], axis=1) elif smpl_type == 'smplx': tf_param = np.load(os.path.join(os.path.dirname(os.path.dirname(path)), 'scale_offset.npy'), allow_pickle=True).item() # smpl_param = np.concatenate([np.array([tf_param['scale']]).reshape(1, -1), tf_param['offset'].reshape(1, -1), smpl_param = np.concatenate([np.array([[1]]), np.array([[0,0,0]]), np.array(smpl_param_data['global_orient']).reshape(1, -1), np.array(smpl_param_data['body_pose']).reshape(1, -1), np.array(smpl_param_data['betas']).reshape(1, -1), np.array(smpl_param_data['left_hand_pose']).reshape(1, -1), np.array(smpl_param_data['right_hand_pose']).reshape(1, -1), np.array(smpl_param_data['jaw_pose']).reshape(1, -1), np.array(smpl_param_data['leye_pose']).reshape(1, -1), np.array(smpl_param_data['reye_pose']).reshape(1, -1), np.array(smpl_param_data['expression']).reshape(1, -1)], axis=1) else: assert False return torch.from_numpy(smpl_param.astype(np.float32)).reshape(-1) class AvatarDataset(Dataset): def __init__(self, data_prefix, code_dir=None, # code_only=False, load_imgs=True, load_norm=False, specific_observation_idcs=None, specific_observation_num=None, first_is_front=False, # yy add # If True, it will returns a random sampled batch with the front view in the first place better_range=False, # yy add # If True, the views will not be fully random, but will be selected by a better skip if_include_video_ref_img= False,# yy add Define a variable to indicate whether to include reference images from the video prob_include_video_ref_img= 0.2, # yy add Define a variable to specify the probability allow_k_angles_near_the_front = 0, # yy add, if value > 0, the front view will be allowed to be selected from the range of [front_view - allow_k_angles_near_the_front, front_view + allow_k_angles_near_the_front] # num_test_imgs=0, if_use_swap_face_v1=False, # yy add, if True, use the swap face v1 random_test_imgs=False, scene_id_as_name=False, cache_path=None, cache_repeat=None, # be the same length with the cache_path test_pose_override=None, num_train_imgs=-1, load_cond_data=True, load_test_data=True, max_num_scenes=-1, # for debug or testing # radius=0.5, radius=1.0, img_res=[640, 896], test_mode=False, step=1, # only for debug & visualization purpose crop=False # randomly crop the image with upper body inputs ): super(AvatarDataset, self).__init__() self.data_prefix = data_prefix self.code_dir = code_dir # self.code_only = code_only self.load_imgs = load_imgs self.load_norm = load_norm self.specific_observation_idcs = specific_observation_idcs self.specific_observation_num = specific_observation_num self.first_is_front = first_is_front self.if_include_video_ref_img= if_include_video_ref_img self.prob_include_video_ref_img = prob_include_video_ref_img self.allow_k_angles_near_the_front = allow_k_angles_near_the_front self.better_range = better_range # self.num_test_imgs = num_test_imgs self.random_test_imgs = random_test_imgs self.scene_id_as_name = scene_id_as_name self.cache_path = cache_path self.cache_repeat = cache_repeat self.test_pose_override = test_pose_override self.num_train_imgs = num_train_imgs self.load_cond_data = load_cond_data self.load_test_data = load_test_data self.max_num_scenes = max_num_scenes self.step = step self.if_use_swap_face_v1 = if_use_swap_face_v1 # import ipdb; ipdb.set_trace() self.img_res = [int(i) for i in img_res] self.radius = torch.tensor([radius], dtype=torch.float32).expand(3) self.center = torch.zeros_like(self.radius) self.load_scenes() self.crop = crop self.test_poses = self.test_intrinsics = None self.defalut_focal = 1120 #40 * (self.img_res[0]/32) # focal 80mm, sensor 32mm self.default_fxy_cxy = torch.tensor([self.defalut_focal, self.defalut_focal, self.img_res[1]//2, self.img_res[0]//2]).reshape(1, 4) self.test_mode = test_mode if self.test_mode: self.parse_scene = self.parse_scene_test def load_scenes(self): if isinstance(self.cache_path, ListConfig): cache_list = [] case_num_per_dataset = 1000000000 for ii, path in enumerate(self.cache_path): cache = np.load(path, allow_pickle=True) if self.cache_repeat is not None: cache = np.repeat(cache, self.cache_repeat[ii], axis=0) print("done loading ", path) cache_list.extend(cache[:case_num_per_dataset]) scenes = cache_list print(f"=========intialized totally {len(scenes)} scenes===========") else: if self.cache_path is not None and os.path.exists(self.cache_path): scenes = np.load(self.cache_path, allow_pickle=True) print("load ", self.cache_path) else: print(f"{self.cache_path} is not exist") raise FileNotFoundError(f"maybe {self.cache_path} is not exist") end = len(scenes) if self.max_num_scenes >= 0: end = min(end, self.max_num_scenes * self.step) self.scenes = scenes[:end:self.step] self.num_scenes = len(self.scenes) def parse_scene(self, scene_id): scene = self.scenes[scene_id] input_is_video = False # flag of if the input is video, some operations should be different # print(scene) # scene['video_path'] = "/data/jxlv/transformers/src/A_pose_MEN-Denim-id_00000089-01_7_additional/result.mp4" # scene['image_paths'] = None #"/data/jxlv/transformers/src/A_pose_MEN-Denim-id_00000089-01_7_additional/source_seg.png" # import pdb # pdb.set_trace() # =========== loading the params =========== param = np.load(scene['param_path'], allow_pickle=True).item() scene.update(param) print(scene.keys()) # =========== loading the multi-view images =========== if scene['image_paths'] is None: input_is_video = True video_path = scene['video_path'] try: if self.if_use_swap_face_v1: image_paths_or_video = read_frames(scene['video_path'].replace('result.mp4', 'output.mp4')) else: image_paths_or_video = read_frames(scene['video_path']) except Exception as e: print(f"Error: {e}") print(f"Error in reading the video : {scene['video_path'].replace('result.mp4', 'output.mp4')}") image_paths_or_video = read_frames(scene['video_path']) # if 'pose_animate_service_0727' in video_path or 'flux' in video_path: # # move the first to the last # # TODO fixed this bug with a better cameras parameters # image_paths_or_video = image_paths_or_video[1:] + image_paths_or_video[0:1] if not input_is_video: image_paths_or_video = scene['image_paths'] scene_name = f"{scene_id:0>4d}" # image_paths[0].split('/')[-3] results = dict( scene_id=[scene_id], scene_name= '{:04d}'.format(scene_id) if self.scene_id_as_name else scene_name, # cpu_only=True ) # import pdb; pdb.set_trace() # if not self.code_only: poses = scene['poses'] smpl_params = scene['smpl_params'] # if input_is_video: # num_imgs = len(video) # else: num_imgs = len(image_paths_or_video) # front_view = num_imgs // 4 # randonly / specificically select the views of output smplx_cam_rotate = smpl_params[4: 7] #get global orient # 1, 3, 63, 10 # smpl_params[70:80] = torch.rand_like(smpl_params[70:80]); print("error !! need to delete this rand betas in avatarnet:287") # get betas front_view = find_front_camera_by_rotation(poses, smplx_cam_rotate) # inputs camera poses and smplx poses if self.allow_k_angles_near_the_front > 0: allow_n_views_near_the_front = round(self.allow_k_angles_near_the_front / 360 * num_imgs) new_front_view = random.randint(-allow_n_views_near_the_front, allow_n_views_near_the_front) + front_view if new_front_view >= num_imgs: new_front_view = new_front_view - num_imgs elif new_front_view < 0: new_front_view = new_front_view + num_imgs front_view = new_front_view print("changes the front views ranges", front_view, "+-", allow_n_views_near_the_front) if self.specific_observation_idcs is None: ######### if not specify views ######## # if self.num_train_imgs >= 0: # num_train_imgs = self.num_train_imgs # else: num_train_imgs = num_imgs if self.random_test_imgs: ###### randomly selected images with self.num_train_imgs ###### cond_inds = random.sample(range(num_imgs), self.num_train_imgs) elif self.specific_observation_num: ###### randomly selected "specific_observation_num" images ###### if self.first_is_front and self.specific_observation_num < 2: # self.specific_observation_num = 2 cond_inds =torch.tensor([front_view, front_view]) # first for input, second for supervised elif self.better_range: # select views by a uniform distribution range if self.first_is_front: # must include the front view num_train_imgs = self.specific_observation_num - 2 else: num_train_imgs = self.specific_observation_num skip_range = num_imgs//num_train_imgs # select random views from each range of [skip_range] seperate from [0, skip_range, 2*skip_range, ...], cond_inds = torch.randperm(num_train_imgs) * skip_range \ + torch.randint(low=0, high=skip_range, size=[num_train_imgs]) if self.first_is_front: # concat [the first view * 2] to the front of cond_inds cond_inds = torch.cat([torch.tensor([front_view, front_view]), cond_inds]) else: # previous version, random views are sampled cond_inds = torch.randperm(num_imgs)[:self.specific_observation_num] else: cond_inds = np.round(np.linspace(0, num_imgs - 1, num_train_imgs)).astype(np.int64) else: ######### selected target views ######## cond_inds = self.specific_observation_idcs test_inds = list(range(num_imgs)) if self.specific_observation_num: # yy note: if specific_observation_num is not None, then remove the test_inds test_inds = [] else: for cond_ind in cond_inds: test_inds.remove(cond_ind) cond_smpl_param_ref = torch.zeros([189]) if_use_smpl_param_ref = torch.Tensor([1]) # 默认使用ref smpl, if self.load_cond_data and len(cond_inds) > 0: # cond_imgs, cond_poses, cond_intrinsics, cond_img_paths, cond_smpl_param, cond_norm = gather_imgs(cond_inds) cond_imgs, cond_poses, cond_intrinsics, cond_img_paths, cond_smpl_param, cond_norm = \ gather_imgs(cond_inds, poses, image_paths_or_video, smpl_params, load_imgs=self.load_imgs, load_norm=self.load_norm, center=self.center, radius=self.radius, input_is_video=input_is_video) cond_smpl_param_ref = cond_smpl_param.clone() # the smpl_param_ref for the reference images if cond_intrinsics.shape[-1] == 3: # the old data format, which contains the value of camera center instead of fxfycxcy cond_intrinsics = self.default_fxy_cxy.clone().repeat(cond_intrinsics.shape[0], 1) # import pdb; pdb.set_trace() # print("video_path", video_path) if input_is_video: cond_img_paths = [f"{video_path[:-4]}_{i:0>4d}.png" for i in range(self.specific_observation_num)] # Replace the .mp4 into index.png if self.if_include_video_ref_img and input_is_video: # 设置一个随机数,如果小于某个概率,那么替换第一张图为另一个图片 if np.random.rand() < self.prob_include_video_ref_img: if 'image_ref' in scene: ref_image_path = scene['image_ref'] print("ref_image_path",ref_image_path) else: ref_image_path = video_path.replace(".mp4", ".jpg") # if "flux" in ref_image_path: # temperaturelly supports the inputs from the flux try: # replacement_img_path = ref_image_path # replacement_img = load_image(ref_image_path) # 假设有一个函数可以加载图片 # 使用cv2.IMREAD_UNCHANGED标志读取图片,以保留alpha通道 img = cv2.imread(ref_image_path, cv2.IMREAD_UNCHANGED) assert img is not None, f"img is None, {ref_image_path}" img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) img = torch.from_numpy(img.astype(np.float32) / 255) # (h, w, 3) # print("img.shape", img.shape) # print("cond_imgs.shape", cond_imgs.shape) # test_img_paths[0] = ref_image_path # results.update(test_imgs=test_imgs, test_img_paths=test_img_paths) # ======== loading the reference smplx for images ========== load_ref_smplx = False if load_ref_smplx: # if "flux" in ref_image_path: smplx_smplify_path = from_video_to_get_ref_smplx(video_path) # load json and get values with open(smplx_smplify_path) as f: data = json.load(f) RT = torch.concatenate([ torch.Tensor(data['camera']['R']), torch.Tensor(data['camera']['t']).reshape(3,1)], dim=1) RT = torch.cat([RT, torch.Tensor([[0,0,0,1]])], dim=0) intri = torch.Tensor(data['camera']['focal'] + data['camera']['princpt']) intri[[3,2]] = intri[[2,3]] intri = intri * self.default_fxy_cxy[0,-1] / intri[-1] # 假设 smpl_param_data 是已经加载好的数据 # (['root_pose', 'body_pose', 'jaw_pose', 'leye_pose', 'reye_pose', 'lhand_pose', 'rhand_pose', 'expr', 'trans', 'betas_save', 'nlf_smplx_betas', 'camera', 'img_path']) smpl_param_data = data # 从字典中提取所需的数据 global_orient = np.array(smpl_param_data['root_pose']).reshape(1, -1) body_pose = np.array(smpl_param_data['body_pose']).reshape(1, -1) shape = np.array(smpl_param_data['betas_save']).reshape(1, -1)[:, :10] left_hand_pose = np.array(smpl_param_data['lhand_pose']).reshape(1, -1) right_hand_pose = np.array(smpl_param_data['rhand_pose']).reshape(1, -1) # smpl_param_ref = np.concatenate([np.array([[1.]]), np.array([[0, 0, 0]]), smpl_param_ref = np.concatenate([np.array([[1.]]), np.array(smpl_param_data['trans']).reshape(1,3), global_orient,body_pose, shape, left_hand_pose, right_hand_pose, np.array(smpl_param_data['jaw_pose']).reshape(1, -1), np.array(smpl_param_data['leye_pose']).reshape(1, -1), np.array(smpl_param_data['reye_pose']).reshape(1, -1), np.array(smpl_param_data['expr']).reshape(1, -1)[:,:10]], axis=1) cond_poses[0] = RT # RT cond_intrinsics[0] = intri # fxfycxcy cond_smpl_param_ref = torch.Tensor(smpl_param_ref).reshape(-1) # 189, combines if_use_smpl_param_ref = torch.Tensor([1]) # use the smpl_param_ref # overwrite some datas cond_imgs[0] = img cond_img_paths[0] = ref_image_path except (FileNotFoundError, json.JSONDecodeError, KeyError, Exception) as e: # 记录错误信息到日志文件 # with open(self.log_file_path, 'a') as log_file: # log_file.write(f"{datetime.datetime.now()} {video_path} \n - An error occurred: {str(e)}\n") print(f"An error occurred: {e}") ''' randomly crop the first image for augmentation''' if self.crop: # print("crop", cond_imgs[0].shape) if random.random() < 0.5: # cond_imgs[0] = F.crop(cond_imgs[0], 0, 0, 512, 512) crop_imgs = cond_imgs[0] # 图像尺寸 h, w, _ = crop_imgs.shape # 随机偏移量 random_offset_head = np.random.randint(-h//7, -h//8) random_offset_body = np.random.randint(-h // 8, h // 8) # head_joint, upper_body_joint head_joint = [ w//2, h//7,] upper_body_joint = [w//2, h//2, ] # 计算裁剪区域 head_y = int(head_joint[1]) + random_offset_head body_y = int(upper_body_joint[1]) + random_offset_body # 确保裁剪区域在图像范围内 head_y = max(0, min(h, head_y)) body_y = max(0, min(h, body_y)) # 裁剪区域的高度和宽度 crop_height = body_y - head_y crop_width =int(crop_height * 640 / 896) # 确保裁剪区域在图像范围内 start_x = max(0, min(w - crop_width, int(w // 2 - crop_width // 2))) end_x = start_x + crop_width start_y = max(0, head_y) end_y = min(h, body_y) # 裁剪图像 cropped_img = crop_imgs[start_y:end_y, start_x:end_x] padded_img = F.resize(cropped_img.permute(2, 0, 1), [h, w]).permute(1, 2, 0) # save this img for debug # Image.fromarray((padded_img.numpy() * 255).astype(np.uint8)).save(f"debug_crop.png") # rescale the image for augmentation cond_imgs[0] = random_scale_and_crop(padded_img, (0.8,1.2)) else: cond_imgs[0] = random_scale_and_crop(cond_imgs[0], (0.8,1.1)) results.update( cond_poses=cond_poses, cond_intrinsics=cond_intrinsics.to(torch.float32), cond_img_paths=cond_img_paths, cond_smpl_param=cond_smpl_param, cond_smpl_param_ref=cond_smpl_param_ref, if_use_smpl_param_ref=if_use_smpl_param_ref) if cond_imgs is not None: results.update(cond_imgs=cond_imgs) if cond_norm is not None: results.update(cond_norm=cond_norm) if self.load_test_data and len(test_inds) > 0: test_imgs, test_poses, test_intrinsics, test_img_paths, test_smpl_param, test_norm = \ gather_imgs(test_inds, poses, image_paths_or_video, smpl_params, load_imgs=self.load_imgs, load_norm=self.load_norm, center=self.center, radius=self.radius) if test_intrinsics.shape[-1] == 3: # the old data format, which contains the value of camera center instead of fxfycxcy test_intrinsics = self.default_fxy_cxy.clone().repeat(test_intrinsics.shape[0], 1) results.update( test_poses=test_poses, test_intrinsics=test_intrinsics, test_img_paths=test_img_paths, test_smpl_param=test_smpl_param) if test_imgs is not None: results.update(test_imgs=test_imgs) if test_norm is not None: results.update(test_norm=test_norm) if self.test_pose_override is not None: results.update(test_poses=self.test_poses, test_intrinsics=self.test_intrinsics) return results def __len__(self): return self.num_scenes def __getitem__(self, scene_id): try: scene = self.parse_scene(scene_id) except: print("ERROR in parsing ", scene_id) scene = self.parse_scene(0) return scene def parse_scene_test(self, scene_id): scene = self.scenes[scene_id] input_is_video = False # flag of if the input is video, some operations should be different # =========== loading the params =========== param = np.load(scene['param_path'], allow_pickle=True).item() scene.update(param) print(scene.keys()) # import ipdb; ipdb.set_trace() if scene['image_paths'] is None: input_is_video = True video_path = scene['video_path'] try: image_paths_or_video = read_frames(scene['video_path']) except Exception as e: print(f"Error: {e}") print(f"Error in reading the video : {scene['video_path'].replace('result.mp4', 'output.mp4')}") image_paths_or_video = read_frames(scene['video_path']) if not input_is_video: image_paths_or_video = scene['image_paths'] scene_name = f"{scene_id:0>4d}" # image_paths[0].split('/')[-3] results = dict( scene_id=[scene_id], scene_name= '{:04d}'.format(scene_id) if self.scene_id_as_name else scene_name, # cpu_only=True ) # if not self.code_only: poses = scene['poses'] smpl_params = scene['smpl_params'] # if input_is_video: # num_imgs = len(video) # else: num_imgs = len(image_paths_or_video) # front_view = num_imgs // 4 # randonly / specificically select the views of output smplx_cam_rotate = smpl_params[4: 7] #get global orient # 1, 3, 63, 10 # smpl_params[70:80] = torch.rand_like(smpl_params[70:80]); print("error !! need to delete this rand betas in avatarnet:287") # get betas front_view = find_front_camera_by_rotation(poses, smplx_cam_rotate) # inputs camera poses and smplx poses if self.allow_k_angles_near_the_front > 0: allow_n_views_near_the_front = round(self.allow_k_angles_near_the_front / 360 * num_imgs) new_front_view = random.randint(-allow_n_views_near_the_front, allow_n_views_near_the_front) + front_view if new_front_view >= num_imgs: new_front_view = new_front_view - num_imgs elif new_front_view < 0: new_front_view = new_front_view + num_imgs front_view = new_front_view print("changes the front views ranges", front_view, "+-", allow_n_views_near_the_front) num_train_imgs = num_imgs test_inds = torch.Tensor( list(range(num_imgs))) cond_inds = np.concatenate([np.array([front_view]),test_inds]).astype(np.int64) # first for input, second for supervised test_inds = cond_inds.tolist() # if self.specific_observation_num: # yy note: if specific_observation_num is not None, then remove the test_inds # test_inds = [] # else: # for cond_ind in cond_inds: # test_inds.remove(cond_ind) # cond_inds = cond_inds cond_smpl_param_ref = torch.zeros([189]) if_use_smpl_param_ref = torch.Tensor([1]) # 默认使用ref smpl, if self.load_cond_data and len(cond_inds) > 0: # cond_imgs, cond_poses, cond_intrinsics, cond_img_paths, cond_smpl_param, cond_norm = gather_imgs(cond_inds) cond_imgs, cond_poses, cond_intrinsics, cond_img_paths, cond_smpl_param, cond_norm = \ gather_imgs(cond_inds, poses, image_paths_or_video, smpl_params, load_imgs=self.load_imgs, load_norm=self.load_norm, center=self.center, radius=self.radius, \ input_is_video=input_is_video) cond_smpl_param_ref = cond_smpl_param.clone() # the smpl_param_ref for the reference images if cond_intrinsics.shape[-1] == 3: # the old data format, which contains the value of camera center instead of fxfycxcy cond_intrinsics = self.default_fxy_cxy.clone().repeat(cond_intrinsics.shape[0], 1) if input_is_video: cond_img_paths = [f"{video_path[:-4]}_{i:0>4d}.png" for i in range(self.specific_observation_num)] # Replace the .mp4 into index.png results.update( cond_poses=cond_poses, cond_intrinsics=cond_intrinsics.to(torch.float32), cond_img_paths=cond_img_paths, cond_smpl_param=cond_smpl_param, cond_smpl_param_ref=cond_smpl_param_ref, if_use_smpl_param_ref=if_use_smpl_param_ref) if cond_imgs is not None: results.update(cond_imgs=cond_imgs) if cond_norm is not None: results.update(cond_norm=cond_norm) if self.load_test_data and len(test_inds) > 0: print("input_is_video", input_is_video) test_imgs, test_poses, test_intrinsics, test_img_paths, test_smpl_param, test_norm = \ gather_imgs(test_inds, poses, image_paths_or_video, smpl_params, load_imgs=self.load_imgs, load_norm=self.load_norm, center=self.center, radius=self.radius, \ input_is_video=input_is_video) if test_intrinsics.shape[-1] == 3: # the old data format, which contains the value of camera center instead of fxfycxcy test_intrinsics = self.default_fxy_cxy.clone().repeat(test_intrinsics.shape[0], 1) results.update( test_poses=test_poses, test_intrinsics=test_intrinsics, test_img_paths=test_img_paths, test_smpl_param=test_smpl_param) if test_imgs is not None: results.update(test_imgs=test_imgs) if test_norm is not None: results.update(test_norm=test_norm) if self.test_pose_override is not None: results.update(test_poses=self.test_poses, test_intrinsics=self.test_intrinsics) return results def gather_imgs(img_ids, poses, image_paths_or_video, smpl_params, load_imgs=True, load_norm=False, center=None, radius=None, input_is_video=False): imgs_list = [] if load_imgs else None norm_list = [] if load_norm else None poses_list = [] cam_centers_list = [] img_paths_list = [] for img_id in img_ids: pose = poses[img_id][0] cam_centers_list.append((poses[img_id][1]).to(torch.float)) # (C) c2w = pose.to(torch.float)#torch.FloatTensor(pose) # 虽然是c2w但其实存的值应该是w2c (R|T) cam_to_ndc = torch.cat( [c2w[:3, :3], (c2w[:3, 3:] - center[:, None]) / radius[:, None]], dim=-1) poses_list.append( torch.cat([ cam_to_ndc, cam_to_ndc.new_tensor([[0.0, 0.0, 0.0, 1.0]]) ], dim=-2)) if input_is_video: # img_paths_list.append(video[img_id]) img = image_paths_or_video[img_id] # for img in imgs: # add the ajustment to make the color > [250,250,250] to be white mask_white = np.all(img[:,:,:3] > 250, axis=-1, keepdims=False) # Image.fromarray(img).save(f"debug.png") # Image.fromarray(mask_white).save(f"debug_mask.png") img[mask_white] = [255, 255, 255] # Image.fromarray(img).save(f"debug_afmask.png") img = torch.from_numpy(img.astype(np.float32) / 255) # (h, w, 3) imgs_list.append(img) else: img_paths_list.append(image_paths_or_video[img_id]) if load_imgs: # img = mmcv.imread(image_paths[img_id], channel_order='rgb') # 使用cv2.IMREAD_UNCHANGED标志读取图片,以保留alpha通道 print("Loading, .......", image_paths_or_video[img_id]) print("Loading, .......", image_paths_or_video[img_id]) print("Loading, .......", image_paths_or_video[img_id]) img = cv2.imread(image_paths_or_video[img_id], cv2.IMREAD_UNCHANGED) print("img.shape", img.shape) # 将透明像素的RGB值设置为白色(255, 255, 255) img[img[..., 3] == 0] = [255, 255, 255, 255] img = img[..., :3] img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) img = torch.from_numpy(img.astype(np.float32) / 255) # (h, w, 3) imgs_list.append(img) if load_norm: # have not support the input type is video norm = cv2.imread(image_paths_or_video[img_id].replace('rgb', 'norm'), cv2.IMREAD_UNCHANGED) norm = cv2.cvtColor(norm, cv2.COLOR_BGR2RGB) norm = torch.from_numpy(norm.astype(np.float32) / 255) norm_list.append(norm) poses_list = torch.stack(poses_list, dim=0) # (n, 4, 4) cam_centers_list = torch.stack(cam_centers_list, dim=0) if load_imgs: imgs_list = torch.stack(imgs_list, dim=0) # (n, h, w, 3) if load_norm: norm_list = torch.stack(norm_list, dim=0) return imgs_list, poses_list, cam_centers_list, img_paths_list, smpl_params, norm_list from scipy.spatial.transform import Rotation as R def calculate_angle(vector1, vector2): unit_vector1 = vector1 / torch.linalg.norm(vector1) unit_vector2 = vector2 / torch.linalg.norm(vector2) dot_product = torch.dot(unit_vector1, unit_vector2) angle = torch.arccos(dot_product) return angle def axis_angle_to_rotation_matrix(axis_angle): if_input_is_torch = torch.is_tensor(axis_angle) if if_input_is_torch: dtype_torch = axis_angle.dtype axis_angle = axis_angle.numpy() r = R.from_rotvec(axis_angle) rotation_matrix = r.as_matrix() if if_input_is_torch: rotation_matrix = torch.from_numpy(rotation_matrix).to(dtype_torch) return rotation_matrix # def find_front_camera_by_global_orient(global_orient, camera_direction): # front_direction = np.array([0, 0, -1]) # 人体正面方向 # min_angle = float('inf') # front_camera_idx = -1 # for idx, global_orient in enumerate(global_orient_list): # angle = calculate_angle(body_direction, front_direction) # if angle < min_angle: # min_angle = angle # front_camera_idx = idx # return front_camera_idx def find_front_camera_by_rotation(poses, global_orient_human): # front_direction = global_orient_human # 人体正面方向 rotation_matrix = axis_angle_to_rotation_matrix(global_orient_human) front_direction = rotation_matrix @ torch.Tensor([0, 0, -1]) # 人体的朝向 min_angle = float('inf') front_camera_idx = -1 for idx, pose in enumerate(poses): rotation_matrix = pose[0][:3, :3] camera_direction = rotation_matrix @ torch.Tensor([0, 0, 1]) # 相机的朝向 angle = calculate_angle(camera_direction, front_direction).to(camera_direction.dtype) if angle < min_angle: min_angle = angle front_camera_idx = idx return front_camera_idx def read_frames(video_path): container = av.open(video_path) video_stream = next(s for s in container.streams if s.type == "video") frames = [] for packet in container.demux(video_stream): for frame in packet.decode(): # image = Image.frombytes( # "RGB", # (frame.width, frame.height), # frame.to_rgb().to_ndarray(), # ) image = frame.to_rgb().to_ndarray() frames.append(image) return frames def prepare_camera( resolution_x, resolution_y, num_views=24, stides=1): # resolution_x = 640 # resolution_y = 896 import math focal_length = 40 #80 sensor_width = 32 # # 创建 Pyrender 相机 # camera = pyrender.PerspectiveCamera(yfov=fov, aspectRatio=aspect_ratio) focal_length = focal_length * (resolution_y/sensor_width) K = np.array( [[focal_length, 0, resolution_x//2], [0, focal_length, resolution_y//2], [0, 0, 1]] ) # print("update!! the camera intrisic is error 0819") def look_at(camera_position, target_position, up_vector): # colmap +z forward, +y down forward = -(camera_position - target_position) / np.linalg.norm(camera_position - target_position) right = np.cross(up_vector, forward) up = np.cross(forward, right) return np.column_stack((right, up, forward)) camera_pose_list = [] for frame_idx in range(0, num_views, stides): # 设置相机的位置和方向 camera_dist = 1.5 #3 #1.2 * 2 phi = math.radians(90) theta = (frame_idx / num_views) * math.pi * 2 camera_location = np.array( [camera_dist * math.sin(phi) * math.cos(theta), camera_dist * math.cos(phi), -camera_dist * math.sin(phi) * math.sin(theta),] ) # print(camera_location) camera_pose = np.eye(4) camera_pose[:3, 3] = camera_location # print("camera_location", camera_location) # from smplx import look_at # 设置相机位置和目标位置 camera_position = camera_location target_position = np.array([0.0, 0.0, 0.0]) # 计算相机的旋转矩阵,使其朝向目标 # up_vector = np.array([0.0, 1.0, 0.0]) up_vector = np.array([0.0, -1.0, 0.0]) # colmap rotation_matrix = look_at(camera_position, target_position, up_vector) # 更新相机的位置和旋转 camera_pose[:3, :3] = rotation_matrix camera_pose[:3, 3] = camera_position camera_pose_list.append(camera_pose) return K, camera_pose_list def from_video_to_get_ref_smplx(video_path): # 分解路径 video_dir = os.path.dirname(video_path) video_name = video_dir.split("/")[-1] # 视频文件夹名称 # 替换视频目录为 smplify 目录 if "flux" in video_dir: smplify_dir = video_dir.replace('/videos/', '/smplx_smplify/') elif "DeepFashion" in video_dir: smplify_dir = video_dir.replace('/video/', '/smplx_smplify/').replace("A_pose_", "") # # 获取视频文件名(不包括扩展名) # video_name = os.path.splitext(video_file)[0] # 构建 JSON 文件路径 if smplify_dir[-1] == '/': smplify_dir = smplify_dir[:-1] json_path = smplify_dir+".json" #os.path.join(smplify_dir, f"{video_name}.json") return json_path def random_scale_and_crop(image: torch.Tensor, scale_range=(0.8, 1.2)) -> torch.Tensor: """ Randomly scale the input image and crop/pad to maintain original size. Args: image: Input image tensor of shape [H, W, 3] scale_range: Range for scaling factor, default (0.8, 1.2) Returns: Scaled and cropped/padded image tensor of shape [H, W, 3] """ is_numpy = False if not torch.is_tensor(image): image = torch.from_numpy(image) is_numpy = True # 获取图像的高度和宽度 h, w = image.shape[:2] # 生成随机缩放因子 scale_factor = random.uniform(*scale_range) # 计算新的高度和宽度 new_h = int(h * scale_factor) new_w = int(w * scale_factor) # 使用 torchvision.transforms.functional.resize 进行缩放 scaled_image = F.resize(image.permute(2, 0, 1), [new_h, new_w]).permute(1, 2, 0) # 如果缩放后的图像比原图大,进行居中裁剪 if new_h > h or new_w > w: top = (new_h - h) // 2 left = (new_w - w) // 2 scaled_image = scaled_image[top:top + h, left:left + w] else: # 如果缩放后的图像比原图小,进行居中填充 padded_image = torch.ones((h, w, 3), dtype=image.dtype) top = h-new_h #(h - new_h) // 2 # H不应该居中 left = (w - new_w) // 2 padded_image[top:top + new_h, left:left + new_w] = scaled_image scaled_image = padded_image if is_numpy: scaled_image = scaled_image.numpy() return scaled_image if __name__ == "__main__": import os params = { "data_prefix": None, "cache_path": ListConfig([ "./processed_data/deepfashion_train_145_local.npy", "./processed_data/flux_batch1_5000_train_145_local.npy" ]), "specific_observation_num": 5, "better_range": True, "first_is_front": True, "if_include_video_ref_img": True, "prob_include_video_ref_img": 0.5, "img_res": [640, 896], 'test_mode': True } data = AvatarDataset(**params) sample = data[0] print(sample.keys()) import os import torch.distributed as dist os.environ['RANK'] = '0' os.environ['WORLD_SIZE'] = '1' os.environ['MASTER_ADDR'] = '127.0.0.1' os.environ['MASTER_PORT'] = '29500' dist.init_process_group(backend='nccl', rank=0, world_size=1) # test the batch loader from torch.utils.data import DataLoader dataloader = DataLoader(data, batch_size=10, shuffle=True, collate_fn=custom_collate_fn) from torch.utils.data.distributed import DistributedSampler import webdataset as wds # sampler = DistributedSampler(data) # training is true!~ sampler = None dataloader = wds.WebLoader(data, batch_size=10, num_workers=1, shuffle=False, sampler=sampler, ) try: for i, batch in enumerate(dataloader): print(batch.keys()) # break except Exception as e: import traceback print("Caught an exception during dataloader iteration:") traceback.print_exc() ================================================ FILE: lib/datasets/dataloader.py ================================================ import os, sys import json import numpy as np import webdataset as wds import pytorch_lightning as pl import torch from torch.utils.data import Dataset from torch.utils.data.distributed import DistributedSampler file_dir = os.path.abspath(os.path.dirname(__file__)) project_root = os.path.join(file_dir, '..', '..') sys.path.append(project_root) from lib.utils.train_util import instantiate_from_config from torch.utils.data import DataLoader class DataModuleFromConfig(pl.LightningDataModule): def __init__( self, batch_size=8, num_workers=4, train=None, validation=None, test=None, **kwargs, ): super().__init__() self.batch_size = batch_size self.num_workers = num_workers self.dataset_configs = dict() if train is not None: self.dataset_configs['train'] = train if validation is not None: self.dataset_configs['validation'] = validation if test is not None: self.dataset_configs['test'] = test def setup(self, stage): self.datasets = dict((k, instantiate_from_config(self.dataset_configs[k])) for k in self.dataset_configs) def train_dataloader(self): sampler = DistributedSampler(self.datasets['train']) if torch.distributed.is_initialized() else None return DataLoader( self.datasets['train'], batch_size=self.batch_size, num_workers=self.num_workers, shuffle=(sampler is None), sampler=sampler, pin_memory=True, drop_last=True, ) def val_dataloader(self): sampler = DistributedSampler(self.datasets['validation']) if torch.distributed.is_initialized() else None return DataLoader( self.datasets['validation'], batch_size=1, num_workers=self.num_workers, shuffle=False, sampler=sampler, pin_memory=True ) def test_dataloader(self): return DataLoader( self.datasets['test'], batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False, pin_memory=True ) ================================================ FILE: lib/humanlrm_wrapper_sa_v1.py ================================================ import os import math import json from torch.optim import Adam from torch.nn.parallel.distributed import DistributedDataParallel import torch import torch.nn.functional as F from torchvision.transforms import InterpolationMode from torchvision.utils import make_grid, save_image from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity from torchmetrics.image.ssim import StructuralSimilarityIndexMeasure import pytorch_lightning as pl from pytorch_lightning.utilities.grads import grad_norm from einops import rearrange, repeat from lib.utils.train_util import instantiate_from_config from lib.ops.activation import TruncExp import time import matplotlib.pyplot as plt from PIL import Image import numpy as np from lib.utils.train_util import main_print from typing import List, Optional, Tuple, Union def get_1d_rotary_pos_embed( dim: int, pos: Union[torch.Tensor, int], theta: float = 10000.0, use_real=False, linear_factor=1.0, ntk_factor=1.0, repeat_interleave_real=True, ): """ Precompute the frequency tensor for complex exponentials (cis) with given dimensions. This function calculates a frequency tensor with complex exponentials using the given dimension 'dim' and the end index 'end'. The 'theta' parameter scales the frequencies. The returned tensor contains complex values in complex64 data type. Args: dim (`int`): Dimension of the frequency tensor. pos (`torch.Tensor` or `int`): Position indices for the frequency tensor. [S] or scalar theta (`float`, *optional*, defaults to 10000.0): Scaling factor for frequency computation. Defaults to 10000.0. use_real (`bool`, *optional*): If True, return real part and imaginary part separately. Otherwise, return complex numbers. linear_factor (`float`, *optional*, defaults to 1.0): Scaling factor for the context extrapolation. Defaults to 1.0. ntk_factor (`float`, *optional*, defaults to 1.0): Scaling factor for the NTK-Aware RoPE. Defaults to 1.0. repeat_interleave_real (`bool`, *optional*, defaults to `True`): If `True` and `use_real`, real part and imaginary part are each interleaved with themselves to reach `dim`. Otherwise, they are concateanted with themselves. Returns: `torch.Tensor`: Precomputed frequency tensor with complex exponentials. [S, D/2] """ assert dim % 2 == 0 if isinstance(pos, int): pos = torch.arange(pos) theta = theta * ntk_factor freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) / linear_factor # [D/2] t = pos # torch.from_numpy(pos).to(freqs.device) # type: ignore # [S] freqs = freqs.to(device=t.device, dtype=t.dtype) # type: ignore freqs = torch.outer(t, freqs).float() # type: ignore # [S, D/2] if use_real and repeat_interleave_real: freqs_cos = freqs.cos().repeat_interleave(2, dim=1) # [S, D] freqs_sin = freqs.sin().repeat_interleave(2, dim=1) # [S, D] return freqs_cos, freqs_sin elif use_real: freqs_cos = torch.cat([freqs.cos(), freqs.cos()], dim=-1) # [S, D] freqs_sin = torch.cat([freqs.sin(), freqs.sin()], dim=-1) # [S, D] return freqs_cos, freqs_sin else: freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 # [S, D/2] return freqs_cis class FluxPosEmbed(torch.nn.Module): # modified from https://github.com/black-forest-labs/flux/blob/c00d7c60b085fce8058b9df845e036090873f2ce/src/flux/modules/layers.py#L11 def __init__(self, theta: int, axes_dim: [int]): super().__init__() self.theta = theta self.axes_dim = axes_dim def forward(self, ids: torch.Tensor) -> torch.Tensor: n_axes = ids.shape[-1] cos_out = [] sin_out = [] pos = ids.float() is_mps = ids.device.type == "mps" freqs_dtype = torch.float32 if is_mps else torch.float64 for i in range(n_axes): cos, sin = get_1d_rotary_pos_embed( self.axes_dim[i], pos[:, i], repeat_interleave_real=True, use_real=True#, freqs_dtype=freqs_dtype ) cos_out.append(cos) sin_out.append(sin) freqs_cos = torch.cat(cos_out, dim=-1).to(ids.device) freqs_sin = torch.cat(sin_out, dim=-1).to(ids.device) return freqs_cos, freqs_sin class SapiensGS_SA_v1(pl.LightningModule): def __init__( self, encoder=dict(type='mmpretrain.VisionTransformer'), neck=dict(type='mmpretrain.VisionTransformer'), decoder=dict(), diffusion_use_ema=True, freeze_decoder=False, image_cond=False, code_permute=None, code_reshape=None, autocast_dtype=None, ortho=True, return_norm=False, # reshape_type='reshape', # 'cnn' code_size=None, decoder_use_ema=None, bg_color=1, training_mode=None, # stage2's flag, default None for stage1 patch_size: int = 4, warmup_steps: int = 12_000, use_checkpoint: bool = True, lambda_depth_tv: float = 0.05, lambda_lpips: float = 2.0, lambda_mse: float = 1.0, lambda_l1: float=0, lambda_ssim: float=0, neck_learning_rate: float = 5e-4, decoder_learning_rate: float = 1e-3, encoder_learning_rate: float=0, max_steps: int = 100_000, loss_coef: float = 0.5, init_iter: int = 500, lambda_offset: int = 50, # offset_weight: 50 scale_weight: float = 0.01, is_debug: bool = False, # if debug, then it will not returns lpips code_activation: dict=None, output_hidden_states: bool=False, # if True, will output the hidden states from sapiens shallow layer, for the neck decoder loss_weights_views: List = [], # the loss weights for the views, if empty, will use the same weights for all the views **kwargs ): super(SapiensGS_SA_v1, self).__init__() ## ========== part -- Add the code to save this parameters for optimizers ======== self.warmup_steps = warmup_steps self.use_checkpoint = use_checkpoint self.lambda_depth_tv = lambda_depth_tv self.lambda_lpips = lambda_lpips self.lambda_mse = lambda_mse self.lambda_l1 = lambda_l1 self.lambda_ssim = lambda_ssim self.neck_learning_rate = neck_learning_rate self.decoder_learning_rate = decoder_learning_rate self.encoder_learning_rate = encoder_learning_rate self.max_steps = max_steps self.loss_coef = loss_coef self.init_iter = init_iter self.lambda_offset = lambda_offset self.scale_weight = scale_weight self.is_debug = is_debug ## ========== end part ======== self.code_size = code_size if code_activation['type'] == 'tanh': self.code_activation = torch.nn.Tanh() else: self.code_activation = TruncExp() #build_module(code_activation) # self.grid_size = grid_size self.decoder = instantiate_from_config(decoder) self.decoder_use_ema = decoder_use_ema if decoder_use_ema: raise NotImplementedError("decoder_use_ema has not been implemented") if self.decoder_use_ema: self.decoder_ema = deepcopy(self.decoder) self.encoder = instantiate_from_config(encoder) # get_obj_from_str(config["target"]) self.code_size = code_reshape self.code_clip_range = [-1,1] # ============= begin config ============= # transformer from class MAEPretrainDecoder(BaseModule): # compress the token number of the uv code self.patch_size = patch_size self.code_patch_size = self.patch_size self.num_patches_axis = code_reshape[-1]//self.patch_size # reshape it for the upsampling self.num_patches = self.num_patches_axis ** 2 self.code_feat_dims = code_reshape[0] # only used for the upsampling of 'reshape' type self.code_resolution = code_reshape[-1] # only used for the upsampling of 'reshape' type self.reshape_type = self.decoder.reshape_type self.inputs_front_only = True self.render_loss_all_view = True self.if_include_video_ref_img = True self.training_mode = training_mode self.loss_weights_views = torch.Tensor(loss_weights_views).reshape(-1) / sum(loss_weights_views) # normalize the weights # ========== config meaning =========== self.neck = instantiate_from_config(neck) self.ids_restore = torch.arange(0, self.num_patches).unsqueeze(0) self.freeze_decoder = freeze_decoder if self.freeze_decoder: self.decoder.requires_grad_(False) if self.decoder_use_ema: self.decoder_ema.requires_grad_(False) self.image_cond = image_cond self.code_permute = code_permute self.code_reshape = code_reshape self.code_reshape_inv = [self.code_size[axis] for axis in self.code_permute] if code_permute is not None \ else self.code_size self.code_permute_inv = [self.code_permute.index(axis) for axis in range(len(self.code_permute))] \ if code_permute is not None else None self.autocast_dtype = autocast_dtype self.ortho = ortho self.return_norm = return_norm '''add a flag for the skip connection from sapiens shallow layer''' self.output_hidden_states = output_hidden_states ''' add the in-the-wild images visualization''' if self.lambda_lpips > 0: self.lpips = LearnedPerceptualImagePatchSimilarity(net_type='vgg') else: self.lpips = None self.ssim = StructuralSimilarityIndexMeasure() self.validation_step_outputs = [] self.validation_step_code_outputs = [] # saving the code self.validation_step_nvPose_outputs = [] self.validation_metrics = [] # loading the smplx for the nv pose import json import numpy as np # evaluate the animation smplx_path = './work_dirs/demo_data/Ways_to_Catch_360_clip1.json' with open(smplx_path, 'r') as f: smplx_pose_param = json.load(f) smplx_param_list = [] for par in smplx_pose_param['annotations']: k = par['smplx_params'] for i in k.keys(): k[i] = np.array(k[i]) left_hands = np.array([1.4624, -0.1615, 0.1361, 1.3851, -0.2597, 0.0247, -0.0683, -0.4478, -0.6652, -0.7290, 0.0084, -0.4818]) betas = torch.zeros((10)) smplx_param = \ np.concatenate([np.array([1]), np.array([0,0.,0]), np.array([0, -1, 0])*k['root_orient'], \ k['pose_body'],betas, \ k['pose_hand'], k['pose_jaw'], np.zeros(6), k['face_expr'][:10]], axis=0).reshape(1,-1) # print(smplx_param.shape) smplx_param_list.append(smplx_param) smplx_params = np.concatenate(smplx_param_list, 0) self.smplx_params = torch.Tensor(smplx_params).cuda() def get_default_smplx_params(self): A_pose = torch.Tensor([[ 1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.1047, 0.0000, 0.0000, -0.1047, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, -0.7854, 0.0000, 0.0000, 0.7854, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, -0.7470, 1.0966, 0.0169, -0.0534, -0.0212, 0.0782, -0.0348, 0.0260, 0.0060, 0.0118, -0.1117, -0.0429, 0.4164, -0.1088, 0.0660, 0.7562, 0.0964, 0.0909, 0.1885, 0.1181, -0.0509, 0.5296, 0.1437, -0.0552, 0.7049, 0.0192, 0.0923, 0.3379, 0.4570, 0.1963, 0.6255, 0.2147, 0.0660, 0.5069, 0.3697, 0.0603, 0.0795, 0.1419, 0.0859, 0.6355, 0.3033, 0.0579, 0.6314, 0.1761, 0.1321, 0.3734, -0.8510, -0.2769, 0.0915, 0.4998, -0.0266, -0.0529, -0.5356, -0.0460, 0.2774, -0.1117, 0.0429, -0.4164, -0.1088, -0.0660, -0.7562, 0.0964, -0.0909, -0.1885, 0.1181, 0.0509, -0.5296, 0.1437, 0.0552, -0.7049, 0.0192, -0.0923, -0.3379, 0.4570, -0.1963, -0.6255, 0.2147, -0.0660, -0.5069, 0.3697, -0.0603, -0.0795, 0.1419, -0.0859, -0.6355, 0.3033, -0.0579, -0.6314, 0.1761, -0.1321, -0.3734, -0.8510, 0.2769, -0.0915, 0.4998, 0.0266, 0.0529, -0.5356, 0.0460, -0.2774, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]]) return A_pose def forward_decoder(self, decoder, code, target_rgbs, cameras, smpl_params=None, return_decoder_loss=False, init=False): decoder = self.decoder_ema if self.freeze_decoder and self.decoder_use_ema else self.decoder num_imgs = target_rgbs.shape[1] outputs = decoder( code, smpl_params, cameras, num_imgs, return_loss=return_decoder_loss, init=init, return_norm=False) return outputs def on_fit_start(self): if self.global_rank == 0: os.makedirs(os.path.join(self.logdir, 'images'), exist_ok=True) os.makedirs(os.path.join(self.logdir, 'images_val'), exist_ok=True) os.makedirs(os.path.join(self.logdir, 'images_val_code'), exist_ok=True) def forward(self, data): # print("iter") num_scenes = len(data['scene_id']) # 8 if 'cond_imgs' in data: cond_imgs = data['cond_imgs'] # (num_scenes, num_imgs, h, w, 3) cond_intrinsics = data['cond_intrinsics'] # (num_scenes, num_imgs, 4), in [fx, fy, cx, cy] cond_poses = data['cond_poses'] # (num_scenes, num_imgs, 4, 4) smpl_params = data['cond_smpl_param'] # (num_scenes, c) # if 'cond_norm' in data:cond_norm = data['cond_norm'] else: cond_norm = None num_scenes, num_imgs, h, w, _ = cond_imgs.size() cameras = torch.cat([cond_intrinsics, cond_poses.reshape(num_scenes, num_imgs, -1)], dim=-1) if self.if_include_video_ref_img: # new! we render to all view, including the input image; # don't compute this loss cameras = cameras[:,1:] num_imgs = num_imgs - 1 if self.inputs_front_only: # default setting to use the first image as input inputs_img_idx = [0] else: raise NotImplementedError("inputs_front_only is False") inputs_img = cond_imgs[:,inputs_img_idx[0],...].permute([0,3,1,2]) # target_imgs = cond_imgs[:, 1:] assert cameras.shape[1] == target_imgs.shape[1] if self.is_debug: try: code = self.forward_image_to_uv(inputs_img, is_training=self.training) #TODO check where the validation except Exception as e: # OOM main_print(e) code = torch.zeros([num_scenes, 32, 256, 256]).to(inputs_img.dtype).to(inputs_img.device) else: code = self.forward_image_to_uv(inputs_img, is_training=self.training) #TODO check where the validation decoder = self.decoder_ema if self.freeze_decoder and self.decoder_use_ema else self.decoder # uvmaps_decoder_gender's forward output = decoder( code, smpl_params, cameras, num_imgs, return_loss=False, init=(self.global_step < self.init_iter), return_norm=False) #(['scales', 'norm', 'image', 'offset']) output['code'] = code output['target_imgs'] = target_imgs output['inputs_img'] = cond_imgs[:,[0],...] # for visualization if self.global_rank == 0 and self.global_step % 200 == 0 and self.is_debug: overlay_imgs = 0.5 * target_imgs + 0.5 * output['image'] overlay_imgs = rearrange(overlay_imgs, 'b n h w c -> b h n w c') overlay_imgs = rearrange(overlay_imgs, ' b h n w c -> (b h) (n w) c') overlay_imgs = overlay_imgs.to(torch.float32).detach().cpu().numpy() overlay_imgs = (overlay_imgs * 255).astype(np.uint8) Image.fromarray(overlay_imgs).save(f'debug_{self.global_step}.jpg') return output def forward_image_to_uv(self, inputs_img, is_training=True): ''' inputs_img: torch.Tensor, bs, 3, H, W return code : bs, 256, 256, 32 ''' if self.decoder_learning_rate <= 0: with torch.no_grad(): features_flatten = self.encoder(inputs_img, use_my_proces=True, output_hidden_states=self.output_hidden_states) else: features_flatten = self.encoder(inputs_img, use_my_proces=True, output_hidden_states=self.output_hidden_states) if self.ids_restore.device !=features_flatten.device: self.ids_restore = self.ids_restore.to(features_flatten.device) ids_restore = self.ids_restore.expand([features_flatten.shape[0], -1]) uv_code = self.neck(features_flatten, ids_restore) batch_size, token_num, dims_feature = uv_code.shape if self.reshape_type=='reshape': feature_map = uv_code.reshape(batch_size, self.num_patches_axis, self.num_patches_axis,\ self.code_feat_dims, self.code_patch_size, self.code_patch_size) # torch.Size([1, 64, 64, 32, 4, 4, ]) feature_map = feature_map.permute(0, 3, 1, 4, 2, 5) # ([1, 32, 64, 4, 64, 4]) feature_map = feature_map.reshape(batch_size, self.code_feat_dims, self.code_resolution, self.code_resolution) # torch.Size([1, 32, 256, 256]) code = feature_map # [1, 32, 256, 256] else: feature_map = uv_code.reshape(batch_size, self.num_patches_axis, self.num_patches_axis,dims_feature) # torch.Size([1, 64, 64, 512, ]) if isinstance(self.decoder, DistributedDataParallel): code = self.decoder.module.upsample_conv(feature_map.permute([0,3,1,2])) # torch.Size([1, 32, 256, 256]) else: code = self.decoder.upsample_conv(feature_map.permute([0,3,1,2])) # torch.Size([1, 32, 256, 256]) code = self.code_activation(code) return code def compute_loss(self, render_out): render_images = render_out['image'] # .Size([1, 5, 896, 640, 3]), range [0, 1] target_images = render_out['target_imgs'] target_images =target_images.to(render_images) if self.is_debug: render_images_tmp= rearrange(render_images, 'b n h w c -> (b n) c h w') target_images_tmp = rearrange(target_images, 'b n h w c -> (b n) c h w') all_images = torch.cat([render_images_tmp, target_images_tmp], dim=2) all_images = render_images_tmp*0.5 + target_images_tmp*0.5 grid = make_grid(all_images, nrow=4, normalize=True, value_range=(0, 1)) save_image(grid, "./debug.png") main_print("saving into ./debug.png") render_images = rearrange(render_images, 'b n h w c -> (b n) c h w') * 2.0 - 1.0 target_images = rearrange(target_images, 'b n h w c -> (b n) c h w') * 2.0 - 1.0 if self.lambda_mse<=0: loss_mse = 0 else: if self.loss_weights_views.numel() != 0: b, n, _, _, _ = render_out['image'].shape loss_weights_views = self.loss_weights_views.unsqueeze(0).to(render_images.device) loss_weights_views = loss_weights_views.repeat(b,1).reshape(-1).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) loss_mse = weighted_mse_loss(render_images, target_images, loss_weights_views) main_print("weighted sum mse") else: loss_mse = F.mse_loss(render_images, target_images) if self.lambda_l1<=0: loss_l1 = 0 else: loss_l1 = F.l1_loss(render_images, target_images) if self.lambda_ssim <= 0: loss_ssim = 0 else: loss_ssim = 1 - self.ssim(render_images, target_images) if not self.is_debug: if self.lambda_lpips<=0: loss_lpips = 0 else: if self.loss_weights_views.numel() != 0: with torch.cuda.amp.autocast(): loss_lpips = self.lpips(render_images.clamp(-1, 1), target_images) else: loss_lpips = 0 with torch.cuda.amp.autocast(): for img_idx in range(render_images.shape[0]): loss_lpips += self.lpips(render_images[[img_idx]].clamp(-1, 1), target_images[[img_idx]]) loss_lpips /= render_images.shape[0] else: loss_lpips = 0 loss_gs_offset = render_out['offset'] loss = loss_mse * self.lambda_mse \ + loss_l1 * self.lambda_l1 \ + loss_ssim * self.lambda_ssim \ + loss_lpips * self.lambda_lpips \ + loss_gs_offset * self.lambda_offset prefix = 'train' loss_dict = {} loss_dict.update({f'{prefix}/loss_mse': loss_mse}) loss_dict.update({f'{prefix}/loss_lpips': loss_lpips}) loss_dict.update({f'{prefix}/loss_gs_offset': loss_gs_offset}) loss_dict.update({f'{prefix}/loss_ssim': loss_ssim}) loss_dict.update({f'{prefix}/loss_l1': loss_l1}) loss_dict.update({f'{prefix}/loss': loss}) return loss, loss_dict def compute_metrics(self, render_out): # NOTE: all the rgb value range is [0, 1] # render_out.keys = (['scales', 'norm', 'image', 'offset', 'code', 'target_imgs']) render_images = render_out['image'].clamp(0, 1) # .Size([1, 5, 896, 640, 3]), range [0, 1] target_images = render_out['target_imgs'] if target_images.dtype!=render_images.dtype: target_images = target_images.to(render_images.dtype) render_images = rearrange(render_images, 'b n h w c -> (b n) c h w') target_images = rearrange(target_images, 'b n h w c -> (b n) c h w').to(render_images) mse = F.mse_loss(render_images, target_images).mean() psnr = 10 * torch.log10(1.0 / mse) ssim = self.ssim(render_images, target_images) render_images = render_images * 2.0 - 1.0 target_images = target_images * 2.0 - 1.0 if self.lambda_lpips<=0: lpips = torch.Tensor([0]).to(render_images.device).to(render_images.dtype) else: with torch.cuda.amp.autocast(): lpips = self.lpips(render_images, target_images) metrics = { 'val/mse': mse, 'val/pnsr': psnr, 'val/ssim': ssim, 'val/lpips': lpips, } return metrics def new_on_before_optimizer_step(self): norms = grad_norm(self.neck, norm_type=2) if 'grad_2.0_norm_total' in norms: self.log_dict({'grad_norm/lrm_generator': norms['grad_2.0_norm_total']}) @torch.no_grad() def validation_step(self, batch, batch_idx): render_out = self.forward(batch) metrics = self.compute_metrics(render_out) self.validation_metrics.append(metrics) render_images = render_out['image'] render_images = rearrange(render_images, 'b n h w c -> b c h (n w)') gt_images = render_out['target_imgs'] gt_images = rearrange(gt_images, 'b n h w c-> b c h (n w)') log_images = torch.cat([render_images, gt_images], dim=-2) self.validation_step_outputs.append(log_images) self.validation_step_code_outputs.append( render_out['code']) render_out_comb = self.forward_nvPose(batch, smplx_given=None) self.validation_step_nvPose_outputs.append(render_out_comb) def forward_nvPose(self, batch, smplx_given): ''' smplx_given: torch.Tensor, bs, 189 it will returns images with cameras_num * poses_num ''' _, num_img, _,_ = batch['cond_poses'].shape # write a code to seperately input the smplx_params if smplx_given == None: step_pose = self.smplx_params.shape[0] // num_img smplx_given = self.smplx_params else: step_pose = 1 render_out_list = [] for i in range(num_img): target_pose = smplx_given[[i*step_pose]] bk = batch['cond_smpl_param'].clone() batch['cond_smpl_param'][:, 7:70] = target_pose[:, 7:70] # copy body_pose batch['cond_smpl_param'][:, 80:80+93] = target_pose[:, 80:80+93]# copy pose_hand + pose_jaw batch['cond_smpl_param'][:, 179:189] = target_pose[:, 179:189]# copy face expression render_out_new = self.forward(batch) render_out_list.append(render_out_new['image']) render_out_comb = torch.cat(render_out_list, dim=2) # stack in the H axis render_out_comb = rearrange(render_out_comb, 'b n h w c -> b c h (n w)') return render_out_comb def on_validation_epoch_end(self): # images = torch.cat(self.validation_step_outputs, dim=-1) all_images = self.all_gather(images).cpu() all_images = rearrange(all_images, 'r b c h w -> (r b) c h w') # nv pose images_pose = torch.cat(self.validation_step_nvPose_outputs, dim=-1) all_images_pose = self.all_gather(images_pose).cpu() all_images_pose = rearrange(all_images_pose, 'r b c h w -> (r b) c h w') if self.global_rank == 0: image_path = os.path.join(self.logdir, 'images_val', f'val_{self.global_step:07d}.png') grid = make_grid(all_images, nrow=1, normalize=True, value_range=(0, 1)) save_image(grid, image_path) main_print(f"Saved image to {image_path}") metrics = {} for key in self.validation_metrics[0].keys(): metrics[key] = torch.stack([m[key] for m in self.validation_metrics]).mean() self.log_dict(metrics, prog_bar=True, logger=True, on_step=False, on_epoch=True) # code for saving the nvPose images image_path_nvPose = os.path.join(self.logdir, 'images_val', f'val_{self.global_step:07d}_nvPose.png') grid_nvPose = make_grid(all_images_pose, nrow=1, normalize=True, value_range=(0, 1)) save_image(grid_nvPose, image_path_nvPose) main_print(f"Saved image to {image_path_nvPose}") # code for saving the code images for i, code in enumerate(self.validation_step_code_outputs): image_path = os.path.join(self.logdir, 'images_val_code') num_scenes, num_chn, h, w = code.size() code_viz = code.reshape(num_scenes, 4, 8, h, w).to(torch.float32).cpu().numpy() code_viz = code_viz.transpose(0, 1, 3, 2, 4).reshape(num_scenes, 4 * h, 8 * w) for j, code_viz_single in enumerate(code_viz): plt.imsave(os.path.join(image_path, f'val_{self.global_step:07d}_{i*num_scenes+j:04d}' + '.png'), code_viz_single, vmin=self.code_clip_range[0], vmax=self.code_clip_range[1]) self.validation_step_outputs.clear() self.validation_step_nvPose_outputs.clear() self.validation_metrics.clear() self.validation_step_code_outputs.clear() def on_test_start(self): if self.global_rank == 0: os.makedirs(os.path.join(self.logdir, 'images_test'), exist_ok=True) def on_test_epoch_end(self): metrics = {} metrics_mean = {} metrics_var = {} for key in self.validation_metrics[0].keys(): tmp = torch.stack([m[key] for m in self.validation_metrics]).cpu().numpy() metrics_mean[key] = tmp.mean() metrics_var[key] = tmp.var() formatted_metrics = {} for key in metrics_mean.keys(): formatted_metrics[key] = f"{metrics_mean[key]:.4f}±{metrics_var[key]:.4f}" for key in self.validation_metrics[0].keys(): metrics[key] = torch.stack([m[key] for m in self.validation_metrics]).cpu().numpy().tolist() final_dict = {"average": formatted_metrics, 'details': metrics} metric_path = os.path.join(self.logdir, f'metrics.json') with open(metric_path, 'w') as f: json.dump(final_dict, f, indent=4) main_print(f"Saved metrics to {metric_path}") for key in metrics.keys(): metrics[key] = torch.stack([m[key] for m in self.validation_metrics]).mean() main_print(metrics) self.validation_metrics.clear() def configure_optimizers(self): # define the optimizer and the scheduler for neck and decoder main_print("WARNING currently, we only support the single optimizer for both neck and decoder") learning_rate = self.neck_learning_rate params= [ {'params': self.neck.parameters(), 'lr': self.neck_learning_rate, }, {'params': self.decoder.parameters(), 'lr': self.decoder_learning_rate}, ] if hasattr(self, "encoder_learning_rate") and self.encoder_learning_rate>0: params.append({'params': self.encoder.parameters(), 'lr': self.encoder_learning_rate}) main_print("============add the encoder into the optimizer============") optimizer = torch.optim.Adam( params ) T_warmup, T_max, eta_min = self.warmup_steps, self.max_steps, 0.001 lr_lambda = lambda step: \ eta_min + (1 - math.cos(math.pi * step / T_warmup)) * (1 - eta_min) * 0.5 if step < T_warmup else \ eta_min + (1 + math.cos(math.pi * (step - T_warmup) / (T_max - T_warmup))) * (1 - eta_min) * 0.5 scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda) return {'optimizer': optimizer, 'lr_scheduler': scheduler} def training_step(self, batch, batch_idx): scheduler = self.lr_schedulers() scheduler.step() render_gt = None #? render_out = self.forward(batch) loss, loss_dict = self.compute_loss(render_out) self.log_dict(loss_dict, prog_bar=True, logger=True, on_step=True, on_epoch=True) if self.global_step % 200 == 0 and self.global_rank == 0: self.new_on_before_optimizer_step() # log the norm if self.global_step % 200 == 0 and self.global_rank == 0: if self.if_include_video_ref_img and self.training: render_images = torch.cat([ torch.ones_like(render_out['image'][:,0:1]), render_out['image']], dim=1) target_images = torch.cat([ render_out['inputs_img'], render_out['target_imgs']], dim=1) target_images = rearrange( target_images, 'b n h w c -> b c h (n w)') render_images = rearrange( render_images, 'b n h w c-> b c h (n w)') grid = torch.cat([ target_images, render_images, 0.5*render_images + 0.5*target_images, ], dim=-2) grid = make_grid(grid, nrow=target_images.shape[0], normalize=True, value_range=(0, 1)) image_path = os.path.join(self.logdir, 'images', f'train_{self.global_step:07d}.jpg') save_image(grid, image_path) main_print(f"Saved image to {image_path}") return loss @torch.no_grad() def test_step(self, batch, batch_idx): # input_dict, render_gt = self.prepare_validation_batch_data(batch) render_out = self.forward(batch) render_gt = render_out['target_imgs'] render_img = render_out['image'] # Compute metrics metrics = self.compute_metrics(render_out) self.validation_metrics.append(metrics) # Save images target_images = rearrange( render_gt, 'b n h w c -> b c h (n w)') render_images = rearrange( render_img, 'b n h w c -> b c h (n w)') grid = torch.cat([ target_images, render_images, ], dim=-2) grid = make_grid(grid, nrow=target_images.shape[0], normalize=True, value_range=(0, 1)) # self.logger.log_image('train/render', [grid], step=self.global_step) image_path = os.path.join(self.logdir, 'images_test', f'{batch_idx:07d}.png') save_image(grid, image_path) # code visualize code = render_out['code'] self.decoder.visualize(code, batch['scene_name'], os.path.dirname(image_path), code_range=self.code_clip_range) print(f"Saved image to {image_path}") def on_test_start(self): if self.global_rank == 0: os.makedirs(os.path.join(self.logdir, 'images_test'), exist_ok=True) def on_test_epoch_end(self): metrics = {} metrics_mean = {} metrics_var = {} for key in self.validation_metrics[0].keys(): tmp = torch.stack([m[key] for m in self.validation_metrics]).cpu().numpy() metrics_mean[key] = tmp.mean() metrics_var[key] = tmp.var() # trans format into "mean±var" formatted_metrics = {} for key in metrics_mean.keys(): formatted_metrics[key] = f"{metrics_mean[key]:.4f}±{metrics_var[key]:.4f}" for key in self.validation_metrics[0].keys(): metrics[key] = torch.stack([m[key] for m in self.validation_metrics]).cpu().numpy().tolist() # saving into a dictionary final_dict = {"average": formatted_metrics, 'details': metrics} metric_path = os.path.join(self.logdir, f'metrics.json') with open(metric_path, 'w') as f: json.dump(final_dict, f, indent=4) print(f"Saved metrics to {metric_path}") for key in metrics.keys(): metrics[key] = torch.stack([m[key] for m in self.validation_metrics]).mean() print(metrics) self.validation_metrics.clear() def weighted_mse_loss(render_images, target_images, weights): squared_diff = (render_images - target_images) ** 2 main_print(squared_diff.shape, weights.shape) weighted_squared_diff = squared_diff * weights loss_mse_weighted = weighted_squared_diff.mean() return loss_mse_weighted ================================================ FILE: lib/mmutils/__init__.py ================================================ from .initialize import xavier_init, constant_init ================================================ FILE: lib/mmutils/initialize.py ================================================ import torch.nn as nn def constant_init(module: nn.Module, val: float, bias: float = 0) -> None: if hasattr(module, 'weight') and module.weight is not None: nn.init.constant_(module.weight, val) if hasattr(module, 'bias') and module.bias is not None: nn.init.constant_(module.bias, bias) def xavier_init(module: nn.Module, gain: float = 1, bias: float = 0, distribution: str = 'normal') -> None: assert distribution in ['uniform', 'normal'] if hasattr(module, 'weight') and module.weight is not None: if distribution == 'uniform': nn.init.xavier_uniform_(module.weight, gain=gain) else: nn.init.xavier_normal_(module.weight, gain=gain) if hasattr(module, 'bias') and module.bias is not None: nn.init.constant_(module.bias, bias) ================================================ FILE: lib/models/__init__.py ================================================ from .decoders import * ================================================ FILE: lib/models/decoders/__init__.py ================================================ from .uvmaps_decoder_gender import UVNDecoder_gender __all__ = [ 'UVNDecoder_gender'] ================================================ FILE: lib/models/decoders/uvmaps_decoder_gender.py ================================================ import os import matplotlib.pyplot as plt import torch import torch.nn as nn import torch.nn.functional as F from torch import einsum from pytorch3d import ops from lib.mmutils import xavier_init, constant_init import numpy as np import time import cv2 import math from simple_knn._C import distCUDA2 from pytorch3d.transforms import quaternion_to_matrix from ..deformers import SMPLXDeformer_gender from ..renderers import GRenderer, get_covariance, batch_rodrigues from lib.ops import TruncExp import torchvision from lib.utils.train_util import main_print def ensure_dtype(input_tensor, target_dtype=torch.float32): """ Ensure tensor dtype matches target dtype. If not, convert it. """ if input_tensor.dtype != target_dtype: input_tensor = input_tensor.to(dtype=target_dtype) return input_tensor class UVNDecoder_gender(nn.Module): activation_dict = { 'relu': nn.ReLU, 'silu': nn.SiLU, 'softplus': nn.Softplus, 'trunc_exp': TruncExp, 'sigmoid': nn.Sigmoid} def __init__(self, *args, interp_mode='bilinear', base_layers=[3 * 32, 128], density_layers=[128, 1], color_layers=[128, 128, 3], offset_layers=[128, 3], scale_layers=[128, 3], radius_layers=[128, 3], use_dir_enc=True, dir_layers=None, scene_base_size=None, scene_rand_dims=(0, 1), activation='silu', sigma_activation='sigmoid', sigmoid_saturation=0.001, code_dropout=0.0, flip_z=False, extend_z=False, gender='neutral', multires=0, bg_color=0, image_size=1024, superres=False, focal = 1280, # the default focal defination reshape_type=None, # if true, it will create a cnn layers to upsample the uv features fix_sigma=False, # if true, the density of GS will be fixed up_cnn_in_channels = None, # the channel number of the upsample cnn vithead_param=None, # the vit head for decode to uv features is_sub2=False, # if true, will use the sub2 uv map **kwargs): super().__init__() self.interp_mode = interp_mode self.in_chn = base_layers[0] self.use_dir_enc = use_dir_enc if scene_base_size is None: self.scene_base = None else: rand_size = [1 for _ in scene_base_size] for dim in scene_rand_dims: rand_size[dim] = scene_base_size[dim] init_base = torch.randn(rand_size).expand(scene_base_size).clone() self.scene_base = nn.Parameter(init_base) self.dir_encoder = None self.sigmoid_saturation = sigmoid_saturation self.deformer = SMPLXDeformer_gender(gender, is_sub2=is_sub2) self.renderer = GRenderer(image_size=image_size, bg_color=bg_color, f=focal) if superres: self.superres = None else: self.superres = None self.gender= gender self.reshape_type = reshape_type if reshape_type=='cnn': self.upsample_conv = torch.nn.ConvTranspose2d(512, 32, kernel_size=4, stride=4,).cuda() elif reshape_type == 'VitHead': # changes the up block's layernorm into the feature channel norm instead of the full image norm from lib.models.decoders.vit_head import VitHead self.upsample_conv = VitHead(**vithead_param) # 256, 128, 128 -> 128, 256, 256 -> 64, 512, 512, ->32, 1024, 1024 base_cache_dir = 'work_dirs/cache' if is_sub2: base_cache_dir = 'work_dirs/cache_sub2' # main_print("!!!!!!!!!!!!!!!!!!! using the sub2 uv map !!!!!!!!!!!!!!!!!!!") if gender == 'neutral': select_uv = torch.as_tensor(np.load(base_cache_dir+'/init_uv_smplx_newNeutral.npy')) self.register_buffer('select_coord', select_uv.unsqueeze(0)*2.-1.) init_pcd = torch.as_tensor(np.load(base_cache_dir+'/init_pcd_smplx_newNeutral.npy')) self.register_buffer('init_pcd', init_pcd.unsqueeze(0), persistent=False) # 0.9-- -1 elif gender == 'male': assert NotImplementedError("Haven't create the init_uv_smplx_thu in v_template") select_uv = torch.as_tensor(np.load(base_cache_dir+'/init_uv_smplx_thu.npy')) self.register_buffer('select_coord', select_uv.unsqueeze(0)*2.-1.) init_pcd = torch.as_tensor(np.load(base_cache_dir+'/init_pcd_smplx_thu.npy')) self.register_buffer('init_pcd', init_pcd.unsqueeze(0), persistent=False) # 0.9-- -1 self.num_init = self.init_pcd.shape[1] main_print(f"!!!!!!!!!!!!!!!!!!! cur points number are {self.num_init} !!!!!!!!!!!!!!!!!!!") self.init_pcd = self.init_pcd self.multires = multires # 0 Haven't if multires > 0: uv_map = torch.as_tensor(np.load(base_cache_dir+'/init_uvmap_smplx_thu.npy')) pcd_map = torch.as_tensor(np.load(base_cache_dir+'/init_posmap_smplx_thu.npy')) input_coord = torch.cat([pcd_map, uv_map], dim=1) self.register_buffer('input_freq', input_coord, persistent=False) base_layers[0] += 5 color_layers[0] += 5 else: self.init_uv = None activation_layer = self.activation_dict[activation.lower()] base_net = [] # linear (in=18, out=64, bias=True) for i in range(len(base_layers) - 1): base_net.append(nn.Conv2d(base_layers[i], base_layers[i + 1], 3, padding=1)) if i != len(base_layers) - 2: base_net.append(nn.BatchNorm2d(base_layers[i+1])) base_net.append(activation_layer()) self.base_net = nn.Sequential(*base_net) self.base_bn = nn.BatchNorm2d(base_layers[-1]) self.base_activation = activation_layer() density_net = [] # linear(in=64, out=1, bias=True), sigmoid for i in range(len(density_layers) - 1): density_net.append(nn.Conv2d(density_layers[i], density_layers[i + 1], 1)) if i != len(density_layers) - 2: density_net.append(nn.BatchNorm2d(density_layers[i+1])) density_net.append(activation_layer()) density_net.append(self.activation_dict[sigma_activation.lower()]()) self.density_net = nn.Sequential(*density_net) offset_net = [] # linear(in=64, out=1, bias=True), sigmoid for i in range(len(offset_layers) - 1): offset_net.append(nn.Conv2d(offset_layers[i], offset_layers[i + 1], 1)) if i != len(offset_layers) - 2: offset_net.append(nn.BatchNorm2d(offset_layers[i+1])) offset_net.append(activation_layer()) self.offset_net = nn.Sequential(*offset_net) self.dir_net = None color_net = [] # linear(in=64, out=3, bias=True), sigmoid for i in range(len(color_layers) - 2): color_net.append(nn.Conv2d(color_layers[i], color_layers[i + 1], kernel_size=3, padding=1)) color_net.append(nn.BatchNorm2d(color_layers[i+1])) color_net.append(activation_layer()) color_net.append(nn.Conv2d(color_layers[-2], color_layers[-1], kernel_size=1)) color_net.append(nn.Sigmoid()) self.color_net = nn.Sequential(*color_net) self.code_dropout = nn.Dropout2d(code_dropout) if code_dropout > 0 else None self.flip_z = flip_z self.extend_z = extend_z if self.gender == 'neutral': init_rot = torch.as_tensor(np.load(base_cache_dir+'/init_rot_smplx_newNeutral.npy')) self.register_buffer('init_rot', init_rot, persistent=False) face_mask = torch.as_tensor(np.load(base_cache_dir+'/face_mask_thu_newNeutral.npy')) self.register_buffer('face_mask', face_mask.unsqueeze(0), persistent=False) hands_mask = torch.as_tensor(np.load(base_cache_dir+'/hands_mask_thu_newNeutral.npy')) self.register_buffer('hands_mask', hands_mask.unsqueeze(0), persistent=False) outside_mask = torch.as_tensor(np.load(base_cache_dir+'/outside_mask_thu_newNeutral.npy')) self.register_buffer('outside_mask', outside_mask.unsqueeze(0), persistent=False) else: assert NotImplementedError("Haven't create the init_rot in v_template") init_rot = torch.as_tensor(np.load(base_cache_dir+'/init_rot_smplx_thu.npy')) self.register_buffer('init_rot', init_rot, persistent=False) face_mask = torch.as_tensor(np.load(base_cache_dir+'/face_mask_thu.npy')) self.register_buffer('face_mask', face_mask.unsqueeze(0), persistent=False) hands_mask = torch.as_tensor(np.load(base_cache_dir+'/hands_mask_thu.npy')) self.register_buffer('hands_mask', hands_mask.unsqueeze(0), persistent=False) outside_mask = torch.as_tensor(np.load(base_cache_dir+'/outside_mask_thu.npy')) self.register_buffer('outside_mask', outside_mask.unsqueeze(0), persistent=False) self.iter = 0 self.init_weights() self.if_rotate_gaussian = False self.fix_sigma = fix_sigma def init_weights(self): for m in self.modules(): if isinstance(m, nn.Linear): xavier_init(m, distribution='uniform') if self.dir_net is not None: constant_init(self.dir_net[-1], 0) if self.offset_net is not None: self.offset_net[-1].weight.data.uniform_(-1e-5, 1e-5) self.offset_net[-1].bias.data.zero_() def extract_pcd(self, code, smpl_params, init=False, zeros_hands_off=False): ''' Args: B == num_scenes code (tensor): latent code. shape: [B, C, H, W] smpl_params (tensor): SMPL parameters. shape: [B_pose, 189] init (bool): Not used Returns: defm_pcd (tensor): deformed point cloud. shape: [B, N, B_pose, 3] sigmas, rgbs, offset, radius, rot(tensor): GS attributes. shape: [B, N, C] tfs(tensor): deformation matrics. shape: [B, N, C] ''' if isinstance(code, list): num_scenes, _, h, w = code[0].size() else: num_scenes, n_channels, h, w = code.size() init_pcd = self.init_pcd.repeat(num_scenes, 1, 1) # T-posed space points, for computing the skinning weights sigmas, rgbs, radius, rot, offset = self._decode(code, init=init) # the person-specify attributes of GS if self.fix_sigma: sigmas = torch.ones_like(sigmas) if zeros_hands_off: offset[self.hands_mask[...,None].expand(num_scenes, -1, 3)] = 0 canon_pcd = init_pcd + offset self.deformer.prepare_deformer(smpl_params, num_scenes, device=canon_pcd.device) defm_pcd, tfs = self.deformer(canon_pcd, rot, mask=(self.face_mask+self.hands_mask+self.outside_mask), cano=False, if_rotate_gaussian=self.if_rotate_gaussian) return defm_pcd, sigmas, rgbs, offset, radius, tfs, rot def deform_pcd(self, code, smpl_params, init=False, zeros_hands_off=False, value=0.1): ''' Args: B == num_scenes code (List): list of data smpl_params (tensor): SMPL parameters. shape: [B_pose, 189] init (bool): Not used Returns: defm_pcd (tensor): deformed point cloud. shape: [B, N, B_pose, 3] sigmas, rgbs, offset, radius, rot(tensor): GS attributes. shape: [B, N, C] tfs(tensor): deformation matrics. shape: [B, N, C] ''' sigmas, rgbs, radius, rot, offset = code num_scenes = sigmas.shape[0] init_pcd = self.init_pcd.repeat(num_scenes, 1, 1) #T-posed space points, for computing the skinning weights if self.fix_sigma: sigmas = torch.ones_like(sigmas) if zeros_hands_off: offset[self.hands_mask[...,None].expand(num_scenes, -1, 3)] = torch.clamp(offset[self.hands_mask[...,None].expand(num_scenes, -1, 3)], -value, value) canon_pcd = init_pcd + offset self.deformer.prepare_deformer(smpl_params, num_scenes, device=canon_pcd.device) defm_pcd, tfs = self.deformer(canon_pcd, rot, mask=(self.face_mask+self.hands_mask+self.outside_mask), cano=False, if_rotate_gaussian=self.if_rotate_gaussian) return defm_pcd, sigmas, rgbs, offset, radius, tfs, rot def _sample_feature(self,results,): # outputs, sigma_uv, offset_uv, rgbs_uv, radius_uv, rot_uv = results['output'], results['sigma'], results['offset'], results['rgbs'], results['radius'], results['rot'] sigma = results['sigma'] outputs = results['output'] if isinstance(sigma, list): num_scenes, _, h, w = sigma[0].shape select_coord = self.select_coord.unsqueeze(1).repeat(num_scenes, 1, 1, 1) elif sigma.dim() == 4: num_scenes, n_channels, h, w = sigma.shape select_coord = self.select_coord.unsqueeze(1).repeat(num_scenes, 1, 1, 1) else: assert False output_attr = F.grid_sample(outputs, select_coord, mode=self.interp_mode, padding_mode='border', align_corners=False).reshape(num_scenes, 13, -1).permute(0, 2, 1) sigma, offset, rgbs, radius, rot = output_attr.split([1, 3, 3, 3, 3], dim=2) if self.sigmoid_saturation > 0: rgbs = rgbs * (1 + self.sigmoid_saturation * 2) - self.sigmoid_saturation radius = (radius - 0.5) * 2 rot = (rot - 0.5) * np.pi return sigma, rgbs, radius, rot, offset def _decode_feature(self, point_code, init=False): if isinstance(point_code, list): num_scenes, _, h, w = point_code[0].shape geo_code, tex_code = point_code # select_coord = self.select_coord.unsqueeze(1).repeat(num_scenes, 1, 1, 1) if self.multires != 0: input_freq = self.input_freq.repeat(num_scenes, 1, 1, 1) elif point_code.dim() == 4: num_scenes, n_channels, h, w = point_code.shape # select_coord = self.select_coord.unsqueeze(1).repeat(num_scenes, 1, 1, 1) if self.multires != 0: input_freq = self.input_freq.repeat(num_scenes, 1, 1, 1) geo_code, tex_code = point_code.split(16, dim=1) else: assert False base_in = geo_code if self.multires == 0 else torch.cat([geo_code, input_freq], dim=1) base_x = self.base_net(base_in) base_x_act = self.base_activation(self.base_bn(base_x)) sigma = self.density_net(base_x_act) offset = self.offset_net(base_x_act) color_in = tex_code if self.multires == 0 else torch.cat([tex_code, input_freq], dim=1) rgbs_radius_rot = self.color_net(color_in) outputs = torch.cat([sigma, offset, rgbs_radius_rot], dim=1) main_print(outputs.shape) sigma, offset, rgbs, radius, rot = outputs.split([1, 3, 3, 3, 3], dim=1) results = {'output':outputs, 'sigma': sigma, 'offset': offset, 'rgbs': rgbs, 'radius': radius, 'rot': rot} return results def _decode(self, point_code, init=False): if isinstance(point_code, list): num_scenes, _, h, w = point_code[0].shape geo_code, tex_code = point_code select_coord = self.select_coord.unsqueeze(1).repeat(num_scenes, 1, 1, 1) if self.multires != 0: input_freq = self.input_freq.repeat(num_scenes, 1, 1, 1) elif point_code.dim() == 4: num_scenes, n_channels, h, w = point_code.shape select_coord = self.select_coord.unsqueeze(1).repeat(num_scenes, 1, 1, 1) if self.multires != 0: input_freq = self.input_freq.repeat(num_scenes, 1, 1, 1) geo_code, tex_code = point_code.split(16, dim=1) else: assert False base_in = geo_code if self.multires == 0 else torch.cat([geo_code, input_freq], dim=1) base_x = self.base_net(base_in) base_x_act = self.base_activation(self.base_bn(base_x)) sigma = self.density_net(base_x_act) offset = self.offset_net(base_x_act) color_in = tex_code if self.multires == 0 else torch.cat([tex_code, input_freq], dim=1) rgbs_radius_rot = self.color_net(color_in) outputs = torch.cat([sigma, offset, rgbs_radius_rot], dim=1) output_attr = F.grid_sample(outputs, select_coord, mode=self.interp_mode, padding_mode='border', align_corners=False).reshape(num_scenes, 13, -1).permute(0, 2, 1) sigma, offset, rgbs, radius, rot = output_attr.split([1, 3, 3, 3, 3], dim=2) if self.sigmoid_saturation > 0: rgbs = rgbs * (1 + self.sigmoid_saturation * 2) - self.sigmoid_saturation radius = (radius - 0.5) * 2 rot = (rot - 0.5) * np.pi return sigma, rgbs, radius, rot, offset def gaussian_render(self, pcd, sigmas, rgbs, normals, rot, num_scenes, num_imgs, cameras, use_scale=False, radius=None, \ return_norm=False, return_viz=False, mask=None): # add mask or visible points to images or select ind to images ''' render the gaussian to images return_norm: return the normals of the gaussian (haven't been used) return_viz: return the mask of the gaussian mask: the mask of the gaussian ''' assert num_scenes == 1 pcd = pcd.reshape(-1, 3) if use_scale: dist2 = distCUDA2(pcd) dist2 = torch.clamp_min((dist2), 0.0000001) scales = torch.sqrt(dist2)[...,None].repeat(1, 3).detach() # distence between different points scale = (radius+1)*scales # scaling_modifier # radius[-1--1], scale of GS cov3D = get_covariance(scale, rot).reshape(-1, 6) # inputs rot is the rotations images_all = [] viz_masks = [] if return_viz else None norm_all = [] if return_norm else None if mask != None: pcd = pcd[mask] rgbs = rgbs[mask] sigmas = sigmas[mask] cov3D = cov3D[mask] normals = normals[mask] if 1: for i in range(num_imgs): self.renderer.prepare(cameras[i]) image = self.renderer.render_gaussian(means3D=pcd, colors_precomp=rgbs, rotations=None, opacities=sigmas, scales=None, cov3D_precomp=cov3D) images_all.append(image) if return_viz: viz_mask = self.renderer.render_gaussian(means3D=pcd, colors_precomp=pcd.clone(), rotations=None, opacities=sigmas*0+1, scales=None, cov3D_precomp=cov3D) viz_masks.append(viz_mask) images_all = torch.stack(images_all, dim=0).unsqueeze(0).permute(0, 1, 3, 4, 2) if return_viz: viz_masks = torch.stack(viz_masks, dim=0).unsqueeze(0).permute(0, 1, 3, 4, 2).reshape(1, -1, 3) dist_sq, idx, neighbors = ops.knn_points(pcd.unsqueeze(0), viz_masks[:, ::10], K=1, return_nn=True) viz_masks = (dist_sq < 0.0001)[0] # ===== END the original code for batch size = 1 ===== if use_scale: return images_all, norm_all, viz_masks, scale else: return images_all, norm_all, viz_masks, None def visualize(self, code, scene_name, viz_dir, code_range=[-1, 1]): num_scenes, num_chn, h, w = code.size() code_viz = code.reshape(num_scenes, 4, 8, h, w).to(torch.float32).cpu().numpy() if not self.flip_z: code_viz = code_viz[..., ::-1, :] code_viz = code_viz.transpose(0, 1, 3, 2, 4).reshape(num_scenes, 4 * h, 8 * w) for code_single, code_viz_single, scene_name_single in zip(code, code_viz, scene_name): plt.imsave(os.path.join(viz_dir, 'a_scene_' + scene_name_single + '.png'), code_viz_single, vmin=code_range[0], vmax=code_range[1]) def forward(self, code, smpl_params, cameras, num_imgs, return_loss=False, return_norm=False, init=False, mask=None, zeros_hands_off=False): """ Args: density_bitfield: Shape (num_scenes, griz_size**3 // 8) YY: grid_size, dt_gamma, perturb, T_thresh are deleted code: Shape (num_scenes, *code_size) cameras: Shape (num_scenes, num_imgs, 19(3+16)) smpl_params: Shape (num_scenes, 189) """ # import ipdb; ipdb.set_trace() if isinstance(code, list): num_scenes = len(code[0]) else: num_scenes = len(code) assert num_scenes > 0 self.iter+=1 image = [] scales = [] norm = [] if return_norm else None viz_masks = [] if not self.training else None xyzs, sigmas, rgbs, offsets, radius, tfs, rot = self.extract_pcd(code, smpl_params, init=init, zeros_hands_off=zeros_hands_off) if zeros_hands_off: main_print('zeros_hands_off is on!') main_print('zeros_hands_off is on!') offsets[self.hands_mask[...,None].repeat(num_scenes, 1, 3)] = 0 rgbs[self.hands_mask[...,None].repeat(num_scenes, 1, 3)] = 0 rgbs[self.hands_mask[...,None].repeat(num_scenes, 1, 3)] = 0 R_delta = batch_rodrigues(rot.reshape(-1, 3)) R = torch.bmm(self.init_rot.repeat(num_scenes, 1, 1), R_delta) R_def = torch.bmm(tfs.flatten(0, 1)[:, :3, :3], R) normals = (R_def[:, :, -1]).reshape(num_scenes, -1, 3) R_def_batch = R_def.reshape(num_scenes, -1, 3, 3) return_to_bfloat16 = True if xyzs.dtype==torch.bfloat16 else False ####### ============ translate the output to BF16 ================= # return_to_bfloat16 = False # I don't want to trans it back to bf16 if return_to_bfloat16: main_print("changes the return_to_bfloat16") cameras, R_def_batch, xyzs, rgbs, sigmas, normals, radius = [ensure_dtype(item, torch.float32) for item in (cameras, R_def_batch, xyzs, rgbs, sigmas, normals, radius)] # with torch.amp.autocast(enabled=False, device_type='cuda'): if 1: for camera_single, R_def_single, pcd_single, rgbs_single, sigmas_single, normal_single, radius_single in zip(cameras, R_def_batch, xyzs, rgbs, sigmas, normals, radius): image_single, norm_single, viz_mask, scale = self.gaussian_render(pcd_single, sigmas_single, rgbs_single, normal_single, R_def_single, 1, num_imgs, camera_single, use_scale=True, \ radius=radius_single, return_norm=return_norm, return_viz=not self.training) image.append(image_single) scales.append(scale.unsqueeze(0)) if return_norm: norm.append(norm_single) if not self.training: viz_masks.append(viz_mask) image = torch.cat(image, dim=0) scales = torch.cat(scales, dim=0) norm = torch.cat(norm, dim=0) if return_norm else None viz_masks = torch.cat(viz_masks, dim=0) if (not self.training) and viz_masks else None main_print("not trans the rendered results to float16") if False: image = image.to(torch.bfloat16) scales = scales.to(torch.bfloat16) if return_norm: norm = norm.to(torch.bfloat16) if viz_masks is not None: viz_masks = viz_masks.to(torch.bfloat16) offsets = offsets.to(torch.bfloat16) if self.training: offset_dist = offsets ** 2 weighted_offset = torch.mean(offset_dist) + torch.mean(offset_dist[self.hands_mask.repeat(num_scenes, 1)]) #+ torch.mean(offset_dist[self.face_mask.repeat(num_scenes, 1)]) else: weighted_offset = offsets results = dict( viz_masks=viz_masks, scales=scales, norm=norm, image=image, offset=weighted_offset) if return_loss: results.update(decoder_reg_loss=self.loss()) return results def forward_render(self, code, cameras, num_imgs, return_loss=False, return_norm=False, init=False, mask=None, zeros_hands_off=False): """ Args: density_bitfield: Shape (num_scenes, griz_size**3 // 8) YY: grid_size, dt_gamma, perturb, T_thresh are deleted code: Shape (num_scenes, *code_size) cameras: Shape (num_scenes, num_imgs, 19(3+16)) smpl_params: Shape (num_scenes, 189) """ image = [] scales = [] norm = [] if return_norm else None viz_masks = [] if not self.training else None xyzs, sigmas, rgbs, offsets, radius, tfs, rot = code num_scenes = xyzs.shape[0] if zeros_hands_off: main_print('zeros_hands_off is on!') main_print('zeros_hands_off is on!') offsets[self.hands_mask[...,None].repeat(num_scenes, 1, 3)] = 0 rgbs[self.hands_mask[...,None].repeat(num_scenes, 1, 3)] = 0 rgbs[self.hands_mask[...,None].repeat(num_scenes, 1, 3)] = 0 R_delta = batch_rodrigues(rot.reshape(-1, 3)) R = torch.bmm(self.init_rot.repeat(num_scenes, 1, 1), R_delta) R_def = torch.bmm(tfs.flatten(0, 1)[:, :3, :3], R) normals = (R_def[:, :, -1]).reshape(num_scenes, -1, 3) R_def_batch = R_def.reshape(num_scenes, -1, 3, 3) # import ipdb; ipdb.set_trace() return_to_bfloat16 = True if xyzs.dtype==torch.bfloat16 else False ####### ============ translate the output to BF16 ================= # return_to_bfloat16 = False # I don't want to trans it back to bf16 if return_to_bfloat16: main_print("changes the return_to_bfloat16") cameras, R_def_batch, xyzs, rgbs, sigmas, normals, radius = [ensure_dtype(item, torch.float32) for item in (cameras, R_def_batch, xyzs, rgbs, sigmas, normals, radius)] if 1: for camera_single, R_def_single, pcd_single, rgbs_single, sigmas_single, normal_single, radius_single in zip(cameras, R_def_batch, xyzs, rgbs, sigmas, normals, radius): image_single, norm_single, viz_mask, scale = self.gaussian_render(pcd_single, sigmas_single, rgbs_single, normal_single, R_def_single, 1, num_imgs, camera_single, use_scale=True, \ radius=radius_single, return_norm=return_norm, return_viz=not self.training) image.append(image_single) scales.append(scale.unsqueeze(0)) if return_norm: norm.append(norm_single) if not self.training: viz_masks.append(viz_mask) image = torch.cat(image, dim=0) scales = torch.cat(scales, dim=0) norm = torch.cat(norm, dim=0) if return_norm else None viz_masks = torch.cat(viz_masks, dim=0) if (not self.training) and viz_masks else None main_print("not trans the rendered results to float16") if False: image = image.to(torch.bfloat16) scales = scales.to(torch.bfloat16) if return_norm: norm = norm.to(torch.bfloat16) if viz_masks is not None: viz_masks = viz_masks.to(torch.bfloat16) offsets = offsets.to(torch.bfloat16) if self.training: offset_dist = offsets ** 2 weighted_offset = torch.mean(offset_dist) + torch.mean(offset_dist[self.hands_mask.repeat(num_scenes, 1)]) #+ torch.mean(offset_dist[self.face_mask.repeat(num_scenes, 1)]) else: weighted_offset = offsets results = dict( viz_masks=viz_masks, scales=scales, norm=norm, image=image, offset=weighted_offset) if return_loss: results.update(decoder_reg_loss=self.loss()) return results def forward_testing_time(self, code, smpl_params, cameras, num_imgs, return_loss=False, return_norm=False, init=False, mask=None, zeros_hands_off=False): """ Args: density_bitfield: Shape (num_scenes, griz_size**3 // 8) YY: grid_size, dt_gamma, perturb, T_thresh are deleted code: Shape (num_scenes, *code_size) cameras: Shape (num_scenes, num_imgs, 19(3+16)) smpl_params: Shape (num_scenes, 189) """ if isinstance(code, list): num_scenes = len(code[0]) else: num_scenes = len(code) assert num_scenes > 0 self.iter+=1 image = [] scales = [] norm = [] if return_norm else None viz_masks = [] if not self.training else None start_time = time.time() xyzs, sigmas, rgbs, offsets, radius, tfs, rot = self.extract_pcd(code, smpl_params, init=init, zeros_hands_off=zeros_hands_off) end_time_to_3D = time.time() time_code_to_3d = end_time_to_3D- start_time R_delta = batch_rodrigues(rot.reshape(-1, 3)) R = torch.bmm(self.init_rot.repeat(num_scenes, 1, 1), R_delta) R_def = torch.bmm(tfs.flatten(0, 1)[:, :3, :3], R) normals = (R_def[:, :, -1]).reshape(num_scenes, -1, 3) R_def_batch = R_def.reshape(num_scenes, -1, 3, 3) if 1: for camera_single, R_def_single, pcd_single, rgbs_single, sigmas_single, normal_single, radius_single in zip(cameras, R_def_batch, xyzs, rgbs, sigmas, normals, radius): image_single, norm_single, viz_mask, scale = self.gaussian_render(pcd_single, sigmas_single, rgbs_single, normal_single, R_def_single, 1, num_imgs, camera_single, use_scale=True, \ radius=radius_single, return_norm=False, return_viz=not self.training) image.append(image_single) scales.append(scale.unsqueeze(0)) if return_norm: norm.append(norm_single) if not self.training: viz_masks.append(viz_mask) image = torch.cat(image, dim=0) scales = torch.cat(scales, dim=0) norm = torch.cat(norm, dim=0) if return_norm else None viz_masks = torch.cat(viz_masks, dim=0) if (not self.training) and viz_masks else None time_3D_to_img = time.time() - end_time_to_3D if False: image = image.to(torch.bfloat16) scales = scales.to(torch.bfloat16) if return_norm: norm = norm.to(torch.bfloat16) if viz_masks is not None: viz_masks = viz_masks.to(torch.bfloat16) offsets = offsets.to(torch.bfloat16) results = dict( image=image) return results, time_code_to_3d, time_3D_to_img ================================================ FILE: lib/models/decoders/vit_head.py ================================================ import torch import torch.nn as nn import torch.nn.functional as F from typing import Sequence, Tuple, Optional class VitHead(nn.Module): def __init__(self, in_channels: int, out_channels: int, deconv_out_channels: Optional[Sequence[int]] = (256, 256, 256), deconv_kernel_sizes: Optional[Sequence[int]] = (4, 4, 4), conv_out_channels: Optional[Sequence[int]] = None, conv_kernel_sizes: Optional[Sequence[int]] = None, ): super(VitHead, self).__init__() if deconv_out_channels: if deconv_kernel_sizes is None or len(deconv_out_channels) != len(deconv_kernel_sizes): raise ValueError( '"deconv_out_channels" and "deconv_kernel_sizes" should ' 'be integer sequences with the same length. Got ' f'mismatched lengths {deconv_out_channels} and ' f'{deconv_kernel_sizes}') self.deconv_layers = self._make_deconv_layers( in_channels=in_channels, layer_out_channels=deconv_out_channels, layer_kernel_sizes=deconv_kernel_sizes, ) in_channels = deconv_out_channels[-1] else: self.deconv_layers = nn.Identity() if conv_out_channels: if conv_kernel_sizes is None or len(conv_out_channels) != len(conv_kernel_sizes): raise ValueError( '"conv_out_channels" and "conv_kernel_sizes" should ' 'be integer sequences with the same length. Got ' f'mismatched lengths {conv_out_channels} and ' f'{conv_kernel_sizes}') self.conv_layers = self._make_conv_layers( in_channels=in_channels, layer_out_channels=conv_out_channels, layer_kernel_sizes=conv_kernel_sizes) in_channels = conv_out_channels[-1] else: self.conv_layers = nn.Identity() self.cls_seg = nn.Conv2d(in_channels, out_channels, kernel_size=1) def _make_conv_layers(self, in_channels: int, layer_out_channels: Sequence[int], layer_kernel_sizes: Sequence[int]) -> nn.Module: """Create convolutional layers by given parameters.""" layers = [] for out_channels, kernel_size in zip(layer_out_channels, layer_kernel_sizes): padding = (kernel_size - 1) // 2 layers.append(nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=1, padding=padding)) layers.append(nn.InstanceNorm2d(out_channels)) layers.append(nn.SiLU(inplace=True)) in_channels = out_channels return nn.Sequential(*layers) def _make_deconv_layers(self, in_channels: int, layer_out_channels: Sequence[int], layer_kernel_sizes: Sequence[int]) -> nn.Module: """Create deconvolutional layers by given parameters.""" layers = [] for out_channels, kernel_size in zip(layer_out_channels, layer_kernel_sizes): if kernel_size == 4: padding = 1 output_padding = 0 elif kernel_size == 3: padding = 1 output_padding = 1 elif kernel_size == 2: padding = 0 output_padding = 0 else: raise ValueError(f'Unsupported kernel size {kernel_size} for deconvolutional layers') layers.append(nn.ConvTranspose2d(in_channels, out_channels, kernel_size=kernel_size, stride=2, padding=padding, output_padding=output_padding, bias=False)) layers.append(nn.InstanceNorm2d(out_channels)) layers.append(nn.SiLU(inplace=True)) in_channels = out_channels return nn.Sequential(*layers) def forward(self, inputs): x = self.deconv_layers(inputs) x = self.conv_layers(x) out = self.cls_seg(x) return out if __name__ == "__main__": # Example usage: model = VitHead(in_channels=1536, out_channels=21, deconv_out_channels=(768, 768, 512, 512), deconv_kernel_sizes=(4, 4, 4, 4), conv_out_channels=(512, 256, 128), conv_kernel_sizes=(1, 1, 1), ) inputs = torch.randn(1, 1536, 64, 64) outputs = model(inputs) print(outputs.shape) ================================================ FILE: lib/models/deformers/__init__.py ================================================ from .smplx_deformer_gender import SMPLXDeformer_gender __all__ = ['SMPLXDeformer_gender'] ================================================ FILE: lib/models/deformers/fast_snarf/cuda/filter/filter.cpp ================================================ #include #include #include #include void launch_filter( torch::Tensor &output, const torch::Tensor &x, const torch::Tensor &mask); void filter(const torch::Tensor &x, const torch::Tensor &mask, torch::Tensor &output) { const at::cuda::OptionalCUDAGuard device_guard(device_of(x)); launch_filter(output, x, mask); } PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("filter", &filter); } ================================================ FILE: lib/models/deformers/fast_snarf/cuda/filter/filter_kernel.cu ================================================ #include #include #include #include #include #include #include #include #include "ATen/Functions.h" #include "ATen/core/TensorAccessor.h" #include "c10/cuda/CUDAException.h" #include "c10/cuda/CUDAStream.h" #include #define TensorAcc4R PackedTensorAccessor32 #define TensorAcc5R PackedTensorAccessor32 using namespace at; using namespace at::cuda::detail; template C10_LAUNCH_BOUNDS_1(512) __global__ void filter( const index_t nthreads, PackedTensorAccessor32 x, PackedTensorAccessor32 mask, PackedTensorAccessor32 output) { index_t n_batch = mask.size(0); index_t n_point = mask.size(1); index_t n_init = mask.size(2); CUDA_KERNEL_LOOP_TYPE(index, nthreads, index_t) { const index_t i_batch = index / (n_batch*n_point); const index_t i_point = index % (n_batch*n_point); for(index_t i = 0; i < n_init; i++) { if(!mask[i_batch][i_point][i]){ output[i_batch][i_point][i] = false; continue; } scalar_t xi0 = x[i_batch][i_point][i][0]; scalar_t xi1 = x[i_batch][i_point][i][1]; scalar_t xi2 = x[i_batch][i_point][i][2]; bool flag = true; for(index_t j = i+1; j < n_init; j++){ if(!mask[i_batch][i_point][j]){ continue; } scalar_t d0 = xi0 - x[i_batch][i_point][j][0]; scalar_t d1 = xi1 - x[i_batch][i_point][j][1]; scalar_t d2 = xi2 - x[i_batch][i_point][j][2]; scalar_t dist = d0*d0 + d1*d1 + d2*d2; if(dist<0.0001*0.0001){ flag=false; break; } } output[i_batch][i_point][i] = flag; } } } void launch_filter( Tensor &output, const Tensor &x, const Tensor &mask) { // calculate #threads required int64_t B = output.size(0); int64_t N = output.size(1); int64_t count = B*N; if (count > 0) { AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "filter", [&] { filter <<>>( static_cast(count), x.packed_accessor32(), mask.packed_accessor32(), output.packed_accessor32()); C10_CUDA_KERNEL_LAUNCH_CHECK(); }); } } ================================================ FILE: lib/models/deformers/fast_snarf/cuda/fuse_kernel/fuse_cuda.cpp ================================================ #include "ATen/Functions.h" #include "ATen/core/TensorBody.h" #include #include #include #include void launch_broyden_kernel(torch::Tensor &x, const torch::Tensor &xd_tgt, const torch::Tensor &grid, const torch::Tensor &grid_J_inv, const torch::Tensor &tfs, const torch::Tensor &bone_ids, bool align_corners, // torch::Tensor &J_inv, torch::Tensor &is_valid, const torch::Tensor& offset, const torch::Tensor& scale, float cvg_threshold, float dvg_threshold); void fuse_broyden(torch::Tensor &x, const torch::Tensor &xd_tgt, const torch::Tensor &grid, const torch::Tensor &grid_J_inv, const torch::Tensor &tfs, const torch::Tensor &bone_ids, bool align_corners, // torch::Tensor& J_inv, torch::Tensor &is_valid, torch::Tensor& offset, torch::Tensor& scale, float cvg_threshold, float dvg_threshold) { const at::cuda::OptionalCUDAGuard device_guard(device_of(x)); launch_broyden_kernel(x, xd_tgt, grid, grid_J_inv, tfs, bone_ids, align_corners, is_valid, offset, scale, cvg_threshold, dvg_threshold); return; } PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("fuse_broyden", &fuse_broyden); } ================================================ FILE: lib/models/deformers/fast_snarf/cuda/fuse_kernel/fuse_cuda_kernel.cu ================================================ #include "ATen/Functions.h" #include "ATen/core/TensorAccessor.h" #include "c10/cuda/CUDAException.h" #include "c10/cuda/CUDAStream.h" #include #include #include #include #include #include #include #include #include #include using namespace std::chrono; using namespace at; using namespace at::cuda::detail; template __device__ void fuse_J_inv_update(const index_t index, scalar_t* J_inv, scalar_t x0, scalar_t x1, scalar_t x2, scalar_t g0, scalar_t g1, scalar_t g2) { // index_t s_J = J_inv.strides[0]; // index_t s_x = delta_x.strides[0]; // index_t s_g = delta_gx.strides[0]; index_t s_J = 9; index_t s_x = 3; index_t s_g = 3; scalar_t J00 = J_inv[3*0 + 0]; scalar_t J01 = J_inv[3*0 + 1]; scalar_t J02 = J_inv[3*0 + 2]; scalar_t J10 = J_inv[3*1 + 0]; scalar_t J11 = J_inv[3*1 + 1]; scalar_t J12 = J_inv[3*1 + 2]; scalar_t J20 = J_inv[3*2 + 0]; scalar_t J21 = J_inv[3*2 + 1]; scalar_t J22 = J_inv[3*2 + 2]; auto c0 = J00 * x0 + J10 * x1 + J20 * x2; auto c1 = J01 * x0 + J11 * x1 + J21 * x2; auto c2 = J02 * x0 + J12 * x1 + J22 * x2; auto s = c0 * g0 + c1 * g1 + c2 * g2; auto r0 = -J00 * g0 - J01 * g1 - J02 * g2; auto r1 = -J10 * g0 - J11 * g1 - J12 * g2; auto r2 = -J20 * g0 - J21 * g1 - J22 * g2; J_inv[3*0 + 0] += c0 * (r0 + x0) / s; J_inv[3*0 + 1] += c1 * (r0 + x0) / s; J_inv[3*0 + 2] += c2 * (r0 + x0) / s; J_inv[3*1 + 0] += c0 * (r1 + x1) / s; J_inv[3*1 + 1] += c1 * (r1 + x1) / s; J_inv[3*1 + 2] += c2 * (r1 + x1) / s; J_inv[3*2 + 0] += c0 * (r2 + x2) / s; J_inv[3*2 + 1] += c1 * (r2 + x2) / s; J_inv[3*2 + 2] += c2 * (r2 + x2) / s; } static __forceinline__ __device__ bool within_bounds_3d(int d, int h, int w, int D, int H, int W) { return d >= 0 && d < D && h >= 0 && h < H && w >= 0 && w < W; } template static __forceinline__ __device__ scalar_t grid_sampler_unnormalize(scalar_t coord, int size, bool align_corners) { if (align_corners) { // unnormalize coord from [-1, 1] to [0, size - 1] return ((coord + 1.f) / 2) * (size - 1); } else { // unnormalize coord from [-1, 1] to [-0.5, size - 0.5] return ((coord + 1.f) * size - 1) / 2; } } // Clips coordinates to between 0 and clip_limit - 1 template static __forceinline__ __device__ scalar_t clip_coordinates(scalar_t in, int clip_limit) { return ::min(static_cast(clip_limit - 1), ::max(in, static_cast(0))); } template static __forceinline__ __device__ scalar_t safe_downgrade_to_int_range(scalar_t x){ // -100.0 does not have special meaning. This is just to make sure // it's not within_bounds_2d or within_bounds_3d, and does not cause // undefined behavior. See #35506. if (x > INT_MAX-1 || x < INT_MIN || !::isfinite(static_cast(x))) return static_cast(-100.0); return x; } template static __forceinline__ __device__ scalar_t compute_coordinates(scalar_t coord, int size, bool align_corners) { // clip coordinates to image borders // coord = clip_coordinates(coord, size); coord = safe_downgrade_to_int_range(coord); return coord; } template static __forceinline__ __device__ scalar_t grid_sampler_compute_source_index(scalar_t coord, int size, bool align_corners) { coord = grid_sampler_unnormalize(coord, size, align_corners); coord = compute_coordinates(coord, size, align_corners); return coord; } template __device__ void grid_sampler_3d( index_t i_batch, TensorInfo input, scalar_t grid_x, scalar_t grid_y, scalar_t grid_z, // TensorInfo output, PackedTensorAccessor32 input_p, // [1, 3, 8, 32, 32] scalar_t* output, // PackedTensorAccessor32 output_p, // [1800000, 3, 1] bool align_corners, bool nearest) { index_t C = input.sizes[1]; index_t inp_D = input.sizes[2]; index_t inp_H = input.sizes[3]; index_t inp_W = input.sizes[4]; // broyden x.sizes=[1800000,3,1] index_t inp_sN = input.strides[0]; index_t inp_sC = input.strides[1]; index_t inp_sD = input.strides[2]; index_t inp_sH = input.strides[3]; index_t inp_sW = input.strides[4]; index_t out_sC = 1; //output size is same as grid size... // return; // get the corresponding input x, y, z co-ordinates from grid scalar_t ix = grid_x; scalar_t iy = grid_y; scalar_t iz = grid_z; // c0 ix,iy,iz=-0.848051,0.592726,0.259927 // c1 ix,iy,iz=2.355216,24.687256,4.409743 ix = grid_sampler_compute_source_index(ix, inp_W, align_corners); iy = grid_sampler_compute_source_index(iy, inp_H, align_corners); iz = grid_sampler_compute_source_index(iz, inp_D, align_corners); if(!nearest){ // get corner pixel values from (x, y, z) // for 4d, we used north-east-south-west // for 5d, we add top-bottom index_t ix_tnw = static_cast(::floor(ix)); index_t iy_tnw = static_cast(::floor(iy)); index_t iz_tnw = static_cast(::floor(iz)); index_t ix_tne = ix_tnw + 1; index_t iy_tne = iy_tnw; index_t iz_tne = iz_tnw; index_t ix_tsw = ix_tnw; index_t iy_tsw = iy_tnw + 1; index_t iz_tsw = iz_tnw; index_t ix_tse = ix_tnw + 1; index_t iy_tse = iy_tnw + 1; index_t iz_tse = iz_tnw; index_t ix_bnw = ix_tnw; index_t iy_bnw = iy_tnw; index_t iz_bnw = iz_tnw + 1; index_t ix_bne = ix_tnw + 1; index_t iy_bne = iy_tnw; index_t iz_bne = iz_tnw + 1; index_t ix_bsw = ix_tnw; index_t iy_bsw = iy_tnw + 1; index_t iz_bsw = iz_tnw + 1; index_t ix_bse = ix_tnw + 1; index_t iy_bse = iy_tnw + 1; index_t iz_bse = iz_tnw + 1; // get surfaces to each neighbor: scalar_t tnw = (ix_bse - ix) * (iy_bse - iy) * (iz_bse - iz); scalar_t tne = (ix - ix_bsw) * (iy_bsw - iy) * (iz_bsw - iz); scalar_t tsw = (ix_bne - ix) * (iy - iy_bne) * (iz_bne - iz); scalar_t tse = (ix - ix_bnw) * (iy - iy_bnw) * (iz_bnw - iz); scalar_t bnw = (ix_tse - ix) * (iy_tse - iy) * (iz - iz_tse); scalar_t bne = (ix - ix_tsw) * (iy_tsw - iy) * (iz - iz_tsw); scalar_t bsw = (ix_tne - ix) * (iy - iy_tne) * (iz - iz_tne); scalar_t bse = (ix - ix_tnw) * (iy - iy_tnw) * (iz - iz_tnw); for (index_t xyz = 0; xyz < C; xyz++) { output[xyz] = 0; if (within_bounds_3d(iz_tnw, iy_tnw, ix_tnw, inp_D, inp_H, inp_W)) { output[xyz] += input_p[i_batch][xyz][iz_tnw][iy_tnw][ix_tnw] * tnw; // *out_ptr_NCDHW += inp_ptr_NC[iz_tnw * inp_sD + iy_tnw * inp_sH + ix_tnw * inp_sW] * tnw; } if (within_bounds_3d(iz_tne, iy_tne, ix_tne, inp_D, inp_H, inp_W)) { output[xyz] += input_p[i_batch][xyz][iz_tne][iy_tne][ix_tne] * tne; // *out_ptr_NCDHW += inp_ptr_NC[iz_tne * inp_sD + iy_tne * inp_sH + ix_tne * inp_sW] * tne; } if (within_bounds_3d(iz_tsw, iy_tsw, ix_tsw, inp_D, inp_H, inp_W)) { output[xyz] += input_p[i_batch][xyz][iz_tsw][iy_tsw][ix_tsw] * tsw; // *out_ptr_NCDHW += inp_ptr_NC[iz_tsw * inp_sD + iy_tsw * inp_sH + ix_tsw * inp_sW] * tsw; } if (within_bounds_3d(iz_tse, iy_tse, ix_tse, inp_D, inp_H, inp_W)) { output[xyz] += input_p[i_batch][xyz][iz_tse][iy_tse][ix_tse] * tse; // *out_ptr_NCDHW += inp_ptr_NC[iz_tse * inp_sD + iy_tse * inp_sH + ix_tse * inp_sW] * tse; } if (within_bounds_3d(iz_bnw, iy_bnw, ix_bnw, inp_D, inp_H, inp_W)) { output[xyz] += input_p[i_batch][xyz][iz_bnw][iy_bnw][ix_bnw] * bnw; // *out_ptr_NCDHW += inp_ptr_NC[iz_bnw * inp_sD + iy_bnw * inp_sH + ix_bnw * inp_sW] * bnw; } if (within_bounds_3d(iz_bne, iy_bne, ix_bne, inp_D, inp_H, inp_W)) { output[xyz] += input_p[i_batch][xyz][iz_bne][iy_bne][ix_bne] * bne; // *out_ptr_NCDHW += inp_ptr_NC[iz_bne * inp_sD + iy_bne * inp_sH + ix_bne * inp_sW] * bne; } if (within_bounds_3d(iz_bsw, iy_bsw, ix_bsw, inp_D, inp_H, inp_W)) { output[xyz] += input_p[i_batch][xyz][iz_bsw][iy_bsw][ix_bsw] * bsw; // *out_ptr_NCDHW += inp_ptr_NC[iz_bsw * inp_sD + iy_bsw * inp_sH + ix_bsw * inp_sW] * bsw; } if (within_bounds_3d(iz_bse, iy_bse, ix_bse, inp_D, inp_H, inp_W)) { output[xyz] += input_p[i_batch][xyz][iz_bse][iy_bse][ix_bse] * bse; // *out_ptr_NCDHW += inp_ptr_NC[iz_bse * inp_sD + iy_bse * inp_sH + ix_bse * inp_sW] * bse; } } } else{ index_t ix_nearest = static_cast(::round(ix)); index_t iy_nearest = static_cast(::round(iy)); index_t iz_nearest = static_cast(::round(iz)); for (index_t xyz = 0; xyz < C; xyz++) { if (within_bounds_3d(iz_nearest, iy_nearest, ix_nearest, inp_D, inp_H, inp_W)) { output[xyz] = input_p[i_batch][xyz][iz_nearest][iy_nearest][ix_nearest]; } else { output[xyz] = static_cast(0); } } } } template C10_LAUNCH_BOUNDS_1(512) __global__ void broyden_kernel( const index_t npoints, const index_t n_batch, const index_t n_point, const index_t n_init, TensorInfo voxel_ti, TensorInfo voxel_J_ti, PackedTensorAccessor32 x, // shape=(N,200000, 9, 3) PackedTensorAccessor32 xd_tgt, // shape=(N,200000, 3) PackedTensorAccessor32 voxel, // shape=(N,3,8,32,32) PackedTensorAccessor32 grid_J_inv, // shape=(N,9,8,32,32) PackedTensorAccessor32 tfs, // shape=(N,24,4,4) PackedTensorAccessor32 bone_ids, // shape=(9) // PackedTensorAccessor32 J_inv,// shape=(N,200000, 9, 9) PackedTensorAccessor32 is_valid,// shape=(N,200000, 9) PackedTensorAccessor32 offset, // shape=(N, 1, 3) PackedTensorAccessor32 scale, // shape=(N, 1, 3) float cvg_threshold, float dvg_threshold, int N ) { index_t index = blockIdx.x * blockDim.x + threadIdx.x; if(index >= npoints) return; const index_t i_batch = index / (n_point*n_init); const index_t i_point = (index % (n_point*n_init)) / n_init; const index_t i_init = (index % (n_point*n_init)) % n_init; if(!is_valid[i_batch][i_point][i_init]){ return; } scalar_t gx[3]; scalar_t gx_new[3]; scalar_t xd_tgt_index[3]; xd_tgt_index[0] = xd_tgt[i_batch][i_point][0]; xd_tgt_index[1] = xd_tgt[i_batch][i_point][1]; xd_tgt_index[2] = xd_tgt[i_batch][i_point][2]; scalar_t x_l[3]; int i_bone = bone_ids[i_init]; scalar_t ixd = xd_tgt_index[0] - tfs[i_batch][i_bone][0][3]; scalar_t iyd = xd_tgt_index[1] - tfs[i_batch][i_bone][1][3]; scalar_t izd = xd_tgt_index[2] - tfs[i_batch][i_bone][2][3]; x_l[0] = ixd * tfs[i_batch][i_bone][0][0] + iyd * tfs[i_batch][i_bone][1][0] + izd * tfs[i_batch][i_bone][2][0]; x_l[1] = ixd * tfs[i_batch][i_bone][0][1] + iyd * tfs[i_batch][i_bone][1][1] + izd * tfs[i_batch][i_bone][2][1]; x_l[2] = ixd * tfs[i_batch][i_bone][0][2] + iyd * tfs[i_batch][i_bone][1][2] + izd * tfs[i_batch][i_bone][2][2]; // scalar_t ix = scale[0][0][0] * (x_l[0] + offset[0][0][0]); // scalar_t iy = scale[0][0][1] * (x_l[1] + offset[0][0][1]); // scalar_t iz = scale[0][0][2] * (x_l[2] + offset[0][0][2]); scalar_t J_local[12]; grid_sampler_3d( i_batch, voxel_J_ti, scale[0][0][0] * (x_l[0] + offset[0][0][0]), scale[0][0][1] * (x_l[1] + offset[0][0][1]), scale[0][0][2] * (x_l[2] + offset[0][0][2]), grid_J_inv, J_local, true, false); scalar_t J_inv_local[9]; J_inv_local[3*0 + 0] = J_local[4*0 + 0]; J_inv_local[3*1 + 0] = J_local[4*0 + 1]; J_inv_local[3*2 + 0] = J_local[4*0 + 2]; J_inv_local[3*0 + 1] = J_local[4*1 + 0]; J_inv_local[3*1 + 1] = J_local[4*1 + 1]; J_inv_local[3*2 + 1] = J_local[4*1 + 2]; J_inv_local[3*0 + 2] = J_local[4*2 + 0]; J_inv_local[3*1 + 2] = J_local[4*2 + 1]; J_inv_local[3*2 + 2] = J_local[4*2 + 2]; for(int i=0; i<10; i++) { scalar_t J00 = J_inv_local[3*0 + 0]; scalar_t J01 = J_inv_local[3*0 + 1]; scalar_t J02 = J_inv_local[3*0 + 2]; scalar_t J10 = J_inv_local[3*1 + 0]; scalar_t J11 = J_inv_local[3*1 + 1]; scalar_t J12 = J_inv_local[3*1 + 2]; scalar_t J20 = J_inv_local[3*2 + 0]; scalar_t J21 = J_inv_local[3*2 + 1]; scalar_t J22 = J_inv_local[3*2 + 2]; // gx = g(x) if (i==0){ // grid_sampler_3d( i_batch, voxel_ti, // scale[0][0][0] * (x_l[0] + offset[0][0][0]), // scale[0][0][1] * (x_l[1] + offset[0][0][1]), // scale[0][0][2] * (x_l[2] + offset[0][0][2]), // voxel, // gx, // true, // false); gx[0] = J_local[4*0+0] * x_l[0] + J_local[4*0+1] * x_l[1] + J_local[4*0+2] * x_l[2] + J_local[4*0+3]; gx[1] = J_local[4*1+0] * x_l[0] + J_local[4*1+1] * x_l[1] + J_local[4*1+2] * x_l[2] + J_local[4*1+3]; gx[2] = J_local[4*2+0] * x_l[0] + J_local[4*2+1] * x_l[1] + J_local[4*2+2] * x_l[2] + J_local[4*2+3]; gx[0] = gx[0] - xd_tgt_index[0]; gx[1] = gx[1] - xd_tgt_index[1]; gx[2] = gx[2] - xd_tgt_index[2]; } else{ gx[0] = gx_new[0]; gx[1] = gx_new[1]; gx[2] = gx_new[2]; } // update = -J_inv @ gx scalar_t u0 = -J00*gx[0] + -J01*gx[1] + -J02*gx[2]; scalar_t u1 = -J10*gx[0] + -J11*gx[1] + -J12*gx[2]; scalar_t u2 = -J20*gx[0] + -J21*gx[1] + -J22*gx[2]; // x += update x_l[0] += u0; x_l[1] += u1; x_l[2] += u2; scalar_t ix = scale[0][0][0] * (x_l[0] + offset[0][0][0]); scalar_t iy = scale[0][0][1] * (x_l[1] + offset[0][0][1]); scalar_t iz = scale[0][0][2] * (x_l[2] + offset[0][0][2]); // gx_new = g(x) grid_sampler_3d( i_batch, voxel_J_ti, ix, iy, iz, grid_J_inv, J_local, true, false); gx_new[0] = J_local[4*0+0] * x_l[0] + J_local[4*0+1] * x_l[1] + J_local[4*0+2] * x_l[2] + J_local[4*0+3] - xd_tgt_index[0]; gx_new[1] = J_local[4*1+0] * x_l[0] + J_local[4*1+1] * x_l[1] + J_local[4*1+2] * x_l[2] + J_local[4*1+3] - xd_tgt_index[1]; gx_new[2] = J_local[4*2+0] * x_l[0] + J_local[4*2+1] * x_l[1] + J_local[4*2+2] * x_l[2] + J_local[4*2+3] - xd_tgt_index[2]; // grid_sampler_3d( i_batch, voxel_ti, // scale[0][0][0] * (x_l[0] + offset[0][0][0]), // scale[0][0][1] * (x_l[1] + offset[0][0][1]), // scale[0][0][2] * (x_l[2] + offset[0][0][2]), // voxel, // gx_new, // true, // false); // gx_new[0] = gx_new[0] - xd_tgt_index[0]; // gx_new[1] = gx_new[1] - xd_tgt_index[1]; // gx_new[2] = gx_new[2] - xd_tgt_index[2]; // convergence checking scalar_t norm_gx = gx_new[0]*gx_new[0] + gx_new[1]*gx_new[1] + gx_new[2]*gx_new[2]; // convergence/divergence criterion if(norm_gx < cvg_threshold*cvg_threshold) { bool is_valid_ = ix >= -1 && ix <= 1 && iy >= -1 && iy <= 1 && iz >= -1 && iz <= 1; is_valid[i_batch][i_point][i_init] = is_valid_; if (is_valid_){ x[i_batch][i_point][i_init][0] = x_l[0]; x[i_batch][i_point][i_init][1] = x_l[1]; x[i_batch][i_point][i_init][2] = x_l[2]; // J_inv[i_batch][i_point][i_init][0][0] = J00; // J_inv[i_batch][i_point][i_init][0][1] = J01; // J_inv[i_batch][i_point][i_init][0][2] = J02; // J_inv[i_batch][i_point][i_init][1][0] = J10; // J_inv[i_batch][i_point][i_init][1][1] = J11; // J_inv[i_batch][i_point][i_init][1][2] = J12; // J_inv[i_batch][i_point][i_init][2][0] = J20; // J_inv[i_batch][i_point][i_init][2][1] = J21; // J_inv[i_batch][i_point][i_init][2][2] = J22; } return; } else if(norm_gx > dvg_threshold*dvg_threshold) { is_valid[i_batch][i_point][i_init] = false; return; } // delta_x = update scalar_t delta_x_0 = u0; scalar_t delta_x_1 = u1; scalar_t delta_x_2 = u2; // delta_gx = gx_new - gx scalar_t delta_gx_0 = gx_new[0] - gx[0]; scalar_t delta_gx_1 = gx_new[1] - gx[1]; scalar_t delta_gx_2 = gx_new[2] - gx[2]; fuse_J_inv_update(index, J_inv_local, delta_x_0, delta_x_1, delta_x_2, delta_gx_0, delta_gx_1, delta_gx_2); } is_valid[i_batch][i_point][i_init] = false; return; } void launch_broyden_kernel( Tensor &x, const Tensor &xd_tgt, const Tensor &voxel, const Tensor &grid_J_inv, const Tensor &tfs, const Tensor &bone_ids, bool align_corners, // Tensor &J_inv, Tensor &is_valid, const Tensor& offset, const Tensor& scale, float cvg_threshold, float dvg_threshold) { // calculate #threads required int64_t n_batch = xd_tgt.size(0); int64_t n_point = xd_tgt.size(1); int64_t n_init = bone_ids.size(0); int64_t count = n_batch * n_point * n_init; if (count > 0) { AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "fuse_kernel_cuda", [&] { broyden_kernel <<>>(static_cast(count), static_cast(n_batch), static_cast(n_point), static_cast(n_init), getTensorInfo(voxel), getTensorInfo(grid_J_inv), x.packed_accessor32(), xd_tgt.packed_accessor32(), voxel.packed_accessor32(), grid_J_inv.packed_accessor32(), tfs.packed_accessor32(), bone_ids.packed_accessor32(), // J_inv.packed_accessor32(), is_valid.packed_accessor32(), offset.packed_accessor32(), scale.packed_accessor32(), cvg_threshold, dvg_threshold, 0); C10_CUDA_KERNEL_LAUNCH_CHECK(); }); } cudaDeviceSynchronize(); } ================================================ FILE: lib/models/deformers/fast_snarf/cuda/precompute/precompute.cpp ================================================ #include #include #include #include void launch_precompute(const torch::Tensor &voxel_w, const torch::Tensor &tfs, torch::Tensor &voxel_d, torch::Tensor &voxel_J, const torch::Tensor &offset, const torch::Tensor &scale); void precompute(const torch::Tensor &voxel_w, const torch::Tensor &tfs, torch::Tensor &voxel_d, torch::Tensor &voxel_J, const torch::Tensor &offset, const torch::Tensor &scale) { const at::cuda::OptionalCUDAGuard device_guard(device_of(voxel_w)); launch_precompute(voxel_w, tfs, voxel_d, voxel_J, offset, scale); } PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("precompute", &precompute); } ================================================ FILE: lib/models/deformers/fast_snarf/cuda/precompute/precompute_kernel.cu ================================================ #include "ATen/Functions.h" #include "ATen/core/TensorAccessor.h" #include "c10/cuda/CUDAException.h" #include "c10/cuda/CUDAStream.h" #include #include #include #include #include #include #include #include #include #include using namespace std::chrono; using namespace at; using namespace at::cuda::detail; template C10_LAUNCH_BOUNDS_1(512) __global__ void precompute_kernel( const index_t npoints, const index_t d, const index_t h, const index_t w, PackedTensorAccessor32 voxel_w, // shape=(N,200000, 9, 3) PackedTensorAccessor32 tfs, // shape=(N,200000, 3) PackedTensorAccessor32 voxel_d, // shape=(N,3,8,32,32) PackedTensorAccessor32 voxel_J, // shape=(N,9,8,32,32) PackedTensorAccessor32 offset, // shape=(N, 1, 3) PackedTensorAccessor32 scale // shape=(N, 1, 3) ) { index_t index = blockIdx.x * blockDim.x + threadIdx.x; if(index >= npoints) return; index_t idx_b = index / (d*h*w); index_t idx_d = index % (d*h*w) / (h*w); index_t idx_h = index % (d*h*w) % (h*w) / w; index_t idx_w = index % (d*h*w) % (h*w) % w; scalar_t coord_x = ( ((scalar_t)idx_w) / (w-1) * 2 -1) / scale[0][0][0] - offset[0][0][0]; scalar_t coord_y = ( ((scalar_t)idx_h) / (h-1) * 2 -1) / scale[0][0][1] - offset[0][0][1]; scalar_t coord_z = ( ((scalar_t)idx_d) / (d-1) * 2 -1) / scale[0][0][2] - offset[0][0][2]; scalar_t J[12]; for(index_t i0 = 0; i0 < 3; i0++){ for(index_t i1 = 0; i1 < 4; i1++){ J[i0*4 + i1] = 0; for(index_t j = 0; j < 24; j++){ J[i0*4 + i1] += voxel_w[0][j][idx_d][idx_h][idx_w]*tfs[idx_b][j][i0][i1]; } } } for(index_t i0 = 0; i0 < 3; i0++){ for(index_t i1 = 0; i1 < 4; i1++){ voxel_J[idx_b][i0*4 + i1][idx_d][idx_h][idx_w] = J[i0*4 + i1]; } } for(index_t i0 = 0; i0 < 3; i0++){ scalar_t xi = J[i0*4 + 0]*coord_x + J[i0*4 + 1]*coord_y + J[i0*4 + 2]*coord_z + J[i0*4 + 3]; voxel_d[idx_b][i0][idx_d][idx_h][idx_w] = xi; } } void launch_precompute( const Tensor &voxel_w, const Tensor &tfs, Tensor &voxel_d, Tensor &voxel_J, const Tensor &offset, const Tensor &scale ) { // calculate #threads required int64_t n_batch = voxel_d.size(0); int64_t d = voxel_d.size(2); int64_t h = voxel_d.size(3); int64_t w = voxel_d.size(4); int64_t count = n_batch*d*h*w; if (count > 0) { AT_DISPATCH_FLOATING_TYPES_AND_HALF(voxel_w.scalar_type(), "precompute", [&] { precompute_kernel <<>>(static_cast(count), static_cast(d), static_cast(h), static_cast(w), voxel_w.packed_accessor32(), tfs.packed_accessor32(), voxel_d.packed_accessor32(), voxel_J.packed_accessor32(), offset.packed_accessor32(), scale.packed_accessor32() ); C10_CUDA_KERNEL_LAUNCH_CHECK(); }); } cudaDeviceSynchronize(); } ================================================ FILE: lib/models/deformers/fast_snarf/lib/model/deformer_smpl.py ================================================ import torch from torch import einsum import torch.nn.functional as F import os from torch.utils.cpp_extension import load import fuse_cuda import filter_cuda import precompute_cuda import numpy as np class ForwardDeformer(torch.nn.Module): """ Tensor shape abbreviation: B: batch size N: number of points J: number of bones I: number of init D: space dimension """ def __init__(self, **kwargs): super().__init__() self.soft_blend = 20 self.init_bones = [0, 1, 2, 4, 5, 12, 15, 16, 17, 18, 19] self.init_bones_cuda = torch.tensor(self.init_bones).int() self.global_scale = 1.2 def forward(self, xd, cond, mask, tfs, eval_mode=False): """Given deformed point return its caonical correspondence Args: xd (tensor): deformed points in batch. shape: [B, N, D] cond (dict): conditional input. tfs (tensor): bone transformation matrices. shape: [B, J, D+1, D+1] Returns: xc (tensor): canonical correspondences. shape: [B, N, I, D] others (dict): other useful outputs. """ xc_opt, others = self.search(xd, cond, mask, tfs, eval_mode=True) if eval_mode: return xc_opt, others def precompute(self, tfs): b, c, d, h, w = tfs.shape[0], 3, self.resolution//4, self.resolution, self.resolution voxel_d = torch.zeros((b,3,d,h,w), device=tfs.device) voxel_J = torch.zeros((b,12,d,h,w), device=tfs.device) precompute_cuda.precompute(self.lbs_voxel_final, tfs, voxel_d, voxel_J, self.offset_kernel, self.scale_kernel) self.voxel_d = voxel_d self.voxel_J = voxel_J def search(self, xd, cond, mask, tfs, eval_mode=False): """Search correspondences. Args: xd (tensor): deformed points in batch. shape: [B, N, D] xc_init (tensor): deformed points in batch. shape: [B, N, I, D] cond (dict): conditional input. tfs (tensor): bone transformation matrices. shape: [B, J, D+1, D+1] Returns: xc_opt (tensor): canonoical correspondences of xd. shape: [B, N, I, D] valid_ids (tensor): identifiers of converged points. [B, N, I] """ # reshape to [B,?,D] for other functions # run broyden without grad with torch.no_grad(): result = self.broyden_cuda(xd, self.voxel_d, self.voxel_J, tfs, mask) return result['result'], result def broyden_cuda(self, xd_tgt, voxel, voxel_J_inv, tfs, mask, cvg_thresh=2e-4, dvg_thresh=1): """ Args: g: f: (N, 3, 1) -> (N, 3, 1) x: (N, 3, 1) J_inv: (N, 3, 3) """ b,n,_ = xd_tgt.shape n_init = self.init_bones_cuda.shape[0] xc_init_IN = torch.zeros((b,n,n_init,3),device=xd_tgt.device,dtype=torch.float) is_valid = mask.expand(b,n,n_init).clone() if self.init_bones_cuda.device != xd_tgt.device: self.init_bones_cuda = self.init_bones_cuda.to(xd_tgt.device) fuse_cuda.fuse_broyden(xc_init_IN, xd_tgt, voxel, voxel_J_inv, tfs, self.init_bones_cuda, True, is_valid, self.offset_kernel, self.scale_kernel, cvg_thresh, dvg_thresh) is_valid_new = torch.zeros_like(is_valid) filter_cuda.filter(xc_init_IN, is_valid, is_valid_new) return {"result": xc_init_IN, 'valid_ids': is_valid_new} #, 'J_inv': J_inv_init_IN} def forward_skinning(self, xc, cond, tfs, mask=None): """Canonical point -> deformed point Args: xc (tensor): canonoical points in batch. shape: [B, N, D] cond (dict): conditional input. tfs (tensor): bone transformation matrices. shape: [B, J, D+1, D+1] Returns: xd (tensor): deformed point. shape: [B, N, D] """ w = self.query_weights(xc, cond, mask=mask) b,n,_ = xc.shape xd = torch.zeros((b,n,3), device=xc.device, dtype=torch.float) w_tf = torch.eye(4, device=xc.device, dtype=torch.float).reshape(1, 1, 4, 4).expand(b, n, -1, -1).clone() xd[mask, :], w_tf[mask, :] = skinning_mask(xc[mask,:], w[mask,:], tfs, inverse=False) return xd, w_tf def switch_to_explicit(self,resolution=32,smpl_verts=None, smpl_faces=None, smpl_weights=None, use_smpl=False): self.resolution = resolution # convert to voxel grid # device = self.device b, c, d, h, w = 1, 24, resolution//4, resolution, resolution self.ratio = h/d grid = create_voxel_grid(d, h, w) device = grid.device gt_bbox = torch.cat([smpl_verts.min(dim=1).values, smpl_verts.max(dim=1).values], dim=0).to(device) offset = (gt_bbox[0] + gt_bbox[1])[None,None,:] * 0.5 scale = (gt_bbox[1] - gt_bbox[0]).max()/2 * self.global_scale self.register_buffer('scale', scale) self.register_buffer('offset', offset) self.register_buffer('offset_kernel', -self.offset) scale_kernel = torch.zeros_like(self.offset) scale_kernel[...] = 1./self.scale scale_kernel[:,:,-1] = scale_kernel[:,:,-1] * self.ratio self.register_buffer('scale_kernel', scale_kernel) def normalize(x): x_normalized = (x+self.offset_kernel)*self.scale_kernel return x_normalized def denormalize(x): x_denormalized = x.clone() #/self.global_scale x_denormalized[..., -1] = x_denormalized[..., -1]/self.ratio x_denormalized *= self.scale x_denormalized += self.offset return x_denormalized self.normalize = normalize self.denormalize = denormalize grid_denorm = self.denormalize(grid) weights = query_weights_smpl(grid_denorm, smpl_verts=smpl_verts.detach().clone(), smpl_weights=smpl_weights.detach().clone()).detach().clone() self.register_buffer('lbs_voxel_final', weights.detach()) self.register_buffer('grid_denorm',grid_denorm) def query_weights( xc, cond=None, mask=None, mode='bilinear'): w = F.grid_sample(self.lbs_voxel_final.expand(xc.shape[0],-1,-1,-1,-1), self.normalize(xc).unsqueeze(2).unsqueeze(2),align_corners=True, mode=mode,padding_mode='border') w = w.squeeze(-1).squeeze(-1).permute(0,2,1) return w self.query_weights = query_weights def update_lbs_voxel(self): self.lbs_voxel_final = F.softmax( self.lbs_voxel*20,dim=1) def query_weights( xc, cond=None, mask=None, mode='bilinear'): w = F.grid_sample(self.lbs_voxel_final.expand(xc.shape[0],-1,-1,-1,-1), self.normalize(xc).unsqueeze(2).unsqueeze(2),align_corners=True, mode=mode,padding_mode='border') w = w.squeeze(-1).squeeze(-1).permute(0,2,1) return w self.query_weights = query_weights def query_sdf_smpl(self, x, smpl_verts, smpl_faces, smpl_weights): device = x.device resolution=128 b, c, d, h, w = 1, 24, resolution//4, resolution, resolution grid = create_voxel_grid(d, h, w, device) grid = self.denormalize(grid) import trimesh mesh = trimesh.Trimesh(vertices=smpl_verts.data.cpu().numpy()[0], faces=smpl_faces.data.cpu().numpy()) BVH = cubvh.cuBVH(mesh.vertices, mesh.faces) sdf, face_id, uvw = BVH.signed_distance(grid, return_uvw=True, mode='watertight') # [N], [N], [N, 3] sdf = sdf.reshape(1, -1, 1) b, c, d, h, w = 1, 1, resolution//4, resolution, resolution sdf = -sdf.permute(0,2,1).reshape(b,c,d,h,w) return sdf.detach() def skinning_normal(self, xc, normal, tfs, cond=None, mask=None, inverse=False): ''' skinning normals Args: x (tensor): canonical points. shape: [B, N, D] normal (tensor): canonical normals. shape: [B, N, D] tfs (tensor): bone transformation matrices. shape: [B, J, D+1, D+1] Returns: posed normal (tensor): posed normals. shape: [B, N, D] ''' if xc.ndim == 2: xc = xc.unsqueeze(0) if normal.ndim == 2: normal = normal.unsqueeze(0) w = self.query_weights(xc, cond, mask=mask) p_h = F.pad(normal, (0, 1), value=0) p_h = torch.einsum('bpn, bnij, bpj->bpi', w, tfs, p_h) return p_h[:, :, :3] def skinning_mask(x, w, tfs, inverse=False): """Linear blend skinning Args: x (tensor): canonical points. shape: [B, N, D] w (tensor): conditional input. [B, N, J] tfs (tensor): bone transformation matrices. shape: [B, J, D+1, D+1] Returns: x (tensor): skinned points. shape: [B, N, D] """ x_h = F.pad(x, (0, 1), value=1.0) p,n = w.shape if inverse: # p:n_point, n:n_bone, i,k: n_dim+1 w_tf = einsum("bpn,bnij->bpij", w, tfs) x_h = x_h.view(b,p,1,4).expand(b,p,4,4) x_h = (fast_inverse(w_tf)*x_h).sum(-1) else: w_tf = einsum("pn,nij->pij", w, tfs.squeeze(0)) x_h = x_h.view(p,1,4).expand(p,4,4) x_h = (w_tf*x_h).sum(-1) return x_h[:, :3], w_tf def skinning(x, w, tfs, inverse=False): """Linear blend skinning Args: x (tensor): canonical points. shape: [B, N, D] w (tensor): conditional input. [B, N, J] tfs (tensor): bone transformation matrices. shape: [B, J, D+1, D+1] Returns: x (tensor): skinned points. shape: [B, N, D] """ x_h = F.pad(x, (0, 1), value=1.0) b,p,n = w.shape if inverse: # p:n_point, n:n_bone, i,k: n_dim+1 w_tf = einsum("bpn,bnij->bpij", w, fast_inverse(tfs)) x_h = x_h.view(b,p,1,4).expand(b,p,4,4) # x_h = (fast_inverse(w_tf)*x_h).sum(-1) x_h = (w_tf*x_h).sum(-1) else: w_tf = einsum("bpn,bnij->bpij", w, tfs) x_h = x_h.view(b,p,1,4).expand(b,p,4,4) x_h = (w_tf*x_h).sum(-1) return x_h[:, :, :3] def fast_inverse(T): shape = T.shape T = T.reshape(-1,4,4) R = T[:, :3,:3] t = T[:, :3,3].unsqueeze(-1) R_inv = R.transpose(1,2) t_inv = -bmv(R_inv,t) T_inv = T T_inv[:,:3,:3] = R_inv T_inv[:,:3,3] = t_inv.squeeze(-1) return T_inv.reshape(shape) def bmv(m, v): return (m*v.transpose(-1,-2).expand(-1,3,-1)).sum(-1,keepdim=True) def create_voxel_grid(d, h, w, device='cuda'): x_range = (torch.linspace(-1,1,steps=w,device=device)).view(1, 1, 1, w).expand(1, d, h, w) # [1, H, W, D] y_range = (torch.linspace(-1,1,steps=h,device=device)).view(1, 1, h, 1).expand(1, d, h, w) # [1, H, W, D] z_range = (torch.linspace(-1,1,steps=d,device=device)).view(1, d, 1, 1).expand(1, d, h, w) # [1, H, W, D] grid = torch.cat((x_range, y_range, z_range), dim=0).reshape(1, 3,-1).permute(0,2,1) return grid def query_weights_smpl(x, smpl_verts, smpl_weights): import pytorch3d.ops as ops device = smpl_weights.device distance_batch, index_batch, neighbor_points = ops.knn_points(x.to(device),smpl_verts.to(device).detach(),K=10,return_nn=True) # neighbor_points = neighbor_points[0] distance_batch = distance_batch[0].sqrt().clamp_(0.00003,0.1) index_batch = index_batch[0] # GPU_id = index_batch.get_device() # print(GPU_id) weights = smpl_weights[0,index_batch] ws=1./distance_batch ws=ws/ws.sum(-1,keepdim=True) weights = (ws[:,:,None]*weights).sum(1)[None] resolution = 64 b, c, d, h, w = 1, 24, resolution//4, resolution, resolution weights = weights.permute(0,2,1).reshape(b,c,d,h,w) return weights.detach() ================================================ FILE: lib/models/deformers/fast_snarf/lib/model/deformer_smplx.py ================================================ import torch from torch import einsum import torch.nn.functional as F import os from torch.utils.cpp_extension import load import fuse_cuda import filter_cuda import precompute_cuda import numpy as np class ForwardDeformer(torch.nn.Module): """ Tensor shape abbreviation: B: batch size N: number of points J: number of bones I: number of init D: space dimension """ def __init__(self, **kwargs): super().__init__() self.soft_blend = 20 self.init_bones = [0, 1, 2, 4, 5, 12, 15, 16, 17, 18, 19] self.init_bones_cuda = torch.tensor(self.init_bones).int() self.global_scale = 1.2 def forward_skinning(self, xc, shape_offset, pose_offset, cond, tfs, tfs_inv, poseoff_ori, lbsw=None, mask=None, pts_query_lbs=None): """Canonical point -> deformed point B == num_scenes Args: xc (tensor): canonoical points in batch. shape: [B, N, D] cond (dict): conditional input. tfs (tensor): bone transformation matrices. shape: [B, J, D+1, D+1] template -> posed tfs_inv (tensor): inverse bone transformation matrices. shape: [B, J, D+1, D+1] T-posed -> template pts_query_lbs (tensor): canonoical points in batch. shape: [B, N, D] , for Lbs weights Returns: xd (tensor): deformed point. shape: [B, N, D] # """ if pts_query_lbs==None: w = self.query_weights(xc, cond) else: w = self.query_weights(pts_query_lbs, cond) w[:, mask[0]] = lbsw[mask] ''' # saving the points xc[1, 50200, 3] with the colored mask [1, 50200] in a ply import trimesh mesh = trimesh.Trimesh(vertices=xc[0].detach().cpu().numpy()) mesh.visual.vertex_colors = mask[0].cpu().numpy().reshape(-1,1) * np.array([[255,0,0,255]]) mesh.export("./xc_mask.ply") ''' b,n,_ = xc.shape xc_cano, w_tf_inv = skinning(xc, w, tfs_inv.expand(b, -1, -1, -1), inverse=False) # T pose space -> template space xc_cano_ori = xc_cano - poseoff_ori.expand(b, -1, -1) xc_shape = xc_cano_ori + shape_offset + pose_offset # template space, use the given points in the template space to forwarding # b,n,_ = xc_shape.shape x_deform, w_tf = skinning(xc_shape, w, tfs, inverse=False) w_tf_all = w_tf @ w_tf_inv.expand(b, -1, -1, -1) # from T-pose to posed space return x_deform, w_tf_all def switch_to_explicit(self,resolution=32,smpl_verts=None, smpl_faces=None, smpl_weights=None, use_smpl=False): self.resolution = resolution # convert to voxel grid b, c, d, h, w = 1, 55, resolution//4, resolution, resolution self.ratio = h/d grid = create_voxel_grid(d, h, w) device = grid.device gt_bbox = torch.cat([smpl_verts.min(dim=1).values, smpl_verts.max(dim=1).values], dim=0).to(device) offset = (gt_bbox[0] + gt_bbox[1])[None,None,:] * 0.5 scale = (gt_bbox[1] - gt_bbox[0]).max()/2 * self.global_scale self.register_buffer('scale', scale) self.register_buffer('offset', offset) self.register_buffer('offset_kernel', -self.offset) scale_kernel = torch.zeros_like(self.offset) scale_kernel[...] = 1./self.scale scale_kernel[:,:,-1] = scale_kernel[:,:,-1] * self.ratio self.register_buffer('scale_kernel', scale_kernel) def normalize(x): x_normalized = (x+self.offset_kernel)*self.scale_kernel return x_normalized def denormalize(x): x_denormalized = x.clone() x_denormalized[..., -1] = x_denormalized[..., -1]/self.ratio x_denormalized *= self.scale x_denormalized += self.offset return x_denormalized self.normalize = normalize self.denormalize = denormalize grid_denorm = self.denormalize(grid) weights = query_weights_smpl(grid_denorm, smpl_verts=smpl_verts.detach().clone(), smpl_weights=smpl_weights.detach().clone()).detach().clone() self.register_buffer('lbs_voxel_final', weights.detach()) self.register_buffer('grid_denorm',grid_denorm) def query_weights( xc, cond=None, mask=None, mode='bilinear'): w = F.grid_sample(self.lbs_voxel_final.expand(xc.shape[0],-1,-1,-1,-1), self.normalize(xc).unsqueeze(2).unsqueeze(2),align_corners=True, mode=mode,padding_mode='border') w = w.squeeze(-1).squeeze(-1).permute(0,2,1) return w self.query_weights = query_weights def update_lbs_voxel(self): self.lbs_voxel_final = F.softmax( self.lbs_voxel*20,dim=1) def query_weights( xc, cond=None, mask=None, mode='bilinear'): w = F.grid_sample(self.lbs_voxel_final.expand(xc.shape[0],-1,-1,-1,-1), self.normalize(xc).unsqueeze(2).unsqueeze(2),align_corners=True, mode=mode,padding_mode='border') w = w.squeeze(-1).squeeze(-1).permute(0,2,1) return w self.query_weights = query_weights def query_sdf_smpl(self, x, smpl_verts, smpl_faces, smpl_weights): device = x.device resolution=128 b, c, d, h, w = 1, 24, resolution//4, resolution, resolution grid = create_voxel_grid(d, h, w, device) grid = self.denormalize(grid) import trimesh mesh = trimesh.Trimesh(vertices=smpl_verts.data.cpu().numpy()[0], faces=smpl_faces.data.cpu().numpy()) BVH = cubvh.cuBVH(mesh.vertices, mesh.faces) sdf, face_id, uvw = BVH.signed_distance(grid, return_uvw=True, mode='watertight') # [N], [N], [N, 3] sdf = sdf.reshape(1, -1, 1) b, c, d, h, w = 1, 1, resolution//4, resolution, resolution sdf = -sdf.permute(0,2,1).reshape(b,c,d,h,w) return sdf.detach() def skinning_normal(self, xc, normal, tfs, cond=None, mask=None, inverse=False): ''' skinning normals Args: x (tensor): canonical points. shape: [B, N, D] normal (tensor): canonical normals. shape: [B, N, D] tfs (tensor): bone transformation matrices. shape: [B, J, D+1, D+1] Returns: posed normal (tensor): posed normals. shape: [B, N, D] ''' if xc.ndim == 2: xc = xc.unsqueeze(0) if normal.ndim == 2: normal = normal.unsqueeze(0) w = self.query_weights(xc, cond, mask=mask) p_h = F.pad(normal, (0, 1), value=0) p_h = torch.einsum('bpn, bnij, bpj->bpi', w, tfs, p_h) return p_h[:, :, :3] def skinning_mask(x, w, tfs, inverse=False): """Linear blend skinning Args: x (tensor): canonical points. shape: [B, N, D] w (tensor): conditional input. [B, N, J] tfs (tensor): bone transformation matrices. shape: [B, J, D+1, D+1] Returns: x (tensor): skinned points. shape: [B, N, D] """ x_h = F.pad(x, (0, 1), value=1.0) p,n = w.shape if inverse: # p:n_point, n:n_bone, i,k: n_dim+1 w_tf = einsum("bpn,bnij->bpij", w, tfs) x_h = x_h.view(b,p,1,4).expand(b,p,4,4) x_h = (fast_inverse(w_tf)*x_h).sum(-1) else: w_tf = einsum("pn,nij->pij", w, tfs.squeeze(0)) x_h = x_h.view(p,1,4).expand(p,4,4) x_h = (w_tf*x_h).sum(-1) return x_h[:, :3], w_tf def skinning(x, w, tfs, inverse=False): """Linear blend skinning Args: x (tensor): canonical points. shape: [B, N, D] w (tensor): conditional input. [B, N, J] tfs (tensor): bone transformation matrices. shape: [B, J, D+1, D+1] Returns: x (tensor): skinned points. shape: [B, N, D] """ x_h = F.pad(x, (0, 1), value=1.0) b,p,n = w.shape if inverse: # p:n_point, n:n_bone, i,k: n_dim+1 w_tf = einsum("bpn,bnij->bpij", w, fast_inverse(tfs)) x_h = x_h.view(b,p,1,4).expand(b,p,4,4) # x_h = (fast_inverse(w_tf)*x_h).sum(-1) x_h = (w_tf*x_h).sum(-1) else: w_tf = einsum("bpn,bnij->bpij", w, tfs) x_h = x_h.view(b,p,1,4).expand(b,p,4,4) x_h = (w_tf*x_h).sum(-1) return x_h[:, :, :3], w_tf def fast_inverse(T): shape = T.shape T = T.reshape(-1,4,4) R = T[:, :3,:3] t = T[:, :3,3].unsqueeze(-1) R_inv = R.transpose(1,2) t_inv = -bmv(R_inv,t) T_inv = T T_inv[:,:3,:3] = R_inv T_inv[:,:3,3] = t_inv.squeeze(-1) return T_inv.reshape(shape) def bmv(m, v): return (m*v.transpose(-1,-2).expand(-1,3,-1)).sum(-1,keepdim=True) def create_voxel_grid(d, h, w, device='cuda'): x_range = (torch.linspace(-1,1,steps=w,device=device)).view(1, 1, 1, w).expand(1, d, h, w) # [1, H, W, D] y_range = (torch.linspace(-1,1,steps=h,device=device)).view(1, 1, h, 1).expand(1, d, h, w) # [1, H, W, D] z_range = (torch.linspace(-1,1,steps=d,device=device)).view(1, d, 1, 1).expand(1, d, h, w) # [1, H, W, D] grid = torch.cat((x_range, y_range, z_range), dim=0).reshape(1, 3,-1).permute(0,2,1) return grid def query_weights_smpl(x, smpl_verts, smpl_weights): import pytorch3d.ops as ops device = smpl_weights.device distance_batch, index_batch, neighbor_points = ops.knn_points(x.to(device),smpl_verts.to(device).detach(),K=10,return_nn=True) # neighbor_points = neighbor_points[0] distance_batch = distance_batch[0].sqrt().clamp_(0.00003,0.1) index_batch = index_batch[0] weights = smpl_weights[0,index_batch] ws=1./distance_batch ws=ws/ws.sum(-1,keepdim=True) weights = (ws[:,:,None]*weights).sum(1)[None] resolution = 64 b, c, d, h, w = 1, 55, resolution//4, resolution, resolution weights = weights.permute(0,2,1).reshape(b,c,d,h,w) return weights.detach() ================================================ FILE: lib/models/deformers/smplx/__init__.py ================================================ # -*- coding: utf-8 -*- # Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is # holder of all proprietary rights on this computer program. # You can only use this computer program if you have closed # a license agreement with MPG or you get the right to use the computer # program from someone who is authorized to grant you that right. # Any use of the computer program without a valid license is prohibited and # liable to prosecution. # # Copyright©2019 Max-Planck-Gesellschaft zur Förderung # der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute # for Intelligent Systems. All rights reserved. # # Contact: ps-license@tuebingen.mpg.de from .body_models import ( SMPL, SMPLX ) ================================================ FILE: lib/models/deformers/smplx/body_models.py ================================================ # -*- coding: utf-8 -*- # Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is # holder of all proprietary rights on this computer program. # You can only use this computer program if you have closed # a license agreement with MPG or you get the right to use the computer # program from someone who is authorized to grant you that right. # Any use of the computer program without a valid license is prohibited and # liable to prosecution. # # Copyright©2019 Max-Planck-Gesellschaft zur Förderung # der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute # for Intelligent Systems. All rights reserved. # # Contact: ps-license@tuebingen.mpg.de from typing import Optional, Dict, Union import os import os.path as osp import pickle import numpy as np import torch import torch.nn as nn from .lbs import ( lbs, blend_shapes, vertices2landmarks) from .vertex_ids import vertex_ids as VERTEX_IDS from .utils import ( Struct, to_np, to_tensor, Tensor, Array, SMPLOutput, SMPLHOutput, SMPLXOutput) from .vertex_joint_selector import VertexJointSelector class SMPL(nn.Module): NUM_JOINTS = 23 NUM_BODY_JOINTS = 23 SHAPE_SPACE_DIM = 300 def __init__( self, model_path: str, kid_template_path: str = '', data_struct: Optional[Struct] = None, create_betas: bool = True, betas: Optional[Tensor] = None, num_betas: int = 10, create_global_orient: bool = True, global_orient: Optional[Tensor] = None, create_body_pose: bool = True, body_pose: Optional[Tensor] = None, create_transl: bool = True, transl: Optional[Tensor] = None, dtype=torch.float32, batch_size: int = 1, joint_mapper=None, gender: str = 'neutral', age: str = 'adult', vertex_ids: Dict[str, int] = None, v_template: Optional[Union[Tensor, Array]] = None, **kwargs ) -> None: ''' SMPL model constructor Parameters ---------- model_path: str The path to the folder or to the file where the model parameters are stored data_struct: Strct A struct object. If given, then the parameters of the model are read from the object. Otherwise, the model tries to read the parameters from the given `model_path`. (default = None) create_global_orient: bool, optional Flag for creating a member variable for the global orientation of the body. (default = True) global_orient: torch.tensor, optional, Bx3 The default value for the global orientation variable. (default = None) create_body_pose: bool, optional Flag for creating a member variable for the pose of the body. (default = True) body_pose: torch.tensor, optional, Bx(Body Joints * 3) The default value for the body pose variable. (default = None) num_betas: int, optional Number of shape components to use (default = 10). create_betas: bool, optional Flag for creating a member variable for the shape space (default = True). betas: torch.tensor, optional, Bx10 The default value for the shape member variable. (default = None) create_transl: bool, optional Flag for creating a member variable for the translation of the body. (default = True) transl: torch.tensor, optional, Bx3 The default value for the transl variable. (default = None) dtype: torch.dtype, optional The data type for the created variables batch_size: int, optional The batch size used for creating the member variables joint_mapper: object, optional An object that re-maps the joints. Useful if one wants to re-order the SMPL joints to some other convention (e.g. MSCOCO) (default = None) gender: str, optional Which gender to load vertex_ids: dict, optional A dictionary containing the indices of the extra vertices that will be selected ''' self.gender = gender self.age = age if data_struct is None: if osp.isdir(model_path): model_fn = 'SMPL_{}.{ext}'.format(gender.upper(), ext='pkl') smpl_path = os.path.join(model_path, model_fn) else: smpl_path = model_path assert osp.exists(smpl_path), 'Path {} does not exist!'.format( smpl_path) with open(smpl_path, 'rb') as smpl_file: data_struct = Struct(**pickle.load(smpl_file, encoding='latin1')) super(SMPL, self).__init__() self.batch_size = batch_size shapedirs = data_struct.shapedirs if (shapedirs.shape[-1] < self.SHAPE_SPACE_DIM): print(f'WARNING: You are using a {self.name()} model, with only' ' 10 shape coefficients.') num_betas = min(num_betas, 10) else: num_betas = min(num_betas, self.SHAPE_SPACE_DIM) if self.age=='kid': v_template_smil = np.load(kid_template_path) v_template_smil -= np.mean(v_template_smil, axis=0) v_template_diff = np.expand_dims(v_template_smil - data_struct.v_template, axis=2) shapedirs = np.concatenate((shapedirs[:, :, :num_betas], v_template_diff), axis=2) num_betas = num_betas + 1 self._num_betas = num_betas shapedirs = shapedirs[:, :, :num_betas] # The shape components self.register_buffer( 'shapedirs', to_tensor(to_np(shapedirs), dtype=dtype)) if vertex_ids is None: # SMPL and SMPL-H share the same topology, so any extra joints can # be drawn from the same place vertex_ids = VERTEX_IDS['smplh'] self.dtype = dtype self.joint_mapper = joint_mapper self.vertex_joint_selector = VertexJointSelector( vertex_ids=vertex_ids, **kwargs) self.faces = data_struct.f self.register_buffer('faces_tensor', to_tensor(to_np(self.faces, dtype=np.int64), dtype=torch.long)) if create_betas: if betas is None: default_betas = torch.zeros( [batch_size, self.num_betas], dtype=dtype) else: if torch.is_tensor(betas): default_betas = betas.clone().detach() else: default_betas = torch.tensor(betas, dtype=dtype) self.register_parameter( 'betas', nn.Parameter(default_betas, requires_grad=True)) # The tensor that contains the global rotation of the model # It is separated from the pose of the joints in case we wish to # optimize only over one of them if create_global_orient: if global_orient is None: default_global_orient = torch.zeros( [batch_size, 3], dtype=dtype) else: if torch.is_tensor(global_orient): default_global_orient = global_orient.clone().detach() else: default_global_orient = torch.tensor( global_orient, dtype=dtype) global_orient = nn.Parameter(default_global_orient, requires_grad=True) self.register_parameter('global_orient', global_orient) if create_body_pose: if body_pose is None: default_body_pose = torch.zeros( [batch_size, self.NUM_BODY_JOINTS * 3], dtype=dtype) else: if torch.is_tensor(body_pose): default_body_pose = body_pose.clone().detach() else: default_body_pose = torch.tensor(body_pose, dtype=dtype) self.register_parameter( 'body_pose', nn.Parameter(default_body_pose, requires_grad=True)) if create_transl: if transl is None: default_transl = torch.zeros([batch_size, 3], dtype=dtype, requires_grad=True) else: default_transl = torch.tensor(transl, dtype=dtype) self.register_parameter( 'transl', nn.Parameter(default_transl, requires_grad=True)) if v_template is None: v_template = data_struct.v_template if not torch.is_tensor(v_template): v_template = to_tensor(to_np(v_template), dtype=dtype) # The vertices of the template model self.register_buffer('v_template', v_template) j_regressor = to_tensor(to_np( data_struct.J_regressor), dtype=dtype) self.register_buffer('J_regressor', j_regressor) # Pose blend shape basis: 6890 x 3 x 207, reshaped to 6890*3 x 207 num_pose_basis = data_struct.posedirs.shape[-1] # 207 x 20670 posedirs = np.reshape(data_struct.posedirs, [-1, num_pose_basis]).T self.register_buffer('posedirs', to_tensor(to_np(posedirs), dtype=dtype)) # indices of parents for each joints parents = to_tensor(to_np(data_struct.kintree_table[0])).long() parents[0] = -1 self.register_buffer('parents', parents) lbs_weights = to_tensor(to_np(data_struct.weights), dtype=dtype) self.register_buffer('lbs_weights', lbs_weights) @property def num_betas(self): return self._num_betas @property def num_expression_coeffs(self): return 0 def create_mean_pose(self, data_struct) -> Tensor: pass def name(self) -> str: return 'SMPL' @torch.no_grad() def reset_params(self, **params_dict) -> None: for param_name, param in self.named_parameters(): if param_name in params_dict: param[:] = torch.tensor(params_dict[param_name]) else: param.fill_(0) def get_num_verts(self) -> int: return self.v_template.shape[0] def get_num_faces(self) -> int: return self.faces.shape[0] def extra_repr(self) -> str: msg = [ f'Gender: {self.gender.upper()}', f'Number of joints: {self.J_regressor.shape[0]}', f'Betas: {self.num_betas}', ] return '\n'.join(msg) def forward_shape( self, betas: Optional[Tensor] = None, ) -> SMPLOutput: betas = betas if betas is not None else self.betas v_shaped = self.v_template + blend_shapes(betas, self.shapedirs) return SMPLOutput(vertices=v_shaped, betas=betas, v_shaped=v_shaped) def forward( self, betas: Optional[Tensor] = None, body_pose: Optional[Tensor] = None, global_orient: Optional[Tensor] = None, transl: Optional[Tensor] = None, return_verts=True, return_full_pose: bool = False, pose2rot: bool = True, scale: Optional[Tensor] = None, **kwargs ) -> SMPLOutput: ''' Forward pass for the SMPL model Parameters ---------- global_orient: torch.tensor, optional, shape Bx3 If given, ignore the member variable and use it as the global rotation of the body. Useful if someone wishes to predicts this with an external model. (default=None) betas: torch.tensor, optional, shape BxN_b If given, ignore the member variable `betas` and use it instead. For example, it can used if shape parameters `betas` are predicted from some external model. (default=None) body_pose: torch.tensor, optional, shape Bx(J*3) If given, ignore the member variable `body_pose` and use it instead. For example, it can used if someone predicts the pose of the body joints are predicted from some external model. It should be a tensor that contains joint rotations in axis-angle format. (default=None) transl: torch.tensor, optional, shape Bx3 If given, ignore the member variable `transl` and use it instead. For example, it can used if the translation `transl` is predicted from some external model. (default=None) return_verts: bool, optional Return the vertices. (default=True) return_full_pose: bool, optional Returns the full axis-angle pose vector (default=False) Returns ------- ''' # If no shape and pose parameters are passed along, then use the # ones from the module global_orient = (global_orient if global_orient is not None else self.global_orient) body_pose = body_pose if body_pose is not None else self.body_pose betas = betas if betas is not None else self.betas apply_trans = transl is not None or hasattr(self, 'transl') if transl is None and hasattr(self, 'transl'): transl = self.transl full_pose = torch.cat([global_orient, body_pose], dim=1) scale = scale if scale is not None else torch.ones([global_orient.shape[0], 1], dtype=global_orient.dtype, device = global_orient.device) batch_size = max(betas.shape[0], global_orient.shape[0], body_pose.shape[0]) if betas.shape[0] != batch_size: num_repeats = int(batch_size / betas.shape[0]) betas = betas.expand(num_repeats, -1) vertices, joints, A, T, shape_offset, pose_offset = lbs( betas, full_pose, self.v_template, self.shapedirs, self.posedirs, self.J_regressor, self.parents, self.lbs_weights, pose2rot=pose2rot) joints = self.vertex_joint_selector(vertices, joints) # Map the joints to the current dataset if self.joint_mapper is not None: joints = self.joint_mapper(joints) if apply_trans: joints += transl.unsqueeze(dim=1) vertices += transl.unsqueeze(dim=1) A[..., :3, 3] += transl.unsqueeze(dim=1) T[..., :3, 3] += transl.unsqueeze(dim=1) joints = joints * (scale.reshape(-1,1,1)) vertices = vertices * (scale.reshape(-1,1,1)) A[..., :3,:3] = A[..., :3,:3] * (scale.reshape(-1, 1,1,1)) T[..., :3,:3] = T[..., :3,:3] * (scale.reshape(-1,1,1,1)) output = SMPLOutput(vertices=vertices if return_verts else None, global_orient=global_orient, body_pose=body_pose, joints=joints, betas=betas, full_pose=full_pose if return_full_pose else None, A=A, T=T, shape_offset=shape_offset, pose_offset=pose_offset) return output class SMPLH(SMPL): # The hand joints are replaced by MANO NUM_BODY_JOINTS = SMPL.NUM_JOINTS - 2 NUM_HAND_JOINTS = 15 NUM_JOINTS = NUM_BODY_JOINTS + 2 * NUM_HAND_JOINTS def __init__( self, model_path, kid_template_path: str = '', data_struct: Optional[Struct] = None, create_left_hand_pose: bool = True, left_hand_pose: Optional[Tensor] = None, create_right_hand_pose: bool = True, right_hand_pose: Optional[Tensor] = None, use_pca: bool = True, num_pca_comps: int = 6, num_betas=16, flat_hand_mean: bool = False, batch_size: int = 1, gender: str = 'neutral', age: str = 'adult', dtype=torch.float32, vertex_ids=None, use_compressed: bool = True, ext: str = 'pkl', **kwargs ) -> None: ''' SMPLH model constructor Parameters ---------- model_path: str The path to the folder or to the file where the model parameters are stored data_struct: Strct A struct object. If given, then the parameters of the model are read from the object. Otherwise, the model tries to read the parameters from the given `model_path`. (default = None) create_left_hand_pose: bool, optional Flag for creating a member variable for the pose of the left hand. (default = True) left_hand_pose: torch.tensor, optional, BxP The default value for the left hand pose member variable. (default = None) create_right_hand_pose: bool, optional Flag for creating a member variable for the pose of the right hand. (default = True) right_hand_pose: torch.tensor, optional, BxP The default value for the right hand pose member variable. (default = None) num_pca_comps: int, optional The number of PCA components to use for each hand. (default = 6) flat_hand_mean: bool, optional If False, then the pose of the hand is initialized to False. batch_size: int, optional The batch size used for creating the member variables gender: str, optional Which gender to load dtype: torch.dtype, optional The data type for the created variables vertex_ids: dict, optional A dictionary containing the indices of the extra vertices that will be selected ''' self.num_pca_comps = num_pca_comps # If no data structure is passed, then load the data from the given # model folder if data_struct is None: # Load the model if osp.isdir(model_path): model_fn = 'SMPLH_{}.{ext}'.format(gender.upper(), ext=ext) smplh_path = os.path.join(model_path, model_fn) else: smplh_path = model_path assert osp.exists(smplh_path), 'Path {} does not exist!'.format( smplh_path) if ext == 'pkl': with open(smplh_path, 'rb') as smplh_file: model_data = pickle.load(smplh_file, encoding='latin1') elif ext == 'npz': model_data = np.load(smplh_path, allow_pickle=True) else: raise ValueError('Unknown extension: {}'.format(ext)) data_struct = Struct(**model_data) if vertex_ids is None: vertex_ids = VERTEX_IDS['smplh'] super(SMPLH, self).__init__( model_path=model_path, kid_template_path=kid_template_path, data_struct=data_struct, num_betas=num_betas, batch_size=batch_size, vertex_ids=vertex_ids, gender=gender, age=age, use_compressed=use_compressed, dtype=dtype, ext=ext, **kwargs) self.use_pca = use_pca self.num_pca_comps = num_pca_comps self.flat_hand_mean = flat_hand_mean left_hand_components = data_struct.hands_componentsl[:num_pca_comps] right_hand_components = data_struct.hands_componentsr[:num_pca_comps] self.np_left_hand_components = left_hand_components self.np_right_hand_components = right_hand_components if self.use_pca: self.register_buffer( 'left_hand_components', torch.tensor(left_hand_components, dtype=dtype)) self.register_buffer( 'right_hand_components', torch.tensor(right_hand_components, dtype=dtype)) if self.flat_hand_mean: left_hand_mean = np.zeros_like(data_struct.hands_meanl) else: left_hand_mean = data_struct.hands_meanl if self.flat_hand_mean: right_hand_mean = np.zeros_like(data_struct.hands_meanr) else: right_hand_mean = data_struct.hands_meanr self.register_buffer('left_hand_mean', to_tensor(left_hand_mean, dtype=self.dtype)) self.register_buffer('right_hand_mean', to_tensor(right_hand_mean, dtype=self.dtype)) # Create the buffers for the pose of the left hand hand_pose_dim = num_pca_comps if use_pca else 3 * self.NUM_HAND_JOINTS if create_left_hand_pose: if left_hand_pose is None: default_lhand_pose = torch.zeros([batch_size, hand_pose_dim], dtype=dtype) else: default_lhand_pose = torch.tensor(left_hand_pose, dtype=dtype) left_hand_pose_param = nn.Parameter(default_lhand_pose, requires_grad=True) self.register_parameter('left_hand_pose', left_hand_pose_param) if create_right_hand_pose: if right_hand_pose is None: default_rhand_pose = torch.zeros([batch_size, hand_pose_dim], dtype=dtype) else: default_rhand_pose = torch.tensor(right_hand_pose, dtype=dtype) right_hand_pose_param = nn.Parameter(default_rhand_pose, requires_grad=True) self.register_parameter('right_hand_pose', right_hand_pose_param) # Create the buffer for the mean pose. pose_mean_tensor = self.create_mean_pose( data_struct, flat_hand_mean=flat_hand_mean) if not torch.is_tensor(pose_mean_tensor): pose_mean_tensor = torch.tensor(pose_mean_tensor, dtype=dtype) self.register_buffer('pose_mean', pose_mean_tensor) def create_mean_pose(self, data_struct, flat_hand_mean=False): # Create the array for the mean pose. If flat_hand is false, then use # the mean that is given by the data, rather than the flat open hand global_orient_mean = torch.zeros([3], dtype=self.dtype) body_pose_mean = torch.zeros([self.NUM_BODY_JOINTS * 3], dtype=self.dtype) pose_mean = torch.cat([global_orient_mean, body_pose_mean, self.left_hand_mean, self.right_hand_mean], dim=0) return pose_mean def name(self) -> str: return 'SMPL+H' def extra_repr(self): msg = super(SMPLH, self).extra_repr() msg = [msg] if self.use_pca: msg.append(f'Number of PCA components: {self.num_pca_comps}') msg.append(f'Flat hand mean: {self.flat_hand_mean}') return '\n'.join(msg) def forward( self, betas: Optional[Tensor] = None, global_orient: Optional[Tensor] = None, body_pose: Optional[Tensor] = None, left_hand_pose: Optional[Tensor] = None, right_hand_pose: Optional[Tensor] = None, transl: Optional[Tensor] = None, return_verts: bool = True, return_full_pose: bool = False, pose2rot: bool = True, **kwargs ) -> SMPLHOutput: ''' ''' # If no shape and pose parameters are passed along, then use the # ones from the module global_orient = (global_orient if global_orient is not None else self.global_orient) body_pose = body_pose if body_pose is not None else self.body_pose betas = betas if betas is not None else self.betas left_hand_pose = (left_hand_pose if left_hand_pose is not None else self.left_hand_pose) right_hand_pose = (right_hand_pose if right_hand_pose is not None else self.right_hand_pose) apply_trans = transl is not None or hasattr(self, 'transl') if transl is None: if hasattr(self, 'transl'): transl = self.transl if self.use_pca: left_hand_pose = torch.einsum( 'bi,ij->bj', [left_hand_pose, self.left_hand_components]) right_hand_pose = torch.einsum( 'bi,ij->bj', [right_hand_pose, self.right_hand_components]) full_pose = torch.cat([global_orient, body_pose, left_hand_pose, right_hand_pose], dim=1) full_pose += self.pose_mean vertices, joints = lbs(betas, full_pose, self.v_template, self.shapedirs, self.posedirs, self.J_regressor, self.parents, self.lbs_weights, pose2rot=pose2rot) # Add any extra joints that might be needed joints = self.vertex_joint_selector(vertices, joints) if self.joint_mapper is not None: joints = self.joint_mapper(joints) if apply_trans: joints += transl.unsqueeze(dim=1) vertices += transl.unsqueeze(dim=1) output = SMPLHOutput(vertices=vertices if return_verts else None, joints=joints, betas=betas, global_orient=global_orient, body_pose=body_pose, left_hand_pose=left_hand_pose, right_hand_pose=right_hand_pose, full_pose=full_pose if return_full_pose else None) return output class SMPLX(SMPLH): ''' SMPL-X (SMPL eXpressive) is a unified body model, with shape parameters trained jointly for the face, hands and body. SMPL-X uses standard vertex based linear blend skinning with learned corrective blend shapes, has N=10475 vertices and K=54 joints, which includes joints for the neck, jaw, eyeballs and fingers. ''' NUM_BODY_JOINTS = SMPLH.NUM_BODY_JOINTS NUM_HAND_JOINTS = 15 NUM_FACE_JOINTS = 3 NUM_JOINTS = NUM_BODY_JOINTS + 2 * NUM_HAND_JOINTS + NUM_FACE_JOINTS EXPRESSION_SPACE_DIM = 100 NECK_IDX = 12 def __init__( self, model_path: str, kid_template_path: str = '', num_expression_coeffs: int = 10, create_expression: bool = True, expression: Optional[Tensor] = None, create_jaw_pose: bool = True, jaw_pose: Optional[Tensor] = None, create_leye_pose: bool = True, leye_pose: Optional[Tensor] = None, create_reye_pose=True, reye_pose: Optional[Tensor] = None, use_face_contour: bool = False, batch_size: int = 1, gender: str = 'neutral', age: str = 'adult', dtype=torch.float32, ext: str = 'npz', **kwargs ) -> None: ''' SMPLX model constructor Parameters ---------- model_path: str The path to the folder or to the file where the model parameters are stored num_expression_coeffs: int, optional Number of expression components to use (default = 10). create_expression: bool, optional Flag for creating a member variable for the expression space (default = True). expression: torch.tensor, optional, Bx10 The default value for the expression member variable. (default = None) create_jaw_pose: bool, optional Flag for creating a member variable for the jaw pose. (default = False) jaw_pose: torch.tensor, optional, Bx3 The default value for the jaw pose variable. (default = None) create_leye_pose: bool, optional Flag for creating a member variable for the left eye pose. (default = False) leye_pose: torch.tensor, optional, Bx10 The default value for the left eye pose variable. (default = None) create_reye_pose: bool, optional Flag for creating a member variable for the right eye pose. (default = False) reye_pose: torch.tensor, optional, Bx10 The default value for the right eye pose variable. (default = None) use_face_contour: bool, optional Whether to compute the keypoints that form the facial contour batch_size: int, optional The batch size used for creating the member variables gender: str, optional Which gender to load dtype: torch.dtype The data type for the created variables ''' # Load the model if osp.isdir(model_path): model_fn = 'SMPLX_{}.{ext}'.format(gender.upper(), ext=ext) smplx_path = os.path.join(model_path, model_fn) else: smplx_path = model_path assert osp.exists(smplx_path), 'Path {} does not exist!'.format( smplx_path) if ext == 'pkl': with open(smplx_path, 'rb') as smplx_file: model_data = pickle.load(smplx_file, encoding='latin1') elif ext == 'npz': model_data = np.load(smplx_path, allow_pickle=True) else: raise ValueError('Unknown extension: {}'.format(ext)) data_struct = Struct(**model_data) super(SMPLX, self).__init__( model_path=model_path, kid_template_path=kid_template_path, data_struct=data_struct, dtype=dtype, batch_size=batch_size, vertex_ids=VERTEX_IDS['smplx'], gender=gender, age=age, ext=ext, **kwargs) lmk_faces_idx = data_struct.lmk_faces_idx self.register_buffer('lmk_faces_idx', torch.tensor(lmk_faces_idx, dtype=torch.long)) lmk_bary_coords = data_struct.lmk_bary_coords self.register_buffer('lmk_bary_coords', torch.tensor(lmk_bary_coords, dtype=dtype)) self.bone_parents = to_np(data_struct.kintree_table[0]) self.use_face_contour = use_face_contour if self.use_face_contour: dynamic_lmk_faces_idx = data_struct.dynamic_lmk_faces_idx dynamic_lmk_faces_idx = torch.tensor( dynamic_lmk_faces_idx, dtype=torch.long) self.register_buffer('dynamic_lmk_faces_idx', dynamic_lmk_faces_idx) dynamic_lmk_bary_coords = data_struct.dynamic_lmk_bary_coords dynamic_lmk_bary_coords = torch.tensor( dynamic_lmk_bary_coords, dtype=dtype) self.register_buffer('dynamic_lmk_bary_coords', dynamic_lmk_bary_coords) neck_kin_chain = find_joint_kin_chain(self.NECK_IDX, self.parents) self.register_buffer( 'neck_kin_chain', torch.tensor(neck_kin_chain, dtype=torch.long)) if create_jaw_pose: if jaw_pose is None: default_jaw_pose = torch.zeros([batch_size, 3], dtype=dtype) else: default_jaw_pose = torch.tensor(jaw_pose, dtype=dtype) jaw_pose_param = nn.Parameter(default_jaw_pose, requires_grad=True) self.register_parameter('jaw_pose', jaw_pose_param) if create_leye_pose: if leye_pose is None: default_leye_pose = torch.zeros([batch_size, 3], dtype=dtype) else: default_leye_pose = torch.tensor(leye_pose, dtype=dtype) leye_pose_param = nn.Parameter(default_leye_pose, requires_grad=True) self.register_parameter('leye_pose', leye_pose_param) if create_reye_pose: if reye_pose is None: default_reye_pose = torch.zeros([batch_size, 3], dtype=dtype) else: default_reye_pose = torch.tensor(reye_pose, dtype=dtype) reye_pose_param = nn.Parameter(default_reye_pose, requires_grad=True) self.register_parameter('reye_pose', reye_pose_param) shapedirs = data_struct.shapedirs if len(shapedirs.shape) < 3: shapedirs = shapedirs[:, :, None] if (shapedirs.shape[-1] < self.SHAPE_SPACE_DIM + self.EXPRESSION_SPACE_DIM): print(f'WARNING: You are using a {self.name()} model, with only' ' 10 shape and 10 expression coefficients.') expr_start_idx = 10 expr_end_idx = 20 num_expression_coeffs = min(num_expression_coeffs, 10) else: expr_start_idx = self.SHAPE_SPACE_DIM expr_end_idx = self.SHAPE_SPACE_DIM + num_expression_coeffs num_expression_coeffs = min( num_expression_coeffs, self.EXPRESSION_SPACE_DIM) self._num_expression_coeffs = num_expression_coeffs expr_dirs = shapedirs[:, :, expr_start_idx:expr_end_idx] self.register_buffer( 'expr_dirs', to_tensor(to_np(expr_dirs), dtype=dtype)) if create_expression: if expression is None: default_expression = torch.zeros( [batch_size, self.num_expression_coeffs], dtype=dtype) else: default_expression = torch.tensor(expression, dtype=dtype) expression_param = nn.Parameter(default_expression, requires_grad=True) self.register_parameter('expression', expression_param) def name(self) -> str: return 'SMPL-X' @property def num_expression_coeffs(self): return self._num_expression_coeffs def create_mean_pose(self, data_struct, flat_hand_mean=False): # Create the array for the mean pose. If flat_hand is false, then use # the mean that is given by the data, rather than the flat open hand global_orient_mean = torch.zeros([3], dtype=self.dtype) body_pose_mean = torch.zeros([self.NUM_BODY_JOINTS * 3], dtype=self.dtype) jaw_pose_mean = torch.zeros([3], dtype=self.dtype) leye_pose_mean = torch.zeros([3], dtype=self.dtype) reye_pose_mean = torch.zeros([3], dtype=self.dtype) pose_mean = np.concatenate([global_orient_mean, body_pose_mean, jaw_pose_mean, leye_pose_mean, reye_pose_mean, self.left_hand_mean, self.right_hand_mean], axis=0) return pose_mean def extra_repr(self): msg = super(SMPLX, self).extra_repr() msg = [ msg, f'Number of Expression Coefficients: {self.num_expression_coeffs}' ] return '\n'.join(msg) def forward( self, betas: Optional[Tensor] = None, global_orient: Optional[Tensor] = None, body_pose: Optional[Tensor] = None, left_hand_pose: Optional[Tensor] = None, right_hand_pose: Optional[Tensor] = None, transl: Optional[Tensor] = None, expression: Optional[Tensor] = None, jaw_pose: Optional[Tensor] = None, leye_pose: Optional[Tensor] = None, reye_pose: Optional[Tensor] = None, return_verts: bool = True, return_full_pose: bool = False, pose2rot: bool = True, return_shaped: bool = True, use_pca:bool = True, # specify where to use pca 12 for hands' pose **kwargs ) -> SMPLXOutput: ''' Forward pass for the SMPLX model Parameters ---------- global_orient: torch.tensor, optional, shape Bx3 If given, ignore the member variable and use it as the global rotation of the body. Useful if someone wishes to predicts this with an external model. (default=None) betas: torch.tensor, optional, shape BxN_b If given, ignore the member variable `betas` and use it instead. For example, it can used if shape parameters `betas` are predicted from some external model. (default=None) expression: torch.tensor, optional, shape BxN_e If given, ignore the member variable `expression` and use it instead. For example, it can used if expression parameters `expression` are predicted from some external model. body_pose: torch.tensor, optional, shape Bx(J*3) If given, ignore the member variable `body_pose` and use it instead. For example, it can used if someone predicts the pose of the body joints are predicted from some external model. It should be a tensor that contains joint rotations in axis-angle format. (default=None) left_hand_pose: torch.tensor, optional, shape BxP If given, ignore the member variable `left_hand_pose` and use this instead. It should either contain PCA coefficients or joint rotations in axis-angle format. right_hand_pose: torch.tensor, optional, shape BxP If given, ignore the member variable `right_hand_pose` and use this instead. It should either contain PCA coefficients or joint rotations in axis-angle format. jaw_pose: torch.tensor, optional, shape Bx3 If given, ignore the member variable `jaw_pose` and use this instead. It should either joint rotations in axis-angle format. transl: torch.tensor, optional, shape Bx3 If given, ignore the member variable `transl` and use it instead. For example, it can used if the translation `transl` is predicted from some external model. (default=None) return_verts: bool, optional Return the vertices. (default=True) return_full_pose: bool, optional Returns the full axis-angle pose vector (default=False) Returns ------- output: ModelOutput A named tuple of type `ModelOutput` ''' # If no shape and pose parameters are passed along, then use the # ones from the module if global_orient is None: assert False global_orient = (global_orient if global_orient is not None else self.global_orient) body_pose = body_pose if body_pose is not None else self.body_pose betas = betas if betas is not None else self.betas left_hand_pose = (left_hand_pose if left_hand_pose is not None else self.left_hand_pose) right_hand_pose = (right_hand_pose if right_hand_pose is not None else self.right_hand_pose) jaw_pose = jaw_pose if jaw_pose is not None else self.jaw_pose leye_pose = leye_pose if leye_pose is not None else self.leye_pose reye_pose = reye_pose if reye_pose is not None else self.reye_pose expression = expression if expression is not None else self.expression apply_trans = transl is not None or hasattr(self, 'transl') if transl is None: if hasattr(self, 'transl'): transl = self.transl if self.use_pca and use_pca: left_hand_pose = torch.einsum( 'bi,ij->bj', [left_hand_pose, self.left_hand_components]) right_hand_pose = torch.einsum( 'bi,ij->bj', [right_hand_pose, self.right_hand_components]) full_pose = torch.cat([global_orient.reshape(-1, 1, 3), body_pose.reshape(-1, self.NUM_BODY_JOINTS, 3), jaw_pose.reshape(-1, 1, 3), leye_pose.reshape(-1, 1, 3), reye_pose.reshape(-1, 1, 3), left_hand_pose.reshape(-1, 15, 3), right_hand_pose.reshape(-1, 15, 3)], dim=1).reshape(-1, 165) # Add the mean pose of the model. Does not affect the body, only the # hands when flat_hand_mean == False full_pose += self.pose_mean batch_size = max(betas.shape[0], global_orient.shape[0], body_pose.shape[0]) # Concatenate the shape and expression coefficients scale = int(batch_size / betas.shape[0]) if scale > 1: betas = betas.expand(scale, -1) shape_components = torch.cat([betas, expression], dim=-1) shapedirs = torch.cat([self.shapedirs, self.expr_dirs], dim=-1) vertices, joints_smpl, A, T, shape_offset, pose_offset, pose_feature = lbs(shape_components, full_pose, self.v_template, shapedirs, self.posedirs, self.J_regressor, self.parents, self.lbs_weights, pose2rot=pose2rot) lmk_faces_idx = self.lmk_faces_idx.unsqueeze( dim=0).expand(batch_size, -1).contiguous() lmk_bary_coords = self.lmk_bary_coords.unsqueeze(dim=0).repeat( self.batch_size, 1, 1) if self.use_face_contour: lmk_idx_and_bcoords = find_dynamic_lmk_idx_and_bcoords( vertices, full_pose, self.dynamic_lmk_faces_idx, self.dynamic_lmk_bary_coords, self.neck_kin_chain, pose2rot=True, ) dyn_lmk_faces_idx, dyn_lmk_bary_coords = lmk_idx_and_bcoords lmk_faces_idx = torch.cat([lmk_faces_idx, dyn_lmk_faces_idx], 1) lmk_bary_coords = torch.cat( [lmk_bary_coords.expand(batch_size, -1, -1), dyn_lmk_bary_coords], 1) landmarks = vertices2landmarks(vertices, self.faces_tensor, lmk_faces_idx, lmk_bary_coords) # Add any extra joints that might be needed joints = self.vertex_joint_selector(vertices, joints_smpl) # Add the landmarks to the joints joints = torch.cat([joints, landmarks], dim=1) # Map the joints to the current dataset if self.joint_mapper is not None: joints = self.joint_mapper(joints=joints, vertices=vertices) if apply_trans: joints_smpl += transl.unsqueeze(dim=1) joints += transl.unsqueeze(dim=1) vertices += transl.unsqueeze(dim=1) A[..., :3, 3] += transl.unsqueeze(dim=1) T[..., :3, 3] += transl.unsqueeze(dim=1) v_shaped = None if return_shaped: v_shaped = self.v_template + blend_shapes(betas, self.shapedirs) else: v_shaped = Tensor(0) output = SMPLXOutput(vertices=vertices if return_verts else None, joints=joints_smpl, betas=shape_components, expression=expression, global_orient=global_orient, body_pose=body_pose, v_shaped=v_shaped, full_pose=full_pose if return_full_pose else None, A=A, T=T, shape_offset=shape_offset, pose_offset=pose_offset, pose_feature=pose_feature) return output ================================================ FILE: lib/models/deformers/smplx/joint_names.py ================================================ # -*- coding: utf-8 -*- # Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is # holder of all proprietary rights on this computer program. # You can only use this computer program if you have closed # a license agreement with MPG or you get the right to use the computer # program from someone who is authorized to grant you that right. # Any use of the computer program without a valid license is prohibited and # liable to prosecution. # # Copyright©2019 Max-Planck-Gesellschaft zur Förderung # der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute # for Intelligent Systems. All rights reserved. # # Contact: ps-license@tuebingen.mpg.de JOINT_NAMES = [ 'pelvis', 'left_hip', 'right_hip', 'spine1', 'left_knee', 'right_knee', 'spine2', 'left_ankle', 'right_ankle', 'spine3', 'left_foot', 'right_foot', 'neck', 'left_collar', 'right_collar', 'head', 'left_shoulder', 'right_shoulder', 'left_elbow', 'right_elbow', 'left_wrist', 'right_wrist', 'jaw', 'left_eye_smplhf', 'right_eye_smplhf', 'left_index1', 'left_index2', 'left_index3', 'left_middle1', 'left_middle2', 'left_middle3', 'left_pinky1', 'left_pinky2', 'left_pinky3', 'left_ring1', 'left_ring2', 'left_ring3', 'left_thumb1', 'left_thumb2', 'left_thumb3', 'right_index1', 'right_index2', 'right_index3', 'right_middle1', 'right_middle2', 'right_middle3', 'right_pinky1', 'right_pinky2', 'right_pinky3', 'right_ring1', 'right_ring2', 'right_ring3', 'right_thumb1', 'right_thumb2', 'right_thumb3', 'nose', 'right_eye', 'left_eye', 'right_ear', 'left_ear', 'left_big_toe', 'left_small_toe', 'left_heel', 'right_big_toe', 'right_small_toe', 'right_heel', 'left_thumb', 'left_index', 'left_middle', 'left_ring', 'left_pinky', 'right_thumb', 'right_index', 'right_middle', 'right_ring', 'right_pinky', 'right_eye_brow1', 'right_eye_brow2', 'right_eye_brow3', 'right_eye_brow4', 'right_eye_brow5', 'left_eye_brow5', 'left_eye_brow4', 'left_eye_brow3', 'left_eye_brow2', 'left_eye_brow1', 'nose1', 'nose2', 'nose3', 'nose4', 'right_nose_2', 'right_nose_1', 'nose_middle', 'left_nose_1', 'left_nose_2', 'right_eye1', 'right_eye2', 'right_eye3', 'right_eye4', 'right_eye5', 'right_eye6', 'left_eye4', 'left_eye3', 'left_eye2', 'left_eye1', 'left_eye6', 'left_eye5', 'right_mouth_1', 'right_mouth_2', 'right_mouth_3', 'mouth_top', 'left_mouth_3', 'left_mouth_2', 'left_mouth_1', 'left_mouth_5', # 59 in OpenPose output 'left_mouth_4', # 58 in OpenPose output 'mouth_bottom', 'right_mouth_4', 'right_mouth_5', 'right_lip_1', 'right_lip_2', 'lip_top', 'left_lip_2', 'left_lip_1', 'left_lip_3', 'lip_bottom', 'right_lip_3', # Face contour 'right_contour_1', 'right_contour_2', 'right_contour_3', 'right_contour_4', 'right_contour_5', 'right_contour_6', 'right_contour_7', 'right_contour_8', 'contour_middle', 'left_contour_8', 'left_contour_7', 'left_contour_6', 'left_contour_5', 'left_contour_4', 'left_contour_3', 'left_contour_2', 'left_contour_1', ] SMPLH_JOINT_NAMES = [ 'pelvis', 'left_hip', 'right_hip', 'spine1', 'left_knee', 'right_knee', 'spine2', 'left_ankle', 'right_ankle', 'spine3', 'left_foot', 'right_foot', 'neck', 'left_collar', 'right_collar', 'head', 'left_shoulder', 'right_shoulder', 'left_elbow', 'right_elbow', 'left_wrist', 'right_wrist', 'left_index1', 'left_index2', 'left_index3', 'left_middle1', 'left_middle2', 'left_middle3', 'left_pinky1', 'left_pinky2', 'left_pinky3', 'left_ring1', 'left_ring2', 'left_ring3', 'left_thumb1', 'left_thumb2', 'left_thumb3', 'right_index1', 'right_index2', 'right_index3', 'right_middle1', 'right_middle2', 'right_middle3', 'right_pinky1', 'right_pinky2', 'right_pinky3', 'right_ring1', 'right_ring2', 'right_ring3', 'right_thumb1', 'right_thumb2', 'right_thumb3', 'nose', 'right_eye', 'left_eye', 'right_ear', 'left_ear', 'left_big_toe', 'left_small_toe', 'left_heel', 'right_big_toe', 'right_small_toe', 'right_heel', 'left_thumb', 'left_index', 'left_middle', 'left_ring', 'left_pinky', 'right_thumb', 'right_index', 'right_middle', 'right_ring', 'right_pinky', ] ================================================ FILE: lib/models/deformers/smplx/lbs.py ================================================ # -*- coding: utf-8 -*- # Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is # holder of all proprietary rights on this computer program. # You can only use this computer program if you have closed # a license agreement with MPG or you get the right to use the computer # program from someone who is authorized to grant you that right. # Any use of the computer program without a valid license is prohibited and # liable to prosecution. # # Copyright©2019 Max-Planck-Gesellschaft zur Förderung # der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute # for Intelligent Systems. All rights reserved. # # Contact: ps-license@tuebingen.mpg.de from __future__ import absolute_import from __future__ import print_function from __future__ import division from typing import Tuple, List import numpy as np import torch import torch.nn.functional as F from .utils import rot_mat_to_euler, Tensor def find_dynamic_lmk_idx_and_bcoords( vertices: Tensor, pose: Tensor, dynamic_lmk_faces_idx: Tensor, dynamic_lmk_b_coords: Tensor, neck_kin_chain: List[int], pose2rot: bool = True, ) -> Tuple[Tensor, Tensor]: ''' Compute the faces, barycentric coordinates for the dynamic landmarks To do so, we first compute the rotation of the neck around the y-axis and then use a pre-computed look-up table to find the faces and the barycentric coordinates that will be used. Special thanks to Soubhik Sanyal (soubhik.sanyal@tuebingen.mpg.de) for providing the original TensorFlow implementation and for the LUT. Parameters ---------- vertices: torch.tensor BxVx3, dtype = torch.float32 The tensor of input vertices pose: torch.tensor Bx(Jx3), dtype = torch.float32 The current pose of the body model dynamic_lmk_faces_idx: torch.tensor L, dtype = torch.long The look-up table from neck rotation to faces dynamic_lmk_b_coords: torch.tensor Lx3, dtype = torch.float32 The look-up table from neck rotation to barycentric coordinates neck_kin_chain: list A python list that contains the indices of the joints that form the kinematic chain of the neck. dtype: torch.dtype, optional Returns ------- dyn_lmk_faces_idx: torch.tensor, dtype = torch.long A tensor of size BxL that contains the indices of the faces that will be used to compute the current dynamic landmarks. dyn_lmk_b_coords: torch.tensor, dtype = torch.float32 A tensor of size BxL that contains the indices of the faces that will be used to compute the current dynamic landmarks. ''' dtype = vertices.dtype batch_size = vertices.shape[0] if pose2rot: aa_pose = torch.index_select(pose.view(batch_size, -1, 3), 1, neck_kin_chain) rot_mats = batch_rodrigues( aa_pose.view(-1, 3)).view(batch_size, -1, 3, 3) else: rot_mats = torch.index_select( pose.view(batch_size, -1, 3, 3), 1, neck_kin_chain) rel_rot_mat = torch.eye( 3, device=vertices.device, dtype=dtype).unsqueeze_(dim=0).repeat( batch_size, 1, 1) for idx in range(len(neck_kin_chain)): rel_rot_mat = torch.bmm(rot_mats[:, idx], rel_rot_mat) y_rot_angle = torch.round( torch.clamp(-rot_mat_to_euler(rel_rot_mat) * 180.0 / np.pi, max=39)).to(dtype=torch.long) neg_mask = y_rot_angle.lt(0).to(dtype=torch.long) mask = y_rot_angle.lt(-39).to(dtype=torch.long) neg_vals = mask * 78 + (1 - mask) * (39 - y_rot_angle) y_rot_angle = (neg_mask * neg_vals + (1 - neg_mask) * y_rot_angle) dyn_lmk_faces_idx = torch.index_select(dynamic_lmk_faces_idx, 0, y_rot_angle) dyn_lmk_b_coords = torch.index_select(dynamic_lmk_b_coords, 0, y_rot_angle) return dyn_lmk_faces_idx, dyn_lmk_b_coords def vertices2landmarks( vertices: Tensor, faces: Tensor, lmk_faces_idx: Tensor, lmk_bary_coords: Tensor ) -> Tensor: ''' Calculates landmarks by barycentric interpolation Parameters ---------- vertices: torch.tensor BxVx3, dtype = torch.float32 The tensor of input vertices faces: torch.tensor Fx3, dtype = torch.long The faces of the mesh lmk_faces_idx: torch.tensor L, dtype = torch.long The tensor with the indices of the faces used to calculate the landmarks. lmk_bary_coords: torch.tensor Lx3, dtype = torch.float32 The tensor of barycentric coordinates that are used to interpolate the landmarks Returns ------- landmarks: torch.tensor BxLx3, dtype = torch.float32 The coordinates of the landmarks for each mesh in the batch ''' # Extract the indices of the vertices for each face # BxLx3 batch_size, num_verts = vertices.shape[:2] device = vertices.device lmk_faces = torch.index_select(faces, 0, lmk_faces_idx.view(-1)).view( batch_size, -1, 3) lmk_faces += torch.arange( batch_size, dtype=torch.long, device=device).view(-1, 1, 1) * num_verts lmk_vertices = vertices.view(-1, 3)[lmk_faces].view( batch_size, -1, 3, 3) landmarks = torch.einsum('blfi,blf->bli', [lmk_vertices, lmk_bary_coords]) return landmarks def lbs( betas: Tensor, pose: Tensor, v_template: Tensor, shapedirs: Tensor, posedirs: Tensor, J_regressor: Tensor, parents: Tensor, lbs_weights: Tensor, pose2rot: bool = True, ) -> Tuple[Tensor, Tensor]: ''' Performs Linear Blend Skinning with the given shape and pose parameters Parameters ---------- betas : torch.tensor BxNB The tensor of shape parameters pose : torch.tensor Bx(J + 1) * 3 The pose parameters in axis-angle format v_template torch.tensor BxVx3 The template mesh that will be deformed shapedirs : torch.tensor 1xNB The tensor of PCA shape displacements posedirs : torch.tensor Px(V * 3) The pose PCA coefficients J_regressor : torch.tensor JxV The regressor array that is used to calculate the joints from the position of the vertices parents: torch.tensor J The array that describes the kinematic tree for the model lbs_weights: torch.tensor N x V x (J + 1) The linear blend skinning weights that represent how much the rotation matrix of each part affects each vertex pose2rot: bool, optional Flag on whether to convert the input pose tensor to rotation matrices. The default value is True. If False, then the pose tensor should already contain rotation matrices and have a size of Bx(J + 1)x9 dtype: torch.dtype, optional Returns ------- verts: torch.tensor BxVx3 The vertices of the mesh after applying the shape and pose displacements. joints: torch.tensor BxJx3 The joints of the model ''' batch_size = max(betas.shape[0], pose.shape[0]) device, dtype = betas.device, betas.dtype # Add shape contribution shape_offset = blend_shapes(betas, shapedirs) v_shaped = v_template + shape_offset # Get the joints # NxJx3 array J = vertices2joints(J_regressor, v_shaped) # 3. Add pose blend shapes # N x J x 3 x 3 ident = torch.eye(3, dtype=dtype, device=device) if pose2rot: rot_mats = batch_rodrigues(pose.view(-1, 3)).view( [batch_size, -1, 3, 3]) pose_feature = (rot_mats[:, 1:, :, :] - ident).view([batch_size, -1]) # (N x P) x (P, V * 3) -> N x V x 3 pose_offsets = torch.matmul( pose_feature, posedirs).view(batch_size, -1, 3) else: pose_feature = pose[:, 1:].view(batch_size, -1, 3, 3) - ident rot_mats = pose.view(batch_size, -1, 3, 3) pose_offsets = torch.matmul(pose_feature.view(batch_size, -1), posedirs).view(batch_size, -1, 3) v_posed = pose_offsets + v_shaped # 4. Get the global joint location J_transformed, A = batch_rigid_transform(rot_mats, J, parents, dtype=dtype) # 5. Do skinning: # W is N x V x (J + 1) W = lbs_weights.unsqueeze(dim=0).expand([batch_size, -1, -1]) # (N x V x (J + 1)) x (N x (J + 1) x 16) num_joints = J_regressor.shape[0] T = torch.matmul(W, A.view(batch_size, num_joints, 16)) \ .view(batch_size, -1, 4, 4) homogen_coord = torch.ones([batch_size, v_posed.shape[1], 1], dtype=dtype, device=device) v_posed_homo = torch.cat([v_posed, homogen_coord], dim=2) v_homo = torch.matmul(T, torch.unsqueeze(v_posed_homo, dim=-1)) verts = v_homo[:, :, :3, 0] return verts, J_transformed, A, T, shape_offset, pose_offsets, pose_feature def vertices2joints(J_regressor: Tensor, vertices: Tensor) -> Tensor: ''' Calculates the 3D joint locations from the vertices Parameters ---------- J_regressor : torch.tensor JxV The regressor array that is used to calculate the joints from the position of the vertices vertices : torch.tensor BxVx3 The tensor of mesh vertices Returns ------- torch.tensor BxJx3 The location of the joints ''' return torch.einsum('bik,ji->bjk', [vertices, J_regressor]) def blend_shapes(betas: Tensor, shape_disps: Tensor) -> Tensor: ''' Calculates the per vertex displacement due to the blend shapes Parameters ---------- betas : torch.tensor Bx(num_betas) Blend shape coefficients shape_disps: torch.tensor Vx3x(num_betas) Blend shapes Returns ------- torch.tensor BxVx3 The per-vertex displacement due to shape deformation ''' # Displacement[b, m, k] = sum_{l} betas[b, l] * shape_disps[m, k, l] # i.e. Multiply each shape displacement by its corresponding beta and # then sum them. blend_shape = torch.einsum('bl,mkl->bmk', [betas, shape_disps]) return blend_shape def batch_rodrigues( rot_vecs: Tensor, epsilon: float = 1e-8, ) -> Tensor: ''' Calculates the rotation matrices for a batch of rotation vectors Parameters ---------- rot_vecs: torch.tensor Nx3 array of N axis-angle vectors Returns ------- R: torch.tensor Nx3x3 The rotation matrices for the given axis-angle parameters ''' batch_size = rot_vecs.shape[0] device, dtype = rot_vecs.device, rot_vecs.dtype angle = torch.norm(rot_vecs + 1e-8, dim=1, keepdim=True) rot_dir = rot_vecs / angle cos = torch.unsqueeze(torch.cos(angle), dim=1) sin = torch.unsqueeze(torch.sin(angle), dim=1) # Bx1 arrays rx, ry, rz = torch.split(rot_dir, 1, dim=1) K = torch.zeros((batch_size, 3, 3), dtype=dtype, device=device) zeros = torch.zeros((batch_size, 1), dtype=dtype, device=device) K = torch.cat([zeros, -rz, ry, rz, zeros, -rx, -ry, rx, zeros], dim=1) \ .view((batch_size, 3, 3)) ident = torch.eye(3, dtype=dtype, device=device).unsqueeze(dim=0) rot_mat = ident + sin * K + (1 - cos) * torch.bmm(K, K) return rot_mat def transform_mat(R: Tensor, t: Tensor) -> Tensor: ''' Creates a batch of transformation matrices Args: - R: Bx3x3 array of a batch of rotation matrices - t: Bx3x1 array of a batch of translation vectors Returns: - T: Bx4x4 Transformation matrix ''' # No padding left or right, only add an extra row return torch.cat([F.pad(R, [0, 0, 0, 1]), F.pad(t, [0, 0, 0, 1], value=1)], dim=2) def batch_rigid_transform( rot_mats: Tensor, joints: Tensor, parents: Tensor, dtype=torch.float32 ) -> Tensor: """ Applies a batch of rigid transformations to the joints Parameters ---------- rot_mats : torch.tensor BxNx3x3 Tensor of rotation matrices joints : torch.tensor BxNx3 Locations of joints parents : torch.tensor BxN The kinematic tree of each object dtype : torch.dtype, optional: The data type of the created tensors, the default is torch.float32 Returns ------- posed_joints : torch.tensor BxNx3 The locations of the joints after applying the pose rotations rel_transforms : torch.tensor BxNx4x4 The relative (with respect to the root joint) rigid transformations for all the joints """ joints = torch.unsqueeze(joints, dim=-1) rel_joints = joints.clone() rel_joints[:, 1:] -= joints[:, parents[1:]] transforms_mat = transform_mat( rot_mats.reshape(-1, 3, 3), rel_joints.reshape(-1, 3, 1)).reshape(-1, joints.shape[1], 4, 4) transform_chain = [transforms_mat[:, 0]] for i in range(1, parents.shape[0]): # Subtract the joint location at the rest pose # No need for rotation, since it's identity when at rest curr_res = torch.matmul(transform_chain[parents[i]], transforms_mat[:, i]) transform_chain.append(curr_res) transforms = torch.stack(transform_chain, dim=1) # The last column of the transformations contains the posed joints posed_joints = transforms[:, :, :3, 3] joints_homogen = F.pad(joints, [0, 0, 0, 1]) rel_transforms = transforms - F.pad( torch.matmul(transforms, joints_homogen), [3, 0, 0, 0, 0, 0, 0, 0]) return posed_joints, rel_transforms ================================================ FILE: lib/models/deformers/smplx/utils.py ================================================ # -*- coding: utf-8 -*- # Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is # holder of all proprietary rights on this computer program. # You can only use this computer program if you have closed # a license agreement with MPG or you get the right to use the computer # program from someone who is authorized to grant you that right. # Any use of the computer program without a valid license is prohibited and # liable to prosecution. # # Copyright©2019 Max-Planck-Gesellschaft zur Förderung # der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute # for Intelligent Systems. All rights reserved. # # Contact: ps-license@tuebingen.mpg.de from optparse import Option from typing import NewType, Union, Optional from dataclasses import dataclass, asdict, fields import numpy as np import torch Tensor = NewType('Tensor', torch.Tensor) Array = NewType('Array', np.ndarray) @dataclass class ModelOutput: vertices: Optional[Tensor] = None joints: Optional[Tensor] = None full_pose: Optional[Tensor] = None global_orient: Optional[Tensor] = None transl: Optional[Tensor] = None v_shaped: Optional[Tensor] = None def __getitem__(self, key): return getattr(self, key) def get(self, key, default=None): return getattr(self, key, default) def __iter__(self): return self.keys() def keys(self): keys = [t.name for t in fields(self)] return iter(keys) def values(self): values = [getattr(self, t.name) for t in fields(self)] return iter(values) def items(self): data = [(t.name, getattr(self, t.name)) for t in fields(self)] return iter(data) @dataclass class SMPLOutput(ModelOutput): betas: Optional[Tensor] = None body_pose: Optional[Tensor] = None T: Optional[Tensor] = None A: Optional[Tensor] = None shape_offset: Optional[Tensor] = None pose_offset: Optional[Tensor] = None pose_feature: Optional[Tensor] = None @dataclass class SMPLHOutput(SMPLOutput): left_hand_pose: Optional[Tensor] = None right_hand_pose: Optional[Tensor] = None transl: Optional[Tensor] = None @dataclass class SMPLXOutput(SMPLHOutput): expression: Optional[Tensor] = None jaw_pose: Optional[Tensor] = None def find_joint_kin_chain(joint_id, kinematic_tree): kin_chain = [] curr_idx = joint_id while curr_idx != -1: kin_chain.append(curr_idx) curr_idx = kinematic_tree[curr_idx] return kin_chain def to_tensor( array: Union[Array, Tensor], dtype=torch.float32 ) -> Tensor: if torch.is_tensor(array): return array else: return torch.tensor(array, dtype=dtype) class Struct(object): def __init__(self, **kwargs): for key, val in kwargs.items(): setattr(self, key, val) def to_np(array, dtype=np.float32): if 'scipy.sparse' in str(type(array)): array = array.todense() return np.array(array, dtype=dtype) def rot_mat_to_euler(rot_mats): # Calculates rotation matrix to euler angles # Careful for extreme cases of eular angles like [0.0, pi, 0.0] sy = torch.sqrt(rot_mats[:, 0, 0] * rot_mats[:, 0, 0] + rot_mats[:, 1, 0] * rot_mats[:, 1, 0]) return torch.atan2(-rot_mats[:, 2, 0], sy) ================================================ FILE: lib/models/deformers/smplx/vertex_ids.py ================================================ # -*- coding: utf-8 -*- # Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is # holder of all proprietary rights on this computer program. # You can only use this computer program if you have closed # a license agreement with MPG or you get the right to use the computer # program from someone who is authorized to grant you that right. # Any use of the computer program without a valid license is prohibited and # liable to prosecution. # # Copyright©2019 Max-Planck-Gesellschaft zur Förderung # der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute # for Intelligent Systems. All rights reserved. # # Contact: ps-license@tuebingen.mpg.de from __future__ import print_function from __future__ import absolute_import from __future__ import division # Joint name to vertex mapping. SMPL/SMPL-H/SMPL-X vertices that correspond to # MSCOCO and OpenPose joints vertex_ids = { 'smplh': { 'nose': 332, 'reye': 6260, 'leye': 2800, 'rear': 4071, 'lear': 583, 'rthumb': 6191, 'rindex': 5782, 'rmiddle': 5905, 'rring': 6016, 'rpinky': 6133, 'lthumb': 2746, 'lindex': 2319, 'lmiddle': 2445, 'lring': 2556, 'lpinky': 2673, 'LBigToe': 3216, 'LSmallToe': 3226, 'LHeel': 3387, 'RBigToe': 6617, 'RSmallToe': 6624, 'RHeel': 6787 }, 'smplx': { 'nose': 9120, 'reye': 9929, 'leye': 9448, 'rear': 616, 'lear': 6, 'rthumb': 8079, 'rindex': 7669, 'rmiddle': 7794, 'rring': 7905, 'rpinky': 8022, 'lthumb': 5361, 'lindex': 4933, 'lmiddle': 5058, 'lring': 5169, 'lpinky': 5286, 'LBigToe': 5770, 'LSmallToe': 5780, 'LHeel': 8846, 'RBigToe': 8463, 'RSmallToe': 8474, 'RHeel': 8635 }, 'mano': { 'thumb': 744, 'index': 320, 'middle': 443, 'ring': 554, 'pinky': 671, } } ================================================ FILE: lib/models/deformers/smplx/vertex_joint_selector.py ================================================ # -*- coding: utf-8 -*- # Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is # holder of all proprietary rights on this computer program. # You can only use this computer program if you have closed # a license agreement with MPG or you get the right to use the computer # program from someone who is authorized to grant you that right. # Any use of the computer program without a valid license is prohibited and # liable to prosecution. # # Copyright©2019 Max-Planck-Gesellschaft zur Förderung # der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute # for Intelligent Systems. All rights reserved. # # Contact: ps-license@tuebingen.mpg.de from __future__ import absolute_import from __future__ import print_function from __future__ import division import numpy as np import torch import torch.nn as nn from .utils import to_tensor class VertexJointSelector(nn.Module): def __init__(self, vertex_ids=None, use_hands=False, use_feet_keypoints=False, **kwargs): super(VertexJointSelector, self).__init__() extra_joints_idxs = [] face_keyp_idxs = np.array([ vertex_ids['nose'], vertex_ids['reye'], vertex_ids['leye'], vertex_ids['rear'], vertex_ids['lear']], dtype=np.int64) extra_joints_idxs = np.concatenate([extra_joints_idxs, face_keyp_idxs]) if use_feet_keypoints: feet_keyp_idxs = np.array([vertex_ids['LBigToe'], vertex_ids['LSmallToe'], vertex_ids['LHeel'], vertex_ids['RBigToe'], vertex_ids['RSmallToe'], vertex_ids['RHeel']], dtype=np.int32) extra_joints_idxs = np.concatenate( [extra_joints_idxs, feet_keyp_idxs]) if use_hands: self.tip_names = ['thumb', 'index', 'middle', 'ring', 'pinky'] tips_idxs = [] for hand_id in ['l', 'r']: for tip_name in self.tip_names: tips_idxs.append(vertex_ids[hand_id + tip_name]) extra_joints_idxs = np.concatenate( [extra_joints_idxs, tips_idxs]) self.register_buffer('extra_joints_idxs', to_tensor(extra_joints_idxs, dtype=torch.long)) def forward(self, vertices, joints): extra_joints = torch.index_select(vertices, 1, self.extra_joints_idxs) joints = torch.cat([joints, extra_joints], dim=1) return joints ================================================ FILE: lib/models/deformers/smplx_deformer_gender.py ================================================ # Modified from Deformer of AG3D from .fast_snarf.lib.model.deformer_smplx import ForwardDeformer, skinning from .smplx import SMPLX import torch from pytorch3d import ops import numpy as np import pickle import json from pytorch3d.transforms import quaternion_to_matrix, matrix_to_quaternion class SMPLXDeformer_gender(torch.nn.Module): def __init__(self, gender, is_sub2=False) -> None: super().__init__() self.body_model = SMPLX('lib/models/deformers/smplx/SMPLX', gender=gender, \ create_body_pose=False, \ create_betas=False, \ create_global_orient=False, \ create_transl=False, create_expression=False, create_jaw_pose=False, create_leye_pose=False, create_reye_pose=False, create_right_hand_pose=False, create_left_hand_pose=False, use_pca=True, num_pca_comps=12, num_betas=10, flat_hand_mean=False,ext='pkl') self.deformer = ForwardDeformer() self.threshold = 0.12 base_cache_dir = 'work_dirs/cache' if is_sub2: base_cache_dir = 'work_dirs/cache_sub2' if gender == 'neutral': init_spdir_neutral = torch.as_tensor(np.load(base_cache_dir+'/init_spdir_smplx_thu_newNeutral.npy')) self.register_buffer('init_spdir', init_spdir_neutral, persistent=False) init_podir_neutral = torch.as_tensor(np.load(base_cache_dir+'/init_podir_smplx_thu_newNeutral.npy')) self.register_buffer('init_podir', init_podir_neutral, persistent=False) init_lbs_weights = torch.as_tensor(np.load(base_cache_dir+'/init_lbsw_smplx_thu_newNeutral.npy')) self.register_buffer('init_lbsw', init_lbs_weights.unsqueeze(0), persistent=False) init_faces = torch.as_tensor(np.load(base_cache_dir+'/init_faces_smplx_newNeutral.npy')) self.register_buffer('init_faces', init_faces.unsqueeze(0), persistent=False) elif gender == 'male': init_spdir_male = torch.as_tensor(np.load(base_cache_dir+'/init_spdir_smplx_thu_newMale.npy')) self.register_buffer('init_spdir', init_spdir_male, persistent=False) init_podir_male = torch.as_tensor(np.load(base_cache_dir+'/init_podir_smplx_thu_newMale.npy')) self.register_buffer('init_podir', init_podir_male, persistent=False) init_lbs_weights = torch.as_tensor(np.load(base_cache_dir+'/init_lbsw_smplx_thu_newMale.npy')) self.register_buffer('init_lbsw', init_lbs_weights.unsqueeze(0), persistent=False) init_faces = torch.as_tensor(np.load(base_cache_dir+'/init_faces_smplx_neuMale.npy')) self.register_buffer('init_faces', init_faces.unsqueeze(0), persistent=False) self.initialize() self.initialized = True def initialize(self): ''' Will only be called once, used to initialize lbs volume ''' batch_size = 1 device = self.body_model.posedirs.device # canonical space is defined in t-pose / star-pose body_pose_t = torch.zeros((batch_size, 63)).to(device) jaw_pose_t = torch.zeros((batch_size, 3)).to(device) ##flat_hand_mean = False left_hand_pose_t = torch.tensor([1.4624, -0.1615, 0.1361, 1.3851, -0.2597, 0.0247, -0.0683, -0.4478, -0.6652, -0.7290, 0.0084, -0.4818]).unsqueeze(0).to(device) right_hand_pose_t = torch.tensor([1.4624, -0.1615, 0.1361, 1.3851, -0.2597, 0.0247, -0.0683, -0.4478, -0.6652, -0.7290, 0.0084, -0.4818]).unsqueeze(0).to(device) ## flat_hand_mean = True leye_pose_t = torch.zeros((batch_size, 3)).to(device) reye_pose_t = torch.zeros((batch_size, 3)).to(device) expression_t = torch.zeros((batch_size, 10)).to(device) global_orient = torch.zeros((batch_size, 3)).to(device) betas = torch.zeros((batch_size, 10)).to(device) smpl_outputs = self.body_model(betas=betas, body_pose=body_pose_t, jaw_pose=jaw_pose_t, left_hand_pose=left_hand_pose_t, right_hand_pose=right_hand_pose_t, leye_pose=leye_pose_t, reye_pose=reye_pose_t, expression=expression_t, transl=None, global_orient=global_orient) tfs_inv_t = torch.inverse(smpl_outputs.A.float().detach()) # from template to posed space vs_template = smpl_outputs.vertices smpl_faces = torch.as_tensor(self.body_model.faces.astype(np.int64)) pose_offset_cano = torch.matmul(smpl_outputs.pose_feature, self.init_podir).reshape(1, -1, 3) pose_offset_cano = torch.cat([pose_offset_cano[:, self.init_faces[..., i]] for i in range(3)], dim=1).mean(1) self.register_buffer('tfs_inv_t', tfs_inv_t, persistent=False) self.register_buffer('vs_template', vs_template, persistent=False) self.register_buffer('smpl_faces', smpl_faces, persistent=False) self.register_buffer('pose_offset_cano', pose_offset_cano, persistent=False) # initialize SNARF smpl_verts = smpl_outputs.vertices.float().detach().clone() self.deformer.switch_to_explicit(resolution=64, smpl_verts=smpl_verts, smpl_faces=self.smpl_faces, smpl_weights=self.body_model.lbs_weights.clone()[None].detach(), use_smpl=True) self.dtype = torch.float32 self.deformer.lbs_voxel_final = self.deformer.lbs_voxel_final.type(self.dtype) self.deformer.grid_denorm = self.deformer.grid_denorm.type(self.dtype) self.deformer.scale = self.deformer.scale.type(self.dtype) self.deformer.offset = self.deformer.offset.type(self.dtype) self.deformer.scale_kernel = self.deformer.scale_kernel.type(self.dtype) self.deformer.offset_kernel = self.deformer.offset_kernel.type(self.dtype) def forword_body_model(self, smpl_params, point_pool=4): batchsize = smpl_params.shape[0] if_use_pca=True if smpl_params.shape[1] == 123: scale, transl, global_orient, pose, betas, left_hand_pose, right_hand_pose, jaw_pose, leye_pose, reye_pose, expression = torch.split(smpl_params, [1, 3, 3, 63, 10, 12, 12, 3, 3, 3, 10], dim=1) else: # not use pca 12 , 189 scale, transl, global_orient, pose, betas, left_hand_pose, right_hand_pose, jaw_pose, leye_pose, reye_pose, expression = torch.split(smpl_params, [1, 3, 3, 63, 10, 45, 45, 3, 3, 3, 10], dim=1) if_use_pca = False smpl_params = { 'betas': betas.reshape(-1, 10), 'expression': expression.reshape(-1, 10), 'body_pose': pose.reshape(-1, 63), 'left_hand_pose': left_hand_pose.reshape(batchsize, -1), 'right_hand_pose': right_hand_pose.reshape(batchsize, -1), 'jaw_pose': jaw_pose.reshape(-1, 3), 'leye_pose': leye_pose.reshape(-1, 3), 'reye_pose': reye_pose.reshape(-1, 3), 'global_orient': global_orient.reshape(-1, 3), 'transl': transl.reshape(-1, 3), 'scale': scale.reshape(-1, 1) } device = smpl_params["betas"].device smpl_outputs = self.body_model(**smpl_params, use_pca=if_use_pca) return smpl_outputs def prepare_deformer(self, smpl_params=None, num_scenes=1, device=None): if smpl_params is None: smpl_params = torch.zeros((num_scenes, 120)).to(device) scale, global_orient, pose, betas, left_hand_pose, right_hand_pose, jaw_pose, leye_pose, reye_pose, expression = torch.split(smpl_params, [1, 3, 63, 10, 12, 12, 3, 3, 3, 10], dim=1) left_hand_pose = torch.tensor([1.4624, -0.1615, 0.1361, 1.3851, -0.2597, 0.0247, -0.0683, -0.4478, -0.6652, -0.7290, 0.0084, -0.4818]).unsqueeze(0).to(device).repeat(num_scenes, 1) right_hand_pose = torch.tensor([1.4624, -0.1615, 0.1361, 1.3851, -0.2597, 0.0247, -0.0683, -0.4478, -0.6652, -0.7290, 0.0084, -0.4818]).unsqueeze(0).to(device).repeat(num_scenes, 1) smpl_params = { 'betas': betas, 'expression': expression, 'body_pose': pose, 'left_hand_pose': left_hand_pose, 'right_hand_pose': right_hand_pose, 'jaw_pose': jaw_pose, 'leye_pose': leye_pose, 'reye_pose': reye_pose, 'global_orient': global_orient, 'transl': None, 'scale': None, } else: batchsize = smpl_params.shape[0] if_use_pca=True if smpl_params.shape[1] == 123: scale, transl, global_orient, pose, betas, left_hand_pose, right_hand_pose, jaw_pose, leye_pose, reye_pose, expression = torch.split(smpl_params, [1, 3, 3, 63, 10, 12, 12, 3, 3, 3, 10], dim=1) else: # not use pca 12 , 165 scale, transl, global_orient, pose, betas, left_hand_pose, right_hand_pose, jaw_pose, leye_pose, reye_pose, expression = torch.split(smpl_params, [1, 3, 3, 63, 10, 45, 45, 3, 3, 3, 10], dim=1) if_use_pca = False smpl_params = { 'betas': betas.reshape(-1, 10), 'expression': expression.reshape(-1, 10), 'body_pose': pose.reshape(-1, 63), 'left_hand_pose': left_hand_pose.reshape(batchsize, -1), 'right_hand_pose': right_hand_pose.reshape(batchsize, -1), 'jaw_pose': jaw_pose.reshape(-1, 3), 'leye_pose': leye_pose.reshape(-1, 3), 'reye_pose': reye_pose.reshape(-1, 3), 'global_orient': global_orient.reshape(-1, 3), 'transl': transl.reshape(-1, 3), 'scale': scale.reshape(-1, 1) } device = smpl_params["betas"].device if not self.initialized: self.initialize(smpl_params["betas"]) self.initialized = True smpl_outputs = self.body_model(**smpl_params, use_pca=if_use_pca) self.smpl_outputs = smpl_outputs tfs = (smpl_outputs.A) @ self.tfs_inv_t.expand(smpl_outputs.A.shape[0],-1,-1,-1) self.tfs = tfs # self.tfs_A @ self.tfs_inv_t self.tfs_A = smpl_outputs.A # X_posed = smpl_outputs.A @ X_template, and (self.tfs_inv_t) @ X_tposed = X_template; # so X_posed = (smpl_outputs.A @ self.tfs_inv_t) @ X_tposed == equal to ==> self.tfs_A @ self.tfs_inv_t @ X_tposed self.shape_offset = torch.einsum('bl,mkl->bmk', [smpl_outputs.betas, self.init_spdir]) # betas-torch.Size([1, 20]) ; init_spdir-([25254, 3, 20]) self.pose_offset = torch.matmul(smpl_outputs.pose_feature, self.init_podir).reshape(self.shape_offset.shape) # batch_size, ([1, 25254, 3]) def __call__(self, pts_in, rot_in, mask=None, cano=True, offset_gs=None, if_rotate_gaussian=False): ''' to calculate the skinning results pts_in (tensor, [bs, N, 3]): the canonical space points + offset_gs, represented a batch of clothed human rot_in (tensor, [bs, N, 3]): the canonical space gaussians points' rotation mask (tensor, [bs, N]): the mask of the vertices (face, hands), 1 for the vertices that use the skinning weights from template directly cono (bool): if True, return the input pts directly offset_gs (tensor, [bs, N, 3]): the estimated offset of the vertices in the canonical space use some of the attributes from the "prepare_deformer" to calculate the skinning, including: pose_offset[bs_pose, N, 3] shape_offset[bs_pose, N, 3] ''' pts = pts_in.clone() rot = rot_in.clone() if cano: return pts, None else: init_faces = self.init_faces b, n, _ = pts.shape smpl_nn = False if smpl_nn: # deformer based on SMPL nearest neighbor search k = 1 dist_sq, idx, neighbors = ops.knn_points(pts, self.smpl_outputs.vertices.float().expand(b, -1, -1), K=k, return_nn=True) dist = dist_sq.sqrt().clamp_(0.00003, 0.1) weights = self.body_model.lbs_weights.clone()[idx] ws=1./dist ws=ws/ws.sum(-1,keepdim=True) weights = (ws[..., None]*weights).sum(2).detach() shape_offset = torch.cat([self.shape_offset[:, init_faces[..., i]] for i in range(3)], dim=1).mean(1) pts += shape_offset pts_cano_all, w_tf = skinning(pts, weights, self.tfs, inverse=False) pts_cano_all = pts_cano_all.unsqueeze(2) else: # defromer based on fast-SNARF shape_offset = torch.cat([self.shape_offset[:, init_faces[..., i]] for i in range(3)], dim=1).mean(1) pose_offset = torch.cat([self.pose_offset[:, init_faces[..., i]] for i in range(3)], dim=1).mean(1) pts_query_lbs = pts.detach() # T_pose + gs_offset pts_cano_all, w_tf = self.deformer.forward_skinning(pts, shape_offset, pose_offset, cond=None, tfs=self.tfs_A, tfs_inv=self.tfs_inv_t, \ poseoff_ori=self.pose_offset_cano, lbsw=self.init_lbsw, mask=mask) pts_cano_all = pts_cano_all.reshape(b, n, -1, 3) if if_rotate_gaussian: # rotate the gaussian points # pts_cano_all = rot # rot_mats = quaternion_to_matrix(rot) # rot_mats = torch.einsum('nxy,nyz->nxz', w_tf[..., :3, :3], rot_mats) # rot_res = matrix_to_quaternion(rot_mats) # return pts_cano_all, w_tf.clone(), rot_res raise NotImplementedError("Code is not correct!") assert pts_in.dim() != 2 return pts_cano_all, w_tf.clone() ================================================ FILE: lib/models/renderers/__init__.py ================================================ from .gau_renderer import GRenderer, get_covariance, batch_rodrigues __all__ = ['GRenderer'] ================================================ FILE: lib/models/renderers/gau_renderer.py ================================================ from diff_gaussian_rasterization import ( GaussianRasterizationSettings, GaussianRasterizer, ) import torch import torch.nn as nn def batch_rodrigues(rot_vecs, epsilon = 1e-8): ''' Calculates the rotation matrices for a batch of rotation vectors Parameters ---------- rot_vecs: torch.tensor Nx3 array of N axis-angle vectors Returns ------- R: torch.tensor Nx3x3 The rotation matrices for the given axis-angle parameters ''' batch_size = rot_vecs.shape[0] device, dtype = rot_vecs.device, rot_vecs.dtype angle = torch.norm(rot_vecs + 1e-8, dim=1, keepdim=True) rot_dir = rot_vecs / angle cos = torch.unsqueeze(torch.cos(angle), dim=1) sin = torch.unsqueeze(torch.sin(angle), dim=1) # Bx1 arrays rx, ry, rz = torch.split(rot_dir, 1, dim=1) K = torch.zeros((batch_size, 3, 3), dtype=dtype, device=device) zeros = torch.zeros((batch_size, 1), dtype=dtype, device=device) K = torch.cat([zeros, -rz, ry, rz, zeros, -rx, -ry, rx, zeros], dim=1) \ .view((batch_size, 3, 3)) ident = torch.eye(3, dtype=dtype, device=device).unsqueeze(dim=0) rot_mat = ident + sin * K + (1 - cos) * torch.bmm(K, K) return rot_mat def build_scaling_rotation(s, r, tfs): L = torch.zeros((s.shape[0], 3, 3), dtype=torch.float, device=s.device) R = build_rotation(r) R_ = R L[:,0,0] = s[:,0] L[:,1,1] = s[:,1] L[:,2,2] = s[:,2] L = R_ @ L return L def strip_symmetric(sym): return strip_lowerdiag(sym) def strip_lowerdiag(L): uncertainty = torch.zeros((L.shape[0], 6), dtype=torch.float, device=L.device) uncertainty[:, 0] = L[:, 0, 0] uncertainty[:, 1] = L[:, 0, 1] uncertainty[:, 2] = L[:, 0, 2] uncertainty[:, 3] = L[:, 1, 1] uncertainty[:, 4] = L[:, 1, 2] uncertainty[:, 5] = L[:, 2, 2] return uncertainty def build_covariance_from_scaling_rotation(scaling, scaling_modifier, rotation, tfs): L = build_scaling_rotation(scaling_modifier * scaling, rotation, tfs) actual_covariance = L @ L.transpose(1, 2) symm = strip_symmetric(actual_covariance) return symm def build_rotation(r): norm = torch.sqrt(r[:,0]*r[:,0] + r[:,1]*r[:,1] + r[:,2]*r[:,2] + r[:,3]*r[:,3]) q = r / norm[:, None] R = torch.zeros((q.size(0), 3, 3), device=q.device) r = q[:, 0] x = q[:, 1] y = q[:, 2] z = q[:, 3] R[:, 0, 0] = 1 - 2 * (y*y + z*z) R[:, 0, 1] = 2 * (x*y - r*z) R[:, 0, 2] = 2 * (x*z + r*y) R[:, 1, 0] = 2 * (x*y + r*z) R[:, 1, 1] = 1 - 2 * (x*x + z*z) R[:, 1, 2] = 2 * (y*z - r*x) R[:, 2, 0] = 2 * (x*z - r*y) R[:, 2, 1] = 2 * (y*z + r*x) R[:, 2, 2] = 1 - 2 * (x*x + y*y) return R def get_covariance(scaling, rotation, scaling_modifier = 1): L = torch.zeros_like(rotation) L[:, 0, 0] = scaling[:, 0] L[:, 1, 1] = scaling[:, 1] L[:, 2, 2] = scaling[:, 2] actual_covariance = rotation @ (L**2) @ rotation.permute(0, 2, 1) return strip_symmetric(actual_covariance) import math class GRenderer(nn.Module): def __init__(self, image_size=256, anti_alias=False, f=5000, near=0.01, far=40, bg_color=0): super().__init__() self.anti_alias = anti_alias self.image_size = image_size self.tanfov = 2 * math.atan(self.image_size[0] / (2 * f)) if bg_color == 0: bg = torch.tensor([0, 0, 0], dtype=torch.float32) else: bg = torch.tensor([1, 1, 1], dtype=torch.float32) self.register_buffer('bg', bg) opengl_proj = torch.tensor([[2 * f / self.image_size[0], 0.0, 0.0, 0.0], [0.0, 2 * f / self.image_size[1], 0.0, 0.0], [0.0, 0.0, far / (far - near), -(far * near) / (far - near)], [0.0, 0.0, 1.0, 0.0]]).float().unsqueeze(0).transpose(1, 2) self.register_buffer('opengl_proj', opengl_proj) if anti_alias: image_size = [s*2 for s in image_size] def prepare(self, cameras): if cameras.shape[-1] == 20: # use the new format: intrisic(fx, fy, cx, cy) + extrinsic(RT) w2c = cameras[4:].reshape(4, 4) cam_center = torch.inverse(w2c)[:3, 3] intrisics = cameras[:4] fov = get_fov(intrisics[0:2], intrisics[2].item(), self.image_size) tanfovx = fov[1] tanfovy = fov[1] w2c = w2c.unsqueeze(0).transpose(1, 2) proj_matrix = get_proj_yy(intrisics[0], self.image_size, 100, 0.01).to(torch.float32).to(intrisics.device) full_proj = torch.bmm(w2c, proj_matrix).to(torch.float32) elif cameras.shape[-1] == 19: # [:3] C, [3: ] RT cam_center = cameras[:3] # C w2c = cameras[3:].reshape(4, 4) w2c = w2c.unsqueeze(0).transpose(1, 2) # RT full_proj = w2c.bmm(self.opengl_proj).to(torch.float32) self.full_proj = full_proj tanfovx = self.tanfov tanfovy = self.tanfov self.raster_settings = GaussianRasterizationSettings( image_height=self.image_size[1], image_width=self.image_size[0], tanfovx=tanfovx, tanfovy=tanfovy, bg=self.bg.to(cameras.dtype), scale_modifier=1.0, viewmatrix=w2c, projmatrix=full_proj, sh_degree=0, campos=cam_center, prefiltered=False, debug=False, antialiasing=True # NEW version of GS ) self.rasterizer = GaussianRasterizer(raster_settings=self.raster_settings) def render_gaussian(self, means3D, colors_precomp, rotations, opacities, scales, cov3D_precomp=None): ''' mode: normal, phong, texture ''' screenspace_points = ( torch.zeros_like( means3D, dtype=means3D.dtype, requires_grad=True, device=means3D.device, ) + 0 ) try: screenspace_points.retain_grad() except: pass if cov3D_precomp != None: image, _, _= self.rasterizer(means3D=means3D, colors_precomp=colors_precomp, \ opacities=opacities, means2D=screenspace_points, cov3D_precomp=cov3D_precomp) else: image, _, _ = self.rasterizer(means3D=means3D, colors_precomp=colors_precomp, \ rotations=torch.nn.functional.normalize(rotations), opacities=opacities, scales=scales, \ means2D=screenspace_points) return image def get_view_matrix(R, t): Rt = torch.cat((R, t.view(3,1)),1) view_matrix = torch.cat((Rt, torch.FloatTensor([0,0,0,1]).cuda().view(1,4))) return view_matrix def get_proj_yy(f, image_size, far, near): opengl_proj = torch.tensor([[2 * f / image_size[0], 0.0, 0.0, 0.0], [0.0, 2 * f / image_size[1], 0.0, 0.0], [0.0, 0.0, far / (far - near), -(far * near) / (far - near)], [0.0, 0.0, 1.0, 0.0]]).float().unsqueeze(0).transpose(1, 2) return opengl_proj def get_proj_matrix(fovY,fovX, z_near, z_far, z_sign): tanHalfFovY = math.tan((fovY / 2)) tanHalfFovX = math.tan((fovX / 2)) top = tanHalfFovY * z_near bottom = -top right = tanHalfFovX * z_near left = -right z_sign = 1.0 proj_matrix = torch.zeros(4, 4).float().cuda() proj_matrix[0, 0] = 2.0 * z_near / (right - left) proj_matrix[1, 1] = 2.0 * z_near / (top - bottom) proj_matrix[0, 2] = (right + left) / (right - left) proj_matrix[1, 2] = (top + bottom) / (top - bottom) proj_matrix[3, 2] = z_sign proj_matrix[2, 2] = z_sign * z_far / (z_far - z_near) proj_matrix[2, 3] = -(z_far * z_near) / (z_far - z_near) return proj_matrix def get_fov(focal, princpt, img_shape): fov_x = 2 * torch.atan(img_shape[1] / (2 * focal[0])) fov_y = 2 * torch.atan(img_shape[0] / (2 * focal[1])) fov = torch.FloatTensor([fov_x, fov_y]).cuda() return fov ================================================ FILE: lib/models/sapiens/__init__.py ================================================ from .sapiens_wrapper_torchscipt import SapiensWrapper_ts ================================================ FILE: lib/models/sapiens/sapiens_wrapper_torchscipt.py ================================================ # Copyright (c) 2023, Zexin He # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import torch import torch.nn as nn from transformers import Dinov2Backbone from torchvision import transforms from einops import rearrange from torch import Tensor def pretrain_forward(sp_lite, inputs: Tensor, layer_num: int, return_hidden_states=False) -> Tensor: B = inputs.size(0) patch_embed_output, _50, _51, _52, _53 = sp_lite.backbone.patch_embed(inputs) cls_token = sp_lite.backbone.cls_token.expand(B, -1, -1) x = torch.cat([cls_token, patch_embed_output], dim=1) cls_pos_embed, patch_pos_embed = sp_lite.backbone.pos_embed[:, 0:1, :], sp_lite.backbone.pos_embed[:, 1:, :] dim = cls_pos_embed.shape[-1] #64x64 patch_pos_embed = patch_pos_embed.reshape(-1, 64, 64, dim) patch_pos_embed_ = patch_pos_embed.permute(0, 3, 1, 2) patch_pos_embed = torch.nn.functional.interpolate( patch_pos_embed_, size = (_52, _53), mode="bicubic", align_corners=False, ) patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(-1, _52 * _53, dim) patch_pos_embed = torch.cat([cls_pos_embed, patch_pos_embed], dim=1) x = x + patch_pos_embed x = sp_lite.backbone.drop_after_pos(x) if return_hidden_states: hidden_states = [] hidden_states.append(x) for i in range(layer_num): x = getattr(sp_lite.backbone.layers, str(i))(x) hidden_states.append(x) x = sp_lite.backbone.ln1(x) cls_output = x[:, 0] # Assuming class token is at index 0 patch_tokens = x[:, 1:] # Remaining are patch tokens output = patch_tokens.view(B, _52, _53, -1).permute(0, 3, 1, 2) if return_hidden_states: return output, hidden_states return output class SapiensWrapper_ts(nn.Module): """ Sapiens wrapper using huggingface transformer implementation. """ def __init__(self, model_path: str = 'facebook/dinov2-base', freeze=True, img_size=None, layer_num=None): super().__init__() if layer_num == None: if "0.3b" in model_path: self.layer_num = 24 else: self.layer_num = 48 else: self.layer_num = layer_num self.model = torch.jit.load(model_path) if img_size is None: self.my_processor = transforms.Compose([ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) else: self.my_processor = transforms.Compose([ transforms.Resize(size=img_size), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) self.interpolate_pos_encoding=True if freeze: self._freeze() def forward(self, image, use_my_proces=False, requires_grad=False, output_hidden_states=False): # image: [B, N, C, H, W] # RGB image with [0,1] scale and properly sized if image.ndim == 5: B, N, _, H, W = image.shape mv = True image = image.flatten(0, 1) if image.ndim == 4: N, _, H, W = image.shape B = None else: raise NotImplementedError device = image.device if not use_my_proces: inputs = self.image_processor(image, return_tensors="pt") inputs['pixel_values'] = inputs['pixel_values'].to(device) else: inputs = self.my_processor(image) inputs = {'pixel_values': inputs} if requires_grad==False: with torch.no_grad(): outputs = pretrain_forward(self.model, inputs['pixel_values'], layer_num=self.layer_num, return_hidden_states=output_hidden_states) else: outputs = pretrain_forward(self.model, inputs['pixel_values'], layer_num=self.layer_num, return_hidden_states=output_hidden_states) last_feature_map = outputs[0] if not output_hidden_states: if B is None: # dim = 5 last_feature_map = rearrange(last_feature_map, 'n dim h w -> n (h w) dim') # N, N_tk, C else: last_feature_map = rearrange(last_feature_map, 'bn dim h w -> bn (h w) dim') last_feature_map = last_feature_map.reshape(B, N, last_feature_map.shape[-2], last_feature_map.shape[-1]) if output_hidden_states: hidden_states = torch.stack(outputs[1], 0).permute(1, 0, 2, 3) # N, N_layer, N_tk, C hidden_states = hidden_states[:, :, 1:,:] # N, N_layer, N_tk, C if output_hidden_states: return hidden_states else: return last_feature_map def _freeze(self): print(f"======== Freezing DinoWrapper ========") self.model.eval() for name, param in self.model.named_parameters(): param.requires_grad = False if __name__ == "__main__": model = SapiensWrapper_ts() model.eval() image = torch.rand(1, 3, 896, 640) output = model(image, use_my_proces=True, output_hidden_states=True) output = pretrain_forward(model, image, layer_num=24) print(output) print("done") ================================================ FILE: lib/models/transformer_sa/__init__.py ================================================ from .mae_decoder_v3_skip import neck_SA_v3_skip ================================================ FILE: lib/models/transformer_sa/mae_decoder_v3_skip.py ================================================ import torch import torch.nn as nn import numpy as np from timm.models.vision_transformer import PatchEmbed, Block, checkpoint_seq from typing import Union def get_2d_sincos_pos_embed(embed_dim, grid_size, add_cls_token=False): """ Create 2D sin/cos positional embeddings. Args: embed_dim (`int`): Embedding dimension. grid_size (`int`): The grid height and width. add_cls_token (`bool`, *optional*, defaults to `False`): Whether or not to add a classification (CLS) token. Returns: (`torch.FloatTensor` of shape (grid_size*grid_size, embed_dim) or (1+grid_size*grid_size, embed_dim): the position embeddings (with or without classification token) """ grid_h = np.arange(grid_size, dtype=np.float32) grid_w = np.arange(grid_size, dtype=np.float32) grid = np.meshgrid(grid_w, grid_h) # here w goes first grid = np.stack(grid, axis=0) grid = grid.reshape([2, 1, grid_size, grid_size]) pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) if add_cls_token: pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0) return pos_embed def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): if embed_dim % 2 != 0: raise ValueError("embed_dim must be even") # use half of dimensions to encode grid_h emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) return emb def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): """ embed_dim: output dimension for each position pos: a list of positions to be encoded: size (M,) out: (M, D) """ if embed_dim % 2 != 0: raise ValueError("embed_dim must be even") omega = np.arange(embed_dim // 2, dtype=float) omega /= embed_dim / 2.0 omega = 1.0 / 10000**omega # (D/2,) pos = pos.reshape(-1) # (M,) out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product emb_sin = np.sin(out) # (M, D/2) emb_cos = np.cos(out) # (M, D/2) emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) return emb class neck_SA_v3_skip(nn.Module): def __init__(self, patch_size=4, in_chans=32, num_patches=196, embed_dim=1024, decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16, mlp_ratio=4, norm_layer=nn.LayerNorm, total_num_hidden_states=25, connect_mode:Union['uniform', 'zeros', 'shadow']='uniform', if_checkpoint_seq=False): super().__init__() self.num_patches = num_patches # Decoder-specific self.if_checkpoint_seq = if_checkpoint_seq # to save the memory self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim)) self.decoder_pos_embed = nn.Parameter(torch.zeros(1, num_patches, decoder_embed_dim), requires_grad=True) # fixed sin-cos embedding self.decoder_blocks_depart = nn.ModuleList([ Block(decoder_embed_dim, decoder_num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer) #qk_scale=None for i in range(decoder_depth)]) self.decoder_norm = norm_layer(decoder_embed_dim) self.decoder_pred = nn.Linear(decoder_embed_dim, patch_size**2 * in_chans, bias=True) # decoder to patch if connect_mode == 'uniform': skip = total_num_hidden_states// (decoder_depth-1) self.select_hidden_states = [skip*i for i in range(decoder_depth)] # for 25 self.select_hidden_states[-1] = total_num_hidden_states - 1 self.select_hidden_states = self.select_hidden_states[::-1] # inverse the order elif connect_mode == 'zeros': # print('!!!!!!!!!!! zeros !!!!!!!!!!!!!!!!') self.select_hidden_states = [0, 0, 0, 0, 0, 0] self.decoder_embed = nn.ModuleList([ nn.Linear(embed_dim, decoder_embed_dim, bias=True) for _ in range(decoder_depth) ]) self.initialize_weights() def initialize_weights(self): # Initialization # Initialize (and freeze) pos_embed by sin-cos embedding decoder_pos_embed = get_2d_sincos_pos_embed(self.decoder_pos_embed.shape[-1], int((self.num_patches)**.5), add_cls_token=False) self.decoder_pos_embed.data.copy_(torch.from_numpy(decoder_pos_embed).float().unsqueeze(0)) # Initialize nn.Linear and nn.LayerNorm self.apply(self._init_weights) def _init_weights(self, m): if isinstance(m, nn.Linear): # We use xavier_uniform following official JAX ViT: torch.nn.init.xavier_uniform_(m.weight) if isinstance(m, nn.Linear) and m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.LayerNorm): nn.init.constant_(m.bias, 0) nn.init.constant_(m.weight, 1.0) def forward_decoder(self, in_features, ids_restore): # Embed tokens B, N_l, N_f, C = in_features.shape select_in_features = in_features[:, self.select_hidden_states, :, :] # parallelly embed the hidden states wiht jit forks = [torch.jit.fork(self.decoder_embed[i], select_in_features[:, i]) for i in range(len(self.select_hidden_states))] x_list = [torch.jit.wait(fork) for fork in forks] x_all_states = torch.stack(x_list) # N_l, B, N_feat, C # Add pos embed mask_tokens = self.mask_token.repeat(B, ids_restore.shape[1], 1) # B, N_q, C query_x = mask_tokens + self.decoder_pos_embed # Append mask tokens to sequence x = torch.zeros_like(x_all_states[0]) x = torch.cat([x, query_x], dim=1) # no cls token # B, N_q+N_f, C # # Apply Transformer blocks # v0 # Apply Transformer blocks # v1 for i, blk in enumerate( self.decoder_blocks_depart): x_add = x_all_states[i] x[:, :N_f, :] += x_add # add the hidden states x = blk(x) x = self.decoder_norm(x) x = x[:, -self.num_patches:, :] x_reshaped = x return x_reshaped def forward(self, encoded_latent, ids_restore): decoded_output = self.forward_decoder(encoded_latent, ids_restore) return decoded_output ================================================ FILE: lib/ops/__init__.py ================================================ from .activation import TruncExp ================================================ FILE: lib/ops/activation.py ================================================ import math import torch import torch.nn as nn from torch.autograd import Function from torch.cuda.amp import custom_bwd, custom_fwd class _trunc_exp(Function): @staticmethod @custom_fwd(cast_inputs=torch.float32) # cast to float32 def forward(ctx, x): exp_x = torch.exp(x) ctx.save_for_backward(exp_x) return exp_x @staticmethod @custom_bwd def backward(ctx, g): exp_x = ctx.saved_tensors[0] return g * exp_x.clamp(min=1e-6, max=1e6) trunc_exp = _trunc_exp.apply class TruncExp(nn.Module): @staticmethod def forward(x): return _trunc_exp.apply(x) ================================================ FILE: lib/utils/infer_util.py ================================================ import os import imageio import rembg import torch import numpy as np import PIL.Image from PIL import Image from typing import Any import json from pathlib import Path from torchvision.transforms import ToTensor from rembg import remove # For background removal from pytorch3d.transforms import axis_angle_to_matrix, matrix_to_axis_angle from lib.models.deformers.smplx.lbs import batch_rodrigues import cv2 from PIL import Image import numpy as np import json # import random import math # import av def reset_first_frame_rotation(root_orient, trans): """ Set the root_orient rotation matrix of the first frame to the identity matrix (no rotation), keep the relative rotation relationships of other frames, and adjust trans accordingly. Parameters: root_orient: Tensor of shape (N, 3), representing the axis-angle parameters for N frames. trans: Tensor of shape (N, 3), representing the translation parameters for N frames. Returns: new_root_orient: Tensor of shape (N, 3), adjusted axis-angle parameters. new_trans: Tensor of shape (N, 3), adjusted translation parameters. """ # Convert the root_orient of the first frame to a rotation matrix R_0 = axis_angle_to_matrix(root_orient[0:1]) # Shape: (1, 3, 3) # Compute the inverse of the first frame's rotation matrix R_0_inv = torch.inverse(R_0) # Shape: (1, 3, 3) # Initialize lists for new root_orient and trans new_root_orient = [] new_trans = [] for i in range(root_orient.shape[0]): # Rotation matrix of the current frame R_i = axis_angle_to_matrix(root_orient[i:i+1]) # Shape: (1, 3, 3) R_new = torch.matmul(R_0_inv, R_i) # Shape: (1, 3, 3) # Convert the rotation matrix back to axis-angle representation axis_angle_new = matrix_to_axis_angle(R_new) # Shape: (1, 3) new_root_orient.append(axis_angle_new) # Adjust the translation for the current frame trans_i = trans[i:i+1] # Shape: (1, 3) trans_new = torch.matmul(R_0_inv, trans_i.T).T # Shape: (1, 3) new_trans.append(trans_new) # Stack the results of new_root_orient and new_trans new_root_orient = torch.cat(new_root_orient, dim=0) # Shape: (N, 3) new_trans = torch.cat(new_trans, dim=0) # Shape: (N, 3) # Adjust the new translations relative to the first frame new_trans = new_trans - new_trans[[0], :] return new_root_orient, new_trans from scipy.spatial.transform import Rotation def rotation_matrix_to_rodrigues(rotation_matrices): # reshape rotation_matrices to (-1, 3, 3) reshaped_matrices = rotation_matrices.reshape(-1, 3, 3) rotation = Rotation.from_matrix(reshaped_matrices) rodrigues_vectors = rotation.as_rotvec() return rodrigues_vectors def get_hand_pose_mean(): import numpy as np hand_pose_mean= np.array([[ 0.11167871, 0.04289218, -0.41644183, 0.10881133, -0.06598568, -0.75622 , -0.09639297, -0.09091566, -0.18845929, -0.11809504, 0.05094385, -0.5295845 , -0.14369841, 0.0552417 , -0.7048571 , -0.01918292, -0.09233685, -0.3379135 , -0.45703298, -0.19628395, -0.6254575 , -0.21465237, -0.06599829, -0.50689423, -0.36972436, -0.06034463, -0.07949023, -0.1418697 , -0.08585263, -0.63552827, -0.3033416 , -0.05788098, -0.6313892 , -0.17612089, -0.13209307, -0.37335458, 0.8509643 , 0.27692273, -0.09154807, -0.49983943, 0.02655647, 0.05288088, 0.5355592 , 0.04596104, -0.27735803, 0.11167871, -0.04289218, 0.41644183, 0.10881133, 0.06598568, 0.75622 , -0.09639297, 0.09091566, 0.18845929, -0.11809504, -0.05094385, 0.5295845 , -0.14369841, -0.0552417 , 0.7048571 , -0.01918292, 0.09233685, 0.3379135 , -0.45703298, 0.19628395, 0.6254575 , -0.21465237, 0.06599829, 0.50689423, -0.36972436, 0.06034463, 0.07949023, -0.1418697 , 0.08585263, 0.63552827, -0.3033416 , 0.05788098, 0.6313892 , -0.17612089, 0.13209307, 0.37335458, 0.8509643 , -0.27692273, 0.09154807, -0.49983943, -0.02655647, -0.05288088, 0.5355592 , -0.04596104, 0.27735803]]) return hand_pose_mean def load_smplify_json(smplx_smplify_path): with open(smplx_smplify_path) as f: data = json.load(f) # Prepare camera transformation matrix (R | t) RT = torch.concatenate([torch.Tensor(data['camera']['R']), torch.Tensor(data['camera']['t']).reshape(3, 1) * 2], dim=1) RT = torch.cat([RT, torch.Tensor([[0, 0, 0, 1]])], dim=0) # Create intrinsic parameters tensor intri = torch.Tensor(data['camera']['focal'] + data['camera']['princpt']) # intri[[3, 2]] = intri[[2, 3]] # # Set default focal length and image resolution # default_focal = 1120 # Default focal length # img_res = [640, 896] # default_fxy_cxy = torch.tensor([default_focal, default_focal, img_res[1] // 2, img_res[0] // 2]).reshape(1, 4) # # Adjust intrinsic parameters based on default focal and resolution # intri = intri * default_fxy_cxy[0, -2] / intri[-2] # intri[-2:] = default_fxy_cxy[0, -2:] # Force consistent image width and height # Extract SMPL parameters from data smpl_param_data = data global_orient = np.array(smpl_param_data['root_pose']).reshape(1, -1) body_pose = np.array(smpl_param_data['body_pose']).reshape(1, -1) shape = np.array(smpl_param_data['betas_save']).reshape(1, -1)[:, :10] left_hand_pose = np.array(smpl_param_data['lhand_pose']).reshape(1, -1) right_hand_pose = np.array(smpl_param_data['rhand_pose']).reshape(1, -1) # Concatenate all parameters into a single tensor for SMPL model smpl_param_ref = np.concatenate([np.array([[1.]]), np.array(smpl_param_data['trans']).reshape(1, 3), global_orient, body_pose, shape, left_hand_pose, right_hand_pose, np.array(smpl_param_data['jaw_pose']).reshape(1, -1), np.zeros_like(np.array(smpl_param_data['leye_pose']).reshape(1, -1)), np.zeros_like(np.array(smpl_param_data['reye_pose']).reshape(1, -1)), np.zeros_like(np.array(smpl_param_data['expr']).reshape(1, -1)[:, :10])], axis=1) return RT, intri, torch.Tensor(smpl_param_ref).reshape(-1) # Return transformation, intrinsic, and SMPL parameters def load_image(input_path, output_folder, image_frame_ratio=None): input_img_path = Path(input_path) vids = [] save_path = os.path.join(output_folder, f"{input_img_path.name}") print(f"Processing: {save_path}") image = Image.open(input_img_path) if image.mode == "RGBA": pass else: # remove bg image = remove(image.convert("RGBA"), alpha_matting=True) # resize object in frame image_arr = np.array(image) in_w, in_h = image_arr.shape[:2] ret, mask = cv2.threshold( np.array(image.split()[-1]), 0, 255, cv2.THRESH_BINARY ) x, y, w, h = cv2.boundingRect(mask) max_size = max(w, h) side_len = ( int(max_size / image_frame_ratio) if image_frame_ratio is not None else int(max_size / 0.85) ) padded_image = np.zeros((side_len, side_len, 4), dtype=np.uint8) center = side_len // 2 padded_image[ center - h // 2 : center - h // 2 + h, center - w // 2 : center - w // 2 + w, ] = image_arr[y : y + h, x : x + w] rgba = Image.fromarray(padded_image).resize((896, 896), Image.LANCZOS) # crop the width into 640 in the center rgba = rgba.crop([128, 0, 640+128, 896]) # white bg rgba_arr = np.array(rgba) / 255.0 rgb = rgba_arr[..., :3] * rgba_arr[..., -1:] + (1 - rgba_arr[..., -1:]) input_image = Image.fromarray((rgb * 255).astype(np.uint8)) image = ToTensor()(input_image) return image def prepare_camera( resolution_x = 640, resolution_y = 640, focal_length = 600,sensor_width = 32, camera_dist = 20, num_views=1, stides=1): def look_at(camera_position, target_position, up_vector): # colmap +z forward, +y down forward = -(camera_position - target_position) / np.linalg.norm(camera_position - target_position) right = np.cross(up_vector, forward) up = np.cross(forward, right) return np.column_stack((right, up, forward)) # set the intrisics focal_length = focal_length * (resolution_y/sensor_width) K = np.array( [[focal_length, 0, resolution_x//2], [0, focal_length, resolution_y//2], [0, 0, 1]] ) # set the extrisics camera_pose_list = [] for frame_idx in range(0, num_views, stides): phi = math.radians(90) theta = (3 / 4) * math.pi * 2 camera_location = np.array( [camera_dist * math.sin(phi) * math.cos(theta), camera_dist * math.cos(phi), -camera_dist * math.sin(phi) * math.sin(theta),] ) # print(camera_location) camera_pose = np.eye(4) camera_pose[:3, 3] = camera_location # Set camera position and target position camera_position = camera_location target_position = np.array([0.0, 0.0, 0.0]) # Compute the camera's rotation matrix to look at the target up_vector = np.array([0.0, -1.0, 0.0]) # colmap rotation_matrix = look_at(camera_position, target_position, up_vector) # Update camera position and rotation camera_pose[:3, :3] = rotation_matrix camera_pose[:3, 3] = camera_position camera_pose_list.append(camera_pose) return K, camera_pose_list def construct_camera(K, cam_list, device='cuda'): num_imgs = len(cam_list) front_idx = num_imgs//4*3 cam_list = cam_list[front_idx:] + cam_list[:front_idx] cam_raw = np.array(cam_list) cam_raw[:, :3, 3] = cam_raw[:, :3, 3] cam = np.linalg.inv(cam_raw) cam = torch.Tensor(cam) intrics = torch.Tensor([K[0,0],K[1,1], K[0,2], K[1,2]]).reshape(-1) scale = 0.5 # diffrent from the synthetic data, the scale is process first # trans from (3,) to (batch_size, 3,1) trans = [0, 0.2, 0] #in the center trans_bt = torch.Tensor(trans).reshape(1, 3, 1).expand(cam.shape[0], 3, 1) cam[:,:3,3] = cam[:,:3,3] + torch.bmm(cam[:,:3,:3], trans_bt).reshape(-1, 3) # T = Rt+T torch.Size([24, 3, 1]) cam[:,:3,:3] = cam[:,:3,:3] * scale # R = sR cam_c2w = torch.inverse(cam) cam_w2c = cam poses = [] for i_cam in range(cam.shape[0]): poses.append( torch.concat([ (intrics.reshape(-1)).to(torch.float32), #C ! # C ? T 理论上要给C (cam_w2c[i_cam]).to(torch.float32).reshape(-1), # RT #Rt|C ? RT 理论上要给RT ], dim=0)) cameras = torch.stack(poses).to(device) # [N, 19] return cameras def get_name_str(name): path_ = os.path.basename(os.path.dirname(name)) + os.path.basename(name) return path_ def load_smplx_from_npy(smplx_path, device='cuda'): hand_mean = get_hand_pose_mean().reshape(-1) smplx_pose_param = np.load(smplx_path, allow_pickle=True) # if "person1" in smplx_pose_param: # smplx_pose_param = smplx_pose_param['person1'] smplx_pose_param = { 'root_orient': smplx_pose_param[:, :3], # controls the global root orientation 'pose_body': smplx_pose_param[:, 3:3+63], # controls the body 'pose_hand': smplx_pose_param[:, 66:66+90], # controls the finger articulation 'pose_jaw': smplx_pose_param[:, 66+90:66+93], # controls the yaw pose 'face_expr': smplx_pose_param[:, 159:159+50], # controls the face expression 'face_shape': smplx_pose_param[:, 209:209+100], # controls the face shape 'trans': smplx_pose_param[:, 309:309+3], # controls the global body position 'betas': smplx_pose_param[:, 312:], # controls the body shape. Body shape is static } smplx_param_list = [] for i in range(1, 1800, 1): # for i in k.keys(): # k[i] = np.array(k[i]) left_hands = np.array([1.4624, -0.1615, 0.1361, 1.3851, -0.2597, 0.0247, -0.0683, -0.4478, -0.6652, -0.7290, 0.0084, -0.4818]) betas = torch.zeros((10)) smplx_param = \ np.concatenate([np.array([1]), smplx_pose_param['trans'][i], smplx_pose_param['root_orient'][i], \ smplx_pose_param['pose_body'][i],betas, \ smplx_pose_param['pose_hand'][i]-hand_mean, smplx_pose_param['pose_jaw'][i], np.zeros(6), smplx_pose_param['face_expr'][i][:10]], axis=0).reshape(1,-1) smplx_param_list.append(smplx_param) smplx_params = np.concatenate(smplx_param_list, 0) smpl_params = torch.Tensor(smplx_params).to(device) return smpl_params def add_root_rotate_to_smplx(smpl_tmp, frames_num=180, device='cuda'): from cv2 import Rodrigues initial_matrix = batch_rodrigues(smpl_tmp.reshape(1,189)[:, 4:7]).cpu().numpy().copy() # Rotate a rotation matrix by 360 degrees around the y-axis. # frames_num = 180 all_smpl = [] # Combine the rotations all_smpl = [] for idx_f in range(frames_num): new_smpl = smpl_tmp.clone() angle = 360//frames_num * idx_f y_angle = np.radians(angle) y_rotation_matrix = np.array([ [ np.cos(y_angle),0, np.sin(y_angle)], [0, 1, 0], [-np.sin(y_angle), 0, np.cos(y_angle)], ]) final_matrix = y_rotation_matrix[None] @ initial_matrix new_smpl[4:7] = torch.Tensor(rotation_matrix_to_rodrigues(torch.Tensor(final_matrix))).to(device) all_smpl.append(new_smpl) all_smpl = torch.stack(all_smpl, 0) smpl_params = all_smpl.to(device) return smpl_params def load_smplx_from_json(smplx_path, device='cuda'): # format of motion-x hand_mean = get_hand_pose_mean().reshape(-1) with open(smplx_path, 'r') as f: smplx_pose_param = json.load(f) smplx_param_list = [] for par in smplx_pose_param['annotations']: k = par['smplx_params'] for i in k.keys(): k[i] = np.array(k[i]) betas = torch.zeros((10)) # ######### wrist pose fix ################ smplx_param = \ np.concatenate([np.array([1]), k['trans'], k['root_orient']*np.array([1, 1, 1]), \ k['pose_body'],betas, \ k['pose_hand']-hand_mean, k['pose_jaw'], np.zeros(6), np.zeros_like(k['face_expr'][:10])], axis=0).reshape(1,-1) smplx_param_list.append(smplx_param) smplx_params = np.concatenate(smplx_param_list, 0) print(smplx_params.shape) smpl_params = torch.Tensor(smplx_params).to(device) return smpl_params def get_image_dimensions(input_path): with Image.open(input_path) as img: return img.height, img.width def construct_camera_from_motionx(smplx_path, device='cuda'): with open(smplx_path, 'r') as f: smplx_pose_param = json.load(f) cam_exts = [] cam_ints = [] for par in smplx_pose_param['annotations']: cam = par['cam_params'] R = np.array(cam['cam_R']) K = np.array(cam['intrins']) T = np.array(cam['cam_T']) cam['cam_T'][1] = -cam['cam_T'][1] cam['cam_T'][2] = -cam['cam_T'][2] extrix = np.eye(4) extrix[:3, :3] = R extrix[:3,3] = T cam_exts.append(extrix) intrix = K cam_ints.append(intrix) # target N,20 cam_exts_array = np.array(cam_exts) cam_exts_stack = torch.Tensor(cam_exts_array).to(device).reshape(-1, 16) cam_ints_stack = torch.Tensor(cam_ints).to(device).reshape(-1, 4) cameras = torch.cat([cam_ints_stack, cam_exts_stack], dim=-1).reshape(-1,1, 20) return cameras def remove_background(image: PIL.Image.Image, rembg_session: Any = None, force: bool = False, **rembg_kwargs, ) -> PIL.Image.Image: do_remove = True if image.mode == "RGBA" and image.getextrema()[3][0] < 255: do_remove = False do_remove = do_remove or force if do_remove: image = rembg.remove(image, session=rembg_session, **rembg_kwargs) return image def resize_foreground( image: PIL.Image.Image, ratio: float, ) -> PIL.Image.Image: image = np.array(image) assert image.shape[-1] == 4 alpha = np.where(image[..., 3] > 0) y1, y2, x1, x2 = ( alpha[0].min(), alpha[0].max(), alpha[1].min(), alpha[1].max(), ) # crop the foreground fg = image[y1:y2, x1:x2] # pad to square size = max(fg.shape[0], fg.shape[1]) ph0, pw0 = (size - fg.shape[0]) // 2, (size - fg.shape[1]) // 2 ph1, pw1 = size - fg.shape[0] - ph0, size - fg.shape[1] - pw0 new_image = np.pad( fg, ((ph0, ph1), (pw0, pw1), (0, 0)), mode="constant", constant_values=((0, 0), (0, 0), (0, 0)), ) # compute padding according to the ratio new_size = int(new_image.shape[0] / ratio) # pad to size, double side ph0, pw0 = (new_size - size) // 2, (new_size - size) // 2 ph1, pw1 = new_size - size - ph0, new_size - size - pw0 new_image = np.pad( new_image, ((ph0, ph1), (pw0, pw1), (0, 0)), mode="constant", constant_values=((0, 0), (0, 0), (0, 0)), ) new_image = PIL.Image.fromarray(new_image) return new_image def images_to_video( images: torch.Tensor, output_path: str, fps: int = 30, ) -> None: # images: (N, C, H, W) video_dir = os.path.dirname(output_path) video_name = os.path.basename(output_path) os.makedirs(video_dir, exist_ok=True) frames = [] for i in range(len(images)): frame = (images[i].permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8) assert frame.shape[0] == images.shape[2] and frame.shape[1] == images.shape[3], \ f"Frame shape mismatch: {frame.shape} vs {images.shape}" assert frame.min() >= 0 and frame.max() <= 255, \ f"Frame value out of range: {frame.min()} ~ {frame.max()}" frames.append(frame) imageio.mimwrite(output_path, np.stack(frames), fps=fps, quality=10) def save_video( frames: torch.Tensor, output_path: str, fps: int = 30, ) -> None: # images: (N, C, H, W) frames = [(frame.permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8) for frame in frames] writer = imageio.get_writer(output_path, fps=fps) for frame in frames: writer.append_data(frame) writer.close() ================================================ FILE: lib/utils/mesh.py ================================================ import os import cv2 import torch import trimesh import numpy as np def dot(x, y): return torch.sum(x * y, -1, keepdim=True) def length(x, eps=1e-20): return torch.sqrt(torch.clamp(dot(x, x), min=eps)) def safe_normalize(x, eps=1e-20): return x / length(x, eps) class Mesh: def __init__( self, v=None, f=None, vn=None, fn=None, vt=None, ft=None, albedo=None, vc=None, # vertex color device=None, ): self.device = device self.v = v self.vn = vn self.vt = vt self.f = f self.fn = fn self.ft = ft # only support a single albedo self.albedo = albedo # support vertex color is no albedo self.vc = vc self.ori_center = 0 self.ori_scale = 1 @classmethod def load(cls, path=None, resize=True, renormal=True, retex=False, front_dir='+z', **kwargs): # assume init with kwargs if path is None: mesh = cls(**kwargs) # obj supports face uv elif path.endswith(".obj"): mesh = cls.load_obj(path, **kwargs) # trimesh only supports vertex uv, but can load more formats else: mesh = cls.load_trimesh(path, **kwargs) print(f"[Mesh loading] v: {mesh.v.shape}, f: {mesh.f.shape}") # auto-normalize if resize: mesh.auto_size() # auto-fix normal if renormal or mesh.vn is None: mesh.auto_normal() print(f"[Mesh loading] vn: {mesh.vn.shape}, fn: {mesh.fn.shape}") # auto-fix texcoords if retex or (mesh.albedo is not None and mesh.vt is None): mesh.auto_uv(cache_path=path) print(f"[Mesh loading] vt: {mesh.vt.shape}, ft: {mesh.ft.shape}") # rotate front dir to +z if front_dir != "+z": # axis switch if "-z" in front_dir: T = torch.tensor([[1, 0, 0], [0, 1, 0], [0, 0, -1]], device=mesh.device, dtype=torch.float32) elif "+x" in front_dir: T = torch.tensor([[0, 0, 1], [0, 1, 0], [1, 0, 0]], device=mesh.device, dtype=torch.float32) elif "-x" in front_dir: T = torch.tensor([[0, 0, -1], [0, 1, 0], [1, 0, 0]], device=mesh.device, dtype=torch.float32) elif "+y" in front_dir: T = torch.tensor([[1, 0, 0], [0, 0, 1], [0, 1, 0]], device=mesh.device, dtype=torch.float32) elif "-y" in front_dir: T = torch.tensor([[1, 0, 0], [0, 0, -1], [0, 1, 0]], device=mesh.device, dtype=torch.float32) else: T = torch.tensor([[1, 0, 0], [0, 1, 0], [0, 0, 1]], device=mesh.device, dtype=torch.float32) # rotation (how many 90 degrees) if '1' in front_dir: T @= torch.tensor([[0, -1, 0], [1, 0, 0], [0, 0, 1]], device=mesh.device, dtype=torch.float32) elif '2' in front_dir: T @= torch.tensor([[1, 0, 0], [0, -1, 0], [0, 0, 1]], device=mesh.device, dtype=torch.float32) elif '3' in front_dir: T @= torch.tensor([[0, 1, 0], [-1, 0, 0], [0, 0, 1]], device=mesh.device, dtype=torch.float32) mesh.v @= T mesh.vn @= T return mesh # load from obj file @classmethod def load_obj(cls, path, albedo_path=None, device=None): assert os.path.splitext(path)[-1] == ".obj" mesh = cls() # device if device is None: device = torch.device("cuda" if torch.cuda.is_available() else "cpu") mesh.device = device # load obj with open(path, "r") as f: lines = f.readlines() def parse_f_v(fv): xs = [int(x) - 1 if x != "" else -1 for x in fv.split("/")] xs.extend([-1] * (3 - len(xs))) return xs[0], xs[1], xs[2] # NOTE: we ignore usemtl, and assume the mesh ONLY uses one material (first in mtl) vertices, texcoords, normals = [], [], [] faces, tfaces, nfaces = [], [], [] mtl_path = None for line in lines: split_line = line.split() # empty line if len(split_line) == 0: continue prefix = split_line[0].lower() # mtllib if prefix == "mtllib": mtl_path = split_line[1] # usemtl elif prefix == "usemtl": pass # ignored # v/vn/vt elif prefix == "v": vertices.append([float(v) for v in split_line[1:]]) elif prefix == "vn": normals.append([float(v) for v in split_line[1:]]) elif prefix == "vt": val = [float(v) for v in split_line[1:]] texcoords.append([val[0], 1.0 - val[1]]) elif prefix == "f": vs = split_line[1:] nv = len(vs) v0, t0, n0 = parse_f_v(vs[0]) for i in range(nv - 2): # triangulate (assume vertices are ordered) v1, t1, n1 = parse_f_v(vs[i + 1]) v2, t2, n2 = parse_f_v(vs[i + 2]) faces.append([v0, v1, v2]) tfaces.append([t0, t1, t2]) nfaces.append([n0, n1, n2]) mesh.v = torch.tensor(vertices, dtype=torch.float32, device=device) mesh.vt = ( torch.tensor(texcoords, dtype=torch.float32, device=device) if len(texcoords) > 0 else None ) mesh.vn = ( torch.tensor(normals, dtype=torch.float32, device=device) if len(normals) > 0 else None ) mesh.f = torch.tensor(faces, dtype=torch.int32, device=device) mesh.ft = ( torch.tensor(tfaces, dtype=torch.int32, device=device) if len(texcoords) > 0 else None ) mesh.fn = ( torch.tensor(nfaces, dtype=torch.int32, device=device) if len(normals) > 0 else None ) # see if there is vertex color use_vertex_color = False if mesh.v.shape[1] == 6: use_vertex_color = True mesh.vc = mesh.v[:, 3:] mesh.v = mesh.v[:, :3] print(f"[load_obj] use vertex color: {mesh.vc.shape}") # try to load texture image if not use_vertex_color: # try to retrieve mtl file mtl_path_candidates = [] if mtl_path is not None: mtl_path_candidates.append(mtl_path) mtl_path_candidates.append(os.path.join(os.path.dirname(path), mtl_path)) mtl_path_candidates.append(path.replace(".obj", ".mtl")) mtl_path = None for candidate in mtl_path_candidates: if os.path.exists(candidate): mtl_path = candidate break # if albedo_path is not provided, try retrieve it from mtl if mtl_path is not None and albedo_path is None: with open(mtl_path, "r") as f: lines = f.readlines() for line in lines: split_line = line.split() # empty line if len(split_line) == 0: continue prefix = split_line[0] # NOTE: simply use the first map_Kd as albedo! if "map_Kd" in prefix: albedo_path = os.path.join(os.path.dirname(path), split_line[1]) print(f"[load_obj] use texture from: {albedo_path}") break # still not found albedo_path, or the path doesn't exist if albedo_path is None or not os.path.exists(albedo_path): # init an empty texture print(f"[load_obj] init empty albedo!") # albedo = np.random.rand(1024, 1024, 3).astype(np.float32) albedo = np.ones((1024, 1024, 3), dtype=np.float32) * np.array([0.5, 0.5, 0.5]) # default color else: albedo = cv2.imread(albedo_path, cv2.IMREAD_UNCHANGED) albedo = cv2.cvtColor(albedo, cv2.COLOR_BGR2RGB) albedo = albedo.astype(np.float32) / 255 print(f"[load_obj] load texture: {albedo.shape}") mesh.albedo = torch.tensor(albedo, dtype=torch.float32, device=device) return mesh @classmethod def load_trimesh(cls, path, device=None): mesh = cls() # device if device is None: device = torch.device("cuda" if torch.cuda.is_available() else "cpu") mesh.device = device # use trimesh to load ply/glb, assume only has one single RootMesh... _data = trimesh.load(path) if isinstance(_data, trimesh.Scene): if len(_data.geometry) == 1: _mesh = list(_data.geometry.values())[0] else: # manual concat, will lose texture _concat = [] for g in _data.geometry.values(): if isinstance(g, trimesh.Trimesh): _concat.append(g) _mesh = trimesh.util.concatenate(_concat) else: _mesh = _data if _mesh.visual.kind == 'vertex': vertex_colors = _mesh.visual.vertex_colors vertex_colors = np.array(vertex_colors[..., :3]).astype(np.float32) / 255 mesh.vc = torch.tensor(vertex_colors, dtype=torch.float32, device=device) print(f"[load_trimesh] use vertex color: {mesh.vc.shape}") elif _mesh.visual.kind == 'texture': _material = _mesh.visual.material if isinstance(_material, trimesh.visual.material.PBRMaterial): texture = np.array(_material.baseColorTexture).astype(np.float32) / 255 elif isinstance(_material, trimesh.visual.material.SimpleMaterial): texture = np.array(_material.to_pbr().baseColorTexture).astype(np.float32) / 255 else: raise NotImplementedError(f"material type {type(_material)} not supported!") mesh.albedo = torch.tensor(texture, dtype=torch.float32, device=device) print(f"[load_trimesh] load texture: {texture.shape}") else: texture = np.ones((1024, 1024, 3), dtype=np.float32) * np.array([0.5, 0.5, 0.5]) mesh.albedo = torch.tensor(texture, dtype=torch.float32, device=device) print(f"[load_trimesh] failed to load texture.") vertices = _mesh.vertices try: texcoords = _mesh.visual.uv texcoords[:, 1] = 1 - texcoords[:, 1] except Exception as e: texcoords = None try: normals = _mesh.vertex_normals except Exception as e: normals = None # trimesh only support vertex uv... faces = tfaces = nfaces = _mesh.faces mesh.v = torch.tensor(vertices, dtype=torch.float32, device=device) mesh.vt = ( torch.tensor(texcoords, dtype=torch.float32, device=device) if texcoords is not None else None ) mesh.vn = ( torch.tensor(normals, dtype=torch.float32, device=device) if normals is not None else None ) mesh.f = torch.tensor(faces, dtype=torch.int32, device=device) mesh.ft = ( torch.tensor(tfaces, dtype=torch.int32, device=device) if texcoords is not None else None ) mesh.fn = ( torch.tensor(nfaces, dtype=torch.int32, device=device) if normals is not None else None ) return mesh # aabb def aabb(self): return torch.min(self.v, dim=0).values, torch.max(self.v, dim=0).values # unit size @torch.no_grad() def auto_size(self): vmin, vmax = self.aabb() self.ori_center = (vmax + vmin) / 2 self.ori_scale = 1.2 / torch.max(vmax - vmin).item() self.v = (self.v - self.ori_center) * self.ori_scale def auto_normal(self): i0, i1, i2 = self.f[:, 0].long(), self.f[:, 1].long(), self.f[:, 2].long() v0, v1, v2 = self.v[i0, :], self.v[i1, :], self.v[i2, :] face_normals = torch.cross(v1 - v0, v2 - v0, dim=1) # Splat face normals to vertices vn = torch.zeros_like(self.v) vn.scatter_add_(0, i0[:, None].repeat(1, 3), face_normals) vn.scatter_add_(0, i1[:, None].repeat(1, 3), face_normals) vn.scatter_add_(0, i2[:, None].repeat(1, 3), face_normals) # Normalize, replace zero (degenerated) normals with some default value vn = torch.where( dot(vn, vn) > 1e-20, vn, torch.tensor([0.0, 0.0, 1.0], dtype=torch.float32, device=vn.device), ) vn = safe_normalize(vn) self.vn = vn self.fn = self.f def auto_uv(self, cache_path=None, vmap=True): # try to load cache if cache_path is not None: cache_path = os.path.splitext(cache_path)[0] + "_uv.npz" if cache_path is not None and os.path.exists(cache_path): data = np.load(cache_path) vt_np, ft_np, vmapping = data["vt"], data["ft"], data["vmapping"] else: import xatlas v_np = self.v.detach().cpu().numpy() f_np = self.f.detach().int().cpu().numpy() atlas = xatlas.Atlas() atlas.add_mesh(v_np, f_np) chart_options = xatlas.ChartOptions() # chart_options.max_iterations = 4 atlas.generate(chart_options=chart_options) vmapping, ft_np, vt_np = atlas[0] # [N], [M, 3], [N, 2] # save to cache if cache_path is not None: np.savez(cache_path, vt=vt_np, ft=ft_np, vmapping=vmapping) vt = torch.from_numpy(vt_np.astype(np.float32)).to(self.device) ft = torch.from_numpy(ft_np.astype(np.int32)).to(self.device) self.vt = vt self.ft = ft if vmap: # remap v/f to vt/ft, so each v correspond to a unique vt. (necessary for gltf) vmapping = torch.from_numpy(vmapping.astype(np.int64)).long().to(self.device) self.align_v_to_vt(vmapping) def align_v_to_vt(self, vmapping=None): # remap v/f and vn/vn to vt/ft. if vmapping is None: ft = self.ft.view(-1).long() f = self.f.view(-1).long() vmapping = torch.zeros(self.vt.shape[0], dtype=torch.long, device=self.device) vmapping[ft] = f # scatter, randomly choose one if index is not unique self.v = self.v[vmapping] self.f = self.ft # assume fn == f if self.vn is not None: self.vn = self.vn[vmapping] self.fn = self.ft def to(self, device): self.device = device for name in ["v", "f", "vn", "fn", "vt", "ft", "albedo"]: tensor = getattr(self, name) if tensor is not None: setattr(self, name, tensor.to(device)) return self def write(self, path): if path.endswith(".ply"): self.write_ply(path) elif path.endswith(".obj"): self.write_obj(path) elif path.endswith(".glb") or path.endswith(".gltf"): self.write_glb(path) else: raise NotImplementedError(f"format {path} not supported!") # write to ply file (only geom) def write_ply(self, path): v_np = self.v.detach().cpu().numpy() f_np = self.f.detach().cpu().numpy() _mesh = trimesh.Trimesh(vertices=v_np, faces=f_np) _mesh.export(path) # write to gltf/glb file (geom + texture) def write_glb(self, path): assert self.vn is not None and self.vt is not None # should be improved to support export without texture... # assert self.v.shape[0] == self.vn.shape[0] and self.v.shape[0] == self.vt.shape[0] if self.v.shape[0] != self.vt.shape[0]: self.align_v_to_vt() # assume f == fn == ft import pygltflib f_np = self.f.detach().cpu().numpy().astype(np.uint32) v_np = self.v.detach().cpu().numpy().astype(np.float32) # vn_np = self.vn.detach().cpu().numpy().astype(np.float32) vt_np = self.vt.detach().cpu().numpy().astype(np.float32) albedo = self.albedo.detach().cpu().numpy() albedo = (albedo * 255).astype(np.uint8) albedo = cv2.cvtColor(albedo, cv2.COLOR_RGB2BGR) f_np_blob = f_np.flatten().tobytes() v_np_blob = v_np.tobytes() # vn_np_blob = vn_np.tobytes() vt_np_blob = vt_np.tobytes() albedo_blob = cv2.imencode('.png', albedo)[1].tobytes() gltf = pygltflib.GLTF2( scene=0, scenes=[pygltflib.Scene(nodes=[0])], nodes=[pygltflib.Node(mesh=0)], meshes=[pygltflib.Mesh(primitives=[ pygltflib.Primitive( # indices to accessors (0 is triangles) attributes=pygltflib.Attributes( POSITION=1, TEXCOORD_0=2, ), indices=0, material=0, ) ])], materials=[ pygltflib.Material( pbrMetallicRoughness=pygltflib.PbrMetallicRoughness( baseColorTexture=pygltflib.TextureInfo(index=0, texCoord=0), metallicFactor=0.0, roughnessFactor=1.0, ), alphaCutoff=0, doubleSided=True, ) ], textures=[ pygltflib.Texture(sampler=0, source=0), ], samplers=[ pygltflib.Sampler(magFilter=pygltflib.LINEAR, minFilter=pygltflib.LINEAR_MIPMAP_LINEAR, wrapS=pygltflib.REPEAT, wrapT=pygltflib.REPEAT), ], images=[ # use embedded (buffer) image pygltflib.Image(bufferView=3, mimeType="image/png"), ], buffers=[ pygltflib.Buffer(byteLength=len(f_np_blob) + len(v_np_blob) + len(vt_np_blob) + len(albedo_blob)) ], # buffer view (based on dtype) bufferViews=[ # triangles; as flatten (element) array pygltflib.BufferView( buffer=0, byteLength=len(f_np_blob), target=pygltflib.ELEMENT_ARRAY_BUFFER, # GL_ELEMENT_ARRAY_BUFFER (34963) ), # positions; as vec3 array pygltflib.BufferView( buffer=0, byteOffset=len(f_np_blob), byteLength=len(v_np_blob), byteStride=12, # vec3 target=pygltflib.ARRAY_BUFFER, # GL_ARRAY_BUFFER (34962) ), # texcoords; as vec2 array pygltflib.BufferView( buffer=0, byteOffset=len(f_np_blob) + len(v_np_blob), byteLength=len(vt_np_blob), byteStride=8, # vec2 target=pygltflib.ARRAY_BUFFER, ), # texture; as none target pygltflib.BufferView( buffer=0, byteOffset=len(f_np_blob) + len(v_np_blob) + len(vt_np_blob), byteLength=len(albedo_blob), ), ], accessors=[ # 0 = triangles pygltflib.Accessor( bufferView=0, componentType=pygltflib.UNSIGNED_INT, # GL_UNSIGNED_INT (5125) count=f_np.size, type=pygltflib.SCALAR, max=[int(f_np.max())], min=[int(f_np.min())], ), # 1 = positions pygltflib.Accessor( bufferView=1, componentType=pygltflib.FLOAT, # GL_FLOAT (5126) count=len(v_np), type=pygltflib.VEC3, max=v_np.max(axis=0).tolist(), min=v_np.min(axis=0).tolist(), ), # 2 = texcoords pygltflib.Accessor( bufferView=2, componentType=pygltflib.FLOAT, count=len(vt_np), type=pygltflib.VEC2, max=vt_np.max(axis=0).tolist(), min=vt_np.min(axis=0).tolist(), ), ], ) # set actual data gltf.set_binary_blob(f_np_blob + v_np_blob + vt_np_blob + albedo_blob) # glb = b"".join(gltf.save_to_bytes()) gltf.save(path) # write to obj file (geom + texture) def write_obj(self, path): mtl_path = path.replace(".obj", ".mtl") albedo_path = path.replace(".obj", "_albedo.png") v_np = self.v.detach().cpu().numpy() vt_np = self.vt.detach().cpu().numpy() if self.vt is not None else None vn_np = self.vn.detach().cpu().numpy() if self.vn is not None else None f_np = self.f.detach().cpu().numpy() ft_np = self.ft.detach().cpu().numpy() if self.ft is not None else None fn_np = self.fn.detach().cpu().numpy() if self.fn is not None else None with open(path, "w") as fp: fp.write(f"mtllib {os.path.basename(mtl_path)} \n") for v in v_np: fp.write(f"v {v[0]} {v[1]} {v[2]} \n") if vt_np is not None: for v in vt_np: fp.write(f"vt {v[0]} {1 - v[1]} \n") if vn_np is not None: for v in vn_np: fp.write(f"vn {v[0]} {v[1]} {v[2]} \n") fp.write(f"usemtl defaultMat \n") for i in range(len(f_np)): fp.write( f'f {f_np[i, 0] + 1}/{ft_np[i, 0] + 1 if ft_np is not None else ""}/{fn_np[i, 0] + 1 if fn_np is not None else ""} \ {f_np[i, 1] + 1}/{ft_np[i, 1] + 1 if ft_np is not None else ""}/{fn_np[i, 1] + 1 if fn_np is not None else ""} \ {f_np[i, 2] + 1}/{ft_np[i, 2] + 1 if ft_np is not None else ""}/{fn_np[i, 2] + 1 if fn_np is not None else ""} \n' ) with open(mtl_path, "w") as fp: fp.write(f"newmtl defaultMat \n") fp.write(f"Ka 1 1 1 \n") fp.write(f"Kd 1 1 1 \n") fp.write(f"Ks 0 0 0 \n") fp.write(f"Tr 1 \n") fp.write(f"illum 1 \n") fp.write(f"Ns 0 \n") fp.write(f"map_Kd {os.path.basename(albedo_path)} \n") albedo = self.albedo.detach().cpu().numpy() albedo = (albedo * 255).astype(np.uint8) cv2.imwrite(albedo_path, cv2.cvtColor(albedo, cv2.COLOR_RGB2BGR)) ================================================ FILE: lib/utils/mesh_utils.py ================================================ import numpy as np import pymeshlab as pml import torch def gaussian_3d_coeff(xyzs, covs): # xyzs: [N, 3] # covs: [N, 6] x, y, z = xyzs[:, 0], xyzs[:, 1], xyzs[:, 2] a, b, c, d, e, f = covs[:, 0], covs[:, 1], covs[:, 2], covs[:, 3], covs[:, 4], covs[:, 5] # eps must be small enough !!! inv_det = 1 / (a * d * f + 2 * e * c * b - e**2 * a - c**2 * d - b**2 * f + 1e-24) inv_a = (d * f - e**2) * inv_det inv_b = (e * c - b * f) * inv_det inv_c = (e * b - c * d) * inv_det inv_d = (a * f - c**2) * inv_det inv_e = (b * c - e * a) * inv_det inv_f = (a * d - b**2) * inv_det power = -0.5 * (x**2 * inv_a + y**2 * inv_d + z**2 * inv_f) - x * y * inv_b - x * z * inv_c - y * z * inv_e power[power > 0] = -1e10 # abnormal values... make weights 0 return torch.exp(power) def poisson_mesh_reconstruction(points, normals=None): # points/normals: [N, 3] np.ndarray import open3d as o3d pcd = o3d.geometry.PointCloud() pcd.points = o3d.utility.Vector3dVector(points) # outlier removal pcd, ind = pcd.remove_statistical_outlier(nb_neighbors=20, std_ratio=10) # normals if normals is None: pcd.estimate_normals() else: pcd.normals = o3d.utility.Vector3dVector(normals[ind]) # visualize o3d.visualization.draw_geometries([pcd], point_show_normal=False) mesh, densities = o3d.geometry.TriangleMesh.create_from_point_cloud_poisson( pcd, depth=9 ) vertices_to_remove = densities < np.quantile(densities, 0.1) mesh.remove_vertices_by_mask(vertices_to_remove) # visualize o3d.visualization.draw_geometries([mesh]) vertices = np.asarray(mesh.vertices) triangles = np.asarray(mesh.triangles) print( f"[INFO] poisson mesh reconstruction: {points.shape} --> {vertices.shape} / {triangles.shape}" ) return vertices, triangles def decimate_mesh( verts, faces, target, backend="pymeshlab", remesh=False, optimalplacement=True ): # optimalplacement: default is True, but for flat mesh must turn False to prevent spike artifect. _ori_vert_shape = verts.shape _ori_face_shape = faces.shape if backend == "pyfqmr": import pyfqmr solver = pyfqmr.Simplify() solver.setMesh(verts, faces) solver.simplify_mesh(target_count=target, preserve_border=False, verbose=False) verts, faces, normals = solver.getMesh() else: m = pml.Mesh(verts, faces) ms = pml.MeshSet() ms.add_mesh(m, "mesh") # will copy! # filters # ms.meshing_decimation_clustering(threshold=pml.PercentageValue(1)) ms.meshing_decimation_quadric_edge_collapse( targetfacenum=int(target), optimalplacement=optimalplacement ) if remesh: # ms.apply_coord_taubin_smoothing() ms.meshing_isotropic_explicit_remeshing( iterations=3, targetlen=pml.PercentageValue(1) ) # extract mesh m = ms.current_mesh() verts = m.vertex_matrix() faces = m.face_matrix() print( f"[INFO] mesh decimation: {_ori_vert_shape} --> {verts.shape}, {_ori_face_shape} --> {faces.shape}" ) return verts, faces def clean_mesh( verts, faces, v_pct=1, min_f=64, min_d=20, repair=True, remesh=True, remesh_size=0.01, ): # verts: [N, 3] # faces: [N, 3] import pymeshlab as pml from importlib.metadata import version PML_VER = version('pymeshlab') import ipdb; ipdb.set_trace() if PML_VER == '2022.2.post3': pml.PercentageValue = pml.Percentage pml.PureValue = pml.AbsoluteValue _ori_vert_shape = verts.shape _ori_face_shape = faces.shape m = pml.Mesh(verts, faces) ms = pml.MeshSet() ms.add_mesh(m, "mesh") # will copy! # filters ms.meshing_remove_unreferenced_vertices() # verts not refed by any faces if v_pct > 0: ms.meshing_merge_close_vertices( threshold=pml.Percentage(v_pct) ) # 1/10000 of bounding box diagonal ms.meshing_remove_duplicate_faces() # faces defined by the same verts ms.meshing_remove_null_faces() # faces with area == 0 if min_d > 0: ms.meshing_remove_connected_component_by_diameter( mincomponentdiag=pml.Percentage(min_d) ) if min_f > 0: ms.meshing_remove_connected_component_by_face_number(mincomponentsize=min_f) if repair: # ms.meshing_remove_t_vertices(method=0, threshold=40, repeat=True) ms.meshing_repair_non_manifold_edges(method=0) ms.meshing_repair_non_manifold_vertices(vertdispratio=0) if remesh: # ms.apply_coord_taubin_smoothing() ms.meshing_isotropic_explicit_remeshing( iterations=3, targetlen=pml.PureValue(remesh_size) ) # extract mesh m = ms.current_mesh() verts = m.vertex_matrix() faces = m.face_matrix() print( f"[INFO] mesh cleaning: {_ori_vert_shape} --> {verts.shape}, {_ori_face_shape} --> {faces.shape}" ) return verts, faces ================================================ FILE: lib/utils/train_util.py ================================================ import importlib import os from pytorch_lightning.utilities import rank_zero_only @rank_zero_only def main_print(*args): print(*args) def count_params(model, verbose=False): total_params = sum(p.numel() for p in model.parameters()) if verbose: main_print(f"{model.__class__.__name__} has {total_params*1.e-6:.2f} M params.") return total_params def instantiate_from_config(config): if not "target" in config: if config == '__is_first_stage__': return None elif config == "__is_unconditional__": return None raise KeyError("Expected key `target` to instantiate.") return get_obj_from_str(config["target"])(**config.get("params", dict())) def get_obj_from_str(string, reload=False): main_print(string) module, cls = string.rsplit(".", 1) if reload: module_imp = importlib.import_module(module) importlib.reload(module_imp) return getattr(importlib.import_module(module, package=None), cls) ================================================ FILE: run_demo.py ================================================ import os os.environ["CUDA_VISIBLE_DEVICES"] = "0" # os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3,4,5,6,7" import argparse import torch from tqdm import tqdm from torchvision.transforms import v2 from pytorch_lightning import seed_everything from omegaconf import OmegaConf from tqdm import tqdm from einops import rearrange from lib.utils.infer_util import * from lib.utils.train_util import instantiate_from_config import torchvision import json ############################################################################### # Arguments. ############################################################################### def parse_args(): """Parse command line arguments""" parser = argparse.ArgumentParser() parser.add_argument('--config', type=str, help='Path to config file.', required=False) parser.add_argument('--input_path', type=str, help='Path to input image or directory.', required=False) parser.add_argument('--resume_path', type=str, help='Path to saved ckpt.', required=False) parser.add_argument('--output_path', type=str, default='outputs/', help='Output directory.') parser.add_argument('--seed', type=int, default=42, help='Random seed for sampling.') parser.add_argument('--distance', type=float, default=1.5, help='Render distance.') parser.add_argument('--no_rembg', action='store_true', help='Do not remove input background.') parser.add_argument('--render_mode', type=str, default='novel_pose', choices=['novel_pose', 'reconstruct', 'novel_pose_A'], help='Rendering mode: novel_pose (animation), reconstruct (reconstruction), or novel_pose_A (360-degree view with A-pose)') return parser.parse_args() ############################################################################### # Stage 0: Configuration. ############################################################################### device = torch.device('cuda') def process_data_on_gpu(args, model, gpu_id, img_paths_list, smplx_ref_path_list, smplx_path_driven_list): torch.cuda.set_device(gpu_id) model = model.cuda() image_plist = [] render_mode = args.render_mode cam_idx = 0 # 12 # fixed cameras and changes pose for novel poses num_imgs = 60 if_load_betas = True if_use_video_cam = False # If the SMPLX sequence provides camera parameters, this can be set to true. if_uniform_coordinates = True # Normalize the SMPL-X sequence for the purpose of driving. for input_path, smplx_ref_path, smplx_path in tqdm(zip(img_paths_list, smplx_ref_path_list, smplx_path_driven_list), total = len(img_paths_list)): print(f"Processing: {input_path}") args.input_path = input_path args.input_path_smpl = smplx_ref_path # get a name for results name = get_name_str(args.input_path) + get_name_str(smplx_path) ############################################################################### # Stage 1: Parameters loading ############################################################################### ''' # Stage 1.1: SMPLX loading (Beta)''' if args.input_path_smpl is not None: # smplx = np.load(args.input_path_smpl, allow_pickle=True).item() smplx = json.load(open(args.input_path_smpl)) if "shapes" in smplx.keys(): smplx['betas'] = smplx['shapes'] else: smplx['betas'] = smplx['betas_save'] smpl_params = torch.zeros(1, 189).to(device) if if_load_betas: smpl_params[:, 70:80] = torch.Tensor(smplx['betas']).to(device) ''' # Stage 1.2: SMPLX loading (Pose)''' # animation if render_mode in ['novel_pose'] : if smplx_path.endswith(".npy"): smpl_params = load_smplx_from_npy(smplx_path) smpl_params[:, 70:80] = torch.Tensor(smplx['betas']).to(device) # ========= Note: If the video camera is not used, center everything at the origin ======== if if_uniform_coordinates: print("''' Ending --- Adjusting root orientation angles '''") # Extract root orientation and translation from SMPL parameters root_orient = smpl_params[:, 4:7] trans = smpl_params[:, 1:4] # Reset the first frame's rotation and adjust translations new_root_orient, new_trans = reset_first_frame_rotation(root_orient, trans) # Update the root orientation and translation in the SMPL parameters smpl_params[:, 4:7] = new_root_orient smpl_params[:, 1:4] = new_trans.squeeze() # Apply the new translation elif smplx_path.endswith(".json"): ''' for motion-x input ''' smpl_params = load_smplx_from_json(smplx_path) smpl_params[:, 70:80] = torch.Tensor(smplx['betas']).to(device) if_use_video_cam = True elif render_mode in ['reconstruct']: RT_rec, intri_rec, smpl_rec = load_smplify_json(smplx_ref_path) H_rec, W_rec = get_image_dimensions(input_path) '''Apply root rotation for a full 360-degree view of the object''' if_add_root_rotate = True if if_add_root_rotate == True: smpl_params = add_root_rotate_to_smplx(smpl_rec, num_imgs) print(" '''ending --- invert the root angles'''") else: smpl_params = smpl_params.to(device) num_imgs = 1 elif render_mode in ['novel_pose_A']: smpl_params = model.get_default_smplx_params().squeeze() smpl_params = smpl_params.to(device) smpl_params = add_root_rotate_to_smplx(smpl_params.clone(), num_imgs) smpl_params[:, 70:80] = torch.Tensor(smplx['betas']).to(device) else: raise NotImplementedError(f"Render mode '{render_mode}' is not supported.") '''# Stage 1.3: Image loading ''' image = load_image(args.input_path, args.output_folders['ref']) H,W = 896,640 image_bs = image.unsqueeze(0).to(device) num_imgs = 180 ''' # Stage 1.4 Camera loading''' if not if_use_video_cam: # prepare cameras K, cam_list = prepare_camera(resolution_x=H, resolution_y=W, num_views=num_imgs, stides=1) cameras = construct_camera(K, cam_list) if render_mode == 'novel_pose': # if poses are changed, cameras will be fixed intrics = torch.Tensor([K[0,0],K[1,1], 256, 256]).reshape(-1) model.decoder.renderer.image_size = [512, 512] assert cameras.shape[-1] == 20 cameras[:, :4] = intrics cameras = cameras[cam_idx:cam_idx+1] num_imgs = smpl_params.shape[0] cameras = cameras.repeat(num_imgs, 1) cameras = cameras[:, None, :] # length of the pose sequences print("modify the render images's resolution into 512x512 ") elif render_mode in ['reconstruct']: # using reference smplify's smplx and camera cameras = torch.concat([intri_rec.reshape(-1,4), RT_rec.reshape(-1, 16)], dim=1) # H, W = int(intri_rec[2] * 2), int(intri_rec[3] * 2) model.decoder.renderer.image_size = [W_rec, H_rec]; print(f"modify the render images's resolution into {H_rec}x{W_rec}") cameras = cameras.reshape(1,1,20).expand(num_imgs,1,-1) cameras = cameras.cuda() elif render_mode == 'novel_pose_A': model.decoder.renderer.image_size = [W, H] cameras = cameras[0].reshape(1,1,20).expand(num_imgs,1,-1) elif if_use_video_cam: # for the animation with motion-x cameras = construct_camera_from_motionx(smplx_path) H, W = 2*cameras[0, 0, [3]].int().item(), 2*cameras[0,0, [2]].int().item() model.decoder.renderer.image_size = [W, H]; print(f"modify the render images's resolution into {H}x{W}") # model.decoder.renderer = ############################################################################### # Stage 2: Reconstruction. ############################################################################### sample = image_bs[[0]] # N, 3, H, W, # if if_use_dataset: # sample = rearrange(sample, 'b h w c -> b c h w') # N, 3, H, W, image_path_idx = os.path.join(args.output_folders['ref'], f'{name}_ref.jpg') torchvision.utils.save_image( sample[0], image_path_idx) with torch.no_grad(): # get latents code = model.forward_image_to_uv(sample, is_training=False) with torch.no_grad(): output_list = [] num_imgs_batch = 5 total_frames = min(smpl_params.shape[0],300) res_uv = None for i in tqdm(range(0, total_frames, num_imgs_batch)): if i+num_imgs_batch > total_frames: num_imgs_batch = total_frames - i code_bt = code.expand(num_imgs_batch, -1, -1, -1) # cameras_bt = cameras.expand(num_imgs_batch, -1, -1) cameras_bt = cameras[i:i+num_imgs_batch] if render_mode in ['reconstruct', 'novel_pose_A'] and res_uv is not None: pass else: res_uv = model.decoder._decode_feature(code_bt) # Decouple UV attributes res_points = model.decoder._sample_feature(res_uv) # Sampling # Animate res_def_points = model.decoder.deform_pcd(res_points, smpl_params[i:i+num_imgs_batch].to(code_bt.dtype), zeros_hands_off=True, value=0.02) output = model.decoder.forward_render(res_def_points, cameras_bt.to(code_bt.dtype), num_imgs=1) image = output["image"][:, 0].cpu().to(torch.float32) print("output shape ", output["image"][:, 0].shape) output_list.append(image) # [:, 0] stands to get the all scenes (poses) del output output = torch.concatenate(output_list, 0) frames = rearrange(output, "b h w c -> b c h w")#.cpu().numpy() video_path_idx = os.path.join(args.output_folders['video'], f'{name}.mp4') save_video( frames[:,:4,...].to(torch.float32), video_path_idx, ) image_plist.append(frames) print("saving into ", video_path_idx) return image_plist def setup_directories(base_path, config_name): """Create output directories for results""" dirs = { 'image': os.path.join(base_path, config_name, 'images'), 'video': os.path.join(base_path, config_name), 'ref': os.path.join(base_path, config_name) } for path in dirs.values(): os.makedirs(path, exist_ok=True) return dirs def main(): """Main execution function""" # Parse arguments and set random seed args = parse_args() args.config = "configs/idol_v0.yaml" args.resume_path = "work_dirs/ckpt/model.ckpt" config = OmegaConf.load(args.config) config_name = os.path.basename(args.config).replace('.yaml', '') model_config = config.model resume_path = args.resume_path # Initialize model model = instantiate_from_config(model_config) model.encoder = model.encoder.to(torch.bfloat16) ; print("moving encoder to bf16") model = model.__class__.load_from_checkpoint(resume_path, **config.model.params) model = model.to(device) model = model.eval() # Setup input paths img_paths_list = ['work_dirs/demo_data/4.jpg'] smplx_ref_path_list = ['work_dirs/demo_data/4.json'] smplx_path_driven_list = ['work_dirs/demo_data/Ways_to_Catch_360_clip1.json'] # smplx_path_driven_list = ['work_dirs/demo_data/finedance-5-144.npy.npy'] # Setup output directories # args.output_path = "./test/" # args.render_mode = 'reconstruct' # 'novel_pose_A' #'reconstruct' #'novel_pose' # make output directories args.output_folders = setup_directories(args.output_path, config_name) # Process data image_plist = process_data_on_gpu( args, model, 0, img_paths_list, smplx_ref_path_list, smplx_path_driven_list ) return image_plist if __name__ == "__main__": main() ================================================ FILE: scripts/download_files.sh ================================================ #!/bin/bash # Create necessary directories mkdir -p work_dirs/ mkdir -p work_dirs/ckpt # Download files from HuggingFace echo "Downloading model files..." wget https://huggingface.co/yiyuzhuang/IDOL/resolve/main/model.ckpt -O work_dirs/ckpt/model.ckpt wget https://huggingface.co/yiyuzhuang/IDOL/resolve/main/sapiens_1b_epoch_173_torchscript.pt2 -O work_dirs/ckpt/sapiens_1b_epoch_173_torchscript.pt2 wget https://huggingface.co/yiyuzhuang/IDOL/resolve/main/cache_sub2.zip -O work_dirs/cache_sub2.zip # Unzip cache file echo "Extracting cache files..." unzip -o work_dirs/cache_sub2.zip -d work_dirs/ rm work_dirs/cache_sub2.zip # Remove zip file after extraction echo "Download and extraction completed!" ================================================ FILE: scripts/fetch_template.sh ================================================ #!/bin/bash urle () { [[ "${1}" ]] || return 1; local LANG=C i x; for (( i = 0; i < ${#1}; i++ )); do x="${1:i:1}"; [[ "${x}" == [a-zA-Z0-9.~-] ]] && echo -n "${x}" || printf '%%%02X' "'${x}"; done; echo; } mkdir -p lib/models/deformers/smplx/SMPLX # username and password input echo -e "\nYou need to register at https://smpl-x.is.tue.mpg.de/, according to Installation Instruction." read -p "Username (SMPL-X):" username read -p "Password (SMPL-X):" password username=$(urle $username) password=$(urle $password) # SMPLX echo -e "\nDownloading SMPL-X model..." wget --post-data "username=$username&password=$password" 'https://download.is.tue.mpg.de/download.php?domain=smplx&sfile=models_smplx_v1_1.zip' -O 'models_smplx_v1_1.zip' --no-check-certificate --continue unzip models_smplx_v1_1.zip -d lib/models/deformers/smplx/SMPLX mv lib/models/deformers/smplx/SMPLX/models/smplx/* lib/models/deformers/smplx/SMPLX rm -rf lib/models/deformers/smplx/SMPLX/models rm -f models_smplx_v1_1.zip mkdir -p work_dirs/cache/template cd work_dirs/cache/template echo -e "\nDownloading SMPL-X segmentation info..." wget https://github.com/Meshcapade/wiki/blob/main/assets/SMPL_body_segmentation/smplx/smplx_vert_segmentation.json echo -e "\nDownloading SMPL-X UV info..." wget --post-data "username=$username&password=$password" 'https://download.is.tue.mpg.de/download.php?domain=smplx&sfile=smplx_uv.zip' -O './smplx_uv.zip' --no-check-certificate --continue unzip smplx_uv.zip -d ./smplx_uv mv smplx_uv/smplx_uv.obj ./ rm -f smplx_uv.zip rm -rf smplx_uv echo -e "\nDownloading SMPL-X FLAME correspondence info..." wget --post-data "username=$username&password=$password" 'https://download.is.tue.mpg.de/download.php?domain=smplx&sfile=smplx_mano_flame_correspondences.zip' -O './smplx_mano_flame_correspondences.zip' --no-check-certificate --continue unzip smplx_mano_flame_correspondences.zip -d ./smplx_mano_flame_correspondences mv smplx_mano_flame_correspondences/SMPL-X__FLAME_vertex_ids.npy ./ rm -f smplx_mano_flame_correspondences.zip rm -rf smplx_mano_flame_correspondences echo -e "\nDownloading FLAME template from neural-head-avatars repo..." wget https://raw.githubusercontent.com/philgras/neural-head-avatars/main/assets/flame/head_template_mesh_mouth.obj echo -e "\nDownloading FLAME template from DECA repo..." wget https://raw.githubusercontent.com/yfeng95/DECA/master/data/head_template.obj echo -e "\nYou need to register at http://flame.is.tue.mpg.de/, according to Installation Instruction." read -p "Username (FLAME):" username read -p "Password (FLAME):" password username=$(urle $username) password=$(urle $password) echo -e "\nDownloading FLAME segmentation info..." wget 'https://files.is.tue.mpg.de/tbolkart/FLAME/FLAME_masks.zip' -O 'FLAME_masks.zip' --no-check-certificate --continue unzip FLAME_masks.zip -d ./FLAME_masks mv FLAME_masks/FLAME_masks.pkl ./ rm -f FLAME_masks.zip rm -rf FLAME_masks cd ../../.. echo -e "\n Finish" ================================================ FILE: scripts/pip_install.sh ================================================ #!/bin/bash # Complete environment setup process # Step 0: Ensure you create a Conda environment # and Activate the environment # conda activate idol # Step 1: Install Pytorch with CUDA: pip install torch==2.3.1+cu118 torchvision==0.18.1+cu118 torchaudio==2.3.1+cu118 \ --index-url https://download.pytorch.org/whl/cu118 # Step 2: Use pip to install additional dependencies pip_packages=( "absl-py==2.1.0" "accelerate==0.29.1" "addict==2.4.0" "albumentations==1.4.17" "bitsandbytes" "deepspeed==0.15.1" "diffusers==0.20.2" "einops==0.8.0" "fastapi==0.111.0" "gradio==3.41.2" "matplotlib==3.8.4" "numpy==1.26.3" "opencv-python==4.9.0.80" "pandas==2.2.2" "pillow==10.3.0" "scikit-image==0.23.2" "scipy==1.13.0" "timm==0.9.16" "transformers==4.40.1" "pytorch-lightning==2.3.1" "omegaconf==2.3.0" "av" "webdataset" "omegaconf" "rembg==2.0.57" "tensorboard" ) Install pip packages in bulk for package in "${pip_packages[@]}" do pip install "$package" done # Create submodule directory if it doesn't exist mkdir -p submodule cd submodule # Step 3: Install PyTorch3D git clone https://github.com/facebookresearch/pytorch3d.git cd pytorch3d git checkout v0.7.7 pip install -e . cd .. # Step 4: Install Simple-KNN git clone https://gitlab.inria.fr/bkerbl/simple-knn.git cd simple-knn pip install -e . cd .. # Step 5: Install Gaussian Splatting git clone https://github.com/graphdeco-inria/gaussian-splatting --recursive cd gaussian-splatting/submodules/diff-gaussian-rasterization python setup.py develop cd ../../.. # Step 6: Install Sapiens git clone https://github.com/facebookresearch/sapiens cd sapiens/engine pip install -e . cd ../pretrain pip install -e . cd ../../.. # Step 7: Install deformation module python setup.py develop echo "idol environment setup completed!" ================================================ FILE: setup.py ================================================ from setuptools import setup from torch.utils.cpp_extension import BuildExtension, CUDAExtension import os cuda_dir = "lib/models/deformers/fast_snarf/cuda" setup( name='fuse', ext_modules=[ CUDAExtension('fuse_cuda', [f'{cuda_dir}/fuse_kernel/fuse_cuda.cpp', f'{cuda_dir}/fuse_kernel/fuse_cuda_kernel.cu']), CUDAExtension('filter_cuda', [f'{cuda_dir}/filter/filter.cpp', f'{cuda_dir}/filter/filter_kernel.cu']), CUDAExtension('precompute_cuda', [f'{cuda_dir}/precompute/precompute.cpp', f'{cuda_dir}/precompute/precompute_kernel.cu']) ], cmdclass={ 'build_ext': BuildExtension }) ================================================ FILE: train.py ================================================ import os, sys # os.environ["WANDB_MODE"] = "dryrun" # default setting to save locally from lib.utils.train_util import main_print import torch # Check GPU information if torch.cuda.is_available(): gpu_info = torch.cuda.get_device_name() if "H20" in gpu_info or "H800" in gpu_info: os.environ["NCCL_SOCKET_IFNAME"] = "bond1" # for H20 # If using H20 GPU, set network interface main_print("changing the network interface to bond1") if "H800" in gpu_info: # Set precision for matrix multiplication torch.set_float32_matmul_precision('medium') # or 'high' import argparse import shutil import subprocess from omegaconf import OmegaConf import torch from pytorch_lightning import seed_everything from pytorch_lightning.trainer import Trainer from pytorch_lightning.strategies import DDPStrategy from pytorch_lightning.strategies import DeepSpeedStrategy from pytorch_lightning.callbacks import Callback from pytorch_lightning.utilities import rank_zero_only from lib.utils.train_util import instantiate_from_config from pytorch_lightning import loggers as pl_loggers def get_parser(**parser_kwargs): def str2bool(v): if isinstance(v, bool): return v if v.lower() in ("yes", "true", "t", "y", "1"): return True elif v.lower() in ("no", "false", "f", "n", "0"): return False else: raise argparse.ArgumentTypeError("Boolean value expected.") parser = argparse.ArgumentParser(**parser_kwargs) parser.add_argument( "-r", "--resume", type=str, default=None, help="resume from checkpoint", ) parser.add_argument( "--resume_weights_only", action="store_true", help="only resume model weights", ) parser.add_argument( "--resume_not_loading_decoder", action="store_true", help="only resume model weights excepts decoder", ) # parser.add_argument( # "--custom_loading_for_PA", # action="store_true", # help="customly loading the PA network", # ) parser.add_argument( "-b", "--base", type=str, default="base_config.yaml", help="path to base configs", ) parser.add_argument( "-n", "--name", type=str, default="", help="experiment name", ) parser.add_argument( "--num_nodes", type=int, default=1, help="number of nodes to use", ) parser.add_argument( "--gpus", type=str, default="0,", help="gpu ids to use", ) parser.add_argument( "-s", "--seed", type=int, default=42, help="seed for seed_everything", ) parser.add_argument( "-l", "--logdir", type=str, default="logs", help="directory for logging data", ) parser.add_argument( "--test_sd", type=str, default="", help="path to state dict for testing", ) parser.add_argument( "--test_dataset", type=str, default="./configs/test_dataset.yaml", help="path to state dict for testing", ) parser.add_argument( "--is_debug", action="store_true", help="flag to specify if in debug mode, if true, it will returns more results", ) parser.add_argument( "--training_mode", type=str, default=None, help="flag to specify the training strategy", ) return parser class SetupCallback(Callback): def __init__(self, resume, logdir, ckptdir, cfgdir, config): super().__init__() self.resume = resume self.logdir = logdir self.ckptdir = ckptdir self.cfgdir = cfgdir self.config = config def on_fit_start(self, trainer, pl_module): if trainer.global_rank == 0: # Create logdirs and save configs os.makedirs(self.logdir, exist_ok=True) os.makedirs(self.ckptdir, exist_ok=True) os.makedirs(self.cfgdir, exist_ok=True) main_print("Project config") main_print(OmegaConf.to_yaml(self.config)) OmegaConf.save(self.config, os.path.join(self.cfgdir, "project.yaml")) class CodeSnapshot(Callback): """ Modified from https://github.com/threestudio-project/threestudio/blob/main/threestudio/utils/callbacks.py#L60 """ def __init__(self, savedir, exclude_patterns=None): self.savedir = savedir # Default excluded files and folders patterns self.exclude_patterns = exclude_patterns or [ "*.mp4", "*.npy", "work_dirs/*", "processed_data/*", "logs/*" ] def get_file_list(self): # Get git tracked files, excluding configs directory tracked_files = subprocess.check_output( 'git ls-files -- ":!:configs/*"', shell=True ).splitlines() # Get untracked but not ignored files untracked_files = subprocess.check_output( "git ls-files --others --exclude-standard", shell=True ).splitlines() # Merge file lists and decode all_files = [b.decode() for b in set(tracked_files) | set(untracked_files)] # Apply exclusion pattern filtering filtered_files = [] for file_path in all_files: should_exclude = False for pattern in self.exclude_patterns: if self._match_pattern(file_path, pattern): should_exclude = True break if not should_exclude: filtered_files.append(file_path) return filtered_files def _match_pattern(self, file_path, pattern): """Check if file path matches the given pattern""" # Handle directory wildcard patterns (e.g., work_dirs/*) if pattern.endswith('/*'): dir_prefix = pattern[:-1] # Remove '*' return file_path.startswith(dir_prefix) # Handle file extension patterns (e.g., *.mp4) if pattern.startswith('*'): ext = pattern[1:] # Get extension part return file_path.endswith(ext) # Exact matching return file_path == pattern @rank_zero_only def save_code_snapshot(self): os.makedirs(self.savedir, exist_ok=True) for f in self.get_file_list(): if not os.path.exists(f) or os.path.isdir(f): continue os.makedirs(os.path.join(self.savedir, os.path.dirname(f)), exist_ok=True) shutil.copyfile(f, os.path.join(self.savedir, f)) def on_fit_start(self, trainer, pl_module): try: self.save_code_snapshot() except: main_print( "Code snapshot is not saved. Please make sure you have git installed and are in a git repository." ) if __name__ == "__main__": # add cwd for convenience and to make classes in this file available when # running as `python main.py` sys.path.append(os.getcwd()) parser = get_parser() opt, unknown = parser.parse_known_args() cfg_fname = os.path.split(opt.base)[-1] cfg_name = os.path.splitext(cfg_fname)[0] exp_name = "-" + opt.name if opt.name != "" else "" logdir = os.path.join(opt.logdir, cfg_name+exp_name) # init configs config = OmegaConf.load(opt.base) lightning_config = config.lightning trainer_config = lightning_config.trainer # modify some config for debug mode if opt.is_debug: lightning_config['trainer']['val_check_interval'] = 1 exp_name = 'debug' logdir = os.path.join(opt.logdir, cfg_name+exp_name) config.model.params['is_debug'] = True config.dataset.batch_size = 1 #ss config.dataset.num_workers = 1 config.dataset.params.train.params.cache_path = config.dataset.params.debug_cache_path ckptdir = os.path.join(logdir, "checkpoints") cfgdir = os.path.join(logdir, "configs") codedir = os.path.join(logdir, "code") seed_everything(opt.seed) main_print(f"Running on GPUs {opt.gpus}") ngpu = len(opt.gpus.strip(",").split(',')) trainer_config['devices'] = ngpu trainer_opt = argparse.Namespace(**trainer_config) lightning_config.trainer = trainer_config # testing setting if len(opt.test_sd) > 0: config_dataset = OmegaConf.load(opt.test_dataset) config.dataset = config_dataset.dataset precision_config = {'precision':"bf16"} # model model = instantiate_from_config(config.model) if precision_config['precision'] == "bf16": model.encoder = model.encoder.to(torch.bfloat16) if opt.resume and opt.resume_weights_only: if opt.resume_not_loading_decoder: main_print("========Loading only model weights excepts decoder ==============") # Load complete state dictionary state_dict = torch.load(opt.resume, map_location='cpu')['state_dict'] # Create a new state dictionary only containing the parts you want to load new_state_dict = {k: v for k, v in state_dict.items() if not (k.startswith('encoder') or k.startswith('decoder') or k.startswith('lpips'))} # Load the remaining state dictionary model.load_state_dict(state_dict, strict=False) del state_dict with torch.amp.autocast( device_type='cpu'): state_dict = torch.load(opt.resume, map_location='cpu')['state_dict'] main_print([k for k in state_dict.keys() if not k.startswith('lpips') ]) new_state_dict = {k: v for k, v in state_dict.items() if not k.startswith('lpips')} model.load_state_dict(new_state_dict, strict=False) model = model.to('cuda') model.logdir = logdir # trainer and callbacks trainer_kwargs = dict() # logger param_log = { 'save_dir': logdir, 'name': cfg_name+exp_name, } trainer_kwargs["logger"] = [ pl_loggers.TensorBoardLogger(**param_log), pl_loggers.CSVLogger(**param_log) ] # model checkpoint default_modelckpt_cfg = { "target": "pytorch_lightning.callbacks.ModelCheckpoint", "params": { "dirpath": ckptdir, "filename": "{step:08}", "verbose": True, "save_last": True, "every_n_train_steps": 5000, "save_top_k": -1, # save all checkpoints } } if "modelcheckpoint" in lightning_config: modelckpt_cfg = lightning_config.modelcheckpoint else: modelckpt_cfg = OmegaConf.create() modelckpt_cfg = OmegaConf.merge(default_modelckpt_cfg, modelckpt_cfg) # callbacks default_callbacks_cfg = { "setup_callback": { "target": "train.SetupCallback", "params": { "resume": opt.resume, "logdir": logdir, "ckptdir": ckptdir, "cfgdir": cfgdir, "config": config, } }, "learning_rate_logger": { "target": "pytorch_lightning.callbacks.LearningRateMonitor", "params": { "logging_interval": "step", } }, "code_snapshot": { "target": "train.CodeSnapshot", "params": { "savedir": codedir, } }, } default_callbacks_cfg["checkpoint_callback"] = modelckpt_cfg if "callbacks" in lightning_config: callbacks_cfg = lightning_config.callbacks else: callbacks_cfg = OmegaConf.create() callbacks_cfg = OmegaConf.merge(default_callbacks_cfg, callbacks_cfg) trainer_kwargs["callbacks"] = [ instantiate_from_config(callbacks_cfg[k]) for k in callbacks_cfg] training_mode = "DDP" if opt.training_mode is None else opt.training_mode if training_mode == 'DDP': trainer_kwargs["strategy"] = DDPStrategy(find_unused_parameters=False, static_graph=True) # TODO modify to True elif training_mode == 'ZERO': # DeepSpeed strategy = DeepSpeedStrategy(config='./configs/deepspeed_config.json') trainer_kwargs["strategy"] = strategy# TODO modify to True elif training_mode == 'FSDP': from pytorch_lightning.strategies import FSDPStrategy fsdp_strategy = FSDPStrategy( auto_wrap_policy=None, activation_checkpointing_policy=None, cpu_offload=False, # Whether to offload model parameters to CPU limit_all_gathers=False, # Whether to limit all gather operations sync_module_states=True, # Whether to synchronize module states # use_sharded_checkpoint=True, # Whether to use sharded checkpoints mixed_precision='bf16', # Mixed precision training, default is 'bf16' ) trainer_kwargs["strategy"] = fsdp_strategy main_print(f" ............ trying training strategy {training_mode} ...........") trainer = Trainer(**precision_config, **trainer_config, **trainer_kwargs, num_nodes=opt.num_nodes) trainer.logdir = logdir # data data = instantiate_from_config(config.dataset) data.prepare_data() data.setup("fit") # configure learning rate base_lr = config.model.params.neck_learning_rate if 'accumulate_grad_batches' in lightning_config.trainer: accumulate_grad_batches = lightning_config.trainer.accumulate_grad_batches else: accumulate_grad_batches = 1 main_print(f"accumulate_grad_batches = {accumulate_grad_batches}") lightning_config.trainer.accumulate_grad_batches = accumulate_grad_batches model.learning_rate = base_lr main_print("++++ NOT USING LR SCALING ++++") main_print(f"Setting learning rate to {model.learning_rate:.2e}") # trainer.fit(model, data) # debug if len(opt.test_sd) > 0: sd = torch.load(opt.test_sd, map_location='cpu') model.load_state_dict(sd, strict=False) with torch.amp.autocast(device_type='cpu'): # import ipdb; ipdb.set_trace() def load_folder_ckpt(checkpoint_dir): # For DeepSpeed loading # Get all .pt files pt_files = [os.path.join(checkpoint_dir, f) for f in os.listdir(checkpoint_dir) if f.endswith('.pt')] # Initialize model state dictionary model_state_dict = {} # Load each .pt file and merge into model state dictionary for pt_file in pt_files: state_dict = torch.load(pt_file, map_location='cpu') model_state_dict.update(state_dict) return model_state_dict if os.path.isdir(opt.test_sd): state_dict = load_folder_ckpt(opt.test_sd+"/checkpoint") # Load checkpoint success = model.load_checkpoint(opt.test_sd, load_optimizer_states=True, load_lr_scheduler_states=True) else: state_dict = torch.load(opt.test_sd, map_location='cpu')['state_dict'] main_print([k for k in state_dict.keys() if not k.startswith('lpips') ]) new_state_dict = {k: v for k, v in state_dict.items() if not k.startswith('lpips')} new_state_dict = {k: v for k, v in new_state_dict.items() if not k.startswith('encoder')} model.load_state_dict(new_state_dict, strict=False) main_print(f"========testing =====, loading from {opt.test_sd} ================") model = model.to('cuda') with torch.no_grad(): trainer.test(model, data) else: # run training loop try: if opt.resume and not opt.resume_weights_only: trainer.fit(model, data, ckpt_path=opt.resume) else: trainer.fit(model, data) except Exception as e: main_print(f"An error occurred: {e}") torch.cuda.empty_cache()