Full Code of yiyuzhuang/IDOL for AI

main 9fd9296c28e8 cached
58 files
396.2 KB
107.8k tokens
260 symbols
1 requests
Download .txt
Showing preview only (418K chars total). Download the full file or copy to clipboard to get everything.
Repository: yiyuzhuang/IDOL
Branch: main
Commit: 9fd9296c28e8
Files: 58
Total size: 396.2 KB

Directory structure:
gitextract_sj4g9tfa/

├── .gitignore
├── README.md
├── configs/
│   ├── idol_debug.yaml
│   ├── idol_v0.yaml
│   └── test_dataset.yaml
├── data_processing/
│   ├── prepare_cache.py
│   ├── process_datasets.sh
│   └── visualize_samples.py
├── dataset/
│   ├── README.md
│   ├── sample/
│   │   └── param/
│   │       └── Kenya_female_fit_streetwear_50~60 years old_1501.npy
│   └── visualize_samples.py
├── env/
│   └── README.md
├── lib/
│   ├── __init__.py
│   ├── datasets/
│   │   ├── __init__.py
│   │   ├── avatar_dataset.py
│   │   └── dataloader.py
│   ├── humanlrm_wrapper_sa_v1.py
│   ├── mmutils/
│   │   ├── __init__.py
│   │   └── initialize.py
│   ├── models/
│   │   ├── __init__.py
│   │   ├── decoders/
│   │   │   ├── __init__.py
│   │   │   ├── uvmaps_decoder_gender.py
│   │   │   └── vit_head.py
│   │   ├── deformers/
│   │   │   ├── __init__.py
│   │   │   ├── fast_snarf/
│   │   │   │   ├── cuda/
│   │   │   │   │   ├── filter/
│   │   │   │   │   │   ├── filter.cpp
│   │   │   │   │   │   └── filter_kernel.cu
│   │   │   │   │   ├── fuse_kernel/
│   │   │   │   │   │   ├── fuse_cuda.cpp
│   │   │   │   │   │   └── fuse_cuda_kernel.cu
│   │   │   │   │   └── precompute/
│   │   │   │   │       ├── precompute.cpp
│   │   │   │   │       └── precompute_kernel.cu
│   │   │   │   └── lib/
│   │   │   │       └── model/
│   │   │   │           ├── deformer_smpl.py
│   │   │   │           └── deformer_smplx.py
│   │   │   ├── smplx/
│   │   │   │   ├── __init__.py
│   │   │   │   ├── body_models.py
│   │   │   │   ├── joint_names.py
│   │   │   │   ├── lbs.py
│   │   │   │   ├── utils.py
│   │   │   │   ├── vertex_ids.py
│   │   │   │   └── vertex_joint_selector.py
│   │   │   └── smplx_deformer_gender.py
│   │   ├── renderers/
│   │   │   ├── __init__.py
│   │   │   └── gau_renderer.py
│   │   ├── sapiens/
│   │   │   ├── __init__.py
│   │   │   └── sapiens_wrapper_torchscipt.py
│   │   └── transformer_sa/
│   │       ├── __init__.py
│   │       └── mae_decoder_v3_skip.py
│   ├── ops/
│   │   ├── __init__.py
│   │   └── activation.py
│   └── utils/
│       ├── infer_util.py
│       ├── mesh.py
│       ├── mesh_utils.py
│       └── train_util.py
├── run_demo.py
├── scripts/
│   ├── download_files.sh
│   ├── fetch_template.sh
│   └── pip_install.sh
├── setup.py
└── train.py

================================================
FILE CONTENTS
================================================

================================================
FILE: .gitignore
================================================
# Python cache files
__pycache__/
*.py[cod]
*$py.class
*.so

# Project specific
work_dirs/
test/
lib/models/deformers/smplx/SMPLX/*

# IDE
.idea/
.vscode/
*.swp
*.swo

# Distribution / packaging
dist/
build/
*.egg-info/

# Jupyter Notebook
.ipynb_checkpoints


# Git related
*.orig
*.rej
*.patch 

================================================
FILE: README.md
================================================
# **IDOL: Instant Photorealistic 3D Human Creation from a Single Image**  

[![Website](https://img.shields.io/badge/Project-Website-0073e6)](https://yiyuzhuang.github.io/IDOL/)
[![Paper](https://img.shields.io/badge/arXiv-PDF-b31b1b)](https://arxiv.org/pdf/2412.14963)
[![Live Demo](https://img.shields.io/badge/Live-Demo-34C759)](https://yiyuzhuang.github.io/IDOL/)
[![License](https://img.shields.io/badge/License-MIT-blue.svg)](https://opensource.org/licenses/MIT)


---

<p align="center">
  <img src="./asset/images/Teaser_v2.png" alt="Teaser Image for IDOL" width="85%">
</p>

---

## **Abstract**

This work introduces **IDOL**, a feed-forward, single-image human reconstruction framework that is fast, high-fidelity, and generalizable. 
Leveraging a large-scale dataset of 100K multi-view subjects, our method demonstrates exceptional generalizability and robustness in handling diverse human shapes, cross-domain data, severe viewpoints, and occlusions. 
With a uniform structured representation, the reconstructed avatars are directly animatable and easily editable, providing a significant step forward for various applications in graphics, vision, and beyond.

In summary, this project introduces:

- **IDOL**: A scalable pipeline for instant photorealistic 3D human reconstruction using a simple yet efficient feed-forward model.
- **HuGe100K Dataset**: We develop a data generation pipeline and present \datasetname, a large-scale multi-view human dataset featuring diverse attributes, high-fidelity, high-resolution appearances, and a well-aligned SMPL-X model.
- **Application Support**: Enabling 3D human reconstruction and downstream tasks such as editing and animation.


---
## 📰 **News** 

- **2024-12-18**: Paper is now available on arXiv.
- **2025-01-02**: The demo dataset containing 100 samples is now available for access. The remaining dataset is currently undergoing further cleaning and review.
- **2025-03-01**: 🎉 Paper accepted by CVPR 2025.
- **2025-03-01**: 🎉 We have released the inference code! Check out the [Code Release](#code-release) section for details.
- **2025-04-01**: 🔥 Full HuGe100K dataset is now available! See the [Dataset Access](#dataset-demo-access) section.
- **2025-04-05**: 🔥 Training code is now available! Check out the [Training Code](#training-code) section for details.

## 🚧 **Project Status**   

We are actively working on releasing the following resources:  

| Resource                    | Status              | Expected Release Date      |
|-----------------------------|---------------------|----------------------------|
| **Dataset Demo**            | ✅ Available        | **Now Live! (2025.01.02)**      |
| **Inference Code**             | ✅ Available        | **Now Live! (2025.03.01)**   |
| **Full Dataset Access**     | ✅ Available        | **Now Live! (2025.04.01)**   |
| **Online Demo**             | 🚧 In Progress      | **Before April  2025**   |
| **Training Code**                    | ✅ Available     | **Now Live! (2025.04.05)**   |

Stay tuned as we update this section with new releases! 🚀  



## 💻 **Code Release** 

### Installation & Environment Setup

Please refer to [env/README.md](env/README.md) for detailed environment setup instructions.

### Quick Start
Run demo with different modes:
```bash
# Reconstruct the input image
python run_demo.py --render_mode reconstruct

# Generate novel poses (animation)
python run_demo.py --render_mode novel_pose

# Generate 360-degree view
python run_demo.py --render_mode novel_pose_A
```

### Training

#### Data Preparation

1. **Dataset Structure**: First, prepare your dataset with the following structure:
   ```
   dataset_root/
   ├── deepfashion/
   │   ├── image1/
   │   │   ├── videos/
   │   │   │   ├── xxx.mp4
   │   │   │   └── xxx.jpg
   │   │   └── param/
   │   │       └── xxx.npy
   │   └── image2/
   │       ├── videos/
   │       └── param/
   └── flux_batch1_5000/
       ├── image1/
       │   ├── videos/
       │   └── param/
       └── image2/
           ├── videos/
           └── param/
   ```

2. **Process Dataset**: Run the data processing script to generate cache files:
   ```bash
   # Process the dataset and generate cache files
   # Please modify the dataset path and the sample number in the script
   bash data_processing/process_datasets.sh
   ```

   This will generate cache files in the `processed_data` directory:
   - `deepfashion_train_140.npy`
   - `deepfashion_val_10.npy`
   - `deepfashion_test_50.npy`
   - `flux_batch1_5000_train_140.npy`
   - `flux_batch1_5000_val_10.npy`
   - `flux_batch1_5000_test_50.npy`

3. **Configure Cache Path**: Update the cache path in your config file (e.g., `configs/idol_v0.yaml`):
   ```yaml
     params:
       cache_path: [
         ./processed_data/deepfashion_train_140.npy,
         ./processed_data/flux_batch1_5000_train_140.npy
       ]
   ```

#### Training

1. **Single-Node Training**: For single-node multi-GPU training:
   ```bash
   python train.py \
     --base configs/idol_v0.yaml \
     --num_nodes 1 \
     --gpus 0,1,2,3,4,5,6,7
   ```

2. **Multi-Node Training**: For multi-node training, specify additional parameters:
   ```bash
   python train.py \
     --base configs/idol_v0.yaml \
     --num_nodes <total_nodes> \
     --node_rank <current_node_rank> \
     --master_addr <master_node_ip> \
     --master_port <port_number> \
     --gpus 0,1,2,3,4,5,6,7
   ```

   Example for a 2-node setup:
   ```bash
   # On master node (node 0):   
   python train.py --base configs/idol_v0.yaml --num_nodes 2 --node_rank 0 --master_addr 192.168.1.100 --master_port 29500 --gpus 0,1,2,3,4,5,6,7

   # On worker node (node 1):
   python train.py --base configs/idol_v0.yaml --num_nodes 2 --node_rank 1 --master_addr 192.168.1.100 --master_port 29500 --gpus 0,1,2,3,4,5,6,7
   ```

3. **Resume Training**: To resume training from a checkpoint:
   ```bash
   python train.py \
     --base configs/idol_v0.yaml \
     --resume PATH/TO/MODEL.ckpt \
     --num_nodes 1 \
     --gpus 0,1,2,3,4,5,6,7
   ```

4. **Test and Evaluate Metrics**:
   ```bash
   python train.py \
     --base configs/idol_v0.yaml \                # Main config file (model)
     --num_nodes 1 \
     --gpus 0,1,2,3,4,5,6,7 \
     --test_sd /path/to/model_checkpoint.ckpt \   # Path to the .ckpt model you want to test
      --test_dataset ./configs/test_dataset.yaml   # (Optional) Dataset config used specifically for testing
   ```

## Notes
- Make sure all GPUs have enough memory for the selected batch size
- For multi-node training, ensure network connectivity between nodes
- Monitor training progress using the logging system
- Adjust learning rate and other hyperparameters in the config file as needed


## 🌐 **Key Links** 

- 📄 [**Paper on arXiv**](https://arxiv.org/pdf/2412.02684)  
- 🌐 [**Project Website**](https://yiyuzhuang.github.io/IDOL/)  
- 🚀 [**Live Demo**](https://your-live-demo-link.com) (Coming Soon!)  

---

## 📊 **Dataset Demo Access**   

We introduce **HuGe100K**, a large-scale multi-view human dataset, supporting 3D human reconstruction and animation research.  

### ▶ **Watch the Demo Video**
<p align="center">
  <img src="./asset/videos/dataset.gif" alt="Dataset GIF" width="85%">
</p>

### 📋 **Dataset Documentation**
For detailed information about the dataset format, structure, and usage guidelines, please refer to our [Dataset Documentation](dataset/README.md).

### 🚀 **Access the Dataset**   

<div align="center">
  <p><strong>🔥 HuGe100K - The largest multi-view human dataset with 100,000+ subjects! 🔥</strong></p>
  <p>High-resolution • Multi-view • Diverse poses • SMPL-X aligned</p>
  

  <a href="https://docs.google.com/forms/d/e/1FAIpQLSeVqrA9Mc_ODdcTZsB3GgrxgSNZk5deOzK4f64N72xlQFhvzQ/viewform?usp=dialog">
    <img src="https://img.shields.io/badge/Apply_for_Access-HuGe100K_Dataset-FF6B6B?style=for-the-badge&logo=googleforms&logoColor=white" alt="Apply for Access" width="300px">
  </a>
  <p><i>Complete the form to get access credentials and download links!</i></p>
</div>

### ⚖️ **License and Attribution**

This dataset includes images derived from the **DeepFashion** dataset, originally provided by MMLAB at The Chinese University of Hong Kong. The use of DeepFashion images in this dataset has been explicitly authorized by the original authors solely for the purpose of creating and distributing this dataset. **Users must not further reproduce, distribute, sell, or commercially exploit any images or derived data originating from DeepFashion.** For any subsequent or separate use of the DeepFashion data, users must directly obtain authorization from MMLAB and comply with the original [DeepFashion License](https://mmlab.ie.cuhk.edu.hk/projects/DeepFashion.html).

---

## 📝 **Citation**   

If you find our work helpful, please cite us using the following BibTeX:

```bibtex
@article{zhuang2024idolinstant,                
  title={IDOL: Instant Photorealistic 3D Human Creation from a Single Image}, 
  author={Yiyu Zhuang and Jiaxi Lv and Hao Wen and Qing Shuai and Ailing Zeng and Hao Zhu and Shifeng Chen and Yujiu Yang and Xun Cao and Wei Liu},
  journal={arXiv preprint arXiv:2412.14963},
  year={2024},
  url={https://arxiv.org/abs/2412.14963}, 
}
```



## **License** 

This project is licensed under the **MIT License**.

- **Permissions**: This license grants permission to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the software.
- **Condition**: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
- **Disclaimer**: The software is provided "as is", without warranty of any kind.

For more information, see the full license [here](https://opensource.org/licenses/MIT).

## **Support Our Work** ⭐

If you find our work useful for your research or applications:

- Please ⭐ **star our repository** to help us reach more people
- Consider **citing our paper** in your publications (see [Citation](#citation) section)
- Share our project with others who might benefit from it

Your support helps us continue developing open-source research projects like this one!

## 📚 **Acknowledgments**

This project is majorly built upon several excellent open-source projects:

- [E3Gen](https://github.com/olivia23333/E3Gen): Efficient, Expressive and Editable Avatars Generation
- [SAPIENS](https://github.com/facebookresearch/sapiens): High-resolution visual models for human-centric tasks
- [GeoLRM](https://github.com/alibaba-yuanjing-aigclab/GeoLRM): Large Reconstruction Model for High-Quality 3D Generation
- [3D Gaussian Splatting](https://github.com/graphdeco-inria/gaussian-splatting): Real-Time 3DGS Rendering

We thank all the authors for their contributions to the open-source community.


================================================
FILE: configs/idol_debug.yaml
================================================

debug: True
# code_size: [32, 256, 256]
code_size: [32, 1024, 1024]
model:
  # base_learning_rate: 2.0e-04 # yy Need to check
  target: lib.SapiensGS_SA_v1
  params:
   # optimizer add
    # use_bf16: true
    max_steps: 100_000
    warmup_steps: 10_000 #12_000
    use_checkpoint: true
    lambda_depth_tv: 0.05 
    lambda_lpips: 0 #2.0
    lambda_mse: 20 #1.0
    lambda_offset: 1 #offset_weight: 50 mse 20, lpips 0.1
    neck_learning_rate: 5e-4
    decoder_learning_rate: 5e-4


    output_hidden_states: true  # if True, will output the hidden states from sapiens shallow layer, for the neck decoder
    
    loss_coef: 0.5 
    init_iter: 500
    scale_weight: 0.01
    smplx_path:  'work_dirs/demo_data/Ways_to_Catch_360_clip1.json'
   
    code_reshape:  [32, 96, 96] 
    patch_size: 1
    code_activation:
      type: tanh
      mean: 0.0
      std: 0.5
      clip_range: 2
    grid_size: 64
    encoder:
      target: lib.models.sapiens.SapiensWrapper_ts 
      params:
        model_path:   work_dirs/ckpt/sapiens_1b_epoch_173_torchscript.pt2
        # model_path: /apdcephfs_cq8/share_1367250/harriswen/projects/sapiens_convert/checkpoints//sapiens_1b_epoch_173_torchscript.pt2
        layer_num: 40
        img_size: [1024, 736]
        freeze: True
    neck:
      target: lib.models.transformer_sa.neck_SA_v3_skip # TODO!! add a self attention version
      params:
        patch_size: 4 #4,
        in_chans: 32  #32, # the uv code  dims
        num_patches: 9216 #4096 #num_patches  #,#4096, # 16*16
        embed_dim: 1536 # sapiens' latent dims # 1920 # 1920 for sapiens encoder2  #1024 # the feature extrators outputs
        decoder_embed_dim: 128 # 1024
        decoder_depth: 2 # 8
        decoder_num_heads: 4 #16,
        total_num_hidden_states: 12 
        mlp_ratio: 4.
    decoder:
      target:  lib.models.decoders.UVNDecoder_gender 
      params:
        interp_mode: bilinear
        base_layers: [16, 64]
        density_layers: [64, 1]
        color_layers: [16, 128, 9]
        offset_layers: [64, 3]
        use_dir_enc: false
        dir_layers: [16, 64]
        activation: silu
        bg_color: 1
        sigma_activation: sigmoid
        sigmoid_saturation: 0.001
        gender: neutral
        is_sub2: true ## update, make it into 10w gs points
        multires: 0
        image_size: [640, 896]
        superres: false
        focal: 1120
        up_cnn_in_channels: 128 # be the same as decoder_embed_dim
        reshape_type: VitHead
        vithead_param:
          in_channels: 128 # be the same as decoder_embed_dim
          out_channels: 32
          deconv_out_channels: [128, 64]
          deconv_kernel_sizes: [4, 4]
          conv_out_channels: [128, 128]
          conv_kernel_sizes: [3, 3]
        fix_sigma: true

dataset:
  target: lib.datasets.dataloader.DataModuleFromConfig
  params:
    batch_size: 1 #16 # 6 for lpips
    num_workers: 1 #2
    # working when in debug mode
    debug_cache_path:./processed_data/flux_batch1_5000_test_50_local.npy

    train: 
      target: lib.datasets.AvatarDataset
      params:
        data_prefix: None
     
        cache_path:  [
          ./processed_data/deepfashion_train_140_local.npy,
          ./processed_data/flux_batch1_5000_train_140_local.npy
        ]

        specific_observation_num: 5
        better_range: true
        first_is_front: true
        if_include_video_ref_img: true  
        prob_include_video_ref_img: 0.5
        img_res: [640, 896]
    validation:
      target: lib.datasets.AvatarDataset
      params:
        data_prefix: None
        load_imgs: true
        specific_observation_num: 3
        better_range: true
        first_is_front: true
        img_res: [640, 896]
        cache_path:       [
        ./processed_data/flux_batch1_5000_test_50_local.npy,
        #./processed_data/flux_batch1_5000_val_10_local.npy
        ]



lightning:
  modelcheckpoint:
    params:
      every_n_train_steps:  4000 #2000
      save_top_k: -1
      save_last: true
      monitor: 'train/loss_mse' # ADD this logging in the wrapper_sa
      mode: "min"
      filename: 'sample-synData-epoch{epoch:02d}-val_loss{val/loss:.2f}'
  callbacks: {}
  trainer:
    num_sanity_val_steps: 1
    accumulate_grad_batches: 1
    gradient_clip_val: 10.0
    max_steps: 80000
    check_val_every_n_epoch: 1  ## check validation set every 1 training batches in the current epoch
    benchmark: true
    val_check_interval: 1.0

================================================
FILE: configs/idol_v0.yaml
================================================

debug: True
# code_size: [32, 256, 256]
code_size: [32, 1024, 1024]
model:
  # base_learning_rate: 2.0e-04 # yy Need to check
  target: lib.SapiensGS_SA_v1
  params:
   # optimizer add
    # use_bf16: true
    max_steps: 100_000
    warmup_steps: 10_000 #12_000
    use_checkpoint: true
    lambda_depth_tv: 0.05 
    lambda_lpips: 10 #2.0
    lambda_mse: 20 #1.0
    lambda_offset: 1 #offset_weight: 50 mse 20, lpips 0.1
    neck_learning_rate: 5e-4
    decoder_learning_rate: 5e-4


    output_hidden_states: true  # if True, will output the hidden states from sapiens shallow layer, for the neck decoder
    
    loss_coef: 0.5 
    init_iter: 500
    scale_weight: 0.01
    smplx_path:  'work_dirs/demo_data/Ways_to_Catch_360_clip1.json'
   
    code_reshape:  [32, 96, 96] 
    patch_size: 1
    code_activation:
      type: tanh
      mean: 0.0
      std: 0.5
      clip_range: 2
    grid_size: 64
    encoder:
      target: lib.models.sapiens.SapiensWrapper_ts 
      params:
        model_path:   work_dirs/ckpt/sapiens_1b_epoch_173_torchscript.pt2
        # model_path: /apdcephfs_cq8/share_1367250/harriswen/projects/sapiens_convert/checkpoints//sapiens_1b_epoch_173_torchscript.pt2
        layer_num: 40
        img_size: [1024, 736]
        freeze: True
    neck:
      target: lib.models.transformer_sa.neck_SA_v3_skip # TODO!! add a self attention version
      params:
        patch_size: 4 #4,
        in_chans: 32  #32, # the uv code  dims
        num_patches: 9216 #4096 #num_patches  #,#4096, # 16*16
        embed_dim: 1536 # 1920 # 1920 for sapiens encoder2  #1024 # the feature extrators outputs
        decoder_embed_dim: 1536 # 1024
        decoder_depth: 16 # 8
        decoder_num_heads: 16 #16,
        total_num_hidden_states: 40 
        mlp_ratio: 4.
    decoder:
      target:  lib.models.decoders.UVNDecoder_gender 
      params:
        interp_mode: bilinear
        base_layers: [16, 64]
        density_layers: [64, 1]
        color_layers: [16, 128, 9]
        offset_layers: [64, 3]
        use_dir_enc: false
        dir_layers: [16, 64]
        activation: silu
        bg_color: 1
        sigma_activation: sigmoid
        sigmoid_saturation: 0.001
        gender: neutral
        is_sub2: true ## update, make it into 10w gs points
        multires: 0
        image_size: [640, 896]
        superres: false
        focal: 1120
        up_cnn_in_channels: 1536 # be the same as decoder_embed_dim
        reshape_type: VitHead
        vithead_param:
          in_channels: 1536 # be the same as decoder_embed_dim
          out_channels: 32
          deconv_out_channels: [512, 512, 512, 256]
          deconv_kernel_sizes: [4, 4, 4, 4]
          conv_out_channels: [128, 128]
          conv_kernel_sizes: [3, 3]
        fix_sigma: true

dataset:
  target: lib.datasets.dataloader.DataModuleFromConfig
  params:
    batch_size: 1 #16 # 6 for lpips
    num_workers: 2 #2
    # working when in debug mode
    debug_cache_path:  ./processed_data/flux_batch1_5000_test_50_local.npy

    train: 
      target: lib.datasets.AvatarDataset
      params:
        data_prefix: None
     
        cache_path:  [
          ./processed_data/deepfashion_train_140_local.npy,
          ./processed_data/flux_batch1_5000_train_140_local.npy
        ]

        specific_observation_num: 5
        better_range: true
        first_is_front: true
        if_include_video_ref_img: true  
        prob_include_video_ref_img: 0.5
        img_res: [640, 896]
    validation:
      target: lib.datasets.AvatarDataset
      params:
        data_prefix: None
        load_imgs: true
        specific_observation_num: 5
        better_range: true
        first_is_front: true
        img_res: [640, 896]
        cache_path:  [
          ./processed_data/deepfashion_val_10_local.npy,
          ./processed_data/flux_batch1_5000_val_10_local.npy
        ]



lightning:
  modelcheckpoint:
    params:
      every_n_train_steps:  4000 #2000
      save_top_k: -1
      save_last: true
      monitor: 'train/loss_mse' # ADD this logging in the wrapper_sa
      mode: "min"
      filename: 'sample-synData-epoch{epoch:02d}-val_loss{val/loss:.2f}'
  callbacks: {}
  trainer:
    num_sanity_val_steps: 0
    accumulate_grad_batches: 1
    gradient_clip_val: 10.0
    max_steps: 80000
    check_val_every_n_epoch: 1  ## check validation set every 1 training batches in the current epoch
    benchmark: true

================================================
FILE: configs/test_dataset.yaml
================================================

dataset:
  target: lib.datasets.dataloader.DataModuleFromConfig
  params:
    batch_size: 1 
    num_workers: 2 
    # working when in debug mode
    debug_cache_path:  ./processed_data/flux_batch1_5000_test_50_local.npy

    train: 
      target: lib.datasets.AvatarDataset
      params:
        data_prefix: None
     
        cache_path:  [
          ./processed_data/deepfashion_train_140_local.npy,
          ./processed_data/flux_batch1_5000_train_140_local.npy
        ]

        specific_observation_num: 5
        better_range: true
        first_is_front: true
        if_include_video_ref_img: true
        prob_include_video_ref_img: 0.5                                                                                                                          
        img_res: [640, 896]
    validation:
      target: lib.datasets.AvatarDataset
      params:
        data_prefix: None
        load_imgs: true
        specific_observation_num: 5
        better_range: true
        first_is_front: true
        img_res: [640, 896]
        cache_path:  [
          ./processed_data/deepfashion_val_10_local.npy,
          ./processed_data/flux_batch1_5000_val_10_local.npy
        ]
    test:
      target: lib.datasets.AvatarDataset
      params:
        data_prefix: None
        load_imgs: true
        specific_observation_num: 5
        better_range: true
        first_is_front: true
        img_res: [640, 896]
        cache_path:  [
          ./processed_data/deepfashion_test_50_local.npy,
          ./processed_data/flux_batch1_5000_test_50_local.npy
        ]

================================================
FILE: data_processing/prepare_cache.py
================================================
#!/usr/bin/env python
# -*- coding: utf-8 -*-

"""
Data preparation script for DeepFashion video dataset.
This script processes video files and their corresponding parameters,
and splits the dataset into train/val/test sets.
"""

import os
import numpy as np
import argparse


def parse_args():
    """Parse command line arguments."""
    parser = argparse.ArgumentParser(description="Prepare DeepFashion video dataset")
    parser.add_argument(
        "--video_dir", 
        type=str, 
        default="/apdcephfs/private_harriswen/data/deepfashion/",
        help="Base directory containing imageX folders"
    )
    parser.add_argument(
        "--output_dir", 
        type=str, 
        default="./", 
        help="Directory to save the processed data"
    )
    parser.add_argument(
        "--prefix", 
        type=str, 
        default="DeepFashion", 
        help="Prefix for the output file names"
    )
    parser.add_argument(
        "--max_videos", 
        type=int, 
        default=20000, 
        help="Maximum number of videos to process (for creating smaller datasets)"
    )
    return parser.parse_args()


def prepare_dataset(video_dir, output_dir, prefix, max_total_videos=20000):
    """
    Prepare the DeepFashion dataset by processing videos and parameters.
    
    Args:
        video_dir: Base directory containing imageX folders
        output_dir: Directory to save processed data
        prefix: Prefix for output filenames
        max_total_videos: Maximum number of videos to process (default: 20000)
    """
    # Find all imageX subdirectories
    image_dirs = []
    for item in os.listdir(video_dir):
        if item.startswith("image") and os.path.isdir(os.path.join(video_dir, item)):
            image_dirs.append(item)
    
    image_dirs.sort()
    print(f"Found {len(image_dirs)} image directories: {image_dirs}")
    
    # Collect all video files
    all_video_files = []
    all_param_files = []
    all_dir_names = []
    # import ipdb; ipdb.set_trace()
    for image_dir in image_dirs:
        videos_path = os.path.join(video_dir, image_dir, "videos")
        params_path = os.path.join(video_dir, image_dir, "param")
        
        if not os.path.exists(videos_path):
            print(f"Warning: Videos directory not found in {image_dir}, skipping.")
            continue
        
        if not os.path.exists(params_path):
            print(f"Warning: Parameters directory not found in {image_dir}, skipping.")
            continue
        
        # Get list of video names in current directory
        param_names = os.listdir(params_path)

        # filter the files with .npy extension
        param_names = [name for name in param_names if name.endswith(".npy")]
        
        for name in param_names:
            video_path = os.path.join(videos_path, name.replace(".npy", ".mp4"))
            param_path = os.path.join(params_path, name)
            
            # Check if both video and parameter files exist
            if not os.path.exists(video_path):
                print(f"Warning: Video file not found: {video_path}, skipping.")
                continue
                
            if not os.path.exists(param_path):
                print(f"Warning: Parameter file not found: {param_path}, skipping.")
                continue
                
            # Add to collection only if both files exist
            all_video_files.append(video_path)
            all_param_files.append(param_path)
            all_dir_names.append(image_dir)
    
    total_videos = len(all_video_files)
    print(f"Total valid videos found: {total_videos}")
    
    if total_videos == 0:
        print("Error: No valid video-parameter pairs found. Please check your data paths.")
        return
    
    # Limit number of videos to process
    if max_total_videos < total_videos:
        # Randomly shuffle and select first max_total_videos
        indices = list(range(total_videos))
        np.random.shuffle(indices)
        indices = indices[:max_total_videos]
        
        all_video_files = [all_video_files[i] for i in indices]
        all_param_files = [all_param_files[i] for i in indices]
        all_dir_names = [all_dir_names[i] for i in indices]
        
        print(f"Limiting to {max_total_videos} videos")
    
    # Process videos and parameters
    scenes = []
    processed_count = 0
    skipped_count = 0
    
    for video_path, param_path, dir_name in zip(all_video_files, all_param_files, all_dir_names):
        processed_count += 1
        case_name = os.path.basename(video_path)
        
        print(f"Processing {processed_count}/{len(all_video_files)}: {dir_name}/{case_name}")
        
        try:
            # Create scene dictionary
            scenes.append(dict(
                video_path=video_path,
                image_paths=None, # only fill it for the data in a images sequence instead of a video
                param_path=param_path,
                image_ref=video_path.replace(".mp4", ".jpg")
            ))
        except Exception as e:
            print(f"Error processing {video_path}: {e}")
            skipped_count += 1
    
    print(f"Total scenes collected: {len(scenes)}")
    print(f"Total scenes skipped: {skipped_count}")
    
    if len(scenes) == 0:
        print("Error: No scenes could be processed. Please check your data.")
        return
    
    # Split dataset
    total_scenes = len(scenes)
    test_scenes = scenes[-50:] if total_scenes > 50 else []
    val_scenes = scenes[-60:-50] if total_scenes > 60 else []
    train_scenes = scenes[:-60] if total_scenes > 60 else scenes
    
    # Save each split
    splits = {
        "train": train_scenes,
        "val": val_scenes,
        "test": test_scenes,
        "all": scenes
    }
    
    # Create output directory
    os.makedirs(output_dir, exist_ok=True)
    
    # Save each split to separate file
    for split_name, split_data in splits.items():
        if not split_data:
            continue
            
        cache_path = os.path.join(
            output_dir, 
            f"{prefix}_{split_name}_{len(split_data)}.npy"
        )
        np.save(cache_path, split_data)
        print(f"Saved {split_name} split with {len(split_data)} samples to {cache_path}")


if __name__ == "__main__":
    # Parse command line arguments
    args = parse_args()
    
    # Prepare and save the dataset
    prepare_dataset(args.video_dir, args.output_dir, args.prefix, args.max_videos)
    print(f"Done processing {args.video_dir} dataset")


================================================
FILE: data_processing/process_datasets.sh
================================================
#!/bin/bash

# Data processing script for multiple datasets
# This script processes all specified datasets and saves the results to output directories

# Define the list of dataset paths
DATASET_PATHS=(
    "/PATH/TO/deepfashion"
    "/PATH/TO/flux_batch1_5000"
    "/PATH/TO/flux_batch2"
    # Add more dataset paths here as needed
)

# Output base directory for processed cache files
OUTPUT_BASE_DIR="./processed_data"

# Maximum videos to process per dataset (set to a smaller number for testing)
# if you want to process all videos, set MAX_VIDEOS to a very large number
MAX_VIDEOS=200

# Process each dataset
for DATASET_PATH in "${DATASET_PATHS[@]}"; do
    # Extract dataset name from path (use the last directory name as prefix)
    DATASET_NAME=$(basename "$DATASET_PATH")
    
    # Create output directory for this dataset
    OUTPUT_DIR="${OUTPUT_BASE_DIR}"
    mkdir -p "$OUTPUT_DIR"
    
    echo "===== Processing ${DATASET_NAME} Dataset ====="
    echo "Source: ${DATASET_PATH}"
    echo "Destination: ${OUTPUT_DIR}"
    
    # Run the processing script
    python data_processing/prepare_cache.py \
        --video_dir "${DATASET_PATH}" \
        --output_dir "${OUTPUT_DIR}" \
        --prefix "${DATASET_NAME}" \
        --max_videos "${MAX_VIDEOS}"
    
    # Check if processing was successful
    if [ $? -ne 0 ]; then
        echo "Error processing ${DATASET_NAME} dataset"
        echo "Continuing with next dataset..."
    else
        echo "Successfully processed ${DATASET_NAME} dataset"
    fi
    
    echo "----------------------------------------"
done

echo "===== All datasets processing completed ====="
echo "Results saved to: ${OUTPUT_BASE_DIR}"

# List all processed datasets
echo "Processed datasets:"
for DATASET_PATH in "${DATASET_PATHS[@]}"; do
    DATASET_NAME=$(basename "$DATASET_PATH")
    echo "- ${DATASET_NAME}: ${OUTPUT_BASE_DIR}/${DATASET_NAME}"
done 

================================================
FILE: data_processing/visualize_samples.py
================================================

import torch
import numpy as np
import os
os.environ["PYOPENGL_PLATFORM"] = "osmesa"
import smplx
import trimesh
import pyrender
import imageio

def init_smplx_model():
    """Initialize the SMPL-X model with predefined settings."""
    body_model = smplx.SMPLX('PATH_TO_YOUR_SMPLX_FOLDER',
                             gender="neutral", 
                             create_body_pose=False, 
                             create_betas=False, 
                             create_global_orient=False, 
                             create_transl=False,
                             create_expression=True,
                             create_jaw_pose=True, 
                             create_leye_pose=True, 
                             create_reye_pose=True, 
                             create_right_hand_pose=False,
                             create_left_hand_pose=False,
                             use_pca=False,
                             num_pca_comps=12,
                             num_betas=10,
                             flat_hand_mean=False)
    return body_model

# Load SMPL-X parameters
param_path =  "./100samples/Apose/param/Argentina_male_buff_thermal wear_20~30 years old_1573.npy"
param = np.load(param_path, allow_pickle=True).item()

# Extract SMPL-X parameters
smpl_params = param['smpl_params'].reshape(1, -1)
scale, transl, global_orient, pose, betas, left_hand_pose, right_hand_pose, jaw_pose, leye_pose, reye_pose, expression = torch.split(smpl_params, [1, 3, 3, 63, 10, 45, 45, 3, 3, 3, 10], dim=1)

# Initialize SMPL-X model and generate vertices
device = torch.device("cpu")
model = init_smplx_model().to(device)
output = model(global_orient=global_orient, body_pose=pose, betas=betas, left_hand_pose=left_hand_pose,
               right_hand_pose=right_hand_pose, jaw_pose=jaw_pose, leye_pose=leye_pose, reye_pose=reye_pose,
               expression=expression)
vertices = output.vertices[0].detach().cpu().numpy()
faces = model.faces

# Create a Trimesh and Pyrender mesh
mesh = trimesh.Trimesh(vertices, faces)
mesh_pyrender = pyrender.Mesh.from_trimesh(mesh)
   
rendered_images_list = []

# Loop through multiple camera views
for idx in range(24):
    scene = pyrender.Scene()
    scene.add(mesh_pyrender)

    # Load and process camera parameters
    camera_params = param['poses']
    intrinsic_params = camera_params[idx][1]  # fx, fy, cx, cy
    extrinsic_params = camera_params[idx][0] # R|T

    # Set up Pyrender camera
    camera = pyrender.IntrinsicsCamera(fx=intrinsic_params[0], fy=intrinsic_params[1],
                                       cx=intrinsic_params[2], cy=intrinsic_params[3])

    # Convert COLMAP coordinates to Pyrender-compatible transformation
    extrinsic_params_inv = torch.inverse(extrinsic_params.clone())
    scale_factor = extrinsic_params_inv[:3, :3].norm(dim=1)
    extrinsic_params_inv[:3, 1:3] = -extrinsic_params_inv[:3, 1:3]
    extrinsic_params_inv[3, :3] = 0

    # Add camera and lighting
    scene.add(camera, pose=extrinsic_params_inv)
    light = pyrender.DirectionalLight(color=[1.0, 1.0, 1.0], intensity=10.0)
    scene.add(light, pose=extrinsic_params_inv)

    # Render the scene
    renderer = pyrender.OffscreenRenderer(640, 896)
    color, depth = renderer.render(scene)
    rendered_images_list.append(color)
    renderer.delete()

# Save rendered images as a video
rendered_images = np.stack(rendered_images_list)
imageio.mimwrite('rendered_results.mp4', rendered_images, fps=15)
print("Rendered results saved as rendered_results.mp4")

# Load an existing video and test alignment
video_path = param_path.replace("param", "videos").replace("npy", "mp4")
input_video = imageio.get_reader(video_path)
input_frames = [frame for frame in input_video]
blended_frames = [(0.5 * frame + 0.5 * render_frame).astype(np.uint8) for render_frame, frame in zip(rendered_images, input_frames)]
imageio.mimwrite('aligned_results.mp4', blended_frames, fps=15)
print("Blended video saved as aligned_results.mp4")


================================================
FILE: dataset/README.md
================================================
# 🌟 HuGe100K Dataset Documentation

## 📊 Dataset Overview
HuGe100K is a large-scale multi-view human dataset featuring diverse attributes, high-fidelity appearances, and well-aligned SMPL-X models.

## 📁 File Format and Structure

The dataset is organized with the following structure:

```
HuGe100K/
├── flux_batch1/
│   ├── images[0...9]/            #  different batch of images
│   │   ├── videos/               # Folder for .mp4 and .jpg files
│   │   │   ├── Algeria_female_average_high fashion_50~60 years old_844.jpg
│   │   │   ├── Algeria_female_average_high fashion_50~60 years old_844.mp4
│   │   │   └── ... (more .jpg and .mp4)
│   │   └── param/               # Folder for parameter files (.npy)
│   │       ├── Algeria_female_average_high fashion_50~60 years old_844.npy
│   │       └── ... (more .npy files)
├── flux_batch2/
│   └── ... (similar structure with images[0...9])
├── flux_batch3/
│   └── ... (similar structure with images[0...9])
├── flux_batch4/
│   └── ... (similar structure with images[0...9])
├── flux_batch5/
│   └── ... (similar structure with images[0...9])
├── flux_batch6/
│   └── ... (similar structure with images[0...9])
├── flux_batch7/
│   └── ... (similar structure with images[0...9])
├── flux_batch8/
│   └── ... (similar structure with images[0...9])
├── flux_batch9/
│   └── ... (similar structure with images[0...9])
└── deepfashion/
    └── ... (similar structure with images[0...9])
```

Where:
- Each `images[X]` folder contains:
  - `videos/`: Reference images and generatedvideo files
  - `param/`: Camera and body pose parameters
- **flux_batch1 through flux_batch7**: Contains subjects in A-pose
- **flux_batch8 and flux_batch9**: Contains subjects in diverse poses 
- **deepfashion**: Contains subjects in A-pose (derived from the DeepFashion dataset)

### File Naming Convention
Files follow the naming pattern: `Area_Gender_BodyType_Clothing_Age_ID.extension`

For example:
- `Algeria_female_average_high fashion_50~60 years old_844.jpg`: Reference image of an Algerian female with average build in high fashion clothing
- `Algeria_female_average_high fashion_50~60 years old_844.npy`: Parameter file for the same subject

### 📸 Sample Visualization

<div style="display: flex; align-items: center; justify-content: center; gap: 10px; flex-wrap: nowrap; width: 100%;">
  <img src="sample/videos/Kenya_female_fit_streetwear_50~60 years old_1501.jpg" alt="Kenya Female Fit Streetwear Image" style="max-width: 45%; width: 45%; height: auto;">
  <span style="font-weight: bold;"> =MVChamp=> </span>
  <!-- <video autoplay loop muted playsinline style="max-width: 45%; width: 45%; height: auto;">
    <source src="sample/videos/Kenya_female_fit_streetwear_50~60 years old_1501.gif" type="video/mp4">
    Your browser does not support the video tag.
  </video> -->
    <img src="sample/videos/Kenya_female_fit_streetwear_50~60 years old_1501.gif" alt="Kenya Female Fit Streetwear Image" style="max-width: 45%; width: 45%; height: auto;">
</div>



## 📈 Dataset Statistics

- **Total Subjects**: 100,000+
- **Views per Subject**: Multiple viewpoints covering 360° in 24 views
- **Pose Types**: A-pose and diverse poses

## 🔍 Visualizing the Dataset

For visualization and data parsing examples, please refer to our provided script:
`visualize_samples.py`. This script demonstrates how to:

- Load the SMPL-X parameters from `.npy` files
- Render the 3D human model from multiple camera views
- Compare rendered results with the original video frames

Requirements for visualization:
- SMPL-X model (download from [official website](https://smpl-x.is.tue.mpg.de/))
- Python packages: `pyrender`, `trimesh`, `smplx`, `numpy`, `torch`

Example usage:
```bash
python visualize_samples.py
```

The script will generate:
- `rendered_results.mp4`: Rendered views of the 3D model
- `aligned_results.mp4`: Blended visualization of rendered model with original frames

## 📋 Usage Guidelines

1. **Research Purposes Only**: This dataset is intended for academic and research purposes.
2. **Citation Required**: If you use this dataset in your research, please cite our paper.
3. **No Commercial Use**: Commercial use is permitted only with explicit permission from us at yiyu.zhuang@smail.nju.edu.cn.
4. **DeepFashion Derivatives**: See License and Attribution section below for special requirements.

## ⚖️ License and Attribution (DeepFashion)

This dataset includes images derived from the **DeepFashion** dataset, originally provided by MMLAB at The Chinese University of Hong Kong. The use of DeepFashion images in this dataset has been explicitly authorized by the original authors solely for the purpose of creating and distributing this dataset. **Users must not further reproduce, distribute, sell, or commercially exploit any images or derived data originating from DeepFashion.** For any subsequent or separate use of the DeepFashion data, users must directly obtain authorization from MMLAB and comply with the original [DeepFashion License](https://mmlab.ie.cuhk.edu.hk/projects/DeepFashion.html). 

================================================
FILE: dataset/visualize_samples.py
================================================

import torch
import numpy as np
import os
os.environ["PYOPENGL_PLATFORM"] = "osmesa"
import smplx
import trimesh
import pyrender
import imageio

def init_smplx_model():
    """Initialize the SMPL-X model with predefined settings."""
    body_model = smplx.SMPLX('PATH_TO_YOUR_SMPLX_FOLDER',
                             gender="neutral", 
                             create_body_pose=False, 
                             create_betas=False, 
                             create_global_orient=False, 
                             create_transl=False,
                             create_expression=True,
                             create_jaw_pose=True, 
                             create_leye_pose=True, 
                             create_reye_pose=True, 
                             create_right_hand_pose=False,
                             create_left_hand_pose=False,
                             use_pca=False,
                             num_pca_comps=12,
                             num_betas=10,
                             flat_hand_mean=False)
    return body_model

# Load SMPL-X parameters
param_path =  "./samples/param/Kenya_female_fit_streetwear_50~60 years old_1501.npy"
param = np.load(param_path, allow_pickle=True).item()

# Extract SMPL-X parameters
smpl_params = param['smpl_params'].reshape(1, -1)
scale, transl, global_orient, pose, betas, left_hand_pose, right_hand_pose, jaw_pose, leye_pose, reye_pose, expression = torch.split(smpl_params, [1, 3, 3, 63, 10, 45, 45, 3, 3, 3, 10], dim=1)

# Initialize SMPL-X model and generate vertices
device = torch.device("cpu")
model = init_smplx_model().to(device)
output = model(global_orient=global_orient, body_pose=pose, betas=betas, left_hand_pose=left_hand_pose,
               right_hand_pose=right_hand_pose, jaw_pose=jaw_pose, leye_pose=leye_pose, reye_pose=reye_pose,
               expression=expression)
vertices = output.vertices[0].detach().cpu().numpy()
faces = model.faces

# Create a Trimesh and Pyrender mesh
mesh = trimesh.Trimesh(vertices, faces)
mesh_pyrender = pyrender.Mesh.from_trimesh(mesh)
   
rendered_images_list = []

# Loop through multiple camera views
for idx in range(24):
    scene = pyrender.Scene()
    scene.add(mesh_pyrender)

    # Load and process camera parameters
    camera_params = param['poses']
    intrinsic_params = camera_params[idx][1]  # fx, fy, cx, cy
    extrinsic_params = camera_params[idx][0] # R|T

    # Set up Pyrender camera
    camera = pyrender.IntrinsicsCamera(fx=intrinsic_params[0], fy=intrinsic_params[1],
                                       cx=intrinsic_params[2], cy=intrinsic_params[3])

    # Convert COLMAP coordinates to Pyrender-compatible transformation
    extrinsic_params_inv = torch.inverse(extrinsic_params.clone())
    scale_factor = extrinsic_params_inv[:3, :3].norm(dim=1)
    extrinsic_params_inv[:3, 1:3] = -extrinsic_params_inv[:3, 1:3]
    extrinsic_params_inv[3, :3] = 0

    # Add camera and lighting
    scene.add(camera, pose=extrinsic_params_inv)
    light = pyrender.DirectionalLight(color=[1.0, 1.0, 1.0], intensity=10.0)
    scene.add(light, pose=extrinsic_params_inv)

    # Render the scene
    renderer = pyrender.OffscreenRenderer(640, 896)
    color, depth = renderer.render(scene)
    rendered_images_list.append(color)
    renderer.delete()

# Save rendered images as a video
rendered_images = np.stack(rendered_images_list)
imageio.mimwrite('rendered_results.mp4', rendered_images, fps=15)
print("Rendered results saved as rendered_results.mp4")

# Load an existing video and test alignment
video_path = param_path.replace("param", "videos").replace("npy", "mp4")
input_video = imageio.get_reader(video_path)
input_frames = [frame for frame in input_video]
blended_frames = [(0.5 * frame + 0.5 * render_frame).astype(np.uint8) for render_frame, frame in zip(rendered_images, input_frames)]
imageio.mimwrite('aligned_results.mp4', blended_frames, fps=15)
print("Blended video saved as aligned_results.mp4")


================================================
FILE: env/README.md
================================================
# Environment Setup Guide

## Prerequisites

- Python 3.10
- CUDA 11.8
- PyTorch 2.3.1

## Installation Steps

### 1. Environment Preparation

First, create and activate a conda environment:
```bash
conda create -n idol python=3.10
conda activate idol
```


Install all dependencies:
```bash
bash scripts/pip_install.sh
```

### 2. Download Required Models

Before proceeding, please register on:
- [SMPL-X website](https://smpl-x.is.tue.mpg.de/)
- [FLAME website](https://flame.is.tue.mpg.de/)

Then download the template files:
```bash
bash scripts/fetch_template.sh
```

### 3. Download Pretrained Models and caches with:
```bash
bash scripts/download_files.sh      # download pretrained models
```

Or mannually download the following models from HuggingFace:
- [IDOL Model Checkpoint](https://huggingface.co/yiyuzhuang/IDOL/blob/main/model.ckpt)
- [Sapiens Pretrained Model](https://huggingface.co/yiyuzhuang/IDOL/blob/main/sapiens_1b_epoch_173_torchscript.pt2)

## System Requirements

- **GPU**: NVIDIA GPU with CUDA 11.8 support
- **GPU RAM**: Recommended 24GB+
- **Storage**: At least 15GB free space

## Common Issues & Solutions
**Issue**: 
```
ImportError: libGL.so.1: cannot open shared object file: No such file or directory
```
when importing OpenCV (`import cv2`)

**Solution**:
```bash
# For Ubuntu/Debian
sudo apt-get install libgl1-mesa-glx
```

### 2. Gaussian Splatting Antialiasing Issue
**Issue**: Error related to `antialiasing=True` setting in `GaussianRasterizationSettings`

**Solution**:
This issue arises due to updates in the Gaussian Splatting repository. Reinstalling the module from the GitHub repository with the latest version should resolve the problem.

================================================
FILE: lib/__init__.py
================================================
from .models import *
from .mmutils import *
from .humanlrm_wrapper_sa_v1 import SapiensGS_SA_v1


================================================
FILE: lib/datasets/__init__.py
================================================
from .avatar_dataset import AvatarDataset
from .dataloader import DataModuleFromConfig

================================================
FILE: lib/datasets/avatar_dataset.py
================================================
import os
import random
import numpy as np
import torch
import json
import pickle

from torch.utils.data import Dataset
import torchvision.transforms.functional as F
import pickle
from torch.utils.data import Dataset
import webdataset as wds
# from lib.utils.train_util import print

import cv2
import av

from omegaconf import OmegaConf, ListConfig

def load_pose(path):
    with open(path, 'rb') as f:
        pose_param = json.load(f)
    c2w = np.array(pose_param['cam_param'], dtype=np.float32).reshape(4,4)
    cam_center = c2w[:3, 3]
    w2c = np.linalg.inv(c2w)
    # pose[:,:2] *= -1
    # pose = np.loadtxt(path, dtype=np.float32, delimiter=' ').reshape(9, 4)
    return [torch.from_numpy(w2c), torch.from_numpy(cam_center)]

def load_npy(file_path):
    return np.load(file_path, allow_pickle=True)

def load_smpl(path, smpl_type='smpl'):
    filetype = path.split('.')[-1]
    with open(path, 'rb') as f:
        if filetype=='pkl':
            smpl_param_data = pickle.load(f)
        elif filetype == 'json':
            smpl_param_data = json.load(f)
        else:
            assert False

    if smpl_type=='smpl':
        with open(os.path.join(os.path.split(path)[0][:-5], 'pose', '000_000.json'), 'rb') as f:
            tf_param = json.load(f)
        smpl_param = np.concatenate([np.array(tf_param['scale']).reshape(1, -1), np.array(tf_param['center'])[None], 
                    smpl_param_data['global_orient'], smpl_param_data['body_pose'].reshape(1, -1), smpl_param_data['betas']], axis=1)
    elif smpl_type == 'smplx':

        tf_param = np.load(os.path.join(os.path.dirname(os.path.dirname(path)), 'scale_offset.npy'), allow_pickle=True).item()
        # smpl_param = np.concatenate([np.array([tf_param['scale']]).reshape(1, -1), tf_param['offset'].reshape(1, -1), 
        smpl_param = np.concatenate([np.array([[1]]), np.array([[0,0,0]]), 
                    np.array(smpl_param_data['global_orient']).reshape(1, -1), np.array(smpl_param_data['body_pose']).reshape(1, -1), 
                    np.array(smpl_param_data['betas']).reshape(1, -1), np.array(smpl_param_data['left_hand_pose']).reshape(1, -1), np.array(smpl_param_data['right_hand_pose']).reshape(1, -1),
                    np.array(smpl_param_data['jaw_pose']).reshape(1, -1), np.array(smpl_param_data['leye_pose']).reshape(1, -1), np.array(smpl_param_data['reye_pose']).reshape(1, -1), 
                    np.array(smpl_param_data['expression']).reshape(1, -1)], axis=1)
    else:
        assert False
    
    return torch.from_numpy(smpl_param.astype(np.float32)).reshape(-1)


class AvatarDataset(Dataset):
    def __init__(self,
                 data_prefix,
                 code_dir=None,
                #  code_only=False,
                 load_imgs=True,
                 load_norm=False,
                 specific_observation_idcs=None,
                 specific_observation_num=None,
                 first_is_front=False, # yy add  # If True, it will returns a random sampled batch with the front view in the first place
                 better_range=False, # yy add  # If True, the views will not be fully random, but will be selected by a better skip
                 if_include_video_ref_img= False,# yy add Define a variable to indicate whether to include reference images from the video
                 prob_include_video_ref_img= 0.2, # yy add Define a variable to specify the probability
                 allow_k_angles_near_the_front = 0, # yy add, if value > 0, the front view will be allowed to be selected from the range of [front_view - allow_k_angles_near_the_front, front_view + allow_k_angles_near_the_front]
                #  num_test_imgs=0, 
                 if_use_swap_face_v1=False, # yy add, if True, use the swap face v1
                 random_test_imgs=False,
                 scene_id_as_name=False,
                 cache_path=None,
                 cache_repeat=None, # be the same length with the cache_path
                 test_pose_override=None,
                 num_train_imgs=-1,
                 load_cond_data=True,
                 load_test_data=True, 
                 max_num_scenes=-1,  # for debug or testing
                #  radius=0.5,
                 radius=1.0,
                 img_res=[640, 896],
                 test_mode=False,
                 step=1,  # only for debug & visualization purpose
                 crop=False # randomly crop the image with upper body inputs
                 ):
        super(AvatarDataset, self).__init__()
        self.data_prefix = data_prefix
        self.code_dir = code_dir
        # self.code_only = code_only
        self.load_imgs = load_imgs
        self.load_norm = load_norm
        self.specific_observation_idcs = specific_observation_idcs
        self.specific_observation_num = specific_observation_num
        self.first_is_front = first_is_front
        


        self.if_include_video_ref_img= if_include_video_ref_img 
        self.prob_include_video_ref_img = prob_include_video_ref_img
        self.allow_k_angles_near_the_front = allow_k_angles_near_the_front 

        self.better_range = better_range
        # self.num_test_imgs = num_test_imgs
        self.random_test_imgs = random_test_imgs
        self.scene_id_as_name = scene_id_as_name
        self.cache_path = cache_path
        self.cache_repeat = cache_repeat
        self.test_pose_override = test_pose_override
        self.num_train_imgs = num_train_imgs
        self.load_cond_data = load_cond_data
        self.load_test_data = load_test_data
        self.max_num_scenes = max_num_scenes
        self.step = step

        self.if_use_swap_face_v1 = if_use_swap_face_v1
        # import ipdb; ipdb.set_trace()

        self.img_res = [int(i) for i in img_res]

        self.radius = torch.tensor([radius], dtype=torch.float32).expand(3)
        self.center = torch.zeros_like(self.radius)

        self.load_scenes()
        
        self.crop = crop


        self.test_poses = self.test_intrinsics = None

        self.defalut_focal = 1120 #40 * (self.img_res[0]/32) # focal 80mm, sensor 32mm

        self.default_fxy_cxy = torch.tensor([self.defalut_focal, self.defalut_focal,  self.img_res[1]//2, self.img_res[0]//2]).reshape(1, 4)

        self.test_mode = test_mode

        if self.test_mode:
            self.parse_scene = self.parse_scene_test


    def load_scenes(self):

        
        if  isinstance(self.cache_path, ListConfig):

            cache_list = []
            case_num_per_dataset = 1000000000
            for ii, path in enumerate(self.cache_path):
                cache = np.load(path, allow_pickle=True)
                if self.cache_repeat is not None:
                    cache = np.repeat(cache, self.cache_repeat[ii], axis=0)
                print("done loading ", path)
                cache_list.extend(cache[:case_num_per_dataset]) 
            scenes = cache_list
            print(f"=========intialized totally {len(scenes)} scenes===========")

        else:
            if self.cache_path is not None and os.path.exists(self.cache_path):
                scenes = np.load(self.cache_path, allow_pickle=True)
                print("load ", self.cache_path)
            else:
                print(f"{self.cache_path} is not exist")
                raise  FileNotFoundError(f"maybe {self.cache_path} is not exist")
            
        end = len(scenes)
        if self.max_num_scenes >= 0:
            end = min(end, self.max_num_scenes * self.step)
        
        self.scenes = scenes[:end:self.step]
        self.num_scenes = len(self.scenes)
        
    def parse_scene(self, scene_id):
        scene = self.scenes[scene_id]
        input_is_video = False # flag of if the input is video, some operations should be different
        # print(scene)
        # scene['video_path'] = "/data/jxlv/transformers/src/A_pose_MEN-Denim-id_00000089-01_7_additional/result.mp4"
        # scene['image_paths'] = None #"/data/jxlv/transformers/src/A_pose_MEN-Denim-id_00000089-01_7_additional/source_seg.png"
        # import pdb
        # pdb.set_trace()
        # =========== loading the params ===========
        param = np.load(scene['param_path'], allow_pickle=True).item()
        scene.update(param)
        print(scene.keys())

        # =========== loading the multi-view images ===========
        if scene['image_paths'] is None:
            input_is_video = True
            video_path = scene['video_path']
            try:
                if self.if_use_swap_face_v1:
                    image_paths_or_video = read_frames(scene['video_path'].replace('result.mp4', 'output.mp4'))
                else:
                    image_paths_or_video = read_frames(scene['video_path'])
            except Exception as e:
                print(f"Error: {e}")
                print(f"Error in reading the video : {scene['video_path'].replace('result.mp4', 'output.mp4')}")
                image_paths_or_video = read_frames(scene['video_path'])
            # if 'pose_animate_service_0727' in video_path or 'flux' in video_path:
            #     # move the first to the last
            #     # TODO fixed this bug with a better cameras parameters
            #     image_paths_or_video = image_paths_or_video[1:] + image_paths_or_video[0:1]
        if not input_is_video:
            image_paths_or_video = scene['image_paths']
        scene_name = f"{scene_id:0>4d}" #  image_paths[0].split('/')[-3]
        results = dict(
            scene_id=[scene_id],
            scene_name=
                '{:04d}'.format(scene_id) if self.scene_id_as_name else scene_name,
                # cpu_only=True
                )
        # import pdb; pdb.set_trace()
        # if not self.code_only:
        poses = scene['poses']
        smpl_params = scene['smpl_params']

        # if input_is_video:
        #     num_imgs = len(video)
        # else:
        num_imgs = len(image_paths_or_video)
        # front_view = num_imgs // 4
        # randonly / specificically select the views of output
        smplx_cam_rotate = smpl_params[4: 7] #get global orient # 1, 3, 63, 10
        # smpl_params[70:80] = torch.rand_like(smpl_params[70:80]); print("error !! need to delete this rand betas in avatarnet:287") # get betas
        front_view = find_front_camera_by_rotation(poses, smplx_cam_rotate) # inputs camera poses and smplx poses
        if self.allow_k_angles_near_the_front > 0:
            allow_n_views_near_the_front =  round(self.allow_k_angles_near_the_front / 360 * num_imgs)
            new_front_view = random.randint(-allow_n_views_near_the_front, allow_n_views_near_the_front) + front_view
            if new_front_view >= num_imgs:
                new_front_view = new_front_view - num_imgs
            elif new_front_view < 0:
                new_front_view = new_front_view + num_imgs
            front_view = new_front_view
            print("changes the front views ranges", front_view, "+-", allow_n_views_near_the_front)

        if self.specific_observation_idcs is None:  ######### if not specify views ########
            # if self.num_train_imgs >= 0:
            #     num_train_imgs = self.num_train_imgs
            # else:
            num_train_imgs = num_imgs
            if self.random_test_imgs: ###### randomly selected images with self.num_train_imgs ######
                cond_inds = random.sample(range(num_imgs), self.num_train_imgs)
            elif self.specific_observation_num: ###### randomly selected "specific_observation_num" images ######
                if self.first_is_front and self.specific_observation_num < 2:
                    # self.specific_observation_num = 2
                    cond_inds =torch.tensor([front_view, front_view]) # first for input, second for supervised
                elif self.better_range: # select views by a uniform distribution range
                    if self.first_is_front: # must include the front view
                        num_train_imgs = self.specific_observation_num - 2
                    else:
                        num_train_imgs = self.specific_observation_num
                    skip_range = num_imgs//num_train_imgs
                    # select random views from each range of [skip_range] seperate from [0, skip_range, 2*skip_range, ...], 
                    cond_inds = torch.randperm(num_train_imgs) * skip_range \
                            + torch.randint(low=0, high=skip_range, size=[num_train_imgs])
                    if self.first_is_front: # concat [the first view * 2] to the front of cond_inds
                        cond_inds = torch.cat([torch.tensor([front_view, front_view]), cond_inds])
                        
                else: # previous version, random views are sampled
                    cond_inds = torch.randperm(num_imgs)[:self.specific_observation_num]
            else:
                cond_inds = np.round(np.linspace(0, num_imgs - 1, num_train_imgs)).astype(np.int64)
        else:   ######### selected target views ########
            cond_inds = self.specific_observation_idcs


        test_inds = list(range(num_imgs))

        if self.specific_observation_num: # yy note: if specific_observation_num is not None, then remove the test_inds
            test_inds = []
        else:
            for cond_ind in cond_inds:
                test_inds.remove(cond_ind)
        cond_smpl_param_ref = torch.zeros([189])
        if_use_smpl_param_ref = torch.Tensor([1]) # 默认使用ref smpl, 
        if self.load_cond_data and len(cond_inds) > 0:
            # cond_imgs, cond_poses, cond_intrinsics, cond_img_paths, cond_smpl_param, cond_norm = gather_imgs(cond_inds)
            cond_imgs, cond_poses, cond_intrinsics, cond_img_paths, cond_smpl_param, cond_norm = \
                gather_imgs(cond_inds, poses, image_paths_or_video, smpl_params, load_imgs=self.load_imgs, load_norm=self.load_norm, center=self.center, radius=self.radius,
                input_is_video=input_is_video)
            cond_smpl_param_ref = cond_smpl_param.clone() # the smpl_param_ref for the reference images
            if cond_intrinsics.shape[-1] == 3: # the old data format, which contains the value of camera center instead of fxfycxcy
                cond_intrinsics = self.default_fxy_cxy.clone().repeat(cond_intrinsics.shape[0], 1)
            # import pdb; pdb.set_trace()
            # print("video_path", video_path)
            if input_is_video:
                cond_img_paths = [f"{video_path[:-4]}_{i:0>4d}.png" for i in  range(self.specific_observation_num)] # Replace the .mp4 into index.png
            
            
            if self.if_include_video_ref_img and input_is_video:
                # 设置一个随机数,如果小于某个概率,那么替换第一张图为另一个图片
                if np.random.rand() < self.prob_include_video_ref_img:
                    if 'image_ref' in scene:
                        ref_image_path = scene['image_ref']
                        print("ref_image_path",ref_image_path)
                    else:
                        ref_image_path = video_path.replace(".mp4", ".jpg")
                    # if "flux" in ref_image_path: # temperaturelly supports the inputs from the flux
                    try:
                        # replacement_img_path =  ref_image_path
                        # replacement_img = load_image(ref_image_path)  # 假设有一个函数可以加载图片
                        # 使用cv2.IMREAD_UNCHANGED标志读取图片,以保留alpha通道
                        img = cv2.imread(ref_image_path, cv2.IMREAD_UNCHANGED)
                        assert img is not None, f"img is None, {ref_image_path}"
                        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
                        img = torch.from_numpy(img.astype(np.float32) / 255)  # (h, w, 3)
                            # print("img.shape", img.shape)
                            # print("cond_imgs.shape", cond_imgs.shape)
                        # test_img_paths[0] = ref_image_path
                        # results.update(test_imgs=test_imgs, test_img_paths=test_img_paths)

                        # ======== loading the reference smplx for images ==========
                        load_ref_smplx = False
                        if load_ref_smplx:
                            # if "flux" in ref_image_path:
                            smplx_smplify_path = from_video_to_get_ref_smplx(video_path)
                            # load json and get values
                            with open(smplx_smplify_path) as f:
                                data = json.load(f)

                            RT = torch.concatenate([ torch.Tensor(data['camera']['R']), torch.Tensor(data['camera']['t']).reshape(3,1)], dim=1)
                            RT = torch.cat([RT, torch.Tensor([[0,0,0,1]])], dim=0)


                            intri = torch.Tensor(data['camera']['focal'] + data['camera']['princpt'])

                            intri[[3,2]] = intri[[2,3]]
                            intri = intri * self.default_fxy_cxy[0,-1] / intri[-1]
                            
                            # 假设 smpl_param_data 是已经加载好的数据 
                            # (['root_pose', 'body_pose', 'jaw_pose', 'leye_pose', 'reye_pose', 'lhand_pose', 'rhand_pose', 'expr', 'trans', 'betas_save', 'nlf_smplx_betas', 'camera', 'img_path'])
                            smpl_param_data = data 

                            # 从字典中提取所需的数据
                            global_orient = np.array(smpl_param_data['root_pose']).reshape(1, -1)
                            body_pose = np.array(smpl_param_data['body_pose']).reshape(1, -1)
                            shape = np.array(smpl_param_data['betas_save']).reshape(1, -1)[:, :10]
                            left_hand_pose = np.array(smpl_param_data['lhand_pose']).reshape(1, -1)
                            right_hand_pose = np.array(smpl_param_data['rhand_pose']).reshape(1, -1)

                            # smpl_param_ref = np.concatenate([np.array([[1.]]), np.array([[0, 0, 0]]),
                            smpl_param_ref = np.concatenate([np.array([[1.]]), np.array(smpl_param_data['trans']).reshape(1,3),
                                global_orient,body_pose,
                                shape, left_hand_pose, right_hand_pose,
                                np.array(smpl_param_data['jaw_pose']).reshape(1, -1), np.array(smpl_param_data['leye_pose']).reshape(1, -1), np.array(smpl_param_data['reye_pose']).reshape(1, -1),
                                np.array(smpl_param_data['expr']).reshape(1, -1)[:,:10]], axis=1)


                            cond_poses[0] = RT         # RT
                            cond_intrinsics[0]  = intri   # fxfycxcy
                            cond_smpl_param_ref =  torch.Tensor(smpl_param_ref).reshape(-1)      # 189, combines 
                            if_use_smpl_param_ref = torch.Tensor([1])  # use the smpl_param_ref

                        # overwrite some datas
                        cond_imgs[0] = img
                        cond_img_paths[0] = ref_image_path

                    except (FileNotFoundError, json.JSONDecodeError, KeyError, Exception) as e:
                        # 记录错误信息到日志文件
                        # with open(self.log_file_path, 'a') as log_file:
                        #     log_file.write(f"{datetime.datetime.now()} {video_path} \n  - An error occurred: {str(e)}\n")
                        print(f"An error occurred: {e}")
                        
            ''' randomly crop the first image for augmentation'''
            if self.crop:
                # print("crop", cond_imgs[0].shape)
                if random.random() < 0.5:
                    # cond_imgs[0] = F.crop(cond_imgs[0], 0, 0, 512, 512)
                    crop_imgs = cond_imgs[0]
                    # 图像尺寸
                    h, w, _ = crop_imgs.shape

                    # 随机偏移量
                    random_offset_head = np.random.randint(-h//7, -h//8)
                    random_offset_body = np.random.randint(-h // 8, h // 8)

                    # head_joint, upper_body_joint
                    head_joint = [ w//2, h//7,]
                    upper_body_joint = [w//2, h//2, ]

                    # 计算裁剪区域
                    head_y = int(head_joint[1]) + random_offset_head
                    body_y = int(upper_body_joint[1]) + random_offset_body

                    # 确保裁剪区域在图像范围内
                    head_y = max(0, min(h, head_y))
                    body_y = max(0, min(h, body_y))

                    # 裁剪区域的高度和宽度
                    crop_height = body_y - head_y
                    crop_width =int(crop_height * 640 / 896)

                    # 确保裁剪区域在图像范围内
                    start_x = max(0, min(w - crop_width, int(w // 2 - crop_width // 2)))
                    end_x = start_x + crop_width
                    start_y = max(0, head_y)
                    end_y = min(h, body_y)

                    # 裁剪图像
                    cropped_img = crop_imgs[start_y:end_y, start_x:end_x]

            
                    padded_img = F.resize(cropped_img.permute(2, 0, 1), [h, w]).permute(1, 2, 0)

                    # save this img for debug
                    # Image.fromarray((padded_img.numpy() * 255).astype(np.uint8)).save(f"debug_crop.png")
                    # rescale the image for augmentation
                    cond_imgs[0] = random_scale_and_crop(padded_img, (0.8,1.2))
                else:
                    cond_imgs[0] = random_scale_and_crop(cond_imgs[0], (0.8,1.1))

            
            results.update(
                cond_poses=cond_poses,
                cond_intrinsics=cond_intrinsics.to(torch.float32),
                cond_img_paths=cond_img_paths, 
                cond_smpl_param=cond_smpl_param,
                cond_smpl_param_ref=cond_smpl_param_ref,
                if_use_smpl_param_ref=if_use_smpl_param_ref)
            if cond_imgs is not None:
                results.update(cond_imgs=cond_imgs)
            if cond_norm is not None:
                results.update(cond_norm=cond_norm)

        if self.load_test_data and len(test_inds) > 0:
            test_imgs, test_poses, test_intrinsics, test_img_paths, test_smpl_param, test_norm = \
                    gather_imgs(test_inds, poses, image_paths_or_video, smpl_params, load_imgs=self.load_imgs, load_norm=self.load_norm, center=self.center, radius=self.radius)
            
            if test_intrinsics.shape[-1] == 3: # the old data format, which contains the value of camera center instead of fxfycxcy
                test_intrinsics = self.default_fxy_cxy.clone().repeat(test_intrinsics.shape[0], 1)

            results.update(
                test_poses=test_poses,
                test_intrinsics=test_intrinsics,
                test_img_paths=test_img_paths,
                test_smpl_param=test_smpl_param)
            if test_imgs is not None:
                results.update(test_imgs=test_imgs)
            if test_norm is not None:
                results.update(test_norm=test_norm)

    
        if self.test_pose_override is not None:
            results.update(test_poses=self.test_poses, test_intrinsics=self.test_intrinsics)
        return results

    def __len__(self):
        return self.num_scenes

    def __getitem__(self, scene_id):
        try:
            scene = self.parse_scene(scene_id)

        except:
            print("ERROR in parsing ", scene_id)
            scene = self.parse_scene(0)
        return scene
    
    def parse_scene_test(self, scene_id):
        scene = self.scenes[scene_id]
        input_is_video = False # flag of if the input is video, some operations should be different

        
        # =========== loading the params ===========
        param = np.load(scene['param_path'], allow_pickle=True).item()
        scene.update(param)
        print(scene.keys())
        # import ipdb; ipdb.set_trace()
        if scene['image_paths'] is None:
            input_is_video = True
            video_path = scene['video_path']
            try:
                image_paths_or_video = read_frames(scene['video_path'])
            except Exception as e:
                print(f"Error: {e}")
                print(f"Error in reading the video : {scene['video_path'].replace('result.mp4', 'output.mp4')}")
                image_paths_or_video = read_frames(scene['video_path'])

        if not input_is_video:
            image_paths_or_video = scene['image_paths']
        scene_name = f"{scene_id:0>4d}" #  image_paths[0].split('/')[-3]
        results = dict(
            scene_id=[scene_id],
            scene_name=
                '{:04d}'.format(scene_id) if self.scene_id_as_name else scene_name,
                # cpu_only=True
                )
        
        # if not self.code_only:
        poses = scene['poses']
        smpl_params = scene['smpl_params']

        # if input_is_video:
        #     num_imgs = len(video)
        # else:
        num_imgs = len(image_paths_or_video)
        # front_view = num_imgs // 4
        # randonly / specificically select the views of output
        smplx_cam_rotate = smpl_params[4: 7] #get global orient # 1, 3, 63, 10
        # smpl_params[70:80] = torch.rand_like(smpl_params[70:80]); print("error !! need to delete this rand betas in avatarnet:287") # get betas
        front_view = find_front_camera_by_rotation(poses, smplx_cam_rotate) # inputs camera poses and smplx poses
        if self.allow_k_angles_near_the_front > 0:
            allow_n_views_near_the_front =  round(self.allow_k_angles_near_the_front / 360 * num_imgs)
            new_front_view = random.randint(-allow_n_views_near_the_front, allow_n_views_near_the_front) + front_view
            if new_front_view >= num_imgs:
                new_front_view = new_front_view - num_imgs
            elif new_front_view < 0:
                new_front_view = new_front_view + num_imgs
            front_view = new_front_view
            print("changes the front views ranges", front_view, "+-", allow_n_views_near_the_front)

        num_train_imgs = num_imgs
        test_inds = torch.Tensor( list(range(num_imgs)))
        cond_inds = np.concatenate([np.array([front_view]),test_inds]).astype(np.int64) # first for input, second for supervised
        test_inds = cond_inds.tolist()
        # if self.specific_observation_num: # yy note: if specific_observation_num is not None, then remove the test_inds
        #     test_inds = []
        # else:
        # for cond_ind in cond_inds:
        #     test_inds.remove(cond_ind)
        # cond_inds = cond_inds

        cond_smpl_param_ref = torch.zeros([189])
        if_use_smpl_param_ref = torch.Tensor([1]) # 默认使用ref smpl, 
        if self.load_cond_data and len(cond_inds) > 0:
            # cond_imgs, cond_poses, cond_intrinsics, cond_img_paths, cond_smpl_param, cond_norm = gather_imgs(cond_inds)
            cond_imgs, cond_poses, cond_intrinsics, cond_img_paths, cond_smpl_param, cond_norm = \
                gather_imgs(cond_inds, poses, image_paths_or_video, smpl_params, load_imgs=self.load_imgs, load_norm=self.load_norm, center=self.center, radius=self.radius, \
                input_is_video=input_is_video)
            cond_smpl_param_ref = cond_smpl_param.clone() # the smpl_param_ref for the reference images
            if cond_intrinsics.shape[-1] == 3: # the old data format, which contains the value of camera center instead of fxfycxcy
                cond_intrinsics = self.default_fxy_cxy.clone().repeat(cond_intrinsics.shape[0], 1)

            if input_is_video:
                cond_img_paths = [f"{video_path[:-4]}_{i:0>4d}.png" for i in  range(self.specific_observation_num)] # Replace the .mp4 into index.png
            
                    
            
            results.update(
                cond_poses=cond_poses,
                cond_intrinsics=cond_intrinsics.to(torch.float32),
                cond_img_paths=cond_img_paths, 
                cond_smpl_param=cond_smpl_param,
                cond_smpl_param_ref=cond_smpl_param_ref,
                if_use_smpl_param_ref=if_use_smpl_param_ref)
            if cond_imgs is not None:
                results.update(cond_imgs=cond_imgs)
            if cond_norm is not None:
                results.update(cond_norm=cond_norm)

        if self.load_test_data and len(test_inds) > 0:
            print("input_is_video", input_is_video)
            test_imgs, test_poses, test_intrinsics, test_img_paths, test_smpl_param, test_norm = \
                    gather_imgs(test_inds, poses, image_paths_or_video, smpl_params, load_imgs=self.load_imgs, load_norm=self.load_norm, center=self.center, radius=self.radius, \
                    input_is_video=input_is_video)
            
            if test_intrinsics.shape[-1] == 3: # the old data format, which contains the value of camera center instead of fxfycxcy
                test_intrinsics = self.default_fxy_cxy.clone().repeat(test_intrinsics.shape[0], 1)

            results.update(
                test_poses=test_poses,
                test_intrinsics=test_intrinsics,
                test_img_paths=test_img_paths,
                test_smpl_param=test_smpl_param)
            if test_imgs is not None:
                results.update(test_imgs=test_imgs)
            if test_norm is not None:
                results.update(test_norm=test_norm)

    
        if self.test_pose_override is not None:
            results.update(test_poses=self.test_poses, test_intrinsics=self.test_intrinsics)
        return results


def gather_imgs(img_ids, poses, image_paths_or_video, smpl_params, load_imgs=True, load_norm=False, center=None, radius=None, input_is_video=False):
    imgs_list = [] if load_imgs else None
    norm_list = [] if load_norm else None
    poses_list = []
    cam_centers_list = []
    img_paths_list = []
   
    for img_id in img_ids:
        pose = poses[img_id][0]
        cam_centers_list.append((poses[img_id][1]).to(torch.float)) # (C)
        c2w = pose.to(torch.float)#torch.FloatTensor(pose) # 虽然是c2w但其实存的值应该是w2c (R|T)
        cam_to_ndc = torch.cat(
            [c2w[:3, :3], (c2w[:3, 3:] - center[:, None]) / radius[:, None]], dim=-1)
        poses_list.append(
            torch.cat([
                cam_to_ndc,
                cam_to_ndc.new_tensor([[0.0, 0.0, 0.0, 1.0]])
            ], dim=-2))
        if input_is_video:
            # img_paths_list.append(video[img_id])
            img = image_paths_or_video[img_id]
            # for img in imgs: # add the ajustment to make the color > [250,250,250] to be white
            mask_white = np.all(img[:,:,:3] > 250, axis=-1, keepdims=False)
            # Image.fromarray(img).save(f"debug.png")
            # Image.fromarray(mask_white).save(f"debug_mask.png")
            img[mask_white] = [255, 255, 255]
            # Image.fromarray(img).save(f"debug_afmask.png")
            img = torch.from_numpy(img.astype(np.float32) / 255)  # (h, w, 3)
           
            imgs_list.append(img)
        else:
            img_paths_list.append(image_paths_or_video[img_id])
            if load_imgs:
                # img = mmcv.imread(image_paths[img_id], channel_order='rgb')
                                        
                # 使用cv2.IMREAD_UNCHANGED标志读取图片,以保留alpha通道
                print("Loading, .......", image_paths_or_video[img_id])
                print("Loading, .......", image_paths_or_video[img_id])
                print("Loading, .......", image_paths_or_video[img_id])
                img = cv2.imread(image_paths_or_video[img_id], cv2.IMREAD_UNCHANGED)
                print("img.shape", img.shape)
                # 将透明像素的RGB值设置为白色(255, 255, 255)
                img[img[..., 3] == 0] = [255, 255, 255, 255]
                img = img[..., :3]
                img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
                img = torch.from_numpy(img.astype(np.float32) / 255)  # (h, w, 3)
                imgs_list.append(img)
        if load_norm: # have not support the input type is video
            norm = cv2.imread(image_paths_or_video[img_id].replace('rgb', 'norm'), cv2.IMREAD_UNCHANGED)
            norm = cv2.cvtColor(norm, cv2.COLOR_BGR2RGB)
            norm = torch.from_numpy(norm.astype(np.float32) / 255)
            norm_list.append(norm)
    poses_list = torch.stack(poses_list, dim=0)  # (n, 4, 4)
    cam_centers_list = torch.stack(cam_centers_list, dim=0)
    if load_imgs:
        imgs_list = torch.stack(imgs_list, dim=0)  # (n, h, w, 3)
    if load_norm:
        norm_list = torch.stack(norm_list, dim=0)
    return imgs_list, poses_list, cam_centers_list, img_paths_list, smpl_params, norm_list

from scipy.spatial.transform import Rotation as R
def calculate_angle(vector1, vector2):
    unit_vector1 = vector1 / torch.linalg.norm(vector1)
    unit_vector2 = vector2 / torch.linalg.norm(vector2)
    dot_product = torch.dot(unit_vector1, unit_vector2)
    angle = torch.arccos(dot_product)
    return angle
def axis_angle_to_rotation_matrix(axis_angle):
    if_input_is_torch = torch.is_tensor(axis_angle)
    if if_input_is_torch:
        dtype_torch = axis_angle.dtype
        axis_angle = axis_angle.numpy()
        
    r = R.from_rotvec(axis_angle)
    rotation_matrix = r.as_matrix()
    if if_input_is_torch:
        rotation_matrix = torch.from_numpy(rotation_matrix).to(dtype_torch)

    return rotation_matrix
# def find_front_camera_by_global_orient(global_orient, camera_direction):
#     front_direction = np.array([0, 0, -1])  # 人体正面方向
#     min_angle = float('inf')
#     front_camera_idx = -1

#     for idx, global_orient in enumerate(global_orient_list):
        

#         angle = calculate_angle(body_direction, front_direction)
#         if angle < min_angle:
#             min_angle = angle
#             front_camera_idx = idx

#     return front_camera_idx
def find_front_camera_by_rotation(poses, global_orient_human):
    # front_direction = global_orient_human  # 人体正面方向
    rotation_matrix = axis_angle_to_rotation_matrix(global_orient_human)
    front_direction = rotation_matrix @ torch.Tensor([0, 0, -1])  # 人体的朝向
    min_angle = float('inf')
    front_camera_idx = -1

    for idx, pose in enumerate(poses):
        rotation_matrix = pose[0][:3, :3]
        camera_direction = rotation_matrix @ torch.Tensor([0, 0, 1])  # 相机的朝向
        angle = calculate_angle(camera_direction, front_direction).to(camera_direction.dtype)
        if angle < min_angle:
            min_angle = angle
            front_camera_idx = idx

    return front_camera_idx

def read_frames(video_path):
    container = av.open(video_path)

    video_stream = next(s for s in container.streams if s.type == "video")
    frames = []
    for packet in container.demux(video_stream):
        for frame in packet.decode():
            # image = Image.frombytes(
            #     "RGB",
            #     (frame.width, frame.height),
            #     frame.to_rgb().to_ndarray(),
            # )
            image =  frame.to_rgb().to_ndarray()
            frames.append(image)

    return frames


def prepare_camera( resolution_x, resolution_y, num_views=24, stides=1):
    # resolution_x = 640
    # resolution_y = 896
    import math
    focal_length = 40 #80
    sensor_width = 32

    # # 创建 Pyrender 相机
    # camera = pyrender.PerspectiveCamera(yfov=fov, aspectRatio=aspect_ratio)
    focal_length = focal_length * (resolution_y/sensor_width)

    K = np.array(
        [[focal_length, 0, resolution_x//2],
        [0, focal_length, resolution_y//2],
        [0, 0, 1]]
    )
    # print("update!! the camera intrisic is error 0819")
    def look_at(camera_position, target_position, up_vector):  # colmap +z forward, +y down
        forward = -(camera_position - target_position) / np.linalg.norm(camera_position - target_position)
        right = np.cross(up_vector, forward)
        up = np.cross(forward, right)
        return np.column_stack((right, up, forward))
    camera_pose_list = []
    for frame_idx in range(0, num_views, stides):
        # 设置相机的位置和方向
        camera_dist = 1.5 #3 #1.2 * 2
        phi = math.radians(90)
        theta = (frame_idx / num_views) * math.pi * 2
        camera_location = np.array(
            [camera_dist * math.sin(phi) * math.cos(theta),
            
            camera_dist * math.cos(phi),
            -camera_dist * math.sin(phi) * math.sin(theta),]
            )
        # print(camera_location)
        camera_pose = np.eye(4)
        camera_pose[:3, 3] = camera_location
        # print("camera_location", camera_location)

        # from smplx import look_at


        # 设置相机位置和目标位置
        camera_position = camera_location
        target_position = np.array([0.0, 0.0, 0.0])

        # 计算相机的旋转矩阵,使其朝向目标
        # up_vector = np.array([0.0, 1.0, 0.0])
        up_vector = np.array([0.0, -1.0, 0.0]) # colmap
        rotation_matrix = look_at(camera_position, target_position, up_vector)

        # 更新相机的位置和旋转
        camera_pose[:3, :3] = rotation_matrix
        camera_pose[:3, 3] = camera_position
        camera_pose_list.append(camera_pose)
    return K, camera_pose_list


def from_video_to_get_ref_smplx(video_path):
    # 分解路径
    video_dir = os.path.dirname(video_path)
    video_name = video_dir.split("/")[-1] # 视频文件夹名称
    
    # 替换视频目录为 smplify 目录
    if "flux" in video_dir:
        smplify_dir = video_dir.replace('/videos/', '/smplx_smplify/')
    elif "DeepFashion" in video_dir:
        smplify_dir = video_dir.replace('/video/', '/smplx_smplify/').replace("A_pose_", "")
    
    # # 获取视频文件名(不包括扩展名)
    # video_name = os.path.splitext(video_file)[0]
    
    # 构建 JSON 文件路径
    if smplify_dir[-1] == '/': smplify_dir = smplify_dir[:-1]
    json_path = smplify_dir+".json" #os.path.join(smplify_dir, f"{video_name}.json")
    
    return json_path

def random_scale_and_crop(image: torch.Tensor, scale_range=(0.8, 1.2)) -> torch.Tensor:
    """
    Randomly scale the input image and crop/pad to maintain original size.

    Args:
        image: Input image tensor of shape [H, W, 3]
        scale_range: Range for scaling factor, default (0.8, 1.2)

    Returns:
        Scaled and cropped/padded image tensor of shape [H, W, 3]
    """
    is_numpy = False
    if not torch.is_tensor(image):
        image = torch.from_numpy(image)
        is_numpy = True
    # 获取图像的高度和宽度
    h, w = image.shape[:2]

    # 生成随机缩放因子
    scale_factor = random.uniform(*scale_range)

    # 计算新的高度和宽度
    new_h = int(h * scale_factor)
    new_w = int(w * scale_factor)

    # 使用 torchvision.transforms.functional.resize 进行缩放
    scaled_image = F.resize(image.permute(2, 0, 1), [new_h, new_w]).permute(1, 2, 0)

    # 如果缩放后的图像比原图大,进行居中裁剪
    if new_h > h or new_w > w:
        top = (new_h - h) // 2
        left = (new_w - w) // 2
        scaled_image = scaled_image[top:top + h, left:left + w]
    else:
        # 如果缩放后的图像比原图小,进行居中填充
        padded_image = torch.ones((h, w, 3), dtype=image.dtype)
        top = h-new_h #(h - new_h) // 2 # H不应该居中
        left = (w - new_w) // 2
        padded_image[top:top + new_h, left:left + new_w] = scaled_image
        scaled_image = padded_image
    if is_numpy:
        scaled_image = scaled_image.numpy()
    return scaled_image

        

if __name__ == "__main__":


    import os

          

    params = {
        "data_prefix": None,
        "cache_path":  ListConfig([ 
            "./processed_data/deepfashion_train_145_local.npy",
            "./processed_data/flux_batch1_5000_train_145_local.npy"
        ]),
        "specific_observation_num": 5,
        "better_range": True,
        "first_is_front": True,
        "if_include_video_ref_img": True,
        "prob_include_video_ref_img": 0.5,
        "img_res": [640, 896],
        'test_mode': True
    }

    data = AvatarDataset(**params)

    sample = data[0]
    print(sample.keys())


    import os
    import torch.distributed as dist

    os.environ['RANK'] = '0'
    os.environ['WORLD_SIZE'] = '1'
    os.environ['MASTER_ADDR'] = '127.0.0.1'
    os.environ['MASTER_PORT'] = '29500'

    dist.init_process_group(backend='nccl', rank=0, world_size=1)


    # test the batch loader
    from torch.utils.data import DataLoader
    dataloader = DataLoader(data, batch_size=10, shuffle=True, collate_fn=custom_collate_fn)


    from torch.utils.data.distributed import DistributedSampler
    import webdataset as wds
    # sampler = DistributedSampler(data) # training  is true!~
    sampler = None
    dataloader = wds.WebLoader(data, batch_size=10, num_workers=1, shuffle=False, sampler=sampler,  )

        
    try:
        for i, batch in enumerate(dataloader):
            print(batch.keys())
            # break
    except Exception as e:
        import traceback
        print("Caught an exception during dataloader iteration:")
        traceback.print_exc()


================================================
FILE: lib/datasets/dataloader.py
================================================

import os, sys
import json


import numpy as np
import webdataset as wds
import pytorch_lightning as pl
import torch
from torch.utils.data import Dataset
from torch.utils.data.distributed import DistributedSampler

file_dir = os.path.abspath(os.path.dirname(__file__))
project_root = os.path.join(file_dir, '..', '..')
sys.path.append(project_root)  
from lib.utils.train_util import instantiate_from_config
from torch.utils.data import DataLoader 

class DataModuleFromConfig(pl.LightningDataModule):
    def __init__(
        self, 
        batch_size=8, 
        num_workers=4, 
        train=None, 
        validation=None, 
        test=None, 
        **kwargs,
    ):
        super().__init__()

        self.batch_size = batch_size
        self.num_workers = num_workers

        self.dataset_configs = dict()
        if train is not None:
            self.dataset_configs['train'] = train
        if validation is not None:
            self.dataset_configs['validation'] = validation
        if test is not None:
            self.dataset_configs['test'] = test

    def setup(self, stage):
        self.datasets = dict((k, instantiate_from_config(self.dataset_configs[k])) for k in self.dataset_configs)

    def train_dataloader(self):
        sampler = DistributedSampler(self.datasets['train']) if torch.distributed.is_initialized() else None
        return DataLoader(
            self.datasets['train'],
            batch_size=self.batch_size,
            num_workers=self.num_workers,
            shuffle=(sampler is None),
            sampler=sampler,
            pin_memory=True,
            drop_last=True,
        )

    def val_dataloader(self):
        sampler = DistributedSampler(self.datasets['validation']) if torch.distributed.is_initialized() else None
        return DataLoader(
            self.datasets['validation'],
            batch_size=1,
            num_workers=self.num_workers,
            shuffle=False,
            sampler=sampler,
            pin_memory=True
        )

    def test_dataloader(self):
        return DataLoader(
            self.datasets['test'],
            batch_size=self.batch_size,
            num_workers=self.num_workers,
            shuffle=False,
            pin_memory=True
        )


================================================
FILE: lib/humanlrm_wrapper_sa_v1.py
================================================

import os
import math
import json
from torch.optim import Adam
from torch.nn.parallel.distributed import DistributedDataParallel
import torch
import torch.nn.functional as F
from torchvision.transforms import InterpolationMode
from torchvision.utils import make_grid, save_image
from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity
from torchmetrics.image.ssim import StructuralSimilarityIndexMeasure
import pytorch_lightning as pl
from pytorch_lightning.utilities.grads import grad_norm
from einops import rearrange, repeat

from lib.utils.train_util import instantiate_from_config
from lib.ops.activation import TruncExp
import time
import matplotlib.pyplot as plt

from PIL import Image

import numpy as np 
from lib.utils.train_util import main_print

from typing import List, Optional, Tuple, Union
def get_1d_rotary_pos_embed(
    dim: int,
    pos: Union[torch.Tensor, int],
    theta: float = 10000.0,
    use_real=False,
    linear_factor=1.0,
    ntk_factor=1.0,
    repeat_interleave_real=True,
):
    """
    Precompute the frequency tensor for complex exponentials (cis) with given dimensions.

    This function calculates a frequency tensor with complex exponentials using the given dimension 'dim' and the end
    index 'end'. The 'theta' parameter scales the frequencies. The returned tensor contains complex values in complex64
    data type.

    Args:
        dim (`int`): Dimension of the frequency tensor.
        pos (`torch.Tensor` or `int`): Position indices for the frequency tensor. [S] or scalar
        theta (`float`, *optional*, defaults to 10000.0):
            Scaling factor for frequency computation. Defaults to 10000.0.
        use_real (`bool`, *optional*):
            If True, return real part and imaginary part separately. Otherwise, return complex numbers.
        linear_factor (`float`, *optional*, defaults to 1.0):
            Scaling factor for the context extrapolation. Defaults to 1.0.
        ntk_factor (`float`, *optional*, defaults to 1.0):
            Scaling factor for the NTK-Aware RoPE. Defaults to 1.0.
        repeat_interleave_real (`bool`, *optional*, defaults to `True`):
            If `True` and `use_real`, real part and imaginary part are each interleaved with themselves to reach `dim`.
            Otherwise, they are concateanted with themselves.
    Returns:
        `torch.Tensor`: Precomputed frequency tensor with complex exponentials. [S, D/2]
    """
    assert dim % 2 == 0

    if isinstance(pos, int):
        pos = torch.arange(pos)
    theta = theta * ntk_factor
    freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) / linear_factor  # [D/2]
    t = pos # torch.from_numpy(pos).to(freqs.device)  # type: ignore  # [S]
    freqs = freqs.to(device=t.device, dtype=t.dtype)  # type: ignore
    freqs = torch.outer(t, freqs).float()  # type: ignore   # [S, D/2]
    if use_real and repeat_interleave_real:
        freqs_cos = freqs.cos().repeat_interleave(2, dim=1)  # [S, D]
        freqs_sin = freqs.sin().repeat_interleave(2, dim=1)  # [S, D]
        return freqs_cos, freqs_sin
    elif use_real:
        freqs_cos = torch.cat([freqs.cos(), freqs.cos()], dim=-1)  # [S, D]
        freqs_sin = torch.cat([freqs.sin(), freqs.sin()], dim=-1)  # [S, D]
        return freqs_cos, freqs_sin
    else:
        freqs_cis = torch.polar(torch.ones_like(freqs), freqs)  # complex64     # [S, D/2]
        return freqs_cis
class FluxPosEmbed(torch.nn.Module):
    # modified from https://github.com/black-forest-labs/flux/blob/c00d7c60b085fce8058b9df845e036090873f2ce/src/flux/modules/layers.py#L11
    def __init__(self, theta: int, axes_dim: [int]):
        super().__init__()
        self.theta = theta
        self.axes_dim = axes_dim

    def forward(self, ids: torch.Tensor) -> torch.Tensor:
        n_axes = ids.shape[-1]
        cos_out = []
        sin_out = []
        pos = ids.float()
        is_mps = ids.device.type == "mps"
        freqs_dtype = torch.float32 if is_mps else torch.float64
        for i in range(n_axes):
            cos, sin = get_1d_rotary_pos_embed(
                self.axes_dim[i], pos[:, i], repeat_interleave_real=True, use_real=True#, freqs_dtype=freqs_dtype
            )
            cos_out.append(cos)
            sin_out.append(sin)
        freqs_cos = torch.cat(cos_out, dim=-1).to(ids.device)
        freqs_sin = torch.cat(sin_out, dim=-1).to(ids.device)
        return freqs_cos, freqs_sin
    
class SapiensGS_SA_v1(pl.LightningModule):
    
    def __init__(
        self,
        encoder=dict(type='mmpretrain.VisionTransformer'),
        neck=dict(type='mmpretrain.VisionTransformer'),
        decoder=dict(),
        diffusion_use_ema=True,
        freeze_decoder=False,
        image_cond=False,
        code_permute=None,
        code_reshape=None,
        autocast_dtype=None,
        ortho=True,
        return_norm=False,
        # reshape_type='reshape', # 'cnn'
        code_size=None,
        decoder_use_ema=None,
        bg_color=1,
        training_mode=None, # stage2's flag, default None for stage1 

        patch_size: int = 4,
            
        warmup_steps: int = 12_000,
        use_checkpoint: bool = True,
        lambda_depth_tv: float = 0.05,
        lambda_lpips: float = 2.0,
        lambda_mse: float = 1.0,
        lambda_l1: float=0, 
        lambda_ssim: float=0, 
        neck_learning_rate: float = 5e-4,
        decoder_learning_rate: float = 1e-3,
        encoder_learning_rate: float=0,
        max_steps: int = 100_000,
        loss_coef: float = 0.5,
        init_iter: int = 500,
        lambda_offset: int = 50,  # offset_weight: 50
        scale_weight: float = 0.01,
        is_debug: bool = False, # if debug, then it will not returns lpips
        code_activation: dict=None,
        output_hidden_states: bool=False, # if True, will output the hidden states from sapiens shallow layer, for the neck decoder
        loss_weights_views: List = [], # the loss weights for the views, if empty, will use the same weights for all the views
        **kwargs
    ):
        super(SapiensGS_SA_v1, self).__init__()
        ## ========== part -- Add the code to save this parameters for optimizers ========
        self.warmup_steps = warmup_steps
        self.use_checkpoint = use_checkpoint
        self.lambda_depth_tv = lambda_depth_tv
        self.lambda_lpips = lambda_lpips
        self.lambda_mse = lambda_mse
        self.lambda_l1 = lambda_l1
        self.lambda_ssim = lambda_ssim  
        self.neck_learning_rate = neck_learning_rate
        self.decoder_learning_rate = decoder_learning_rate
        self.encoder_learning_rate = encoder_learning_rate
        self.max_steps = max_steps

        self.loss_coef = loss_coef
        self.init_iter = init_iter
        self.lambda_offset = lambda_offset
        self.scale_weight = scale_weight

        self.is_debug = is_debug
        ## ========== end part ========

     
        
        self.code_size = code_size
        if code_activation['type'] == 'tanh':
            self.code_activation = torch.nn.Tanh()
        else:
            self.code_activation = TruncExp() #build_module(code_activation)
        # self.grid_size = grid_size
        self.decoder = instantiate_from_config(decoder)
        self.decoder_use_ema = decoder_use_ema
        if decoder_use_ema:
            raise NotImplementedError("decoder_use_ema has not been implemented")
            if self.decoder_use_ema:
                self.decoder_ema = deepcopy(self.decoder)
        self.encoder = instantiate_from_config(encoder)
        # get_obj_from_str(config["target"])

        self.code_size = code_reshape
        self.code_clip_range = [-1,1]
        
        # ============= begin config  =============
        # transformer from class MAEPretrainDecoder(BaseModule):
         # compress the token number of the uv code
        self.patch_size = patch_size
        self.code_patch_size = self.patch_size
        self.num_patches_axis = code_reshape[-1]//self.patch_size # reshape it for the upsampling
        self.num_patches = self.num_patches_axis ** 2
        self.code_feat_dims = code_reshape[0] # only used for the upsampling of 'reshape' type 
        self.code_resolution = code_reshape[-1] # only used for the upsampling of 'reshape' type 

        self.reshape_type = self.decoder.reshape_type

        
        self.inputs_front_only = True
        self.render_loss_all_view = True
        self.if_include_video_ref_img = True
        
        self.training_mode = training_mode

        self.loss_weights_views = torch.Tensor(loss_weights_views).reshape(-1) / sum(loss_weights_views)  # normalize the weights
        
        

        # ========== config meaning ===========
        self.neck =  instantiate_from_config(neck)

        self.ids_restore = torch.arange(0, self.num_patches).unsqueeze(0)
        self.freeze_decoder = freeze_decoder
        if self.freeze_decoder:
            self.decoder.requires_grad_(False)
            if self.decoder_use_ema:
                self.decoder_ema.requires_grad_(False)
        self.image_cond = image_cond
        self.code_permute = code_permute
        self.code_reshape = code_reshape
        self.code_reshape_inv = [self.code_size[axis] for axis in self.code_permute] if code_permute is not None \
            else self.code_size
        self.code_permute_inv = [self.code_permute.index(axis) for axis in range(len(self.code_permute))] \
            if code_permute is not None else None

        self.autocast_dtype = autocast_dtype
        self.ortho = ortho
        self.return_norm = return_norm

        '''add a flag for the skip connection from sapiens shallow layer'''
        self.output_hidden_states = output_hidden_states

        ''' add the in-the-wild images visualization'''
        if self.lambda_lpips > 0:
            self.lpips = LearnedPerceptualImagePatchSimilarity(net_type='vgg')
        else:
            self.lpips = None

        self.ssim = StructuralSimilarityIndexMeasure()
        
        
        self.validation_step_outputs = []
        self.validation_step_code_outputs = [] # saving the code
        self.validation_step_nvPose_outputs = []
        self.validation_metrics = []

        
        # loading the smplx for the nv pose
        import json
        import numpy as np
        # evaluate the animation
        smplx_path = './work_dirs/demo_data/Ways_to_Catch_360_clip1.json'
        with open(smplx_path, 'r') as f:
            smplx_pose_param = json.load(f)
        smplx_param_list = []
        for par in smplx_pose_param['annotations']:
            k = par['smplx_params']
            for i in k.keys():
                k[i] = np.array(k[i])
            left_hands = np.array([1.4624, -0.1615,  0.1361,  1.3851, -0.2597,  0.0247, -0.0683, -0.4478,
                -0.6652, -0.7290,  0.0084, -0.4818])
            betas = torch.zeros((10))
            smplx_param = \
                np.concatenate([np.array([1]), np.array([0,0.,0]),  np.array([0, -1, 0])*k['root_orient'], \
                                k['pose_body'],betas, \
                                    k['pose_hand'], k['pose_jaw'], np.zeros(6), k['face_expr'][:10]], axis=0).reshape(1,-1)
            # print(smplx_param.shape)
            smplx_param_list.append(smplx_param)
        smplx_params = np.concatenate(smplx_param_list, 0)
        self.smplx_params = torch.Tensor(smplx_params).cuda()
    def get_default_smplx_params(self):
        A_pose = torch.Tensor([[ 1.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
            0.0000,  0.1047,  0.0000,  0.0000, -0.1047,  0.0000,  0.0000,  0.0000,
            0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
            0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
            0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
            0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
            0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000, -0.7854,  0.0000,
            0.0000,  0.7854,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
            0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000, -0.7470,  1.0966,
            0.0169, -0.0534, -0.0212,  0.0782, -0.0348,  0.0260,  0.0060,  0.0118,
            -0.1117, -0.0429,  0.4164, -0.1088,  0.0660,  0.7562,  0.0964,  0.0909,
            0.1885,  0.1181, -0.0509,  0.5296,  0.1437, -0.0552,  0.7049,  0.0192,
            0.0923,  0.3379,  0.4570,  0.1963,  0.6255,  0.2147,  0.0660,  0.5069,
            0.3697,  0.0603,  0.0795,  0.1419,  0.0859,  0.6355,  0.3033,  0.0579,
            0.6314,  0.1761,  0.1321,  0.3734, -0.8510, -0.2769,  0.0915,  0.4998,
            -0.0266, -0.0529, -0.5356, -0.0460,  0.2774, -0.1117,  0.0429, -0.4164,
            -0.1088, -0.0660, -0.7562,  0.0964, -0.0909, -0.1885,  0.1181,  0.0509,
            -0.5296,  0.1437,  0.0552, -0.7049,  0.0192, -0.0923, -0.3379,  0.4570,
            -0.1963, -0.6255,  0.2147, -0.0660, -0.5069,  0.3697, -0.0603, -0.0795,
            0.1419, -0.0859, -0.6355,  0.3033, -0.0579, -0.6314,  0.1761, -0.1321,
            -0.3734, -0.8510,  0.2769, -0.0915,  0.4998,  0.0266,  0.0529, -0.5356,
            0.0460, -0.2774,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
            0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
            0.0000,  0.0000,  0.0000,  0.0000,  0.0000]])
        return A_pose
    def forward_decoder(self, decoder, code, target_rgbs, cameras,   
            smpl_params=None, return_decoder_loss=False, init=False):
        decoder = self.decoder_ema if self.freeze_decoder and self.decoder_use_ema else self.decoder
        num_imgs = target_rgbs.shape[1]
        outputs = decoder(
            code, smpl_params, cameras,
            num_imgs, return_loss=return_decoder_loss, init=init, return_norm=False)
        return outputs

    def on_fit_start(self):
        if self.global_rank == 0:
            os.makedirs(os.path.join(self.logdir, 'images'), exist_ok=True)
            os.makedirs(os.path.join(self.logdir, 'images_val'), exist_ok=True)
            os.makedirs(os.path.join(self.logdir, 'images_val_code'), exist_ok=True)
    


    def forward(self, data):
        # print("iter")
        
        num_scenes = len(data['scene_id'])  # 8
        if 'cond_imgs' in data:
            cond_imgs = data['cond_imgs']  # (num_scenes, num_imgs, h, w, 3)
            cond_intrinsics = data['cond_intrinsics']  # (num_scenes, num_imgs, 4), in [fx, fy, cx, cy]
            cond_poses = data['cond_poses'] # (num_scenes, num_imgs, 4, 4)
            smpl_params = data['cond_smpl_param']  # (num_scenes, c)
            # if 'cond_norm' in data:cond_norm = data['cond_norm'] else:  cond_norm = None
            num_scenes, num_imgs, h, w, _ = cond_imgs.size()
            cameras = torch.cat([cond_intrinsics, cond_poses.reshape(num_scenes, num_imgs, -1)], dim=-1)
            if self.if_include_video_ref_img: # new! we render to all view, including the input image; # don't compute this loss
                cameras = cameras[:,1:]
                num_imgs = num_imgs - 1


        if self.inputs_front_only: # default setting to use the first image as input
            inputs_img_idx = [0]
        else:
            raise NotImplementedError("inputs_front_only is False")
        inputs_img = cond_imgs[:,inputs_img_idx[0],...].permute([0,3,1,2]) # 
        
        target_imgs = cond_imgs[:, 1:]
        assert cameras.shape[1] == target_imgs.shape[1]


        if self.is_debug:
            try:
                code = self.forward_image_to_uv(inputs_img, is_training=self.training) #TODO check where the validation
            except Exception as e: # OOM
                main_print(e)
                code = torch.zeros([num_scenes, 32, 256, 256]).to(inputs_img.dtype).to(inputs_img.device)
        else:
            code = self.forward_image_to_uv(inputs_img, is_training=self.training) #TODO check where the validation

        decoder = self.decoder_ema if self.freeze_decoder and self.decoder_use_ema else self.decoder
        # uvmaps_decoder_gender's forward
        output = decoder(
            code, smpl_params, cameras,
            num_imgs, return_loss=False, init=(self.global_step < self.init_iter), return_norm=False) #(['scales', 'norm', 'image', 'offset'])
     
        output['code'] = code
        output['target_imgs'] = target_imgs
        output['inputs_img'] = cond_imgs[:,[0],...]
       
        # for visualization
        if self.global_rank == 0 and self.global_step % 200 == 0 and self.is_debug:
            overlay_imgs = 0.5 * target_imgs + 0.5 * output['image']
            overlay_imgs = rearrange(overlay_imgs, 'b n h w c -> b h n w c')
            overlay_imgs = rearrange(overlay_imgs, ' b h n w c -> (b h) (n w) c')
            overlay_imgs = overlay_imgs.to(torch.float32).detach().cpu().numpy()
            overlay_imgs = (overlay_imgs * 255).astype(np.uint8)
            Image.fromarray(overlay_imgs).save(f'debug_{self.global_step}.jpg')
        
        return output

    def forward_image_to_uv(self, inputs_img, is_training=True):
        '''
            inputs_img: torch.Tensor, bs, 3, H, W
            return
            code : bs, 256, 256, 32
        '''
        if self.decoder_learning_rate <= 0:
            with torch.no_grad():
                features_flatten =  self.encoder(inputs_img, use_my_proces=True, output_hidden_states=self.output_hidden_states) 
        else:
            features_flatten =  self.encoder(inputs_img, use_my_proces=True, output_hidden_states=self.output_hidden_states) 
        
        if self.ids_restore.device !=features_flatten.device:
            self.ids_restore = self.ids_restore.to(features_flatten.device)
        ids_restore = self.ids_restore.expand([features_flatten.shape[0], -1])
        uv_code =  self.neck(features_flatten, ids_restore)
        batch_size, token_num, dims_feature = uv_code.shape
        
        if self.reshape_type=='reshape':
            feature_map = uv_code.reshape(batch_size, self.num_patches_axis, self.num_patches_axis,\
                            self.code_feat_dims, self.code_patch_size, self.code_patch_size) # torch.Size([1, 64, 64, 32, 4, 4, ])  
            feature_map = feature_map.permute(0, 3, 1, 4, 2, 5)   # ([1, 32, 64, 4, 64, 4])
            feature_map = feature_map.reshape(batch_size, self.code_feat_dims,  self.code_resolution, self.code_resolution) # torch.Size([1, 32, 256, 256])
            code = feature_map # [1, 32, 256, 256]
        else:
            feature_map = uv_code.reshape(batch_size, self.num_patches_axis, self.num_patches_axis,dims_feature) # torch.Size([1, 64, 64, 512, ])  
            if isinstance(self.decoder, DistributedDataParallel):
                code = self.decoder.module.upsample_conv(feature_map.permute([0,3,1,2])) # torch.Size([1, 32, 256, 256])
            else:
                code = self.decoder.upsample_conv(feature_map.permute([0,3,1,2])) # torch.Size([1, 32, 256, 256])

        code = self.code_activation(code)
        return code

    def compute_loss(self, render_out):
        render_images = render_out['image'] # .Size([1, 5, 896, 640, 3]), range [0, 1]
        target_images = render_out['target_imgs']
        target_images  =target_images.to(render_images)
        if self.is_debug:
            render_images_tmp= rearrange(render_images, 'b n h w c -> (b n) c h w')
            target_images_tmp = rearrange(target_images, 'b n h w c -> (b n) c h w')
            all_images = torch.cat([render_images_tmp, target_images_tmp], dim=2)
            all_images = render_images_tmp*0.5 + target_images_tmp*0.5
            grid = make_grid(all_images, nrow=4, normalize=True, value_range=(0, 1))
            save_image(grid, "./debug.png")
            main_print("saving into ./debug.png")
           

        render_images = rearrange(render_images, 'b n h w c -> (b n) c h w') * 2.0 - 1.0
        target_images = rearrange(target_images, 'b n h w c -> (b n) c h w') * 2.0 - 1.0
        if self.lambda_mse<=0:
            loss_mse = 0
        else:
            if self.loss_weights_views.numel() != 0:
                b, n, _, _, _ = render_out['image'].shape
                loss_weights_views = self.loss_weights_views.unsqueeze(0).to(render_images.device)
                loss_weights_views = loss_weights_views.repeat(b,1).reshape(-1).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
                loss_mse = weighted_mse_loss(render_images, target_images, loss_weights_views)
                main_print("weighted sum mse")
            else:
                loss_mse = F.mse_loss(render_images, target_images)

        if self.lambda_l1<=0:
            loss_l1 = 0
        else:
            loss_l1 = F.l1_loss(render_images, target_images)

        if self.lambda_ssim <= 0:
            loss_ssim = 0
        else:
            loss_ssim = 1 - self.ssim(render_images, target_images)
        if not self.is_debug:
            if self.lambda_lpips<=0:
                loss_lpips = 0
            else:
                if self.loss_weights_views.numel() != 0:
                    with torch.cuda.amp.autocast():
                        loss_lpips = self.lpips(render_images.clamp(-1, 1), target_images)
                else:
                    loss_lpips = 0
                    with torch.cuda.amp.autocast():
                        for img_idx in range(render_images.shape[0]):
                            loss_lpips += self.lpips(render_images[[img_idx]].clamp(-1, 1), target_images[[img_idx]])
                    loss_lpips /= render_images.shape[0]
                    
        else:
            loss_lpips = 0
        loss_gs_offset = render_out['offset']
        loss = loss_mse * self.lambda_mse \
            + loss_l1 * self.lambda_l1 \
            + loss_ssim * self.lambda_ssim \
            + loss_lpips * self.lambda_lpips \
            + loss_gs_offset * self.lambda_offset
        
        prefix = 'train'
        loss_dict = {}
        loss_dict.update({f'{prefix}/loss_mse': loss_mse})
        loss_dict.update({f'{prefix}/loss_lpips': loss_lpips})
        loss_dict.update({f'{prefix}/loss_gs_offset': loss_gs_offset})
        loss_dict.update({f'{prefix}/loss_ssim': loss_ssim})
        loss_dict.update({f'{prefix}/loss_l1': loss_l1})
        loss_dict.update({f'{prefix}/loss': loss})

        return loss, loss_dict
    
    def compute_metrics(self, render_out):
        # NOTE: all the rgb value range  is [0, 1]
        # render_out.keys = (['scales', 'norm', 'image', 'offset', 'code', 'target_imgs'])
        render_images = render_out['image'].clamp(0, 1) # .Size([1, 5, 896, 640, 3]), range [0, 1]
        target_images = render_out['target_imgs']
        if target_images.dtype!=render_images.dtype:
            target_images = target_images.to(render_images.dtype)

        render_images = rearrange(render_images, 'b n h w c -> (b n) c h w')
        target_images = rearrange(target_images, 'b n h w c -> (b n) c h w').to(render_images)

        mse = F.mse_loss(render_images, target_images).mean()
        psnr = 10 * torch.log10(1.0 / mse)
        ssim = self.ssim(render_images, target_images)
        
        render_images = render_images * 2.0 - 1.0
        target_images = target_images * 2.0 - 1.0

        if self.lambda_lpips<=0:
            lpips = torch.Tensor([0]).to(render_images.device).to(render_images.dtype)
        else:
            with torch.cuda.amp.autocast():
                lpips = self.lpips(render_images, target_images)

        metrics = {
            'val/mse': mse,
            'val/pnsr': psnr,
            'val/ssim': ssim,
            'val/lpips': lpips,
        }
        return metrics

    def new_on_before_optimizer_step(self):
        norms = grad_norm(self.neck, norm_type=2)
        if 'grad_2.0_norm_total' in norms:
            self.log_dict({'grad_norm/lrm_generator': norms['grad_2.0_norm_total']})

    @torch.no_grad()
    def validation_step(self, batch, batch_idx):
        render_out = self.forward(batch)

        metrics = self.compute_metrics(render_out)
        self.validation_metrics.append(metrics)
        render_images = render_out['image']
        render_images = rearrange(render_images, 'b n h w c -> b c h (n w)')
        gt_images = render_out['target_imgs']
        gt_images = rearrange(gt_images, 'b n h w c-> b c h (n w)')
        log_images = torch.cat([render_images, gt_images], dim=-2)
        self.validation_step_outputs.append(log_images)

        self.validation_step_code_outputs.append( render_out['code'])

        render_out_comb = self.forward_nvPose(batch, smplx_given=None)
        self.validation_step_nvPose_outputs.append(render_out_comb)
       
      
    def forward_nvPose(self, batch, smplx_given):
        '''
            smplx_given: torch.Tensor, bs, 189
            it will returns images with cameras_num * poses_num
        '''
        _, num_img, _,_ = batch['cond_poses'].shape
        # write a code to seperately input the smplx_params
        if smplx_given == None:
            step_pose = self.smplx_params.shape[0] // num_img
            smplx_given = self.smplx_params
        else:
            step_pose = 1
        render_out_list = []
        for i in range(num_img):  
            target_pose = smplx_given[[i*step_pose]]
            bk = batch['cond_smpl_param'].clone()
            batch['cond_smpl_param'][:, 7:70] = target_pose[:, 7:70] # copy body_pose
            batch['cond_smpl_param'][:, 80:80+93] = target_pose[:, 80:80+93]# copy pose_hand + pose_jaw
            batch['cond_smpl_param'][:, 179:189] = target_pose[:, 179:189]# copy face expression
            render_out_new = self.forward(batch)
            render_out_list.append(render_out_new['image'])
        render_out_comb = torch.cat(render_out_list, dim=2) # stack in the H axis
        render_out_comb = rearrange(render_out_comb, 'b n h w c -> b c h (n w)')
        return render_out_comb

    
    def on_validation_epoch_end(self): #
        images = torch.cat(self.validation_step_outputs, dim=-1)
        all_images = self.all_gather(images).cpu()
        all_images = rearrange(all_images, 'r b c h w -> (r b) c h w')

        # nv pose
        images_pose = torch.cat(self.validation_step_nvPose_outputs, dim=-1)
        all_images_pose = self.all_gather(images_pose).cpu()
        all_images_pose = rearrange(all_images_pose, 'r b c h w -> (r b) c h w')

        if self.global_rank == 0:
            image_path = os.path.join(self.logdir, 'images_val', f'val_{self.global_step:07d}.png')

            grid = make_grid(all_images, nrow=1, normalize=True, value_range=(0, 1))
            save_image(grid, image_path)
            main_print(f"Saved image to {image_path}")

            metrics = {}
            for key in self.validation_metrics[0].keys():
                metrics[key] = torch.stack([m[key] for m in self.validation_metrics]).mean()
            self.log_dict(metrics, prog_bar=True, logger=True, on_step=False, on_epoch=True)


            # code for saving the nvPose images
            image_path_nvPose = os.path.join(self.logdir, 'images_val', f'val_{self.global_step:07d}_nvPose.png')
            grid_nvPose = make_grid(all_images_pose, nrow=1, normalize=True, value_range=(0, 1))
            save_image(grid_nvPose, image_path_nvPose)
            main_print(f"Saved image to {image_path_nvPose}")


            # code for saving the code images
            for i, code in enumerate(self.validation_step_code_outputs):
                image_path = os.path.join(self.logdir, 'images_val_code')
                
                num_scenes, num_chn, h, w = code.size()
                code_viz = code.reshape(num_scenes, 4, 8, h, w).to(torch.float32).cpu().numpy()
                code_viz = code_viz.transpose(0, 1, 3, 2, 4).reshape(num_scenes, 4 * h, 8 * w)
                for j, code_viz_single in enumerate(code_viz):
                    plt.imsave(os.path.join(image_path, f'val_{self.global_step:07d}_{i*num_scenes+j:04d}' + '.png'), code_viz_single,
                        vmin=self.code_clip_range[0], vmax=self.code_clip_range[1])
        self.validation_step_outputs.clear()
        self.validation_step_nvPose_outputs.clear()
        self.validation_metrics.clear()
        self.validation_step_code_outputs.clear()
    
    def on_test_start(self):
        if self.global_rank == 0:
            os.makedirs(os.path.join(self.logdir, 'images_test'), exist_ok=True)
    
    def on_test_epoch_end(self):
        metrics = {}
        metrics_mean = {}
        metrics_var = {}
        for key in self.validation_metrics[0].keys():
            tmp = torch.stack([m[key] for m in self.validation_metrics]).cpu().numpy()
            metrics_mean[key] = tmp.mean()
            metrics_var[key] = tmp.var()

        formatted_metrics = {}
        for key in metrics_mean.keys():
            formatted_metrics[key] = f"{metrics_mean[key]:.4f}±{metrics_var[key]:.4f}"

        for key in self.validation_metrics[0].keys():
            metrics[key] = torch.stack([m[key] for m in self.validation_metrics]).cpu().numpy().tolist()
        

        final_dict = {"average": formatted_metrics,
                      'details': metrics}

        metric_path = os.path.join(self.logdir, f'metrics.json')
        with open(metric_path, 'w') as f:
            json.dump(final_dict, f, indent=4)
        main_print(f"Saved metrics to {metric_path}")
        
        for key in metrics.keys():
            metrics[key] = torch.stack([m[key] for m in self.validation_metrics]).mean()
        main_print(metrics)
        
        self.validation_metrics.clear()
    
    def configure_optimizers(self):
        # define the optimizer and the scheduler for neck and decoder 
        main_print("WARNING currently, we only support the single optimizer for both neck and decoder")
        
        learning_rate = self.neck_learning_rate
        params= [
            {'params': self.neck.parameters(), 'lr': self.neck_learning_rate, },
            {'params': self.decoder.parameters(), 'lr': self.decoder_learning_rate},
        ]
        if hasattr(self, "encoder_learning_rate") and self.encoder_learning_rate>0:
            params.append({'params': self.encoder.parameters(), 'lr': self.encoder_learning_rate})
            main_print("============add the encoder into the optimizer============")
        optimizer = torch.optim.Adam(
            params
        )
        T_warmup, T_max, eta_min = self.warmup_steps, self.max_steps, 0.001
        lr_lambda = lambda step: \
            eta_min + (1 - math.cos(math.pi * step / T_warmup)) * (1 - eta_min) * 0.5 if step < T_warmup else \
            eta_min + (1 + math.cos(math.pi * (step - T_warmup) / (T_max - T_warmup))) * (1 - eta_min) * 0.5
        scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
        return  {'optimizer': optimizer, 'lr_scheduler': scheduler}
    
    def training_step(self, batch, batch_idx):
        scheduler = self.lr_schedulers()
        scheduler.step()
        render_gt = None #? 
        render_out = self.forward(batch)
        loss, loss_dict = self.compute_loss(render_out)


        self.log_dict(loss_dict, prog_bar=True, logger=True, on_step=True, on_epoch=True)
       
        if self.global_step % 200 == 0 and self.global_rank == 0:
            self.new_on_before_optimizer_step() # log the norm
        if self.global_step % 200 == 0 and self.global_rank == 0:
            if self.if_include_video_ref_img and self.training:
                render_images = torch.cat([ torch.ones_like(render_out['image'][:,0:1]), render_out['image']], dim=1)
                target_images = torch.cat([ render_out['inputs_img'], render_out['target_imgs']], dim=1)

            target_images = rearrange(
                target_images, 'b n h w c -> b c h (n w)')
            render_images = rearrange(
                render_images, 'b n  h w c-> b c h (n w)')
            

            grid = torch.cat([
                target_images, render_images, 0.5*render_images + 0.5*target_images,
               
            ], dim=-2)
            grid = make_grid(grid, nrow=target_images.shape[0], normalize=True, value_range=(0, 1))
           
            image_path = os.path.join(self.logdir, 'images', f'train_{self.global_step:07d}.jpg')
            save_image(grid, image_path)
            main_print(f"Saved image to {image_path}")

        return loss
        
    @torch.no_grad()
    def test_step(self, batch, batch_idx):
        # input_dict, render_gt = self.prepare_validation_batch_data(batch)
        render_out = self.forward(batch)
        render_gt = render_out['target_imgs']
        render_img = render_out['image']
        # Compute metrics
        metrics = self.compute_metrics(render_out)
        self.validation_metrics.append(metrics)
        
        # Save images
        target_images = rearrange(
            render_gt, 'b n h w c -> b c h (n w)')
        render_images = rearrange(
            render_img, 'b n h w c -> b c h (n w)')


        grid = torch.cat([
            target_images, render_images, 
        ], dim=-2)
        grid = make_grid(grid, nrow=target_images.shape[0], normalize=True, value_range=(0, 1))
        # self.logger.log_image('train/render', [grid], step=self.global_step)
        image_path = os.path.join(self.logdir, 'images_test', f'{batch_idx:07d}.png')
        save_image(grid, image_path)

        # code visualize
        code = render_out['code']
        self.decoder.visualize(code, batch['scene_name'],
                        os.path.dirname(image_path), code_range=self.code_clip_range)

        print(f"Saved image to {image_path}")
    
    def on_test_start(self):
        if self.global_rank == 0:
            os.makedirs(os.path.join(self.logdir, 'images_test'), exist_ok=True)
    
    def on_test_epoch_end(self):
        metrics = {}
        metrics_mean = {}
        metrics_var = {}
        for key in self.validation_metrics[0].keys():
            tmp = torch.stack([m[key] for m in self.validation_metrics]).cpu().numpy()
            metrics_mean[key] = tmp.mean()
            metrics_var[key] = tmp.var()

        # trans format into "mean±var" 
        formatted_metrics = {}
        for key in metrics_mean.keys():
            formatted_metrics[key] = f"{metrics_mean[key]:.4f}±{metrics_var[key]:.4f}"

        for key in self.validation_metrics[0].keys():
            metrics[key] = torch.stack([m[key] for m in self.validation_metrics]).cpu().numpy().tolist()
        

        # saving into a dictionary
        final_dict = {"average": formatted_metrics,
                      'details': metrics}

        metric_path = os.path.join(self.logdir, f'metrics.json')
        with open(metric_path, 'w') as f:
            json.dump(final_dict, f, indent=4)
        print(f"Saved metrics to {metric_path}")
        
        for key in metrics.keys():
            metrics[key] = torch.stack([m[key] for m in self.validation_metrics]).mean()
        print(metrics)
        
        self.validation_metrics.clear()
    
def weighted_mse_loss(render_images, target_images, weights):
    squared_diff = (render_images - target_images) ** 2
    main_print(squared_diff.shape, weights.shape)
    weighted_squared_diff = squared_diff * weights
    loss_mse_weighted = weighted_squared_diff.mean()
    return loss_mse_weighted

================================================
FILE: lib/mmutils/__init__.py
================================================
from .initialize import xavier_init, constant_init

================================================
FILE: lib/mmutils/initialize.py
================================================
import torch.nn as nn

def constant_init(module: nn.Module, val: float, bias: float = 0) -> None:
    if hasattr(module, 'weight') and module.weight is not None:
        nn.init.constant_(module.weight, val)
    if hasattr(module, 'bias') and module.bias is not None:
        nn.init.constant_(module.bias, bias)


def xavier_init(module: nn.Module,
                gain: float = 1,
                bias: float = 0,
                distribution: str = 'normal') -> None:
    assert distribution in ['uniform', 'normal']
    if hasattr(module, 'weight') and module.weight is not None:
        if distribution == 'uniform':
            nn.init.xavier_uniform_(module.weight, gain=gain)
        else:
            nn.init.xavier_normal_(module.weight, gain=gain)
    if hasattr(module, 'bias') and module.bias is not None:
        nn.init.constant_(module.bias, bias)

================================================
FILE: lib/models/__init__.py
================================================

from .decoders import *


================================================
FILE: lib/models/decoders/__init__.py
================================================

from .uvmaps_decoder_gender import UVNDecoder_gender

__all__ = [ 'UVNDecoder_gender']


================================================
FILE: lib/models/decoders/uvmaps_decoder_gender.py
================================================
import os
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import einsum
from pytorch3d import ops

from lib.mmutils import xavier_init, constant_init
import numpy as np
import time
import cv2
import math
from simple_knn._C import distCUDA2
from pytorch3d.transforms import quaternion_to_matrix

from ..deformers import SMPLXDeformer_gender
from ..renderers import GRenderer, get_covariance, batch_rodrigues 
from lib.ops import TruncExp
import torchvision

from lib.utils.train_util import main_print

def ensure_dtype(input_tensor, target_dtype=torch.float32):
    """
    Ensure tensor dtype matches target dtype.
    If not, convert it.
    """
    if input_tensor.dtype != target_dtype:
        input_tensor = input_tensor.to(dtype=target_dtype)
    return input_tensor


class UVNDecoder_gender(nn.Module):

    activation_dict = {
        'relu': nn.ReLU,
        'silu': nn.SiLU,
        'softplus': nn.Softplus,
        'trunc_exp': TruncExp,
        'sigmoid': nn.Sigmoid}

    def __init__(self,
                 *args,
                 interp_mode='bilinear',
                 base_layers=[3 * 32, 128],
                 density_layers=[128, 1],
                 color_layers=[128, 128, 3],
                 offset_layers=[128, 3],
                 scale_layers=[128, 3],
                 radius_layers=[128, 3],
                 use_dir_enc=True,
                 dir_layers=None,
                 scene_base_size=None,
                 scene_rand_dims=(0, 1),
                 activation='silu',
                 sigma_activation='sigmoid',
                 sigmoid_saturation=0.001,
                 code_dropout=0.0,
                 flip_z=False,
                 extend_z=False,
                 gender='neutral',
                 multires=0,
                 bg_color=0,
                 image_size=1024,
                 superres=False,
                 focal = 1280, # the default focal defination
                 reshape_type=None, # if true, it will create a cnn layers to upsample the uv features
                 fix_sigma=False, #  if true, the density of GS will be fixed
                 up_cnn_in_channels = None, #  the channel number of the upsample cnn
                 vithead_param=None, # the vit head for decode to uv features
                 is_sub2=False, # if true, will use the sub2 uv map
                 **kwargs):
        super().__init__()
        self.interp_mode = interp_mode
        self.in_chn = base_layers[0]
        self.use_dir_enc = use_dir_enc
        if scene_base_size is None:
            self.scene_base = None
        else:
            rand_size = [1 for _ in scene_base_size]
            for dim in scene_rand_dims:
                rand_size[dim] = scene_base_size[dim]
            init_base = torch.randn(rand_size).expand(scene_base_size).clone()
            self.scene_base = nn.Parameter(init_base)
        self.dir_encoder = None
        self.sigmoid_saturation = sigmoid_saturation
        self.deformer = SMPLXDeformer_gender(gender, is_sub2=is_sub2)

        self.renderer = GRenderer(image_size=image_size, bg_color=bg_color, f=focal)
        if superres:
            self.superres = None
        else:
            self.superres = None
        self.gender= gender
        self.reshape_type = reshape_type
        if reshape_type=='cnn':
            self.upsample_conv = torch.nn.ConvTranspose2d(512, 32, kernel_size=4, stride=4,).cuda()
                                                 
        elif reshape_type == 'VitHead': # changes the up block's layernorm into the feature channel norm instead of the full image norm
            from lib.models.decoders.vit_head import VitHead
            self.upsample_conv = VitHead(**vithead_param)
            # 256, 128, 128 -> 128, 256, 256 -> 64, 512, 512, ->32, 1024, 1024
        
        base_cache_dir = 'work_dirs/cache'   
        if is_sub2:
            base_cache_dir = 'work_dirs/cache_sub2'
            # main_print("!!!!!!!!!!!!!!!!!!! using the sub2 uv map !!!!!!!!!!!!!!!!!!!")
        if gender == 'neutral':
            select_uv = torch.as_tensor(np.load(base_cache_dir+'/init_uv_smplx_newNeutral.npy'))
            self.register_buffer('select_coord', select_uv.unsqueeze(0)*2.-1.)

            init_pcd = torch.as_tensor(np.load(base_cache_dir+'/init_pcd_smplx_newNeutral.npy'))
            self.register_buffer('init_pcd', init_pcd.unsqueeze(0), persistent=False) # 0.9-- -1
        elif gender == 'male':
            assert NotImplementedError("Haven't create the init_uv_smplx_thu in v_template")
            select_uv = torch.as_tensor(np.load(base_cache_dir+'/init_uv_smplx_thu.npy'))
            self.register_buffer('select_coord', select_uv.unsqueeze(0)*2.-1.)

            init_pcd = torch.as_tensor(np.load(base_cache_dir+'/init_pcd_smplx_thu.npy'))
            self.register_buffer('init_pcd', init_pcd.unsqueeze(0), persistent=False) # 0.9-- -1
        self.num_init = self.init_pcd.shape[1]
        main_print(f"!!!!!!!!!!!!!!!!!!! cur points number are {self.num_init} !!!!!!!!!!!!!!!!!!!")

        self.init_pcd = self.init_pcd 

        self.multires = multires # 0 Haven't 
        if multires > 0:
            uv_map = torch.as_tensor(np.load(base_cache_dir+'/init_uvmap_smplx_thu.npy'))
            pcd_map = torch.as_tensor(np.load(base_cache_dir+'/init_posmap_smplx_thu.npy'))
            input_coord = torch.cat([pcd_map, uv_map], dim=1)
            self.register_buffer('input_freq', input_coord, persistent=False)
            base_layers[0] += 5
            color_layers[0] += 5
        else:
            self.init_uv = None

        activation_layer = self.activation_dict[activation.lower()]


        base_net = [] # linear (in=18, out=64, bias=True)
        for i in range(len(base_layers) - 1):
            base_net.append(nn.Conv2d(base_layers[i], base_layers[i + 1], 3, padding=1))
            if i != len(base_layers) - 2:
                base_net.append(nn.BatchNorm2d(base_layers[i+1]))
                base_net.append(activation_layer())
        self.base_net = nn.Sequential(*base_net)
        self.base_bn = nn.BatchNorm2d(base_layers[-1])
        self.base_activation = activation_layer()

        density_net = [] # linear(in=64, out=1, bias=True), sigmoid
        for i in range(len(density_layers) - 1):
            density_net.append(nn.Conv2d(density_layers[i], density_layers[i + 1], 1))
            if i != len(density_layers) - 2:
                density_net.append(nn.BatchNorm2d(density_layers[i+1]))
                density_net.append(activation_layer())
        density_net.append(self.activation_dict[sigma_activation.lower()]())
        self.density_net = nn.Sequential(*density_net)

        offset_net = [] # linear(in=64, out=1, bias=True), sigmoid
        for i in range(len(offset_layers) - 1):
            offset_net.append(nn.Conv2d(offset_layers[i], offset_layers[i + 1], 1))
            if i != len(offset_layers) - 2:
                offset_net.append(nn.BatchNorm2d(offset_layers[i+1]))
                offset_net.append(activation_layer())
        self.offset_net = nn.Sequential(*offset_net)

        self.dir_net = None
        color_net = [] # linear(in=64, out=3, bias=True), sigmoid
        for i in range(len(color_layers) - 2):
            color_net.append(nn.Conv2d(color_layers[i], color_layers[i + 1], kernel_size=3, padding=1))
            color_net.append(nn.BatchNorm2d(color_layers[i+1]))
            color_net.append(activation_layer())
        color_net.append(nn.Conv2d(color_layers[-2], color_layers[-1], kernel_size=1))
        color_net.append(nn.Sigmoid())
        self.color_net = nn.Sequential(*color_net)
        self.code_dropout = nn.Dropout2d(code_dropout) if code_dropout > 0 else None

        self.flip_z = flip_z
        self.extend_z = extend_z

        if self.gender == 'neutral':
            init_rot = torch.as_tensor(np.load(base_cache_dir+'/init_rot_smplx_newNeutral.npy'))
            self.register_buffer('init_rot', init_rot, persistent=False)

            face_mask = torch.as_tensor(np.load(base_cache_dir+'/face_mask_thu_newNeutral.npy'))
            self.register_buffer('face_mask', face_mask.unsqueeze(0), persistent=False)

            hands_mask = torch.as_tensor(np.load(base_cache_dir+'/hands_mask_thu_newNeutral.npy'))
            self.register_buffer('hands_mask', hands_mask.unsqueeze(0), persistent=False)

            outside_mask = torch.as_tensor(np.load(base_cache_dir+'/outside_mask_thu_newNeutral.npy'))
            self.register_buffer('outside_mask', outside_mask.unsqueeze(0), persistent=False)
        else:
            assert NotImplementedError("Haven't create the init_rot in v_template")
            init_rot = torch.as_tensor(np.load(base_cache_dir+'/init_rot_smplx_thu.npy'))
            self.register_buffer('init_rot', init_rot, persistent=False)

            face_mask = torch.as_tensor(np.load(base_cache_dir+'/face_mask_thu.npy'))
            self.register_buffer('face_mask', face_mask.unsqueeze(0), persistent=False)

            hands_mask = torch.as_tensor(np.load(base_cache_dir+'/hands_mask_thu.npy'))
            self.register_buffer('hands_mask', hands_mask.unsqueeze(0), persistent=False)

            outside_mask = torch.as_tensor(np.load(base_cache_dir+'/outside_mask_thu.npy'))
            self.register_buffer('outside_mask', outside_mask.unsqueeze(0), persistent=False)

        self.iter = 0
        self.init_weights()
        self.if_rotate_gaussian = False
        self.fix_sigma = fix_sigma
    def init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Linear):
                xavier_init(m, distribution='uniform')
        if self.dir_net is not None:
            constant_init(self.dir_net[-1], 0)
        if self.offset_net is not None:
            self.offset_net[-1].weight.data.uniform_(-1e-5, 1e-5)
            self.offset_net[-1].bias.data.zero_()


    def extract_pcd(self, code, smpl_params, init=False, zeros_hands_off=False):
        '''
        Args:
            B == num_scenes
            code (tensor): latent code. shape: [B, C, H, W]
            smpl_params (tensor): SMPL parameters. shape: [B_pose, 189]
            init (bool): Not used
        Returns:
            defm_pcd (tensor): deformed point cloud. shape: [B, N, B_pose, 3]
            sigmas, rgbs, offset, radius, rot(tensor): GS attributes. shape: [B, N, C]
            tfs(tensor): deformation matrics. shape: [B, N, C]
        '''
        if isinstance(code, list):
            num_scenes, _, h, w = code[0].size()
        else:
            num_scenes, n_channels, h, w = code.size()
        init_pcd = self.init_pcd.repeat(num_scenes, 1, 1)  # T-posed space points, for computing the skinning weights
        
        sigmas, rgbs, radius, rot, offset = self._decode(code, init=init) #  the person-specify attributes of GS
        if self.fix_sigma:
            sigmas = torch.ones_like(sigmas)
        if zeros_hands_off:
            offset[self.hands_mask[...,None].expand(num_scenes, -1, 3)] = 0
        canon_pcd = init_pcd + offset
        
        self.deformer.prepare_deformer(smpl_params, num_scenes, device=canon_pcd.device)
        defm_pcd, tfs = self.deformer(canon_pcd, rot, mask=(self.face_mask+self.hands_mask+self.outside_mask), cano=False, if_rotate_gaussian=self.if_rotate_gaussian)
        return defm_pcd, sigmas, rgbs, offset, radius, tfs, rot

    def deform_pcd(self, code, smpl_params, init=False, zeros_hands_off=False, value=0.1):
        '''
        Args:
            B == num_scenes
            code (List): list of data
            smpl_params (tensor): SMPL parameters. shape: [B_pose, 189]
            init (bool): Not used
        Returns:
            defm_pcd (tensor): deformed point cloud. shape: [B, N, B_pose, 3]
            sigmas, rgbs, offset, radius, rot(tensor): GS attributes. shape: [B, N, C]
            tfs(tensor): deformation matrics. shape: [B, N, C]
        '''
        sigmas, rgbs, radius, rot, offset = code
        num_scenes = sigmas.shape[0]
        init_pcd = self.init_pcd.repeat(num_scenes, 1, 1)  #T-posed space points, for computing the skinning weights

        if self.fix_sigma:
            sigmas = torch.ones_like(sigmas)
        if zeros_hands_off:
            offset[self.hands_mask[...,None].expand(num_scenes, -1, 3)] = torch.clamp(offset[self.hands_mask[...,None].expand(num_scenes, -1, 3)], -value, value)
        canon_pcd = init_pcd + offset
        self.deformer.prepare_deformer(smpl_params, num_scenes, device=canon_pcd.device)
        defm_pcd, tfs = self.deformer(canon_pcd, rot, mask=(self.face_mask+self.hands_mask+self.outside_mask), cano=False, if_rotate_gaussian=self.if_rotate_gaussian)
        return defm_pcd, sigmas, rgbs, offset, radius, tfs, rot


        
    def _sample_feature(self,results,):
        # outputs, sigma_uv, offset_uv, rgbs_uv, radius_uv, rot_uv = results['output'], results['sigma'], results['offset'], results['rgbs'], results['radius'], results['rot']
        sigma = results['sigma']
        outputs = results['output']
        if isinstance(sigma, list):
            num_scenes, _, h, w = sigma[0].shape
            select_coord = self.select_coord.unsqueeze(1).repeat(num_scenes, 1, 1, 1)
        elif sigma.dim() == 4:
            num_scenes, n_channels, h, w = sigma.shape
            select_coord = self.select_coord.unsqueeze(1).repeat(num_scenes, 1, 1, 1)
        else:
            assert False
        output_attr = F.grid_sample(outputs, select_coord, mode=self.interp_mode, padding_mode='border', align_corners=False).reshape(num_scenes, 13, -1).permute(0, 2, 1)
        sigma, offset, rgbs, radius, rot = output_attr.split([1, 3, 3, 3, 3], dim=2)

        if self.sigmoid_saturation > 0:
            rgbs = rgbs * (1 + self.sigmoid_saturation * 2) - self.sigmoid_saturation
        
        radius = (radius - 0.5) * 2
        rot = (rot - 0.5) * np.pi

        return sigma, rgbs, radius, rot, offset

    def _decode_feature(self, point_code, init=False):
        if isinstance(point_code, list):
            num_scenes, _, h, w = point_code[0].shape
            geo_code, tex_code = point_code
            # select_coord = self.select_coord.unsqueeze(1).repeat(num_scenes, 1, 1, 1)
            if self.multires != 0:
                input_freq = self.input_freq.repeat(num_scenes, 1, 1, 1)
        elif point_code.dim() == 4:
            num_scenes, n_channels, h, w = point_code.shape
            # select_coord = self.select_coord.unsqueeze(1).repeat(num_scenes, 1, 1, 1)
            if self.multires != 0:
                input_freq = self.input_freq.repeat(num_scenes, 1, 1, 1)
            geo_code, tex_code = point_code.split(16, dim=1)
        else:
            assert False

        base_in = geo_code if self.multires == 0 else torch.cat([geo_code, input_freq], dim=1)
        base_x = self.base_net(base_in)
        base_x_act = self.base_activation(self.base_bn(base_x))
    
        sigma = self.density_net(base_x_act)
        offset = self.offset_net(base_x_act)
        color_in = tex_code if self.multires == 0 else torch.cat([tex_code, input_freq], dim=1)
        rgbs_radius_rot = self.color_net(color_in)
        
        outputs = torch.cat([sigma, offset, rgbs_radius_rot], dim=1)
        main_print(outputs.shape)
        sigma, offset, rgbs, radius, rot = outputs.split([1, 3, 3, 3, 3], dim=1)
        results = {'output':outputs, 'sigma': sigma, 'offset': offset, 'rgbs': rgbs, 'radius': radius, 'rot': rot}

        return results
    def _decode(self, point_code, init=False):
        if isinstance(point_code, list):
            num_scenes, _, h, w = point_code[0].shape
            geo_code, tex_code = point_code
            select_coord = self.select_coord.unsqueeze(1).repeat(num_scenes, 1, 1, 1)
            if self.multires != 0:
                input_freq = self.input_freq.repeat(num_scenes, 1, 1, 1)
        elif point_code.dim() == 4:
            num_scenes, n_channels, h, w = point_code.shape
            select_coord = self.select_coord.unsqueeze(1).repeat(num_scenes, 1, 1, 1)
            if self.multires != 0:
                input_freq = self.input_freq.repeat(num_scenes, 1, 1, 1)
            geo_code, tex_code = point_code.split(16, dim=1)
        else:
            assert False

        base_in = geo_code if self.multires == 0 else torch.cat([geo_code, input_freq], dim=1)
        base_x = self.base_net(base_in)
        base_x_act = self.base_activation(self.base_bn(base_x))
     
        sigma = self.density_net(base_x_act)
        offset = self.offset_net(base_x_act)
        color_in = tex_code if self.multires == 0 else torch.cat([tex_code, input_freq], dim=1)
        rgbs_radius_rot = self.color_net(color_in)
        
        outputs = torch.cat([sigma, offset, rgbs_radius_rot], dim=1)
        output_attr = F.grid_sample(outputs, select_coord, mode=self.interp_mode, padding_mode='border', align_corners=False).reshape(num_scenes, 13, -1).permute(0, 2, 1)
        sigma, offset, rgbs, radius, rot = output_attr.split([1, 3, 3, 3, 3], dim=2)

        if self.sigmoid_saturation > 0:
            rgbs = rgbs * (1 + self.sigmoid_saturation * 2) - self.sigmoid_saturation
        
        radius = (radius - 0.5) * 2
        rot = (rot - 0.5) * np.pi

        return sigma, rgbs, radius, rot, offset

    def gaussian_render(self, pcd, sigmas, rgbs, normals, rot, num_scenes, num_imgs, cameras, use_scale=False, radius=None, \
                         return_norm=False, return_viz=False, mask=None):
        # add mask or visible points to images or select ind to images
        '''
           render the gaussian to images
           return_norm: return the normals of the gaussian (haven't been used)
           return_viz: return the mask of the gaussian
           mask: the mask of the gaussian
        '''
        assert num_scenes == 1
        
        pcd = pcd.reshape(-1, 3)
        if use_scale: 
            dist2 = distCUDA2(pcd)
            dist2 = torch.clamp_min((dist2), 0.0000001)
            scales = torch.sqrt(dist2)[...,None].repeat(1, 3).detach() # distence between different points
            scale = (radius+1)*scales  # scaling_modifier # radius[-1--1], scale of GS
            cov3D = get_covariance(scale, rot).reshape(-1, 6) # inputs rot is the rotations
       
        images_all = []
        viz_masks = [] if return_viz else None
        norm_all = [] if return_norm else None

        if mask != None:
            pcd = pcd[mask]
            rgbs = rgbs[mask]
            sigmas = sigmas[mask]
            cov3D = cov3D[mask]
            normals = normals[mask]
        if 1:
            for i in range(num_imgs):
                self.renderer.prepare(cameras[i])

                image = self.renderer.render_gaussian(means3D=pcd, colors_precomp=rgbs, 
                    rotations=None, opacities=sigmas, scales=None, cov3D_precomp=cov3D)
                images_all.append(image)
                if return_viz:
                    viz_mask = self.renderer.render_gaussian(means3D=pcd, colors_precomp=pcd.clone(), 
                        rotations=None, opacities=sigmas*0+1, scales=None, cov3D_precomp=cov3D)
                    viz_masks.append(viz_mask)
          

        images_all = torch.stack(images_all, dim=0).unsqueeze(0).permute(0, 1, 3, 4, 2)
        if return_viz:
            viz_masks = torch.stack(viz_masks, dim=0).unsqueeze(0).permute(0, 1, 3, 4, 2).reshape(1, -1, 3)
            dist_sq, idx, neighbors = ops.knn_points(pcd.unsqueeze(0), viz_masks[:, ::10], K=1, return_nn=True)
            viz_masks = (dist_sq < 0.0001)[0]
         # ===== END the original code for batch size = 1 =====
        if use_scale:
            return images_all, norm_all, viz_masks, scale
        else:
            return images_all, norm_all, viz_masks, None

    def visualize(self, code, scene_name, viz_dir, code_range=[-1, 1]):
        num_scenes, num_chn, h, w = code.size()
        code_viz = code.reshape(num_scenes, 4, 8, h, w).to(torch.float32).cpu().numpy()
        if not self.flip_z:
            code_viz = code_viz[..., ::-1, :]
        code_viz = code_viz.transpose(0, 1, 3, 2, 4).reshape(num_scenes, 4 * h, 8 * w)
        for code_single, code_viz_single, scene_name_single in zip(code, code_viz, scene_name):
            plt.imsave(os.path.join(viz_dir, 'a_scene_' + scene_name_single + '.png'), code_viz_single,
                       vmin=code_range[0], vmax=code_range[1])

    def forward(self, code, smpl_params, cameras, num_imgs,
                return_loss=False, return_norm=False, init=False, mask=None, zeros_hands_off=False):
        """
        Args:

          
            density_bitfield: Shape (num_scenes, griz_size**3 // 8)
            YY:
            grid_size, dt_gamma, perturb, T_thresh are deleted
            code: Shape (num_scenes, *code_size)
            cameras: Shape (num_scenes, num_imgs, 19(3+16))
            smpl_params: Shape (num_scenes, 189)

        """
        # import ipdb; ipdb.set_trace()
        if isinstance(code, list):
            num_scenes = len(code[0])
        else:
            num_scenes = len(code)
        assert num_scenes > 0
        self.iter+=1

        image = []
        scales = []
        norm = [] if return_norm else None
        viz_masks = [] if not self.training else None

        xyzs, sigmas, rgbs, offsets, radius, tfs, rot = self.extract_pcd(code, smpl_params, init=init, zeros_hands_off=zeros_hands_off)

        if zeros_hands_off:
            main_print('zeros_hands_off is on!')
            main_print('zeros_hands_off is on!')
            offsets[self.hands_mask[...,None].repeat(num_scenes, 1, 3)] = 0
            rgbs[self.hands_mask[...,None].repeat(num_scenes, 1, 3)] = 0
            rgbs[self.hands_mask[...,None].repeat(num_scenes, 1, 3)] = 0
        R_delta = batch_rodrigues(rot.reshape(-1, 3))
        R = torch.bmm(self.init_rot.repeat(num_scenes, 1, 1), R_delta)
        R_def = torch.bmm(tfs.flatten(0, 1)[:, :3, :3], R)
        normals = (R_def[:, :, -1]).reshape(num_scenes, -1, 3)
        R_def_batch = R_def.reshape(num_scenes, -1, 3, 3)
      
        return_to_bfloat16 = True if xyzs.dtype==torch.bfloat16 else False ####### ============ translate the output to BF16 =================
        # return_to_bfloat16 = False # I don't want to trans it back to bf16
        if return_to_bfloat16:
              main_print("changes the return_to_bfloat16")
              cameras, R_def_batch, xyzs, rgbs, sigmas, normals, radius = [ensure_dtype(item, torch.float32) for item in (cameras, R_def_batch, xyzs, rgbs, sigmas, normals, radius)]
        # with torch.amp.autocast(enabled=False, device_type='cuda'):
        if 1:
            for camera_single, R_def_single, pcd_single, rgbs_single, sigmas_single, normal_single, radius_single in zip(cameras, R_def_batch, xyzs, rgbs, sigmas, normals, radius):
                image_single, norm_single,  viz_mask, scale = self.gaussian_render(pcd_single, sigmas_single, rgbs_single, normal_single, R_def_single, 1, num_imgs, camera_single, use_scale=True, \
                                                                        radius=radius_single, return_norm=return_norm, return_viz=not self.training)
                image.append(image_single)
                scales.append(scale.unsqueeze(0))
                if return_norm:
                    norm.append(norm_single)
                if not self.training:
                    viz_masks.append(viz_mask)
        image = torch.cat(image, dim=0)
        scales = torch.cat(scales, dim=0)

        norm = torch.cat(norm, dim=0) if return_norm else None
        viz_masks = torch.cat(viz_masks, dim=0) if  (not self.training) and viz_masks else None

      
        main_print("not trans the rendered results to float16")
        if False:
            image = image.to(torch.bfloat16)
            scales = scales.to(torch.bfloat16)
            if return_norm:
                norm = norm.to(torch.bfloat16)
            if viz_masks is not None:
                viz_masks = viz_masks.to(torch.bfloat16)
            offsets = offsets.to(torch.bfloat16)

        if self.training:
            offset_dist = offsets ** 2
            weighted_offset = torch.mean(offset_dist) + torch.mean(offset_dist[self.hands_mask.repeat(num_scenes, 1)]) #+ torch.mean(offset_dist[self.face_mask.repeat(num_scenes, 1)])
        else:
            weighted_offset = offsets
        

        results = dict(
            viz_masks=viz_masks,
            scales=scales,
            norm=norm,
            image=image,
            offset=weighted_offset)

        if return_loss:
            results.update(decoder_reg_loss=self.loss())

        return results
    
    def forward_render(self, code, cameras, num_imgs,
                return_loss=False, return_norm=False, init=False, mask=None, zeros_hands_off=False):
        """
        Args:

        
            density_bitfield: Shape (num_scenes, griz_size**3 // 8)
            YY:
            grid_size, dt_gamma, perturb, T_thresh are deleted
            code: Shape (num_scenes, *code_size)
            cameras: Shape (num_scenes, num_imgs, 19(3+16))
            smpl_params: Shape (num_scenes, 189)

        """
        image = []
        scales = []
        norm = [] if return_norm else None
        viz_masks = [] if not self.training else None

        xyzs, sigmas, rgbs, offsets, radius, tfs, rot =  code
        num_scenes = xyzs.shape[0]
        if zeros_hands_off:
            main_print('zeros_hands_off is on!')
            main_print('zeros_hands_off is on!')
            offsets[self.hands_mask[...,None].repeat(num_scenes, 1, 3)] = 0
            rgbs[self.hands_mask[...,None].repeat(num_scenes, 1, 3)] = 0
            rgbs[self.hands_mask[...,None].repeat(num_scenes, 1, 3)] = 0
        R_delta = batch_rodrigues(rot.reshape(-1, 3))
        R = torch.bmm(self.init_rot.repeat(num_scenes, 1, 1), R_delta)
        R_def = torch.bmm(tfs.flatten(0, 1)[:, :3, :3], R)
        normals = (R_def[:, :, -1]).reshape(num_scenes, -1, 3)
        R_def_batch = R_def.reshape(num_scenes, -1, 3, 3)
        # import ipdb; ipdb.set_trace()
    
        return_to_bfloat16 = True if xyzs.dtype==torch.bfloat16 else False ####### ============ translate the output to BF16 =================
        # return_to_bfloat16 = False # I don't want to trans it back to bf16
        if return_to_bfloat16:
            main_print("changes the return_to_bfloat16")
            cameras, R_def_batch, xyzs, rgbs, sigmas, normals, radius = [ensure_dtype(item, torch.float32) for item in (cameras, R_def_batch, xyzs, rgbs, sigmas, normals, radius)]

        if 1:
            for camera_single, R_def_single, pcd_single, rgbs_single, sigmas_single, normal_single, radius_single in zip(cameras, R_def_batch, xyzs, rgbs, sigmas, normals, radius):
                image_single, norm_single,  viz_mask, scale = self.gaussian_render(pcd_single, sigmas_single, rgbs_single, normal_single, R_def_single, 1, num_imgs, camera_single, use_scale=True, \
                                                                        radius=radius_single, return_norm=return_norm, return_viz=not self.training)
                image.append(image_single)
                scales.append(scale.unsqueeze(0))
                if return_norm:
                    norm.append(norm_single)
                if not self.training:
                    viz_masks.append(viz_mask)
        image = torch.cat(image, dim=0)
        scales = torch.cat(scales, dim=0)

        norm = torch.cat(norm, dim=0) if return_norm else None
        viz_masks = torch.cat(viz_masks, dim=0) if  (not self.training) and viz_masks else None


        main_print("not trans the rendered results to float16")
        if False:
            image = image.to(torch.bfloat16)
            scales = scales.to(torch.bfloat16)
            if return_norm:
                norm = norm.to(torch.bfloat16)
            if viz_masks is not None:
                viz_masks = viz_masks.to(torch.bfloat16)
            offsets = offsets.to(torch.bfloat16)

        if self.training:
            offset_dist = offsets ** 2
            weighted_offset = torch.mean(offset_dist) + torch.mean(offset_dist[self.hands_mask.repeat(num_scenes, 1)]) #+ torch.mean(offset_dist[self.face_mask.repeat(num_scenes, 1)])
        else:
            weighted_offset = offsets
        

        results = dict(
            viz_masks=viz_masks,
            scales=scales,
            norm=norm,
            image=image,
            offset=weighted_offset)
        if return_loss:
            results.update(decoder_reg_loss=self.loss())

        return results
    
    
    def forward_testing_time(self, code, smpl_params, cameras, num_imgs,
                return_loss=False, return_norm=False, init=False, mask=None, zeros_hands_off=False):
        """
        Args:

          
            density_bitfield: Shape (num_scenes, griz_size**3 // 8)
            YY:
            grid_size, dt_gamma, perturb, T_thresh are deleted
            code: Shape (num_scenes, *code_size)
            cameras: Shape (num_scenes, num_imgs, 19(3+16))
            smpl_params: Shape (num_scenes, 189)

        """
        if isinstance(code, list):
            num_scenes = len(code[0])
        else:
            num_scenes = len(code)
        assert num_scenes > 0
        self.iter+=1

        image = []
        scales = []
        norm = [] if return_norm else None
        viz_masks = [] if not self.training else None
        start_time = time.time()
        xyzs, sigmas, rgbs, offsets, radius, tfs, rot = self.extract_pcd(code, smpl_params, init=init, zeros_hands_off=zeros_hands_off)
        end_time_to_3D = time.time()
        time_code_to_3d = end_time_to_3D- start_time 

        R_delta = batch_rodrigues(rot.reshape(-1, 3))
        R = torch.bmm(self.init_rot.repeat(num_scenes, 1, 1), R_delta)
        R_def = torch.bmm(tfs.flatten(0, 1)[:, :3, :3], R)
        normals = (R_def[:, :, -1]).reshape(num_scenes, -1, 3)
        R_def_batch = R_def.reshape(num_scenes, -1, 3, 3)
        if 1:
            for camera_single, R_def_single, pcd_single, rgbs_single, sigmas_single, normal_single, radius_single in zip(cameras, R_def_batch, xyzs, rgbs, sigmas, normals, radius):
                image_single, norm_single,  viz_mask, scale = self.gaussian_render(pcd_single, sigmas_single, rgbs_single, normal_single, R_def_single, 1, num_imgs, camera_single, use_scale=True, \
                                                                        radius=radius_single, return_norm=False, return_viz=not self.training)
                image.append(image_single)
                scales.append(scale.unsqueeze(0))
                if return_norm:
                    norm.append(norm_single)
                if not self.training:
                    viz_masks.append(viz_mask)
        image = torch.cat(image, dim=0)
        scales = torch.cat(scales, dim=0)

        norm = torch.cat(norm, dim=0) if return_norm else None
        viz_masks = torch.cat(viz_masks, dim=0) if  (not self.training) and viz_masks else None

        time_3D_to_img = time.time() - end_time_to_3D


        if False:
            image = image.to(torch.bfloat16)
            scales = scales.to(torch.bfloat16)
            if return_norm:
                norm = norm.to(torch.bfloat16)
            if viz_masks is not None:
                viz_masks = viz_masks.to(torch.bfloat16)
            offsets = offsets.to(torch.bfloat16)

        results = dict(
            image=image)

        return results, time_code_to_3d, time_3D_to_img

================================================
FILE: lib/models/decoders/vit_head.py
================================================
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Sequence, Tuple, Optional

class VitHead(nn.Module):
    def __init__(self,
                 in_channels: int,
                 out_channels: int,
                 deconv_out_channels: Optional[Sequence[int]] = (256, 256, 256),
                 deconv_kernel_sizes: Optional[Sequence[int]] = (4, 4, 4),
                 conv_out_channels: Optional[Sequence[int]] = None,
                 conv_kernel_sizes: Optional[Sequence[int]] = None,
                 ):
        super(VitHead, self).__init__()

        if deconv_out_channels:
            if deconv_kernel_sizes is None or len(deconv_out_channels) != len(deconv_kernel_sizes):
                raise ValueError(
                    '"deconv_out_channels" and "deconv_kernel_sizes" should '
                    'be integer sequences with the same length. Got '
                    f'mismatched lengths {deconv_out_channels} and '
                    f'{deconv_kernel_sizes}')

            self.deconv_layers = self._make_deconv_layers(
                in_channels=in_channels,
                layer_out_channels=deconv_out_channels,
                layer_kernel_sizes=deconv_kernel_sizes,
            )
            in_channels = deconv_out_channels[-1]
        else:
            self.deconv_layers = nn.Identity()

        if conv_out_channels:
            if conv_kernel_sizes is None or len(conv_out_channels) != len(conv_kernel_sizes):
                raise ValueError(
                    '"conv_out_channels" and "conv_kernel_sizes" should '
                    'be integer sequences with the same length. Got '
                    f'mismatched lengths {conv_out_channels} and '
                    f'{conv_kernel_sizes}')

            self.conv_layers = self._make_conv_layers(
                in_channels=in_channels,
                layer_out_channels=conv_out_channels,
                layer_kernel_sizes=conv_kernel_sizes)
            in_channels = conv_out_channels[-1]
        else:
            self.conv_layers = nn.Identity()

        self.cls_seg = nn.Conv2d(in_channels, out_channels, kernel_size=1)

    def _make_conv_layers(self, in_channels: int,
                          layer_out_channels: Sequence[int],
                          layer_kernel_sizes: Sequence[int]) -> nn.Module:
        """Create convolutional layers by given parameters."""
        layers = []
        for out_channels, kernel_size in zip(layer_out_channels, layer_kernel_sizes):
            padding = (kernel_size - 1) // 2
            layers.append(nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=1, padding=padding))
            layers.append(nn.InstanceNorm2d(out_channels))
            layers.append(nn.SiLU(inplace=True))
            in_channels = out_channels

        return nn.Sequential(*layers)

    def _make_deconv_layers(self, in_channels: int,
                            layer_out_channels: Sequence[int],
                            layer_kernel_sizes: Sequence[int]) -> nn.Module:
        """Create deconvolutional layers by given parameters."""
        layers = []
        for out_channels, kernel_size in zip(layer_out_channels, layer_kernel_sizes):
            if kernel_size == 4:
                padding = 1
                output_padding = 0
            elif kernel_size == 3:
                padding = 1
                output_padding = 1
            elif kernel_size == 2:
                padding = 0
                output_padding = 0
            else:
                raise ValueError(f'Unsupported kernel size {kernel_size} for deconvolutional layers')

            layers.append(nn.ConvTranspose2d(in_channels, out_channels, kernel_size=kernel_size, stride=2, padding=padding, output_padding=output_padding, bias=False))
            layers.append(nn.InstanceNorm2d(out_channels))
            layers.append(nn.SiLU(inplace=True))
            in_channels = out_channels

        return nn.Sequential(*layers)

    def forward(self, inputs):
        x = self.deconv_layers(inputs)
        x = self.conv_layers(x)
        out = self.cls_seg(x)
        return out

if __name__ == "__main__":

    # Example usage:
    model = VitHead(in_channels=1536, out_channels=21, deconv_out_channels=(768, 768, 512, 512),
                    deconv_kernel_sizes=(4, 4, 4, 4),
                    conv_out_channels=(512, 256, 128), conv_kernel_sizes=(1, 1, 1),
                    )
    inputs = torch.randn(1, 1536, 64, 64)
    outputs = model(inputs)
    print(outputs.shape)

================================================
FILE: lib/models/deformers/__init__.py
================================================

from .smplx_deformer_gender import SMPLXDeformer_gender

__all__ = ['SMPLXDeformer_gender']

================================================
FILE: lib/models/deformers/fast_snarf/cuda/filter/filter.cpp
================================================
#include <torch/extension.h>
#include <ATen/ATen.h>
#include <vector>
#include <c10/cuda/CUDAGuard.h>





void launch_filter(
  torch::Tensor &output, const torch::Tensor &x, const torch::Tensor &mask);

void filter(const torch::Tensor &x, const torch::Tensor &mask, torch::Tensor &output) {
  const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
  launch_filter(output, x, mask);
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
  m.def("filter", &filter);
}


================================================
FILE: lib/models/deformers/fast_snarf/cuda/filter/filter_kernel.cu
================================================
#include <vector>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/detail/TensorInfo.cuh>
#include <ATen/cuda/detail/IndexUtils.cuh>
#include <ATen/cuda/detail/KernelUtils.h>
#include <ATen/core/TensorBase.h>
#include <ATen/Dispatch.h>
#include <c10/macros/Macros.h>

#include "ATen/Functions.h"
#include "ATen/core/TensorAccessor.h"
#include "c10/cuda/CUDAException.h"
#include "c10/cuda/CUDAStream.h"

#include <chrono>

#define TensorAcc4R PackedTensorAccessor32<scalar_t,4,RestrictPtrTraits>
#define TensorAcc5R PackedTensorAccessor32<scalar_t,5,RestrictPtrTraits>

using namespace at;
using namespace at::cuda::detail;


template <typename scalar_t, typename index_t>
C10_LAUNCH_BOUNDS_1(512)
__global__ void filter(
    const index_t nthreads,
    PackedTensorAccessor32<scalar_t, 4, RestrictPtrTraits> x,
    PackedTensorAccessor32<bool, 3, RestrictPtrTraits> mask,
    PackedTensorAccessor32<bool, 3, RestrictPtrTraits> output) {

    index_t n_batch = mask.size(0);
    index_t n_point = mask.size(1);
    index_t n_init = mask.size(2);
    CUDA_KERNEL_LOOP_TYPE(index, nthreads, index_t) {

        const index_t i_batch = index / (n_batch*n_point);
        const index_t i_point = index % (n_batch*n_point);


        for(index_t i = 0; i < n_init; i++) {
            if(!mask[i_batch][i_point][i]){
                output[i_batch][i_point][i] = false;
                continue;
            }
            scalar_t xi0 = x[i_batch][i_point][i][0];
            scalar_t xi1 = x[i_batch][i_point][i][1];
            scalar_t xi2 = x[i_batch][i_point][i][2];

            bool flag = true;
            for(index_t j = i+1; j < n_init; j++){
                if(!mask[i_batch][i_point][j]){
                    continue;
                }
                scalar_t d0 = xi0 - x[i_batch][i_point][j][0];
                scalar_t d1 = xi1 - x[i_batch][i_point][j][1];
                scalar_t d2 = xi2 - x[i_batch][i_point][j][2];

                scalar_t dist = d0*d0 + d1*d1 + d2*d2;
                if(dist<0.0001*0.0001){
                    flag=false;
                    break;
                }
            }

            output[i_batch][i_point][i] = flag;
        }

    }
}

void launch_filter(
    Tensor &output,
    const Tensor &x,
    const Tensor &mask) {

  // calculate #threads required
  int64_t B = output.size(0);
  int64_t N = output.size(1);

  int64_t count = B*N;
  if (count > 0) {
      AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "filter", [&] {
            filter<scalar_t>
            <<<GET_BLOCKS(count, 512), 512, 0, at::cuda::getCurrentCUDAStream()>>>(
              static_cast<int>(count),
              x.packed_accessor32<scalar_t,4,RestrictPtrTraits>(),
              mask.packed_accessor32<bool,3,RestrictPtrTraits>(),
              output.packed_accessor32<bool,3,RestrictPtrTraits>());
          C10_CUDA_KERNEL_LAUNCH_CHECK();
      });
  }
}


================================================
FILE: lib/models/deformers/fast_snarf/cuda/fuse_kernel/fuse_cuda.cpp
================================================
#include "ATen/Functions.h"
#include "ATen/core/TensorBody.h"
#include <torch/extension.h>
#include <ATen/ATen.h>
#include <vector>
#include <c10/cuda/CUDAGuard.h>



void launch_broyden_kernel(torch::Tensor &x,
                           const torch::Tensor &xd_tgt,
                           const torch::Tensor &grid,
                           const torch::Tensor &grid_J_inv,
                           const torch::Tensor &tfs,
                           const torch::Tensor &bone_ids,
                           bool align_corners,
                          //  torch::Tensor &J_inv,
                           torch::Tensor &is_valid,
                           const torch::Tensor& offset,
                           const torch::Tensor& scale,
                           float cvg_threshold,
                           float dvg_threshold);


void fuse_broyden(torch::Tensor &x,
                  const torch::Tensor &xd_tgt,
                  const torch::Tensor &grid,
                  const torch::Tensor &grid_J_inv,
                  const torch::Tensor &tfs,
                  const torch::Tensor &bone_ids,
                  bool align_corners,
                  // torch::Tensor& J_inv,
                  torch::Tensor &is_valid,
                  torch::Tensor& offset,
                  torch::Tensor& scale,
                  float cvg_threshold,
                  float dvg_threshold) {

  const at::cuda::OptionalCUDAGuard device_guard(device_of(x));

  launch_broyden_kernel(x, xd_tgt, grid, grid_J_inv, tfs, bone_ids, align_corners, is_valid, offset, scale, cvg_threshold, dvg_threshold);
  return;
}



PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
  m.def("fuse_broyden", &fuse_broyden);
}

================================================
FILE: lib/models/deformers/fast_snarf/cuda/fuse_kernel/fuse_cuda_kernel.cu
================================================
#include "ATen/Functions.h"
#include "ATen/core/TensorAccessor.h"
#include "c10/cuda/CUDAException.h"
#include "c10/cuda/CUDAStream.h"

#include <ratio>
#include <vector>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/detail/TensorInfo.cuh>
#include <ATen/cuda/detail/IndexUtils.cuh>
#include <ATen/cuda/detail/KernelUtils.h>
#include <ATen/core/TensorBase.h>
#include <ATen/Dispatch.h>
#include <c10/macros/Macros.h>

#include <chrono>
using namespace std::chrono;

using namespace at;
using namespace at::cuda::detail;

template <typename scalar_t, typename index_t>
__device__ void fuse_J_inv_update(const index_t index,
                                  scalar_t* J_inv,
                                  scalar_t x0,
                                  scalar_t x1,
                                  scalar_t x2,
                                  scalar_t g0,
                                  scalar_t g1,
                                  scalar_t g2) {

    // index_t s_J = J_inv.strides[0];
    // index_t s_x = delta_x.strides[0];
    // index_t s_g = delta_gx.strides[0];
    index_t s_J = 9;
    index_t s_x = 3;
    index_t s_g = 3;

    scalar_t J00 = J_inv[3*0 + 0];
    scalar_t J01 = J_inv[3*0 + 1];
    scalar_t J02 = J_inv[3*0 + 2];
    scalar_t J10 = J_inv[3*1 + 0];
    scalar_t J11 = J_inv[3*1 + 1];
    scalar_t J12 = J_inv[3*1 + 2];
    scalar_t J20 = J_inv[3*2 + 0];
    scalar_t J21 = J_inv[3*2 + 1];
    scalar_t J22 = J_inv[3*2 + 2];

    auto c0 = J00 * x0 + J10 * x1 + J20 * x2;
    auto c1 = J01 * x0 + J11 * x1 + J21 * x2;
    auto c2 = J02 * x0 + J12 * x1 + J22 * x2;

    auto s = c0 * g0 + c1 * g1 + c2 * g2;

    auto r0 = -J00 * g0 - J01 * g1 - J02 * g2;
    auto r1 = -J10 * g0 - J11 * g1 - J12 * g2;
    auto r2 = -J20 * g0 - J21 * g1 - J22 * g2;

    J_inv[3*0 + 0] += c0 * (r0 + x0) / s;
    J_inv[3*0 + 1] += c1 * (r0 + x0) / s;
    J_inv[3*0 + 2] += c2 * (r0 + x0) / s;
    J_inv[3*1 + 0] += c0 * (r1 + x1) / s;
    J_inv[3*1 + 1] += c1 * (r1 + x1) / s;
    J_inv[3*1 + 2] += c2 * (r1 + x1) / s;
    J_inv[3*2 + 0] += c0 * (r2 + x2) / s;
    J_inv[3*2 + 1] += c1 * (r2 + x2) / s;
    J_inv[3*2 + 2] += c2 * (r2 + x2) / s;

}


static __forceinline__ __device__
bool within_bounds_3d(int d, int h, int w, int D, int H, int W) {
  return d >= 0 && d < D && h >= 0 && h < H && w >= 0 && w < W;
}


template <typename scalar_t>
static __forceinline__ __device__
scalar_t grid_sampler_unnormalize(scalar_t coord, int size, bool align_corners) {
  if (align_corners) {
    // unnormalize coord from [-1, 1] to [0, size - 1]
    return ((coord + 1.f) / 2) * (size - 1);
  } else {
    // unnormalize coord from [-1, 1] to [-0.5, size - 0.5]
    return ((coord + 1.f) * size - 1) / 2;
  }
}


// Clips coordinates to between 0 and clip_limit - 1
template <typename scalar_t>
static __forceinline__ __device__
scalar_t clip_coordinates(scalar_t in, int clip_limit) {
  return ::min(static_cast<scalar_t>(clip_limit - 1), ::max(in, static_cast<scalar_t>(0)));
}


template<typename scalar_t>
static __forceinline__ __device__
scalar_t safe_downgrade_to_int_range(scalar_t x){
  // -100.0 does not have special meaning. This is just to make sure
  // it's not within_bounds_2d or within_bounds_3d, and does not cause
  // undefined behavior. See #35506.
  if (x > INT_MAX-1 || x < INT_MIN || !::isfinite(static_cast<double>(x)))
    return static_cast<scalar_t>(-100.0);
  return x;
}


template<typename scalar_t>
static __forceinline__ __device__
scalar_t compute_coordinates(scalar_t coord, int size, bool align_corners) {
  // clip coordinates to image borders
  // coord = clip_coordinates(coord, size);
  coord = safe_downgrade_to_int_range(coord);
  return coord;
}

template <typename scalar_t>
static __forceinline__ __device__
scalar_t grid_sampler_compute_source_index(scalar_t coord, int size, bool align_corners) {
  coord = grid_sampler_unnormalize(coord, size, align_corners);
  coord = compute_coordinates(coord, size, align_corners);
  return coord;
}

template <typename scalar_t, typename index_t>
__device__ void grid_sampler_3d(
                                index_t i_batch,
                                TensorInfo<scalar_t, index_t> input,
                                scalar_t grid_x,
                                scalar_t grid_y,
                                scalar_t grid_z,
                                // TensorInfo<scalar_t, index_t> output,
                                PackedTensorAccessor32<scalar_t, 5> input_p, // [1, 3, 8, 32, 32]
                                scalar_t* output,
                                // PackedTensorAccessor32<scalar_t, 3> output_p, // [1800000, 3, 1]
                                bool align_corners,
                                bool nearest) {

  index_t C = input.sizes[1];
  index_t inp_D = input.sizes[2];
  index_t inp_H = input.sizes[3];
  index_t inp_W = input.sizes[4];
  // broyden x.sizes=[1800000,3,1]
  
  index_t inp_sN = input.strides[0];
  index_t inp_sC = input.strides[1];
  index_t inp_sD = input.strides[2];
  index_t inp_sH = input.strides[3];
  index_t inp_sW = input.strides[4];

  index_t out_sC = 1; //output size is same as grid size...

  // return;

  // get the corresponding input x, y, z co-ordinates from grid
  scalar_t ix = grid_x;
  scalar_t iy = grid_y;
  scalar_t iz = grid_z;

  // c0 ix,iy,iz=-0.848051,0.592726,0.259927
  // c1 ix,iy,iz=2.355216,24.687256,4.409743

  ix = grid_sampler_compute_source_index(ix, inp_W, align_corners);
  iy = grid_sampler_compute_source_index(iy, inp_H, align_corners);
  iz = grid_sampler_compute_source_index(iz, inp_D, align_corners);

  if(!nearest){
    // get corner pixel values from (x, y, z)
    // for 4d, we used north-east-south-west
    // for 5d, we add top-bottom
    index_t ix_tnw = static_cast<index_t>(::floor(ix));
    index_t iy_tnw = static_cast<index_t>(::floor(iy));
    index_t iz_tnw = static_cast<index_t>(::floor(iz));

    index_t ix_tne = ix_tnw + 1;
    index_t iy_tne = iy_tnw;
    index_t iz_tne = iz_tnw;

    index_t ix_tsw = ix_tnw;
    index_t iy_tsw = iy_tnw + 1;
    index_t iz_tsw = iz_tnw;

    index_t ix_tse = ix_tnw + 1;
    index_t iy_tse = iy_tnw + 1;
    index_t iz_tse = iz_tnw;

    index_t ix_bnw = ix_tnw;
    index_t iy_bnw = iy_tnw;
    index_t iz_bnw = iz_tnw + 1;

    index_t ix_bne = ix_tnw + 1;
    index_t iy_bne = iy_tnw;
    index_t iz_bne = iz_tnw + 1;

    index_t ix_bsw = ix_tnw;
    index_t iy_bsw = iy_tnw + 1;
    index_t iz_bsw = iz_tnw + 1;

    index_t ix_bse = ix_tnw + 1;
    index_t iy_bse = iy_tnw + 1;
    index_t iz_bse = iz_tnw + 1;

    // get surfaces to each neighbor:
    scalar_t tnw = (ix_bse - ix)    * (iy_bse - iy)    * (iz_bse - iz);
    scalar_t tne = (ix    - ix_bsw) * (iy_bsw - iy)    * (iz_bsw - iz);
    scalar_t tsw = (ix_bne - ix)    * (iy    - iy_bne) * (iz_bne - iz);
    scalar_t tse = (ix    - ix_bnw) * (iy    - iy_bnw) * (iz_bnw - iz);
    scalar_t bnw = (ix_tse - ix)    * (iy_tse - iy)    * (iz - iz_tse);
    scalar_t bne = (ix    - ix_tsw) * (iy_tsw - iy)    * (iz - iz_tsw);
    scalar_t bsw = (ix_tne - ix)    * (iy    - iy_tne) * (iz - iz_tne);
    scalar_t bse = (ix    - ix_tnw) * (iy    - iy_tnw) * (iz - iz_tnw);


    for (index_t xyz = 0; xyz < C; xyz++) {
      output[xyz] = 0;

      if (within_bounds_3d(iz_tnw, iy_tnw, ix_tnw, inp_D, inp_H, inp_W)) {
        output[xyz] += input_p[i_batch][xyz][iz_tnw][iy_tnw][ix_tnw] * tnw;
        // *out_ptr_NCDHW += inp_ptr_NC[iz_tnw * inp_sD + iy_tnw * inp_sH + ix_tnw * inp_sW] * tnw;
      }
      if (within_bounds_3d(iz_tne, iy_tne, ix_tne, inp_D, inp_H, inp_W)) {
        output[xyz] += input_p[i_batch][xyz][iz_tne][iy_tne][ix_tne] * tne;
        // *out_ptr_NCDHW += inp_ptr_NC[iz_tne * inp_sD + iy_tne * inp_sH + ix_tne * inp_sW] * tne;
      }
      if (within_bounds_3d(iz_tsw, iy_tsw, ix_tsw, inp_D, inp_H, inp_W)) {
        output[xyz] += input_p[i_batch][xyz][iz_tsw][iy_tsw][ix_tsw] * tsw;
        // *out_ptr_NCDHW += inp_ptr_NC[iz_tsw * inp_sD + iy_tsw * inp_sH + ix_tsw * inp_sW] * tsw;
      }
      if (within_bounds_3d(iz_tse, iy_tse, ix_tse, inp_D, inp_H, inp_W)) {
        output[xyz] += input_p[i_batch][xyz][iz_tse][iy_tse][ix_tse] * tse;
        // *out_ptr_NCDHW += inp_ptr_NC[iz_tse * inp_sD + iy_tse * inp_sH + ix_tse * inp_sW] * tse;
      }
      if (within_bounds_3d(iz_bnw, iy_bnw, ix_bnw, inp_D, inp_H, inp_W)) {
        output[xyz] += input_p[i_batch][xyz][iz_bnw][iy_bnw][ix_bnw] * bnw;
        // *out_ptr_NCDHW += inp_ptr_NC[iz_bnw * inp_sD + iy_bnw * inp_sH + ix_bnw * inp_sW] * bnw;
      }
      if (within_bounds_3d(iz_bne, iy_bne, ix_bne, inp_D, inp_H, inp_W)) {
        output[xyz] += input_p[i_batch][xyz][iz_bne][iy_bne][ix_bne] * bne;
        // *out_ptr_NCDHW += inp_ptr_NC[iz_bne * inp_sD + iy_bne * inp_sH + ix_bne * inp_sW] * bne;
      }
      if (within_bounds_3d(iz_bsw, iy_bsw, ix_bsw, inp_D, inp_H, inp_W)) {
        output[xyz] += input_p[i_batch][xyz][iz_bsw][iy_bsw][ix_bsw] * bsw;
        // *out_ptr_NCDHW += inp_ptr_NC[iz_bsw * inp_sD + iy_bsw * inp_sH + ix_bsw * inp_sW] * bsw;
      }
      if (within_bounds_3d(iz_bse, iy_bse, ix_bse, inp_D, inp_H, inp_W)) {
        output[xyz] += input_p[i_batch][xyz][iz_bse][iy_bse][ix_bse] * bse;
        // *out_ptr_NCDHW += inp_ptr_NC[iz_bse * inp_sD + iy_bse * inp_sH + ix_bse * inp_sW] * bse;
      }
    }
    }
    else{
        index_t ix_nearest = static_cast<index_t>(::round(ix));
        index_t iy_nearest = static_cast<index_t>(::round(iy));
        index_t iz_nearest = static_cast<index_t>(::round(iz));

        for (index_t xyz = 0; xyz < C; xyz++) {
          if (within_bounds_3d(iz_nearest, iy_nearest, ix_nearest, inp_D, inp_H, inp_W)) {
            output[xyz] = input_p[i_batch][xyz][iz_nearest][iy_nearest][ix_nearest]; 
          } else {
            output[xyz] = static_cast<scalar_t>(0);
          }
        }
      }
}



template <typename scalar_t, typename index_t>
C10_LAUNCH_BOUNDS_1(512)
__global__ void broyden_kernel(
                               const index_t npoints,
                              const index_t n_batch,
                              const index_t n_point,
                              const index_t n_init,
                               TensorInfo<scalar_t, index_t> voxel_ti,
                               TensorInfo<scalar_t, index_t> voxel_J_ti,
                               PackedTensorAccessor32<scalar_t, 4> x, // shape=(N,200000, 9, 3)
                               PackedTensorAccessor32<scalar_t, 3> xd_tgt, // shape=(N,200000, 3)
                               PackedTensorAccessor32<scalar_t, 5> voxel, // shape=(N,3,8,32,32)
                               PackedTensorAccessor32<scalar_t, 5> grid_J_inv, // shape=(N,9,8,32,32)
                               PackedTensorAccessor32<scalar_t, 4> tfs, // shape=(N,24,4,4)
                               PackedTensorAccessor32<int, 1> bone_ids, // shape=(9)
                              //  PackedTensorAccessor32<scalar_t, 5> J_inv,// shape=(N,200000, 9, 9)
                               PackedTensorAccessor32<bool, 3> is_valid,// shape=(N,200000, 9)
                               PackedTensorAccessor32<scalar_t, 3> offset, // shape=(N, 1, 3) 
                               PackedTensorAccessor32<scalar_t, 3> scale, // shape=(N, 1, 3)
                               float cvg_threshold,
                               float dvg_threshold,
                               int N
                               ) 
{

  index_t index = blockIdx.x * blockDim.x + threadIdx.x;
  if(index >= npoints) return;

  
  const index_t i_batch = index / (n_point*n_init);
  const index_t i_point = (index % (n_point*n_init)) / n_init;
  const index_t i_init = (index %  (n_point*n_init)) % n_init;
 
  if(!is_valid[i_batch][i_point][i_init]){
    return;
  }

  scalar_t gx[3];
  scalar_t gx_new[3];

  scalar_t xd_tgt_index[3];

  xd_tgt_index[0] = xd_tgt[i_batch][i_point][0];
  xd_tgt_index[1] = xd_tgt[i_batch][i_point][1];
  xd_tgt_index[2] = xd_tgt[i_batch][i_point][2];

  scalar_t x_l[3];

  int i_bone = bone_ids[i_init];
  scalar_t ixd = xd_tgt_index[0] - tfs[i_batch][i_bone][0][3]; 
  scalar_t iyd = xd_tgt_index[1] - tfs[i_batch][i_bone][1][3]; 
  scalar_t izd = xd_tgt_index[2] - tfs[i_batch][i_bone][2][3]; 
  x_l[0] = ixd * tfs[i_batch][i_bone][0][0]  
            + iyd * tfs[i_batch][i_bone][1][0] 
            + izd * tfs[i_batch][i_bone][2][0];


  x_l[1] = ixd * tfs[i_batch][i_bone][0][1]  
            + iyd * tfs[i_batch][i_bone][1][1] 
            + izd * tfs[i_batch][i_bone][2][1];


  x_l[2] = ixd * tfs[i_batch][i_bone][0][2]  
            + iyd * tfs[i_batch][i_bone][1][2] 
            + izd * tfs[i_batch][i_bone][2][2];


  // scalar_t ix = scale[0][0][0] * (x_l[0] + offset[0][0][0]);
  // scalar_t iy = scale[0][0][1] * (x_l[1] + offset[0][0][1]);
  // scalar_t iz = scale[0][0][2] * (x_l[2] + offset[0][0][2]);

  scalar_t J_local[12];
  grid_sampler_3d( i_batch, voxel_J_ti,
                    scale[0][0][0] * (x_l[0] + offset[0][0][0]),
                    scale[0][0][1] * (x_l[1] + offset[0][0][1]),
                    scale[0][0][2] * (x_l[2] + offset[0][0][2]),
                    grid_J_inv,
                    J_local,
                    true,
                    false);

  scalar_t J_inv_local[9];
  J_inv_local[3*0 + 0] = J_local[4*0 + 0]; 
  J_inv_local[3*1 + 0] = J_local[4*0 + 1];
  J_inv_local[3*2 + 0] = J_local[4*0 + 2];
  J_inv_local[3*0 + 1] = J_local[4*1 + 0];
  J_inv_local[3*1 + 1] = J_local[4*1 + 1];
  J_inv_local[3*2 + 1] = J_local[4*1 + 2];
  J_inv_local[3*0 + 2] = J_local[4*2 + 0];
  J_inv_local[3*1 + 2] = J_local[4*2 + 1];
  J_inv_local[3*2 + 2] = J_local[4*2 + 2];
  
  for(int i=0; i<10; i++) {

    scalar_t J00 = J_inv_local[3*0 + 0]; 
    scalar_t J01 = J_inv_local[3*0 + 1];
    scalar_t J02 = J_inv_local[3*0 + 2];
    scalar_t J10 = J_inv_local[3*1 + 0];
    scalar_t J11 = J_inv_local[3*1 + 1];
    scalar_t J12 = J_inv_local[3*1 + 2];
    scalar_t J20 = J_inv_local[3*2 + 0];
    scalar_t J21 = J_inv_local[3*2 + 1];
    scalar_t J22 = J_inv_local[3*2 + 2];

    // gx = g(x)
    if (i==0){
      // grid_sampler_3d( i_batch, voxel_ti,
      //                   scale[0][0][0] * (x_l[0] + offset[0][0][0]),
      //                   scale[0][0][1] * (x_l[1] + offset[0][0][1]),
      //                   scale[0][0][2] * (x_l[2] + offset[0][0][2]),
      //                   voxel,
      //                   gx,
      //                   true,
      //                   false);
                        
      gx[0] = J_local[4*0+0] * x_l[0] + J_local[4*0+1] * x_l[1] + J_local[4*0+2] * x_l[2] + J_local[4*0+3];
      gx[1] = J_local[4*1+0] * x_l[0] + J_local[4*1+1] * x_l[1] + J_local[4*1+2] * x_l[2] + J_local[4*1+3];
      gx[2] = J_local[4*2+0] * x_l[0] + J_local[4*2+1] * x_l[1] + J_local[4*2+2] * x_l[2] + J_local[4*2+3];
                      
      gx[0] = gx[0] - xd_tgt_index[0];
      gx[1] = gx[1] - xd_tgt_index[1];
      gx[2] = gx[2] - xd_tgt_index[2];

    }
    else{
      gx[0] = gx_new[0];
      gx[1] = gx_new[1];
      gx[2] = gx_new[2];
    }

    // update = -J_inv @ gx
    scalar_t u0 = -J00*gx[0] + -J01*gx[1] + -J02*gx[2];
    scalar_t u1 = -J10*gx[0] + -J11*gx[1] + -J12*gx[2];
    scalar_t u2 = -J20*gx[0] + -J21*gx[1] + -J22*gx[2];

    // x += update
    x_l[0] += u0;
    x_l[1] += u1;
    x_l[2] += u2;

    scalar_t ix = scale[0][0][0] * (x_l[0] + offset[0][0][0]);
    scalar_t iy = scale[0][0][1] * (x_l[1] + offset[0][0][1]);
    scalar_t iz = scale[0][0][2] * (x_l[2] + offset[0][0][2]);

    // gx_new = g(x)
    grid_sampler_3d( i_batch, voxel_J_ti,
                      ix,
                      iy,
                      iz,
                      grid_J_inv,
                      J_local,
                      true,
                      false);
                      
    gx_new[0] = J_local[4*0+0] * x_l[0] + J_local[4*0+1] * x_l[1] + J_local[4*0+2] * x_l[2] + J_local[4*0+3] - xd_tgt_index[0];
    gx_new[1] = J_local[4*1+0] * x_l[0] + J_local[4*1+1] * x_l[1] + J_local[4*1+2] * x_l[2] + J_local[4*1+3] - xd_tgt_index[1];
    gx_new[2] = J_local[4*2+0] * x_l[0] + J_local[4*2+1] * x_l[1] + J_local[4*2+2] * x_l[2] + J_local[4*2+3] - xd_tgt_index[2];

      // grid_sampler_3d( i_batch, voxel_ti,
      //                   scale[0][0][0] * (x_l[0] + offset[0][0][0]),
      //                   scale[0][0][1] * (x_l[1] + offset[0][0][1]),
      //                   scale[0][0][2] * (x_l[2] + offset[0][0][2]),
      //                   voxel,
      //                   gx_new,
      //                   true,
      //                   false);

      // gx_new[0] = gx_new[0] - xd_tgt_index[0];
      // gx_new[1] = gx_new[1] - xd_tgt_index[1];
      // gx_new[2] = gx_new[2] - xd_tgt_index[2];
             
    // convergence checking
    scalar_t norm_gx = gx_new[0]*gx_new[0] + gx_new[1]*gx_new[1] + gx_new[2]*gx_new[2];

    // convergence/divergence criterion
    if(norm_gx < cvg_threshold*cvg_threshold) {

      bool is_valid_ = ix >= -1 && ix <= 1 && iy >= -1 && iy <= 1 && iz >= -1 && iz <= 1;

      is_valid[i_batch][i_point][i_init] = is_valid_;

      if (is_valid_){
      x[i_batch][i_point][i_init][0] = x_l[0];
      x[i_batch][i_point][i_init][1] = x_l[1];
      x[i_batch][i_point][i_init][2] = x_l[2];

      // J_inv[i_batch][i_point][i_init][0][0] = J00;
      // J_inv[i_batch][i_point][i_init][0][1] = J01;
      // J_inv[i_batch][i_point][i_init][0][2] = J02;
      // J_inv[i_batch][i_point][i_init][1][0] = J10;
      // J_inv[i_batch][i_point][i_init][1][1] = J11;
      // J_inv[i_batch][i_point][i_init][1][2] = J12;
      // J_inv[i_batch][i_point][i_init][2][0] = J20;
      // J_inv[i_batch][i_point][i_init][2][1] = J21;
      // J_inv[i_batch][i_point][i_init][2][2] = J22;
      }
      return;

    }
    else if(norm_gx > dvg_threshold*dvg_threshold) {
      is_valid[i_batch][i_point][i_init] = false;
      return;
    }
    

    // delta_x = update
    scalar_t delta_x_0 = u0;
    scalar_t delta_x_1 = u1;
    scalar_t delta_x_2 = u2;

    // delta_gx = gx_new - gx
    scalar_t delta_gx_0 = gx_new[0] - gx[0];
    scalar_t delta_gx_1 = gx_new[1] - gx[1];
    scalar_t delta_gx_2 = gx_new[2] - gx[2];

    fuse_J_inv_update(index, J_inv_local, delta_x_0, delta_x_1, delta_x_2, delta_gx_0, delta_gx_1, delta_gx_2);

  }
  is_valid[i_batch][i_point][i_init] = false;
  return;

}


void launch_broyden_kernel(
                           Tensor &x,
                           const Tensor &xd_tgt,
    const Tensor &voxel,
    const Tensor &grid_J_inv,
    const Tensor &tfs,
    const Tensor &bone_ids,
    bool align_corners,
    // Tensor &J_inv,
    Tensor &is_valid,
    const Tensor& offset,
    const Tensor& scale,
    float cvg_threshold,
    float dvg_threshold) {

  // calculate #threads required
  int64_t n_batch = xd_tgt.size(0);
  int64_t n_point = xd_tgt.size(1);
  int64_t n_init = bone_ids.size(0);
  int64_t count = n_batch * n_point * n_init;


  if (count > 0) {
    AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "fuse_kernel_cuda", [&] {
      broyden_kernel
        <<<GET_BLOCKS(count, 512), 512, 0,
        at::cuda::getCurrentCUDAStream()>>>(static_cast<int>(count),
                                            static_cast<int>(n_batch),
                                            static_cast<int>(n_point),
                                            static_cast<int>(n_init),
                                            getTensorInfo<scalar_t, int>(voxel),
                                            getTensorInfo<scalar_t, int>(grid_J_inv),
                                            x.packed_accessor32<scalar_t, 4>(),
                                            xd_tgt.packed_accessor32<scalar_t, 3>(),
                                            voxel.packed_accessor32<scalar_t, 5>(),
                                            grid_J_inv.packed_accessor32<scalar_t, 5>(),
                                            tfs.packed_accessor32<scalar_t, 4>(),
                                            bone_ids.packed_accessor32<int, 1>(),
                                            // J_inv.packed_accessor32<scalar_t, 5>(),
                                            is_valid.packed_accessor32<bool, 3>(),
                                            offset.packed_accessor32<scalar_t, 3>(),
                                            scale.packed_accessor32<scalar_t, 3>(),
                                            cvg_threshold,
                                            dvg_threshold,
                                            0);
      C10_CUDA_KERNEL_LAUNCH_CHECK();

    });
  }

  cudaDeviceSynchronize();
}



================================================
FILE: lib/models/deformers/fast_snarf/cuda/precompute/precompute.cpp
================================================
#include <torch/extension.h>
#include <ATen/ATen.h>
#include <vector>
#include <c10/cuda/CUDAGuard.h>





void launch_precompute(const torch::Tensor &voxel_w, const torch::Tensor &tfs, torch::Tensor &voxel_d, torch::Tensor &voxel_J, const torch::Tensor &offset, const torch::Tensor &scale);

void precompute(const torch::Tensor &voxel_w, const torch::Tensor &tfs, torch::Tensor &voxel_d, torch::Tensor &voxel_J, const torch::Tensor &offset, const torch::Tensor &scale) {
  const at::cuda::OptionalCUDAGuard device_guard(device_of(voxel_w));

  launch_precompute(voxel_w, tfs, voxel_d, voxel_J, offset, scale);
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
  m.def("precompute", &precompute);
}


================================================
FILE: lib/models/deformers/fast_snarf/cuda/precompute/precompute_kernel.cu
================================================
#include "ATen/Functions.h"
#include "ATen/core/TensorAccessor.h"
#include "c10/cuda/CUDAException.h"
#include "c10/cuda/CUDAStream.h"

#include <ratio>
#include <vector>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/detail/TensorInfo.cuh>
#include <ATen/cuda/detail/IndexUtils.cuh>
#include <ATen/cuda/detail/KernelUtils.h>
#include <ATen/core/TensorBase.h>
#include <ATen/Dispatch.h>
#include <c10/macros/Macros.h>

#include <chrono>
using namespace std::chrono;

using namespace at;
using namespace at::cuda::detail;


template <typename scalar_t, typename index_t>
C10_LAUNCH_BOUNDS_1(512)
__global__ void precompute_kernel(
                              const index_t npoints,
                              const index_t d,
                              const index_t h,
                              const index_t w,
                               PackedTensorAccessor32<scalar_t, 5> voxel_w, // shape=(N,200000, 9, 3)
                               PackedTensorAccessor32<scalar_t, 4> tfs, // shape=(N,200000, 3)
                               PackedTensorAccessor32<scalar_t, 5> voxel_d, // shape=(N,3,8,32,32)
                               PackedTensorAccessor32<scalar_t, 5> voxel_J, // shape=(N,9,8,32,32)
                               PackedTensorAccessor32<scalar_t, 3> offset, // shape=(N, 1, 3) 
                               PackedTensorAccessor32<scalar_t, 3> scale // shape=(N, 1, 3)
                               ) 
{

  index_t index = blockIdx.x * blockDim.x + threadIdx.x;
  if(index >= npoints) return;

  index_t idx_b = index / (d*h*w);
  index_t idx_d = index % (d*h*w) / (h*w);
  index_t idx_h = index % (d*h*w) % (h*w) / w;
  index_t idx_w = index % (d*h*w) % (h*w) % w;

  scalar_t coord_x = ( ((scalar_t)idx_w) / (w-1) * 2 -1) / scale[0][0][0] - offset[0][0][0];
  scalar_t coord_y = ( ((scalar_t)idx_h) / (h-1) * 2 -1) / scale[0][0][1] - offset[0][0][1];
  scalar_t coord_z = ( ((scalar_t)idx_d) / (d-1) * 2 -1) / scale[0][0][2] - offset[0][0][2];

  scalar_t J[12];

  for(index_t i0 = 0; i0 < 3; i0++){
    for(index_t i1 = 0; i1 < 4; i1++){
      J[i0*4 + i1] = 0;
      for(index_t j = 0; j < 24; j++){
        J[i0*4 + i1] += voxel_w[0][j][idx_d][idx_h][idx_w]*tfs[idx_b][j][i0][i1];
      }
    }
  }
  for(index_t i0 = 0; i0 < 3; i0++){
    for(index_t i1 = 0; i1 < 4; i1++){
      voxel_J[idx_b][i0*4 + i1][idx_d][idx_h][idx_w] = J[i0*4 + i1];
    }
  }

  for(index_t i0 = 0; i0 < 3; i0++){
    scalar_t xi = J[i0*4 + 0]*coord_x + J[i0*4 + 1]*coord_y + J[i0*4 + 2]*coord_z + J[i0*4 + 3];
    voxel_d[idx_b][i0][idx_d][idx_h][idx_w] = xi;
  }
}

void launch_precompute(
                           const Tensor &voxel_w,
                           const Tensor &tfs,
                           Tensor &voxel_d,
                           Tensor &voxel_J,
                           const Tensor &offset,
                           const Tensor &scale
                            ) {

  // calculate #threads required
  int64_t n_batch = voxel_d.size(0);
  int64_t d = voxel_d.size(2);
  int64_t h = voxel_d.size(3);
  int64_t w = voxel_d.size(4);

  int64_t count = n_batch*d*h*w;


  if (count > 0) {
    AT_DISPATCH_FLOATING_TYPES_AND_HALF(voxel_w.scalar_type(), "precompute", [&] {
      precompute_kernel
        <<<GET_BLOCKS(count, 512), 512, 0,
        at::cuda::getCurrentCUDAStream()>>>(static_cast<int>(count),
                                            static_cast<int>(d),
                                            static_cast<int>(h),
                                            static_cast<int>(w),
                                            voxel_w.packed_accessor32<scalar_t, 5>(),
                                            tfs.packed_accessor32<scalar_t, 4>(),
                                            voxel_d.packed_accessor32<scalar_t, 5>(),
                                            voxel_J.packed_accessor32<scalar_t, 5>(),
                                            offset.packed_accessor32<scalar_t, 3>(),
                                            scale.packed_accessor32<scalar_t, 3>()
                                            );

      C10_CUDA_KERNEL_LAUNCH_CHECK();

    });
  }

  cudaDeviceSynchronize();
}


================================================
FILE: lib/models/deformers/fast_snarf/lib/model/deformer_smpl.py
================================================
import torch
from torch import einsum
import torch.nn.functional as F
import os

from torch.utils.cpp_extension import load

import fuse_cuda 
import filter_cuda
import precompute_cuda
import numpy as np


class ForwardDeformer(torch.nn.Module):
    """
    Tensor shape abbreviation:
        B: batch size
        N: number of points
        J: number of bones
        I: number of init
        D: space dimension
    """

    def __init__(self,  **kwargs):
        super().__init__()

        self.soft_blend = 20

        self.init_bones = [0, 1, 2, 4, 5, 12, 15, 16, 17, 18, 19]
        
        self.init_bones_cuda = torch.tensor(self.init_bones).int()
        
        self.global_scale = 1.2
        

    def forward(self, xd, cond, mask, tfs, eval_mode=False):
        """Given deformed point return its caonical correspondence
        Args:
            xd (tensor): deformed points in batch. shape: [B, N, D]
            cond (dict): conditional input.
            tfs (tensor): bone transformation matrices. shape: [B, J, D+1, D+1]
        Returns:
            xc (tensor): canonical correspondences. shape: [B, N, I, D]
            others (dict): other useful outputs.
        """

        xc_opt, others = self.search(xd, cond, mask, tfs, eval_mode=True)


        if eval_mode:
            return xc_opt, others


    def precompute(self, tfs):

        b, c, d, h, w = tfs.shape[0], 3, self.resolution//4, self.resolution, self.resolution
        voxel_d = torch.zeros((b,3,d,h,w), device=tfs.device)
        voxel_J = torch.zeros((b,12,d,h,w), device=tfs.device)
        precompute_cuda.precompute(self.lbs_voxel_final, tfs, voxel_d, voxel_J, self.offset_kernel, self.scale_kernel)
        self.voxel_d = voxel_d
        self.voxel_J = voxel_J

    def search(self, xd, cond, mask, tfs, eval_mode=False):
        """Search correspondences.

        Args:
            xd (tensor): deformed points in batch. shape: [B, N, D]
            xc_init (tensor): deformed points in batch. shape: [B, N, I, D]
            cond (dict): conditional input.
            tfs (tensor): bone transformation matrices. shape: [B, J, D+1, D+1]

        Returns:
            xc_opt (tensor): canonoical correspondences of xd. shape: [B, N, I, D]
            valid_ids (tensor): identifiers of converged points. [B, N, I]
        """

        # reshape to [B,?,D] for other functions

        # run broyden without grad
        with torch.no_grad():
            result = self.broyden_cuda(xd, self.voxel_d, self.voxel_J, tfs, mask)

        return result['result'], result

    def broyden_cuda(self,
                    xd_tgt,
                    voxel,
                    voxel_J_inv,
                    tfs,
                    mask,
                    cvg_thresh=2e-4,
                    dvg_thresh=1):
        """
        Args:
            g:     f: (N, 3, 1) -> (N, 3, 1)
            x:     (N, 3, 1)
            J_inv: (N, 3, 3)
        """
        b,n,_ = xd_tgt.shape
        n_init = self.init_bones_cuda.shape[0]

        xc_init_IN = torch.zeros((b,n,n_init,3),device=xd_tgt.device,dtype=torch.float)

        is_valid = mask.expand(b,n,n_init).clone()

        if self.init_bones_cuda.device != xd_tgt.device:
            self.init_bones_cuda = self.init_bones_cuda.to(xd_tgt.device)
        fuse_cuda.fuse_broyden(xc_init_IN, xd_tgt, voxel, voxel_J_inv, tfs, self.init_bones_cuda, True, is_valid, self.offset_kernel, self.scale_kernel, cvg_thresh, dvg_thresh)

        is_valid_new = torch.zeros_like(is_valid)
        filter_cuda.filter(xc_init_IN, is_valid, is_valid_new)

        return {"result": xc_init_IN, 'valid_ids': is_valid_new} #, 'J_inv': J_inv_init_IN}


    def forward_skinning(self, xc, cond, tfs, mask=None):
        """Canonical point -> deformed point

        Args:
            xc (tensor): canonoical points in batch. shape: [B, N, D]
            cond (dict): conditional input.
            tfs (tensor): bone transformation matrices. shape: [B, J, D+1, D+1]

        Returns:
   
Download .txt
gitextract_sj4g9tfa/

├── .gitignore
├── README.md
├── configs/
│   ├── idol_debug.yaml
│   ├── idol_v0.yaml
│   └── test_dataset.yaml
├── data_processing/
│   ├── prepare_cache.py
│   ├── process_datasets.sh
│   └── visualize_samples.py
├── dataset/
│   ├── README.md
│   ├── sample/
│   │   └── param/
│   │       └── Kenya_female_fit_streetwear_50~60 years old_1501.npy
│   └── visualize_samples.py
├── env/
│   └── README.md
├── lib/
│   ├── __init__.py
│   ├── datasets/
│   │   ├── __init__.py
│   │   ├── avatar_dataset.py
│   │   └── dataloader.py
│   ├── humanlrm_wrapper_sa_v1.py
│   ├── mmutils/
│   │   ├── __init__.py
│   │   └── initialize.py
│   ├── models/
│   │   ├── __init__.py
│   │   ├── decoders/
│   │   │   ├── __init__.py
│   │   │   ├── uvmaps_decoder_gender.py
│   │   │   └── vit_head.py
│   │   ├── deformers/
│   │   │   ├── __init__.py
│   │   │   ├── fast_snarf/
│   │   │   │   ├── cuda/
│   │   │   │   │   ├── filter/
│   │   │   │   │   │   ├── filter.cpp
│   │   │   │   │   │   └── filter_kernel.cu
│   │   │   │   │   ├── fuse_kernel/
│   │   │   │   │   │   ├── fuse_cuda.cpp
│   │   │   │   │   │   └── fuse_cuda_kernel.cu
│   │   │   │   │   └── precompute/
│   │   │   │   │       ├── precompute.cpp
│   │   │   │   │       └── precompute_kernel.cu
│   │   │   │   └── lib/
│   │   │   │       └── model/
│   │   │   │           ├── deformer_smpl.py
│   │   │   │           └── deformer_smplx.py
│   │   │   ├── smplx/
│   │   │   │   ├── __init__.py
│   │   │   │   ├── body_models.py
│   │   │   │   ├── joint_names.py
│   │   │   │   ├── lbs.py
│   │   │   │   ├── utils.py
│   │   │   │   ├── vertex_ids.py
│   │   │   │   └── vertex_joint_selector.py
│   │   │   └── smplx_deformer_gender.py
│   │   ├── renderers/
│   │   │   ├── __init__.py
│   │   │   └── gau_renderer.py
│   │   ├── sapiens/
│   │   │   ├── __init__.py
│   │   │   └── sapiens_wrapper_torchscipt.py
│   │   └── transformer_sa/
│   │       ├── __init__.py
│   │       └── mae_decoder_v3_skip.py
│   ├── ops/
│   │   ├── __init__.py
│   │   └── activation.py
│   └── utils/
│       ├── infer_util.py
│       ├── mesh.py
│       ├── mesh_utils.py
│       └── train_util.py
├── run_demo.py
├── scripts/
│   ├── download_files.sh
│   ├── fetch_template.sh
│   └── pip_install.sh
├── setup.py
└── train.py
Download .txt
SYMBOL INDEX (260 symbols across 29 files)

FILE: data_processing/prepare_cache.py
  function parse_args (line 15) | def parse_args():
  function prepare_dataset (line 45) | def prepare_dataset(video_dir, output_dir, prefix, max_total_videos=20000):

FILE: data_processing/visualize_samples.py
  function init_smplx_model (line 11) | def init_smplx_model():

FILE: dataset/visualize_samples.py
  function init_smplx_model (line 11) | def init_smplx_model():

FILE: lib/datasets/avatar_dataset.py
  function load_pose (line 20) | def load_pose(path):
  function load_npy (line 30) | def load_npy(file_path):
  function load_smpl (line 33) | def load_smpl(path, smpl_type='smpl'):
  class AvatarDataset (line 63) | class AvatarDataset(Dataset):
    method __init__ (line 64) | def __init__(self,
    method load_scenes (line 149) | def load_scenes(self):
    method parse_scene (line 180) | def parse_scene(self, scene_id):
    method __len__ (line 454) | def __len__(self):
    method __getitem__ (line 457) | def __getitem__(self, scene_id):
    method parse_scene_test (line 466) | def parse_scene_test(self, scene_id):
  function gather_imgs (line 583) | def gather_imgs(img_ids, poses, image_paths_or_video, smpl_params, load_...
  function calculate_angle (line 644) | def calculate_angle(vector1, vector2):
  function axis_angle_to_rotation_matrix (line 650) | def axis_angle_to_rotation_matrix(axis_angle):
  function find_front_camera_by_rotation (line 676) | def find_front_camera_by_rotation(poses, global_orient_human):
  function read_frames (line 693) | def read_frames(video_path):
  function prepare_camera (line 711) | def prepare_camera( resolution_x, resolution_y, num_views=24, stides=1):
  function from_video_to_get_ref_smplx (line 769) | def from_video_to_get_ref_smplx(video_path):
  function random_scale_and_crop (line 789) | def random_scale_and_crop(image: torch.Tensor, scale_range=(0.8, 1.2)) -...

FILE: lib/datasets/dataloader.py
  class DataModuleFromConfig (line 19) | class DataModuleFromConfig(pl.LightningDataModule):
    method __init__ (line 20) | def __init__(
    method setup (line 42) | def setup(self, stage):
    method train_dataloader (line 45) | def train_dataloader(self):
    method val_dataloader (line 57) | def val_dataloader(self):
    method test_dataloader (line 68) | def test_dataloader(self):

FILE: lib/humanlrm_wrapper_sa_v1.py
  function get_1d_rotary_pos_embed (line 28) | def get_1d_rotary_pos_embed(
  class FluxPosEmbed (line 81) | class FluxPosEmbed(torch.nn.Module):
    method __init__ (line 83) | def __init__(self, theta: int, axes_dim: [int]):
    method forward (line 88) | def forward(self, ids: torch.Tensor) -> torch.Tensor:
  class SapiensGS_SA_v1 (line 105) | class SapiensGS_SA_v1(pl.LightningModule):
    method __init__ (line 107) | def __init__(
    method get_default_smplx_params (line 276) | def get_default_smplx_params(self):
    method forward_decoder (line 302) | def forward_decoder(self, decoder, code, target_rgbs, cameras,
    method on_fit_start (line 311) | def on_fit_start(self):
    method forward (line 319) | def forward(self, data):
    method forward_image_to_uv (line 376) | def forward_image_to_uv(self, inputs_img, is_training=True):
    method compute_loss (line 410) | def compute_loss(self, render_out):
    method compute_metrics (line 481) | def compute_metrics(self, render_out):
    method new_on_before_optimizer_step (line 513) | def new_on_before_optimizer_step(self):
    method validation_step (line 519) | def validation_step(self, batch, batch_idx):
    method forward_nvPose (line 537) | def forward_nvPose(self, batch, smplx_given):
    method on_validation_epoch_end (line 563) | def on_validation_epoch_end(self): #
    method on_test_start (line 608) | def on_test_start(self):
    method on_test_epoch_end (line 612) | def on_test_epoch_end(self):
    method configure_optimizers (line 643) | def configure_optimizers(self):
    method training_step (line 665) | def training_step(self, batch, batch_idx):
    method test_step (line 701) | def test_step(self, batch, batch_idx):
    method on_test_start (line 732) | def on_test_start(self):
    method on_test_epoch_end (line 736) | def on_test_epoch_end(self):
  function weighted_mse_loss (line 769) | def weighted_mse_loss(render_images, target_images, weights):

FILE: lib/mmutils/initialize.py
  function constant_init (line 3) | def constant_init(module: nn.Module, val: float, bias: float = 0) -> None:
  function xavier_init (line 10) | def xavier_init(module: nn.Module,

FILE: lib/models/decoders/uvmaps_decoder_gender.py
  function ensure_dtype (line 25) | def ensure_dtype(input_tensor, target_dtype=torch.float32):
  class UVNDecoder_gender (line 35) | class UVNDecoder_gender(nn.Module):
    method __init__ (line 44) | def __init__(self,
    method init_weights (line 213) | def init_weights(self):
    method extract_pcd (line 224) | def extract_pcd(self, code, smpl_params, init=False, zeros_hands_off=F...
    method deform_pcd (line 253) | def deform_pcd(self, code, smpl_params, init=False, zeros_hands_off=Fa...
    method _sample_feature (line 280) | def _sample_feature(self,results,):
    method _decode_feature (line 303) | def _decode_feature(self, point_code, init=False):
    method _decode (line 334) | def _decode(self, point_code, init=False):
    method gaussian_render (line 371) | def gaussian_render(self, pcd, sigmas, rgbs, normals, rot, num_scenes,...
    method visualize (line 424) | def visualize(self, code, scene_name, viz_dir, code_range=[-1, 1]):
    method forward (line 434) | def forward(self, code, smpl_params, cameras, num_imgs,
    method forward_render (line 527) | def forward_render(self, code, cameras, num_imgs,
    method forward_testing_time (line 613) | def forward_testing_time(self, code, smpl_params, cameras, num_imgs,

FILE: lib/models/decoders/vit_head.py
  class VitHead (line 6) | class VitHead(nn.Module):
    method __init__ (line 7) | def __init__(self,
    method _make_conv_layers (line 52) | def _make_conv_layers(self, in_channels: int,
    method _make_deconv_layers (line 66) | def _make_deconv_layers(self, in_channels: int,
    method forward (line 91) | def forward(self, inputs):

FILE: lib/models/deformers/fast_snarf/cuda/filter/filter.cpp
  function filter (line 13) | void filter(const torch::Tensor &x, const torch::Tensor &mask, torch::Te...
  function PYBIND11_MODULE (line 18) | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {

FILE: lib/models/deformers/fast_snarf/cuda/fuse_kernel/fuse_cuda.cpp
  function fuse_broyden (line 25) | void fuse_broyden(torch::Tensor &x,
  function PYBIND11_MODULE (line 47) | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {

FILE: lib/models/deformers/fast_snarf/cuda/precompute/precompute.cpp
  function precompute (line 12) | void precompute(const torch::Tensor &voxel_w, const torch::Tensor &tfs, ...
  function PYBIND11_MODULE (line 18) | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {

FILE: lib/models/deformers/fast_snarf/lib/model/deformer_smpl.py
  class ForwardDeformer (line 14) | class ForwardDeformer(torch.nn.Module):
    method __init__ (line 24) | def __init__(self,  **kwargs):
    method forward (line 36) | def forward(self, xd, cond, mask, tfs, eval_mode=False):
    method precompute (line 54) | def precompute(self, tfs):
    method search (line 63) | def search(self, xd, cond, mask, tfs, eval_mode=False):
    method broyden_cuda (line 85) | def broyden_cuda(self,
    method forward_skinning (line 116) | def forward_skinning(self, xc, cond, tfs, mask=None):
    method switch_to_explicit (line 135) | def switch_to_explicit(self,resolution=32,smpl_verts=None, smpl_faces=...
    method update_lbs_voxel (line 189) | def update_lbs_voxel(self):
    method query_sdf_smpl (line 199) | def query_sdf_smpl(self, x, smpl_verts, smpl_faces, smpl_weights):
    method skinning_normal (line 221) | def skinning_normal(self, xc, normal, tfs, cond=None, mask=None, inver...
  function skinning_mask (line 242) | def skinning_mask(x, w, tfs, inverse=False):
  function skinning (line 271) | def skinning(x, w, tfs, inverse=False):
  function fast_inverse (line 300) | def fast_inverse(T):
  function bmv (line 317) | def bmv(m, v):
  function create_voxel_grid (line 320) | def create_voxel_grid(d, h, w, device='cuda'):
  function query_weights_smpl (line 329) | def query_weights_smpl(x, smpl_verts, smpl_weights):

FILE: lib/models/deformers/fast_snarf/lib/model/deformer_smplx.py
  class ForwardDeformer (line 14) | class ForwardDeformer(torch.nn.Module):
    method __init__ (line 24) | def __init__(self,  **kwargs):
    method forward_skinning (line 35) | def forward_skinning(self, xc, shape_offset, pose_offset, cond, tfs, t...
    method switch_to_explicit (line 78) | def switch_to_explicit(self,resolution=32,smpl_verts=None, smpl_faces=...
    method update_lbs_voxel (line 132) | def update_lbs_voxel(self):
    method query_sdf_smpl (line 142) | def query_sdf_smpl(self, x, smpl_verts, smpl_faces, smpl_weights):
    method skinning_normal (line 164) | def skinning_normal(self, xc, normal, tfs, cond=None, mask=None, inver...
  function skinning_mask (line 185) | def skinning_mask(x, w, tfs, inverse=False):
  function skinning (line 214) | def skinning(x, w, tfs, inverse=False):
  function fast_inverse (line 243) | def fast_inverse(T):
  function bmv (line 260) | def bmv(m, v):
  function create_voxel_grid (line 264) | def create_voxel_grid(d, h, w, device='cuda'):
  function query_weights_smpl (line 273) | def query_weights_smpl(x, smpl_verts, smpl_weights):

FILE: lib/models/deformers/smplx/body_models.py
  class SMPL (line 39) | class SMPL(nn.Module):
    method __init__ (line 44) | def __init__(
    method num_betas (line 260) | def num_betas(self):
    method num_expression_coeffs (line 264) | def num_expression_coeffs(self):
    method create_mean_pose (line 267) | def create_mean_pose(self, data_struct) -> Tensor:
    method name (line 270) | def name(self) -> str:
    method reset_params (line 274) | def reset_params(self, **params_dict) -> None:
    method get_num_verts (line 281) | def get_num_verts(self) -> int:
    method get_num_faces (line 284) | def get_num_faces(self) -> int:
    method extra_repr (line 287) | def extra_repr(self) -> str:
    method forward_shape (line 295) | def forward_shape(
    method forward (line 303) | def forward(
  class SMPLH (line 405) | class SMPLH(SMPL):
    method __init__ (line 412) | def __init__(
    method create_mean_pose (line 570) | def create_mean_pose(self, data_struct, flat_hand_mean=False):
    method name (line 582) | def name(self) -> str:
    method extra_repr (line 585) | def extra_repr(self):
    method forward (line 593) | def forward(
  class SMPLX (line 661) | class SMPLX(SMPLH):
    method __init__ (line 677) | def __init__(
    method name (line 859) | def name(self) -> str:
    method num_expression_coeffs (line 863) | def num_expression_coeffs(self):
    method create_mean_pose (line 866) | def create_mean_pose(self, data_struct, flat_hand_mean=False):
    method extra_repr (line 884) | def extra_repr(self):
    method forward (line 892) | def forward(

FILE: lib/models/deformers/smplx/lbs.py
  function find_dynamic_lmk_idx_and_bcoords (line 30) | def find_dynamic_lmk_idx_and_bcoords(
  function vertices2landmarks (line 108) | def vertices2landmarks(
  function lbs (line 152) | def lbs(
  function vertices2joints (line 251) | def vertices2joints(J_regressor: Tensor, vertices: Tensor) -> Tensor:
  function blend_shapes (line 271) | def blend_shapes(betas: Tensor, shape_disps: Tensor) -> Tensor:
  function batch_rodrigues (line 295) | def batch_rodrigues(
  function transform_mat (line 332) | def transform_mat(R: Tensor, t: Tensor) -> Tensor:
  function batch_rigid_transform (line 345) | def batch_rigid_transform(

FILE: lib/models/deformers/smplx/utils.py
  class ModelOutput (line 28) | class ModelOutput:
    method __getitem__ (line 36) | def __getitem__(self, key):
    method get (line 39) | def get(self, key, default=None):
    method __iter__ (line 42) | def __iter__(self):
    method keys (line 45) | def keys(self):
    method values (line 49) | def values(self):
    method items (line 53) | def items(self):
  class SMPLOutput (line 59) | class SMPLOutput(ModelOutput):
  class SMPLHOutput (line 69) | class SMPLHOutput(SMPLOutput):
  class SMPLXOutput (line 76) | class SMPLXOutput(SMPLHOutput):
  function find_joint_kin_chain (line 81) | def find_joint_kin_chain(joint_id, kinematic_tree):
  function to_tensor (line 90) | def to_tensor(
  class Struct (line 99) | class Struct(object):
    method __init__ (line 100) | def __init__(self, **kwargs):
  function to_np (line 105) | def to_np(array, dtype=np.float32):
  function rot_mat_to_euler (line 111) | def rot_mat_to_euler(rot_mats):

FILE: lib/models/deformers/smplx/vertex_joint_selector.py
  class VertexJointSelector (line 29) | class VertexJointSelector(nn.Module):
    method __init__ (line 31) | def __init__(self, vertex_ids=None,
    method forward (line 73) | def forward(self, vertices, joints):

FILE: lib/models/deformers/smplx_deformer_gender.py
  class SMPLXDeformer_gender (line 12) | class SMPLXDeformer_gender(torch.nn.Module):
    method __init__ (line 14) | def __init__(self, gender, is_sub2=False) -> None:
    method initialize (line 68) | def initialize(self):
    method forword_body_model (line 125) | def forword_body_model(self, smpl_params, point_pool=4):
    method prepare_deformer (line 150) | def prepare_deformer(self, smpl_params=None, num_scenes=1, device=None):
    method __call__ (line 215) | def __call__(self, pts_in, rot_in, mask=None, cano=True, offset_gs=Non...

FILE: lib/models/renderers/gau_renderer.py
  function batch_rodrigues (line 9) | def batch_rodrigues(rot_vecs, epsilon = 1e-8):
  function build_scaling_rotation (line 42) | def build_scaling_rotation(s, r, tfs):
  function strip_symmetric (line 54) | def strip_symmetric(sym):
  function strip_lowerdiag (line 57) | def strip_lowerdiag(L):
  function build_covariance_from_scaling_rotation (line 68) | def build_covariance_from_scaling_rotation(scaling, scaling_modifier, ro...
  function build_rotation (line 74) | def build_rotation(r):
  function get_covariance (line 97) | def get_covariance(scaling, rotation, scaling_modifier = 1):
  class GRenderer (line 105) | class GRenderer(nn.Module):
    method __init__ (line 106) | def __init__(self, image_size=256, anti_alias=False, f=5000, near=0.01...
    method prepare (line 126) | def prepare(self, cameras):
    method render_gaussian (line 165) | def render_gaussian(self, means3D, colors_precomp, rotations, opacitie...
  function get_view_matrix (line 195) | def get_view_matrix(R, t):
  function get_proj_yy (line 200) | def get_proj_yy(f, image_size, far, near):
  function get_proj_matrix (line 206) | def get_proj_matrix(fovY,fovX, z_near, z_far, z_sign):
  function get_fov (line 226) | def get_fov(focal, princpt, img_shape):

FILE: lib/models/sapiens/sapiens_wrapper_torchscipt.py
  function pretrain_forward (line 23) | def pretrain_forward(sp_lite, inputs: Tensor, layer_num: int, return_hid...
  class SapiensWrapper_ts (line 63) | class SapiensWrapper_ts(nn.Module):
    method __init__ (line 68) | def __init__(self,
    method forward (line 94) | def forward(self, image, use_my_proces=False, requires_grad=False, out...
    method _freeze (line 139) | def _freeze(self):

FILE: lib/models/transformer_sa/mae_decoder_v3_skip.py
  function get_2d_sincos_pos_embed (line 9) | def get_2d_sincos_pos_embed(embed_dim, grid_size, add_cls_token=False):
  function get_2d_sincos_pos_embed_from_grid (line 37) | def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
  function get_1d_sincos_pos_embed_from_grid (line 49) | def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
  class neck_SA_v3_skip (line 71) | class neck_SA_v3_skip(nn.Module):
    method __init__ (line 72) | def __init__(self, patch_size=4, in_chans=32, num_patches=196, embed_d...
    method initialize_weights (line 106) | def initialize_weights(self):
    method _init_weights (line 116) | def _init_weights(self, m):
    method forward_decoder (line 126) | def forward_decoder(self, in_features, ids_restore):
    method forward (line 160) | def forward(self, encoded_latent, ids_restore):

FILE: lib/ops/activation.py
  class _trunc_exp (line 8) | class _trunc_exp(Function):
    method forward (line 11) | def forward(ctx, x):
    method backward (line 18) | def backward(ctx, g):
  class TruncExp (line 26) | class TruncExp(nn.Module):
    method forward (line 29) | def forward(x):

FILE: lib/utils/infer_util.py
  function reset_first_frame_rotation (line 26) | def reset_first_frame_rotation(root_orient, trans):
  function rotation_matrix_to_rodrigues (line 73) | def rotation_matrix_to_rodrigues(rotation_matrices):
  function get_hand_pose_mean (line 82) | def get_hand_pose_mean():
  function load_smplify_json (line 105) | def load_smplify_json(smplx_smplify_path):
  function load_image (line 144) | def load_image(input_path, output_folder, image_frame_ratio=None):
  function prepare_camera (line 191) | def prepare_camera( resolution_x = 640, resolution_y = 640, focal_length...
  function construct_camera (line 239) | def construct_camera(K, cam_list, device='cuda'):
  function get_name_str (line 266) | def get_name_str(name):
  function load_smplx_from_npy (line 272) | def load_smplx_from_npy(smplx_path, device='cuda'):
  function add_root_rotate_to_smplx (line 303) | def add_root_rotate_to_smplx(smpl_tmp, frames_num=180, device='cuda'):
  function load_smplx_from_json (line 328) | def load_smplx_from_json(smplx_path, device='cuda'):
  function get_image_dimensions (line 354) | def get_image_dimensions(input_path):
  function construct_camera_from_motionx (line 358) | def construct_camera_from_motionx(smplx_path, device='cuda'):
  function remove_background (line 385) | def remove_background(image: PIL.Image.Image,
  function resize_foreground (line 399) | def resize_foreground(
  function images_to_video (line 440) | def images_to_video(
  function save_video (line 461) | def save_video(

FILE: lib/utils/mesh.py
  function dot (line 7) | def dot(x, y):
  function length (line 11) | def length(x, eps=1e-20):
  function safe_normalize (line 15) | def safe_normalize(x, eps=1e-20):
  class Mesh (line 18) | class Mesh:
    method __init__ (line 19) | def __init__(
    method load (line 47) | def load(cls, path=None, resize=True, renormal=True, retex=False, fron...
    method load_obj (line 100) | def load_obj(cls, path, albedo_path=None, device=None):
    method load_trimesh (line 238) | def load_trimesh(cls, path, device=None):
    method aabb (line 325) | def aabb(self):
    method auto_size (line 330) | def auto_size(self):
    method auto_normal (line 336) | def auto_normal(self):
    method auto_uv (line 359) | def auto_uv(self, cache_path=None, vmap=True):
    method align_v_to_vt (line 392) | def align_v_to_vt(self, vmapping=None):
    method to (line 407) | def to(self, device):
    method write (line 415) | def write(self, path):
    method write_ply (line 426) | def write_ply(self, path):
    method write_glb (line 435) | def write_glb(self, path):
    method write_obj (line 568) | def write_obj(self, path):

FILE: lib/utils/mesh_utils.py
  function gaussian_3d_coeff (line 6) | def gaussian_3d_coeff(xyzs, covs):
  function poisson_mesh_reconstruction (line 27) | def poisson_mesh_reconstruction(points, normals=None):
  function decimate_mesh (line 67) | def decimate_mesh(
  function clean_mesh (line 111) | def clean_mesh(

FILE: lib/utils/train_util.py
  function main_print (line 7) | def main_print(*args):
  function count_params (line 12) | def count_params(model, verbose=False):
  function instantiate_from_config (line 19) | def instantiate_from_config(config):
  function get_obj_from_str (line 29) | def get_obj_from_str(string, reload=False):

FILE: run_demo.py
  function parse_args (line 20) | def parse_args():
  function process_data_on_gpu (line 44) | def process_data_on_gpu(args, model, gpu_id, img_paths_list, smplx_ref_p...
  function setup_directories (line 237) | def setup_directories(base_path, config_name):
  function main (line 248) | def main():

FILE: train.py
  function get_parser (line 34) | def get_parser(**parser_kwargs):
  class SetupCallback (line 134) | class SetupCallback(Callback):
    method __init__ (line 135) | def __init__(self, resume, logdir, ckptdir, cfgdir, config):
    method on_fit_start (line 143) | def on_fit_start(self, trainer, pl_module):
  class CodeSnapshot (line 156) | class CodeSnapshot(Callback):
    method __init__ (line 160) | def __init__(self, savedir, exclude_patterns=None):
    method get_file_list (line 167) | def get_file_list(self):
    method _match_pattern (line 194) | def _match_pattern(self, file_path, pattern):
    method save_code_snapshot (line 210) | def save_code_snapshot(self):
    method on_fit_start (line 218) | def on_fit_start(self, trainer, pl_module):
  function load_folder_ckpt (line 428) | def load_folder_ckpt(checkpoint_dir):
Condensed preview — 58 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (423K chars).
[
  {
    "path": ".gitignore",
    "chars": 325,
    "preview": "# Python cache files\r\n__pycache__/\r\n*.py[cod]\r\n*$py.class\r\n*.so\r\n\r\n# Project specific\r\nwork_dirs/\r\ntest/\r\nlib/models/def"
  },
  {
    "path": "README.md",
    "chars": 11054,
    "preview": "# **IDOL: Instant Photorealistic 3D Human Creation from a Single Image**  \r\n\r\n[![Website](https://img.shields.io/badge/P"
  },
  {
    "path": "configs/idol_debug.yaml",
    "chars": 4445,
    "preview": "\ndebug: True\n# code_size: [32, 256, 256]\ncode_size: [32, 1024, 1024]\nmodel:\n  # base_learning_rate: 2.0e-04 # yy Need to"
  },
  {
    "path": "configs/idol_v0.yaml",
    "chars": 4411,
    "preview": "\ndebug: True\n# code_size: [32, 256, 256]\ncode_size: [32, 1024, 1024]\nmodel:\n  # base_learning_rate: 2.0e-04 # yy Need to"
  },
  {
    "path": "configs/test_dataset.yaml",
    "chars": 1580,
    "preview": "\ndataset:\n  target: lib.datasets.dataloader.DataModuleFromConfig\n  params:\n    batch_size: 1 \n    num_workers: 2 \n    # "
  },
  {
    "path": "data_processing/prepare_cache.py",
    "chars": 6536,
    "preview": "#!/usr/bin/env python\n# -*- coding: utf-8 -*-\n\n\"\"\"\nData preparation script for DeepFashion video dataset.\nThis script pr"
  },
  {
    "path": "data_processing/process_datasets.sh",
    "chars": 1901,
    "preview": "#!/bin/bash\n\n# Data processing script for multiple datasets\n# This script processes all specified datasets and saves the"
  },
  {
    "path": "data_processing/visualize_samples.py",
    "chars": 4000,
    "preview": "\nimport torch\nimport numpy as np\nimport os\nos.environ[\"PYOPENGL_PLATFORM\"] = \"osmesa\"\nimport smplx\nimport trimesh\nimport"
  },
  {
    "path": "dataset/README.md",
    "chars": 5160,
    "preview": "# 🌟 HuGe100K Dataset Documentation\r\n\r\n## 📊 Dataset Overview\r\nHuGe100K is a large-scale multi-view human dataset featurin"
  },
  {
    "path": "dataset/visualize_samples.py",
    "chars": 3986,
    "preview": "\nimport torch\nimport numpy as np\nimport os\nos.environ[\"PYOPENGL_PLATFORM\"] = \"osmesa\"\nimport smplx\nimport trimesh\nimport"
  },
  {
    "path": "env/README.md",
    "chars": 1756,
    "preview": "# Environment Setup Guide\r\n\r\n## Prerequisites\r\n\r\n- Python 3.10\r\n- CUDA 11.8\r\n- PyTorch 2.3.1\r\n\r\n## Installation Steps\r\n\r"
  },
  {
    "path": "lib/__init__.py",
    "chars": 97,
    "preview": "from .models import *\nfrom .mmutils import *\nfrom .humanlrm_wrapper_sa_v1 import SapiensGS_SA_v1\n"
  },
  {
    "path": "lib/datasets/__init__.py",
    "chars": 86,
    "preview": "from .avatar_dataset import AvatarDataset\nfrom .dataloader import DataModuleFromConfig"
  },
  {
    "path": "lib/datasets/avatar_dataset.py",
    "chars": 40679,
    "preview": "import os\nimport random\nimport numpy as np\nimport torch\nimport json\nimport pickle\n\nfrom torch.utils.data import Dataset\n"
  },
  {
    "path": "lib/datasets/dataloader.py",
    "chars": 2251,
    "preview": "\nimport os, sys\nimport json\n\n\nimport numpy as np\nimport webdataset as wds\nimport pytorch_lightning as pl\nimport torch\nfr"
  },
  {
    "path": "lib/humanlrm_wrapper_sa_v1.py",
    "chars": 35472,
    "preview": "\nimport os\nimport math\nimport json\nfrom torch.optim import Adam\nfrom torch.nn.parallel.distributed import DistributedDat"
  },
  {
    "path": "lib/mmutils/__init__.py",
    "chars": 50,
    "preview": "from .initialize import xavier_init, constant_init"
  },
  {
    "path": "lib/mmutils/initialize.py",
    "chars": 863,
    "preview": "import torch.nn as nn\n\ndef constant_init(module: nn.Module, val: float, bias: float = 0) -> None:\n    if hasattr(module,"
  },
  {
    "path": "lib/models/__init__.py",
    "chars": 25,
    "preview": "\nfrom .decoders import *\n"
  },
  {
    "path": "lib/models/decoders/__init__.py",
    "chars": 88,
    "preview": "\nfrom .uvmaps_decoder_gender import UVNDecoder_gender\n\n__all__ = [ 'UVNDecoder_gender']\n"
  },
  {
    "path": "lib/models/decoders/uvmaps_decoder_gender.py",
    "chars": 31646,
    "preview": "import os\nimport matplotlib.pyplot as plt\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom torch"
  },
  {
    "path": "lib/models/decoders/vit_head.py",
    "chars": 4542,
    "preview": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom typing import Sequence, Tuple, Optional\n\nclass V"
  },
  {
    "path": "lib/models/deformers/__init__.py",
    "chars": 92,
    "preview": "\nfrom .smplx_deformer_gender import SMPLXDeformer_gender\n\n__all__ = ['SMPLXDeformer_gender']"
  },
  {
    "path": "lib/models/deformers/fast_snarf/cuda/filter/filter.cpp",
    "chars": 467,
    "preview": "#include <torch/extension.h>\n#include <ATen/ATen.h>\n#include <vector>\n#include <c10/cuda/CUDAGuard.h>\n\n\n\n\n\nvoid launch_f"
  },
  {
    "path": "lib/models/deformers/fast_snarf/cuda/filter/filter_kernel.cu",
    "chars": 2906,
    "preview": "#include <vector>\n#include <ATen/cuda/CUDAContext.h>\n#include <ATen/cuda/detail/TensorInfo.cuh>\n#include <ATen/cuda/deta"
  },
  {
    "path": "lib/models/deformers/fast_snarf/cuda/fuse_kernel/fuse_cuda.cpp",
    "chars": 1715,
    "preview": "#include \"ATen/Functions.h\"\n#include \"ATen/core/TensorBody.h\"\n#include <torch/extension.h>\n#include <ATen/ATen.h>\n#inclu"
  },
  {
    "path": "lib/models/deformers/fast_snarf/cuda/fuse_kernel/fuse_cuda_kernel.cu",
    "chars": 20733,
    "preview": "#include \"ATen/Functions.h\"\n#include \"ATen/core/TensorAccessor.h\"\n#include \"c10/cuda/CUDAException.h\"\n#include \"c10/cuda"
  },
  {
    "path": "lib/models/deformers/fast_snarf/cuda/precompute/precompute.cpp",
    "chars": 695,
    "preview": "#include <torch/extension.h>\n#include <ATen/ATen.h>\n#include <vector>\n#include <c10/cuda/CUDAGuard.h>\n\n\n\n\n\nvoid launch_p"
  },
  {
    "path": "lib/models/deformers/fast_snarf/cuda/precompute/precompute_kernel.cu",
    "chars": 4204,
    "preview": "#include \"ATen/Functions.h\"\n#include \"ATen/core/TensorAccessor.h\"\n#include \"c10/cuda/CUDAException.h\"\n#include \"c10/cuda"
  },
  {
    "path": "lib/models/deformers/fast_snarf/lib/model/deformer_smpl.py",
    "chars": 11948,
    "preview": "import torch\nfrom torch import einsum\nimport torch.nn.functional as F\nimport os\n\nfrom torch.utils.cpp_extension import l"
  },
  {
    "path": "lib/models/deformers/fast_snarf/lib/model/deformer_smplx.py",
    "chars": 9971,
    "preview": "import torch\nfrom torch import einsum\nimport torch.nn.functional as F\nimport os\n\nfrom torch.utils.cpp_extension import l"
  },
  {
    "path": "lib/models/deformers/smplx/__init__.py",
    "chars": 728,
    "preview": "# -*- coding: utf-8 -*-\n\n# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is\n# holder of all propri"
  },
  {
    "path": "lib/models/deformers/smplx/body_models.py",
    "chars": 45117,
    "preview": "#  -*- coding: utf-8 -*-\n\n# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is\n# holder of all propr"
  },
  {
    "path": "lib/models/deformers/smplx/joint_names.py",
    "chars": 4896,
    "preview": "# -*- coding: utf-8 -*-\n\n# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is\n# holder of all propri"
  },
  {
    "path": "lib/models/deformers/smplx/lbs.py",
    "chars": 13913,
    "preview": "# -*- coding: utf-8 -*-\n\n# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is\n# holder of all propri"
  },
  {
    "path": "lib/models/deformers/smplx/utils.py",
    "chars": 3326,
    "preview": "# -*- coding: utf-8 -*-\n\n# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is\n# holder of all propri"
  },
  {
    "path": "lib/models/deformers/smplx/vertex_ids.py",
    "chars": 2204,
    "preview": "# -*- coding: utf-8 -*-\n\n# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is\n# holder of all propri"
  },
  {
    "path": "lib/models/deformers/smplx/vertex_joint_selector.py",
    "chars": 2702,
    "preview": "# -*- coding: utf-8 -*-\n\n# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is\n# holder of all propri"
  },
  {
    "path": "lib/models/deformers/smplx_deformer_gender.py",
    "chars": 14596,
    "preview": "# Modified from Deformer of AG3D\n\nfrom .fast_snarf.lib.model.deformer_smplx import ForwardDeformer, skinning\nfrom .smplx"
  },
  {
    "path": "lib/models/renderers/__init__.py",
    "chars": 93,
    "preview": "\nfrom .gau_renderer import GRenderer, get_covariance, batch_rodrigues\n__all__ = ['GRenderer']"
  },
  {
    "path": "lib/models/renderers/gau_renderer.py",
    "chars": 8192,
    "preview": "from diff_gaussian_rasterization import (\n    GaussianRasterizationSettings,\n    GaussianRasterizer,\n)\nimport torch\nimpo"
  },
  {
    "path": "lib/models/sapiens/__init__.py",
    "chars": 57,
    "preview": "from .sapiens_wrapper_torchscipt import SapiensWrapper_ts"
  },
  {
    "path": "lib/models/sapiens/sapiens_wrapper_torchscipt.py",
    "chars": 5789,
    "preview": "# Copyright (c) 2023, Zexin He\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use thi"
  },
  {
    "path": "lib/models/transformer_sa/__init__.py",
    "chars": 48,
    "preview": "from .mae_decoder_v3_skip import neck_SA_v3_skip"
  },
  {
    "path": "lib/models/transformer_sa/mae_decoder_v3_skip.py",
    "chars": 6383,
    "preview": "import torch\nimport torch.nn as nn\nimport numpy as np\nfrom timm.models.vision_transformer import PatchEmbed, Block, chec"
  },
  {
    "path": "lib/ops/__init__.py",
    "chars": 33,
    "preview": "from .activation import TruncExp\n"
  },
  {
    "path": "lib/ops/activation.py",
    "chars": 637,
    "preview": "import math\nimport torch\nimport torch.nn as nn\nfrom torch.autograd import Function\nfrom torch.cuda.amp import custom_bwd"
  },
  {
    "path": "lib/utils/infer_util.py",
    "chars": 18522,
    "preview": "import os\nimport imageio\nimport rembg\nimport torch\nimport numpy as np\nimport PIL.Image\nfrom PIL import Image\nfrom typing"
  },
  {
    "path": "lib/utils/mesh.py",
    "chars": 23321,
    "preview": "import os\nimport cv2\nimport torch\nimport trimesh\nimport numpy as np\n\ndef dot(x, y):\n    return torch.sum(x * y, -1, keep"
  },
  {
    "path": "lib/utils/mesh_utils.py",
    "chars": 5142,
    "preview": "import numpy as np\nimport pymeshlab as pml\nimport torch\n\n\ndef gaussian_3d_coeff(xyzs, covs):\n    # xyzs: [N, 3]\n    # co"
  },
  {
    "path": "lib/utils/train_util.py",
    "chars": 1009,
    "preview": "import importlib\n\nimport os\n\nfrom pytorch_lightning.utilities import rank_zero_only\n@rank_zero_only\ndef main_print(*args"
  },
  {
    "path": "run_demo.py",
    "chars": 12797,
    "preview": "import os\nos.environ[\"CUDA_VISIBLE_DEVICES\"] = \"0\"\n# os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"0,1,2,3,4,5,6,7\"\nimport argpa"
  },
  {
    "path": "scripts/download_files.sh",
    "chars": 713,
    "preview": "#!/bin/bash\n\n# Create necessary directories\nmkdir -p work_dirs/\nmkdir -p work_dirs/ckpt\n\n# Download files from HuggingFa"
  },
  {
    "path": "scripts/fetch_template.sh",
    "chars": 2969,
    "preview": "#!/bin/bash\nurle () { [[ \"${1}\" ]] || return 1; local LANG=C i x; for (( i = 0; i < ${#1}; i++ )); do x=\"${1:i:1}\"; [[ \""
  },
  {
    "path": "scripts/pip_install.sh",
    "chars": 1894,
    "preview": "#!/bin/bash\n\n# Complete environment setup process\n\n# Step 0: Ensure you create a Conda environment \n# and Activate the e"
  },
  {
    "path": "setup.py",
    "chars": 709,
    "preview": "from setuptools import setup\r\nfrom torch.utils.cpp_extension import BuildExtension, CUDAExtension\r\nimport os\r\ncuda_dir ="
  },
  {
    "path": "train.py",
    "chars": 16200,
    "preview": "import os, sys\n# os.environ[\"WANDB_MODE\"] = \"dryrun\" # default setting to save locally\nfrom lib.utils.train_util import "
  }
]

// ... and 1 more files (download for full content)

About this extraction

This page contains the full source code of the yiyuzhuang/IDOL GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 58 files (396.2 KB), approximately 107.8k tokens, and a symbol index with 260 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.

Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.

Copied to clipboard!