[
  {
    "path": ".gitignore",
    "content": "# Python cache files\r\n__pycache__/\r\n*.py[cod]\r\n*$py.class\r\n*.so\r\n\r\n# Project specific\r\nwork_dirs/\r\ntest/\r\nlib/models/deformers/smplx/SMPLX/*\r\n\r\n# IDE\r\n.idea/\r\n.vscode/\r\n*.swp\r\n*.swo\r\n\r\n# Distribution / packaging\r\ndist/\r\nbuild/\r\n*.egg-info/\r\n\r\n# Jupyter Notebook\r\n.ipynb_checkpoints\r\n\r\n\r\n# Git related\r\n*.orig\r\n*.rej\r\n*.patch "
  },
  {
    "path": "README.md",
    "content": "# **IDOL: Instant Photorealistic 3D Human Creation from a Single Image**  \r\n\r\n[![Website](https://img.shields.io/badge/Project-Website-0073e6)](https://yiyuzhuang.github.io/IDOL/)\r\n[![Paper](https://img.shields.io/badge/arXiv-PDF-b31b1b)](https://arxiv.org/pdf/2412.14963)\r\n[![Live Demo](https://img.shields.io/badge/Live-Demo-34C759)](https://yiyuzhuang.github.io/IDOL/)\r\n[![License](https://img.shields.io/badge/License-MIT-blue.svg)](https://opensource.org/licenses/MIT)\r\n\r\n\r\n---\r\n\r\n<p align=\"center\">\r\n  <img src=\"./asset/images/Teaser_v2.png\" alt=\"Teaser Image for IDOL\" width=\"85%\">\r\n</p>\r\n\r\n---\r\n\r\n## **Abstract**\r\n\r\nThis work introduces **IDOL**, a feed-forward, single-image human reconstruction framework that is fast, high-fidelity, and generalizable. \r\nLeveraging a large-scale dataset of 100K multi-view subjects, our method demonstrates exceptional generalizability and robustness in handling diverse human shapes, cross-domain data, severe viewpoints, and occlusions. \r\nWith a uniform structured representation, the reconstructed avatars are directly animatable and easily editable, providing a significant step forward for various applications in graphics, vision, and beyond.\r\n\r\nIn summary, this project introduces:\r\n\r\n- **IDOL**: A scalable pipeline for instant photorealistic 3D human reconstruction using a simple yet efficient feed-forward model.\r\n- **HuGe100K Dataset**: We develop a data generation pipeline and present \\datasetname, a large-scale multi-view human dataset featuring diverse attributes, high-fidelity, high-resolution appearances, and a well-aligned SMPL-X model.\r\n- **Application Support**: Enabling 3D human reconstruction and downstream tasks such as editing and animation.\r\n\r\n\r\n---\r\n## 📰 **News** \r\n\r\n- **2024-12-18**: Paper is now available on arXiv.\r\n- **2025-01-02**: The demo dataset containing 100 samples is now available for access. The remaining dataset is currently undergoing further cleaning and review.\r\n- **2025-03-01**: 🎉 Paper accepted by CVPR 2025.\r\n- **2025-03-01**: 🎉 We have released the inference code! Check out the [Code Release](#code-release) section for details.\r\n- **2025-04-01**: 🔥 Full HuGe100K dataset is now available! See the [Dataset Access](#dataset-demo-access) section.\r\n- **2025-04-05**: 🔥 Training code is now available! Check out the [Training Code](#training-code) section for details.\r\n\r\n## 🚧 **Project Status**   \r\n\r\nWe are actively working on releasing the following resources:  \r\n\r\n| Resource                    | Status              | Expected Release Date      |\r\n|-----------------------------|---------------------|----------------------------|\r\n| **Dataset Demo**            | ✅ Available        | **Now Live! (2025.01.02)**      |\r\n| **Inference Code**             | ✅ Available        | **Now Live! (2025.03.01)**   |\r\n| **Full Dataset Access**     | ✅ Available        | **Now Live! (2025.04.01)**   |\r\n| **Online Demo**             | 🚧 In Progress      | **Before April  2025**   |\r\n| **Training Code**                    | ✅ Available     | **Now Live! (2025.04.05)**   |\r\n\r\nStay tuned as we update this section with new releases! 🚀  \r\n\r\n\r\n\r\n## 💻 **Code Release** \r\n\r\n### Installation & Environment Setup\r\n\r\nPlease refer to [env/README.md](env/README.md) for detailed environment setup instructions.\r\n\r\n### Quick Start\r\nRun demo with different modes:\r\n```bash\r\n# Reconstruct the input image\r\npython run_demo.py --render_mode reconstruct\r\n\r\n# Generate novel poses (animation)\r\npython run_demo.py --render_mode novel_pose\r\n\r\n# Generate 360-degree view\r\npython run_demo.py --render_mode novel_pose_A\r\n```\r\n\r\n### Training\r\n\r\n#### Data Preparation\r\n\r\n1. **Dataset Structure**: First, prepare your dataset with the following structure:\r\n   ```\r\n   dataset_root/\r\n   ├── deepfashion/\r\n   │   ├── image1/\r\n   │   │   ├── videos/\r\n   │   │   │   ├── xxx.mp4\r\n   │   │   │   └── xxx.jpg\r\n   │   │   └── param/\r\n   │   │       └── xxx.npy\r\n   │   └── image2/\r\n   │       ├── videos/\r\n   │       └── param/\r\n   └── flux_batch1_5000/\r\n       ├── image1/\r\n       │   ├── videos/\r\n       │   └── param/\r\n       └── image2/\r\n           ├── videos/\r\n           └── param/\r\n   ```\r\n\r\n2. **Process Dataset**: Run the data processing script to generate cache files:\r\n   ```bash\r\n   # Process the dataset and generate cache files\r\n   # Please modify the dataset path and the sample number in the script\r\n   bash data_processing/process_datasets.sh\r\n   ```\r\n\r\n   This will generate cache files in the `processed_data` directory:\r\n   - `deepfashion_train_140.npy`\r\n   - `deepfashion_val_10.npy`\r\n   - `deepfashion_test_50.npy`\r\n   - `flux_batch1_5000_train_140.npy`\r\n   - `flux_batch1_5000_val_10.npy`\r\n   - `flux_batch1_5000_test_50.npy`\r\n\r\n3. **Configure Cache Path**: Update the cache path in your config file (e.g., `configs/idol_v0.yaml`):\r\n   ```yaml\r\n     params:\r\n       cache_path: [\r\n         ./processed_data/deepfashion_train_140.npy,\r\n         ./processed_data/flux_batch1_5000_train_140.npy\r\n       ]\r\n   ```\r\n\r\n#### Training\r\n\r\n1. **Single-Node Training**: For single-node multi-GPU training:\r\n   ```bash\r\n   python train.py \\\r\n     --base configs/idol_v0.yaml \\\r\n     --num_nodes 1 \\\r\n     --gpus 0,1,2,3,4,5,6,7\r\n   ```\r\n\r\n2. **Multi-Node Training**: For multi-node training, specify additional parameters:\r\n   ```bash\r\n   python train.py \\\r\n     --base configs/idol_v0.yaml \\\r\n     --num_nodes <total_nodes> \\\r\n     --node_rank <current_node_rank> \\\r\n     --master_addr <master_node_ip> \\\r\n     --master_port <port_number> \\\r\n     --gpus 0,1,2,3,4,5,6,7\r\n   ```\r\n\r\n   Example for a 2-node setup:\r\n   ```bash\r\n   # On master node (node 0):   \r\n   python train.py --base configs/idol_v0.yaml --num_nodes 2 --node_rank 0 --master_addr 192.168.1.100 --master_port 29500 --gpus 0,1,2,3,4,5,6,7\r\n\r\n   # On worker node (node 1):\r\n   python train.py --base configs/idol_v0.yaml --num_nodes 2 --node_rank 1 --master_addr 192.168.1.100 --master_port 29500 --gpus 0,1,2,3,4,5,6,7\r\n   ```\r\n\r\n3. **Resume Training**: To resume training from a checkpoint:\r\n   ```bash\r\n   python train.py \\\r\n     --base configs/idol_v0.yaml \\\r\n     --resume PATH/TO/MODEL.ckpt \\\r\n     --num_nodes 1 \\\r\n     --gpus 0,1,2,3,4,5,6,7\r\n   ```\r\n\r\n4. **Test and Evaluate Metrics**:\r\n   ```bash\r\n   python train.py \\\r\n     --base configs/idol_v0.yaml \\                # Main config file (model)\r\n     --num_nodes 1 \\\r\n     --gpus 0,1,2,3,4,5,6,7 \\\r\n     --test_sd /path/to/model_checkpoint.ckpt \\   # Path to the .ckpt model you want to test\r\n      --test_dataset ./configs/test_dataset.yaml   # (Optional) Dataset config used specifically for testing\r\n   ```\r\n\r\n## Notes\r\n- Make sure all GPUs have enough memory for the selected batch size\r\n- For multi-node training, ensure network connectivity between nodes\r\n- Monitor training progress using the logging system\r\n- Adjust learning rate and other hyperparameters in the config file as needed\r\n\r\n\r\n## 🌐 **Key Links** \r\n\r\n- 📄 [**Paper on arXiv**](https://arxiv.org/pdf/2412.02684)  \r\n- 🌐 [**Project Website**](https://yiyuzhuang.github.io/IDOL/)  \r\n- 🚀 [**Live Demo**](https://your-live-demo-link.com) (Coming Soon!)  \r\n\r\n---\r\n\r\n## 📊 **Dataset Demo Access**   \r\n\r\nWe introduce **HuGe100K**, a large-scale multi-view human dataset, supporting 3D human reconstruction and animation research.  \r\n\r\n### ▶ **Watch the Demo Video**\r\n<p align=\"center\">\r\n  <img src=\"./asset/videos/dataset.gif\" alt=\"Dataset GIF\" width=\"85%\">\r\n</p>\r\n\r\n### 📋 **Dataset Documentation**\r\nFor detailed information about the dataset format, structure, and usage guidelines, please refer to our [Dataset Documentation](dataset/README.md).\r\n\r\n### 🚀 **Access the Dataset**   \r\n\r\n<div align=\"center\">\r\n  <p><strong>🔥 HuGe100K - The largest multi-view human dataset with 100,000+ subjects! 🔥</strong></p>\r\n  <p>High-resolution • Multi-view • Diverse poses • SMPL-X aligned</p>\r\n  \r\n\r\n  <a href=\"https://docs.google.com/forms/d/e/1FAIpQLSeVqrA9Mc_ODdcTZsB3GgrxgSNZk5deOzK4f64N72xlQFhvzQ/viewform?usp=dialog\">\r\n    <img src=\"https://img.shields.io/badge/Apply_for_Access-HuGe100K_Dataset-FF6B6B?style=for-the-badge&logo=googleforms&logoColor=white\" alt=\"Apply for Access\" width=\"300px\">\r\n  </a>\r\n  <p><i>Complete the form to get access credentials and download links!</i></p>\r\n</div>\r\n\r\n### ⚖️ **License and Attribution**\r\n\r\nThis dataset includes images derived from the **DeepFashion** dataset, originally provided by MMLAB at The Chinese University of Hong Kong. The use of DeepFashion images in this dataset has been explicitly authorized by the original authors solely for the purpose of creating and distributing this dataset. **Users must not further reproduce, distribute, sell, or commercially exploit any images or derived data originating from DeepFashion.** For any subsequent or separate use of the DeepFashion data, users must directly obtain authorization from MMLAB and comply with the original [DeepFashion License](https://mmlab.ie.cuhk.edu.hk/projects/DeepFashion.html).\r\n\r\n---\r\n\r\n## 📝 **Citation**   \r\n\r\nIf you find our work helpful, please cite us using the following BibTeX:\r\n\r\n```bibtex\r\n@article{zhuang2024idolinstant,                \r\n  title={IDOL: Instant Photorealistic 3D Human Creation from a Single Image}, \r\n  author={Yiyu Zhuang and Jiaxi Lv and Hao Wen and Qing Shuai and Ailing Zeng and Hao Zhu and Shifeng Chen and Yujiu Yang and Xun Cao and Wei Liu},\r\n  journal={arXiv preprint arXiv:2412.14963},\r\n  year={2024},\r\n  url={https://arxiv.org/abs/2412.14963}, \r\n}\r\n```\r\n\r\n\r\n\r\n## **License** \r\n\r\nThis project is licensed under the **MIT License**.\r\n\r\n- **Permissions**: This license grants permission to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the software.\r\n- **Condition**: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.\r\n- **Disclaimer**: The software is provided \"as is\", without warranty of any kind.\r\n\r\nFor more information, see the full license [here](https://opensource.org/licenses/MIT).\r\n\r\n## **Support Our Work** ⭐\r\n\r\nIf you find our work useful for your research or applications:\r\n\r\n- Please ⭐ **star our repository** to help us reach more people\r\n- Consider **citing our paper** in your publications (see [Citation](#citation) section)\r\n- Share our project with others who might benefit from it\r\n\r\nYour support helps us continue developing open-source research projects like this one!\r\n\r\n## 📚 **Acknowledgments**\r\n\r\nThis project is majorly built upon several excellent open-source projects:\r\n\r\n- [E3Gen](https://github.com/olivia23333/E3Gen): Efficient, Expressive and Editable Avatars Generation\r\n- [SAPIENS](https://github.com/facebookresearch/sapiens): High-resolution visual models for human-centric tasks\r\n- [GeoLRM](https://github.com/alibaba-yuanjing-aigclab/GeoLRM): Large Reconstruction Model for High-Quality 3D Generation\r\n- [3D Gaussian Splatting](https://github.com/graphdeco-inria/gaussian-splatting): Real-Time 3DGS Rendering\r\n\r\nWe thank all the authors for their contributions to the open-source community.\r\n"
  },
  {
    "path": "configs/idol_debug.yaml",
    "content": "\ndebug: True\n# code_size: [32, 256, 256]\ncode_size: [32, 1024, 1024]\nmodel:\n  # base_learning_rate: 2.0e-04 # yy Need to check\n  target: lib.SapiensGS_SA_v1\n  params:\n   # optimizer add\n    # use_bf16: true\n    max_steps: 100_000\n    warmup_steps: 10_000 #12_000\n    use_checkpoint: true\n    lambda_depth_tv: 0.05 \n    lambda_lpips: 0 #2.0\n    lambda_mse: 20 #1.0\n    lambda_offset: 1 #offset_weight: 50 mse 20, lpips 0.1\n    neck_learning_rate: 5e-4\n    decoder_learning_rate: 5e-4\n\n\n    output_hidden_states: true  # if True, will output the hidden states from sapiens shallow layer, for the neck decoder\n    \n    loss_coef: 0.5 \n    init_iter: 500\n    scale_weight: 0.01\n    smplx_path:  'work_dirs/demo_data/Ways_to_Catch_360_clip1.json'\n   \n    code_reshape:  [32, 96, 96] \n    patch_size: 1\n    code_activation:\n      type: tanh\n      mean: 0.0\n      std: 0.5\n      clip_range: 2\n    grid_size: 64\n    encoder:\n      target: lib.models.sapiens.SapiensWrapper_ts \n      params:\n        model_path:   work_dirs/ckpt/sapiens_1b_epoch_173_torchscript.pt2\n        # model_path: /apdcephfs_cq8/share_1367250/harriswen/projects/sapiens_convert/checkpoints//sapiens_1b_epoch_173_torchscript.pt2\n        layer_num: 40\n        img_size: [1024, 736]\n        freeze: True\n    neck:\n      target: lib.models.transformer_sa.neck_SA_v3_skip # TODO!! add a self attention version\n      params:\n        patch_size: 4 #4,\n        in_chans: 32  #32, # the uv code  dims\n        num_patches: 9216 #4096 #num_patches  #,#4096, # 16*16\n        embed_dim: 1536 # sapiens' latent dims # 1920 # 1920 for sapiens encoder2  #1024 # the feature extrators outputs\n        decoder_embed_dim: 128 # 1024\n        decoder_depth: 2 # 8\n        decoder_num_heads: 4 #16,\n        total_num_hidden_states: 12 \n        mlp_ratio: 4.\n    decoder:\n      target:  lib.models.decoders.UVNDecoder_gender \n      params:\n        interp_mode: bilinear\n        base_layers: [16, 64]\n        density_layers: [64, 1]\n        color_layers: [16, 128, 9]\n        offset_layers: [64, 3]\n        use_dir_enc: false\n        dir_layers: [16, 64]\n        activation: silu\n        bg_color: 1\n        sigma_activation: sigmoid\n        sigmoid_saturation: 0.001\n        gender: neutral\n        is_sub2: true ## update, make it into 10w gs points\n        multires: 0\n        image_size: [640, 896]\n        superres: false\n        focal: 1120\n        up_cnn_in_channels: 128 # be the same as decoder_embed_dim\n        reshape_type: VitHead\n        vithead_param:\n          in_channels: 128 # be the same as decoder_embed_dim\n          out_channels: 32\n          deconv_out_channels: [128, 64]\n          deconv_kernel_sizes: [4, 4]\n          conv_out_channels: [128, 128]\n          conv_kernel_sizes: [3, 3]\n        fix_sigma: true\n\ndataset:\n  target: lib.datasets.dataloader.DataModuleFromConfig\n  params:\n    batch_size: 1 #16 # 6 for lpips\n    num_workers: 1 #2\n    # working when in debug mode\n    debug_cache_path:./processed_data/flux_batch1_5000_test_50_local.npy\n\n    train: \n      target: lib.datasets.AvatarDataset\n      params:\n        data_prefix: None\n     \n        cache_path:  [\n          ./processed_data/deepfashion_train_140_local.npy,\n          ./processed_data/flux_batch1_5000_train_140_local.npy\n        ]\n\n        specific_observation_num: 5\n        better_range: true\n        first_is_front: true\n        if_include_video_ref_img: true  \n        prob_include_video_ref_img: 0.5\n        img_res: [640, 896]\n    validation:\n      target: lib.datasets.AvatarDataset\n      params:\n        data_prefix: None\n        load_imgs: true\n        specific_observation_num: 3\n        better_range: true\n        first_is_front: true\n        img_res: [640, 896]\n        cache_path:       [\n        ./processed_data/flux_batch1_5000_test_50_local.npy,\n        #./processed_data/flux_batch1_5000_val_10_local.npy\n        ]\n\n\n\nlightning:\n  modelcheckpoint:\n    params:\n      every_n_train_steps:  4000 #2000\n      save_top_k: -1\n      save_last: true\n      monitor: 'train/loss_mse' # ADD this logging in the wrapper_sa\n      mode: \"min\"\n      filename: 'sample-synData-epoch{epoch:02d}-val_loss{val/loss:.2f}'\n  callbacks: {}\n  trainer:\n    num_sanity_val_steps: 1\n    accumulate_grad_batches: 1\n    gradient_clip_val: 10.0\n    max_steps: 80000\n    check_val_every_n_epoch: 1  ## check validation set every 1 training batches in the current epoch\n    benchmark: true\n    val_check_interval: 1.0"
  },
  {
    "path": "configs/idol_v0.yaml",
    "content": "\ndebug: True\n# code_size: [32, 256, 256]\ncode_size: [32, 1024, 1024]\nmodel:\n  # base_learning_rate: 2.0e-04 # yy Need to check\n  target: lib.SapiensGS_SA_v1\n  params:\n   # optimizer add\n    # use_bf16: true\n    max_steps: 100_000\n    warmup_steps: 10_000 #12_000\n    use_checkpoint: true\n    lambda_depth_tv: 0.05 \n    lambda_lpips: 10 #2.0\n    lambda_mse: 20 #1.0\n    lambda_offset: 1 #offset_weight: 50 mse 20, lpips 0.1\n    neck_learning_rate: 5e-4\n    decoder_learning_rate: 5e-4\n\n\n    output_hidden_states: true  # if True, will output the hidden states from sapiens shallow layer, for the neck decoder\n    \n    loss_coef: 0.5 \n    init_iter: 500\n    scale_weight: 0.01\n    smplx_path:  'work_dirs/demo_data/Ways_to_Catch_360_clip1.json'\n   \n    code_reshape:  [32, 96, 96] \n    patch_size: 1\n    code_activation:\n      type: tanh\n      mean: 0.0\n      std: 0.5\n      clip_range: 2\n    grid_size: 64\n    encoder:\n      target: lib.models.sapiens.SapiensWrapper_ts \n      params:\n        model_path:   work_dirs/ckpt/sapiens_1b_epoch_173_torchscript.pt2\n        # model_path: /apdcephfs_cq8/share_1367250/harriswen/projects/sapiens_convert/checkpoints//sapiens_1b_epoch_173_torchscript.pt2\n        layer_num: 40\n        img_size: [1024, 736]\n        freeze: True\n    neck:\n      target: lib.models.transformer_sa.neck_SA_v3_skip # TODO!! add a self attention version\n      params:\n        patch_size: 4 #4,\n        in_chans: 32  #32, # the uv code  dims\n        num_patches: 9216 #4096 #num_patches  #,#4096, # 16*16\n        embed_dim: 1536 # 1920 # 1920 for sapiens encoder2  #1024 # the feature extrators outputs\n        decoder_embed_dim: 1536 # 1024\n        decoder_depth: 16 # 8\n        decoder_num_heads: 16 #16,\n        total_num_hidden_states: 40 \n        mlp_ratio: 4.\n    decoder:\n      target:  lib.models.decoders.UVNDecoder_gender \n      params:\n        interp_mode: bilinear\n        base_layers: [16, 64]\n        density_layers: [64, 1]\n        color_layers: [16, 128, 9]\n        offset_layers: [64, 3]\n        use_dir_enc: false\n        dir_layers: [16, 64]\n        activation: silu\n        bg_color: 1\n        sigma_activation: sigmoid\n        sigmoid_saturation: 0.001\n        gender: neutral\n        is_sub2: true ## update, make it into 10w gs points\n        multires: 0\n        image_size: [640, 896]\n        superres: false\n        focal: 1120\n        up_cnn_in_channels: 1536 # be the same as decoder_embed_dim\n        reshape_type: VitHead\n        vithead_param:\n          in_channels: 1536 # be the same as decoder_embed_dim\n          out_channels: 32\n          deconv_out_channels: [512, 512, 512, 256]\n          deconv_kernel_sizes: [4, 4, 4, 4]\n          conv_out_channels: [128, 128]\n          conv_kernel_sizes: [3, 3]\n        fix_sigma: true\n\ndataset:\n  target: lib.datasets.dataloader.DataModuleFromConfig\n  params:\n    batch_size: 1 #16 # 6 for lpips\n    num_workers: 2 #2\n    # working when in debug mode\n    debug_cache_path:  ./processed_data/flux_batch1_5000_test_50_local.npy\n\n    train: \n      target: lib.datasets.AvatarDataset\n      params:\n        data_prefix: None\n     \n        cache_path:  [\n          ./processed_data/deepfashion_train_140_local.npy,\n          ./processed_data/flux_batch1_5000_train_140_local.npy\n        ]\n\n        specific_observation_num: 5\n        better_range: true\n        first_is_front: true\n        if_include_video_ref_img: true  \n        prob_include_video_ref_img: 0.5\n        img_res: [640, 896]\n    validation:\n      target: lib.datasets.AvatarDataset\n      params:\n        data_prefix: None\n        load_imgs: true\n        specific_observation_num: 5\n        better_range: true\n        first_is_front: true\n        img_res: [640, 896]\n        cache_path:  [\n          ./processed_data/deepfashion_val_10_local.npy,\n          ./processed_data/flux_batch1_5000_val_10_local.npy\n        ]\n\n\n\nlightning:\n  modelcheckpoint:\n    params:\n      every_n_train_steps:  4000 #2000\n      save_top_k: -1\n      save_last: true\n      monitor: 'train/loss_mse' # ADD this logging in the wrapper_sa\n      mode: \"min\"\n      filename: 'sample-synData-epoch{epoch:02d}-val_loss{val/loss:.2f}'\n  callbacks: {}\n  trainer:\n    num_sanity_val_steps: 0\n    accumulate_grad_batches: 1\n    gradient_clip_val: 10.0\n    max_steps: 80000\n    check_val_every_n_epoch: 1  ## check validation set every 1 training batches in the current epoch\n    benchmark: true"
  },
  {
    "path": "configs/test_dataset.yaml",
    "content": "\ndataset:\n  target: lib.datasets.dataloader.DataModuleFromConfig\n  params:\n    batch_size: 1 \n    num_workers: 2 \n    # working when in debug mode\n    debug_cache_path:  ./processed_data/flux_batch1_5000_test_50_local.npy\n\n    train: \n      target: lib.datasets.AvatarDataset\n      params:\n        data_prefix: None\n     \n        cache_path:  [\n          ./processed_data/deepfashion_train_140_local.npy,\n          ./processed_data/flux_batch1_5000_train_140_local.npy\n        ]\n\n        specific_observation_num: 5\n        better_range: true\n        first_is_front: true\n        if_include_video_ref_img: true\n        prob_include_video_ref_img: 0.5                                                                                                                          \n        img_res: [640, 896]\n    validation:\n      target: lib.datasets.AvatarDataset\n      params:\n        data_prefix: None\n        load_imgs: true\n        specific_observation_num: 5\n        better_range: true\n        first_is_front: true\n        img_res: [640, 896]\n        cache_path:  [\n          ./processed_data/deepfashion_val_10_local.npy,\n          ./processed_data/flux_batch1_5000_val_10_local.npy\n        ]\n    test:\n      target: lib.datasets.AvatarDataset\n      params:\n        data_prefix: None\n        load_imgs: true\n        specific_observation_num: 5\n        better_range: true\n        first_is_front: true\n        img_res: [640, 896]\n        cache_path:  [\n          ./processed_data/deepfashion_test_50_local.npy,\n          ./processed_data/flux_batch1_5000_test_50_local.npy\n        ]"
  },
  {
    "path": "data_processing/prepare_cache.py",
    "content": "#!/usr/bin/env python\n# -*- coding: utf-8 -*-\n\n\"\"\"\nData preparation script for DeepFashion video dataset.\nThis script processes video files and their corresponding parameters,\nand splits the dataset into train/val/test sets.\n\"\"\"\n\nimport os\nimport numpy as np\nimport argparse\n\n\ndef parse_args():\n    \"\"\"Parse command line arguments.\"\"\"\n    parser = argparse.ArgumentParser(description=\"Prepare DeepFashion video dataset\")\n    parser.add_argument(\n        \"--video_dir\", \n        type=str, \n        default=\"/apdcephfs/private_harriswen/data/deepfashion/\",\n        help=\"Base directory containing imageX folders\"\n    )\n    parser.add_argument(\n        \"--output_dir\", \n        type=str, \n        default=\"./\", \n        help=\"Directory to save the processed data\"\n    )\n    parser.add_argument(\n        \"--prefix\", \n        type=str, \n        default=\"DeepFashion\", \n        help=\"Prefix for the output file names\"\n    )\n    parser.add_argument(\n        \"--max_videos\", \n        type=int, \n        default=20000, \n        help=\"Maximum number of videos to process (for creating smaller datasets)\"\n    )\n    return parser.parse_args()\n\n\ndef prepare_dataset(video_dir, output_dir, prefix, max_total_videos=20000):\n    \"\"\"\n    Prepare the DeepFashion dataset by processing videos and parameters.\n    \n    Args:\n        video_dir: Base directory containing imageX folders\n        output_dir: Directory to save processed data\n        prefix: Prefix for output filenames\n        max_total_videos: Maximum number of videos to process (default: 20000)\n    \"\"\"\n    # Find all imageX subdirectories\n    image_dirs = []\n    for item in os.listdir(video_dir):\n        if item.startswith(\"image\") and os.path.isdir(os.path.join(video_dir, item)):\n            image_dirs.append(item)\n    \n    image_dirs.sort()\n    print(f\"Found {len(image_dirs)} image directories: {image_dirs}\")\n    \n    # Collect all video files\n    all_video_files = []\n    all_param_files = []\n    all_dir_names = []\n    # import ipdb; ipdb.set_trace()\n    for image_dir in image_dirs:\n        videos_path = os.path.join(video_dir, image_dir, \"videos\")\n        params_path = os.path.join(video_dir, image_dir, \"param\")\n        \n        if not os.path.exists(videos_path):\n            print(f\"Warning: Videos directory not found in {image_dir}, skipping.\")\n            continue\n        \n        if not os.path.exists(params_path):\n            print(f\"Warning: Parameters directory not found in {image_dir}, skipping.\")\n            continue\n        \n        # Get list of video names in current directory\n        param_names = os.listdir(params_path)\n\n        # filter the files with .npy extension\n        param_names = [name for name in param_names if name.endswith(\".npy\")]\n        \n        for name in param_names:\n            video_path = os.path.join(videos_path, name.replace(\".npy\", \".mp4\"))\n            param_path = os.path.join(params_path, name)\n            \n            # Check if both video and parameter files exist\n            if not os.path.exists(video_path):\n                print(f\"Warning: Video file not found: {video_path}, skipping.\")\n                continue\n                \n            if not os.path.exists(param_path):\n                print(f\"Warning: Parameter file not found: {param_path}, skipping.\")\n                continue\n                \n            # Add to collection only if both files exist\n            all_video_files.append(video_path)\n            all_param_files.append(param_path)\n            all_dir_names.append(image_dir)\n    \n    total_videos = len(all_video_files)\n    print(f\"Total valid videos found: {total_videos}\")\n    \n    if total_videos == 0:\n        print(\"Error: No valid video-parameter pairs found. Please check your data paths.\")\n        return\n    \n    # Limit number of videos to process\n    if max_total_videos < total_videos:\n        # Randomly shuffle and select first max_total_videos\n        indices = list(range(total_videos))\n        np.random.shuffle(indices)\n        indices = indices[:max_total_videos]\n        \n        all_video_files = [all_video_files[i] for i in indices]\n        all_param_files = [all_param_files[i] for i in indices]\n        all_dir_names = [all_dir_names[i] for i in indices]\n        \n        print(f\"Limiting to {max_total_videos} videos\")\n    \n    # Process videos and parameters\n    scenes = []\n    processed_count = 0\n    skipped_count = 0\n    \n    for video_path, param_path, dir_name in zip(all_video_files, all_param_files, all_dir_names):\n        processed_count += 1\n        case_name = os.path.basename(video_path)\n        \n        print(f\"Processing {processed_count}/{len(all_video_files)}: {dir_name}/{case_name}\")\n        \n        try:\n            # Create scene dictionary\n            scenes.append(dict(\n                video_path=video_path,\n                image_paths=None, # only fill it for the data in a images sequence instead of a video\n                param_path=param_path,\n                image_ref=video_path.replace(\".mp4\", \".jpg\")\n            ))\n        except Exception as e:\n            print(f\"Error processing {video_path}: {e}\")\n            skipped_count += 1\n    \n    print(f\"Total scenes collected: {len(scenes)}\")\n    print(f\"Total scenes skipped: {skipped_count}\")\n    \n    if len(scenes) == 0:\n        print(\"Error: No scenes could be processed. Please check your data.\")\n        return\n    \n    # Split dataset\n    total_scenes = len(scenes)\n    test_scenes = scenes[-50:] if total_scenes > 50 else []\n    val_scenes = scenes[-60:-50] if total_scenes > 60 else []\n    train_scenes = scenes[:-60] if total_scenes > 60 else scenes\n    \n    # Save each split\n    splits = {\n        \"train\": train_scenes,\n        \"val\": val_scenes,\n        \"test\": test_scenes,\n        \"all\": scenes\n    }\n    \n    # Create output directory\n    os.makedirs(output_dir, exist_ok=True)\n    \n    # Save each split to separate file\n    for split_name, split_data in splits.items():\n        if not split_data:\n            continue\n            \n        cache_path = os.path.join(\n            output_dir, \n            f\"{prefix}_{split_name}_{len(split_data)}.npy\"\n        )\n        np.save(cache_path, split_data)\n        print(f\"Saved {split_name} split with {len(split_data)} samples to {cache_path}\")\n\n\nif __name__ == \"__main__\":\n    # Parse command line arguments\n    args = parse_args()\n    \n    # Prepare and save the dataset\n    prepare_dataset(args.video_dir, args.output_dir, args.prefix, args.max_videos)\n    print(f\"Done processing {args.video_dir} dataset\")\n"
  },
  {
    "path": "data_processing/process_datasets.sh",
    "content": "#!/bin/bash\n\n# Data processing script for multiple datasets\n# This script processes all specified datasets and saves the results to output directories\n\n# Define the list of dataset paths\nDATASET_PATHS=(\n    \"/PATH/TO/deepfashion\"\n    \"/PATH/TO/flux_batch1_5000\"\n    \"/PATH/TO/flux_batch2\"\n    # Add more dataset paths here as needed\n)\n\n# Output base directory for processed cache files\nOUTPUT_BASE_DIR=\"./processed_data\"\n\n# Maximum videos to process per dataset (set to a smaller number for testing)\n# if you want to process all videos, set MAX_VIDEOS to a very large number\nMAX_VIDEOS=200\n\n# Process each dataset\nfor DATASET_PATH in \"${DATASET_PATHS[@]}\"; do\n    # Extract dataset name from path (use the last directory name as prefix)\n    DATASET_NAME=$(basename \"$DATASET_PATH\")\n    \n    # Create output directory for this dataset\n    OUTPUT_DIR=\"${OUTPUT_BASE_DIR}\"\n    mkdir -p \"$OUTPUT_DIR\"\n    \n    echo \"===== Processing ${DATASET_NAME} Dataset =====\"\n    echo \"Source: ${DATASET_PATH}\"\n    echo \"Destination: ${OUTPUT_DIR}\"\n    \n    # Run the processing script\n    python data_processing/prepare_cache.py \\\n        --video_dir \"${DATASET_PATH}\" \\\n        --output_dir \"${OUTPUT_DIR}\" \\\n        --prefix \"${DATASET_NAME}\" \\\n        --max_videos \"${MAX_VIDEOS}\"\n    \n    # Check if processing was successful\n    if [ $? -ne 0 ]; then\n        echo \"Error processing ${DATASET_NAME} dataset\"\n        echo \"Continuing with next dataset...\"\n    else\n        echo \"Successfully processed ${DATASET_NAME} dataset\"\n    fi\n    \n    echo \"----------------------------------------\"\ndone\n\necho \"===== All datasets processing completed =====\"\necho \"Results saved to: ${OUTPUT_BASE_DIR}\"\n\n# List all processed datasets\necho \"Processed datasets:\"\nfor DATASET_PATH in \"${DATASET_PATHS[@]}\"; do\n    DATASET_NAME=$(basename \"$DATASET_PATH\")\n    echo \"- ${DATASET_NAME}: ${OUTPUT_BASE_DIR}/${DATASET_NAME}\"\ndone "
  },
  {
    "path": "data_processing/visualize_samples.py",
    "content": "\nimport torch\nimport numpy as np\nimport os\nos.environ[\"PYOPENGL_PLATFORM\"] = \"osmesa\"\nimport smplx\nimport trimesh\nimport pyrender\nimport imageio\n\ndef init_smplx_model():\n    \"\"\"Initialize the SMPL-X model with predefined settings.\"\"\"\n    body_model = smplx.SMPLX('PATH_TO_YOUR_SMPLX_FOLDER',\n                             gender=\"neutral\", \n                             create_body_pose=False, \n                             create_betas=False, \n                             create_global_orient=False, \n                             create_transl=False,\n                             create_expression=True,\n                             create_jaw_pose=True, \n                             create_leye_pose=True, \n                             create_reye_pose=True, \n                             create_right_hand_pose=False,\n                             create_left_hand_pose=False,\n                             use_pca=False,\n                             num_pca_comps=12,\n                             num_betas=10,\n                             flat_hand_mean=False)\n    return body_model\n\n# Load SMPL-X parameters\nparam_path =  \"./100samples/Apose/param/Argentina_male_buff_thermal wear_20~30 years old_1573.npy\"\nparam = np.load(param_path, allow_pickle=True).item()\n\n# Extract SMPL-X parameters\nsmpl_params = param['smpl_params'].reshape(1, -1)\nscale, transl, global_orient, pose, betas, left_hand_pose, right_hand_pose, jaw_pose, leye_pose, reye_pose, expression = torch.split(smpl_params, [1, 3, 3, 63, 10, 45, 45, 3, 3, 3, 10], dim=1)\n\n# Initialize SMPL-X model and generate vertices\ndevice = torch.device(\"cpu\")\nmodel = init_smplx_model().to(device)\noutput = model(global_orient=global_orient, body_pose=pose, betas=betas, left_hand_pose=left_hand_pose,\n               right_hand_pose=right_hand_pose, jaw_pose=jaw_pose, leye_pose=leye_pose, reye_pose=reye_pose,\n               expression=expression)\nvertices = output.vertices[0].detach().cpu().numpy()\nfaces = model.faces\n\n# Create a Trimesh and Pyrender mesh\nmesh = trimesh.Trimesh(vertices, faces)\nmesh_pyrender = pyrender.Mesh.from_trimesh(mesh)\n   \nrendered_images_list = []\n\n# Loop through multiple camera views\nfor idx in range(24):\n    scene = pyrender.Scene()\n    scene.add(mesh_pyrender)\n\n    # Load and process camera parameters\n    camera_params = param['poses']\n    intrinsic_params = camera_params[idx][1]  # fx, fy, cx, cy\n    extrinsic_params = camera_params[idx][0] # R|T\n\n    # Set up Pyrender camera\n    camera = pyrender.IntrinsicsCamera(fx=intrinsic_params[0], fy=intrinsic_params[1],\n                                       cx=intrinsic_params[2], cy=intrinsic_params[3])\n\n    # Convert COLMAP coordinates to Pyrender-compatible transformation\n    extrinsic_params_inv = torch.inverse(extrinsic_params.clone())\n    scale_factor = extrinsic_params_inv[:3, :3].norm(dim=1)\n    extrinsic_params_inv[:3, 1:3] = -extrinsic_params_inv[:3, 1:3]\n    extrinsic_params_inv[3, :3] = 0\n\n    # Add camera and lighting\n    scene.add(camera, pose=extrinsic_params_inv)\n    light = pyrender.DirectionalLight(color=[1.0, 1.0, 1.0], intensity=10.0)\n    scene.add(light, pose=extrinsic_params_inv)\n\n    # Render the scene\n    renderer = pyrender.OffscreenRenderer(640, 896)\n    color, depth = renderer.render(scene)\n    rendered_images_list.append(color)\n    renderer.delete()\n\n# Save rendered images as a video\nrendered_images = np.stack(rendered_images_list)\nimageio.mimwrite('rendered_results.mp4', rendered_images, fps=15)\nprint(\"Rendered results saved as rendered_results.mp4\")\n\n# Load an existing video and test alignment\nvideo_path = param_path.replace(\"param\", \"videos\").replace(\"npy\", \"mp4\")\ninput_video = imageio.get_reader(video_path)\ninput_frames = [frame for frame in input_video]\nblended_frames = [(0.5 * frame + 0.5 * render_frame).astype(np.uint8) for render_frame, frame in zip(rendered_images, input_frames)]\nimageio.mimwrite('aligned_results.mp4', blended_frames, fps=15)\nprint(\"Blended video saved as aligned_results.mp4\")\n"
  },
  {
    "path": "dataset/README.md",
    "content": "# 🌟 HuGe100K Dataset Documentation\r\n\r\n## 📊 Dataset Overview\r\nHuGe100K is a large-scale multi-view human dataset featuring diverse attributes, high-fidelity appearances, and well-aligned SMPL-X models.\r\n\r\n## 📁 File Format and Structure\r\n\r\nThe dataset is organized with the following structure:\r\n\r\n```\r\nHuGe100K/\r\n├── flux_batch1/\r\n│   ├── images[0...9]/            #  different batch of images\r\n│   │   ├── videos/               # Folder for .mp4 and .jpg files\r\n│   │   │   ├── Algeria_female_average_high fashion_50~60 years old_844.jpg\r\n│   │   │   ├── Algeria_female_average_high fashion_50~60 years old_844.mp4\r\n│   │   │   └── ... (more .jpg and .mp4)\r\n│   │   └── param/               # Folder for parameter files (.npy)\r\n│   │       ├── Algeria_female_average_high fashion_50~60 years old_844.npy\r\n│   │       └── ... (more .npy files)\r\n├── flux_batch2/\r\n│   └── ... (similar structure with images[0...9])\r\n├── flux_batch3/\r\n│   └── ... (similar structure with images[0...9])\r\n├── flux_batch4/\r\n│   └── ... (similar structure with images[0...9])\r\n├── flux_batch5/\r\n│   └── ... (similar structure with images[0...9])\r\n├── flux_batch6/\r\n│   └── ... (similar structure with images[0...9])\r\n├── flux_batch7/\r\n│   └── ... (similar structure with images[0...9])\r\n├── flux_batch8/\r\n│   └── ... (similar structure with images[0...9])\r\n├── flux_batch9/\r\n│   └── ... (similar structure with images[0...9])\r\n└── deepfashion/\r\n    └── ... (similar structure with images[0...9])\r\n```\r\n\r\nWhere:\r\n- Each `images[X]` folder contains:\r\n  - `videos/`: Reference images and generatedvideo files\r\n  - `param/`: Camera and body pose parameters\r\n- **flux_batch1 through flux_batch7**: Contains subjects in A-pose\r\n- **flux_batch8 and flux_batch9**: Contains subjects in diverse poses \r\n- **deepfashion**: Contains subjects in A-pose (derived from the DeepFashion dataset)\r\n\r\n### File Naming Convention\r\nFiles follow the naming pattern: `Area_Gender_BodyType_Clothing_Age_ID.extension`\r\n\r\nFor example:\r\n- `Algeria_female_average_high fashion_50~60 years old_844.jpg`: Reference image of an Algerian female with average build in high fashion clothing\r\n- `Algeria_female_average_high fashion_50~60 years old_844.npy`: Parameter file for the same subject\r\n\r\n### 📸 Sample Visualization\r\n\r\n<div style=\"display: flex; align-items: center; justify-content: center; gap: 10px; flex-wrap: nowrap; width: 100%;\">\r\n  <img src=\"sample/videos/Kenya_female_fit_streetwear_50~60 years old_1501.jpg\" alt=\"Kenya Female Fit Streetwear Image\" style=\"max-width: 45%; width: 45%; height: auto;\">\r\n  <span style=\"font-weight: bold;\"> =MVChamp=> </span>\r\n  <!-- <video autoplay loop muted playsinline style=\"max-width: 45%; width: 45%; height: auto;\">\r\n    <source src=\"sample/videos/Kenya_female_fit_streetwear_50~60 years old_1501.gif\" type=\"video/mp4\">\r\n    Your browser does not support the video tag.\r\n  </video> -->\r\n    <img src=\"sample/videos/Kenya_female_fit_streetwear_50~60 years old_1501.gif\" alt=\"Kenya Female Fit Streetwear Image\" style=\"max-width: 45%; width: 45%; height: auto;\">\r\n</div>\r\n\r\n\r\n\r\n## 📈 Dataset Statistics\r\n\r\n- **Total Subjects**: 100,000+\r\n- **Views per Subject**: Multiple viewpoints covering 360° in 24 views\r\n- **Pose Types**: A-pose and diverse poses\r\n\r\n## 🔍 Visualizing the Dataset\r\n\r\nFor visualization and data parsing examples, please refer to our provided script:\r\n`visualize_samples.py`. This script demonstrates how to:\r\n\r\n- Load the SMPL-X parameters from `.npy` files\r\n- Render the 3D human model from multiple camera views\r\n- Compare rendered results with the original video frames\r\n\r\nRequirements for visualization:\r\n- SMPL-X model (download from [official website](https://smpl-x.is.tue.mpg.de/))\r\n- Python packages: `pyrender`, `trimesh`, `smplx`, `numpy`, `torch`\r\n\r\nExample usage:\r\n```bash\r\npython visualize_samples.py\r\n```\r\n\r\nThe script will generate:\r\n- `rendered_results.mp4`: Rendered views of the 3D model\r\n- `aligned_results.mp4`: Blended visualization of rendered model with original frames\r\n\r\n## 📋 Usage Guidelines\r\n\r\n1. **Research Purposes Only**: This dataset is intended for academic and research purposes.\r\n2. **Citation Required**: If you use this dataset in your research, please cite our paper.\r\n3. **No Commercial Use**: Commercial use is permitted only with explicit permission from us at yiyu.zhuang@smail.nju.edu.cn.\r\n4. **DeepFashion Derivatives**: See License and Attribution section below for special requirements.\r\n\r\n## ⚖️ License and Attribution (DeepFashion)\r\n\r\nThis dataset includes images derived from the **DeepFashion** dataset, originally provided by MMLAB at The Chinese University of Hong Kong. The use of DeepFashion images in this dataset has been explicitly authorized by the original authors solely for the purpose of creating and distributing this dataset. **Users must not further reproduce, distribute, sell, or commercially exploit any images or derived data originating from DeepFashion.** For any subsequent or separate use of the DeepFashion data, users must directly obtain authorization from MMLAB and comply with the original [DeepFashion License](https://mmlab.ie.cuhk.edu.hk/projects/DeepFashion.html). "
  },
  {
    "path": "dataset/visualize_samples.py",
    "content": "\nimport torch\nimport numpy as np\nimport os\nos.environ[\"PYOPENGL_PLATFORM\"] = \"osmesa\"\nimport smplx\nimport trimesh\nimport pyrender\nimport imageio\n\ndef init_smplx_model():\n    \"\"\"Initialize the SMPL-X model with predefined settings.\"\"\"\n    body_model = smplx.SMPLX('PATH_TO_YOUR_SMPLX_FOLDER',\n                             gender=\"neutral\", \n                             create_body_pose=False, \n                             create_betas=False, \n                             create_global_orient=False, \n                             create_transl=False,\n                             create_expression=True,\n                             create_jaw_pose=True, \n                             create_leye_pose=True, \n                             create_reye_pose=True, \n                             create_right_hand_pose=False,\n                             create_left_hand_pose=False,\n                             use_pca=False,\n                             num_pca_comps=12,\n                             num_betas=10,\n                             flat_hand_mean=False)\n    return body_model\n\n# Load SMPL-X parameters\nparam_path =  \"./samples/param/Kenya_female_fit_streetwear_50~60 years old_1501.npy\"\nparam = np.load(param_path, allow_pickle=True).item()\n\n# Extract SMPL-X parameters\nsmpl_params = param['smpl_params'].reshape(1, -1)\nscale, transl, global_orient, pose, betas, left_hand_pose, right_hand_pose, jaw_pose, leye_pose, reye_pose, expression = torch.split(smpl_params, [1, 3, 3, 63, 10, 45, 45, 3, 3, 3, 10], dim=1)\n\n# Initialize SMPL-X model and generate vertices\ndevice = torch.device(\"cpu\")\nmodel = init_smplx_model().to(device)\noutput = model(global_orient=global_orient, body_pose=pose, betas=betas, left_hand_pose=left_hand_pose,\n               right_hand_pose=right_hand_pose, jaw_pose=jaw_pose, leye_pose=leye_pose, reye_pose=reye_pose,\n               expression=expression)\nvertices = output.vertices[0].detach().cpu().numpy()\nfaces = model.faces\n\n# Create a Trimesh and Pyrender mesh\nmesh = trimesh.Trimesh(vertices, faces)\nmesh_pyrender = pyrender.Mesh.from_trimesh(mesh)\n   \nrendered_images_list = []\n\n# Loop through multiple camera views\nfor idx in range(24):\n    scene = pyrender.Scene()\n    scene.add(mesh_pyrender)\n\n    # Load and process camera parameters\n    camera_params = param['poses']\n    intrinsic_params = camera_params[idx][1]  # fx, fy, cx, cy\n    extrinsic_params = camera_params[idx][0] # R|T\n\n    # Set up Pyrender camera\n    camera = pyrender.IntrinsicsCamera(fx=intrinsic_params[0], fy=intrinsic_params[1],\n                                       cx=intrinsic_params[2], cy=intrinsic_params[3])\n\n    # Convert COLMAP coordinates to Pyrender-compatible transformation\n    extrinsic_params_inv = torch.inverse(extrinsic_params.clone())\n    scale_factor = extrinsic_params_inv[:3, :3].norm(dim=1)\n    extrinsic_params_inv[:3, 1:3] = -extrinsic_params_inv[:3, 1:3]\n    extrinsic_params_inv[3, :3] = 0\n\n    # Add camera and lighting\n    scene.add(camera, pose=extrinsic_params_inv)\n    light = pyrender.DirectionalLight(color=[1.0, 1.0, 1.0], intensity=10.0)\n    scene.add(light, pose=extrinsic_params_inv)\n\n    # Render the scene\n    renderer = pyrender.OffscreenRenderer(640, 896)\n    color, depth = renderer.render(scene)\n    rendered_images_list.append(color)\n    renderer.delete()\n\n# Save rendered images as a video\nrendered_images = np.stack(rendered_images_list)\nimageio.mimwrite('rendered_results.mp4', rendered_images, fps=15)\nprint(\"Rendered results saved as rendered_results.mp4\")\n\n# Load an existing video and test alignment\nvideo_path = param_path.replace(\"param\", \"videos\").replace(\"npy\", \"mp4\")\ninput_video = imageio.get_reader(video_path)\ninput_frames = [frame for frame in input_video]\nblended_frames = [(0.5 * frame + 0.5 * render_frame).astype(np.uint8) for render_frame, frame in zip(rendered_images, input_frames)]\nimageio.mimwrite('aligned_results.mp4', blended_frames, fps=15)\nprint(\"Blended video saved as aligned_results.mp4\")\n"
  },
  {
    "path": "env/README.md",
    "content": "# Environment Setup Guide\r\n\r\n## Prerequisites\r\n\r\n- Python 3.10\r\n- CUDA 11.8\r\n- PyTorch 2.3.1\r\n\r\n## Installation Steps\r\n\r\n### 1. Environment Preparation\r\n\r\nFirst, create and activate a conda environment:\r\n```bash\r\nconda create -n idol python=3.10\r\nconda activate idol\r\n```\r\n\r\n\r\nInstall all dependencies:\r\n```bash\r\nbash scripts/pip_install.sh\r\n```\r\n\r\n### 2. Download Required Models\r\n\r\nBefore proceeding, please register on:\r\n- [SMPL-X website](https://smpl-x.is.tue.mpg.de/)\r\n- [FLAME website](https://flame.is.tue.mpg.de/)\r\n\r\nThen download the template files:\r\n```bash\r\nbash scripts/fetch_template.sh\r\n```\r\n\r\n### 3. Download Pretrained Models and caches with:\r\n```bash\r\nbash scripts/download_files.sh      # download pretrained models\r\n```\r\n\r\nOr mannually download the following models from HuggingFace:\r\n- [IDOL Model Checkpoint](https://huggingface.co/yiyuzhuang/IDOL/blob/main/model.ckpt)\r\n- [Sapiens Pretrained Model](https://huggingface.co/yiyuzhuang/IDOL/blob/main/sapiens_1b_epoch_173_torchscript.pt2)\r\n\r\n## System Requirements\r\n\r\n- **GPU**: NVIDIA GPU with CUDA 11.8 support\r\n- **GPU RAM**: Recommended 24GB+\r\n- **Storage**: At least 15GB free space\r\n\r\n## Common Issues & Solutions\r\n**Issue**: \r\n```\r\nImportError: libGL.so.1: cannot open shared object file: No such file or directory\r\n```\r\nwhen importing OpenCV (`import cv2`)\r\n\r\n**Solution**:\r\n```bash\r\n# For Ubuntu/Debian\r\nsudo apt-get install libgl1-mesa-glx\r\n```\r\n\r\n### 2. Gaussian Splatting Antialiasing Issue\r\n**Issue**: Error related to `antialiasing=True` setting in `GaussianRasterizationSettings`\r\n\r\n**Solution**:\r\nThis issue arises due to updates in the Gaussian Splatting repository. Reinstalling the module from the GitHub repository with the latest version should resolve the problem."
  },
  {
    "path": "lib/__init__.py",
    "content": "from .models import *\nfrom .mmutils import *\nfrom .humanlrm_wrapper_sa_v1 import SapiensGS_SA_v1\n"
  },
  {
    "path": "lib/datasets/__init__.py",
    "content": "from .avatar_dataset import AvatarDataset\nfrom .dataloader import DataModuleFromConfig"
  },
  {
    "path": "lib/datasets/avatar_dataset.py",
    "content": "import os\nimport random\nimport numpy as np\nimport torch\nimport json\nimport pickle\n\nfrom torch.utils.data import Dataset\nimport torchvision.transforms.functional as F\nimport pickle\nfrom torch.utils.data import Dataset\nimport webdataset as wds\n# from lib.utils.train_util import print\n\nimport cv2\nimport av\n\nfrom omegaconf import OmegaConf, ListConfig\n\ndef load_pose(path):\n    with open(path, 'rb') as f:\n        pose_param = json.load(f)\n    c2w = np.array(pose_param['cam_param'], dtype=np.float32).reshape(4,4)\n    cam_center = c2w[:3, 3]\n    w2c = np.linalg.inv(c2w)\n    # pose[:,:2] *= -1\n    # pose = np.loadtxt(path, dtype=np.float32, delimiter=' ').reshape(9, 4)\n    return [torch.from_numpy(w2c), torch.from_numpy(cam_center)]\n\ndef load_npy(file_path):\n    return np.load(file_path, allow_pickle=True)\n\ndef load_smpl(path, smpl_type='smpl'):\n    filetype = path.split('.')[-1]\n    with open(path, 'rb') as f:\n        if filetype=='pkl':\n            smpl_param_data = pickle.load(f)\n        elif filetype == 'json':\n            smpl_param_data = json.load(f)\n        else:\n            assert False\n\n    if smpl_type=='smpl':\n        with open(os.path.join(os.path.split(path)[0][:-5], 'pose', '000_000.json'), 'rb') as f:\n            tf_param = json.load(f)\n        smpl_param = np.concatenate([np.array(tf_param['scale']).reshape(1, -1), np.array(tf_param['center'])[None], \n                    smpl_param_data['global_orient'], smpl_param_data['body_pose'].reshape(1, -1), smpl_param_data['betas']], axis=1)\n    elif smpl_type == 'smplx':\n\n        tf_param = np.load(os.path.join(os.path.dirname(os.path.dirname(path)), 'scale_offset.npy'), allow_pickle=True).item()\n        # smpl_param = np.concatenate([np.array([tf_param['scale']]).reshape(1, -1), tf_param['offset'].reshape(1, -1), \n        smpl_param = np.concatenate([np.array([[1]]), np.array([[0,0,0]]), \n                    np.array(smpl_param_data['global_orient']).reshape(1, -1), np.array(smpl_param_data['body_pose']).reshape(1, -1), \n                    np.array(smpl_param_data['betas']).reshape(1, -1), np.array(smpl_param_data['left_hand_pose']).reshape(1, -1), np.array(smpl_param_data['right_hand_pose']).reshape(1, -1),\n                    np.array(smpl_param_data['jaw_pose']).reshape(1, -1), np.array(smpl_param_data['leye_pose']).reshape(1, -1), np.array(smpl_param_data['reye_pose']).reshape(1, -1), \n                    np.array(smpl_param_data['expression']).reshape(1, -1)], axis=1)\n    else:\n        assert False\n    \n    return torch.from_numpy(smpl_param.astype(np.float32)).reshape(-1)\n\n\nclass AvatarDataset(Dataset):\n    def __init__(self,\n                 data_prefix,\n                 code_dir=None,\n                #  code_only=False,\n                 load_imgs=True,\n                 load_norm=False,\n                 specific_observation_idcs=None,\n                 specific_observation_num=None,\n                 first_is_front=False, # yy add  # If True, it will returns a random sampled batch with the front view in the first place\n                 better_range=False, # yy add  # If True, the views will not be fully random, but will be selected by a better skip\n                 if_include_video_ref_img= False,# yy add Define a variable to indicate whether to include reference images from the video\n                 prob_include_video_ref_img= 0.2, # yy add Define a variable to specify the probability\n                 allow_k_angles_near_the_front = 0, # yy add, if value > 0, the front view will be allowed to be selected from the range of [front_view - allow_k_angles_near_the_front, front_view + allow_k_angles_near_the_front]\n                #  num_test_imgs=0, \n                 if_use_swap_face_v1=False, # yy add, if True, use the swap face v1\n                 random_test_imgs=False,\n                 scene_id_as_name=False,\n                 cache_path=None,\n                 cache_repeat=None, # be the same length with the cache_path\n                 test_pose_override=None,\n                 num_train_imgs=-1,\n                 load_cond_data=True,\n                 load_test_data=True, \n                 max_num_scenes=-1,  # for debug or testing\n                #  radius=0.5,\n                 radius=1.0,\n                 img_res=[640, 896],\n                 test_mode=False,\n                 step=1,  # only for debug & visualization purpose\n                 crop=False # randomly crop the image with upper body inputs\n                 ):\n        super(AvatarDataset, self).__init__()\n        self.data_prefix = data_prefix\n        self.code_dir = code_dir\n        # self.code_only = code_only\n        self.load_imgs = load_imgs\n        self.load_norm = load_norm\n        self.specific_observation_idcs = specific_observation_idcs\n        self.specific_observation_num = specific_observation_num\n        self.first_is_front = first_is_front\n        \n\n\n        self.if_include_video_ref_img= if_include_video_ref_img \n        self.prob_include_video_ref_img = prob_include_video_ref_img\n        self.allow_k_angles_near_the_front = allow_k_angles_near_the_front \n\n        self.better_range = better_range\n        # self.num_test_imgs = num_test_imgs\n        self.random_test_imgs = random_test_imgs\n        self.scene_id_as_name = scene_id_as_name\n        self.cache_path = cache_path\n        self.cache_repeat = cache_repeat\n        self.test_pose_override = test_pose_override\n        self.num_train_imgs = num_train_imgs\n        self.load_cond_data = load_cond_data\n        self.load_test_data = load_test_data\n        self.max_num_scenes = max_num_scenes\n        self.step = step\n\n        self.if_use_swap_face_v1 = if_use_swap_face_v1\n        # import ipdb; ipdb.set_trace()\n\n        self.img_res = [int(i) for i in img_res]\n\n        self.radius = torch.tensor([radius], dtype=torch.float32).expand(3)\n        self.center = torch.zeros_like(self.radius)\n\n        self.load_scenes()\n        \n        self.crop = crop\n\n\n        self.test_poses = self.test_intrinsics = None\n\n        self.defalut_focal = 1120 #40 * (self.img_res[0]/32) # focal 80mm, sensor 32mm\n\n        self.default_fxy_cxy = torch.tensor([self.defalut_focal, self.defalut_focal,  self.img_res[1]//2, self.img_res[0]//2]).reshape(1, 4)\n\n        self.test_mode = test_mode\n\n        if self.test_mode:\n            self.parse_scene = self.parse_scene_test\n\n\n    def load_scenes(self):\n\n        \n        if  isinstance(self.cache_path, ListConfig):\n\n            cache_list = []\n            case_num_per_dataset = 1000000000\n            for ii, path in enumerate(self.cache_path):\n                cache = np.load(path, allow_pickle=True)\n                if self.cache_repeat is not None:\n                    cache = np.repeat(cache, self.cache_repeat[ii], axis=0)\n                print(\"done loading \", path)\n                cache_list.extend(cache[:case_num_per_dataset]) \n            scenes = cache_list\n            print(f\"=========intialized totally {len(scenes)} scenes===========\")\n\n        else:\n            if self.cache_path is not None and os.path.exists(self.cache_path):\n                scenes = np.load(self.cache_path, allow_pickle=True)\n                print(\"load \", self.cache_path)\n            else:\n                print(f\"{self.cache_path} is not exist\")\n                raise  FileNotFoundError(f\"maybe {self.cache_path} is not exist\")\n            \n        end = len(scenes)\n        if self.max_num_scenes >= 0:\n            end = min(end, self.max_num_scenes * self.step)\n        \n        self.scenes = scenes[:end:self.step]\n        self.num_scenes = len(self.scenes)\n        \n    def parse_scene(self, scene_id):\n        scene = self.scenes[scene_id]\n        input_is_video = False # flag of if the input is video, some operations should be different\n        # print(scene)\n        # scene['video_path'] = \"/data/jxlv/transformers/src/A_pose_MEN-Denim-id_00000089-01_7_additional/result.mp4\"\n        # scene['image_paths'] = None #\"/data/jxlv/transformers/src/A_pose_MEN-Denim-id_00000089-01_7_additional/source_seg.png\"\n        # import pdb\n        # pdb.set_trace()\n        # =========== loading the params ===========\n        param = np.load(scene['param_path'], allow_pickle=True).item()\n        scene.update(param)\n        print(scene.keys())\n\n        # =========== loading the multi-view images ===========\n        if scene['image_paths'] is None:\n            input_is_video = True\n            video_path = scene['video_path']\n            try:\n                if self.if_use_swap_face_v1:\n                    image_paths_or_video = read_frames(scene['video_path'].replace('result.mp4', 'output.mp4'))\n                else:\n                    image_paths_or_video = read_frames(scene['video_path'])\n            except Exception as e:\n                print(f\"Error: {e}\")\n                print(f\"Error in reading the video : {scene['video_path'].replace('result.mp4', 'output.mp4')}\")\n                image_paths_or_video = read_frames(scene['video_path'])\n            # if 'pose_animate_service_0727' in video_path or 'flux' in video_path:\n            #     # move the first to the last\n            #     # TODO fixed this bug with a better cameras parameters\n            #     image_paths_or_video = image_paths_or_video[1:] + image_paths_or_video[0:1]\n        if not input_is_video:\n            image_paths_or_video = scene['image_paths']\n        scene_name = f\"{scene_id:0>4d}\" #  image_paths[0].split('/')[-3]\n        results = dict(\n            scene_id=[scene_id],\n            scene_name=\n                '{:04d}'.format(scene_id) if self.scene_id_as_name else scene_name,\n                # cpu_only=True\n                )\n        # import pdb; pdb.set_trace()\n        # if not self.code_only:\n        poses = scene['poses']\n        smpl_params = scene['smpl_params']\n\n        # if input_is_video:\n        #     num_imgs = len(video)\n        # else:\n        num_imgs = len(image_paths_or_video)\n        # front_view = num_imgs // 4\n        # randonly / specificically select the views of output\n        smplx_cam_rotate = smpl_params[4: 7] #get global orient # 1, 3, 63, 10\n        # smpl_params[70:80] = torch.rand_like(smpl_params[70:80]); print(\"error !! need to delete this rand betas in avatarnet:287\") # get betas\n        front_view = find_front_camera_by_rotation(poses, smplx_cam_rotate) # inputs camera poses and smplx poses\n        if self.allow_k_angles_near_the_front > 0:\n            allow_n_views_near_the_front =  round(self.allow_k_angles_near_the_front / 360 * num_imgs)\n            new_front_view = random.randint(-allow_n_views_near_the_front, allow_n_views_near_the_front) + front_view\n            if new_front_view >= num_imgs:\n                new_front_view = new_front_view - num_imgs\n            elif new_front_view < 0:\n                new_front_view = new_front_view + num_imgs\n            front_view = new_front_view\n            print(\"changes the front views ranges\", front_view, \"+-\", allow_n_views_near_the_front)\n\n        if self.specific_observation_idcs is None:  ######### if not specify views ########\n            # if self.num_train_imgs >= 0:\n            #     num_train_imgs = self.num_train_imgs\n            # else:\n            num_train_imgs = num_imgs\n            if self.random_test_imgs: ###### randomly selected images with self.num_train_imgs ######\n                cond_inds = random.sample(range(num_imgs), self.num_train_imgs)\n            elif self.specific_observation_num: ###### randomly selected \"specific_observation_num\" images ######\n                if self.first_is_front and self.specific_observation_num < 2:\n                    # self.specific_observation_num = 2\n                    cond_inds =torch.tensor([front_view, front_view]) # first for input, second for supervised\n                elif self.better_range: # select views by a uniform distribution range\n                    if self.first_is_front: # must include the front view\n                        num_train_imgs = self.specific_observation_num - 2\n                    else:\n                        num_train_imgs = self.specific_observation_num\n                    skip_range = num_imgs//num_train_imgs\n                    # select random views from each range of [skip_range] seperate from [0, skip_range, 2*skip_range, ...], \n                    cond_inds = torch.randperm(num_train_imgs) * skip_range \\\n                            + torch.randint(low=0, high=skip_range, size=[num_train_imgs])\n                    if self.first_is_front: # concat [the first view * 2] to the front of cond_inds\n                        cond_inds = torch.cat([torch.tensor([front_view, front_view]), cond_inds])\n                        \n                else: # previous version, random views are sampled\n                    cond_inds = torch.randperm(num_imgs)[:self.specific_observation_num]\n            else:\n                cond_inds = np.round(np.linspace(0, num_imgs - 1, num_train_imgs)).astype(np.int64)\n        else:   ######### selected target views ########\n            cond_inds = self.specific_observation_idcs\n\n\n        test_inds = list(range(num_imgs))\n\n        if self.specific_observation_num: # yy note: if specific_observation_num is not None, then remove the test_inds\n            test_inds = []\n        else:\n            for cond_ind in cond_inds:\n                test_inds.remove(cond_ind)\n        cond_smpl_param_ref = torch.zeros([189])\n        if_use_smpl_param_ref = torch.Tensor([1]) # 默认使用ref smpl, \n        if self.load_cond_data and len(cond_inds) > 0:\n            # cond_imgs, cond_poses, cond_intrinsics, cond_img_paths, cond_smpl_param, cond_norm = gather_imgs(cond_inds)\n            cond_imgs, cond_poses, cond_intrinsics, cond_img_paths, cond_smpl_param, cond_norm = \\\n                gather_imgs(cond_inds, poses, image_paths_or_video, smpl_params, load_imgs=self.load_imgs, load_norm=self.load_norm, center=self.center, radius=self.radius,\n                input_is_video=input_is_video)\n            cond_smpl_param_ref = cond_smpl_param.clone() # the smpl_param_ref for the reference images\n            if cond_intrinsics.shape[-1] == 3: # the old data format, which contains the value of camera center instead of fxfycxcy\n                cond_intrinsics = self.default_fxy_cxy.clone().repeat(cond_intrinsics.shape[0], 1)\n            # import pdb; pdb.set_trace()\n            # print(\"video_path\", video_path)\n            if input_is_video:\n                cond_img_paths = [f\"{video_path[:-4]}_{i:0>4d}.png\" for i in  range(self.specific_observation_num)] # Replace the .mp4 into index.png\n            \n            \n            if self.if_include_video_ref_img and input_is_video:\n                # 设置一个随机数，如果小于某个概率，那么替换第一张图为另一个图片\n                if np.random.rand() < self.prob_include_video_ref_img:\n                    if 'image_ref' in scene:\n                        ref_image_path = scene['image_ref']\n                        print(\"ref_image_path\",ref_image_path)\n                    else:\n                        ref_image_path = video_path.replace(\".mp4\", \".jpg\")\n                    # if \"flux\" in ref_image_path: # temperaturelly supports the inputs from the flux\n                    try:\n                        # replacement_img_path =  ref_image_path\n                        # replacement_img = load_image(ref_image_path)  # 假设有一个函数可以加载图片\n                        # 使用cv2.IMREAD_UNCHANGED标志读取图片，以保留alpha通道\n                        img = cv2.imread(ref_image_path, cv2.IMREAD_UNCHANGED)\n                        assert img is not None, f\"img is None, {ref_image_path}\"\n                        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)\n                        img = torch.from_numpy(img.astype(np.float32) / 255)  # (h, w, 3)\n                            # print(\"img.shape\", img.shape)\n                            # print(\"cond_imgs.shape\", cond_imgs.shape)\n                        # test_img_paths[0] = ref_image_path\n                        # results.update(test_imgs=test_imgs, test_img_paths=test_img_paths)\n\n                        # ======== loading the reference smplx for images ==========\n                        load_ref_smplx = False\n                        if load_ref_smplx:\n                            # if \"flux\" in ref_image_path:\n                            smplx_smplify_path = from_video_to_get_ref_smplx(video_path)\n                            # load json and get values\n                            with open(smplx_smplify_path) as f:\n                                data = json.load(f)\n\n                            RT = torch.concatenate([ torch.Tensor(data['camera']['R']), torch.Tensor(data['camera']['t']).reshape(3,1)], dim=1)\n                            RT = torch.cat([RT, torch.Tensor([[0,0,0,1]])], dim=0)\n\n\n                            intri = torch.Tensor(data['camera']['focal'] + data['camera']['princpt'])\n\n                            intri[[3,2]] = intri[[2,3]]\n                            intri = intri * self.default_fxy_cxy[0,-1] / intri[-1]\n                            \n                            # 假设 smpl_param_data 是已经加载好的数据 \n                            # (['root_pose', 'body_pose', 'jaw_pose', 'leye_pose', 'reye_pose', 'lhand_pose', 'rhand_pose', 'expr', 'trans', 'betas_save', 'nlf_smplx_betas', 'camera', 'img_path'])\n                            smpl_param_data = data \n\n                            # 从字典中提取所需的数据\n                            global_orient = np.array(smpl_param_data['root_pose']).reshape(1, -1)\n                            body_pose = np.array(smpl_param_data['body_pose']).reshape(1, -1)\n                            shape = np.array(smpl_param_data['betas_save']).reshape(1, -1)[:, :10]\n                            left_hand_pose = np.array(smpl_param_data['lhand_pose']).reshape(1, -1)\n                            right_hand_pose = np.array(smpl_param_data['rhand_pose']).reshape(1, -1)\n\n                            # smpl_param_ref = np.concatenate([np.array([[1.]]), np.array([[0, 0, 0]]),\n                            smpl_param_ref = np.concatenate([np.array([[1.]]), np.array(smpl_param_data['trans']).reshape(1,3),\n                                global_orient,body_pose,\n                                shape, left_hand_pose, right_hand_pose,\n                                np.array(smpl_param_data['jaw_pose']).reshape(1, -1), np.array(smpl_param_data['leye_pose']).reshape(1, -1), np.array(smpl_param_data['reye_pose']).reshape(1, -1),\n                                np.array(smpl_param_data['expr']).reshape(1, -1)[:,:10]], axis=1)\n\n\n                            cond_poses[0] = RT         # RT\n                            cond_intrinsics[0]  = intri   # fxfycxcy\n                            cond_smpl_param_ref =  torch.Tensor(smpl_param_ref).reshape(-1)      # 189, combines \n                            if_use_smpl_param_ref = torch.Tensor([1])  # use the smpl_param_ref\n\n                        # overwrite some datas\n                        cond_imgs[0] = img\n                        cond_img_paths[0] = ref_image_path\n\n                    except (FileNotFoundError, json.JSONDecodeError, KeyError, Exception) as e:\n                        # 记录错误信息到日志文件\n                        # with open(self.log_file_path, 'a') as log_file:\n                        #     log_file.write(f\"{datetime.datetime.now()} {video_path} \\n  - An error occurred: {str(e)}\\n\")\n                        print(f\"An error occurred: {e}\")\n                        \n            ''' randomly crop the first image for augmentation'''\n            if self.crop:\n                # print(\"crop\", cond_imgs[0].shape)\n                if random.random() < 0.5:\n                    # cond_imgs[0] = F.crop(cond_imgs[0], 0, 0, 512, 512)\n                    crop_imgs = cond_imgs[0]\n                    # 图像尺寸\n                    h, w, _ = crop_imgs.shape\n\n                    # 随机偏移量\n                    random_offset_head = np.random.randint(-h//7, -h//8)\n                    random_offset_body = np.random.randint(-h // 8, h // 8)\n\n                    # head_joint, upper_body_joint\n                    head_joint = [ w//2, h//7,]\n                    upper_body_joint = [w//2, h//2, ]\n\n                    # 计算裁剪区域\n                    head_y = int(head_joint[1]) + random_offset_head\n                    body_y = int(upper_body_joint[1]) + random_offset_body\n\n                    # 确保裁剪区域在图像范围内\n                    head_y = max(0, min(h, head_y))\n                    body_y = max(0, min(h, body_y))\n\n                    # 裁剪区域的高度和宽度\n                    crop_height = body_y - head_y\n                    crop_width =int(crop_height * 640 / 896)\n\n                    # 确保裁剪区域在图像范围内\n                    start_x = max(0, min(w - crop_width, int(w // 2 - crop_width // 2)))\n                    end_x = start_x + crop_width\n                    start_y = max(0, head_y)\n                    end_y = min(h, body_y)\n\n                    # 裁剪图像\n                    cropped_img = crop_imgs[start_y:end_y, start_x:end_x]\n\n            \n                    padded_img = F.resize(cropped_img.permute(2, 0, 1), [h, w]).permute(1, 2, 0)\n\n                    # save this img for debug\n                    # Image.fromarray((padded_img.numpy() * 255).astype(np.uint8)).save(f\"debug_crop.png\")\n                    # rescale the image for augmentation\n                    cond_imgs[0] = random_scale_and_crop(padded_img, (0.8,1.2))\n                else:\n                    cond_imgs[0] = random_scale_and_crop(cond_imgs[0], (0.8,1.1))\n\n            \n            results.update(\n                cond_poses=cond_poses,\n                cond_intrinsics=cond_intrinsics.to(torch.float32),\n                cond_img_paths=cond_img_paths, \n                cond_smpl_param=cond_smpl_param,\n                cond_smpl_param_ref=cond_smpl_param_ref,\n                if_use_smpl_param_ref=if_use_smpl_param_ref)\n            if cond_imgs is not None:\n                results.update(cond_imgs=cond_imgs)\n            if cond_norm is not None:\n                results.update(cond_norm=cond_norm)\n\n        if self.load_test_data and len(test_inds) > 0:\n            test_imgs, test_poses, test_intrinsics, test_img_paths, test_smpl_param, test_norm = \\\n                    gather_imgs(test_inds, poses, image_paths_or_video, smpl_params, load_imgs=self.load_imgs, load_norm=self.load_norm, center=self.center, radius=self.radius)\n            \n            if test_intrinsics.shape[-1] == 3: # the old data format, which contains the value of camera center instead of fxfycxcy\n                test_intrinsics = self.default_fxy_cxy.clone().repeat(test_intrinsics.shape[0], 1)\n\n            results.update(\n                test_poses=test_poses,\n                test_intrinsics=test_intrinsics,\n                test_img_paths=test_img_paths,\n                test_smpl_param=test_smpl_param)\n            if test_imgs is not None:\n                results.update(test_imgs=test_imgs)\n            if test_norm is not None:\n                results.update(test_norm=test_norm)\n\n    \n        if self.test_pose_override is not None:\n            results.update(test_poses=self.test_poses, test_intrinsics=self.test_intrinsics)\n        return results\n\n    def __len__(self):\n        return self.num_scenes\n\n    def __getitem__(self, scene_id):\n        try:\n            scene = self.parse_scene(scene_id)\n\n        except:\n            print(\"ERROR in parsing \", scene_id)\n            scene = self.parse_scene(0)\n        return scene\n    \n    def parse_scene_test(self, scene_id):\n        scene = self.scenes[scene_id]\n        input_is_video = False # flag of if the input is video, some operations should be different\n\n        \n        # =========== loading the params ===========\n        param = np.load(scene['param_path'], allow_pickle=True).item()\n        scene.update(param)\n        print(scene.keys())\n        # import ipdb; ipdb.set_trace()\n        if scene['image_paths'] is None:\n            input_is_video = True\n            video_path = scene['video_path']\n            try:\n                image_paths_or_video = read_frames(scene['video_path'])\n            except Exception as e:\n                print(f\"Error: {e}\")\n                print(f\"Error in reading the video : {scene['video_path'].replace('result.mp4', 'output.mp4')}\")\n                image_paths_or_video = read_frames(scene['video_path'])\n\n        if not input_is_video:\n            image_paths_or_video = scene['image_paths']\n        scene_name = f\"{scene_id:0>4d}\" #  image_paths[0].split('/')[-3]\n        results = dict(\n            scene_id=[scene_id],\n            scene_name=\n                '{:04d}'.format(scene_id) if self.scene_id_as_name else scene_name,\n                # cpu_only=True\n                )\n        \n        # if not self.code_only:\n        poses = scene['poses']\n        smpl_params = scene['smpl_params']\n\n        # if input_is_video:\n        #     num_imgs = len(video)\n        # else:\n        num_imgs = len(image_paths_or_video)\n        # front_view = num_imgs // 4\n        # randonly / specificically select the views of output\n        smplx_cam_rotate = smpl_params[4: 7] #get global orient # 1, 3, 63, 10\n        # smpl_params[70:80] = torch.rand_like(smpl_params[70:80]); print(\"error !! need to delete this rand betas in avatarnet:287\") # get betas\n        front_view = find_front_camera_by_rotation(poses, smplx_cam_rotate) # inputs camera poses and smplx poses\n        if self.allow_k_angles_near_the_front > 0:\n            allow_n_views_near_the_front =  round(self.allow_k_angles_near_the_front / 360 * num_imgs)\n            new_front_view = random.randint(-allow_n_views_near_the_front, allow_n_views_near_the_front) + front_view\n            if new_front_view >= num_imgs:\n                new_front_view = new_front_view - num_imgs\n            elif new_front_view < 0:\n                new_front_view = new_front_view + num_imgs\n            front_view = new_front_view\n            print(\"changes the front views ranges\", front_view, \"+-\", allow_n_views_near_the_front)\n\n        num_train_imgs = num_imgs\n        test_inds = torch.Tensor( list(range(num_imgs)))\n        cond_inds = np.concatenate([np.array([front_view]),test_inds]).astype(np.int64) # first for input, second for supervised\n        test_inds = cond_inds.tolist()\n        # if self.specific_observation_num: # yy note: if specific_observation_num is not None, then remove the test_inds\n        #     test_inds = []\n        # else:\n        # for cond_ind in cond_inds:\n        #     test_inds.remove(cond_ind)\n        # cond_inds = cond_inds\n\n        cond_smpl_param_ref = torch.zeros([189])\n        if_use_smpl_param_ref = torch.Tensor([1]) # 默认使用ref smpl, \n        if self.load_cond_data and len(cond_inds) > 0:\n            # cond_imgs, cond_poses, cond_intrinsics, cond_img_paths, cond_smpl_param, cond_norm = gather_imgs(cond_inds)\n            cond_imgs, cond_poses, cond_intrinsics, cond_img_paths, cond_smpl_param, cond_norm = \\\n                gather_imgs(cond_inds, poses, image_paths_or_video, smpl_params, load_imgs=self.load_imgs, load_norm=self.load_norm, center=self.center, radius=self.radius, \\\n                input_is_video=input_is_video)\n            cond_smpl_param_ref = cond_smpl_param.clone() # the smpl_param_ref for the reference images\n            if cond_intrinsics.shape[-1] == 3: # the old data format, which contains the value of camera center instead of fxfycxcy\n                cond_intrinsics = self.default_fxy_cxy.clone().repeat(cond_intrinsics.shape[0], 1)\n\n            if input_is_video:\n                cond_img_paths = [f\"{video_path[:-4]}_{i:0>4d}.png\" for i in  range(self.specific_observation_num)] # Replace the .mp4 into index.png\n            \n                    \n            \n            results.update(\n                cond_poses=cond_poses,\n                cond_intrinsics=cond_intrinsics.to(torch.float32),\n                cond_img_paths=cond_img_paths, \n                cond_smpl_param=cond_smpl_param,\n                cond_smpl_param_ref=cond_smpl_param_ref,\n                if_use_smpl_param_ref=if_use_smpl_param_ref)\n            if cond_imgs is not None:\n                results.update(cond_imgs=cond_imgs)\n            if cond_norm is not None:\n                results.update(cond_norm=cond_norm)\n\n        if self.load_test_data and len(test_inds) > 0:\n            print(\"input_is_video\", input_is_video)\n            test_imgs, test_poses, test_intrinsics, test_img_paths, test_smpl_param, test_norm = \\\n                    gather_imgs(test_inds, poses, image_paths_or_video, smpl_params, load_imgs=self.load_imgs, load_norm=self.load_norm, center=self.center, radius=self.radius, \\\n                    input_is_video=input_is_video)\n            \n            if test_intrinsics.shape[-1] == 3: # the old data format, which contains the value of camera center instead of fxfycxcy\n                test_intrinsics = self.default_fxy_cxy.clone().repeat(test_intrinsics.shape[0], 1)\n\n            results.update(\n                test_poses=test_poses,\n                test_intrinsics=test_intrinsics,\n                test_img_paths=test_img_paths,\n                test_smpl_param=test_smpl_param)\n            if test_imgs is not None:\n                results.update(test_imgs=test_imgs)\n            if test_norm is not None:\n                results.update(test_norm=test_norm)\n\n    \n        if self.test_pose_override is not None:\n            results.update(test_poses=self.test_poses, test_intrinsics=self.test_intrinsics)\n        return results\n\n\ndef gather_imgs(img_ids, poses, image_paths_or_video, smpl_params, load_imgs=True, load_norm=False, center=None, radius=None, input_is_video=False):\n    imgs_list = [] if load_imgs else None\n    norm_list = [] if load_norm else None\n    poses_list = []\n    cam_centers_list = []\n    img_paths_list = []\n   \n    for img_id in img_ids:\n        pose = poses[img_id][0]\n        cam_centers_list.append((poses[img_id][1]).to(torch.float)) # (C)\n        c2w = pose.to(torch.float)#torch.FloatTensor(pose) # 虽然是c2w但其实存的值应该是w2c (R|T)\n        cam_to_ndc = torch.cat(\n            [c2w[:3, :3], (c2w[:3, 3:] - center[:, None]) / radius[:, None]], dim=-1)\n        poses_list.append(\n            torch.cat([\n                cam_to_ndc,\n                cam_to_ndc.new_tensor([[0.0, 0.0, 0.0, 1.0]])\n            ], dim=-2))\n        if input_is_video:\n            # img_paths_list.append(video[img_id])\n            img = image_paths_or_video[img_id]\n            # for img in imgs: # add the ajustment to make the color > [250,250,250] to be white\n            mask_white = np.all(img[:,:,:3] > 250, axis=-1, keepdims=False)\n            # Image.fromarray(img).save(f\"debug.png\")\n            # Image.fromarray(mask_white).save(f\"debug_mask.png\")\n            img[mask_white] = [255, 255, 255]\n            # Image.fromarray(img).save(f\"debug_afmask.png\")\n            img = torch.from_numpy(img.astype(np.float32) / 255)  # (h, w, 3)\n           \n            imgs_list.append(img)\n        else:\n            img_paths_list.append(image_paths_or_video[img_id])\n            if load_imgs:\n                # img = mmcv.imread(image_paths[img_id], channel_order='rgb')\n                                        \n                # 使用cv2.IMREAD_UNCHANGED标志读取图片，以保留alpha通道\n                print(\"Loading, .......\", image_paths_or_video[img_id])\n                print(\"Loading, .......\", image_paths_or_video[img_id])\n                print(\"Loading, .......\", image_paths_or_video[img_id])\n                img = cv2.imread(image_paths_or_video[img_id], cv2.IMREAD_UNCHANGED)\n                print(\"img.shape\", img.shape)\n                # 将透明像素的RGB值设置为白色（255, 255, 255）\n                img[img[..., 3] == 0] = [255, 255, 255, 255]\n                img = img[..., :3]\n                img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)\n                img = torch.from_numpy(img.astype(np.float32) / 255)  # (h, w, 3)\n                imgs_list.append(img)\n        if load_norm: # have not support the input type is video\n            norm = cv2.imread(image_paths_or_video[img_id].replace('rgb', 'norm'), cv2.IMREAD_UNCHANGED)\n            norm = cv2.cvtColor(norm, cv2.COLOR_BGR2RGB)\n            norm = torch.from_numpy(norm.astype(np.float32) / 255)\n            norm_list.append(norm)\n    poses_list = torch.stack(poses_list, dim=0)  # (n, 4, 4)\n    cam_centers_list = torch.stack(cam_centers_list, dim=0)\n    if load_imgs:\n        imgs_list = torch.stack(imgs_list, dim=0)  # (n, h, w, 3)\n    if load_norm:\n        norm_list = torch.stack(norm_list, dim=0)\n    return imgs_list, poses_list, cam_centers_list, img_paths_list, smpl_params, norm_list\n\nfrom scipy.spatial.transform import Rotation as R\ndef calculate_angle(vector1, vector2):\n    unit_vector1 = vector1 / torch.linalg.norm(vector1)\n    unit_vector2 = vector2 / torch.linalg.norm(vector2)\n    dot_product = torch.dot(unit_vector1, unit_vector2)\n    angle = torch.arccos(dot_product)\n    return angle\ndef axis_angle_to_rotation_matrix(axis_angle):\n    if_input_is_torch = torch.is_tensor(axis_angle)\n    if if_input_is_torch:\n        dtype_torch = axis_angle.dtype\n        axis_angle = axis_angle.numpy()\n        \n    r = R.from_rotvec(axis_angle)\n    rotation_matrix = r.as_matrix()\n    if if_input_is_torch:\n        rotation_matrix = torch.from_numpy(rotation_matrix).to(dtype_torch)\n\n    return rotation_matrix\n# def find_front_camera_by_global_orient(global_orient, camera_direction):\n#     front_direction = np.array([0, 0, -1])  # 人体正面方向\n#     min_angle = float('inf')\n#     front_camera_idx = -1\n\n#     for idx, global_orient in enumerate(global_orient_list):\n        \n\n#         angle = calculate_angle(body_direction, front_direction)\n#         if angle < min_angle:\n#             min_angle = angle\n#             front_camera_idx = idx\n\n#     return front_camera_idx\ndef find_front_camera_by_rotation(poses, global_orient_human):\n    # front_direction = global_orient_human  # 人体正面方向\n    rotation_matrix = axis_angle_to_rotation_matrix(global_orient_human)\n    front_direction = rotation_matrix @ torch.Tensor([0, 0, -1])  # 人体的朝向\n    min_angle = float('inf')\n    front_camera_idx = -1\n\n    for idx, pose in enumerate(poses):\n        rotation_matrix = pose[0][:3, :3]\n        camera_direction = rotation_matrix @ torch.Tensor([0, 0, 1])  # 相机的朝向\n        angle = calculate_angle(camera_direction, front_direction).to(camera_direction.dtype)\n        if angle < min_angle:\n            min_angle = angle\n            front_camera_idx = idx\n\n    return front_camera_idx\n\ndef read_frames(video_path):\n    container = av.open(video_path)\n\n    video_stream = next(s for s in container.streams if s.type == \"video\")\n    frames = []\n    for packet in container.demux(video_stream):\n        for frame in packet.decode():\n            # image = Image.frombytes(\n            #     \"RGB\",\n            #     (frame.width, frame.height),\n            #     frame.to_rgb().to_ndarray(),\n            # )\n            image =  frame.to_rgb().to_ndarray()\n            frames.append(image)\n\n    return frames\n\n\ndef prepare_camera( resolution_x, resolution_y, num_views=24, stides=1):\n    # resolution_x = 640\n    # resolution_y = 896\n    import math\n    focal_length = 40 #80\n    sensor_width = 32\n\n    # # 创建 Pyrender 相机\n    # camera = pyrender.PerspectiveCamera(yfov=fov, aspectRatio=aspect_ratio)\n    focal_length = focal_length * (resolution_y/sensor_width)\n\n    K = np.array(\n        [[focal_length, 0, resolution_x//2],\n        [0, focal_length, resolution_y//2],\n        [0, 0, 1]]\n    )\n    # print(\"update!! the camera intrisic is error 0819\")\n    def look_at(camera_position, target_position, up_vector):  # colmap +z forward, +y down\n        forward = -(camera_position - target_position) / np.linalg.norm(camera_position - target_position)\n        right = np.cross(up_vector, forward)\n        up = np.cross(forward, right)\n        return np.column_stack((right, up, forward))\n    camera_pose_list = []\n    for frame_idx in range(0, num_views, stides):\n        # 设置相机的位置和方向\n        camera_dist = 1.5 #3 #1.2 * 2\n        phi = math.radians(90)\n        theta = (frame_idx / num_views) * math.pi * 2\n        camera_location = np.array(\n            [camera_dist * math.sin(phi) * math.cos(theta),\n            \n            camera_dist * math.cos(phi),\n            -camera_dist * math.sin(phi) * math.sin(theta),]\n            )\n        # print(camera_location)\n        camera_pose = np.eye(4)\n        camera_pose[:3, 3] = camera_location\n        # print(\"camera_location\", camera_location)\n\n        # from smplx import look_at\n\n\n        # 设置相机位置和目标位置\n        camera_position = camera_location\n        target_position = np.array([0.0, 0.0, 0.0])\n\n        # 计算相机的旋转矩阵，使其朝向目标\n        # up_vector = np.array([0.0, 1.0, 0.0])\n        up_vector = np.array([0.0, -1.0, 0.0]) # colmap\n        rotation_matrix = look_at(camera_position, target_position, up_vector)\n\n        # 更新相机的位置和旋转\n        camera_pose[:3, :3] = rotation_matrix\n        camera_pose[:3, 3] = camera_position\n        camera_pose_list.append(camera_pose)\n    return K, camera_pose_list\n\n\ndef from_video_to_get_ref_smplx(video_path):\n    # 分解路径\n    video_dir = os.path.dirname(video_path)\n    video_name = video_dir.split(\"/\")[-1] # 视频文件夹名称\n    \n    # 替换视频目录为 smplify 目录\n    if \"flux\" in video_dir:\n        smplify_dir = video_dir.replace('/videos/', '/smplx_smplify/')\n    elif \"DeepFashion\" in video_dir:\n        smplify_dir = video_dir.replace('/video/', '/smplx_smplify/').replace(\"A_pose_\", \"\")\n    \n    # # 获取视频文件名（不包括扩展名）\n    # video_name = os.path.splitext(video_file)[0]\n    \n    # 构建 JSON 文件路径\n    if smplify_dir[-1] == '/': smplify_dir = smplify_dir[:-1]\n    json_path = smplify_dir+\".json\" #os.path.join(smplify_dir, f\"{video_name}.json\")\n    \n    return json_path\n\ndef random_scale_and_crop(image: torch.Tensor, scale_range=(0.8, 1.2)) -> torch.Tensor:\n    \"\"\"\n    Randomly scale the input image and crop/pad to maintain original size.\n\n    Args:\n        image: Input image tensor of shape [H, W, 3]\n        scale_range: Range for scaling factor, default (0.8, 1.2)\n\n    Returns:\n        Scaled and cropped/padded image tensor of shape [H, W, 3]\n    \"\"\"\n    is_numpy = False\n    if not torch.is_tensor(image):\n        image = torch.from_numpy(image)\n        is_numpy = True\n    # 获取图像的高度和宽度\n    h, w = image.shape[:2]\n\n    # 生成随机缩放因子\n    scale_factor = random.uniform(*scale_range)\n\n    # 计算新的高度和宽度\n    new_h = int(h * scale_factor)\n    new_w = int(w * scale_factor)\n\n    # 使用 torchvision.transforms.functional.resize 进行缩放\n    scaled_image = F.resize(image.permute(2, 0, 1), [new_h, new_w]).permute(1, 2, 0)\n\n    # 如果缩放后的图像比原图大，进行居中裁剪\n    if new_h > h or new_w > w:\n        top = (new_h - h) // 2\n        left = (new_w - w) // 2\n        scaled_image = scaled_image[top:top + h, left:left + w]\n    else:\n        # 如果缩放后的图像比原图小，进行居中填充\n        padded_image = torch.ones((h, w, 3), dtype=image.dtype)\n        top = h-new_h #(h - new_h) // 2 # H不应该居中\n        left = (w - new_w) // 2\n        padded_image[top:top + new_h, left:left + new_w] = scaled_image\n        scaled_image = padded_image\n    if is_numpy:\n        scaled_image = scaled_image.numpy()\n    return scaled_image\n\n        \n\nif __name__ == \"__main__\":\n\n\n    import os\n\n          \n\n    params = {\n        \"data_prefix\": None,\n        \"cache_path\":  ListConfig([ \n            \"./processed_data/deepfashion_train_145_local.npy\",\n            \"./processed_data/flux_batch1_5000_train_145_local.npy\"\n        ]),\n        \"specific_observation_num\": 5,\n        \"better_range\": True,\n        \"first_is_front\": True,\n        \"if_include_video_ref_img\": True,\n        \"prob_include_video_ref_img\": 0.5,\n        \"img_res\": [640, 896],\n        'test_mode': True\n    }\n\n    data = AvatarDataset(**params)\n\n    sample = data[0]\n    print(sample.keys())\n\n\n    import os\n    import torch.distributed as dist\n\n    os.environ['RANK'] = '0'\n    os.environ['WORLD_SIZE'] = '1'\n    os.environ['MASTER_ADDR'] = '127.0.0.1'\n    os.environ['MASTER_PORT'] = '29500'\n\n    dist.init_process_group(backend='nccl', rank=0, world_size=1)\n\n\n    # test the batch loader\n    from torch.utils.data import DataLoader\n    dataloader = DataLoader(data, batch_size=10, shuffle=True, collate_fn=custom_collate_fn)\n\n\n    from torch.utils.data.distributed import DistributedSampler\n    import webdataset as wds\n    # sampler = DistributedSampler(data) # training  is true!~\n    sampler = None\n    dataloader = wds.WebLoader(data, batch_size=10, num_workers=1, shuffle=False, sampler=sampler,  )\n\n        \n    try:\n        for i, batch in enumerate(dataloader):\n            print(batch.keys())\n            # break\n    except Exception as e:\n        import traceback\n        print(\"Caught an exception during dataloader iteration:\")\n        traceback.print_exc()\n"
  },
  {
    "path": "lib/datasets/dataloader.py",
    "content": "\nimport os, sys\nimport json\n\n\nimport numpy as np\nimport webdataset as wds\nimport pytorch_lightning as pl\nimport torch\nfrom torch.utils.data import Dataset\nfrom torch.utils.data.distributed import DistributedSampler\n\nfile_dir = os.path.abspath(os.path.dirname(__file__))\nproject_root = os.path.join(file_dir, '..', '..')\nsys.path.append(project_root)  \nfrom lib.utils.train_util import instantiate_from_config\nfrom torch.utils.data import DataLoader \n\nclass DataModuleFromConfig(pl.LightningDataModule):\n    def __init__(\n        self, \n        batch_size=8, \n        num_workers=4, \n        train=None, \n        validation=None, \n        test=None, \n        **kwargs,\n    ):\n        super().__init__()\n\n        self.batch_size = batch_size\n        self.num_workers = num_workers\n\n        self.dataset_configs = dict()\n        if train is not None:\n            self.dataset_configs['train'] = train\n        if validation is not None:\n            self.dataset_configs['validation'] = validation\n        if test is not None:\n            self.dataset_configs['test'] = test\n\n    def setup(self, stage):\n        self.datasets = dict((k, instantiate_from_config(self.dataset_configs[k])) for k in self.dataset_configs)\n\n    def train_dataloader(self):\n        sampler = DistributedSampler(self.datasets['train']) if torch.distributed.is_initialized() else None\n        return DataLoader(\n            self.datasets['train'],\n            batch_size=self.batch_size,\n            num_workers=self.num_workers,\n            shuffle=(sampler is None),\n            sampler=sampler,\n            pin_memory=True,\n            drop_last=True,\n        )\n\n    def val_dataloader(self):\n        sampler = DistributedSampler(self.datasets['validation']) if torch.distributed.is_initialized() else None\n        return DataLoader(\n            self.datasets['validation'],\n            batch_size=1,\n            num_workers=self.num_workers,\n            shuffle=False,\n            sampler=sampler,\n            pin_memory=True\n        )\n\n    def test_dataloader(self):\n        return DataLoader(\n            self.datasets['test'],\n            batch_size=self.batch_size,\n            num_workers=self.num_workers,\n            shuffle=False,\n            pin_memory=True\n        )\n"
  },
  {
    "path": "lib/humanlrm_wrapper_sa_v1.py",
    "content": "\nimport os\nimport math\nimport json\nfrom torch.optim import Adam\nfrom torch.nn.parallel.distributed import DistributedDataParallel\nimport torch\nimport torch.nn.functional as F\nfrom torchvision.transforms import InterpolationMode\nfrom torchvision.utils import make_grid, save_image\nfrom torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity\nfrom torchmetrics.image.ssim import StructuralSimilarityIndexMeasure\nimport pytorch_lightning as pl\nfrom pytorch_lightning.utilities.grads import grad_norm\nfrom einops import rearrange, repeat\n\nfrom lib.utils.train_util import instantiate_from_config\nfrom lib.ops.activation import TruncExp\nimport time\nimport matplotlib.pyplot as plt\n\nfrom PIL import Image\n\nimport numpy as np \nfrom lib.utils.train_util import main_print\n\nfrom typing import List, Optional, Tuple, Union\ndef get_1d_rotary_pos_embed(\n    dim: int,\n    pos: Union[torch.Tensor, int],\n    theta: float = 10000.0,\n    use_real=False,\n    linear_factor=1.0,\n    ntk_factor=1.0,\n    repeat_interleave_real=True,\n):\n    \"\"\"\n    Precompute the frequency tensor for complex exponentials (cis) with given dimensions.\n\n    This function calculates a frequency tensor with complex exponentials using the given dimension 'dim' and the end\n    index 'end'. The 'theta' parameter scales the frequencies. The returned tensor contains complex values in complex64\n    data type.\n\n    Args:\n        dim (`int`): Dimension of the frequency tensor.\n        pos (`torch.Tensor` or `int`): Position indices for the frequency tensor. [S] or scalar\n        theta (`float`, *optional*, defaults to 10000.0):\n            Scaling factor for frequency computation. Defaults to 10000.0.\n        use_real (`bool`, *optional*):\n            If True, return real part and imaginary part separately. Otherwise, return complex numbers.\n        linear_factor (`float`, *optional*, defaults to 1.0):\n            Scaling factor for the context extrapolation. Defaults to 1.0.\n        ntk_factor (`float`, *optional*, defaults to 1.0):\n            Scaling factor for the NTK-Aware RoPE. Defaults to 1.0.\n        repeat_interleave_real (`bool`, *optional*, defaults to `True`):\n            If `True` and `use_real`, real part and imaginary part are each interleaved with themselves to reach `dim`.\n            Otherwise, they are concateanted with themselves.\n    Returns:\n        `torch.Tensor`: Precomputed frequency tensor with complex exponentials. [S, D/2]\n    \"\"\"\n    assert dim % 2 == 0\n\n    if isinstance(pos, int):\n        pos = torch.arange(pos)\n    theta = theta * ntk_factor\n    freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) / linear_factor  # [D/2]\n    t = pos # torch.from_numpy(pos).to(freqs.device)  # type: ignore  # [S]\n    freqs = freqs.to(device=t.device, dtype=t.dtype)  # type: ignore\n    freqs = torch.outer(t, freqs).float()  # type: ignore   # [S, D/2]\n    if use_real and repeat_interleave_real:\n        freqs_cos = freqs.cos().repeat_interleave(2, dim=1)  # [S, D]\n        freqs_sin = freqs.sin().repeat_interleave(2, dim=1)  # [S, D]\n        return freqs_cos, freqs_sin\n    elif use_real:\n        freqs_cos = torch.cat([freqs.cos(), freqs.cos()], dim=-1)  # [S, D]\n        freqs_sin = torch.cat([freqs.sin(), freqs.sin()], dim=-1)  # [S, D]\n        return freqs_cos, freqs_sin\n    else:\n        freqs_cis = torch.polar(torch.ones_like(freqs), freqs)  # complex64     # [S, D/2]\n        return freqs_cis\nclass FluxPosEmbed(torch.nn.Module):\n    # modified from https://github.com/black-forest-labs/flux/blob/c00d7c60b085fce8058b9df845e036090873f2ce/src/flux/modules/layers.py#L11\n    def __init__(self, theta: int, axes_dim: [int]):\n        super().__init__()\n        self.theta = theta\n        self.axes_dim = axes_dim\n\n    def forward(self, ids: torch.Tensor) -> torch.Tensor:\n        n_axes = ids.shape[-1]\n        cos_out = []\n        sin_out = []\n        pos = ids.float()\n        is_mps = ids.device.type == \"mps\"\n        freqs_dtype = torch.float32 if is_mps else torch.float64\n        for i in range(n_axes):\n            cos, sin = get_1d_rotary_pos_embed(\n                self.axes_dim[i], pos[:, i], repeat_interleave_real=True, use_real=True#, freqs_dtype=freqs_dtype\n            )\n            cos_out.append(cos)\n            sin_out.append(sin)\n        freqs_cos = torch.cat(cos_out, dim=-1).to(ids.device)\n        freqs_sin = torch.cat(sin_out, dim=-1).to(ids.device)\n        return freqs_cos, freqs_sin\n    \nclass SapiensGS_SA_v1(pl.LightningModule):\n    \n    def __init__(\n        self,\n        encoder=dict(type='mmpretrain.VisionTransformer'),\n        neck=dict(type='mmpretrain.VisionTransformer'),\n        decoder=dict(),\n        diffusion_use_ema=True,\n        freeze_decoder=False,\n        image_cond=False,\n        code_permute=None,\n        code_reshape=None,\n        autocast_dtype=None,\n        ortho=True,\n        return_norm=False,\n        # reshape_type='reshape', # 'cnn'\n        code_size=None,\n        decoder_use_ema=None,\n        bg_color=1,\n        training_mode=None, # stage2's flag, default None for stage1 \n\n        patch_size: int = 4,\n            \n        warmup_steps: int = 12_000,\n        use_checkpoint: bool = True,\n        lambda_depth_tv: float = 0.05,\n        lambda_lpips: float = 2.0,\n        lambda_mse: float = 1.0,\n        lambda_l1: float=0, \n        lambda_ssim: float=0, \n        neck_learning_rate: float = 5e-4,\n        decoder_learning_rate: float = 1e-3,\n        encoder_learning_rate: float=0,\n        max_steps: int = 100_000,\n        loss_coef: float = 0.5,\n        init_iter: int = 500,\n        lambda_offset: int = 50,  # offset_weight: 50\n        scale_weight: float = 0.01,\n        is_debug: bool = False, # if debug, then it will not returns lpips\n        code_activation: dict=None,\n        output_hidden_states: bool=False, # if True, will output the hidden states from sapiens shallow layer, for the neck decoder\n        loss_weights_views: List = [], # the loss weights for the views, if empty, will use the same weights for all the views\n        **kwargs\n    ):\n        super(SapiensGS_SA_v1, self).__init__()\n        ## ========== part -- Add the code to save this parameters for optimizers ========\n        self.warmup_steps = warmup_steps\n        self.use_checkpoint = use_checkpoint\n        self.lambda_depth_tv = lambda_depth_tv\n        self.lambda_lpips = lambda_lpips\n        self.lambda_mse = lambda_mse\n        self.lambda_l1 = lambda_l1\n        self.lambda_ssim = lambda_ssim  \n        self.neck_learning_rate = neck_learning_rate\n        self.decoder_learning_rate = decoder_learning_rate\n        self.encoder_learning_rate = encoder_learning_rate\n        self.max_steps = max_steps\n\n        self.loss_coef = loss_coef\n        self.init_iter = init_iter\n        self.lambda_offset = lambda_offset\n        self.scale_weight = scale_weight\n\n        self.is_debug = is_debug\n        ## ========== end part ========\n\n     \n        \n        self.code_size = code_size\n        if code_activation['type'] == 'tanh':\n            self.code_activation = torch.nn.Tanh()\n        else:\n            self.code_activation = TruncExp() #build_module(code_activation)\n        # self.grid_size = grid_size\n        self.decoder = instantiate_from_config(decoder)\n        self.decoder_use_ema = decoder_use_ema\n        if decoder_use_ema:\n            raise NotImplementedError(\"decoder_use_ema has not been implemented\")\n            if self.decoder_use_ema:\n                self.decoder_ema = deepcopy(self.decoder)\n        self.encoder = instantiate_from_config(encoder)\n        # get_obj_from_str(config[\"target\"])\n\n        self.code_size = code_reshape\n        self.code_clip_range = [-1,1]\n        \n        # ============= begin config  =============\n        # transformer from class MAEPretrainDecoder(BaseModule):\n         # compress the token number of the uv code\n        self.patch_size = patch_size\n        self.code_patch_size = self.patch_size\n        self.num_patches_axis = code_reshape[-1]//self.patch_size # reshape it for the upsampling\n        self.num_patches = self.num_patches_axis ** 2\n        self.code_feat_dims = code_reshape[0] # only used for the upsampling of 'reshape' type \n        self.code_resolution = code_reshape[-1] # only used for the upsampling of 'reshape' type \n\n        self.reshape_type = self.decoder.reshape_type\n\n        \n        self.inputs_front_only = True\n        self.render_loss_all_view = True\n        self.if_include_video_ref_img = True\n        \n        self.training_mode = training_mode\n\n        self.loss_weights_views = torch.Tensor(loss_weights_views).reshape(-1) / sum(loss_weights_views)  # normalize the weights\n        \n        \n\n        # ========== config meaning ===========\n        self.neck =  instantiate_from_config(neck)\n\n        self.ids_restore = torch.arange(0, self.num_patches).unsqueeze(0)\n        self.freeze_decoder = freeze_decoder\n        if self.freeze_decoder:\n            self.decoder.requires_grad_(False)\n            if self.decoder_use_ema:\n                self.decoder_ema.requires_grad_(False)\n        self.image_cond = image_cond\n        self.code_permute = code_permute\n        self.code_reshape = code_reshape\n        self.code_reshape_inv = [self.code_size[axis] for axis in self.code_permute] if code_permute is not None \\\n            else self.code_size\n        self.code_permute_inv = [self.code_permute.index(axis) for axis in range(len(self.code_permute))] \\\n            if code_permute is not None else None\n\n        self.autocast_dtype = autocast_dtype\n        self.ortho = ortho\n        self.return_norm = return_norm\n\n        '''add a flag for the skip connection from sapiens shallow layer'''\n        self.output_hidden_states = output_hidden_states\n\n        ''' add the in-the-wild images visualization'''\n        if self.lambda_lpips > 0:\n            self.lpips = LearnedPerceptualImagePatchSimilarity(net_type='vgg')\n        else:\n            self.lpips = None\n\n        self.ssim = StructuralSimilarityIndexMeasure()\n        \n        \n        self.validation_step_outputs = []\n        self.validation_step_code_outputs = [] # saving the code\n        self.validation_step_nvPose_outputs = []\n        self.validation_metrics = []\n\n        \n        # loading the smplx for the nv pose\n        import json\n        import numpy as np\n        # evaluate the animation\n        smplx_path = './work_dirs/demo_data/Ways_to_Catch_360_clip1.json'\n        with open(smplx_path, 'r') as f:\n            smplx_pose_param = json.load(f)\n        smplx_param_list = []\n        for par in smplx_pose_param['annotations']:\n            k = par['smplx_params']\n            for i in k.keys():\n                k[i] = np.array(k[i])\n            left_hands = np.array([1.4624, -0.1615,  0.1361,  1.3851, -0.2597,  0.0247, -0.0683, -0.4478,\n                -0.6652, -0.7290,  0.0084, -0.4818])\n            betas = torch.zeros((10))\n            smplx_param = \\\n                np.concatenate([np.array([1]), np.array([0,0.,0]),  np.array([0, -1, 0])*k['root_orient'], \\\n                                k['pose_body'],betas, \\\n                                    k['pose_hand'], k['pose_jaw'], np.zeros(6), k['face_expr'][:10]], axis=0).reshape(1,-1)\n            # print(smplx_param.shape)\n            smplx_param_list.append(smplx_param)\n        smplx_params = np.concatenate(smplx_param_list, 0)\n        self.smplx_params = torch.Tensor(smplx_params).cuda()\n    def get_default_smplx_params(self):\n        A_pose = torch.Tensor([[ 1.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,\n            0.0000,  0.1047,  0.0000,  0.0000, -0.1047,  0.0000,  0.0000,  0.0000,\n            0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,\n            0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,\n            0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,\n            0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,\n            0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000, -0.7854,  0.0000,\n            0.0000,  0.7854,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,\n            0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000, -0.7470,  1.0966,\n            0.0169, -0.0534, -0.0212,  0.0782, -0.0348,  0.0260,  0.0060,  0.0118,\n            -0.1117, -0.0429,  0.4164, -0.1088,  0.0660,  0.7562,  0.0964,  0.0909,\n            0.1885,  0.1181, -0.0509,  0.5296,  0.1437, -0.0552,  0.7049,  0.0192,\n            0.0923,  0.3379,  0.4570,  0.1963,  0.6255,  0.2147,  0.0660,  0.5069,\n            0.3697,  0.0603,  0.0795,  0.1419,  0.0859,  0.6355,  0.3033,  0.0579,\n            0.6314,  0.1761,  0.1321,  0.3734, -0.8510, -0.2769,  0.0915,  0.4998,\n            -0.0266, -0.0529, -0.5356, -0.0460,  0.2774, -0.1117,  0.0429, -0.4164,\n            -0.1088, -0.0660, -0.7562,  0.0964, -0.0909, -0.1885,  0.1181,  0.0509,\n            -0.5296,  0.1437,  0.0552, -0.7049,  0.0192, -0.0923, -0.3379,  0.4570,\n            -0.1963, -0.6255,  0.2147, -0.0660, -0.5069,  0.3697, -0.0603, -0.0795,\n            0.1419, -0.0859, -0.6355,  0.3033, -0.0579, -0.6314,  0.1761, -0.1321,\n            -0.3734, -0.8510,  0.2769, -0.0915,  0.4998,  0.0266,  0.0529, -0.5356,\n            0.0460, -0.2774,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,\n            0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,\n            0.0000,  0.0000,  0.0000,  0.0000,  0.0000]])\n        return A_pose\n    def forward_decoder(self, decoder, code, target_rgbs, cameras,   \n            smpl_params=None, return_decoder_loss=False, init=False):\n        decoder = self.decoder_ema if self.freeze_decoder and self.decoder_use_ema else self.decoder\n        num_imgs = target_rgbs.shape[1]\n        outputs = decoder(\n            code, smpl_params, cameras,\n            num_imgs, return_loss=return_decoder_loss, init=init, return_norm=False)\n        return outputs\n\n    def on_fit_start(self):\n        if self.global_rank == 0:\n            os.makedirs(os.path.join(self.logdir, 'images'), exist_ok=True)\n            os.makedirs(os.path.join(self.logdir, 'images_val'), exist_ok=True)\n            os.makedirs(os.path.join(self.logdir, 'images_val_code'), exist_ok=True)\n    \n\n\n    def forward(self, data):\n        # print(\"iter\")\n        \n        num_scenes = len(data['scene_id'])  # 8\n        if 'cond_imgs' in data:\n            cond_imgs = data['cond_imgs']  # (num_scenes, num_imgs, h, w, 3)\n            cond_intrinsics = data['cond_intrinsics']  # (num_scenes, num_imgs, 4), in [fx, fy, cx, cy]\n            cond_poses = data['cond_poses'] # (num_scenes, num_imgs, 4, 4)\n            smpl_params = data['cond_smpl_param']  # (num_scenes, c)\n            # if 'cond_norm' in data:cond_norm = data['cond_norm'] else:  cond_norm = None\n            num_scenes, num_imgs, h, w, _ = cond_imgs.size()\n            cameras = torch.cat([cond_intrinsics, cond_poses.reshape(num_scenes, num_imgs, -1)], dim=-1)\n            if self.if_include_video_ref_img: # new! we render to all view, including the input image; # don't compute this loss\n                cameras = cameras[:,1:]\n                num_imgs = num_imgs - 1\n\n\n        if self.inputs_front_only: # default setting to use the first image as input\n            inputs_img_idx = [0]\n        else:\n            raise NotImplementedError(\"inputs_front_only is False\")\n        inputs_img = cond_imgs[:,inputs_img_idx[0],...].permute([0,3,1,2]) # \n        \n        target_imgs = cond_imgs[:, 1:]\n        assert cameras.shape[1] == target_imgs.shape[1]\n\n\n        if self.is_debug:\n            try:\n                code = self.forward_image_to_uv(inputs_img, is_training=self.training) #TODO check where the validation\n            except Exception as e: # OOM\n                main_print(e)\n                code = torch.zeros([num_scenes, 32, 256, 256]).to(inputs_img.dtype).to(inputs_img.device)\n        else:\n            code = self.forward_image_to_uv(inputs_img, is_training=self.training) #TODO check where the validation\n\n        decoder = self.decoder_ema if self.freeze_decoder and self.decoder_use_ema else self.decoder\n        # uvmaps_decoder_gender's forward\n        output = decoder(\n            code, smpl_params, cameras,\n            num_imgs, return_loss=False, init=(self.global_step < self.init_iter), return_norm=False) #(['scales', 'norm', 'image', 'offset'])\n     \n        output['code'] = code\n        output['target_imgs'] = target_imgs\n        output['inputs_img'] = cond_imgs[:,[0],...]\n       \n        # for visualization\n        if self.global_rank == 0 and self.global_step % 200 == 0 and self.is_debug:\n            overlay_imgs = 0.5 * target_imgs + 0.5 * output['image']\n            overlay_imgs = rearrange(overlay_imgs, 'b n h w c -> b h n w c')\n            overlay_imgs = rearrange(overlay_imgs, ' b h n w c -> (b h) (n w) c')\n            overlay_imgs = overlay_imgs.to(torch.float32).detach().cpu().numpy()\n            overlay_imgs = (overlay_imgs * 255).astype(np.uint8)\n            Image.fromarray(overlay_imgs).save(f'debug_{self.global_step}.jpg')\n        \n        return output\n\n    def forward_image_to_uv(self, inputs_img, is_training=True):\n        '''\n            inputs_img: torch.Tensor, bs, 3, H, W\n            return\n            code : bs, 256, 256, 32\n        '''\n        if self.decoder_learning_rate <= 0:\n            with torch.no_grad():\n                features_flatten =  self.encoder(inputs_img, use_my_proces=True, output_hidden_states=self.output_hidden_states) \n        else:\n            features_flatten =  self.encoder(inputs_img, use_my_proces=True, output_hidden_states=self.output_hidden_states) \n        \n        if self.ids_restore.device !=features_flatten.device:\n            self.ids_restore = self.ids_restore.to(features_flatten.device)\n        ids_restore = self.ids_restore.expand([features_flatten.shape[0], -1])\n        uv_code =  self.neck(features_flatten, ids_restore)\n        batch_size, token_num, dims_feature = uv_code.shape\n        \n        if self.reshape_type=='reshape':\n            feature_map = uv_code.reshape(batch_size, self.num_patches_axis, self.num_patches_axis,\\\n                            self.code_feat_dims, self.code_patch_size, self.code_patch_size) # torch.Size([1, 64, 64, 32, 4, 4, ])  \n            feature_map = feature_map.permute(0, 3, 1, 4, 2, 5)   # ([1, 32, 64, 4, 64, 4])\n            feature_map = feature_map.reshape(batch_size, self.code_feat_dims,  self.code_resolution, self.code_resolution) # torch.Size([1, 32, 256, 256])\n            code = feature_map # [1, 32, 256, 256]\n        else:\n            feature_map = uv_code.reshape(batch_size, self.num_patches_axis, self.num_patches_axis,dims_feature) # torch.Size([1, 64, 64, 512, ])  \n            if isinstance(self.decoder, DistributedDataParallel):\n                code = self.decoder.module.upsample_conv(feature_map.permute([0,3,1,2])) # torch.Size([1, 32, 256, 256])\n            else:\n                code = self.decoder.upsample_conv(feature_map.permute([0,3,1,2])) # torch.Size([1, 32, 256, 256])\n\n        code = self.code_activation(code)\n        return code\n\n    def compute_loss(self, render_out):\n        render_images = render_out['image'] # .Size([1, 5, 896, 640, 3]), range [0, 1]\n        target_images = render_out['target_imgs']\n        target_images  =target_images.to(render_images)\n        if self.is_debug:\n            render_images_tmp= rearrange(render_images, 'b n h w c -> (b n) c h w')\n            target_images_tmp = rearrange(target_images, 'b n h w c -> (b n) c h w')\n            all_images = torch.cat([render_images_tmp, target_images_tmp], dim=2)\n            all_images = render_images_tmp*0.5 + target_images_tmp*0.5\n            grid = make_grid(all_images, nrow=4, normalize=True, value_range=(0, 1))\n            save_image(grid, \"./debug.png\")\n            main_print(\"saving into ./debug.png\")\n           \n\n        render_images = rearrange(render_images, 'b n h w c -> (b n) c h w') * 2.0 - 1.0\n        target_images = rearrange(target_images, 'b n h w c -> (b n) c h w') * 2.0 - 1.0\n        if self.lambda_mse<=0:\n            loss_mse = 0\n        else:\n            if self.loss_weights_views.numel() != 0:\n                b, n, _, _, _ = render_out['image'].shape\n                loss_weights_views = self.loss_weights_views.unsqueeze(0).to(render_images.device)\n                loss_weights_views = loss_weights_views.repeat(b,1).reshape(-1).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)\n                loss_mse = weighted_mse_loss(render_images, target_images, loss_weights_views)\n                main_print(\"weighted sum mse\")\n            else:\n                loss_mse = F.mse_loss(render_images, target_images)\n\n        if self.lambda_l1<=0:\n            loss_l1 = 0\n        else:\n            loss_l1 = F.l1_loss(render_images, target_images)\n\n        if self.lambda_ssim <= 0:\n            loss_ssim = 0\n        else:\n            loss_ssim = 1 - self.ssim(render_images, target_images)\n        if not self.is_debug:\n            if self.lambda_lpips<=0:\n                loss_lpips = 0\n            else:\n                if self.loss_weights_views.numel() != 0:\n                    with torch.cuda.amp.autocast():\n                        loss_lpips = self.lpips(render_images.clamp(-1, 1), target_images)\n                else:\n                    loss_lpips = 0\n                    with torch.cuda.amp.autocast():\n                        for img_idx in range(render_images.shape[0]):\n                            loss_lpips += self.lpips(render_images[[img_idx]].clamp(-1, 1), target_images[[img_idx]])\n                    loss_lpips /= render_images.shape[0]\n                    \n        else:\n            loss_lpips = 0\n        loss_gs_offset = render_out['offset']\n        loss = loss_mse * self.lambda_mse \\\n            + loss_l1 * self.lambda_l1 \\\n            + loss_ssim * self.lambda_ssim \\\n            + loss_lpips * self.lambda_lpips \\\n            + loss_gs_offset * self.lambda_offset\n        \n        prefix = 'train'\n        loss_dict = {}\n        loss_dict.update({f'{prefix}/loss_mse': loss_mse})\n        loss_dict.update({f'{prefix}/loss_lpips': loss_lpips})\n        loss_dict.update({f'{prefix}/loss_gs_offset': loss_gs_offset})\n        loss_dict.update({f'{prefix}/loss_ssim': loss_ssim})\n        loss_dict.update({f'{prefix}/loss_l1': loss_l1})\n        loss_dict.update({f'{prefix}/loss': loss})\n\n        return loss, loss_dict\n    \n    def compute_metrics(self, render_out):\n        # NOTE: all the rgb value range  is [0, 1]\n        # render_out.keys = (['scales', 'norm', 'image', 'offset', 'code', 'target_imgs'])\n        render_images = render_out['image'].clamp(0, 1) # .Size([1, 5, 896, 640, 3]), range [0, 1]\n        target_images = render_out['target_imgs']\n        if target_images.dtype!=render_images.dtype:\n            target_images = target_images.to(render_images.dtype)\n\n        render_images = rearrange(render_images, 'b n h w c -> (b n) c h w')\n        target_images = rearrange(target_images, 'b n h w c -> (b n) c h w').to(render_images)\n\n        mse = F.mse_loss(render_images, target_images).mean()\n        psnr = 10 * torch.log10(1.0 / mse)\n        ssim = self.ssim(render_images, target_images)\n        \n        render_images = render_images * 2.0 - 1.0\n        target_images = target_images * 2.0 - 1.0\n\n        if self.lambda_lpips<=0:\n            lpips = torch.Tensor([0]).to(render_images.device).to(render_images.dtype)\n        else:\n            with torch.cuda.amp.autocast():\n                lpips = self.lpips(render_images, target_images)\n\n        metrics = {\n            'val/mse': mse,\n            'val/pnsr': psnr,\n            'val/ssim': ssim,\n            'val/lpips': lpips,\n        }\n        return metrics\n\n    def new_on_before_optimizer_step(self):\n        norms = grad_norm(self.neck, norm_type=2)\n        if 'grad_2.0_norm_total' in norms:\n            self.log_dict({'grad_norm/lrm_generator': norms['grad_2.0_norm_total']})\n\n    @torch.no_grad()\n    def validation_step(self, batch, batch_idx):\n        render_out = self.forward(batch)\n\n        metrics = self.compute_metrics(render_out)\n        self.validation_metrics.append(metrics)\n        render_images = render_out['image']\n        render_images = rearrange(render_images, 'b n h w c -> b c h (n w)')\n        gt_images = render_out['target_imgs']\n        gt_images = rearrange(gt_images, 'b n h w c-> b c h (n w)')\n        log_images = torch.cat([render_images, gt_images], dim=-2)\n        self.validation_step_outputs.append(log_images)\n\n        self.validation_step_code_outputs.append( render_out['code'])\n\n        render_out_comb = self.forward_nvPose(batch, smplx_given=None)\n        self.validation_step_nvPose_outputs.append(render_out_comb)\n       \n      \n    def forward_nvPose(self, batch, smplx_given):\n        '''\n            smplx_given: torch.Tensor, bs, 189\n            it will returns images with cameras_num * poses_num\n        '''\n        _, num_img, _,_ = batch['cond_poses'].shape\n        # write a code to seperately input the smplx_params\n        if smplx_given == None:\n            step_pose = self.smplx_params.shape[0] // num_img\n            smplx_given = self.smplx_params\n        else:\n            step_pose = 1\n        render_out_list = []\n        for i in range(num_img):  \n            target_pose = smplx_given[[i*step_pose]]\n            bk = batch['cond_smpl_param'].clone()\n            batch['cond_smpl_param'][:, 7:70] = target_pose[:, 7:70] # copy body_pose\n            batch['cond_smpl_param'][:, 80:80+93] = target_pose[:, 80:80+93]# copy pose_hand + pose_jaw\n            batch['cond_smpl_param'][:, 179:189] = target_pose[:, 179:189]# copy face expression\n            render_out_new = self.forward(batch)\n            render_out_list.append(render_out_new['image'])\n        render_out_comb = torch.cat(render_out_list, dim=2) # stack in the H axis\n        render_out_comb = rearrange(render_out_comb, 'b n h w c -> b c h (n w)')\n        return render_out_comb\n\n    \n    def on_validation_epoch_end(self): #\n        images = torch.cat(self.validation_step_outputs, dim=-1)\n        all_images = self.all_gather(images).cpu()\n        all_images = rearrange(all_images, 'r b c h w -> (r b) c h w')\n\n        # nv pose\n        images_pose = torch.cat(self.validation_step_nvPose_outputs, dim=-1)\n        all_images_pose = self.all_gather(images_pose).cpu()\n        all_images_pose = rearrange(all_images_pose, 'r b c h w -> (r b) c h w')\n\n        if self.global_rank == 0:\n            image_path = os.path.join(self.logdir, 'images_val', f'val_{self.global_step:07d}.png')\n\n            grid = make_grid(all_images, nrow=1, normalize=True, value_range=(0, 1))\n            save_image(grid, image_path)\n            main_print(f\"Saved image to {image_path}\")\n\n            metrics = {}\n            for key in self.validation_metrics[0].keys():\n                metrics[key] = torch.stack([m[key] for m in self.validation_metrics]).mean()\n            self.log_dict(metrics, prog_bar=True, logger=True, on_step=False, on_epoch=True)\n\n\n            # code for saving the nvPose images\n            image_path_nvPose = os.path.join(self.logdir, 'images_val', f'val_{self.global_step:07d}_nvPose.png')\n            grid_nvPose = make_grid(all_images_pose, nrow=1, normalize=True, value_range=(0, 1))\n            save_image(grid_nvPose, image_path_nvPose)\n            main_print(f\"Saved image to {image_path_nvPose}\")\n\n\n            # code for saving the code images\n            for i, code in enumerate(self.validation_step_code_outputs):\n                image_path = os.path.join(self.logdir, 'images_val_code')\n                \n                num_scenes, num_chn, h, w = code.size()\n                code_viz = code.reshape(num_scenes, 4, 8, h, w).to(torch.float32).cpu().numpy()\n                code_viz = code_viz.transpose(0, 1, 3, 2, 4).reshape(num_scenes, 4 * h, 8 * w)\n                for j, code_viz_single in enumerate(code_viz):\n                    plt.imsave(os.path.join(image_path, f'val_{self.global_step:07d}_{i*num_scenes+j:04d}' + '.png'), code_viz_single,\n                        vmin=self.code_clip_range[0], vmax=self.code_clip_range[1])\n        self.validation_step_outputs.clear()\n        self.validation_step_nvPose_outputs.clear()\n        self.validation_metrics.clear()\n        self.validation_step_code_outputs.clear()\n    \n    def on_test_start(self):\n        if self.global_rank == 0:\n            os.makedirs(os.path.join(self.logdir, 'images_test'), exist_ok=True)\n    \n    def on_test_epoch_end(self):\n        metrics = {}\n        metrics_mean = {}\n        metrics_var = {}\n        for key in self.validation_metrics[0].keys():\n            tmp = torch.stack([m[key] for m in self.validation_metrics]).cpu().numpy()\n            metrics_mean[key] = tmp.mean()\n            metrics_var[key] = tmp.var()\n\n        formatted_metrics = {}\n        for key in metrics_mean.keys():\n            formatted_metrics[key] = f\"{metrics_mean[key]:.4f}±{metrics_var[key]:.4f}\"\n\n        for key in self.validation_metrics[0].keys():\n            metrics[key] = torch.stack([m[key] for m in self.validation_metrics]).cpu().numpy().tolist()\n        \n\n        final_dict = {\"average\": formatted_metrics,\n                      'details': metrics}\n\n        metric_path = os.path.join(self.logdir, f'metrics.json')\n        with open(metric_path, 'w') as f:\n            json.dump(final_dict, f, indent=4)\n        main_print(f\"Saved metrics to {metric_path}\")\n        \n        for key in metrics.keys():\n            metrics[key] = torch.stack([m[key] for m in self.validation_metrics]).mean()\n        main_print(metrics)\n        \n        self.validation_metrics.clear()\n    \n    def configure_optimizers(self):\n        # define the optimizer and the scheduler for neck and decoder \n        main_print(\"WARNING currently, we only support the single optimizer for both neck and decoder\")\n        \n        learning_rate = self.neck_learning_rate\n        params= [\n            {'params': self.neck.parameters(), 'lr': self.neck_learning_rate, },\n            {'params': self.decoder.parameters(), 'lr': self.decoder_learning_rate},\n        ]\n        if hasattr(self, \"encoder_learning_rate\") and self.encoder_learning_rate>0:\n            params.append({'params': self.encoder.parameters(), 'lr': self.encoder_learning_rate})\n            main_print(\"============add the encoder into the optimizer============\")\n        optimizer = torch.optim.Adam(\n            params\n        )\n        T_warmup, T_max, eta_min = self.warmup_steps, self.max_steps, 0.001\n        lr_lambda = lambda step: \\\n            eta_min + (1 - math.cos(math.pi * step / T_warmup)) * (1 - eta_min) * 0.5 if step < T_warmup else \\\n            eta_min + (1 + math.cos(math.pi * (step - T_warmup) / (T_max - T_warmup))) * (1 - eta_min) * 0.5\n        scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)\n        return  {'optimizer': optimizer, 'lr_scheduler': scheduler}\n    \n    def training_step(self, batch, batch_idx):\n        scheduler = self.lr_schedulers()\n        scheduler.step()\n        render_gt = None #? \n        render_out = self.forward(batch)\n        loss, loss_dict = self.compute_loss(render_out)\n\n\n        self.log_dict(loss_dict, prog_bar=True, logger=True, on_step=True, on_epoch=True)\n       \n        if self.global_step % 200 == 0 and self.global_rank == 0:\n            self.new_on_before_optimizer_step() # log the norm\n        if self.global_step % 200 == 0 and self.global_rank == 0:\n            if self.if_include_video_ref_img and self.training:\n                render_images = torch.cat([ torch.ones_like(render_out['image'][:,0:1]), render_out['image']], dim=1)\n                target_images = torch.cat([ render_out['inputs_img'], render_out['target_imgs']], dim=1)\n\n            target_images = rearrange(\n                target_images, 'b n h w c -> b c h (n w)')\n            render_images = rearrange(\n                render_images, 'b n  h w c-> b c h (n w)')\n            \n\n            grid = torch.cat([\n                target_images, render_images, 0.5*render_images + 0.5*target_images,\n               \n            ], dim=-2)\n            grid = make_grid(grid, nrow=target_images.shape[0], normalize=True, value_range=(0, 1))\n           \n            image_path = os.path.join(self.logdir, 'images', f'train_{self.global_step:07d}.jpg')\n            save_image(grid, image_path)\n            main_print(f\"Saved image to {image_path}\")\n\n        return loss\n        \n    @torch.no_grad()\n    def test_step(self, batch, batch_idx):\n        # input_dict, render_gt = self.prepare_validation_batch_data(batch)\n        render_out = self.forward(batch)\n        render_gt = render_out['target_imgs']\n        render_img = render_out['image']\n        # Compute metrics\n        metrics = self.compute_metrics(render_out)\n        self.validation_metrics.append(metrics)\n        \n        # Save images\n        target_images = rearrange(\n            render_gt, 'b n h w c -> b c h (n w)')\n        render_images = rearrange(\n            render_img, 'b n h w c -> b c h (n w)')\n\n\n        grid = torch.cat([\n            target_images, render_images, \n        ], dim=-2)\n        grid = make_grid(grid, nrow=target_images.shape[0], normalize=True, value_range=(0, 1))\n        # self.logger.log_image('train/render', [grid], step=self.global_step)\n        image_path = os.path.join(self.logdir, 'images_test', f'{batch_idx:07d}.png')\n        save_image(grid, image_path)\n\n        # code visualize\n        code = render_out['code']\n        self.decoder.visualize(code, batch['scene_name'],\n                        os.path.dirname(image_path), code_range=self.code_clip_range)\n\n        print(f\"Saved image to {image_path}\")\n    \n    def on_test_start(self):\n        if self.global_rank == 0:\n            os.makedirs(os.path.join(self.logdir, 'images_test'), exist_ok=True)\n    \n    def on_test_epoch_end(self):\n        metrics = {}\n        metrics_mean = {}\n        metrics_var = {}\n        for key in self.validation_metrics[0].keys():\n            tmp = torch.stack([m[key] for m in self.validation_metrics]).cpu().numpy()\n            metrics_mean[key] = tmp.mean()\n            metrics_var[key] = tmp.var()\n\n        # trans format into \"mean±var\" \n        formatted_metrics = {}\n        for key in metrics_mean.keys():\n            formatted_metrics[key] = f\"{metrics_mean[key]:.4f}±{metrics_var[key]:.4f}\"\n\n        for key in self.validation_metrics[0].keys():\n            metrics[key] = torch.stack([m[key] for m in self.validation_metrics]).cpu().numpy().tolist()\n        \n\n        # saving into a dictionary\n        final_dict = {\"average\": formatted_metrics,\n                      'details': metrics}\n\n        metric_path = os.path.join(self.logdir, f'metrics.json')\n        with open(metric_path, 'w') as f:\n            json.dump(final_dict, f, indent=4)\n        print(f\"Saved metrics to {metric_path}\")\n        \n        for key in metrics.keys():\n            metrics[key] = torch.stack([m[key] for m in self.validation_metrics]).mean()\n        print(metrics)\n        \n        self.validation_metrics.clear()\n    \ndef weighted_mse_loss(render_images, target_images, weights):\n    squared_diff = (render_images - target_images) ** 2\n    main_print(squared_diff.shape, weights.shape)\n    weighted_squared_diff = squared_diff * weights\n    loss_mse_weighted = weighted_squared_diff.mean()\n    return loss_mse_weighted"
  },
  {
    "path": "lib/mmutils/__init__.py",
    "content": "from .initialize import xavier_init, constant_init"
  },
  {
    "path": "lib/mmutils/initialize.py",
    "content": "import torch.nn as nn\n\ndef constant_init(module: nn.Module, val: float, bias: float = 0) -> None:\n    if hasattr(module, 'weight') and module.weight is not None:\n        nn.init.constant_(module.weight, val)\n    if hasattr(module, 'bias') and module.bias is not None:\n        nn.init.constant_(module.bias, bias)\n\n\ndef xavier_init(module: nn.Module,\n                gain: float = 1,\n                bias: float = 0,\n                distribution: str = 'normal') -> None:\n    assert distribution in ['uniform', 'normal']\n    if hasattr(module, 'weight') and module.weight is not None:\n        if distribution == 'uniform':\n            nn.init.xavier_uniform_(module.weight, gain=gain)\n        else:\n            nn.init.xavier_normal_(module.weight, gain=gain)\n    if hasattr(module, 'bias') and module.bias is not None:\n        nn.init.constant_(module.bias, bias)"
  },
  {
    "path": "lib/models/__init__.py",
    "content": "\nfrom .decoders import *\n"
  },
  {
    "path": "lib/models/decoders/__init__.py",
    "content": "\nfrom .uvmaps_decoder_gender import UVNDecoder_gender\n\n__all__ = [ 'UVNDecoder_gender']\n"
  },
  {
    "path": "lib/models/decoders/uvmaps_decoder_gender.py",
    "content": "import os\nimport matplotlib.pyplot as plt\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom torch import einsum\nfrom pytorch3d import ops\n\nfrom lib.mmutils import xavier_init, constant_init\nimport numpy as np\nimport time\nimport cv2\nimport math\nfrom simple_knn._C import distCUDA2\nfrom pytorch3d.transforms import quaternion_to_matrix\n\nfrom ..deformers import SMPLXDeformer_gender\nfrom ..renderers import GRenderer, get_covariance, batch_rodrigues \nfrom lib.ops import TruncExp\nimport torchvision\n\nfrom lib.utils.train_util import main_print\n\ndef ensure_dtype(input_tensor, target_dtype=torch.float32):\n    \"\"\"\n    Ensure tensor dtype matches target dtype.\n    If not, convert it.\n    \"\"\"\n    if input_tensor.dtype != target_dtype:\n        input_tensor = input_tensor.to(dtype=target_dtype)\n    return input_tensor\n\n\nclass UVNDecoder_gender(nn.Module):\n\n    activation_dict = {\n        'relu': nn.ReLU,\n        'silu': nn.SiLU,\n        'softplus': nn.Softplus,\n        'trunc_exp': TruncExp,\n        'sigmoid': nn.Sigmoid}\n\n    def __init__(self,\n                 *args,\n                 interp_mode='bilinear',\n                 base_layers=[3 * 32, 128],\n                 density_layers=[128, 1],\n                 color_layers=[128, 128, 3],\n                 offset_layers=[128, 3],\n                 scale_layers=[128, 3],\n                 radius_layers=[128, 3],\n                 use_dir_enc=True,\n                 dir_layers=None,\n                 scene_base_size=None,\n                 scene_rand_dims=(0, 1),\n                 activation='silu',\n                 sigma_activation='sigmoid',\n                 sigmoid_saturation=0.001,\n                 code_dropout=0.0,\n                 flip_z=False,\n                 extend_z=False,\n                 gender='neutral',\n                 multires=0,\n                 bg_color=0,\n                 image_size=1024,\n                 superres=False,\n                 focal = 1280, # the default focal defination\n                 reshape_type=None, # if true, it will create a cnn layers to upsample the uv features\n                 fix_sigma=False, #  if true, the density of GS will be fixed\n                 up_cnn_in_channels = None, #  the channel number of the upsample cnn\n                 vithead_param=None, # the vit head for decode to uv features\n                 is_sub2=False, # if true, will use the sub2 uv map\n                 **kwargs):\n        super().__init__()\n        self.interp_mode = interp_mode\n        self.in_chn = base_layers[0]\n        self.use_dir_enc = use_dir_enc\n        if scene_base_size is None:\n            self.scene_base = None\n        else:\n            rand_size = [1 for _ in scene_base_size]\n            for dim in scene_rand_dims:\n                rand_size[dim] = scene_base_size[dim]\n            init_base = torch.randn(rand_size).expand(scene_base_size).clone()\n            self.scene_base = nn.Parameter(init_base)\n        self.dir_encoder = None\n        self.sigmoid_saturation = sigmoid_saturation\n        self.deformer = SMPLXDeformer_gender(gender, is_sub2=is_sub2)\n\n        self.renderer = GRenderer(image_size=image_size, bg_color=bg_color, f=focal)\n        if superres:\n            self.superres = None\n        else:\n            self.superres = None\n        self.gender= gender\n        self.reshape_type = reshape_type\n        if reshape_type=='cnn':\n            self.upsample_conv = torch.nn.ConvTranspose2d(512, 32, kernel_size=4, stride=4,).cuda()\n                                                 \n        elif reshape_type == 'VitHead': # changes the up block's layernorm into the feature channel norm instead of the full image norm\n            from lib.models.decoders.vit_head import VitHead\n            self.upsample_conv = VitHead(**vithead_param)\n            # 256, 128, 128 -> 128, 256, 256 -> 64, 512, 512, ->32, 1024, 1024\n        \n        base_cache_dir = 'work_dirs/cache'   \n        if is_sub2:\n            base_cache_dir = 'work_dirs/cache_sub2'\n            # main_print(\"!!!!!!!!!!!!!!!!!!! using the sub2 uv map !!!!!!!!!!!!!!!!!!!\")\n        if gender == 'neutral':\n            select_uv = torch.as_tensor(np.load(base_cache_dir+'/init_uv_smplx_newNeutral.npy'))\n            self.register_buffer('select_coord', select_uv.unsqueeze(0)*2.-1.)\n\n            init_pcd = torch.as_tensor(np.load(base_cache_dir+'/init_pcd_smplx_newNeutral.npy'))\n            self.register_buffer('init_pcd', init_pcd.unsqueeze(0), persistent=False) # 0.9-- -1\n        elif gender == 'male':\n            assert NotImplementedError(\"Haven't create the init_uv_smplx_thu in v_template\")\n            select_uv = torch.as_tensor(np.load(base_cache_dir+'/init_uv_smplx_thu.npy'))\n            self.register_buffer('select_coord', select_uv.unsqueeze(0)*2.-1.)\n\n            init_pcd = torch.as_tensor(np.load(base_cache_dir+'/init_pcd_smplx_thu.npy'))\n            self.register_buffer('init_pcd', init_pcd.unsqueeze(0), persistent=False) # 0.9-- -1\n        self.num_init = self.init_pcd.shape[1]\n        main_print(f\"!!!!!!!!!!!!!!!!!!! cur points number are {self.num_init} !!!!!!!!!!!!!!!!!!!\")\n\n        self.init_pcd = self.init_pcd \n\n        self.multires = multires # 0 Haven't \n        if multires > 0:\n            uv_map = torch.as_tensor(np.load(base_cache_dir+'/init_uvmap_smplx_thu.npy'))\n            pcd_map = torch.as_tensor(np.load(base_cache_dir+'/init_posmap_smplx_thu.npy'))\n            input_coord = torch.cat([pcd_map, uv_map], dim=1)\n            self.register_buffer('input_freq', input_coord, persistent=False)\n            base_layers[0] += 5\n            color_layers[0] += 5\n        else:\n            self.init_uv = None\n\n        activation_layer = self.activation_dict[activation.lower()]\n\n\n        base_net = [] # linear (in=18, out=64, bias=True)\n        for i in range(len(base_layers) - 1):\n            base_net.append(nn.Conv2d(base_layers[i], base_layers[i + 1], 3, padding=1))\n            if i != len(base_layers) - 2:\n                base_net.append(nn.BatchNorm2d(base_layers[i+1]))\n                base_net.append(activation_layer())\n        self.base_net = nn.Sequential(*base_net)\n        self.base_bn = nn.BatchNorm2d(base_layers[-1])\n        self.base_activation = activation_layer()\n\n        density_net = [] # linear(in=64, out=1, bias=True), sigmoid\n        for i in range(len(density_layers) - 1):\n            density_net.append(nn.Conv2d(density_layers[i], density_layers[i + 1], 1))\n            if i != len(density_layers) - 2:\n                density_net.append(nn.BatchNorm2d(density_layers[i+1]))\n                density_net.append(activation_layer())\n        density_net.append(self.activation_dict[sigma_activation.lower()]())\n        self.density_net = nn.Sequential(*density_net)\n\n        offset_net = [] # linear(in=64, out=1, bias=True), sigmoid\n        for i in range(len(offset_layers) - 1):\n            offset_net.append(nn.Conv2d(offset_layers[i], offset_layers[i + 1], 1))\n            if i != len(offset_layers) - 2:\n                offset_net.append(nn.BatchNorm2d(offset_layers[i+1]))\n                offset_net.append(activation_layer())\n        self.offset_net = nn.Sequential(*offset_net)\n\n        self.dir_net = None\n        color_net = [] # linear(in=64, out=3, bias=True), sigmoid\n        for i in range(len(color_layers) - 2):\n            color_net.append(nn.Conv2d(color_layers[i], color_layers[i + 1], kernel_size=3, padding=1))\n            color_net.append(nn.BatchNorm2d(color_layers[i+1]))\n            color_net.append(activation_layer())\n        color_net.append(nn.Conv2d(color_layers[-2], color_layers[-1], kernel_size=1))\n        color_net.append(nn.Sigmoid())\n        self.color_net = nn.Sequential(*color_net)\n        self.code_dropout = nn.Dropout2d(code_dropout) if code_dropout > 0 else None\n\n        self.flip_z = flip_z\n        self.extend_z = extend_z\n\n        if self.gender == 'neutral':\n            init_rot = torch.as_tensor(np.load(base_cache_dir+'/init_rot_smplx_newNeutral.npy'))\n            self.register_buffer('init_rot', init_rot, persistent=False)\n\n            face_mask = torch.as_tensor(np.load(base_cache_dir+'/face_mask_thu_newNeutral.npy'))\n            self.register_buffer('face_mask', face_mask.unsqueeze(0), persistent=False)\n\n            hands_mask = torch.as_tensor(np.load(base_cache_dir+'/hands_mask_thu_newNeutral.npy'))\n            self.register_buffer('hands_mask', hands_mask.unsqueeze(0), persistent=False)\n\n            outside_mask = torch.as_tensor(np.load(base_cache_dir+'/outside_mask_thu_newNeutral.npy'))\n            self.register_buffer('outside_mask', outside_mask.unsqueeze(0), persistent=False)\n        else:\n            assert NotImplementedError(\"Haven't create the init_rot in v_template\")\n            init_rot = torch.as_tensor(np.load(base_cache_dir+'/init_rot_smplx_thu.npy'))\n            self.register_buffer('init_rot', init_rot, persistent=False)\n\n            face_mask = torch.as_tensor(np.load(base_cache_dir+'/face_mask_thu.npy'))\n            self.register_buffer('face_mask', face_mask.unsqueeze(0), persistent=False)\n\n            hands_mask = torch.as_tensor(np.load(base_cache_dir+'/hands_mask_thu.npy'))\n            self.register_buffer('hands_mask', hands_mask.unsqueeze(0), persistent=False)\n\n            outside_mask = torch.as_tensor(np.load(base_cache_dir+'/outside_mask_thu.npy'))\n            self.register_buffer('outside_mask', outside_mask.unsqueeze(0), persistent=False)\n\n        self.iter = 0\n        self.init_weights()\n        self.if_rotate_gaussian = False\n        self.fix_sigma = fix_sigma\n    def init_weights(self):\n        for m in self.modules():\n            if isinstance(m, nn.Linear):\n                xavier_init(m, distribution='uniform')\n        if self.dir_net is not None:\n            constant_init(self.dir_net[-1], 0)\n        if self.offset_net is not None:\n            self.offset_net[-1].weight.data.uniform_(-1e-5, 1e-5)\n            self.offset_net[-1].bias.data.zero_()\n\n\n    def extract_pcd(self, code, smpl_params, init=False, zeros_hands_off=False):\n        '''\n        Args:\n            B == num_scenes\n            code (tensor): latent code. shape: [B, C, H, W]\n            smpl_params (tensor): SMPL parameters. shape: [B_pose, 189]\n            init (bool): Not used\n        Returns:\n            defm_pcd (tensor): deformed point cloud. shape: [B, N, B_pose, 3]\n            sigmas, rgbs, offset, radius, rot(tensor): GS attributes. shape: [B, N, C]\n            tfs(tensor): deformation matrics. shape: [B, N, C]\n        '''\n        if isinstance(code, list):\n            num_scenes, _, h, w = code[0].size()\n        else:\n            num_scenes, n_channels, h, w = code.size()\n        init_pcd = self.init_pcd.repeat(num_scenes, 1, 1)  # T-posed space points, for computing the skinning weights\n        \n        sigmas, rgbs, radius, rot, offset = self._decode(code, init=init) #  the person-specify attributes of GS\n        if self.fix_sigma:\n            sigmas = torch.ones_like(sigmas)\n        if zeros_hands_off:\n            offset[self.hands_mask[...,None].expand(num_scenes, -1, 3)] = 0\n        canon_pcd = init_pcd + offset\n        \n        self.deformer.prepare_deformer(smpl_params, num_scenes, device=canon_pcd.device)\n        defm_pcd, tfs = self.deformer(canon_pcd, rot, mask=(self.face_mask+self.hands_mask+self.outside_mask), cano=False, if_rotate_gaussian=self.if_rotate_gaussian)\n        return defm_pcd, sigmas, rgbs, offset, radius, tfs, rot\n\n    def deform_pcd(self, code, smpl_params, init=False, zeros_hands_off=False, value=0.1):\n        '''\n        Args:\n            B == num_scenes\n            code (List): list of data\n            smpl_params (tensor): SMPL parameters. shape: [B_pose, 189]\n            init (bool): Not used\n        Returns:\n            defm_pcd (tensor): deformed point cloud. shape: [B, N, B_pose, 3]\n            sigmas, rgbs, offset, radius, rot(tensor): GS attributes. shape: [B, N, C]\n            tfs(tensor): deformation matrics. shape: [B, N, C]\n        '''\n        sigmas, rgbs, radius, rot, offset = code\n        num_scenes = sigmas.shape[0]\n        init_pcd = self.init_pcd.repeat(num_scenes, 1, 1)  #T-posed space points, for computing the skinning weights\n\n        if self.fix_sigma:\n            sigmas = torch.ones_like(sigmas)\n        if zeros_hands_off:\n            offset[self.hands_mask[...,None].expand(num_scenes, -1, 3)] = torch.clamp(offset[self.hands_mask[...,None].expand(num_scenes, -1, 3)], -value, value)\n        canon_pcd = init_pcd + offset\n        self.deformer.prepare_deformer(smpl_params, num_scenes, device=canon_pcd.device)\n        defm_pcd, tfs = self.deformer(canon_pcd, rot, mask=(self.face_mask+self.hands_mask+self.outside_mask), cano=False, if_rotate_gaussian=self.if_rotate_gaussian)\n        return defm_pcd, sigmas, rgbs, offset, radius, tfs, rot\n\n\n        \n    def _sample_feature(self,results,):\n        # outputs, sigma_uv, offset_uv, rgbs_uv, radius_uv, rot_uv = results['output'], results['sigma'], results['offset'], results['rgbs'], results['radius'], results['rot']\n        sigma = results['sigma']\n        outputs = results['output']\n        if isinstance(sigma, list):\n            num_scenes, _, h, w = sigma[0].shape\n            select_coord = self.select_coord.unsqueeze(1).repeat(num_scenes, 1, 1, 1)\n        elif sigma.dim() == 4:\n            num_scenes, n_channels, h, w = sigma.shape\n            select_coord = self.select_coord.unsqueeze(1).repeat(num_scenes, 1, 1, 1)\n        else:\n            assert False\n        output_attr = F.grid_sample(outputs, select_coord, mode=self.interp_mode, padding_mode='border', align_corners=False).reshape(num_scenes, 13, -1).permute(0, 2, 1)\n        sigma, offset, rgbs, radius, rot = output_attr.split([1, 3, 3, 3, 3], dim=2)\n\n        if self.sigmoid_saturation > 0:\n            rgbs = rgbs * (1 + self.sigmoid_saturation * 2) - self.sigmoid_saturation\n        \n        radius = (radius - 0.5) * 2\n        rot = (rot - 0.5) * np.pi\n\n        return sigma, rgbs, radius, rot, offset\n\n    def _decode_feature(self, point_code, init=False):\n        if isinstance(point_code, list):\n            num_scenes, _, h, w = point_code[0].shape\n            geo_code, tex_code = point_code\n            # select_coord = self.select_coord.unsqueeze(1).repeat(num_scenes, 1, 1, 1)\n            if self.multires != 0:\n                input_freq = self.input_freq.repeat(num_scenes, 1, 1, 1)\n        elif point_code.dim() == 4:\n            num_scenes, n_channels, h, w = point_code.shape\n            # select_coord = self.select_coord.unsqueeze(1).repeat(num_scenes, 1, 1, 1)\n            if self.multires != 0:\n                input_freq = self.input_freq.repeat(num_scenes, 1, 1, 1)\n            geo_code, tex_code = point_code.split(16, dim=1)\n        else:\n            assert False\n\n        base_in = geo_code if self.multires == 0 else torch.cat([geo_code, input_freq], dim=1)\n        base_x = self.base_net(base_in)\n        base_x_act = self.base_activation(self.base_bn(base_x))\n    \n        sigma = self.density_net(base_x_act)\n        offset = self.offset_net(base_x_act)\n        color_in = tex_code if self.multires == 0 else torch.cat([tex_code, input_freq], dim=1)\n        rgbs_radius_rot = self.color_net(color_in)\n        \n        outputs = torch.cat([sigma, offset, rgbs_radius_rot], dim=1)\n        main_print(outputs.shape)\n        sigma, offset, rgbs, radius, rot = outputs.split([1, 3, 3, 3, 3], dim=1)\n        results = {'output':outputs, 'sigma': sigma, 'offset': offset, 'rgbs': rgbs, 'radius': radius, 'rot': rot}\n\n        return results\n    def _decode(self, point_code, init=False):\n        if isinstance(point_code, list):\n            num_scenes, _, h, w = point_code[0].shape\n            geo_code, tex_code = point_code\n            select_coord = self.select_coord.unsqueeze(1).repeat(num_scenes, 1, 1, 1)\n            if self.multires != 0:\n                input_freq = self.input_freq.repeat(num_scenes, 1, 1, 1)\n        elif point_code.dim() == 4:\n            num_scenes, n_channels, h, w = point_code.shape\n            select_coord = self.select_coord.unsqueeze(1).repeat(num_scenes, 1, 1, 1)\n            if self.multires != 0:\n                input_freq = self.input_freq.repeat(num_scenes, 1, 1, 1)\n            geo_code, tex_code = point_code.split(16, dim=1)\n        else:\n            assert False\n\n        base_in = geo_code if self.multires == 0 else torch.cat([geo_code, input_freq], dim=1)\n        base_x = self.base_net(base_in)\n        base_x_act = self.base_activation(self.base_bn(base_x))\n     \n        sigma = self.density_net(base_x_act)\n        offset = self.offset_net(base_x_act)\n        color_in = tex_code if self.multires == 0 else torch.cat([tex_code, input_freq], dim=1)\n        rgbs_radius_rot = self.color_net(color_in)\n        \n        outputs = torch.cat([sigma, offset, rgbs_radius_rot], dim=1)\n        output_attr = F.grid_sample(outputs, select_coord, mode=self.interp_mode, padding_mode='border', align_corners=False).reshape(num_scenes, 13, -1).permute(0, 2, 1)\n        sigma, offset, rgbs, radius, rot = output_attr.split([1, 3, 3, 3, 3], dim=2)\n\n        if self.sigmoid_saturation > 0:\n            rgbs = rgbs * (1 + self.sigmoid_saturation * 2) - self.sigmoid_saturation\n        \n        radius = (radius - 0.5) * 2\n        rot = (rot - 0.5) * np.pi\n\n        return sigma, rgbs, radius, rot, offset\n\n    def gaussian_render(self, pcd, sigmas, rgbs, normals, rot, num_scenes, num_imgs, cameras, use_scale=False, radius=None, \\\n                         return_norm=False, return_viz=False, mask=None):\n        # add mask or visible points to images or select ind to images\n        '''\n           render the gaussian to images\n           return_norm: return the normals of the gaussian (haven't been used)\n           return_viz: return the mask of the gaussian\n           mask: the mask of the gaussian\n        '''\n        assert num_scenes == 1\n        \n        pcd = pcd.reshape(-1, 3)\n        if use_scale: \n            dist2 = distCUDA2(pcd)\n            dist2 = torch.clamp_min((dist2), 0.0000001)\n            scales = torch.sqrt(dist2)[...,None].repeat(1, 3).detach() # distence between different points\n            scale = (radius+1)*scales  # scaling_modifier # radius[-1--1], scale of GS\n            cov3D = get_covariance(scale, rot).reshape(-1, 6) # inputs rot is the rotations\n       \n        images_all = []\n        viz_masks = [] if return_viz else None\n        norm_all = [] if return_norm else None\n\n        if mask != None:\n            pcd = pcd[mask]\n            rgbs = rgbs[mask]\n            sigmas = sigmas[mask]\n            cov3D = cov3D[mask]\n            normals = normals[mask]\n        if 1:\n            for i in range(num_imgs):\n                self.renderer.prepare(cameras[i])\n\n                image = self.renderer.render_gaussian(means3D=pcd, colors_precomp=rgbs, \n                    rotations=None, opacities=sigmas, scales=None, cov3D_precomp=cov3D)\n                images_all.append(image)\n                if return_viz:\n                    viz_mask = self.renderer.render_gaussian(means3D=pcd, colors_precomp=pcd.clone(), \n                        rotations=None, opacities=sigmas*0+1, scales=None, cov3D_precomp=cov3D)\n                    viz_masks.append(viz_mask)\n          \n\n        images_all = torch.stack(images_all, dim=0).unsqueeze(0).permute(0, 1, 3, 4, 2)\n        if return_viz:\n            viz_masks = torch.stack(viz_masks, dim=0).unsqueeze(0).permute(0, 1, 3, 4, 2).reshape(1, -1, 3)\n            dist_sq, idx, neighbors = ops.knn_points(pcd.unsqueeze(0), viz_masks[:, ::10], K=1, return_nn=True)\n            viz_masks = (dist_sq < 0.0001)[0]\n         # ===== END the original code for batch size = 1 =====\n        if use_scale:\n            return images_all, norm_all, viz_masks, scale\n        else:\n            return images_all, norm_all, viz_masks, None\n\n    def visualize(self, code, scene_name, viz_dir, code_range=[-1, 1]):\n        num_scenes, num_chn, h, w = code.size()\n        code_viz = code.reshape(num_scenes, 4, 8, h, w).to(torch.float32).cpu().numpy()\n        if not self.flip_z:\n            code_viz = code_viz[..., ::-1, :]\n        code_viz = code_viz.transpose(0, 1, 3, 2, 4).reshape(num_scenes, 4 * h, 8 * w)\n        for code_single, code_viz_single, scene_name_single in zip(code, code_viz, scene_name):\n            plt.imsave(os.path.join(viz_dir, 'a_scene_' + scene_name_single + '.png'), code_viz_single,\n                       vmin=code_range[0], vmax=code_range[1])\n\n    def forward(self, code, smpl_params, cameras, num_imgs,\n                return_loss=False, return_norm=False, init=False, mask=None, zeros_hands_off=False):\n        \"\"\"\n        Args:\n\n          \n            density_bitfield: Shape (num_scenes, griz_size**3 // 8)\n            YY:\n            grid_size, dt_gamma, perturb, T_thresh are deleted\n            code: Shape (num_scenes, *code_size)\n            cameras: Shape (num_scenes, num_imgs, 19(3+16))\n            smpl_params: Shape (num_scenes, 189)\n\n        \"\"\"\n        # import ipdb; ipdb.set_trace()\n        if isinstance(code, list):\n            num_scenes = len(code[0])\n        else:\n            num_scenes = len(code)\n        assert num_scenes > 0\n        self.iter+=1\n\n        image = []\n        scales = []\n        norm = [] if return_norm else None\n        viz_masks = [] if not self.training else None\n\n        xyzs, sigmas, rgbs, offsets, radius, tfs, rot = self.extract_pcd(code, smpl_params, init=init, zeros_hands_off=zeros_hands_off)\n\n        if zeros_hands_off:\n            main_print('zeros_hands_off is on!')\n            main_print('zeros_hands_off is on!')\n            offsets[self.hands_mask[...,None].repeat(num_scenes, 1, 3)] = 0\n            rgbs[self.hands_mask[...,None].repeat(num_scenes, 1, 3)] = 0\n            rgbs[self.hands_mask[...,None].repeat(num_scenes, 1, 3)] = 0\n        R_delta = batch_rodrigues(rot.reshape(-1, 3))\n        R = torch.bmm(self.init_rot.repeat(num_scenes, 1, 1), R_delta)\n        R_def = torch.bmm(tfs.flatten(0, 1)[:, :3, :3], R)\n        normals = (R_def[:, :, -1]).reshape(num_scenes, -1, 3)\n        R_def_batch = R_def.reshape(num_scenes, -1, 3, 3)\n      \n        return_to_bfloat16 = True if xyzs.dtype==torch.bfloat16 else False ####### ============ translate the output to BF16 =================\n        # return_to_bfloat16 = False # I don't want to trans it back to bf16\n        if return_to_bfloat16:\n              main_print(\"changes the return_to_bfloat16\")\n              cameras, R_def_batch, xyzs, rgbs, sigmas, normals, radius = [ensure_dtype(item, torch.float32) for item in (cameras, R_def_batch, xyzs, rgbs, sigmas, normals, radius)]\n        # with torch.amp.autocast(enabled=False, device_type='cuda'):\n        if 1:\n            for camera_single, R_def_single, pcd_single, rgbs_single, sigmas_single, normal_single, radius_single in zip(cameras, R_def_batch, xyzs, rgbs, sigmas, normals, radius):\n                image_single, norm_single,  viz_mask, scale = self.gaussian_render(pcd_single, sigmas_single, rgbs_single, normal_single, R_def_single, 1, num_imgs, camera_single, use_scale=True, \\\n                                                                        radius=radius_single, return_norm=return_norm, return_viz=not self.training)\n                image.append(image_single)\n                scales.append(scale.unsqueeze(0))\n                if return_norm:\n                    norm.append(norm_single)\n                if not self.training:\n                    viz_masks.append(viz_mask)\n        image = torch.cat(image, dim=0)\n        scales = torch.cat(scales, dim=0)\n\n        norm = torch.cat(norm, dim=0) if return_norm else None\n        viz_masks = torch.cat(viz_masks, dim=0) if  (not self.training) and viz_masks else None\n\n      \n        main_print(\"not trans the rendered results to float16\")\n        if False:\n            image = image.to(torch.bfloat16)\n            scales = scales.to(torch.bfloat16)\n            if return_norm:\n                norm = norm.to(torch.bfloat16)\n            if viz_masks is not None:\n                viz_masks = viz_masks.to(torch.bfloat16)\n            offsets = offsets.to(torch.bfloat16)\n\n        if self.training:\n            offset_dist = offsets ** 2\n            weighted_offset = torch.mean(offset_dist) + torch.mean(offset_dist[self.hands_mask.repeat(num_scenes, 1)]) #+ torch.mean(offset_dist[self.face_mask.repeat(num_scenes, 1)])\n        else:\n            weighted_offset = offsets\n        \n\n        results = dict(\n            viz_masks=viz_masks,\n            scales=scales,\n            norm=norm,\n            image=image,\n            offset=weighted_offset)\n\n        if return_loss:\n            results.update(decoder_reg_loss=self.loss())\n\n        return results\n    \n    def forward_render(self, code, cameras, num_imgs,\n                return_loss=False, return_norm=False, init=False, mask=None, zeros_hands_off=False):\n        \"\"\"\n        Args:\n\n        \n            density_bitfield: Shape (num_scenes, griz_size**3 // 8)\n            YY:\n            grid_size, dt_gamma, perturb, T_thresh are deleted\n            code: Shape (num_scenes, *code_size)\n            cameras: Shape (num_scenes, num_imgs, 19(3+16))\n            smpl_params: Shape (num_scenes, 189)\n\n        \"\"\"\n        image = []\n        scales = []\n        norm = [] if return_norm else None\n        viz_masks = [] if not self.training else None\n\n        xyzs, sigmas, rgbs, offsets, radius, tfs, rot =  code\n        num_scenes = xyzs.shape[0]\n        if zeros_hands_off:\n            main_print('zeros_hands_off is on!')\n            main_print('zeros_hands_off is on!')\n            offsets[self.hands_mask[...,None].repeat(num_scenes, 1, 3)] = 0\n            rgbs[self.hands_mask[...,None].repeat(num_scenes, 1, 3)] = 0\n            rgbs[self.hands_mask[...,None].repeat(num_scenes, 1, 3)] = 0\n        R_delta = batch_rodrigues(rot.reshape(-1, 3))\n        R = torch.bmm(self.init_rot.repeat(num_scenes, 1, 1), R_delta)\n        R_def = torch.bmm(tfs.flatten(0, 1)[:, :3, :3], R)\n        normals = (R_def[:, :, -1]).reshape(num_scenes, -1, 3)\n        R_def_batch = R_def.reshape(num_scenes, -1, 3, 3)\n        # import ipdb; ipdb.set_trace()\n    \n        return_to_bfloat16 = True if xyzs.dtype==torch.bfloat16 else False ####### ============ translate the output to BF16 =================\n        # return_to_bfloat16 = False # I don't want to trans it back to bf16\n        if return_to_bfloat16:\n            main_print(\"changes the return_to_bfloat16\")\n            cameras, R_def_batch, xyzs, rgbs, sigmas, normals, radius = [ensure_dtype(item, torch.float32) for item in (cameras, R_def_batch, xyzs, rgbs, sigmas, normals, radius)]\n\n        if 1:\n            for camera_single, R_def_single, pcd_single, rgbs_single, sigmas_single, normal_single, radius_single in zip(cameras, R_def_batch, xyzs, rgbs, sigmas, normals, radius):\n                image_single, norm_single,  viz_mask, scale = self.gaussian_render(pcd_single, sigmas_single, rgbs_single, normal_single, R_def_single, 1, num_imgs, camera_single, use_scale=True, \\\n                                                                        radius=radius_single, return_norm=return_norm, return_viz=not self.training)\n                image.append(image_single)\n                scales.append(scale.unsqueeze(0))\n                if return_norm:\n                    norm.append(norm_single)\n                if not self.training:\n                    viz_masks.append(viz_mask)\n        image = torch.cat(image, dim=0)\n        scales = torch.cat(scales, dim=0)\n\n        norm = torch.cat(norm, dim=0) if return_norm else None\n        viz_masks = torch.cat(viz_masks, dim=0) if  (not self.training) and viz_masks else None\n\n\n        main_print(\"not trans the rendered results to float16\")\n        if False:\n            image = image.to(torch.bfloat16)\n            scales = scales.to(torch.bfloat16)\n            if return_norm:\n                norm = norm.to(torch.bfloat16)\n            if viz_masks is not None:\n                viz_masks = viz_masks.to(torch.bfloat16)\n            offsets = offsets.to(torch.bfloat16)\n\n        if self.training:\n            offset_dist = offsets ** 2\n            weighted_offset = torch.mean(offset_dist) + torch.mean(offset_dist[self.hands_mask.repeat(num_scenes, 1)]) #+ torch.mean(offset_dist[self.face_mask.repeat(num_scenes, 1)])\n        else:\n            weighted_offset = offsets\n        \n\n        results = dict(\n            viz_masks=viz_masks,\n            scales=scales,\n            norm=norm,\n            image=image,\n            offset=weighted_offset)\n        if return_loss:\n            results.update(decoder_reg_loss=self.loss())\n\n        return results\n    \n    \n    def forward_testing_time(self, code, smpl_params, cameras, num_imgs,\n                return_loss=False, return_norm=False, init=False, mask=None, zeros_hands_off=False):\n        \"\"\"\n        Args:\n\n          \n            density_bitfield: Shape (num_scenes, griz_size**3 // 8)\n            YY:\n            grid_size, dt_gamma, perturb, T_thresh are deleted\n            code: Shape (num_scenes, *code_size)\n            cameras: Shape (num_scenes, num_imgs, 19(3+16))\n            smpl_params: Shape (num_scenes, 189)\n\n        \"\"\"\n        if isinstance(code, list):\n            num_scenes = len(code[0])\n        else:\n            num_scenes = len(code)\n        assert num_scenes > 0\n        self.iter+=1\n\n        image = []\n        scales = []\n        norm = [] if return_norm else None\n        viz_masks = [] if not self.training else None\n        start_time = time.time()\n        xyzs, sigmas, rgbs, offsets, radius, tfs, rot = self.extract_pcd(code, smpl_params, init=init, zeros_hands_off=zeros_hands_off)\n        end_time_to_3D = time.time()\n        time_code_to_3d = end_time_to_3D- start_time \n\n        R_delta = batch_rodrigues(rot.reshape(-1, 3))\n        R = torch.bmm(self.init_rot.repeat(num_scenes, 1, 1), R_delta)\n        R_def = torch.bmm(tfs.flatten(0, 1)[:, :3, :3], R)\n        normals = (R_def[:, :, -1]).reshape(num_scenes, -1, 3)\n        R_def_batch = R_def.reshape(num_scenes, -1, 3, 3)\n        if 1:\n            for camera_single, R_def_single, pcd_single, rgbs_single, sigmas_single, normal_single, radius_single in zip(cameras, R_def_batch, xyzs, rgbs, sigmas, normals, radius):\n                image_single, norm_single,  viz_mask, scale = self.gaussian_render(pcd_single, sigmas_single, rgbs_single, normal_single, R_def_single, 1, num_imgs, camera_single, use_scale=True, \\\n                                                                        radius=radius_single, return_norm=False, return_viz=not self.training)\n                image.append(image_single)\n                scales.append(scale.unsqueeze(0))\n                if return_norm:\n                    norm.append(norm_single)\n                if not self.training:\n                    viz_masks.append(viz_mask)\n        image = torch.cat(image, dim=0)\n        scales = torch.cat(scales, dim=0)\n\n        norm = torch.cat(norm, dim=0) if return_norm else None\n        viz_masks = torch.cat(viz_masks, dim=0) if  (not self.training) and viz_masks else None\n\n        time_3D_to_img = time.time() - end_time_to_3D\n\n\n        if False:\n            image = image.to(torch.bfloat16)\n            scales = scales.to(torch.bfloat16)\n            if return_norm:\n                norm = norm.to(torch.bfloat16)\n            if viz_masks is not None:\n                viz_masks = viz_masks.to(torch.bfloat16)\n            offsets = offsets.to(torch.bfloat16)\n\n        results = dict(\n            image=image)\n\n        return results, time_code_to_3d, time_3D_to_img"
  },
  {
    "path": "lib/models/decoders/vit_head.py",
    "content": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom typing import Sequence, Tuple, Optional\n\nclass VitHead(nn.Module):\n    def __init__(self,\n                 in_channels: int,\n                 out_channels: int,\n                 deconv_out_channels: Optional[Sequence[int]] = (256, 256, 256),\n                 deconv_kernel_sizes: Optional[Sequence[int]] = (4, 4, 4),\n                 conv_out_channels: Optional[Sequence[int]] = None,\n                 conv_kernel_sizes: Optional[Sequence[int]] = None,\n                 ):\n        super(VitHead, self).__init__()\n\n        if deconv_out_channels:\n            if deconv_kernel_sizes is None or len(deconv_out_channels) != len(deconv_kernel_sizes):\n                raise ValueError(\n                    '\"deconv_out_channels\" and \"deconv_kernel_sizes\" should '\n                    'be integer sequences with the same length. Got '\n                    f'mismatched lengths {deconv_out_channels} and '\n                    f'{deconv_kernel_sizes}')\n\n            self.deconv_layers = self._make_deconv_layers(\n                in_channels=in_channels,\n                layer_out_channels=deconv_out_channels,\n                layer_kernel_sizes=deconv_kernel_sizes,\n            )\n            in_channels = deconv_out_channels[-1]\n        else:\n            self.deconv_layers = nn.Identity()\n\n        if conv_out_channels:\n            if conv_kernel_sizes is None or len(conv_out_channels) != len(conv_kernel_sizes):\n                raise ValueError(\n                    '\"conv_out_channels\" and \"conv_kernel_sizes\" should '\n                    'be integer sequences with the same length. Got '\n                    f'mismatched lengths {conv_out_channels} and '\n                    f'{conv_kernel_sizes}')\n\n            self.conv_layers = self._make_conv_layers(\n                in_channels=in_channels,\n                layer_out_channels=conv_out_channels,\n                layer_kernel_sizes=conv_kernel_sizes)\n            in_channels = conv_out_channels[-1]\n        else:\n            self.conv_layers = nn.Identity()\n\n        self.cls_seg = nn.Conv2d(in_channels, out_channels, kernel_size=1)\n\n    def _make_conv_layers(self, in_channels: int,\n                          layer_out_channels: Sequence[int],\n                          layer_kernel_sizes: Sequence[int]) -> nn.Module:\n        \"\"\"Create convolutional layers by given parameters.\"\"\"\n        layers = []\n        for out_channels, kernel_size in zip(layer_out_channels, layer_kernel_sizes):\n            padding = (kernel_size - 1) // 2\n            layers.append(nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=1, padding=padding))\n            layers.append(nn.InstanceNorm2d(out_channels))\n            layers.append(nn.SiLU(inplace=True))\n            in_channels = out_channels\n\n        return nn.Sequential(*layers)\n\n    def _make_deconv_layers(self, in_channels: int,\n                            layer_out_channels: Sequence[int],\n                            layer_kernel_sizes: Sequence[int]) -> nn.Module:\n        \"\"\"Create deconvolutional layers by given parameters.\"\"\"\n        layers = []\n        for out_channels, kernel_size in zip(layer_out_channels, layer_kernel_sizes):\n            if kernel_size == 4:\n                padding = 1\n                output_padding = 0\n            elif kernel_size == 3:\n                padding = 1\n                output_padding = 1\n            elif kernel_size == 2:\n                padding = 0\n                output_padding = 0\n            else:\n                raise ValueError(f'Unsupported kernel size {kernel_size} for deconvolutional layers')\n\n            layers.append(nn.ConvTranspose2d(in_channels, out_channels, kernel_size=kernel_size, stride=2, padding=padding, output_padding=output_padding, bias=False))\n            layers.append(nn.InstanceNorm2d(out_channels))\n            layers.append(nn.SiLU(inplace=True))\n            in_channels = out_channels\n\n        return nn.Sequential(*layers)\n\n    def forward(self, inputs):\n        x = self.deconv_layers(inputs)\n        x = self.conv_layers(x)\n        out = self.cls_seg(x)\n        return out\n\nif __name__ == \"__main__\":\n\n    # Example usage:\n    model = VitHead(in_channels=1536, out_channels=21, deconv_out_channels=(768, 768, 512, 512),\n                    deconv_kernel_sizes=(4, 4, 4, 4),\n                    conv_out_channels=(512, 256, 128), conv_kernel_sizes=(1, 1, 1),\n                    )\n    inputs = torch.randn(1, 1536, 64, 64)\n    outputs = model(inputs)\n    print(outputs.shape)"
  },
  {
    "path": "lib/models/deformers/__init__.py",
    "content": "\nfrom .smplx_deformer_gender import SMPLXDeformer_gender\n\n__all__ = ['SMPLXDeformer_gender']"
  },
  {
    "path": "lib/models/deformers/fast_snarf/cuda/filter/filter.cpp",
    "content": "#include <torch/extension.h>\n#include <ATen/ATen.h>\n#include <vector>\n#include <c10/cuda/CUDAGuard.h>\n\n\n\n\n\nvoid launch_filter(\n  torch::Tensor &output, const torch::Tensor &x, const torch::Tensor &mask);\n\nvoid filter(const torch::Tensor &x, const torch::Tensor &mask, torch::Tensor &output) {\n  const at::cuda::OptionalCUDAGuard device_guard(device_of(x));\n  launch_filter(output, x, mask);\n}\n\nPYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {\n  m.def(\"filter\", &filter);\n}\n"
  },
  {
    "path": "lib/models/deformers/fast_snarf/cuda/filter/filter_kernel.cu",
    "content": "#include <vector>\n#include <ATen/cuda/CUDAContext.h>\n#include <ATen/cuda/detail/TensorInfo.cuh>\n#include <ATen/cuda/detail/IndexUtils.cuh>\n#include <ATen/cuda/detail/KernelUtils.h>\n#include <ATen/core/TensorBase.h>\n#include <ATen/Dispatch.h>\n#include <c10/macros/Macros.h>\n\n#include \"ATen/Functions.h\"\n#include \"ATen/core/TensorAccessor.h\"\n#include \"c10/cuda/CUDAException.h\"\n#include \"c10/cuda/CUDAStream.h\"\n\n#include <chrono>\n\n#define TensorAcc4R PackedTensorAccessor32<scalar_t,4,RestrictPtrTraits>\n#define TensorAcc5R PackedTensorAccessor32<scalar_t,5,RestrictPtrTraits>\n\nusing namespace at;\nusing namespace at::cuda::detail;\n\n\ntemplate <typename scalar_t, typename index_t>\nC10_LAUNCH_BOUNDS_1(512)\n__global__ void filter(\n    const index_t nthreads,\n    PackedTensorAccessor32<scalar_t, 4, RestrictPtrTraits> x,\n    PackedTensorAccessor32<bool, 3, RestrictPtrTraits> mask,\n    PackedTensorAccessor32<bool, 3, RestrictPtrTraits> output) {\n\n    index_t n_batch = mask.size(0);\n    index_t n_point = mask.size(1);\n    index_t n_init = mask.size(2);\n    CUDA_KERNEL_LOOP_TYPE(index, nthreads, index_t) {\n\n        const index_t i_batch = index / (n_batch*n_point);\n        const index_t i_point = index % (n_batch*n_point);\n\n\n        for(index_t i = 0; i < n_init; i++) {\n            if(!mask[i_batch][i_point][i]){\n                output[i_batch][i_point][i] = false;\n                continue;\n            }\n            scalar_t xi0 = x[i_batch][i_point][i][0];\n            scalar_t xi1 = x[i_batch][i_point][i][1];\n            scalar_t xi2 = x[i_batch][i_point][i][2];\n\n            bool flag = true;\n            for(index_t j = i+1; j < n_init; j++){\n                if(!mask[i_batch][i_point][j]){\n                    continue;\n                }\n                scalar_t d0 = xi0 - x[i_batch][i_point][j][0];\n                scalar_t d1 = xi1 - x[i_batch][i_point][j][1];\n                scalar_t d2 = xi2 - x[i_batch][i_point][j][2];\n\n                scalar_t dist = d0*d0 + d1*d1 + d2*d2;\n                if(dist<0.0001*0.0001){\n                    flag=false;\n                    break;\n                }\n            }\n\n            output[i_batch][i_point][i] = flag;\n        }\n\n    }\n}\n\nvoid launch_filter(\n    Tensor &output,\n    const Tensor &x,\n    const Tensor &mask) {\n\n  // calculate #threads required\n  int64_t B = output.size(0);\n  int64_t N = output.size(1);\n\n  int64_t count = B*N;\n  if (count > 0) {\n      AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), \"filter\", [&] {\n            filter<scalar_t>\n            <<<GET_BLOCKS(count, 512), 512, 0, at::cuda::getCurrentCUDAStream()>>>(\n              static_cast<int>(count),\n              x.packed_accessor32<scalar_t,4,RestrictPtrTraits>(),\n              mask.packed_accessor32<bool,3,RestrictPtrTraits>(),\n              output.packed_accessor32<bool,3,RestrictPtrTraits>());\n          C10_CUDA_KERNEL_LAUNCH_CHECK();\n      });\n  }\n}\n"
  },
  {
    "path": "lib/models/deformers/fast_snarf/cuda/fuse_kernel/fuse_cuda.cpp",
    "content": "#include \"ATen/Functions.h\"\n#include \"ATen/core/TensorBody.h\"\n#include <torch/extension.h>\n#include <ATen/ATen.h>\n#include <vector>\n#include <c10/cuda/CUDAGuard.h>\n\n\n\nvoid launch_broyden_kernel(torch::Tensor &x,\n                           const torch::Tensor &xd_tgt,\n                           const torch::Tensor &grid,\n                           const torch::Tensor &grid_J_inv,\n                           const torch::Tensor &tfs,\n                           const torch::Tensor &bone_ids,\n                           bool align_corners,\n                          //  torch::Tensor &J_inv,\n                           torch::Tensor &is_valid,\n                           const torch::Tensor& offset,\n                           const torch::Tensor& scale,\n                           float cvg_threshold,\n                           float dvg_threshold);\n\n\nvoid fuse_broyden(torch::Tensor &x,\n                  const torch::Tensor &xd_tgt,\n                  const torch::Tensor &grid,\n                  const torch::Tensor &grid_J_inv,\n                  const torch::Tensor &tfs,\n                  const torch::Tensor &bone_ids,\n                  bool align_corners,\n                  // torch::Tensor& J_inv,\n                  torch::Tensor &is_valid,\n                  torch::Tensor& offset,\n                  torch::Tensor& scale,\n                  float cvg_threshold,\n                  float dvg_threshold) {\n\n  const at::cuda::OptionalCUDAGuard device_guard(device_of(x));\n\n  launch_broyden_kernel(x, xd_tgt, grid, grid_J_inv, tfs, bone_ids, align_corners, is_valid, offset, scale, cvg_threshold, dvg_threshold);\n  return;\n}\n\n\n\nPYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {\n  m.def(\"fuse_broyden\", &fuse_broyden);\n}"
  },
  {
    "path": "lib/models/deformers/fast_snarf/cuda/fuse_kernel/fuse_cuda_kernel.cu",
    "content": "#include \"ATen/Functions.h\"\n#include \"ATen/core/TensorAccessor.h\"\n#include \"c10/cuda/CUDAException.h\"\n#include \"c10/cuda/CUDAStream.h\"\n\n#include <ratio>\n#include <vector>\n#include <ATen/cuda/CUDAContext.h>\n#include <ATen/cuda/detail/TensorInfo.cuh>\n#include <ATen/cuda/detail/IndexUtils.cuh>\n#include <ATen/cuda/detail/KernelUtils.h>\n#include <ATen/core/TensorBase.h>\n#include <ATen/Dispatch.h>\n#include <c10/macros/Macros.h>\n\n#include <chrono>\nusing namespace std::chrono;\n\nusing namespace at;\nusing namespace at::cuda::detail;\n\ntemplate <typename scalar_t, typename index_t>\n__device__ void fuse_J_inv_update(const index_t index,\n                                  scalar_t* J_inv,\n                                  scalar_t x0,\n                                  scalar_t x1,\n                                  scalar_t x2,\n                                  scalar_t g0,\n                                  scalar_t g1,\n                                  scalar_t g2) {\n\n    // index_t s_J = J_inv.strides[0];\n    // index_t s_x = delta_x.strides[0];\n    // index_t s_g = delta_gx.strides[0];\n    index_t s_J = 9;\n    index_t s_x = 3;\n    index_t s_g = 3;\n\n    scalar_t J00 = J_inv[3*0 + 0];\n    scalar_t J01 = J_inv[3*0 + 1];\n    scalar_t J02 = J_inv[3*0 + 2];\n    scalar_t J10 = J_inv[3*1 + 0];\n    scalar_t J11 = J_inv[3*1 + 1];\n    scalar_t J12 = J_inv[3*1 + 2];\n    scalar_t J20 = J_inv[3*2 + 0];\n    scalar_t J21 = J_inv[3*2 + 1];\n    scalar_t J22 = J_inv[3*2 + 2];\n\n    auto c0 = J00 * x0 + J10 * x1 + J20 * x2;\n    auto c1 = J01 * x0 + J11 * x1 + J21 * x2;\n    auto c2 = J02 * x0 + J12 * x1 + J22 * x2;\n\n    auto s = c0 * g0 + c1 * g1 + c2 * g2;\n\n    auto r0 = -J00 * g0 - J01 * g1 - J02 * g2;\n    auto r1 = -J10 * g0 - J11 * g1 - J12 * g2;\n    auto r2 = -J20 * g0 - J21 * g1 - J22 * g2;\n\n    J_inv[3*0 + 0] += c0 * (r0 + x0) / s;\n    J_inv[3*0 + 1] += c1 * (r0 + x0) / s;\n    J_inv[3*0 + 2] += c2 * (r0 + x0) / s;\n    J_inv[3*1 + 0] += c0 * (r1 + x1) / s;\n    J_inv[3*1 + 1] += c1 * (r1 + x1) / s;\n    J_inv[3*1 + 2] += c2 * (r1 + x1) / s;\n    J_inv[3*2 + 0] += c0 * (r2 + x2) / s;\n    J_inv[3*2 + 1] += c1 * (r2 + x2) / s;\n    J_inv[3*2 + 2] += c2 * (r2 + x2) / s;\n\n}\n\n\nstatic __forceinline__ __device__\nbool within_bounds_3d(int d, int h, int w, int D, int H, int W) {\n  return d >= 0 && d < D && h >= 0 && h < H && w >= 0 && w < W;\n}\n\n\ntemplate <typename scalar_t>\nstatic __forceinline__ __device__\nscalar_t grid_sampler_unnormalize(scalar_t coord, int size, bool align_corners) {\n  if (align_corners) {\n    // unnormalize coord from [-1, 1] to [0, size - 1]\n    return ((coord + 1.f) / 2) * (size - 1);\n  } else {\n    // unnormalize coord from [-1, 1] to [-0.5, size - 0.5]\n    return ((coord + 1.f) * size - 1) / 2;\n  }\n}\n\n\n// Clips coordinates to between 0 and clip_limit - 1\ntemplate <typename scalar_t>\nstatic __forceinline__ __device__\nscalar_t clip_coordinates(scalar_t in, int clip_limit) {\n  return ::min(static_cast<scalar_t>(clip_limit - 1), ::max(in, static_cast<scalar_t>(0)));\n}\n\n\ntemplate<typename scalar_t>\nstatic __forceinline__ __device__\nscalar_t safe_downgrade_to_int_range(scalar_t x){\n  // -100.0 does not have special meaning. This is just to make sure\n  // it's not within_bounds_2d or within_bounds_3d, and does not cause\n  // undefined behavior. See #35506.\n  if (x > INT_MAX-1 || x < INT_MIN || !::isfinite(static_cast<double>(x)))\n    return static_cast<scalar_t>(-100.0);\n  return x;\n}\n\n\ntemplate<typename scalar_t>\nstatic __forceinline__ __device__\nscalar_t compute_coordinates(scalar_t coord, int size, bool align_corners) {\n  // clip coordinates to image borders\n  // coord = clip_coordinates(coord, size);\n  coord = safe_downgrade_to_int_range(coord);\n  return coord;\n}\n\ntemplate <typename scalar_t>\nstatic __forceinline__ __device__\nscalar_t grid_sampler_compute_source_index(scalar_t coord, int size, bool align_corners) {\n  coord = grid_sampler_unnormalize(coord, size, align_corners);\n  coord = compute_coordinates(coord, size, align_corners);\n  return coord;\n}\n\ntemplate <typename scalar_t, typename index_t>\n__device__ void grid_sampler_3d(\n                                index_t i_batch,\n                                TensorInfo<scalar_t, index_t> input,\n                                scalar_t grid_x,\n                                scalar_t grid_y,\n                                scalar_t grid_z,\n                                // TensorInfo<scalar_t, index_t> output,\n                                PackedTensorAccessor32<scalar_t, 5> input_p, // [1, 3, 8, 32, 32]\n                                scalar_t* output,\n                                // PackedTensorAccessor32<scalar_t, 3> output_p, // [1800000, 3, 1]\n                                bool align_corners,\n                                bool nearest) {\n\n  index_t C = input.sizes[1];\n  index_t inp_D = input.sizes[2];\n  index_t inp_H = input.sizes[3];\n  index_t inp_W = input.sizes[4];\n  // broyden x.sizes=[1800000,3,1]\n  \n  index_t inp_sN = input.strides[0];\n  index_t inp_sC = input.strides[1];\n  index_t inp_sD = input.strides[2];\n  index_t inp_sH = input.strides[3];\n  index_t inp_sW = input.strides[4];\n\n  index_t out_sC = 1; //output size is same as grid size...\n\n  // return;\n\n  // get the corresponding input x, y, z co-ordinates from grid\n  scalar_t ix = grid_x;\n  scalar_t iy = grid_y;\n  scalar_t iz = grid_z;\n\n  // c0 ix,iy,iz=-0.848051,0.592726,0.259927\n  // c1 ix,iy,iz=2.355216,24.687256,4.409743\n\n  ix = grid_sampler_compute_source_index(ix, inp_W, align_corners);\n  iy = grid_sampler_compute_source_index(iy, inp_H, align_corners);\n  iz = grid_sampler_compute_source_index(iz, inp_D, align_corners);\n\n  if(!nearest){\n    // get corner pixel values from (x, y, z)\n    // for 4d, we used north-east-south-west\n    // for 5d, we add top-bottom\n    index_t ix_tnw = static_cast<index_t>(::floor(ix));\n    index_t iy_tnw = static_cast<index_t>(::floor(iy));\n    index_t iz_tnw = static_cast<index_t>(::floor(iz));\n\n    index_t ix_tne = ix_tnw + 1;\n    index_t iy_tne = iy_tnw;\n    index_t iz_tne = iz_tnw;\n\n    index_t ix_tsw = ix_tnw;\n    index_t iy_tsw = iy_tnw + 1;\n    index_t iz_tsw = iz_tnw;\n\n    index_t ix_tse = ix_tnw + 1;\n    index_t iy_tse = iy_tnw + 1;\n    index_t iz_tse = iz_tnw;\n\n    index_t ix_bnw = ix_tnw;\n    index_t iy_bnw = iy_tnw;\n    index_t iz_bnw = iz_tnw + 1;\n\n    index_t ix_bne = ix_tnw + 1;\n    index_t iy_bne = iy_tnw;\n    index_t iz_bne = iz_tnw + 1;\n\n    index_t ix_bsw = ix_tnw;\n    index_t iy_bsw = iy_tnw + 1;\n    index_t iz_bsw = iz_tnw + 1;\n\n    index_t ix_bse = ix_tnw + 1;\n    index_t iy_bse = iy_tnw + 1;\n    index_t iz_bse = iz_tnw + 1;\n\n    // get surfaces to each neighbor:\n    scalar_t tnw = (ix_bse - ix)    * (iy_bse - iy)    * (iz_bse - iz);\n    scalar_t tne = (ix    - ix_bsw) * (iy_bsw - iy)    * (iz_bsw - iz);\n    scalar_t tsw = (ix_bne - ix)    * (iy    - iy_bne) * (iz_bne - iz);\n    scalar_t tse = (ix    - ix_bnw) * (iy    - iy_bnw) * (iz_bnw - iz);\n    scalar_t bnw = (ix_tse - ix)    * (iy_tse - iy)    * (iz - iz_tse);\n    scalar_t bne = (ix    - ix_tsw) * (iy_tsw - iy)    * (iz - iz_tsw);\n    scalar_t bsw = (ix_tne - ix)    * (iy    - iy_tne) * (iz - iz_tne);\n    scalar_t bse = (ix    - ix_tnw) * (iy    - iy_tnw) * (iz - iz_tnw);\n\n\n    for (index_t xyz = 0; xyz < C; xyz++) {\n      output[xyz] = 0;\n\n      if (within_bounds_3d(iz_tnw, iy_tnw, ix_tnw, inp_D, inp_H, inp_W)) {\n        output[xyz] += input_p[i_batch][xyz][iz_tnw][iy_tnw][ix_tnw] * tnw;\n        // *out_ptr_NCDHW += inp_ptr_NC[iz_tnw * inp_sD + iy_tnw * inp_sH + ix_tnw * inp_sW] * tnw;\n      }\n      if (within_bounds_3d(iz_tne, iy_tne, ix_tne, inp_D, inp_H, inp_W)) {\n        output[xyz] += input_p[i_batch][xyz][iz_tne][iy_tne][ix_tne] * tne;\n        // *out_ptr_NCDHW += inp_ptr_NC[iz_tne * inp_sD + iy_tne * inp_sH + ix_tne * inp_sW] * tne;\n      }\n      if (within_bounds_3d(iz_tsw, iy_tsw, ix_tsw, inp_D, inp_H, inp_W)) {\n        output[xyz] += input_p[i_batch][xyz][iz_tsw][iy_tsw][ix_tsw] * tsw;\n        // *out_ptr_NCDHW += inp_ptr_NC[iz_tsw * inp_sD + iy_tsw * inp_sH + ix_tsw * inp_sW] * tsw;\n      }\n      if (within_bounds_3d(iz_tse, iy_tse, ix_tse, inp_D, inp_H, inp_W)) {\n        output[xyz] += input_p[i_batch][xyz][iz_tse][iy_tse][ix_tse] * tse;\n        // *out_ptr_NCDHW += inp_ptr_NC[iz_tse * inp_sD + iy_tse * inp_sH + ix_tse * inp_sW] * tse;\n      }\n      if (within_bounds_3d(iz_bnw, iy_bnw, ix_bnw, inp_D, inp_H, inp_W)) {\n        output[xyz] += input_p[i_batch][xyz][iz_bnw][iy_bnw][ix_bnw] * bnw;\n        // *out_ptr_NCDHW += inp_ptr_NC[iz_bnw * inp_sD + iy_bnw * inp_sH + ix_bnw * inp_sW] * bnw;\n      }\n      if (within_bounds_3d(iz_bne, iy_bne, ix_bne, inp_D, inp_H, inp_W)) {\n        output[xyz] += input_p[i_batch][xyz][iz_bne][iy_bne][ix_bne] * bne;\n        // *out_ptr_NCDHW += inp_ptr_NC[iz_bne * inp_sD + iy_bne * inp_sH + ix_bne * inp_sW] * bne;\n      }\n      if (within_bounds_3d(iz_bsw, iy_bsw, ix_bsw, inp_D, inp_H, inp_W)) {\n        output[xyz] += input_p[i_batch][xyz][iz_bsw][iy_bsw][ix_bsw] * bsw;\n        // *out_ptr_NCDHW += inp_ptr_NC[iz_bsw * inp_sD + iy_bsw * inp_sH + ix_bsw * inp_sW] * bsw;\n      }\n      if (within_bounds_3d(iz_bse, iy_bse, ix_bse, inp_D, inp_H, inp_W)) {\n        output[xyz] += input_p[i_batch][xyz][iz_bse][iy_bse][ix_bse] * bse;\n        // *out_ptr_NCDHW += inp_ptr_NC[iz_bse * inp_sD + iy_bse * inp_sH + ix_bse * inp_sW] * bse;\n      }\n    }\n    }\n    else{\n        index_t ix_nearest = static_cast<index_t>(::round(ix));\n        index_t iy_nearest = static_cast<index_t>(::round(iy));\n        index_t iz_nearest = static_cast<index_t>(::round(iz));\n\n        for (index_t xyz = 0; xyz < C; xyz++) {\n          if (within_bounds_3d(iz_nearest, iy_nearest, ix_nearest, inp_D, inp_H, inp_W)) {\n            output[xyz] = input_p[i_batch][xyz][iz_nearest][iy_nearest][ix_nearest]; \n          } else {\n            output[xyz] = static_cast<scalar_t>(0);\n          }\n        }\n      }\n}\n\n\n\ntemplate <typename scalar_t, typename index_t>\nC10_LAUNCH_BOUNDS_1(512)\n__global__ void broyden_kernel(\n                               const index_t npoints,\n                              const index_t n_batch,\n                              const index_t n_point,\n                              const index_t n_init,\n                               TensorInfo<scalar_t, index_t> voxel_ti,\n                               TensorInfo<scalar_t, index_t> voxel_J_ti,\n                               PackedTensorAccessor32<scalar_t, 4> x, // shape=(N,200000, 9, 3)\n                               PackedTensorAccessor32<scalar_t, 3> xd_tgt, // shape=(N,200000, 3)\n                               PackedTensorAccessor32<scalar_t, 5> voxel, // shape=(N,3,8,32,32)\n                               PackedTensorAccessor32<scalar_t, 5> grid_J_inv, // shape=(N,9,8,32,32)\n                               PackedTensorAccessor32<scalar_t, 4> tfs, // shape=(N,24,4,4)\n                               PackedTensorAccessor32<int, 1> bone_ids, // shape=(9)\n                              //  PackedTensorAccessor32<scalar_t, 5> J_inv,// shape=(N,200000, 9, 9)\n                               PackedTensorAccessor32<bool, 3> is_valid,// shape=(N,200000, 9)\n                               PackedTensorAccessor32<scalar_t, 3> offset, // shape=(N, 1, 3) \n                               PackedTensorAccessor32<scalar_t, 3> scale, // shape=(N, 1, 3)\n                               float cvg_threshold,\n                               float dvg_threshold,\n                               int N\n                               ) \n{\n\n  index_t index = blockIdx.x * blockDim.x + threadIdx.x;\n  if(index >= npoints) return;\n\n  \n  const index_t i_batch = index / (n_point*n_init);\n  const index_t i_point = (index % (n_point*n_init)) / n_init;\n  const index_t i_init = (index %  (n_point*n_init)) % n_init;\n \n  if(!is_valid[i_batch][i_point][i_init]){\n    return;\n  }\n\n  scalar_t gx[3];\n  scalar_t gx_new[3];\n\n  scalar_t xd_tgt_index[3];\n\n  xd_tgt_index[0] = xd_tgt[i_batch][i_point][0];\n  xd_tgt_index[1] = xd_tgt[i_batch][i_point][1];\n  xd_tgt_index[2] = xd_tgt[i_batch][i_point][2];\n\n  scalar_t x_l[3];\n\n  int i_bone = bone_ids[i_init];\n  scalar_t ixd = xd_tgt_index[0] - tfs[i_batch][i_bone][0][3]; \n  scalar_t iyd = xd_tgt_index[1] - tfs[i_batch][i_bone][1][3]; \n  scalar_t izd = xd_tgt_index[2] - tfs[i_batch][i_bone][2][3]; \n  x_l[0] = ixd * tfs[i_batch][i_bone][0][0]  \n            + iyd * tfs[i_batch][i_bone][1][0] \n            + izd * tfs[i_batch][i_bone][2][0];\n\n\n  x_l[1] = ixd * tfs[i_batch][i_bone][0][1]  \n            + iyd * tfs[i_batch][i_bone][1][1] \n            + izd * tfs[i_batch][i_bone][2][1];\n\n\n  x_l[2] = ixd * tfs[i_batch][i_bone][0][2]  \n            + iyd * tfs[i_batch][i_bone][1][2] \n            + izd * tfs[i_batch][i_bone][2][2];\n\n\n  // scalar_t ix = scale[0][0][0] * (x_l[0] + offset[0][0][0]);\n  // scalar_t iy = scale[0][0][1] * (x_l[1] + offset[0][0][1]);\n  // scalar_t iz = scale[0][0][2] * (x_l[2] + offset[0][0][2]);\n\n  scalar_t J_local[12];\n  grid_sampler_3d( i_batch, voxel_J_ti,\n                    scale[0][0][0] * (x_l[0] + offset[0][0][0]),\n                    scale[0][0][1] * (x_l[1] + offset[0][0][1]),\n                    scale[0][0][2] * (x_l[2] + offset[0][0][2]),\n                    grid_J_inv,\n                    J_local,\n                    true,\n                    false);\n\n  scalar_t J_inv_local[9];\n  J_inv_local[3*0 + 0] = J_local[4*0 + 0]; \n  J_inv_local[3*1 + 0] = J_local[4*0 + 1];\n  J_inv_local[3*2 + 0] = J_local[4*0 + 2];\n  J_inv_local[3*0 + 1] = J_local[4*1 + 0];\n  J_inv_local[3*1 + 1] = J_local[4*1 + 1];\n  J_inv_local[3*2 + 1] = J_local[4*1 + 2];\n  J_inv_local[3*0 + 2] = J_local[4*2 + 0];\n  J_inv_local[3*1 + 2] = J_local[4*2 + 1];\n  J_inv_local[3*2 + 2] = J_local[4*2 + 2];\n  \n  for(int i=0; i<10; i++) {\n\n    scalar_t J00 = J_inv_local[3*0 + 0]; \n    scalar_t J01 = J_inv_local[3*0 + 1];\n    scalar_t J02 = J_inv_local[3*0 + 2];\n    scalar_t J10 = J_inv_local[3*1 + 0];\n    scalar_t J11 = J_inv_local[3*1 + 1];\n    scalar_t J12 = J_inv_local[3*1 + 2];\n    scalar_t J20 = J_inv_local[3*2 + 0];\n    scalar_t J21 = J_inv_local[3*2 + 1];\n    scalar_t J22 = J_inv_local[3*2 + 2];\n\n    // gx = g(x)\n    if (i==0){\n      // grid_sampler_3d( i_batch, voxel_ti,\n      //                   scale[0][0][0] * (x_l[0] + offset[0][0][0]),\n      //                   scale[0][0][1] * (x_l[1] + offset[0][0][1]),\n      //                   scale[0][0][2] * (x_l[2] + offset[0][0][2]),\n      //                   voxel,\n      //                   gx,\n      //                   true,\n      //                   false);\n                        \n      gx[0] = J_local[4*0+0] * x_l[0] + J_local[4*0+1] * x_l[1] + J_local[4*0+2] * x_l[2] + J_local[4*0+3];\n      gx[1] = J_local[4*1+0] * x_l[0] + J_local[4*1+1] * x_l[1] + J_local[4*1+2] * x_l[2] + J_local[4*1+3];\n      gx[2] = J_local[4*2+0] * x_l[0] + J_local[4*2+1] * x_l[1] + J_local[4*2+2] * x_l[2] + J_local[4*2+3];\n                      \n      gx[0] = gx[0] - xd_tgt_index[0];\n      gx[1] = gx[1] - xd_tgt_index[1];\n      gx[2] = gx[2] - xd_tgt_index[2];\n\n    }\n    else{\n      gx[0] = gx_new[0];\n      gx[1] = gx_new[1];\n      gx[2] = gx_new[2];\n    }\n\n    // update = -J_inv @ gx\n    scalar_t u0 = -J00*gx[0] + -J01*gx[1] + -J02*gx[2];\n    scalar_t u1 = -J10*gx[0] + -J11*gx[1] + -J12*gx[2];\n    scalar_t u2 = -J20*gx[0] + -J21*gx[1] + -J22*gx[2];\n\n    // x += update\n    x_l[0] += u0;\n    x_l[1] += u1;\n    x_l[2] += u2;\n\n    scalar_t ix = scale[0][0][0] * (x_l[0] + offset[0][0][0]);\n    scalar_t iy = scale[0][0][1] * (x_l[1] + offset[0][0][1]);\n    scalar_t iz = scale[0][0][2] * (x_l[2] + offset[0][0][2]);\n\n    // gx_new = g(x)\n    grid_sampler_3d( i_batch, voxel_J_ti,\n                      ix,\n                      iy,\n                      iz,\n                      grid_J_inv,\n                      J_local,\n                      true,\n                      false);\n                      \n    gx_new[0] = J_local[4*0+0] * x_l[0] + J_local[4*0+1] * x_l[1] + J_local[4*0+2] * x_l[2] + J_local[4*0+3] - xd_tgt_index[0];\n    gx_new[1] = J_local[4*1+0] * x_l[0] + J_local[4*1+1] * x_l[1] + J_local[4*1+2] * x_l[2] + J_local[4*1+3] - xd_tgt_index[1];\n    gx_new[2] = J_local[4*2+0] * x_l[0] + J_local[4*2+1] * x_l[1] + J_local[4*2+2] * x_l[2] + J_local[4*2+3] - xd_tgt_index[2];\n\n      // grid_sampler_3d( i_batch, voxel_ti,\n      //                   scale[0][0][0] * (x_l[0] + offset[0][0][0]),\n      //                   scale[0][0][1] * (x_l[1] + offset[0][0][1]),\n      //                   scale[0][0][2] * (x_l[2] + offset[0][0][2]),\n      //                   voxel,\n      //                   gx_new,\n      //                   true,\n      //                   false);\n\n      // gx_new[0] = gx_new[0] - xd_tgt_index[0];\n      // gx_new[1] = gx_new[1] - xd_tgt_index[1];\n      // gx_new[2] = gx_new[2] - xd_tgt_index[2];\n             \n    // convergence checking\n    scalar_t norm_gx = gx_new[0]*gx_new[0] + gx_new[1]*gx_new[1] + gx_new[2]*gx_new[2];\n\n    // convergence/divergence criterion\n    if(norm_gx < cvg_threshold*cvg_threshold) {\n\n      bool is_valid_ = ix >= -1 && ix <= 1 && iy >= -1 && iy <= 1 && iz >= -1 && iz <= 1;\n\n      is_valid[i_batch][i_point][i_init] = is_valid_;\n\n      if (is_valid_){\n      x[i_batch][i_point][i_init][0] = x_l[0];\n      x[i_batch][i_point][i_init][1] = x_l[1];\n      x[i_batch][i_point][i_init][2] = x_l[2];\n\n      // J_inv[i_batch][i_point][i_init][0][0] = J00;\n      // J_inv[i_batch][i_point][i_init][0][1] = J01;\n      // J_inv[i_batch][i_point][i_init][0][2] = J02;\n      // J_inv[i_batch][i_point][i_init][1][0] = J10;\n      // J_inv[i_batch][i_point][i_init][1][1] = J11;\n      // J_inv[i_batch][i_point][i_init][1][2] = J12;\n      // J_inv[i_batch][i_point][i_init][2][0] = J20;\n      // J_inv[i_batch][i_point][i_init][2][1] = J21;\n      // J_inv[i_batch][i_point][i_init][2][2] = J22;\n      }\n      return;\n\n    }\n    else if(norm_gx > dvg_threshold*dvg_threshold) {\n      is_valid[i_batch][i_point][i_init] = false;\n      return;\n    }\n    \n\n    // delta_x = update\n    scalar_t delta_x_0 = u0;\n    scalar_t delta_x_1 = u1;\n    scalar_t delta_x_2 = u2;\n\n    // delta_gx = gx_new - gx\n    scalar_t delta_gx_0 = gx_new[0] - gx[0];\n    scalar_t delta_gx_1 = gx_new[1] - gx[1];\n    scalar_t delta_gx_2 = gx_new[2] - gx[2];\n\n    fuse_J_inv_update(index, J_inv_local, delta_x_0, delta_x_1, delta_x_2, delta_gx_0, delta_gx_1, delta_gx_2);\n\n  }\n  is_valid[i_batch][i_point][i_init] = false;\n  return;\n\n}\n\n\nvoid launch_broyden_kernel(\n                           Tensor &x,\n                           const Tensor &xd_tgt,\n    const Tensor &voxel,\n    const Tensor &grid_J_inv,\n    const Tensor &tfs,\n    const Tensor &bone_ids,\n    bool align_corners,\n    // Tensor &J_inv,\n    Tensor &is_valid,\n    const Tensor& offset,\n    const Tensor& scale,\n    float cvg_threshold,\n    float dvg_threshold) {\n\n  // calculate #threads required\n  int64_t n_batch = xd_tgt.size(0);\n  int64_t n_point = xd_tgt.size(1);\n  int64_t n_init = bone_ids.size(0);\n  int64_t count = n_batch * n_point * n_init;\n\n\n  if (count > 0) {\n    AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), \"fuse_kernel_cuda\", [&] {\n      broyden_kernel\n        <<<GET_BLOCKS(count, 512), 512, 0,\n        at::cuda::getCurrentCUDAStream()>>>(static_cast<int>(count),\n                                            static_cast<int>(n_batch),\n                                            static_cast<int>(n_point),\n                                            static_cast<int>(n_init),\n                                            getTensorInfo<scalar_t, int>(voxel),\n                                            getTensorInfo<scalar_t, int>(grid_J_inv),\n                                            x.packed_accessor32<scalar_t, 4>(),\n                                            xd_tgt.packed_accessor32<scalar_t, 3>(),\n                                            voxel.packed_accessor32<scalar_t, 5>(),\n                                            grid_J_inv.packed_accessor32<scalar_t, 5>(),\n                                            tfs.packed_accessor32<scalar_t, 4>(),\n                                            bone_ids.packed_accessor32<int, 1>(),\n                                            // J_inv.packed_accessor32<scalar_t, 5>(),\n                                            is_valid.packed_accessor32<bool, 3>(),\n                                            offset.packed_accessor32<scalar_t, 3>(),\n                                            scale.packed_accessor32<scalar_t, 3>(),\n                                            cvg_threshold,\n                                            dvg_threshold,\n                                            0);\n      C10_CUDA_KERNEL_LAUNCH_CHECK();\n\n    });\n  }\n\n  cudaDeviceSynchronize();\n}\n\n"
  },
  {
    "path": "lib/models/deformers/fast_snarf/cuda/precompute/precompute.cpp",
    "content": "#include <torch/extension.h>\n#include <ATen/ATen.h>\n#include <vector>\n#include <c10/cuda/CUDAGuard.h>\n\n\n\n\n\nvoid launch_precompute(const torch::Tensor &voxel_w, const torch::Tensor &tfs, torch::Tensor &voxel_d, torch::Tensor &voxel_J, const torch::Tensor &offset, const torch::Tensor &scale);\n\nvoid precompute(const torch::Tensor &voxel_w, const torch::Tensor &tfs, torch::Tensor &voxel_d, torch::Tensor &voxel_J, const torch::Tensor &offset, const torch::Tensor &scale) {\n  const at::cuda::OptionalCUDAGuard device_guard(device_of(voxel_w));\n\n  launch_precompute(voxel_w, tfs, voxel_d, voxel_J, offset, scale);\n}\n\nPYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {\n  m.def(\"precompute\", &precompute);\n}\n"
  },
  {
    "path": "lib/models/deformers/fast_snarf/cuda/precompute/precompute_kernel.cu",
    "content": "#include \"ATen/Functions.h\"\n#include \"ATen/core/TensorAccessor.h\"\n#include \"c10/cuda/CUDAException.h\"\n#include \"c10/cuda/CUDAStream.h\"\n\n#include <ratio>\n#include <vector>\n#include <ATen/cuda/CUDAContext.h>\n#include <ATen/cuda/detail/TensorInfo.cuh>\n#include <ATen/cuda/detail/IndexUtils.cuh>\n#include <ATen/cuda/detail/KernelUtils.h>\n#include <ATen/core/TensorBase.h>\n#include <ATen/Dispatch.h>\n#include <c10/macros/Macros.h>\n\n#include <chrono>\nusing namespace std::chrono;\n\nusing namespace at;\nusing namespace at::cuda::detail;\n\n\ntemplate <typename scalar_t, typename index_t>\nC10_LAUNCH_BOUNDS_1(512)\n__global__ void precompute_kernel(\n                              const index_t npoints,\n                              const index_t d,\n                              const index_t h,\n                              const index_t w,\n                               PackedTensorAccessor32<scalar_t, 5> voxel_w, // shape=(N,200000, 9, 3)\n                               PackedTensorAccessor32<scalar_t, 4> tfs, // shape=(N,200000, 3)\n                               PackedTensorAccessor32<scalar_t, 5> voxel_d, // shape=(N,3,8,32,32)\n                               PackedTensorAccessor32<scalar_t, 5> voxel_J, // shape=(N,9,8,32,32)\n                               PackedTensorAccessor32<scalar_t, 3> offset, // shape=(N, 1, 3) \n                               PackedTensorAccessor32<scalar_t, 3> scale // shape=(N, 1, 3)\n                               ) \n{\n\n  index_t index = blockIdx.x * blockDim.x + threadIdx.x;\n  if(index >= npoints) return;\n\n  index_t idx_b = index / (d*h*w);\n  index_t idx_d = index % (d*h*w) / (h*w);\n  index_t idx_h = index % (d*h*w) % (h*w) / w;\n  index_t idx_w = index % (d*h*w) % (h*w) % w;\n\n  scalar_t coord_x = ( ((scalar_t)idx_w) / (w-1) * 2 -1) / scale[0][0][0] - offset[0][0][0];\n  scalar_t coord_y = ( ((scalar_t)idx_h) / (h-1) * 2 -1) / scale[0][0][1] - offset[0][0][1];\n  scalar_t coord_z = ( ((scalar_t)idx_d) / (d-1) * 2 -1) / scale[0][0][2] - offset[0][0][2];\n\n  scalar_t J[12];\n\n  for(index_t i0 = 0; i0 < 3; i0++){\n    for(index_t i1 = 0; i1 < 4; i1++){\n      J[i0*4 + i1] = 0;\n      for(index_t j = 0; j < 24; j++){\n        J[i0*4 + i1] += voxel_w[0][j][idx_d][idx_h][idx_w]*tfs[idx_b][j][i0][i1];\n      }\n    }\n  }\n  for(index_t i0 = 0; i0 < 3; i0++){\n    for(index_t i1 = 0; i1 < 4; i1++){\n      voxel_J[idx_b][i0*4 + i1][idx_d][idx_h][idx_w] = J[i0*4 + i1];\n    }\n  }\n\n  for(index_t i0 = 0; i0 < 3; i0++){\n    scalar_t xi = J[i0*4 + 0]*coord_x + J[i0*4 + 1]*coord_y + J[i0*4 + 2]*coord_z + J[i0*4 + 3];\n    voxel_d[idx_b][i0][idx_d][idx_h][idx_w] = xi;\n  }\n}\n\nvoid launch_precompute(\n                           const Tensor &voxel_w,\n                           const Tensor &tfs,\n                           Tensor &voxel_d,\n                           Tensor &voxel_J,\n                           const Tensor &offset,\n                           const Tensor &scale\n                            ) {\n\n  // calculate #threads required\n  int64_t n_batch = voxel_d.size(0);\n  int64_t d = voxel_d.size(2);\n  int64_t h = voxel_d.size(3);\n  int64_t w = voxel_d.size(4);\n\n  int64_t count = n_batch*d*h*w;\n\n\n  if (count > 0) {\n    AT_DISPATCH_FLOATING_TYPES_AND_HALF(voxel_w.scalar_type(), \"precompute\", [&] {\n      precompute_kernel\n        <<<GET_BLOCKS(count, 512), 512, 0,\n        at::cuda::getCurrentCUDAStream()>>>(static_cast<int>(count),\n                                            static_cast<int>(d),\n                                            static_cast<int>(h),\n                                            static_cast<int>(w),\n                                            voxel_w.packed_accessor32<scalar_t, 5>(),\n                                            tfs.packed_accessor32<scalar_t, 4>(),\n                                            voxel_d.packed_accessor32<scalar_t, 5>(),\n                                            voxel_J.packed_accessor32<scalar_t, 5>(),\n                                            offset.packed_accessor32<scalar_t, 3>(),\n                                            scale.packed_accessor32<scalar_t, 3>()\n                                            );\n\n      C10_CUDA_KERNEL_LAUNCH_CHECK();\n\n    });\n  }\n\n  cudaDeviceSynchronize();\n}\n"
  },
  {
    "path": "lib/models/deformers/fast_snarf/lib/model/deformer_smpl.py",
    "content": "import torch\nfrom torch import einsum\nimport torch.nn.functional as F\nimport os\n\nfrom torch.utils.cpp_extension import load\n\nimport fuse_cuda \nimport filter_cuda\nimport precompute_cuda\nimport numpy as np\n\n\nclass ForwardDeformer(torch.nn.Module):\n    \"\"\"\n    Tensor shape abbreviation:\n        B: batch size\n        N: number of points\n        J: number of bones\n        I: number of init\n        D: space dimension\n    \"\"\"\n\n    def __init__(self,  **kwargs):\n        super().__init__()\n\n        self.soft_blend = 20\n\n        self.init_bones = [0, 1, 2, 4, 5, 12, 15, 16, 17, 18, 19]\n        \n        self.init_bones_cuda = torch.tensor(self.init_bones).int()\n        \n        self.global_scale = 1.2\n        \n\n    def forward(self, xd, cond, mask, tfs, eval_mode=False):\n        \"\"\"Given deformed point return its caonical correspondence\n        Args:\n            xd (tensor): deformed points in batch. shape: [B, N, D]\n            cond (dict): conditional input.\n            tfs (tensor): bone transformation matrices. shape: [B, J, D+1, D+1]\n        Returns:\n            xc (tensor): canonical correspondences. shape: [B, N, I, D]\n            others (dict): other useful outputs.\n        \"\"\"\n\n        xc_opt, others = self.search(xd, cond, mask, tfs, eval_mode=True)\n\n\n        if eval_mode:\n            return xc_opt, others\n\n\n    def precompute(self, tfs):\n\n        b, c, d, h, w = tfs.shape[0], 3, self.resolution//4, self.resolution, self.resolution\n        voxel_d = torch.zeros((b,3,d,h,w), device=tfs.device)\n        voxel_J = torch.zeros((b,12,d,h,w), device=tfs.device)\n        precompute_cuda.precompute(self.lbs_voxel_final, tfs, voxel_d, voxel_J, self.offset_kernel, self.scale_kernel)\n        self.voxel_d = voxel_d\n        self.voxel_J = voxel_J\n\n    def search(self, xd, cond, mask, tfs, eval_mode=False):\n        \"\"\"Search correspondences.\n\n        Args:\n            xd (tensor): deformed points in batch. shape: [B, N, D]\n            xc_init (tensor): deformed points in batch. shape: [B, N, I, D]\n            cond (dict): conditional input.\n            tfs (tensor): bone transformation matrices. shape: [B, J, D+1, D+1]\n\n        Returns:\n            xc_opt (tensor): canonoical correspondences of xd. shape: [B, N, I, D]\n            valid_ids (tensor): identifiers of converged points. [B, N, I]\n        \"\"\"\n\n        # reshape to [B,?,D] for other functions\n\n        # run broyden without grad\n        with torch.no_grad():\n            result = self.broyden_cuda(xd, self.voxel_d, self.voxel_J, tfs, mask)\n\n        return result['result'], result\n\n    def broyden_cuda(self,\n                    xd_tgt,\n                    voxel,\n                    voxel_J_inv,\n                    tfs,\n                    mask,\n                    cvg_thresh=2e-4,\n                    dvg_thresh=1):\n        \"\"\"\n        Args:\n            g:     f: (N, 3, 1) -> (N, 3, 1)\n            x:     (N, 3, 1)\n            J_inv: (N, 3, 3)\n        \"\"\"\n        b,n,_ = xd_tgt.shape\n        n_init = self.init_bones_cuda.shape[0]\n\n        xc_init_IN = torch.zeros((b,n,n_init,3),device=xd_tgt.device,dtype=torch.float)\n\n        is_valid = mask.expand(b,n,n_init).clone()\n\n        if self.init_bones_cuda.device != xd_tgt.device:\n            self.init_bones_cuda = self.init_bones_cuda.to(xd_tgt.device)\n        fuse_cuda.fuse_broyden(xc_init_IN, xd_tgt, voxel, voxel_J_inv, tfs, self.init_bones_cuda, True, is_valid, self.offset_kernel, self.scale_kernel, cvg_thresh, dvg_thresh)\n\n        is_valid_new = torch.zeros_like(is_valid)\n        filter_cuda.filter(xc_init_IN, is_valid, is_valid_new)\n\n        return {\"result\": xc_init_IN, 'valid_ids': is_valid_new} #, 'J_inv': J_inv_init_IN}\n\n\n    def forward_skinning(self, xc, cond, tfs, mask=None):\n        \"\"\"Canonical point -> deformed point\n\n        Args:\n            xc (tensor): canonoical points in batch. shape: [B, N, D]\n            cond (dict): conditional input.\n            tfs (tensor): bone transformation matrices. shape: [B, J, D+1, D+1]\n\n        Returns:\n            xd (tensor): deformed point. shape: [B, N, D]\n        \"\"\"\n\n        w = self.query_weights(xc, cond, mask=mask)\n        b,n,_ = xc.shape\n        xd = torch.zeros((b,n,3), device=xc.device, dtype=torch.float)\n        w_tf = torch.eye(4, device=xc.device, dtype=torch.float).reshape(1, 1, 4, 4).expand(b, n, -1, -1).clone()\n        xd[mask, :], w_tf[mask, :] = skinning_mask(xc[mask,:], w[mask,:], tfs, inverse=False)\n        return xd, w_tf\n\n    def switch_to_explicit(self,resolution=32,smpl_verts=None, smpl_faces=None, smpl_weights=None, use_smpl=False):\n        \n        self.resolution = resolution\n        # convert to voxel grid\n        # device = self.device\n    \n        b, c, d, h, w = 1, 24, resolution//4, resolution, resolution\n        self.ratio = h/d\n        grid = create_voxel_grid(d, h, w)\n        device = grid.device\n\n        gt_bbox = torch.cat([smpl_verts.min(dim=1).values, smpl_verts.max(dim=1).values], dim=0).to(device)\n        \n        offset = (gt_bbox[0] + gt_bbox[1])[None,None,:] * 0.5\n        scale = (gt_bbox[1] - gt_bbox[0]).max()/2 * self.global_scale\n\n        self.register_buffer('scale', scale)\n        self.register_buffer('offset', offset)\n\n        self.register_buffer('offset_kernel', -self.offset)\n        scale_kernel = torch.zeros_like(self.offset)\n        scale_kernel[...] = 1./self.scale\n        scale_kernel[:,:,-1] = scale_kernel[:,:,-1] * self.ratio\n        self.register_buffer('scale_kernel', scale_kernel)\n        \n        def normalize(x):\n            x_normalized = (x+self.offset_kernel)*self.scale_kernel\n            return x_normalized\n\n        def denormalize(x):\n            x_denormalized = x.clone() #/self.global_scale\n            x_denormalized[..., -1] = x_denormalized[..., -1]/self.ratio\n            x_denormalized *= self.scale\n            x_denormalized += self.offset\n\n            return x_denormalized\n\n        self.normalize = normalize\n        self.denormalize = denormalize\n\n        grid_denorm = self.denormalize(grid)\n\n        weights = query_weights_smpl(grid_denorm, smpl_verts=smpl_verts.detach().clone(), smpl_weights=smpl_weights.detach().clone()).detach().clone()\n\n        self.register_buffer('lbs_voxel_final', weights.detach())\n        self.register_buffer('grid_denorm',grid_denorm)\n\n        def query_weights( xc, cond=None, mask=None, mode='bilinear'):\n            w = F.grid_sample(self.lbs_voxel_final.expand(xc.shape[0],-1,-1,-1,-1), self.normalize(xc).unsqueeze(2).unsqueeze(2),align_corners=True, mode=mode,padding_mode='border')\n            w = w.squeeze(-1).squeeze(-1).permute(0,2,1)\n            return w\n    \n        self.query_weights = query_weights\n\n    def update_lbs_voxel(self):\n        self.lbs_voxel_final = F.softmax( self.lbs_voxel*20,dim=1)\n        def query_weights( xc, cond=None, mask=None, mode='bilinear'):\n            w = F.grid_sample(self.lbs_voxel_final.expand(xc.shape[0],-1,-1,-1,-1), self.normalize(xc).unsqueeze(2).unsqueeze(2),align_corners=True, mode=mode,padding_mode='border')\n            w = w.squeeze(-1).squeeze(-1).permute(0,2,1)\n            return w\n\n        self.query_weights = query_weights\n\n\n    def query_sdf_smpl(self, x, smpl_verts, smpl_faces, smpl_weights):\n        \n        device = x.device\n\n        resolution=128\n        b, c, d, h, w = 1, 24, resolution//4, resolution, resolution\n        grid = create_voxel_grid(d, h, w, device)\n        grid = self.denormalize(grid)\n\n        import trimesh\n        mesh = trimesh.Trimesh(vertices=smpl_verts.data.cpu().numpy()[0], faces=smpl_faces.data.cpu().numpy())\n        BVH = cubvh.cuBVH(mesh.vertices, mesh.faces)\n    \n        sdf, face_id, uvw = BVH.signed_distance(grid, return_uvw=True, mode='watertight') # [N], [N], [N, 3]\n\n        sdf = sdf.reshape(1, -1, 1)\n        b, c, d, h, w = 1, 1, resolution//4, resolution, resolution\n\n        sdf = -sdf.permute(0,2,1).reshape(b,c,d,h,w)\n\n        return sdf.detach()\n\n    def skinning_normal(self, xc, normal, tfs, cond=None, mask=None, inverse=False):\n        ''' skinning normals\n        \n        Args:\n            x (tensor): canonical points. shape: [B, N, D]\n            normal (tensor): canonical normals. shape: [B, N, D]\n            tfs (tensor): bone transformation matrices. shape: [B, J, D+1, D+1]\n        Returns:\n            posed normal (tensor): posed normals. shape: [B, N, D]\n            \n        '''\n        if xc.ndim == 2:\n            xc = xc.unsqueeze(0)\n        if normal.ndim == 2:\n            normal = normal.unsqueeze(0)\n        w = self.query_weights(xc, cond, mask=mask)\n        p_h = F.pad(normal, (0, 1), value=0)\n        p_h = torch.einsum('bpn, bnij, bpj->bpi', w, tfs, p_h)\n\n        return p_h[:, :, :3]\n    \ndef skinning_mask(x, w, tfs, inverse=False):\n    \"\"\"Linear blend skinning\n\n    Args:\n        x (tensor): canonical points. shape: [B, N, D]\n        w (tensor): conditional input. [B, N, J]\n        tfs (tensor): bone transformation matrices. shape: [B, J, D+1, D+1]\n    Returns:\n        x (tensor): skinned points. shape: [B, N, D]\n    \"\"\"\n    x_h = F.pad(x, (0, 1), value=1.0)\n    p,n = w.shape\n\n    if inverse:\n        # p:n_point, n:n_bone, i,k: n_dim+1\n\n        w_tf = einsum(\"bpn,bnij->bpij\", w, tfs)\n\n        x_h = x_h.view(b,p,1,4).expand(b,p,4,4)\n        x_h = (fast_inverse(w_tf)*x_h).sum(-1)\n\n    else:\n        w_tf = einsum(\"pn,nij->pij\", w, tfs.squeeze(0))\n\n        x_h = x_h.view(p,1,4).expand(p,4,4)\n        x_h = (w_tf*x_h).sum(-1)\n\n    return x_h[:, :3], w_tf\n\ndef skinning(x, w, tfs, inverse=False):\n    \"\"\"Linear blend skinning\n\n    Args:\n        x (tensor): canonical points. shape: [B, N, D]\n        w (tensor): conditional input. [B, N, J]\n        tfs (tensor): bone transformation matrices. shape: [B, J, D+1, D+1]\n    Returns:\n        x (tensor): skinned points. shape: [B, N, D]\n    \"\"\"\n    x_h = F.pad(x, (0, 1), value=1.0)\n    b,p,n = w.shape\n\n    if inverse:\n        # p:n_point, n:n_bone, i,k: n_dim+1\n        w_tf = einsum(\"bpn,bnij->bpij\", w, fast_inverse(tfs))\n\n        x_h = x_h.view(b,p,1,4).expand(b,p,4,4)\n        # x_h = (fast_inverse(w_tf)*x_h).sum(-1)\n        x_h = (w_tf*x_h).sum(-1)\n\n    else:\n        w_tf = einsum(\"bpn,bnij->bpij\", w, tfs)\n\n        x_h = x_h.view(b,p,1,4).expand(b,p,4,4)\n        x_h = (w_tf*x_h).sum(-1)\n\n    return x_h[:, :, :3]\n\ndef fast_inverse(T):\n\n    shape = T.shape\n\n    T = T.reshape(-1,4,4)\n    R = T[:, :3,:3]\n    t = T[:, :3,3].unsqueeze(-1)\n\n    R_inv = R.transpose(1,2)\n    t_inv = -bmv(R_inv,t)\n\n    T_inv = T\n    T_inv[:,:3,:3] = R_inv\n    T_inv[:,:3,3] = t_inv.squeeze(-1)\n    \n    return T_inv.reshape(shape)\n\ndef bmv(m, v):\n    return (m*v.transpose(-1,-2).expand(-1,3,-1)).sum(-1,keepdim=True)\n\ndef create_voxel_grid(d, h, w, device='cuda'):\n    x_range = (torch.linspace(-1,1,steps=w,device=device)).view(1, 1, 1, w).expand(1, d, h, w)  # [1, H, W, D]\n    y_range = (torch.linspace(-1,1,steps=h,device=device)).view(1, 1, h, 1).expand(1, d, h, w)  # [1, H, W, D]\n    z_range = (torch.linspace(-1,1,steps=d,device=device)).view(1, d, 1, 1).expand(1, d, h, w)  # [1, H, W, D]\n    grid = torch.cat((x_range, y_range, z_range), dim=0).reshape(1, 3,-1).permute(0,2,1)\n\n    return grid\n\n\ndef query_weights_smpl(x, smpl_verts, smpl_weights):\n    import pytorch3d.ops as ops\n\n    device = smpl_weights.device\n    distance_batch, index_batch, neighbor_points  = ops.knn_points(x.to(device),smpl_verts.to(device).detach(),K=10,return_nn=True)\n\n    # neighbor_points = neighbor_points[0]\n    distance_batch = distance_batch[0].sqrt().clamp_(0.00003,0.1)\n    index_batch = index_batch[0]\n    \n    # GPU_id = index_batch.get_device()\n    # print(GPU_id)\n    weights = smpl_weights[0,index_batch]\n\n   \n    ws=1./distance_batch\n    ws=ws/ws.sum(-1,keepdim=True)\n    weights = (ws[:,:,None]*weights).sum(1)[None]\n\n\n    resolution = 64\n\n    b, c, d, h, w = 1, 24, resolution//4, resolution, resolution\n    weights = weights.permute(0,2,1).reshape(b,c,d,h,w)\n\n    return weights.detach()"
  },
  {
    "path": "lib/models/deformers/fast_snarf/lib/model/deformer_smplx.py",
    "content": "import torch\nfrom torch import einsum\nimport torch.nn.functional as F\nimport os\n\nfrom torch.utils.cpp_extension import load\n\nimport fuse_cuda \nimport filter_cuda\nimport precompute_cuda\nimport numpy as np\n\n\nclass ForwardDeformer(torch.nn.Module):\n    \"\"\"\n    Tensor shape abbreviation:\n        B: batch size\n        N: number of points\n        J: number of bones\n        I: number of init\n        D: space dimension\n    \"\"\"\n\n    def __init__(self,  **kwargs):\n        super().__init__()\n\n        self.soft_blend = 20\n\n        self.init_bones = [0, 1, 2, 4, 5, 12, 15, 16, 17, 18, 19]\n        \n        self.init_bones_cuda = torch.tensor(self.init_bones).int()\n        \n        self.global_scale = 1.2\n\n    def forward_skinning(self, xc, shape_offset, pose_offset, cond, tfs, tfs_inv, poseoff_ori, lbsw=None, mask=None, pts_query_lbs=None):\n        \"\"\"Canonical point -> deformed point\n        B == num_scenes\n        Args:\n            xc (tensor): canonoical points in batch. shape: [B, N, D]\n            cond (dict): conditional input.\n            tfs (tensor): bone transformation matrices. shape: [B, J, D+1, D+1] template -> posed\n            tfs_inv (tensor): inverse bone transformation matrices. shape: [B, J, D+1, D+1] T-posed -> template\n            pts_query_lbs (tensor):  canonoical points in batch. shape: [B, N, D] , for Lbs weights\n\n        Returns:\n            xd (tensor): deformed point. shape: [B, N, D]\n        \n        #\n        \"\"\"\n        if pts_query_lbs==None:\n            w = self.query_weights(xc, cond)\n        else: \n            w = self.query_weights(pts_query_lbs, cond)\n        \n        w[:, mask[0]] = lbsw[mask]\n        \n        '''\n        # saving the points xc[1, 50200, 3] with the colored mask [1, 50200] in a ply\n        import trimesh\n        mesh = trimesh.Trimesh(vertices=xc[0].detach().cpu().numpy())\n        mesh.visual.vertex_colors = mask[0].cpu().numpy().reshape(-1,1) * np.array([[255,0,0,255]])\n        mesh.export(\"./xc_mask.ply\")\n        '''\n        b,n,_ = xc.shape\n        xc_cano, w_tf_inv = skinning(xc, w, tfs_inv.expand(b, -1, -1, -1), inverse=False) # T pose space -> template space\n        xc_cano_ori = xc_cano - poseoff_ori.expand(b, -1, -1)\n\n        xc_shape = xc_cano_ori + shape_offset + pose_offset   # template space, use the given points in the template space to forwarding\n        # b,n,_ = xc_shape.shape\n        x_deform, w_tf = skinning(xc_shape, w, tfs, inverse=False)\n        w_tf_all = w_tf @ w_tf_inv.expand(b, -1, -1, -1) # from T-pose to posed space\n\n        \n        return x_deform, w_tf_all\n    \n    \n\n    def switch_to_explicit(self,resolution=32,smpl_verts=None, smpl_faces=None, smpl_weights=None, use_smpl=False):\n        \n        self.resolution = resolution\n        # convert to voxel grid\n    \n        b, c, d, h, w = 1, 55, resolution//4, resolution, resolution\n        \n        self.ratio = h/d\n        grid = create_voxel_grid(d, h, w)\n        device = grid.device\n\n        gt_bbox = torch.cat([smpl_verts.min(dim=1).values, smpl_verts.max(dim=1).values], dim=0).to(device)\n        \n        offset = (gt_bbox[0] + gt_bbox[1])[None,None,:] * 0.5\n        scale = (gt_bbox[1] - gt_bbox[0]).max()/2 * self.global_scale\n\n        self.register_buffer('scale', scale)\n        self.register_buffer('offset', offset)\n\n        self.register_buffer('offset_kernel', -self.offset)\n        scale_kernel = torch.zeros_like(self.offset)\n        scale_kernel[...] = 1./self.scale\n        scale_kernel[:,:,-1] = scale_kernel[:,:,-1] * self.ratio\n        self.register_buffer('scale_kernel', scale_kernel)\n        \n        def normalize(x):\n            x_normalized = (x+self.offset_kernel)*self.scale_kernel\n            return x_normalized\n\n        def denormalize(x):\n            x_denormalized = x.clone() \n            x_denormalized[..., -1] = x_denormalized[..., -1]/self.ratio\n            x_denormalized *= self.scale\n            x_denormalized += self.offset\n\n            return x_denormalized\n\n        self.normalize = normalize\n        self.denormalize = denormalize\n\n        grid_denorm = self.denormalize(grid)\n\n        weights = query_weights_smpl(grid_denorm, smpl_verts=smpl_verts.detach().clone(), smpl_weights=smpl_weights.detach().clone()).detach().clone()\n\n        self.register_buffer('lbs_voxel_final', weights.detach())\n        self.register_buffer('grid_denorm',grid_denorm)\n\n        def query_weights( xc, cond=None, mask=None, mode='bilinear'):\n            w = F.grid_sample(self.lbs_voxel_final.expand(xc.shape[0],-1,-1,-1,-1), self.normalize(xc).unsqueeze(2).unsqueeze(2),align_corners=True, mode=mode,padding_mode='border')\n            w = w.squeeze(-1).squeeze(-1).permute(0,2,1)\n            return w\n    \n        self.query_weights = query_weights\n\n    def update_lbs_voxel(self):\n        self.lbs_voxel_final = F.softmax( self.lbs_voxel*20,dim=1)\n        def query_weights( xc, cond=None, mask=None, mode='bilinear'):\n            w = F.grid_sample(self.lbs_voxel_final.expand(xc.shape[0],-1,-1,-1,-1), self.normalize(xc).unsqueeze(2).unsqueeze(2),align_corners=True, mode=mode,padding_mode='border')\n            w = w.squeeze(-1).squeeze(-1).permute(0,2,1)\n            return w\n\n        self.query_weights = query_weights\n\n\n    def query_sdf_smpl(self, x, smpl_verts, smpl_faces, smpl_weights):\n        \n        device = x.device\n\n        resolution=128\n        b, c, d, h, w = 1, 24, resolution//4, resolution, resolution\n        grid = create_voxel_grid(d, h, w, device)\n        grid = self.denormalize(grid)\n\n        import trimesh\n        mesh = trimesh.Trimesh(vertices=smpl_verts.data.cpu().numpy()[0], faces=smpl_faces.data.cpu().numpy())\n        BVH = cubvh.cuBVH(mesh.vertices, mesh.faces)\n    \n        sdf, face_id, uvw = BVH.signed_distance(grid, return_uvw=True, mode='watertight') # [N], [N], [N, 3]\n\n        sdf = sdf.reshape(1, -1, 1)\n        b, c, d, h, w = 1, 1, resolution//4, resolution, resolution\n\n        sdf = -sdf.permute(0,2,1).reshape(b,c,d,h,w)\n\n        return sdf.detach()\n\n    def skinning_normal(self, xc, normal, tfs, cond=None, mask=None, inverse=False):\n        ''' skinning normals\n        \n        Args:\n            x (tensor): canonical points. shape: [B, N, D]\n            normal (tensor): canonical normals. shape: [B, N, D]\n            tfs (tensor): bone transformation matrices. shape: [B, J, D+1, D+1]\n        Returns:\n            posed normal (tensor): posed normals. shape: [B, N, D]\n            \n        '''\n        if xc.ndim == 2:\n            xc = xc.unsqueeze(0)\n        if normal.ndim == 2:\n            normal = normal.unsqueeze(0)\n        w = self.query_weights(xc, cond, mask=mask)\n        p_h = F.pad(normal, (0, 1), value=0)\n        p_h = torch.einsum('bpn, bnij, bpj->bpi', w, tfs, p_h)\n\n        return p_h[:, :, :3]\n    \ndef skinning_mask(x, w, tfs, inverse=False):\n    \"\"\"Linear blend skinning\n\n    Args:\n        x (tensor): canonical points. shape: [B, N, D]\n        w (tensor): conditional input. [B, N, J]\n        tfs (tensor): bone transformation matrices. shape: [B, J, D+1, D+1]\n    Returns:\n        x (tensor): skinned points. shape: [B, N, D]\n    \"\"\"\n    x_h = F.pad(x, (0, 1), value=1.0)\n    p,n = w.shape\n\n    if inverse:\n        # p:n_point, n:n_bone, i,k: n_dim+1\n\n        w_tf = einsum(\"bpn,bnij->bpij\", w, tfs)\n\n        x_h = x_h.view(b,p,1,4).expand(b,p,4,4)\n        x_h = (fast_inverse(w_tf)*x_h).sum(-1)\n\n    else:\n        w_tf = einsum(\"pn,nij->pij\", w, tfs.squeeze(0))\n\n        x_h = x_h.view(p,1,4).expand(p,4,4)\n        x_h = (w_tf*x_h).sum(-1)\n\n    return x_h[:, :3], w_tf\n\ndef skinning(x, w, tfs, inverse=False):\n    \"\"\"Linear blend skinning\n\n    Args:\n        x (tensor): canonical points. shape: [B, N, D]\n        w (tensor): conditional input. [B, N, J]\n        tfs (tensor): bone transformation matrices. shape: [B, J, D+1, D+1]\n    Returns:\n        x (tensor): skinned points. shape: [B, N, D]\n    \"\"\"\n    x_h = F.pad(x, (0, 1), value=1.0)\n    b,p,n = w.shape\n\n    if inverse:\n        # p:n_point, n:n_bone, i,k: n_dim+1\n        w_tf = einsum(\"bpn,bnij->bpij\", w, fast_inverse(tfs))\n\n        x_h = x_h.view(b,p,1,4).expand(b,p,4,4)\n        # x_h = (fast_inverse(w_tf)*x_h).sum(-1)\n        x_h = (w_tf*x_h).sum(-1)\n\n    else:\n        w_tf = einsum(\"bpn,bnij->bpij\", w, tfs)\n\n        x_h = x_h.view(b,p,1,4).expand(b,p,4,4)\n        x_h = (w_tf*x_h).sum(-1)\n\n    return x_h[:, :, :3], w_tf\n\ndef fast_inverse(T):\n\n    shape = T.shape\n\n    T = T.reshape(-1,4,4)\n    R = T[:, :3,:3]\n    t = T[:, :3,3].unsqueeze(-1)\n\n    R_inv = R.transpose(1,2)\n    t_inv = -bmv(R_inv,t)\n\n    T_inv = T\n    T_inv[:,:3,:3] = R_inv\n    T_inv[:,:3,3] = t_inv.squeeze(-1)\n    \n    return T_inv.reshape(shape)\n\ndef bmv(m, v):\n    return (m*v.transpose(-1,-2).expand(-1,3,-1)).sum(-1,keepdim=True)\n\n\ndef create_voxel_grid(d, h, w, device='cuda'):\n    x_range = (torch.linspace(-1,1,steps=w,device=device)).view(1, 1, 1, w).expand(1, d, h, w)  # [1, H, W, D]\n    y_range = (torch.linspace(-1,1,steps=h,device=device)).view(1, 1, h, 1).expand(1, d, h, w)  # [1, H, W, D]\n    z_range = (torch.linspace(-1,1,steps=d,device=device)).view(1, d, 1, 1).expand(1, d, h, w)  # [1, H, W, D]\n    grid = torch.cat((x_range, y_range, z_range), dim=0).reshape(1, 3,-1).permute(0,2,1)\n\n    return grid\n\n\ndef query_weights_smpl(x, smpl_verts, smpl_weights):\n    import pytorch3d.ops as ops\n\n    device = smpl_weights.device\n    distance_batch, index_batch, neighbor_points  = ops.knn_points(x.to(device),smpl_verts.to(device).detach(),K=10,return_nn=True)\n\n    # neighbor_points = neighbor_points[0]\n    distance_batch = distance_batch[0].sqrt().clamp_(0.00003,0.1)\n    index_batch = index_batch[0]\n    \n    weights = smpl_weights[0,index_batch]\n   \n    ws=1./distance_batch\n    ws=ws/ws.sum(-1,keepdim=True)\n    weights = (ws[:,:,None]*weights).sum(1)[None]\n\n    resolution = 64\n\n    b, c, d, h, w = 1, 55, resolution//4, resolution, resolution\n    weights = weights.permute(0,2,1).reshape(b,c,d,h,w)\n\n    return weights.detach()"
  },
  {
    "path": "lib/models/deformers/smplx/__init__.py",
    "content": "# -*- coding: utf-8 -*-\n\n# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is\n# holder of all proprietary rights on this computer program.\n# You can only use this computer program if you have closed\n# a license agreement with MPG or you get the right to use the computer\n# program from someone who is authorized to grant you that right.\n# Any use of the computer program without a valid license is prohibited and\n# liable to prosecution.\n#\n# Copyright©2019 Max-Planck-Gesellschaft zur Förderung\n# der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute\n# for Intelligent Systems. All rights reserved.\n#\n# Contact: ps-license@tuebingen.mpg.de\n\nfrom .body_models import (\n    SMPL, SMPLX\n)\n"
  },
  {
    "path": "lib/models/deformers/smplx/body_models.py",
    "content": "#  -*- coding: utf-8 -*-\n\n# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is\n# holder of all proprietary rights on this computer program.\n# You can only use this computer program if you have closed\n# a license agreement with MPG or you get the right to use the computer\n# program from someone who is authorized to grant you that right.\n# Any use of the computer program without a valid license is prohibited and\n# liable to prosecution.\n#\n# Copyright©2019 Max-Planck-Gesellschaft zur Förderung\n# der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute\n# for Intelligent Systems. All rights reserved.\n#\n# Contact: ps-license@tuebingen.mpg.de\n\nfrom typing import Optional, Dict, Union\nimport os\nimport os.path as osp\n\nimport pickle\n\nimport numpy as np\n\nimport torch\nimport torch.nn as nn\n\nfrom .lbs import (\n    lbs, blend_shapes, vertices2landmarks)\n\nfrom .vertex_ids import vertex_ids as VERTEX_IDS\nfrom .utils import (\n    Struct, to_np, to_tensor, Tensor, Array,\n    SMPLOutput, SMPLHOutput,\n    SMPLXOutput)\nfrom .vertex_joint_selector import VertexJointSelector\n\n\nclass SMPL(nn.Module):\n    NUM_JOINTS = 23\n    NUM_BODY_JOINTS = 23\n    SHAPE_SPACE_DIM = 300\n\n    def __init__(\n        self, model_path: str,\n        kid_template_path: str = '',\n        data_struct: Optional[Struct] = None,\n        create_betas: bool = True,\n        betas: Optional[Tensor] = None,\n        num_betas: int = 10,\n        create_global_orient: bool = True,\n        global_orient: Optional[Tensor] = None,\n        create_body_pose: bool = True,\n        body_pose: Optional[Tensor] = None,\n        create_transl: bool = True,\n        transl: Optional[Tensor] = None,\n        dtype=torch.float32,\n        batch_size: int = 1,\n        joint_mapper=None,\n        gender: str = 'neutral',\n        age: str = 'adult',\n        vertex_ids: Dict[str, int] = None,\n        v_template: Optional[Union[Tensor, Array]] = None,\n        **kwargs\n    ) -> None:\n        ''' SMPL model constructor\n\n            Parameters\n            ----------\n            model_path: str\n                The path to the folder or to the file where the model\n                parameters are stored\n            data_struct: Strct\n                A struct object. If given, then the parameters of the model are\n                read from the object. Otherwise, the model tries to read the\n                parameters from the given `model_path`. (default = None)\n            create_global_orient: bool, optional\n                Flag for creating a member variable for the global orientation\n                of the body. (default = True)\n            global_orient: torch.tensor, optional, Bx3\n                The default value for the global orientation variable.\n                (default = None)\n            create_body_pose: bool, optional\n                Flag for creating a member variable for the pose of the body.\n                (default = True)\n            body_pose: torch.tensor, optional, Bx(Body Joints * 3)\n                The default value for the body pose variable.\n                (default = None)\n            num_betas: int, optional\n                Number of shape components to use\n                (default = 10).\n            create_betas: bool, optional\n                Flag for creating a member variable for the shape space\n                (default = True).\n            betas: torch.tensor, optional, Bx10\n                The default value for the shape member variable.\n                (default = None)\n            create_transl: bool, optional\n                Flag for creating a member variable for the translation\n                of the body. (default = True)\n            transl: torch.tensor, optional, Bx3\n                The default value for the transl variable.\n                (default = None)\n            dtype: torch.dtype, optional\n                The data type for the created variables\n            batch_size: int, optional\n                The batch size used for creating the member variables\n            joint_mapper: object, optional\n                An object that re-maps the joints. Useful if one wants to\n                re-order the SMPL joints to some other convention (e.g. MSCOCO)\n                (default = None)\n            gender: str, optional\n                Which gender to load\n            vertex_ids: dict, optional\n                A dictionary containing the indices of the extra vertices that\n                will be selected\n        '''\n\n        self.gender = gender\n        self.age = age\n\n        if data_struct is None:\n            if osp.isdir(model_path):\n                model_fn = 'SMPL_{}.{ext}'.format(gender.upper(), ext='pkl')\n                smpl_path = os.path.join(model_path, model_fn)\n            else:\n                smpl_path = model_path\n            assert osp.exists(smpl_path), 'Path {} does not exist!'.format(\n                smpl_path)\n\n            with open(smpl_path, 'rb') as smpl_file:\n                data_struct = Struct(**pickle.load(smpl_file,\n                                                   encoding='latin1'))\n\n        super(SMPL, self).__init__()\n        self.batch_size = batch_size\n        shapedirs = data_struct.shapedirs\n        if (shapedirs.shape[-1] < self.SHAPE_SPACE_DIM):\n            print(f'WARNING: You are using a {self.name()} model, with only'\n                  ' 10 shape coefficients.')\n            num_betas = min(num_betas, 10)\n        else:\n            num_betas = min(num_betas, self.SHAPE_SPACE_DIM)\n\n        if self.age=='kid':\n            v_template_smil = np.load(kid_template_path)\n            v_template_smil -= np.mean(v_template_smil, axis=0)\n            v_template_diff = np.expand_dims(v_template_smil - data_struct.v_template, axis=2)\n            shapedirs = np.concatenate((shapedirs[:, :, :num_betas], v_template_diff), axis=2)\n            num_betas = num_betas + 1\n\n        self._num_betas = num_betas\n        shapedirs = shapedirs[:, :, :num_betas]\n        # The shape components\n        self.register_buffer(\n            'shapedirs',\n            to_tensor(to_np(shapedirs), dtype=dtype))\n\n        if vertex_ids is None:\n            # SMPL and SMPL-H share the same topology, so any extra joints can\n            # be drawn from the same place\n            vertex_ids = VERTEX_IDS['smplh']\n\n        self.dtype = dtype\n\n        self.joint_mapper = joint_mapper\n\n        self.vertex_joint_selector = VertexJointSelector(\n            vertex_ids=vertex_ids, **kwargs)\n\n        self.faces = data_struct.f\n        self.register_buffer('faces_tensor',\n                             to_tensor(to_np(self.faces, dtype=np.int64),\n                                       dtype=torch.long))\n\n        if create_betas:\n            if betas is None:\n                default_betas = torch.zeros(\n                    [batch_size, self.num_betas], dtype=dtype)\n            else:\n                if torch.is_tensor(betas):\n                    default_betas = betas.clone().detach()\n                else:\n                    default_betas = torch.tensor(betas, dtype=dtype)\n\n            self.register_parameter(\n                'betas', nn.Parameter(default_betas, requires_grad=True))\n\n        # The tensor that contains the global rotation of the model\n        # It is separated from the pose of the joints in case we wish to\n        # optimize only over one of them\n        if create_global_orient:\n            if global_orient is None:\n                default_global_orient = torch.zeros(\n                    [batch_size, 3], dtype=dtype)\n            else:\n                if torch.is_tensor(global_orient):\n                    default_global_orient = global_orient.clone().detach()\n                else:\n                    default_global_orient = torch.tensor(\n                        global_orient, dtype=dtype)\n\n            global_orient = nn.Parameter(default_global_orient,\n                                         requires_grad=True)\n            self.register_parameter('global_orient', global_orient)\n\n        if create_body_pose:\n            if body_pose is None:\n                default_body_pose = torch.zeros(\n                    [batch_size, self.NUM_BODY_JOINTS * 3], dtype=dtype)\n            else:\n                if torch.is_tensor(body_pose):\n                    default_body_pose = body_pose.clone().detach()\n                else:\n                    default_body_pose = torch.tensor(body_pose,\n                                                     dtype=dtype)\n            self.register_parameter(\n                'body_pose',\n                nn.Parameter(default_body_pose, requires_grad=True))\n\n        if create_transl:\n            if transl is None:\n                default_transl = torch.zeros([batch_size, 3],\n                                             dtype=dtype,\n                                             requires_grad=True)\n            else:\n                default_transl = torch.tensor(transl, dtype=dtype)\n            self.register_parameter(\n                'transl', nn.Parameter(default_transl, requires_grad=True))\n\n                            \n                                       \n        if v_template is None:\n            v_template = data_struct.v_template\n        if not torch.is_tensor(v_template):\n            v_template = to_tensor(to_np(v_template), dtype=dtype)\n        # The vertices of the template model\n        self.register_buffer('v_template', v_template)\n\n        j_regressor = to_tensor(to_np(\n            data_struct.J_regressor), dtype=dtype)\n        self.register_buffer('J_regressor', j_regressor)\n\n        # Pose blend shape basis: 6890 x 3 x 207, reshaped to 6890*3 x 207\n        num_pose_basis = data_struct.posedirs.shape[-1]\n        # 207 x 20670\n        posedirs = np.reshape(data_struct.posedirs, [-1, num_pose_basis]).T\n        self.register_buffer('posedirs',\n                             to_tensor(to_np(posedirs), dtype=dtype))\n\n        # indices of parents for each joints\n        parents = to_tensor(to_np(data_struct.kintree_table[0])).long()\n        parents[0] = -1\n        self.register_buffer('parents', parents)\n\n        lbs_weights = to_tensor(to_np(data_struct.weights), dtype=dtype)\n        self.register_buffer('lbs_weights', lbs_weights)\n\n    @property\n    def num_betas(self):\n        return self._num_betas\n\n    @property\n    def num_expression_coeffs(self):\n        return 0\n\n    def create_mean_pose(self, data_struct) -> Tensor:\n        pass\n\n    def name(self) -> str:\n        return 'SMPL'\n\n    @torch.no_grad()\n    def reset_params(self, **params_dict) -> None:\n        for param_name, param in self.named_parameters():\n            if param_name in params_dict:\n                param[:] = torch.tensor(params_dict[param_name])\n            else:\n                param.fill_(0)\n\n    def get_num_verts(self) -> int:\n        return self.v_template.shape[0]\n\n    def get_num_faces(self) -> int:\n        return self.faces.shape[0]\n\n    def extra_repr(self) -> str:\n        msg = [\n            f'Gender: {self.gender.upper()}',\n            f'Number of joints: {self.J_regressor.shape[0]}',\n            f'Betas: {self.num_betas}',\n        ]\n        return '\\n'.join(msg)\n\n    def forward_shape(\n        self,\n        betas: Optional[Tensor] = None,\n    ) -> SMPLOutput:\n        betas = betas if betas is not None else self.betas\n        v_shaped = self.v_template + blend_shapes(betas, self.shapedirs)\n        return SMPLOutput(vertices=v_shaped, betas=betas, v_shaped=v_shaped)\n\n    def forward(\n        self,\n        betas: Optional[Tensor] = None,\n        body_pose: Optional[Tensor] = None,\n        global_orient: Optional[Tensor] = None,\n        transl: Optional[Tensor] = None,\n        return_verts=True,\n        return_full_pose: bool = False,\n        pose2rot: bool = True,\n        scale: Optional[Tensor] = None,\n        **kwargs\n    ) -> SMPLOutput:\n        ''' Forward pass for the SMPL model\n\n            Parameters\n            ----------\n            global_orient: torch.tensor, optional, shape Bx3\n                If given, ignore the member variable and use it as the global\n                rotation of the body. Useful if someone wishes to predicts this\n                with an external model. (default=None)\n            betas: torch.tensor, optional, shape BxN_b\n                If given, ignore the member variable `betas` and use it\n                instead. For example, it can used if shape parameters\n                `betas` are predicted from some external model.\n                (default=None)\n            body_pose: torch.tensor, optional, shape Bx(J*3)\n                If given, ignore the member variable `body_pose` and use it\n                instead. For example, it can used if someone predicts the\n                pose of the body joints are predicted from some external model.\n                It should be a tensor that contains joint rotations in\n                axis-angle format. (default=None)\n            transl: torch.tensor, optional, shape Bx3\n                If given, ignore the member variable `transl` and use it\n                instead. For example, it can used if the translation\n                `transl` is predicted from some external model.\n                (default=None)\n            return_verts: bool, optional\n                Return the vertices. (default=True)\n            return_full_pose: bool, optional\n                Returns the full axis-angle pose vector (default=False)\n\n            Returns\n            -------\n        '''\n        # If no shape and pose parameters are passed along, then use the\n        # ones from the module\n        global_orient = (global_orient if global_orient is not None else\n                         self.global_orient)\n        body_pose = body_pose if body_pose is not None else self.body_pose\n        betas = betas if betas is not None else self.betas\n\n        apply_trans = transl is not None or hasattr(self, 'transl')\n        \n        if transl is None and hasattr(self, 'transl'):\n            transl = self.transl\n        \n        full_pose = torch.cat([global_orient, body_pose], dim=1)\n\n        \n        scale = scale if scale is not None else torch.ones([global_orient.shape[0], 1], dtype=global_orient.dtype, device = global_orient.device)\n        \n        batch_size = max(betas.shape[0], global_orient.shape[0],\n                         body_pose.shape[0])\n\n        if betas.shape[0] != batch_size:\n            num_repeats = int(batch_size / betas.shape[0])\n            betas = betas.expand(num_repeats, -1)\n\n        vertices, joints, A, T, shape_offset, pose_offset = lbs(\n            betas, full_pose, self.v_template, self.shapedirs, self.posedirs,\n            self.J_regressor, self.parents, self.lbs_weights, pose2rot=pose2rot)\n\n        joints = self.vertex_joint_selector(vertices, joints)\n        # Map the joints to the current dataset\n        if self.joint_mapper is not None:\n            joints = self.joint_mapper(joints)\n\n        if apply_trans:\n            joints += transl.unsqueeze(dim=1)\n            vertices += transl.unsqueeze(dim=1)\n            A[..., :3, 3] += transl.unsqueeze(dim=1)\n            T[..., :3, 3] += transl.unsqueeze(dim=1)\n\n\n        joints = joints * (scale.reshape(-1,1,1))\n        vertices = vertices * (scale.reshape(-1,1,1))\n\n        A[..., :3,:3] = A[..., :3,:3] * (scale.reshape(-1, 1,1,1))\n        T[..., :3,:3] = T[..., :3,:3] * (scale.reshape(-1,1,1,1))\n        \n        output = SMPLOutput(vertices=vertices if return_verts else None,\n                            global_orient=global_orient,\n                            body_pose=body_pose,\n                            joints=joints,\n                            betas=betas,\n                            full_pose=full_pose if return_full_pose else None,\n                            A=A,\n                            T=T,\n                            shape_offset=shape_offset,\n                            pose_offset=pose_offset)\n        return output\n\nclass SMPLH(SMPL):\n\n    # The hand joints are replaced by MANO\n    NUM_BODY_JOINTS = SMPL.NUM_JOINTS - 2\n    NUM_HAND_JOINTS = 15\n    NUM_JOINTS = NUM_BODY_JOINTS + 2 * NUM_HAND_JOINTS\n\n    def __init__(\n        self, model_path,\n        kid_template_path: str = '',\n        data_struct: Optional[Struct] = None,\n        create_left_hand_pose: bool = True,\n        left_hand_pose: Optional[Tensor] = None,\n        create_right_hand_pose: bool = True,\n        right_hand_pose: Optional[Tensor] = None,\n        use_pca: bool = True,\n        num_pca_comps: int = 6,\n        num_betas=16,\n        flat_hand_mean: bool = False,\n        batch_size: int = 1,\n        gender: str = 'neutral',\n        age: str = 'adult',\n        dtype=torch.float32,\n        vertex_ids=None,\n        use_compressed: bool = True,\n        ext: str = 'pkl',\n        **kwargs\n    ) -> None:\n        ''' SMPLH model constructor\n\n            Parameters\n            ----------\n            model_path: str\n                The path to the folder or to the file where the model\n                parameters are stored\n            data_struct: Strct\n                A struct object. If given, then the parameters of the model are\n                read from the object. Otherwise, the model tries to read the\n                parameters from the given `model_path`. (default = None)\n            create_left_hand_pose: bool, optional\n                Flag for creating a member variable for the pose of the left\n                hand. (default = True)\n            left_hand_pose: torch.tensor, optional, BxP\n                The default value for the left hand pose member variable.\n                (default = None)\n            create_right_hand_pose: bool, optional\n                Flag for creating a member variable for the pose of the right\n                hand. (default = True)\n            right_hand_pose: torch.tensor, optional, BxP\n                The default value for the right hand pose member variable.\n                (default = None)\n            num_pca_comps: int, optional\n                The number of PCA components to use for each hand.\n                (default = 6)\n            flat_hand_mean: bool, optional\n                If False, then the pose of the hand is initialized to False.\n            batch_size: int, optional\n                The batch size used for creating the member variables\n            gender: str, optional\n                Which gender to load\n            dtype: torch.dtype, optional\n                The data type for the created variables\n            vertex_ids: dict, optional\n                A dictionary containing the indices of the extra vertices that\n                will be selected\n        '''\n\n        self.num_pca_comps = num_pca_comps\n        # If no data structure is passed, then load the data from the given\n        # model folder\n        if data_struct is None:\n            # Load the model\n            if osp.isdir(model_path):\n                model_fn = 'SMPLH_{}.{ext}'.format(gender.upper(), ext=ext)\n                smplh_path = os.path.join(model_path, model_fn)\n            else:\n                smplh_path = model_path\n            assert osp.exists(smplh_path), 'Path {} does not exist!'.format(\n                smplh_path)\n\n            if ext == 'pkl':\n                with open(smplh_path, 'rb') as smplh_file:\n                    model_data = pickle.load(smplh_file, encoding='latin1')\n            elif ext == 'npz':\n                model_data = np.load(smplh_path, allow_pickle=True)\n            else:\n                raise ValueError('Unknown extension: {}'.format(ext))\n            data_struct = Struct(**model_data)\n\n        if vertex_ids is None:\n            vertex_ids = VERTEX_IDS['smplh']\n\n        super(SMPLH, self).__init__(\n            model_path=model_path,\n            kid_template_path=kid_template_path,\n            data_struct=data_struct,\n            num_betas=num_betas,\n            batch_size=batch_size, vertex_ids=vertex_ids, gender=gender, age=age,\n            use_compressed=use_compressed, dtype=dtype, ext=ext, **kwargs)\n\n        self.use_pca = use_pca\n        self.num_pca_comps = num_pca_comps\n        self.flat_hand_mean = flat_hand_mean\n\n        left_hand_components = data_struct.hands_componentsl[:num_pca_comps]\n        right_hand_components = data_struct.hands_componentsr[:num_pca_comps]\n\n        self.np_left_hand_components = left_hand_components\n        self.np_right_hand_components = right_hand_components\n        if self.use_pca:\n            self.register_buffer(\n                'left_hand_components',\n                torch.tensor(left_hand_components, dtype=dtype))\n            self.register_buffer(\n                'right_hand_components',\n                torch.tensor(right_hand_components, dtype=dtype))\n\n        if self.flat_hand_mean:\n            left_hand_mean = np.zeros_like(data_struct.hands_meanl)\n        else:\n            left_hand_mean = data_struct.hands_meanl\n\n        if self.flat_hand_mean:\n            right_hand_mean = np.zeros_like(data_struct.hands_meanr)\n        else:\n            right_hand_mean = data_struct.hands_meanr\n\n        self.register_buffer('left_hand_mean',\n                             to_tensor(left_hand_mean, dtype=self.dtype))\n        self.register_buffer('right_hand_mean',\n                             to_tensor(right_hand_mean, dtype=self.dtype))\n\n        # Create the buffers for the pose of the left hand\n        hand_pose_dim = num_pca_comps if use_pca else 3 * self.NUM_HAND_JOINTS\n        if create_left_hand_pose:\n            if left_hand_pose is None:\n                default_lhand_pose = torch.zeros([batch_size, hand_pose_dim],\n                                                 dtype=dtype)\n            else:\n                default_lhand_pose = torch.tensor(left_hand_pose, dtype=dtype)\n\n            left_hand_pose_param = nn.Parameter(default_lhand_pose,\n                                                requires_grad=True)\n            self.register_parameter('left_hand_pose',\n                                    left_hand_pose_param)\n\n        if create_right_hand_pose:\n            if right_hand_pose is None:\n                default_rhand_pose = torch.zeros([batch_size, hand_pose_dim],\n                                                 dtype=dtype)\n            else:\n                default_rhand_pose = torch.tensor(right_hand_pose, dtype=dtype)\n\n            right_hand_pose_param = nn.Parameter(default_rhand_pose,\n                                                 requires_grad=True)\n            self.register_parameter('right_hand_pose',\n                                    right_hand_pose_param)\n\n        # Create the buffer for the mean pose.\n        pose_mean_tensor = self.create_mean_pose(\n            data_struct, flat_hand_mean=flat_hand_mean)\n        if not torch.is_tensor(pose_mean_tensor):\n            pose_mean_tensor = torch.tensor(pose_mean_tensor, dtype=dtype)\n        self.register_buffer('pose_mean', pose_mean_tensor)\n\n    def create_mean_pose(self, data_struct, flat_hand_mean=False):\n        # Create the array for the mean pose. If flat_hand is false, then use\n        # the mean that is given by the data, rather than the flat open hand\n        global_orient_mean = torch.zeros([3], dtype=self.dtype)\n        body_pose_mean = torch.zeros([self.NUM_BODY_JOINTS * 3],\n                                     dtype=self.dtype)\n\n        pose_mean = torch.cat([global_orient_mean, body_pose_mean,\n                               self.left_hand_mean,\n                               self.right_hand_mean], dim=0)\n        return pose_mean\n\n    def name(self) -> str:\n        return 'SMPL+H'\n\n    def extra_repr(self):\n        msg = super(SMPLH, self).extra_repr()\n        msg = [msg]\n        if self.use_pca:\n            msg.append(f'Number of PCA components: {self.num_pca_comps}')\n        msg.append(f'Flat hand mean: {self.flat_hand_mean}')\n        return '\\n'.join(msg)\n\n    def forward(\n        self,\n        betas: Optional[Tensor] = None,\n        global_orient: Optional[Tensor] = None,\n        body_pose: Optional[Tensor] = None,\n        left_hand_pose: Optional[Tensor] = None,\n        right_hand_pose: Optional[Tensor] = None,\n        transl: Optional[Tensor] = None,\n        return_verts: bool = True,\n        return_full_pose: bool = False,\n        pose2rot: bool = True,\n        **kwargs\n    ) -> SMPLHOutput:\n        '''\n        '''\n\n        # If no shape and pose parameters are passed along, then use the\n        # ones from the module\n        global_orient = (global_orient if global_orient is not None else\n                         self.global_orient)\n        body_pose = body_pose if body_pose is not None else self.body_pose\n        betas = betas if betas is not None else self.betas\n        left_hand_pose = (left_hand_pose if left_hand_pose is not None else\n                          self.left_hand_pose)\n        right_hand_pose = (right_hand_pose if right_hand_pose is not None else\n                           self.right_hand_pose)\n\n        apply_trans = transl is not None or hasattr(self, 'transl')\n        if transl is None:\n            if hasattr(self, 'transl'):\n                transl = self.transl\n\n        if self.use_pca:\n            left_hand_pose = torch.einsum(\n                'bi,ij->bj', [left_hand_pose, self.left_hand_components])\n            right_hand_pose = torch.einsum(\n                'bi,ij->bj', [right_hand_pose, self.right_hand_components])\n\n        full_pose = torch.cat([global_orient, body_pose,\n                               left_hand_pose,\n                               right_hand_pose], dim=1)\n        full_pose += self.pose_mean\n\n        vertices, joints = lbs(betas, full_pose, self.v_template,\n                               self.shapedirs, self.posedirs,\n                               self.J_regressor, self.parents,\n                               self.lbs_weights, pose2rot=pose2rot)\n\n        # Add any extra joints that might be needed\n        joints = self.vertex_joint_selector(vertices, joints)\n        if self.joint_mapper is not None:\n            joints = self.joint_mapper(joints)\n\n        if apply_trans:\n            joints += transl.unsqueeze(dim=1)\n            vertices += transl.unsqueeze(dim=1)\n\n        output = SMPLHOutput(vertices=vertices if return_verts else None,\n                             joints=joints,\n                             betas=betas,\n                             global_orient=global_orient,\n                             body_pose=body_pose,\n                             left_hand_pose=left_hand_pose,\n                             right_hand_pose=right_hand_pose,\n                             full_pose=full_pose if return_full_pose else None)\n\n        return output\n\nclass SMPLX(SMPLH):\n    '''\n    SMPL-X (SMPL eXpressive) is a unified body model, with shape parameters\n    trained jointly for the face, hands and body.\n    SMPL-X uses standard vertex based linear blend skinning with learned\n    corrective blend shapes, has N=10475 vertices and K=54 joints,\n    which includes joints for the neck, jaw, eyeballs and fingers.\n    '''\n\n    NUM_BODY_JOINTS = SMPLH.NUM_BODY_JOINTS\n    NUM_HAND_JOINTS = 15\n    NUM_FACE_JOINTS = 3\n    NUM_JOINTS = NUM_BODY_JOINTS + 2 * NUM_HAND_JOINTS + NUM_FACE_JOINTS\n    EXPRESSION_SPACE_DIM = 100\n    NECK_IDX = 12\n\n    def __init__(\n        self, model_path: str,\n        kid_template_path: str = '',\n        num_expression_coeffs: int = 10,\n        create_expression: bool = True,\n        expression: Optional[Tensor] = None,\n        create_jaw_pose: bool = True,\n        jaw_pose: Optional[Tensor] = None,\n        create_leye_pose: bool = True,\n        leye_pose: Optional[Tensor] = None,\n        create_reye_pose=True,\n        reye_pose: Optional[Tensor] = None,\n        use_face_contour: bool = False,\n        batch_size: int = 1,\n        gender: str = 'neutral',\n        age: str = 'adult',\n        dtype=torch.float32,\n        ext: str = 'npz',\n        **kwargs\n    ) -> None:\n        ''' SMPLX model constructor\n\n            Parameters\n            ----------\n            model_path: str\n                The path to the folder or to the file where the model\n                parameters are stored\n            num_expression_coeffs: int, optional\n                Number of expression components to use\n                (default = 10).\n            create_expression: bool, optional\n                Flag for creating a member variable for the expression space\n                (default = True).\n            expression: torch.tensor, optional, Bx10\n                The default value for the expression member variable.\n                (default = None)\n            create_jaw_pose: bool, optional\n                Flag for creating a member variable for the jaw pose.\n                (default = False)\n            jaw_pose: torch.tensor, optional, Bx3\n                The default value for the jaw pose variable.\n                (default = None)\n            create_leye_pose: bool, optional\n                Flag for creating a member variable for the left eye pose.\n                (default = False)\n            leye_pose: torch.tensor, optional, Bx10\n                The default value for the left eye pose variable.\n                (default = None)\n            create_reye_pose: bool, optional\n                Flag for creating a member variable for the right eye pose.\n                (default = False)\n            reye_pose: torch.tensor, optional, Bx10\n                The default value for the right eye pose variable.\n                (default = None)\n            use_face_contour: bool, optional\n                Whether to compute the keypoints that form the facial contour\n            batch_size: int, optional\n                The batch size used for creating the member variables\n            gender: str, optional\n                Which gender to load\n            dtype: torch.dtype\n                The data type for the created variables\n        '''\n\n        # Load the model\n        if osp.isdir(model_path):\n            model_fn = 'SMPLX_{}.{ext}'.format(gender.upper(), ext=ext)\n            smplx_path = os.path.join(model_path, model_fn)\n        else:\n            smplx_path = model_path\n        assert osp.exists(smplx_path), 'Path {} does not exist!'.format(\n            smplx_path)\n\n        if ext == 'pkl':\n            with open(smplx_path, 'rb') as smplx_file:\n                model_data = pickle.load(smplx_file, encoding='latin1')\n        elif ext == 'npz':\n            model_data = np.load(smplx_path, allow_pickle=True)\n        else:\n            raise ValueError('Unknown extension: {}'.format(ext))\n\n        data_struct = Struct(**model_data)\n\n        super(SMPLX, self).__init__(\n            model_path=model_path,\n            kid_template_path=kid_template_path,\n            data_struct=data_struct,\n            dtype=dtype,\n            batch_size=batch_size,\n            vertex_ids=VERTEX_IDS['smplx'],\n            gender=gender, age=age, ext=ext,\n            **kwargs)\n\n        lmk_faces_idx = data_struct.lmk_faces_idx\n        self.register_buffer('lmk_faces_idx',\n                             torch.tensor(lmk_faces_idx, dtype=torch.long))\n        lmk_bary_coords = data_struct.lmk_bary_coords\n        self.register_buffer('lmk_bary_coords',\n                             torch.tensor(lmk_bary_coords, dtype=dtype))\n\n        self.bone_parents = to_np(data_struct.kintree_table[0])\n\n        self.use_face_contour = use_face_contour\n        if self.use_face_contour:\n            dynamic_lmk_faces_idx = data_struct.dynamic_lmk_faces_idx\n            dynamic_lmk_faces_idx = torch.tensor(\n                dynamic_lmk_faces_idx,\n                dtype=torch.long)\n            self.register_buffer('dynamic_lmk_faces_idx',\n                                 dynamic_lmk_faces_idx)\n\n            dynamic_lmk_bary_coords = data_struct.dynamic_lmk_bary_coords\n            dynamic_lmk_bary_coords = torch.tensor(\n                dynamic_lmk_bary_coords, dtype=dtype)\n            self.register_buffer('dynamic_lmk_bary_coords',\n                                 dynamic_lmk_bary_coords)\n\n            neck_kin_chain = find_joint_kin_chain(self.NECK_IDX, self.parents)\n            self.register_buffer(\n                'neck_kin_chain',\n                torch.tensor(neck_kin_chain, dtype=torch.long))\n\n        if create_jaw_pose:\n            if jaw_pose is None:\n                default_jaw_pose = torch.zeros([batch_size, 3], dtype=dtype)\n            else:\n                default_jaw_pose = torch.tensor(jaw_pose, dtype=dtype)\n            jaw_pose_param = nn.Parameter(default_jaw_pose,\n                                          requires_grad=True)\n            self.register_parameter('jaw_pose', jaw_pose_param)\n\n        if create_leye_pose:\n            if leye_pose is None:\n                default_leye_pose = torch.zeros([batch_size, 3], dtype=dtype)\n            else:\n                default_leye_pose = torch.tensor(leye_pose, dtype=dtype)\n            leye_pose_param = nn.Parameter(default_leye_pose,\n                                           requires_grad=True)\n            self.register_parameter('leye_pose', leye_pose_param)\n\n        if create_reye_pose:\n            if reye_pose is None:\n                default_reye_pose = torch.zeros([batch_size, 3], dtype=dtype)\n            else:\n                default_reye_pose = torch.tensor(reye_pose, dtype=dtype)\n            reye_pose_param = nn.Parameter(default_reye_pose,\n                                           requires_grad=True)\n            self.register_parameter('reye_pose', reye_pose_param)\n\n        shapedirs = data_struct.shapedirs\n        \n        if len(shapedirs.shape) < 3:\n            shapedirs = shapedirs[:, :, None]\n        if (shapedirs.shape[-1] < self.SHAPE_SPACE_DIM +\n                self.EXPRESSION_SPACE_DIM):\n            print(f'WARNING: You are using a {self.name()} model, with only'\n                  ' 10 shape and 10 expression coefficients.')\n            expr_start_idx = 10\n            expr_end_idx = 20\n            num_expression_coeffs = min(num_expression_coeffs, 10)\n        else:\n            expr_start_idx = self.SHAPE_SPACE_DIM\n            expr_end_idx = self.SHAPE_SPACE_DIM + num_expression_coeffs\n            num_expression_coeffs = min(\n                num_expression_coeffs, self.EXPRESSION_SPACE_DIM)\n\n        self._num_expression_coeffs = num_expression_coeffs\n\n        expr_dirs = shapedirs[:, :, expr_start_idx:expr_end_idx]\n        self.register_buffer(\n            'expr_dirs', to_tensor(to_np(expr_dirs), dtype=dtype))\n\n        if create_expression:\n            if expression is None:\n                default_expression = torch.zeros(\n                    [batch_size, self.num_expression_coeffs], dtype=dtype)\n            else:\n                default_expression = torch.tensor(expression, dtype=dtype)\n            expression_param = nn.Parameter(default_expression,\n                                            requires_grad=True)\n            self.register_parameter('expression', expression_param)\n\n    def name(self) -> str:\n        return 'SMPL-X'\n\n    @property\n    def num_expression_coeffs(self):\n        return self._num_expression_coeffs\n\n    def create_mean_pose(self, data_struct, flat_hand_mean=False):\n        # Create the array for the mean pose. If flat_hand is false, then use\n        # the mean that is given by the data, rather than the flat open hand\n        global_orient_mean = torch.zeros([3], dtype=self.dtype)\n        body_pose_mean = torch.zeros([self.NUM_BODY_JOINTS * 3],\n                                     dtype=self.dtype)\n        jaw_pose_mean = torch.zeros([3], dtype=self.dtype)\n        leye_pose_mean = torch.zeros([3], dtype=self.dtype)\n        reye_pose_mean = torch.zeros([3], dtype=self.dtype)\n\n        pose_mean = np.concatenate([global_orient_mean, body_pose_mean,\n                                    jaw_pose_mean,\n                                    leye_pose_mean, reye_pose_mean,\n                                    self.left_hand_mean, self.right_hand_mean],\n                                   axis=0)\n\n        return pose_mean\n\n    def extra_repr(self):\n        msg = super(SMPLX, self).extra_repr()\n        msg = [\n            msg,\n            f'Number of Expression Coefficients: {self.num_expression_coeffs}'\n        ]\n        return '\\n'.join(msg)\n\n    def forward(\n        self,\n        betas: Optional[Tensor] = None,\n        global_orient: Optional[Tensor] = None,\n        body_pose: Optional[Tensor] = None,\n        left_hand_pose: Optional[Tensor] = None,\n        right_hand_pose: Optional[Tensor] = None,\n        transl: Optional[Tensor] = None,\n        expression: Optional[Tensor] = None,\n        jaw_pose: Optional[Tensor] = None,\n        leye_pose: Optional[Tensor] = None,\n        reye_pose: Optional[Tensor] = None,\n        return_verts: bool = True,\n        return_full_pose: bool = False,\n        pose2rot: bool = True,\n        return_shaped: bool = True,\n        use_pca:bool = True, # specify where to use pca 12 for hands' pose\n        **kwargs\n    ) -> SMPLXOutput:\n        '''\n        Forward pass for the SMPLX model\n\n            Parameters\n            ----------\n            global_orient: torch.tensor, optional, shape Bx3\n                If given, ignore the member variable and use it as the global\n                rotation of the body. Useful if someone wishes to predicts this\n                with an external model. (default=None)\n            betas: torch.tensor, optional, shape BxN_b\n                If given, ignore the member variable `betas` and use it\n                instead. For example, it can used if shape parameters\n                `betas` are predicted from some external model.\n                (default=None)\n            expression: torch.tensor, optional, shape BxN_e\n                If given, ignore the member variable `expression` and use it\n                instead. For example, it can used if expression parameters\n                `expression` are predicted from some external model.\n            body_pose: torch.tensor, optional, shape Bx(J*3)\n                If given, ignore the member variable `body_pose` and use it\n                instead. For example, it can used if someone predicts the\n                pose of the body joints are predicted from some external model.\n                It should be a tensor that contains joint rotations in\n                axis-angle format. (default=None)\n            left_hand_pose: torch.tensor, optional, shape BxP\n                If given, ignore the member variable `left_hand_pose` and\n                use this instead. It should either contain PCA coefficients or\n                joint rotations in axis-angle format.\n            right_hand_pose: torch.tensor, optional, shape BxP\n                If given, ignore the member variable `right_hand_pose` and\n                use this instead. It should either contain PCA coefficients or\n                joint rotations in axis-angle format.\n            jaw_pose: torch.tensor, optional, shape Bx3\n                If given, ignore the member variable `jaw_pose` and\n                use this instead. It should either joint rotations in\n                axis-angle format.\n            transl: torch.tensor, optional, shape Bx3\n                If given, ignore the member variable `transl` and use it\n                instead. For example, it can used if the translation\n                `transl` is predicted from some external model.\n                (default=None)\n            return_verts: bool, optional\n                Return the vertices. (default=True)\n            return_full_pose: bool, optional\n                Returns the full axis-angle pose vector (default=False)\n\n            Returns\n            -------\n                output: ModelOutput\n                A named tuple of type `ModelOutput`\n        '''\n\n        # If no shape and pose parameters are passed along, then use the\n        # ones from the module\n        if global_orient is None:\n            assert False\n        global_orient = (global_orient if global_orient is not None else\n                         self.global_orient)\n        body_pose = body_pose if body_pose is not None else self.body_pose\n        betas = betas if betas is not None else self.betas\n\n        left_hand_pose = (left_hand_pose if left_hand_pose is not None else\n                          self.left_hand_pose)\n        right_hand_pose = (right_hand_pose if right_hand_pose is not None else\n                           self.right_hand_pose)\n        jaw_pose = jaw_pose if jaw_pose is not None else self.jaw_pose\n        leye_pose = leye_pose if leye_pose is not None else self.leye_pose\n        reye_pose = reye_pose if reye_pose is not None else self.reye_pose\n        expression = expression if expression is not None else self.expression\n\n        apply_trans = transl is not None or hasattr(self, 'transl')\n        if transl is None:\n            if hasattr(self, 'transl'):\n                transl = self.transl\n\n        if self.use_pca and use_pca:\n            left_hand_pose = torch.einsum(\n                'bi,ij->bj', [left_hand_pose, self.left_hand_components])\n            right_hand_pose = torch.einsum(\n                'bi,ij->bj', [right_hand_pose, self.right_hand_components])\n\n        full_pose = torch.cat([global_orient.reshape(-1, 1, 3),\n                               body_pose.reshape(-1, self.NUM_BODY_JOINTS, 3),\n                               jaw_pose.reshape(-1, 1, 3),\n                               leye_pose.reshape(-1, 1, 3),\n                               reye_pose.reshape(-1, 1, 3),\n                               left_hand_pose.reshape(-1, 15, 3),\n                               right_hand_pose.reshape(-1, 15, 3)],\n                              dim=1).reshape(-1, 165)\n\n        # Add the mean pose of the model. Does not affect the body, only the\n        # hands when flat_hand_mean == False\n        full_pose += self.pose_mean\n\n        batch_size = max(betas.shape[0], global_orient.shape[0],\n                         body_pose.shape[0])\n        # Concatenate the shape and expression coefficients\n        scale = int(batch_size / betas.shape[0])\n        if scale > 1:\n            betas = betas.expand(scale, -1)\n        shape_components = torch.cat([betas, expression], dim=-1)\n\n        shapedirs = torch.cat([self.shapedirs, self.expr_dirs], dim=-1)\n        \n        vertices, joints_smpl, A, T, shape_offset, pose_offset, pose_feature = lbs(shape_components, full_pose, self.v_template, \n                                            shapedirs, self.posedirs,\n                                            self.J_regressor, self.parents, \n                                            self.lbs_weights, pose2rot=pose2rot)\n        \n        lmk_faces_idx = self.lmk_faces_idx.unsqueeze(\n            dim=0).expand(batch_size, -1).contiguous()\n        lmk_bary_coords = self.lmk_bary_coords.unsqueeze(dim=0).repeat(\n            self.batch_size, 1, 1)\n        if self.use_face_contour:\n            lmk_idx_and_bcoords = find_dynamic_lmk_idx_and_bcoords(\n                vertices, full_pose, self.dynamic_lmk_faces_idx,\n                self.dynamic_lmk_bary_coords,\n                self.neck_kin_chain,\n                pose2rot=True,\n            )\n            dyn_lmk_faces_idx, dyn_lmk_bary_coords = lmk_idx_and_bcoords\n\n            lmk_faces_idx = torch.cat([lmk_faces_idx,\n                                       dyn_lmk_faces_idx], 1)\n            lmk_bary_coords = torch.cat(\n                [lmk_bary_coords.expand(batch_size, -1, -1),\n                 dyn_lmk_bary_coords], 1)\n\n        landmarks = vertices2landmarks(vertices, self.faces_tensor,\n                                       lmk_faces_idx,\n                                       lmk_bary_coords)\n\n        # Add any extra joints that might be needed\n        joints = self.vertex_joint_selector(vertices, joints_smpl)\n        # Add the landmarks to the joints\n        joints = torch.cat([joints, landmarks], dim=1)\n        # Map the joints to the current dataset\n\n        if self.joint_mapper is not None:\n            joints = self.joint_mapper(joints=joints, vertices=vertices)\n\n        if apply_trans:\n            joints_smpl += transl.unsqueeze(dim=1)\n            joints += transl.unsqueeze(dim=1)\n            vertices += transl.unsqueeze(dim=1)\n            A[..., :3, 3] += transl.unsqueeze(dim=1)\n            T[..., :3, 3] += transl.unsqueeze(dim=1)\n\n        v_shaped = None\n        if return_shaped:\n            v_shaped = self.v_template + blend_shapes(betas, self.shapedirs)\n        else:\n            v_shaped = Tensor(0)\n        output = SMPLXOutput(vertices=vertices if return_verts else None,\n                              joints=joints_smpl,\n                              betas=shape_components,\n                              expression=expression,\n                              global_orient=global_orient,\n                              body_pose=body_pose,\n                              v_shaped=v_shaped,\n                              full_pose=full_pose if return_full_pose else None,\n                              A=A,\n                              T=T, shape_offset=shape_offset, pose_offset=pose_offset, pose_feature=pose_feature)\n        return output"
  },
  {
    "path": "lib/models/deformers/smplx/joint_names.py",
    "content": "# -*- coding: utf-8 -*-\n\n# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is\n# holder of all proprietary rights on this computer program.\n# You can only use this computer program if you have closed\n# a license agreement with MPG or you get the right to use the computer\n# program from someone who is authorized to grant you that right.\n# Any use of the computer program without a valid license is prohibited and\n# liable to prosecution.\n#\n# Copyright©2019 Max-Planck-Gesellschaft zur Förderung\n# der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute\n# for Intelligent Systems. All rights reserved.\n#\n# Contact: ps-license@tuebingen.mpg.de\n\nJOINT_NAMES = [\n    'pelvis',\n    'left_hip',\n    'right_hip',\n    'spine1',\n    'left_knee',\n    'right_knee',\n    'spine2',\n    'left_ankle',\n    'right_ankle',\n    'spine3',\n    'left_foot',\n    'right_foot',\n    'neck',\n    'left_collar',\n    'right_collar',\n    'head',\n    'left_shoulder',\n    'right_shoulder',\n    'left_elbow',\n    'right_elbow',\n    'left_wrist',\n    'right_wrist',\n    'jaw',\n    'left_eye_smplhf',\n    'right_eye_smplhf',\n    'left_index1',\n    'left_index2',\n    'left_index3',\n    'left_middle1',\n    'left_middle2',\n    'left_middle3',\n    'left_pinky1',\n    'left_pinky2',\n    'left_pinky3',\n    'left_ring1',\n    'left_ring2',\n    'left_ring3',\n    'left_thumb1',\n    'left_thumb2',\n    'left_thumb3',\n    'right_index1',\n    'right_index2',\n    'right_index3',\n    'right_middle1',\n    'right_middle2',\n    'right_middle3',\n    'right_pinky1',\n    'right_pinky2',\n    'right_pinky3',\n    'right_ring1',\n    'right_ring2',\n    'right_ring3',\n    'right_thumb1',\n    'right_thumb2',\n    'right_thumb3',\n    'nose',\n    'right_eye',\n    'left_eye',\n    'right_ear',\n    'left_ear',\n    'left_big_toe',\n    'left_small_toe',\n    'left_heel',\n    'right_big_toe',\n    'right_small_toe',\n    'right_heel',\n    'left_thumb',\n    'left_index',\n    'left_middle',\n    'left_ring',\n    'left_pinky',\n    'right_thumb',\n    'right_index',\n    'right_middle',\n    'right_ring',\n    'right_pinky',\n    'right_eye_brow1',\n    'right_eye_brow2',\n    'right_eye_brow3',\n    'right_eye_brow4',\n    'right_eye_brow5',\n    'left_eye_brow5',\n    'left_eye_brow4',\n    'left_eye_brow3',\n    'left_eye_brow2',\n    'left_eye_brow1',\n    'nose1',\n    'nose2',\n    'nose3',\n    'nose4',\n    'right_nose_2',\n    'right_nose_1',\n    'nose_middle',\n    'left_nose_1',\n    'left_nose_2',\n    'right_eye1',\n    'right_eye2',\n    'right_eye3',\n    'right_eye4',\n    'right_eye5',\n    'right_eye6',\n    'left_eye4',\n    'left_eye3',\n    'left_eye2',\n    'left_eye1',\n    'left_eye6',\n    'left_eye5',\n    'right_mouth_1',\n    'right_mouth_2',\n    'right_mouth_3',\n    'mouth_top',\n    'left_mouth_3',\n    'left_mouth_2',\n    'left_mouth_1',\n    'left_mouth_5',  # 59 in OpenPose output\n    'left_mouth_4',  # 58 in OpenPose output\n    'mouth_bottom',\n    'right_mouth_4',\n    'right_mouth_5',\n    'right_lip_1',\n    'right_lip_2',\n    'lip_top',\n    'left_lip_2',\n    'left_lip_1',\n    'left_lip_3',\n    'lip_bottom',\n    'right_lip_3',\n    # Face contour\n    'right_contour_1',\n    'right_contour_2',\n    'right_contour_3',\n    'right_contour_4',\n    'right_contour_5',\n    'right_contour_6',\n    'right_contour_7',\n    'right_contour_8',\n    'contour_middle',\n    'left_contour_8',\n    'left_contour_7',\n    'left_contour_6',\n    'left_contour_5',\n    'left_contour_4',\n    'left_contour_3',\n    'left_contour_2',\n    'left_contour_1',\n]\n\n\nSMPLH_JOINT_NAMES = [\n    'pelvis',\n    'left_hip',\n    'right_hip',\n    'spine1',\n    'left_knee',\n    'right_knee',\n    'spine2',\n    'left_ankle',\n    'right_ankle',\n    'spine3',\n    'left_foot',\n    'right_foot',\n    'neck',\n    'left_collar',\n    'right_collar',\n    'head',\n    'left_shoulder',\n    'right_shoulder',\n    'left_elbow',\n    'right_elbow',\n    'left_wrist',\n    'right_wrist',\n    'left_index1',\n    'left_index2',\n    'left_index3',\n    'left_middle1',\n    'left_middle2',\n    'left_middle3',\n    'left_pinky1',\n    'left_pinky2',\n    'left_pinky3',\n    'left_ring1',\n    'left_ring2',\n    'left_ring3',\n    'left_thumb1',\n    'left_thumb2',\n    'left_thumb3',\n    'right_index1',\n    'right_index2',\n    'right_index3',\n    'right_middle1',\n    'right_middle2',\n    'right_middle3',\n    'right_pinky1',\n    'right_pinky2',\n    'right_pinky3',\n    'right_ring1',\n    'right_ring2',\n    'right_ring3',\n    'right_thumb1',\n    'right_thumb2',\n    'right_thumb3',\n    'nose',\n    'right_eye',\n    'left_eye',\n    'right_ear',\n    'left_ear',\n    'left_big_toe',\n    'left_small_toe',\n    'left_heel',\n    'right_big_toe',\n    'right_small_toe',\n    'right_heel',\n    'left_thumb',\n    'left_index',\n    'left_middle',\n    'left_ring',\n    'left_pinky',\n    'right_thumb',\n    'right_index',\n    'right_middle',\n    'right_ring',\n    'right_pinky',\n]\n"
  },
  {
    "path": "lib/models/deformers/smplx/lbs.py",
    "content": "# -*- coding: utf-8 -*-\n\n# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is\n# holder of all proprietary rights on this computer program.\n# You can only use this computer program if you have closed\n# a license agreement with MPG or you get the right to use the computer\n# program from someone who is authorized to grant you that right.\n# Any use of the computer program without a valid license is prohibited and\n# liable to prosecution.\n#\n# Copyright©2019 Max-Planck-Gesellschaft zur Förderung\n# der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute\n# for Intelligent Systems. All rights reserved.\n#\n# Contact: ps-license@tuebingen.mpg.de\n\nfrom __future__ import absolute_import\nfrom __future__ import print_function\nfrom __future__ import division\n\nfrom typing import Tuple, List\nimport numpy as np\n\nimport torch\nimport torch.nn.functional as F\n\nfrom .utils import rot_mat_to_euler, Tensor\n\n\ndef find_dynamic_lmk_idx_and_bcoords(\n    vertices: Tensor,\n    pose: Tensor,\n    dynamic_lmk_faces_idx: Tensor,\n    dynamic_lmk_b_coords: Tensor,\n    neck_kin_chain: List[int],\n    pose2rot: bool = True,\n) -> Tuple[Tensor, Tensor]:\n    ''' Compute the faces, barycentric coordinates for the dynamic landmarks\n\n\n        To do so, we first compute the rotation of the neck around the y-axis\n        and then use a pre-computed look-up table to find the faces and the\n        barycentric coordinates that will be used.\n\n        Special thanks to Soubhik Sanyal (soubhik.sanyal@tuebingen.mpg.de)\n        for providing the original TensorFlow implementation and for the LUT.\n\n        Parameters\n        ----------\n        vertices: torch.tensor BxVx3, dtype = torch.float32\n            The tensor of input vertices\n        pose: torch.tensor Bx(Jx3), dtype = torch.float32\n            The current pose of the body model\n        dynamic_lmk_faces_idx: torch.tensor L, dtype = torch.long\n            The look-up table from neck rotation to faces\n        dynamic_lmk_b_coords: torch.tensor Lx3, dtype = torch.float32\n            The look-up table from neck rotation to barycentric coordinates\n        neck_kin_chain: list\n            A python list that contains the indices of the joints that form the\n            kinematic chain of the neck.\n        dtype: torch.dtype, optional\n\n        Returns\n        -------\n        dyn_lmk_faces_idx: torch.tensor, dtype = torch.long\n            A tensor of size BxL that contains the indices of the faces that\n            will be used to compute the current dynamic landmarks.\n        dyn_lmk_b_coords: torch.tensor, dtype = torch.float32\n            A tensor of size BxL that contains the indices of the faces that\n            will be used to compute the current dynamic landmarks.\n    '''\n\n    dtype = vertices.dtype\n    batch_size = vertices.shape[0]\n\n    if pose2rot:\n        aa_pose = torch.index_select(pose.view(batch_size, -1, 3), 1,\n                                     neck_kin_chain)\n        rot_mats = batch_rodrigues(\n            aa_pose.view(-1, 3)).view(batch_size, -1, 3, 3)\n    else:\n        rot_mats = torch.index_select(\n            pose.view(batch_size, -1, 3, 3), 1, neck_kin_chain)\n\n    rel_rot_mat = torch.eye(\n        3, device=vertices.device, dtype=dtype).unsqueeze_(dim=0).repeat(\n            batch_size, 1, 1)\n    for idx in range(len(neck_kin_chain)):\n        rel_rot_mat = torch.bmm(rot_mats[:, idx], rel_rot_mat)\n\n    y_rot_angle = torch.round(\n        torch.clamp(-rot_mat_to_euler(rel_rot_mat) * 180.0 / np.pi,\n                    max=39)).to(dtype=torch.long)\n    neg_mask = y_rot_angle.lt(0).to(dtype=torch.long)\n    mask = y_rot_angle.lt(-39).to(dtype=torch.long)\n    neg_vals = mask * 78 + (1 - mask) * (39 - y_rot_angle)\n    y_rot_angle = (neg_mask * neg_vals +\n                   (1 - neg_mask) * y_rot_angle)\n\n    dyn_lmk_faces_idx = torch.index_select(dynamic_lmk_faces_idx,\n                                           0, y_rot_angle)\n    dyn_lmk_b_coords = torch.index_select(dynamic_lmk_b_coords,\n                                          0, y_rot_angle)\n\n    return dyn_lmk_faces_idx, dyn_lmk_b_coords\n\n\ndef vertices2landmarks(\n    vertices: Tensor,\n    faces: Tensor,\n    lmk_faces_idx: Tensor,\n    lmk_bary_coords: Tensor\n) -> Tensor:\n    ''' Calculates landmarks by barycentric interpolation\n\n        Parameters\n        ----------\n        vertices: torch.tensor BxVx3, dtype = torch.float32\n            The tensor of input vertices\n        faces: torch.tensor Fx3, dtype = torch.long\n            The faces of the mesh\n        lmk_faces_idx: torch.tensor L, dtype = torch.long\n            The tensor with the indices of the faces used to calculate the\n            landmarks.\n        lmk_bary_coords: torch.tensor Lx3, dtype = torch.float32\n            The tensor of barycentric coordinates that are used to interpolate\n            the landmarks\n\n        Returns\n        -------\n        landmarks: torch.tensor BxLx3, dtype = torch.float32\n            The coordinates of the landmarks for each mesh in the batch\n    '''\n    # Extract the indices of the vertices for each face\n    # BxLx3\n    batch_size, num_verts = vertices.shape[:2]\n    device = vertices.device\n\n    lmk_faces = torch.index_select(faces, 0, lmk_faces_idx.view(-1)).view(\n        batch_size, -1, 3)\n\n    lmk_faces += torch.arange(\n        batch_size, dtype=torch.long, device=device).view(-1, 1, 1) * num_verts\n\n    lmk_vertices = vertices.view(-1, 3)[lmk_faces].view(\n        batch_size, -1, 3, 3)\n\n    landmarks = torch.einsum('blfi,blf->bli', [lmk_vertices, lmk_bary_coords])\n    return landmarks\n\n\ndef lbs(\n    betas: Tensor,\n    pose: Tensor,\n    v_template: Tensor,\n    shapedirs: Tensor,\n    posedirs: Tensor,\n    J_regressor: Tensor,\n    parents: Tensor,\n    lbs_weights: Tensor,\n    pose2rot: bool = True,\n) -> Tuple[Tensor, Tensor]:\n    ''' Performs Linear Blend Skinning with the given shape and pose parameters\n\n        Parameters\n        ----------\n        betas : torch.tensor BxNB\n            The tensor of shape parameters\n        pose : torch.tensor Bx(J + 1) * 3\n            The pose parameters in axis-angle format\n        v_template torch.tensor BxVx3\n            The template mesh that will be deformed\n        shapedirs : torch.tensor 1xNB\n            The tensor of PCA shape displacements\n        posedirs : torch.tensor Px(V * 3)\n            The pose PCA coefficients\n        J_regressor : torch.tensor JxV\n            The regressor array that is used to calculate the joints from\n            the position of the vertices\n        parents: torch.tensor J\n            The array that describes the kinematic tree for the model\n        lbs_weights: torch.tensor N x V x (J + 1)\n            The linear blend skinning weights that represent how much the\n            rotation matrix of each part affects each vertex\n        pose2rot: bool, optional\n            Flag on whether to convert the input pose tensor to rotation\n            matrices. The default value is True. If False, then the pose tensor\n            should already contain rotation matrices and have a size of\n            Bx(J + 1)x9\n        dtype: torch.dtype, optional\n\n        Returns\n        -------\n        verts: torch.tensor BxVx3\n            The vertices of the mesh after applying the shape and pose\n            displacements.\n        joints: torch.tensor BxJx3\n            The joints of the model\n    '''\n\n    batch_size = max(betas.shape[0], pose.shape[0])\n    device, dtype = betas.device, betas.dtype\n\n    # Add shape contribution\n    shape_offset = blend_shapes(betas, shapedirs)\n    v_shaped = v_template + shape_offset\n\n    # Get the joints\n    # NxJx3 array\n    J = vertices2joints(J_regressor, v_shaped)\n\n    # 3. Add pose blend shapes\n    # N x J x 3 x 3\n    ident = torch.eye(3, dtype=dtype, device=device)\n    if pose2rot:\n        rot_mats = batch_rodrigues(pose.view(-1, 3)).view(\n            [batch_size, -1, 3, 3])\n\n        pose_feature = (rot_mats[:, 1:, :, :] - ident).view([batch_size, -1])\n        # (N x P) x (P, V * 3) -> N x V x 3\n        pose_offsets = torch.matmul(\n            pose_feature, posedirs).view(batch_size, -1, 3)\n    else:\n        pose_feature = pose[:, 1:].view(batch_size, -1, 3, 3) - ident\n        rot_mats = pose.view(batch_size, -1, 3, 3)\n\n        pose_offsets = torch.matmul(pose_feature.view(batch_size, -1),\n                                    posedirs).view(batch_size, -1, 3)\n\n    v_posed = pose_offsets + v_shaped\n    # 4. Get the global joint location\n    J_transformed, A = batch_rigid_transform(rot_mats, J, parents, dtype=dtype)\n\n    # 5. Do skinning:\n    # W is N x V x (J + 1)\n    W = lbs_weights.unsqueeze(dim=0).expand([batch_size, -1, -1])\n    # (N x V x (J + 1)) x (N x (J + 1) x 16)\n    num_joints = J_regressor.shape[0]\n    T = torch.matmul(W, A.view(batch_size, num_joints, 16)) \\\n        .view(batch_size, -1, 4, 4)\n\n    homogen_coord = torch.ones([batch_size, v_posed.shape[1], 1],\n                               dtype=dtype, device=device)\n    v_posed_homo = torch.cat([v_posed, homogen_coord], dim=2)\n    v_homo = torch.matmul(T, torch.unsqueeze(v_posed_homo, dim=-1))\n\n    verts = v_homo[:, :, :3, 0]\n    return verts, J_transformed, A, T, shape_offset, pose_offsets, pose_feature\n\n\ndef vertices2joints(J_regressor: Tensor, vertices: Tensor) -> Tensor:\n    ''' Calculates the 3D joint locations from the vertices\n\n    Parameters\n    ----------\n    J_regressor : torch.tensor JxV\n        The regressor array that is used to calculate the joints from the\n        position of the vertices\n    vertices : torch.tensor BxVx3\n        The tensor of mesh vertices\n\n    Returns\n    -------\n    torch.tensor BxJx3\n        The location of the joints\n    '''\n\n    return torch.einsum('bik,ji->bjk', [vertices, J_regressor])\n\n\ndef blend_shapes(betas: Tensor, shape_disps: Tensor) -> Tensor:\n    ''' Calculates the per vertex displacement due to the blend shapes\n\n\n    Parameters\n    ----------\n    betas : torch.tensor Bx(num_betas)\n        Blend shape coefficients\n    shape_disps: torch.tensor Vx3x(num_betas)\n        Blend shapes\n\n    Returns\n    -------\n    torch.tensor BxVx3\n        The per-vertex displacement due to shape deformation\n    '''\n\n    # Displacement[b, m, k] = sum_{l} betas[b, l] * shape_disps[m, k, l]\n    # i.e. Multiply each shape displacement by its corresponding beta and\n    # then sum them.\n    blend_shape = torch.einsum('bl,mkl->bmk', [betas, shape_disps])\n    return blend_shape\n\n\ndef batch_rodrigues(\n    rot_vecs: Tensor,\n    epsilon: float = 1e-8,\n) -> Tensor:\n    ''' Calculates the rotation matrices for a batch of rotation vectors\n        Parameters\n        ----------\n        rot_vecs: torch.tensor Nx3\n            array of N axis-angle vectors\n        Returns\n        -------\n        R: torch.tensor Nx3x3\n            The rotation matrices for the given axis-angle parameters\n    '''\n\n    batch_size = rot_vecs.shape[0]\n    device, dtype = rot_vecs.device, rot_vecs.dtype\n\n    angle = torch.norm(rot_vecs + 1e-8, dim=1, keepdim=True)\n    rot_dir = rot_vecs / angle\n\n    cos = torch.unsqueeze(torch.cos(angle), dim=1)\n    sin = torch.unsqueeze(torch.sin(angle), dim=1)\n\n    # Bx1 arrays\n    rx, ry, rz = torch.split(rot_dir, 1, dim=1)\n    K = torch.zeros((batch_size, 3, 3), dtype=dtype, device=device)\n\n    zeros = torch.zeros((batch_size, 1), dtype=dtype, device=device)\n    K = torch.cat([zeros, -rz, ry, rz, zeros, -rx, -ry, rx, zeros], dim=1) \\\n        .view((batch_size, 3, 3))\n\n    ident = torch.eye(3, dtype=dtype, device=device).unsqueeze(dim=0)\n    rot_mat = ident + sin * K + (1 - cos) * torch.bmm(K, K)\n    return rot_mat\n\n\ndef transform_mat(R: Tensor, t: Tensor) -> Tensor:\n    ''' Creates a batch of transformation matrices\n        Args:\n            - R: Bx3x3 array of a batch of rotation matrices\n            - t: Bx3x1 array of a batch of translation vectors\n        Returns:\n            - T: Bx4x4 Transformation matrix\n    '''\n    # No padding left or right, only add an extra row\n    return torch.cat([F.pad(R, [0, 0, 0, 1]),\n                      F.pad(t, [0, 0, 0, 1], value=1)], dim=2)\n\n\ndef batch_rigid_transform(\n    rot_mats: Tensor,\n    joints: Tensor,\n    parents: Tensor,\n    dtype=torch.float32\n) -> Tensor:\n    \"\"\"\n    Applies a batch of rigid transformations to the joints\n\n    Parameters\n    ----------\n    rot_mats : torch.tensor BxNx3x3\n        Tensor of rotation matrices\n    joints : torch.tensor BxNx3\n        Locations of joints\n    parents : torch.tensor BxN\n        The kinematic tree of each object\n    dtype : torch.dtype, optional:\n        The data type of the created tensors, the default is torch.float32\n\n    Returns\n    -------\n    posed_joints : torch.tensor BxNx3\n        The locations of the joints after applying the pose rotations\n    rel_transforms : torch.tensor BxNx4x4\n        The relative (with respect to the root joint) rigid transformations\n        for all the joints\n    \"\"\"\n\n    joints = torch.unsqueeze(joints, dim=-1)\n\n    rel_joints = joints.clone()\n    rel_joints[:, 1:] -= joints[:, parents[1:]]\n\n    transforms_mat = transform_mat(\n        rot_mats.reshape(-1, 3, 3),\n        rel_joints.reshape(-1, 3, 1)).reshape(-1, joints.shape[1], 4, 4)\n\n    transform_chain = [transforms_mat[:, 0]]\n    for i in range(1, parents.shape[0]):\n        # Subtract the joint location at the rest pose\n        # No need for rotation, since it's identity when at rest\n        curr_res = torch.matmul(transform_chain[parents[i]],\n                                transforms_mat[:, i])\n        transform_chain.append(curr_res)\n\n    transforms = torch.stack(transform_chain, dim=1)\n\n    # The last column of the transformations contains the posed joints\n    posed_joints = transforms[:, :, :3, 3]\n\n    joints_homogen = F.pad(joints, [0, 0, 0, 1])\n\n    rel_transforms = transforms - F.pad(\n        torch.matmul(transforms, joints_homogen), [3, 0, 0, 0, 0, 0, 0, 0])\n\n    return posed_joints, rel_transforms\n"
  },
  {
    "path": "lib/models/deformers/smplx/utils.py",
    "content": "# -*- coding: utf-8 -*-\n\n# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is\n# holder of all proprietary rights on this computer program.\n# You can only use this computer program if you have closed\n# a license agreement with MPG or you get the right to use the computer\n# program from someone who is authorized to grant you that right.\n# Any use of the computer program without a valid license is prohibited and\n# liable to prosecution.\n#\n# Copyright©2019 Max-Planck-Gesellschaft zur Förderung\n# der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute\n# for Intelligent Systems. All rights reserved.\n#\n# Contact: ps-license@tuebingen.mpg.de\n\nfrom optparse import Option\nfrom typing import NewType, Union, Optional\nfrom dataclasses import dataclass, asdict, fields\nimport numpy as np\nimport torch\n\nTensor = NewType('Tensor', torch.Tensor)\nArray = NewType('Array', np.ndarray)\n\n\n@dataclass\nclass ModelOutput:\n    vertices: Optional[Tensor] = None\n    joints: Optional[Tensor] = None\n    full_pose: Optional[Tensor] = None\n    global_orient: Optional[Tensor] = None\n    transl: Optional[Tensor] = None\n    v_shaped: Optional[Tensor] = None\n\n    def __getitem__(self, key):\n        return getattr(self, key)\n\n    def get(self, key, default=None):\n        return getattr(self, key, default)\n\n    def __iter__(self):\n        return self.keys()\n\n    def keys(self):\n        keys = [t.name for t in fields(self)]\n        return iter(keys)\n\n    def values(self):\n        values = [getattr(self, t.name) for t in fields(self)]\n        return iter(values)\n\n    def items(self):\n        data = [(t.name, getattr(self, t.name)) for t in fields(self)]\n        return iter(data)\n\n\n@dataclass\nclass SMPLOutput(ModelOutput):\n    betas: Optional[Tensor] = None\n    body_pose: Optional[Tensor] = None\n    T: Optional[Tensor] = None\n    A: Optional[Tensor] = None\n    shape_offset: Optional[Tensor] = None\n    pose_offset: Optional[Tensor] = None\n    pose_feature: Optional[Tensor] = None\n\n@dataclass\nclass SMPLHOutput(SMPLOutput):\n    left_hand_pose: Optional[Tensor] = None\n    right_hand_pose: Optional[Tensor] = None\n    transl: Optional[Tensor] = None\n\n\n@dataclass\nclass SMPLXOutput(SMPLHOutput):\n    expression: Optional[Tensor] = None\n    jaw_pose: Optional[Tensor] = None\n\n\ndef find_joint_kin_chain(joint_id, kinematic_tree):\n    kin_chain = []\n    curr_idx = joint_id\n    while curr_idx != -1:\n        kin_chain.append(curr_idx)\n        curr_idx = kinematic_tree[curr_idx]\n    return kin_chain\n\n\ndef to_tensor(\n        array: Union[Array, Tensor], dtype=torch.float32\n) -> Tensor:\n    if torch.is_tensor(array):\n        return array\n    else:\n        return torch.tensor(array, dtype=dtype)\n\n\nclass Struct(object):\n    def __init__(self, **kwargs):\n        for key, val in kwargs.items():\n            setattr(self, key, val)\n\n\ndef to_np(array, dtype=np.float32):\n    if 'scipy.sparse' in str(type(array)):\n        array = array.todense()\n    return np.array(array, dtype=dtype)\n\n\ndef rot_mat_to_euler(rot_mats):\n    # Calculates rotation matrix to euler angles\n    # Careful for extreme cases of eular angles like [0.0, pi, 0.0]\n\n    sy = torch.sqrt(rot_mats[:, 0, 0] * rot_mats[:, 0, 0] +\n                    rot_mats[:, 1, 0] * rot_mats[:, 1, 0])\n    return torch.atan2(-rot_mats[:, 2, 0], sy)\n"
  },
  {
    "path": "lib/models/deformers/smplx/vertex_ids.py",
    "content": "# -*- coding: utf-8 -*-\n\n# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is\n# holder of all proprietary rights on this computer program.\n# You can only use this computer program if you have closed\n# a license agreement with MPG or you get the right to use the computer\n# program from someone who is authorized to grant you that right.\n# Any use of the computer program without a valid license is prohibited and\n# liable to prosecution.\n#\n# Copyright©2019 Max-Planck-Gesellschaft zur Förderung\n# der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute\n# for Intelligent Systems. All rights reserved.\n#\n# Contact: ps-license@tuebingen.mpg.de\n\nfrom __future__ import print_function\nfrom __future__ import absolute_import\nfrom __future__ import division\n\n# Joint name to vertex mapping. SMPL/SMPL-H/SMPL-X vertices that correspond to\n# MSCOCO and OpenPose joints\nvertex_ids = {\n    'smplh': {\n        'nose':\t\t    332,\n        'reye':\t\t    6260,\n        'leye':\t\t    2800,\n        'rear':\t\t    4071,\n        'lear':\t\t    583,\n        'rthumb':\t\t6191,\n        'rindex':\t\t5782,\n        'rmiddle':\t\t5905,\n        'rring':\t\t6016,\n        'rpinky':\t\t6133,\n        'lthumb':\t\t2746,\n        'lindex':\t\t2319,\n        'lmiddle':\t\t2445,\n        'lring':\t\t2556,\n        'lpinky':\t\t2673,\n        'LBigToe':\t\t3216,\n        'LSmallToe':\t3226,\n        'LHeel':\t\t3387,\n        'RBigToe':\t\t6617,\n        'RSmallToe':    6624,\n        'RHeel':\t\t6787\n    },\n    'smplx': {\n        'nose':\t\t    9120,\n        'reye':\t\t    9929,\n        'leye':\t\t    9448,\n        'rear':\t\t    616,\n        'lear':\t\t    6,\n        'rthumb':\t\t8079,\n        'rindex':\t\t7669,\n        'rmiddle':\t\t7794,\n        'rring':\t\t7905,\n        'rpinky':\t\t8022,\n        'lthumb':\t\t5361,\n        'lindex':\t\t4933,\n        'lmiddle':\t\t5058,\n        'lring':\t\t5169,\n        'lpinky':\t\t5286,\n        'LBigToe':\t\t5770,\n        'LSmallToe':    5780,\n        'LHeel':\t\t8846,\n        'RBigToe':\t\t8463,\n        'RSmallToe': \t8474,\n        'RHeel':  \t\t8635\n    },\n    'mano': {\n            'thumb':\t\t744,\n            'index':\t\t320,\n            'middle':\t\t443,\n            'ring':\t\t    554,\n            'pinky':\t\t671,\n        }\n}\n"
  },
  {
    "path": "lib/models/deformers/smplx/vertex_joint_selector.py",
    "content": "# -*- coding: utf-8 -*-\n\n# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is\n# holder of all proprietary rights on this computer program.\n# You can only use this computer program if you have closed\n# a license agreement with MPG or you get the right to use the computer\n# program from someone who is authorized to grant you that right.\n# Any use of the computer program without a valid license is prohibited and\n# liable to prosecution.\n#\n# Copyright©2019 Max-Planck-Gesellschaft zur Förderung\n# der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute\n# for Intelligent Systems. All rights reserved.\n#\n# Contact: ps-license@tuebingen.mpg.de\n\nfrom __future__ import absolute_import\nfrom __future__ import print_function\nfrom __future__ import division\n\nimport numpy as np\n\nimport torch\nimport torch.nn as nn\n\nfrom .utils import to_tensor\n\n\nclass VertexJointSelector(nn.Module):\n\n    def __init__(self, vertex_ids=None,\n                 use_hands=False,\n                 use_feet_keypoints=False, **kwargs):\n        super(VertexJointSelector, self).__init__()\n\n        extra_joints_idxs = []\n\n        face_keyp_idxs = np.array([\n            vertex_ids['nose'],\n            vertex_ids['reye'],\n            vertex_ids['leye'],\n            vertex_ids['rear'],\n            vertex_ids['lear']], dtype=np.int64)\n\n        extra_joints_idxs = np.concatenate([extra_joints_idxs,\n                                            face_keyp_idxs])\n\n        if use_feet_keypoints:\n            feet_keyp_idxs = np.array([vertex_ids['LBigToe'],\n                                       vertex_ids['LSmallToe'],\n                                       vertex_ids['LHeel'],\n                                       vertex_ids['RBigToe'],\n                                       vertex_ids['RSmallToe'],\n                                       vertex_ids['RHeel']], dtype=np.int32)\n\n            extra_joints_idxs = np.concatenate(\n                [extra_joints_idxs, feet_keyp_idxs])\n\n        if use_hands:\n            self.tip_names = ['thumb', 'index', 'middle', 'ring', 'pinky']\n\n            tips_idxs = []\n            for hand_id in ['l', 'r']:\n                for tip_name in self.tip_names:\n                    tips_idxs.append(vertex_ids[hand_id + tip_name])\n\n            extra_joints_idxs = np.concatenate(\n                [extra_joints_idxs, tips_idxs])\n\n        self.register_buffer('extra_joints_idxs',\n                             to_tensor(extra_joints_idxs, dtype=torch.long))\n\n    def forward(self, vertices, joints):\n        extra_joints = torch.index_select(vertices, 1, self.extra_joints_idxs)\n        joints = torch.cat([joints, extra_joints], dim=1)\n\n        return joints\n"
  },
  {
    "path": "lib/models/deformers/smplx_deformer_gender.py",
    "content": "# Modified from Deformer of AG3D\n\nfrom .fast_snarf.lib.model.deformer_smplx import ForwardDeformer, skinning\nfrom .smplx import SMPLX\nimport torch\nfrom pytorch3d import ops\nimport numpy as np\nimport pickle\nimport json\n\nfrom pytorch3d.transforms import quaternion_to_matrix, matrix_to_quaternion\nclass SMPLXDeformer_gender(torch.nn.Module):\n    \n    def __init__(self, gender, is_sub2=False) -> None: \n        super().__init__()\n        self.body_model = SMPLX('lib/models/deformers/smplx/SMPLX', gender=gender, \\\n                                create_body_pose=False, \\\n                                create_betas=False, \\\n                                create_global_orient=False, \\\n                                create_transl=False,\n                                create_expression=False,\n                                create_jaw_pose=False,\n                                create_leye_pose=False,\n                                create_reye_pose=False,\n                                create_right_hand_pose=False,\n                                create_left_hand_pose=False,\n                                use_pca=True,\n                                num_pca_comps=12,\n                                num_betas=10,\n                                flat_hand_mean=False,ext='pkl')\n        self.deformer = ForwardDeformer()\n        \n        self.threshold = 0.12\n\n                \n        base_cache_dir = 'work_dirs/cache'   \n        if is_sub2:\n            base_cache_dir = 'work_dirs/cache_sub2'\n\n        if gender == 'neutral':\n            init_spdir_neutral = torch.as_tensor(np.load(base_cache_dir+'/init_spdir_smplx_thu_newNeutral.npy'))\n            self.register_buffer('init_spdir', init_spdir_neutral, persistent=False)\n                    \n            init_podir_neutral = torch.as_tensor(np.load(base_cache_dir+'/init_podir_smplx_thu_newNeutral.npy'))\n            self.register_buffer('init_podir', init_podir_neutral, persistent=False)\n\n            init_lbs_weights = torch.as_tensor(np.load(base_cache_dir+'/init_lbsw_smplx_thu_newNeutral.npy'))\n            self.register_buffer('init_lbsw', init_lbs_weights.unsqueeze(0), persistent=False)\n            init_faces = torch.as_tensor(np.load(base_cache_dir+'/init_faces_smplx_newNeutral.npy'))\n            self.register_buffer('init_faces', init_faces.unsqueeze(0), persistent=False)\n\n        elif gender == 'male':\n            init_spdir_male = torch.as_tensor(np.load(base_cache_dir+'/init_spdir_smplx_thu_newMale.npy'))\n            self.register_buffer('init_spdir', init_spdir_male, persistent=False)\n\n            init_podir_male = torch.as_tensor(np.load(base_cache_dir+'/init_podir_smplx_thu_newMale.npy'))\n            self.register_buffer('init_podir', init_podir_male, persistent=False)\n            init_lbs_weights = torch.as_tensor(np.load(base_cache_dir+'/init_lbsw_smplx_thu_newMale.npy'))\n            self.register_buffer('init_lbsw', init_lbs_weights.unsqueeze(0), persistent=False)\n\n        \n            init_faces = torch.as_tensor(np.load(base_cache_dir+'/init_faces_smplx_neuMale.npy'))\n            self.register_buffer('init_faces', init_faces.unsqueeze(0), persistent=False)\n\n        self.initialize()\n        self.initialized = True\n\n    def initialize(self):\n        '''\n         Will only be called once, used to initialize lbs volume\n        '''\n        batch_size = 1\n        device = self.body_model.posedirs.device\n        # canonical space is defined in t-pose / star-pose\n        body_pose_t = torch.zeros((batch_size, 63)).to(device)\n\n        jaw_pose_t = torch.zeros((batch_size, 3)).to(device)\n\n        ##flat_hand_mean = False\n        left_hand_pose_t = torch.tensor([1.4624, -0.1615,  0.1361,  1.3851, -0.2597,  0.0247, -0.0683, -0.4478,\n         -0.6652, -0.7290,  0.0084, -0.4818]).unsqueeze(0).to(device)\n        right_hand_pose_t = torch.tensor([1.4624, -0.1615,  0.1361,  1.3851, -0.2597,  0.0247, -0.0683, -0.4478,\n         -0.6652, -0.7290,  0.0084, -0.4818]).unsqueeze(0).to(device)\n         \n        ## flat_hand_mean = True\n        leye_pose_t = torch.zeros((batch_size, 3)).to(device)\n        reye_pose_t = torch.zeros((batch_size, 3)).to(device)\n        expression_t = torch.zeros((batch_size, 10)).to(device)\n\n        global_orient = torch.zeros((batch_size, 3)).to(device)\n        \n        betas = torch.zeros((batch_size, 10)).to(device)\n        smpl_outputs = self.body_model(betas=betas, body_pose=body_pose_t, jaw_pose=jaw_pose_t, \n                                        left_hand_pose=left_hand_pose_t, right_hand_pose=right_hand_pose_t,\n                                        leye_pose=leye_pose_t, reye_pose=reye_pose_t, expression=expression_t,\n                                        transl=None, global_orient=global_orient) \n        \n        tfs_inv_t = torch.inverse(smpl_outputs.A.float().detach()) # from template to posed space\n        vs_template = smpl_outputs.vertices\n        smpl_faces = torch.as_tensor(self.body_model.faces.astype(np.int64))\n        pose_offset_cano = torch.matmul(smpl_outputs.pose_feature, self.init_podir).reshape(1, -1, 3)\n        pose_offset_cano = torch.cat([pose_offset_cano[:, self.init_faces[..., i]] for i in range(3)], dim=1).mean(1)\n        self.register_buffer('tfs_inv_t', tfs_inv_t, persistent=False)\n        self.register_buffer('vs_template', vs_template, persistent=False)\n        self.register_buffer('smpl_faces', smpl_faces, persistent=False)\n        self.register_buffer('pose_offset_cano', pose_offset_cano, persistent=False)\n\n        # initialize SNARF\n        smpl_verts = smpl_outputs.vertices.float().detach().clone()\n\n        self.deformer.switch_to_explicit(resolution=64,\n                                         smpl_verts=smpl_verts,\n                                         smpl_faces=self.smpl_faces,\n                                         smpl_weights=self.body_model.lbs_weights.clone()[None].detach(),\n                                         use_smpl=True)\n\n        self.dtype = torch.float32\n        self.deformer.lbs_voxel_final = self.deformer.lbs_voxel_final.type(self.dtype)\n        self.deformer.grid_denorm = self.deformer.grid_denorm.type(self.dtype)\n        self.deformer.scale = self.deformer.scale.type(self.dtype)\n        self.deformer.offset = self.deformer.offset.type(self.dtype)\n        self.deformer.scale_kernel = self.deformer.scale_kernel.type(self.dtype)\n        self.deformer.offset_kernel = self.deformer.offset_kernel.type(self.dtype)\n\n    def forword_body_model(self, smpl_params, point_pool=4):\n        batchsize = smpl_params.shape[0]\n        if_use_pca=True\n        if smpl_params.shape[1] == 123:\n            scale, transl, global_orient, pose, betas, left_hand_pose, right_hand_pose, jaw_pose, leye_pose, reye_pose, expression = torch.split(smpl_params, [1, 3, 3, 63, 10, 12, 12, 3, 3, 3, 10], dim=1)\n        else: # not use pca 12 , 189\n            scale, transl, global_orient, pose, betas, left_hand_pose, right_hand_pose, jaw_pose, leye_pose, reye_pose, expression = torch.split(smpl_params, [1, 3, 3, 63, 10, 45, 45, 3, 3, 3, 10], dim=1)\n            if_use_pca = False\n        smpl_params = {\n            'betas': betas.reshape(-1, 10),\n            'expression': expression.reshape(-1, 10),\n            'body_pose': pose.reshape(-1, 63),\n            'left_hand_pose': left_hand_pose.reshape(batchsize, -1),\n            'right_hand_pose': right_hand_pose.reshape(batchsize, -1),\n            'jaw_pose': jaw_pose.reshape(-1, 3),\n            'leye_pose': leye_pose.reshape(-1, 3),\n            'reye_pose': reye_pose.reshape(-1, 3),\n            'global_orient': global_orient.reshape(-1, 3),\n            'transl': transl.reshape(-1, 3),\n            'scale': scale.reshape(-1, 1)\n        }\n        \n        device = smpl_params[\"betas\"].device\n        smpl_outputs = self.body_model(**smpl_params, use_pca=if_use_pca)\n        return smpl_outputs\n    def prepare_deformer(self, smpl_params=None, num_scenes=1, device=None):\n        if smpl_params is None:\n            smpl_params = torch.zeros((num_scenes, 120)).to(device)\n            scale, global_orient, pose, betas, left_hand_pose, right_hand_pose, jaw_pose, leye_pose, reye_pose, expression = torch.split(smpl_params, [1, 3, 63, 10, 12, 12, 3, 3, 3, 10], dim=1)\n            left_hand_pose = torch.tensor([1.4624, -0.1615,  0.1361,  1.3851, -0.2597,  0.0247, -0.0683, -0.4478,\n                -0.6652, -0.7290,  0.0084, -0.4818]).unsqueeze(0).to(device).repeat(num_scenes, 1)\n            right_hand_pose = torch.tensor([1.4624, -0.1615,  0.1361,  1.3851, -0.2597,  0.0247, -0.0683, -0.4478,\n                -0.6652, -0.7290,  0.0084, -0.4818]).unsqueeze(0).to(device).repeat(num_scenes, 1)\n         \n            smpl_params = {\n                'betas': betas,\n                'expression': expression,\n                'body_pose': pose,\n                'left_hand_pose': left_hand_pose,\n                'right_hand_pose': right_hand_pose,\n                'jaw_pose': jaw_pose,\n                'leye_pose': leye_pose,\n                'reye_pose': reye_pose,\n                'global_orient': global_orient,\n                'transl': None,\n                'scale': None,\n            }\n            \n        else:\n            batchsize = smpl_params.shape[0]\n            if_use_pca=True\n            if smpl_params.shape[1] == 123:\n                scale, transl, global_orient, pose, betas, left_hand_pose, right_hand_pose, jaw_pose, leye_pose, reye_pose, expression = torch.split(smpl_params, [1, 3, 3, 63, 10, 12, 12, 3, 3, 3, 10], dim=1)\n            else: # not use pca 12 , 165\n                scale, transl, global_orient, pose, betas, left_hand_pose, right_hand_pose, jaw_pose, leye_pose, reye_pose, expression = torch.split(smpl_params, [1, 3, 3, 63, 10, 45, 45, 3, 3, 3, 10], dim=1)\n                if_use_pca = False\n            smpl_params = {\n                'betas': betas.reshape(-1, 10),\n                'expression': expression.reshape(-1, 10),\n                'body_pose': pose.reshape(-1, 63),\n                'left_hand_pose': left_hand_pose.reshape(batchsize, -1),\n                'right_hand_pose': right_hand_pose.reshape(batchsize, -1),\n                'jaw_pose': jaw_pose.reshape(-1, 3),\n                'leye_pose': leye_pose.reshape(-1, 3),\n                'reye_pose': reye_pose.reshape(-1, 3),\n                'global_orient': global_orient.reshape(-1, 3),\n                'transl': transl.reshape(-1, 3),\n                'scale': scale.reshape(-1, 1)\n            }\n        \n        device = smpl_params[\"betas\"].device\n\n        if not self.initialized:\n            self.initialize(smpl_params[\"betas\"])\n            self.initialized = True\n    \n        smpl_outputs = self.body_model(**smpl_params, use_pca=if_use_pca)\n\n\n        self.smpl_outputs = smpl_outputs\n\n        tfs = (smpl_outputs.A) @ self.tfs_inv_t.expand(smpl_outputs.A.shape[0],-1,-1,-1)\n\n        self.tfs = tfs # self.tfs_A @ self.tfs_inv_t\n        self.tfs_A = smpl_outputs.A \n        # X_posed = smpl_outputs.A @ X_template, and (self.tfs_inv_t) @ X_tposed = X_template; \n        # so X_posed = (smpl_outputs.A @ self.tfs_inv_t) @ X_tposed == equal to ==> self.tfs_A @ self.tfs_inv_t @ X_tposed\n        self.shape_offset = torch.einsum('bl,mkl->bmk', [smpl_outputs.betas, self.init_spdir]) #  betas-torch.Size([1, 20]) ; init_spdir-([25254, 3, 20])\n        self.pose_offset = torch.matmul(smpl_outputs.pose_feature, self.init_podir).reshape(self.shape_offset.shape) # batch_size, ([1, 25254, 3])\n\n    def __call__(self, pts_in, rot_in, mask=None, cano=True, offset_gs=None, if_rotate_gaussian=False):\n        '''\n            to calculate the skinning results\n            pts_in (tensor, [bs, N, 3]): the canonical space points + offset_gs, represented a batch of clothed human\n            rot_in (tensor, [bs, N, 3]): the canonical space gaussians points' rotation\n            mask (tensor, [bs, N]): the mask of the vertices (face, hands), 1 for the vertices that use the skinning weights from template directly\n            cono (bool): if True, return the input pts directly\n            offset_gs (tensor, [bs, N, 3]): the estimated offset of the vertices in the canonical space\n\n            use some of the attributes from the \"prepare_deformer\" to calculate the skinning, including:\n            pose_offset[bs_pose, N, 3]\n            shape_offset[bs_pose, N, 3]\n\n        '''\n        pts = pts_in.clone()\n        rot = rot_in.clone()\n\n        if cano:\n            return pts, None\n        else:\n            init_faces = self.init_faces\n           \n        b, n, _ = pts.shape\n\n        smpl_nn = False\n\n        if smpl_nn:\n            # deformer based on SMPL nearest neighbor search\n            \n            k = 1\n            dist_sq, idx, neighbors = ops.knn_points(pts, self.smpl_outputs.vertices.float().expand(b, -1, -1), K=k, return_nn=True)\n\n            \n            dist = dist_sq.sqrt().clamp_(0.00003, 0.1)\n            weights = self.body_model.lbs_weights.clone()[idx]\n\n\n            ws=1./dist\n            ws=ws/ws.sum(-1,keepdim=True)\n            weights = (ws[..., None]*weights).sum(2).detach()\n\n            shape_offset = torch.cat([self.shape_offset[:, init_faces[..., i]] for i in range(3)], dim=1).mean(1)\n            pts += shape_offset\n            pts_cano_all, w_tf = skinning(pts, weights, self.tfs, inverse=False)\n            pts_cano_all = pts_cano_all.unsqueeze(2)\n            \n        else:\n            # defromer based on fast-SNARF\n            shape_offset = torch.cat([self.shape_offset[:, init_faces[..., i]] for i in range(3)], dim=1).mean(1)\n            pose_offset = torch.cat([self.pose_offset[:, init_faces[..., i]] for i in range(3)], dim=1).mean(1)\n            \n            pts_query_lbs = pts.detach() # T_pose + gs_offset\n\n            pts_cano_all, w_tf = self.deformer.forward_skinning(pts, shape_offset, pose_offset, cond=None, tfs=self.tfs_A, tfs_inv=self.tfs_inv_t, \\\n                                                                poseoff_ori=self.pose_offset_cano, lbsw=self.init_lbsw, mask=mask)\n\n        pts_cano_all = pts_cano_all.reshape(b, n, -1, 3)\n\n        if if_rotate_gaussian:\n            # rotate the gaussian points\n            # pts_cano_all =  rot\n            # rot_mats = quaternion_to_matrix(rot)\n            # rot_mats = torch.einsum('nxy,nyz->nxz', w_tf[..., :3, :3], rot_mats)\n            # rot_res = matrix_to_quaternion(rot_mats)\n            # return pts_cano_all, w_tf.clone(), rot_res\n            raise NotImplementedError(\"Code is not correct!\")\n           \n        \n        assert pts_in.dim() != 2\n        \n        return pts_cano_all, w_tf.clone()"
  },
  {
    "path": "lib/models/renderers/__init__.py",
    "content": "\nfrom .gau_renderer import GRenderer, get_covariance, batch_rodrigues\n__all__ = ['GRenderer']"
  },
  {
    "path": "lib/models/renderers/gau_renderer.py",
    "content": "from diff_gaussian_rasterization import (\n    GaussianRasterizationSettings,\n    GaussianRasterizer,\n)\nimport torch\nimport torch.nn as nn\n\n\ndef batch_rodrigues(rot_vecs, epsilon = 1e-8):\n    ''' Calculates the rotation matrices for a batch of rotation vectors\n        Parameters\n        ----------\n        rot_vecs: torch.tensor Nx3\n            array of N axis-angle vectors\n        Returns\n        -------\n        R: torch.tensor Nx3x3\n            The rotation matrices for the given axis-angle parameters\n    '''\n\n    batch_size = rot_vecs.shape[0]\n    device, dtype = rot_vecs.device, rot_vecs.dtype\n\n    angle = torch.norm(rot_vecs + 1e-8, dim=1, keepdim=True)\n    rot_dir = rot_vecs / angle\n\n    cos = torch.unsqueeze(torch.cos(angle), dim=1)\n    sin = torch.unsqueeze(torch.sin(angle), dim=1)\n\n    # Bx1 arrays\n    rx, ry, rz = torch.split(rot_dir, 1, dim=1)\n    K = torch.zeros((batch_size, 3, 3), dtype=dtype, device=device)\n\n    zeros = torch.zeros((batch_size, 1), dtype=dtype, device=device)\n    K = torch.cat([zeros, -rz, ry, rz, zeros, -rx, -ry, rx, zeros], dim=1) \\\n        .view((batch_size, 3, 3))\n\n    ident = torch.eye(3, dtype=dtype, device=device).unsqueeze(dim=0)\n    rot_mat = ident + sin * K + (1 - cos) * torch.bmm(K, K)\n    return rot_mat\n\ndef build_scaling_rotation(s, r, tfs):\n    L = torch.zeros((s.shape[0], 3, 3), dtype=torch.float, device=s.device)\n    R = build_rotation(r)\n    R_ = R\n\n    L[:,0,0] = s[:,0]\n    L[:,1,1] = s[:,1]\n    L[:,2,2] = s[:,2]\n\n    L = R_ @ L\n    return L\n\ndef strip_symmetric(sym):\n    return strip_lowerdiag(sym)\n\ndef strip_lowerdiag(L):\n    uncertainty = torch.zeros((L.shape[0], 6), dtype=torch.float, device=L.device)\n\n    uncertainty[:, 0] = L[:, 0, 0]\n    uncertainty[:, 1] = L[:, 0, 1]\n    uncertainty[:, 2] = L[:, 0, 2]\n    uncertainty[:, 3] = L[:, 1, 1]\n    uncertainty[:, 4] = L[:, 1, 2]\n    uncertainty[:, 5] = L[:, 2, 2]\n    return uncertainty\n\ndef build_covariance_from_scaling_rotation(scaling, scaling_modifier, rotation, tfs):\n    L = build_scaling_rotation(scaling_modifier * scaling, rotation, tfs)\n    actual_covariance = L @ L.transpose(1, 2)\n    symm = strip_symmetric(actual_covariance)\n    return symm\n\ndef build_rotation(r):\n    norm = torch.sqrt(r[:,0]*r[:,0] + r[:,1]*r[:,1] + r[:,2]*r[:,2] + r[:,3]*r[:,3])\n\n    q = r / norm[:, None]\n\n    R = torch.zeros((q.size(0), 3, 3), device=q.device)\n\n    r = q[:, 0]\n    x = q[:, 1]\n    y = q[:, 2]\n    z = q[:, 3]\n\n    R[:, 0, 0] = 1 - 2 * (y*y + z*z)\n    R[:, 0, 1] = 2 * (x*y - r*z)\n    R[:, 0, 2] = 2 * (x*z + r*y)\n    R[:, 1, 0] = 2 * (x*y + r*z)\n    R[:, 1, 1] = 1 - 2 * (x*x + z*z)\n    R[:, 1, 2] = 2 * (y*z - r*x)\n    R[:, 2, 0] = 2 * (x*z - r*y)\n    R[:, 2, 1] = 2 * (y*z + r*x)\n    R[:, 2, 2] = 1 - 2 * (x*x + y*y)\n    return R\n\ndef get_covariance(scaling, rotation, scaling_modifier = 1):\n    L = torch.zeros_like(rotation)\n    L[:, 0, 0] = scaling[:, 0]\n    L[:, 1, 1] = scaling[:, 1]\n    L[:, 2, 2] = scaling[:, 2]\n    actual_covariance = rotation @ (L**2) @ rotation.permute(0, 2, 1)\n    return strip_symmetric(actual_covariance)\nimport math\nclass GRenderer(nn.Module):\n    def __init__(self, image_size=256, anti_alias=False, f=5000, near=0.01, far=40, bg_color=0):\n        super().__init__()\n        self.anti_alias = anti_alias\n        self.image_size = image_size\n        self.tanfov = 2 * math.atan(self.image_size[0] / (2 * f)) \n        if bg_color == 0:\n            bg = torch.tensor([0, 0, 0], dtype=torch.float32)\n        else:\n            bg = torch.tensor([1, 1, 1], dtype=torch.float32)\n\n        self.register_buffer('bg', bg)\n        \n        opengl_proj = torch.tensor([[2 * f / self.image_size[0], 0.0, 0.0, 0.0],\n                                    [0.0, 2 * f / self.image_size[1], 0.0, 0.0],\n                                    [0.0, 0.0, far / (far - near), -(far * near) / (far - near)],\n                                    [0.0, 0.0, 1.0, 0.0]]).float().unsqueeze(0).transpose(1, 2)\n        self.register_buffer('opengl_proj', opengl_proj)\n\n        if anti_alias: image_size = [s*2 for s in image_size]\n        \n    def prepare(self, cameras):\n        if cameras.shape[-1] == 20: # use the new format: intrisic(fx, fy, cx, cy) + extrinsic(RT)\n            w2c = cameras[4:].reshape(4, 4)\n            cam_center = torch.inverse(w2c)[:3, 3]\n            intrisics = cameras[:4]\n\n            fov = get_fov(intrisics[0:2], intrisics[2].item(), self.image_size)    \n            tanfovx =  fov[1] \n            tanfovy = fov[1]\n            w2c = w2c.unsqueeze(0).transpose(1, 2)\n            proj_matrix = get_proj_yy(intrisics[0], self.image_size, 100, 0.01).to(torch.float32).to(intrisics.device)\n            full_proj = torch.bmm(w2c, proj_matrix).to(torch.float32)\n\n        elif cameras.shape[-1] == 19: # [:3] C, [3: ] RT\n            cam_center = cameras[:3] # C\n            w2c = cameras[3:].reshape(4, 4)\n            w2c = w2c.unsqueeze(0).transpose(1, 2) # RT\n            full_proj = w2c.bmm(self.opengl_proj).to(torch.float32)\n            self.full_proj = full_proj\n            tanfovx = self.tanfov\n            tanfovy = self.tanfov\n\n        self.raster_settings = GaussianRasterizationSettings(\n            image_height=self.image_size[1],\n            image_width=self.image_size[0],\n            tanfovx=tanfovx,\n            tanfovy=tanfovy,\n            bg=self.bg.to(cameras.dtype),\n            scale_modifier=1.0,\n            viewmatrix=w2c,\n            projmatrix=full_proj,\n            sh_degree=0,\n            campos=cam_center,\n            prefiltered=False,\n            debug=False,\n            antialiasing=True # NEW version of GS\n        )\n        self.rasterizer = GaussianRasterizer(raster_settings=self.raster_settings)\n        \n    def render_gaussian(self, means3D, colors_precomp, rotations, opacities, scales, cov3D_precomp=None):\n        '''\n        mode: normal, phong, texture\n        '''\n        screenspace_points = (\n            torch.zeros_like(\n                means3D, \n                dtype=means3D.dtype,\n                requires_grad=True,\n                device=means3D.device,\n            )\n            + 0\n        )\n\n        try:\n            screenspace_points.retain_grad()\n        except:\n            pass\n\n        if cov3D_precomp != None:\n            image, _, _= self.rasterizer(means3D=means3D, colors_precomp=colors_precomp, \\\n                opacities=opacities, means2D=screenspace_points, cov3D_precomp=cov3D_precomp)\n        else:\n            image, _, _ = self.rasterizer(means3D=means3D, colors_precomp=colors_precomp, \\\n                rotations=torch.nn.functional.normalize(rotations), opacities=opacities, scales=scales, \\\n                means2D=screenspace_points)\n            \n        return  image\n    \n\ndef get_view_matrix(R, t):\n    Rt = torch.cat((R, t.view(3,1)),1)\n    view_matrix = torch.cat((Rt, torch.FloatTensor([0,0,0,1]).cuda().view(1,4)))\n    return view_matrix\n\ndef get_proj_yy(f, image_size, far, near):\n    opengl_proj = torch.tensor([[2 * f / image_size[0], 0.0, 0.0, 0.0],\n                            [0.0, 2 * f / image_size[1], 0.0, 0.0],\n                            [0.0, 0.0, far / (far - near), -(far * near) / (far - near)],\n                            [0.0, 0.0, 1.0, 0.0]]).float().unsqueeze(0).transpose(1, 2)\n    return opengl_proj\ndef get_proj_matrix(fovY,fovX, z_near, z_far, z_sign):\n    tanHalfFovY = math.tan((fovY / 2))\n    tanHalfFovX = math.tan((fovX / 2))\n\n    top = tanHalfFovY * z_near\n    bottom = -top\n    right = tanHalfFovX * z_near\n    left = -right\n    z_sign = 1.0\n\n    proj_matrix = torch.zeros(4, 4).float().cuda()\n    proj_matrix[0, 0] = 2.0 * z_near / (right - left)\n    proj_matrix[1, 1] = 2.0 * z_near / (top - bottom)\n    proj_matrix[0, 2] = (right + left) / (right - left)\n    proj_matrix[1, 2] = (top + bottom) / (top - bottom)\n    proj_matrix[3, 2] = z_sign\n    proj_matrix[2, 2] = z_sign * z_far / (z_far - z_near)\n    proj_matrix[2, 3] = -(z_far * z_near) / (z_far - z_near)\n    return proj_matrix\n\ndef get_fov(focal, princpt, img_shape):\n    fov_x = 2 * torch.atan(img_shape[1] / (2 * focal[0]))\n    fov_y = 2 * torch.atan(img_shape[0] / (2 * focal[1]))\n    fov = torch.FloatTensor([fov_x, fov_y]).cuda()\n    return fov"
  },
  {
    "path": "lib/models/sapiens/__init__.py",
    "content": "from .sapiens_wrapper_torchscipt import SapiensWrapper_ts"
  },
  {
    "path": "lib/models/sapiens/sapiens_wrapper_torchscipt.py",
    "content": "# Copyright (c) 2023, Zexin He\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     https://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport torch\nimport torch.nn as nn\nfrom transformers import Dinov2Backbone\nfrom torchvision import transforms\nfrom einops import rearrange\n\nfrom torch import Tensor\n\ndef pretrain_forward(sp_lite, inputs: Tensor, layer_num: int, return_hidden_states=False) -> Tensor:\n    B = inputs.size(0)\n    patch_embed_output, _50, _51, _52, _53 = sp_lite.backbone.patch_embed(inputs)\n    cls_token = sp_lite.backbone.cls_token.expand(B, -1, -1)\n    x = torch.cat([cls_token, patch_embed_output], dim=1)\n\n    cls_pos_embed, patch_pos_embed = sp_lite.backbone.pos_embed[:, 0:1, :], sp_lite.backbone.pos_embed[:, 1:, :]\n    \n    dim = cls_pos_embed.shape[-1]\n    #64x64\n    patch_pos_embed = patch_pos_embed.reshape(-1, 64, 64, dim)\n    patch_pos_embed_ = patch_pos_embed.permute(0, 3, 1, 2)\n    patch_pos_embed = torch.nn.functional.interpolate(\n        patch_pos_embed_,\n        size = (_52, _53),\n        mode=\"bicubic\",\n        align_corners=False,\n    )\n\n    patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(-1, _52 * _53, dim)\n    patch_pos_embed =  torch.cat([cls_pos_embed, patch_pos_embed], dim=1)\n    \n    x = x + patch_pos_embed\n    x = sp_lite.backbone.drop_after_pos(x)\n    if return_hidden_states:\n        hidden_states = []\n        hidden_states.append(x)\n    for i in range(layer_num):\n        x = getattr(sp_lite.backbone.layers, str(i))(x)\n        hidden_states.append(x)\n\n    x = sp_lite.backbone.ln1(x)\n    \n    cls_output = x[:, 0]  # Assuming class token is at index 0\n    patch_tokens = x[:, 1:]  # Remaining are patch tokens\n\n    output = patch_tokens.view(B, _52, _53, -1).permute(0, 3, 1, 2)\n    if return_hidden_states:\n        return output, hidden_states\n    return output\nclass SapiensWrapper_ts(nn.Module):\n\n    \"\"\"\n    Sapiens wrapper using huggingface transformer implementation.\n    \"\"\"\n    def __init__(self,\n                 model_path: str = 'facebook/dinov2-base',\n                 freeze=True,\n                 img_size=None,\n                 layer_num=None):\n        super().__init__()\n        if layer_num == None:\n            if \"0.3b\" in model_path:\n                self.layer_num = 24\n            else:\n                self.layer_num = 48\n        else:\n            self.layer_num = layer_num\n        self.model = torch.jit.load(model_path)\n        if img_size is None:\n            self.my_processor = transforms.Compose([\n                transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),\n                ])\n        else:\n            self.my_processor = transforms.Compose([\n                transforms.Resize(size=img_size),\n                transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),\n                ])\n        self.interpolate_pos_encoding=True\n        if freeze:\n            self._freeze()\n    def forward(self, image, use_my_proces=False, requires_grad=False, output_hidden_states=False):\n        # image: [B, N, C, H, W]\n        # RGB image with [0,1] scale and properly sized\n        if image.ndim == 5:\n            B, N, _, H, W = image.shape\n            mv = True\n            image = image.flatten(0, 1)\n        if image.ndim == 4:\n            N, _, H, W = image.shape\n            B = None\n        else:\n\n            raise NotImplementedError\n        device = image.device\n        if not use_my_proces:\n            inputs = self.image_processor(image, return_tensors=\"pt\") \n            inputs['pixel_values'] = inputs['pixel_values'].to(device) \n        else:\n            inputs = self.my_processor(image)\n            inputs = {'pixel_values': inputs}\n           \n\n        if requires_grad==False:\n            with torch.no_grad():\n                outputs = pretrain_forward(self.model, inputs['pixel_values'], layer_num=self.layer_num, return_hidden_states=output_hidden_states)\n        else:\n            outputs = pretrain_forward(self.model, inputs['pixel_values'], layer_num=self.layer_num, return_hidden_states=output_hidden_states)\n        last_feature_map = outputs[0]\n\n\n        if not output_hidden_states:\n            if B is  None: # dim = 5 \n                last_feature_map = rearrange(last_feature_map, 'n dim h w -> n (h w) dim') # N, N_tk, C\n            else:\n                last_feature_map = rearrange(last_feature_map, 'bn  dim h w -> bn (h w) dim')\n                last_feature_map = last_feature_map.reshape(B, N, last_feature_map.shape[-2], last_feature_map.shape[-1])\n\n        if output_hidden_states:\n            hidden_states = torch.stack(outputs[1], 0).permute(1, 0, 2, 3) # N, N_layer, N_tk, C\n            hidden_states = hidden_states[:, :, 1:,:]  # N, N_layer, N_tk, C\n        if output_hidden_states:\n            return hidden_states\n        else:\n            return last_feature_map\n\n    def _freeze(self):\n        print(f\"======== Freezing DinoWrapper ========\")\n        self.model.eval()\n        for name, param in self.model.named_parameters():\n            param.requires_grad = False\n\nif __name__ == \"__main__\":\n    model = SapiensWrapper_ts()\n    model.eval()\n    image = torch.rand(1, 3, 896, 640)\n    output = model(image, use_my_proces=True,  output_hidden_states=True)\n    output =  pretrain_forward(model, image, layer_num=24)\n    print(output)\n    print(\"done\")"
  },
  {
    "path": "lib/models/transformer_sa/__init__.py",
    "content": "from .mae_decoder_v3_skip import neck_SA_v3_skip"
  },
  {
    "path": "lib/models/transformer_sa/mae_decoder_v3_skip.py",
    "content": "import torch\nimport torch.nn as nn\nimport numpy as np\nfrom timm.models.vision_transformer import PatchEmbed, Block, checkpoint_seq\nfrom typing import Union\n\n\n\ndef get_2d_sincos_pos_embed(embed_dim, grid_size, add_cls_token=False):\n    \"\"\"\n    Create 2D sin/cos positional embeddings.\n\n    Args:\n        embed_dim (`int`):\n            Embedding dimension.\n        grid_size (`int`):\n            The grid height and width.\n        add_cls_token (`bool`, *optional*, defaults to `False`):\n            Whether or not to add a classification (CLS) token.\n\n    Returns:\n        (`torch.FloatTensor` of shape (grid_size*grid_size, embed_dim) or (1+grid_size*grid_size, embed_dim): the\n        position embeddings (with or without classification token)\n    \"\"\"\n    grid_h = np.arange(grid_size, dtype=np.float32)\n    grid_w = np.arange(grid_size, dtype=np.float32)\n    grid = np.meshgrid(grid_w, grid_h)  # here w goes first\n    grid = np.stack(grid, axis=0)\n\n    grid = grid.reshape([2, 1, grid_size, grid_size])\n    pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)\n    if add_cls_token:\n        pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)\n    return pos_embed\n\n\ndef get_2d_sincos_pos_embed_from_grid(embed_dim, grid):\n    if embed_dim % 2 != 0:\n        raise ValueError(\"embed_dim must be even\")\n\n    # use half of dimensions to encode grid_h\n    emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0])  # (H*W, D/2)\n    emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1])  # (H*W, D/2)\n\n    emb = np.concatenate([emb_h, emb_w], axis=1)  # (H*W, D)\n    return emb\n\n\ndef get_1d_sincos_pos_embed_from_grid(embed_dim, pos):\n    \"\"\"\n    embed_dim: output dimension for each position pos: a list of positions to be encoded: size (M,) out: (M, D)\n    \"\"\"\n    if embed_dim % 2 != 0:\n        raise ValueError(\"embed_dim must be even\")\n\n    omega = np.arange(embed_dim // 2, dtype=float)\n    omega /= embed_dim / 2.0\n    omega = 1.0 / 10000**omega  # (D/2,)\n\n    pos = pos.reshape(-1)  # (M,)\n    out = np.einsum(\"m,d->md\", pos, omega)  # (M, D/2), outer product\n\n    emb_sin = np.sin(out)  # (M, D/2)\n    emb_cos = np.cos(out)  # (M, D/2)\n\n    emb = np.concatenate([emb_sin, emb_cos], axis=1)  # (M, D)\n    return emb\n\n\n    \nclass neck_SA_v3_skip(nn.Module):\n    def __init__(self, patch_size=4, in_chans=32, num_patches=196, embed_dim=1024, decoder_embed_dim=512, \n                 decoder_depth=8, decoder_num_heads=16, mlp_ratio=4, norm_layer=nn.LayerNorm, total_num_hidden_states=25, connect_mode:Union['uniform', 'zeros', 'shadow']='uniform', if_checkpoint_seq=False):\n        super().__init__()\n        self.num_patches = num_patches\n        # Decoder-specific\n\n        self.if_checkpoint_seq = if_checkpoint_seq # to save the memory\n\n\n        self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim))\n\n        self.decoder_pos_embed = nn.Parameter(torch.zeros(1, num_patches, decoder_embed_dim), requires_grad=True)  # fixed sin-cos embedding\n\n        self.decoder_blocks_depart = nn.ModuleList([\n            Block(decoder_embed_dim, decoder_num_heads, mlp_ratio, qkv_bias=True,  norm_layer=norm_layer) #qk_scale=None\n            for i in range(decoder_depth)])\n\n        self.decoder_norm = norm_layer(decoder_embed_dim)\n        self.decoder_pred = nn.Linear(decoder_embed_dim, patch_size**2 * in_chans, bias=True) # decoder to patch\n\n        if connect_mode == 'uniform':\n            skip = total_num_hidden_states// (decoder_depth-1)\n            self.select_hidden_states = [skip*i for i in range(decoder_depth)] # for 25\n            self.select_hidden_states[-1] = total_num_hidden_states - 1 \n            self.select_hidden_states = self.select_hidden_states[::-1] # inverse the order\n        elif connect_mode == 'zeros':\n            # print('!!!!!!!!!!! zeros !!!!!!!!!!!!!!!!')\n            self.select_hidden_states = [0, 0, 0, 0, 0, 0]\n        self.decoder_embed = nn.ModuleList([\n            nn.Linear(embed_dim, decoder_embed_dim, bias=True)\n            for _ in range(decoder_depth) \n        ])\n        self.initialize_weights()\n\n    def initialize_weights(self):\n        # Initialization\n        # Initialize (and freeze) pos_embed by sin-cos embedding\n        decoder_pos_embed = get_2d_sincos_pos_embed(self.decoder_pos_embed.shape[-1], int((self.num_patches)**.5), add_cls_token=False)\n        self.decoder_pos_embed.data.copy_(torch.from_numpy(decoder_pos_embed).float().unsqueeze(0))\n\n        # Initialize nn.Linear and nn.LayerNorm\n        self.apply(self._init_weights)\n       \n\n    def _init_weights(self, m):\n        if isinstance(m, nn.Linear):\n            # We use xavier_uniform following official JAX ViT:\n            torch.nn.init.xavier_uniform_(m.weight)\n            if isinstance(m, nn.Linear) and m.bias is not None:\n                nn.init.constant_(m.bias, 0)\n        elif isinstance(m, nn.LayerNorm):\n            nn.init.constant_(m.bias, 0)\n            nn.init.constant_(m.weight, 1.0)\n\n    def forward_decoder(self, in_features, ids_restore):\n        # Embed tokens\n        B, N_l, N_f, C = in_features.shape\n        select_in_features = in_features[:, self.select_hidden_states, :, :]\n        # parallelly embed the hidden states wiht jit\n        forks = [torch.jit.fork(self.decoder_embed[i], select_in_features[:, i]) for i in range(len(self.select_hidden_states))]\n        x_list = [torch.jit.wait(fork) for fork in forks]\n        x_all_states = torch.stack(x_list) # N_l, B, N_feat, C\n\n\n        # Add pos embed\n        mask_tokens = self.mask_token.repeat(B, ids_restore.shape[1], 1) # B, N_q, C\n        query_x = mask_tokens + self.decoder_pos_embed\n\n        # Append mask tokens to sequence\n        x = torch.zeros_like(x_all_states[0])\n        x = torch.cat([x, query_x], dim=1)  # no cls token # B, N_q+N_f, C\n        # # Apply Transformer blocks # v0\n\n        # Apply Transformer blocks # v1\n        for i, blk in enumerate( self.decoder_blocks_depart):\n            x_add = x_all_states[i]\n            x[:, :N_f, :] += x_add # add the hidden states\n            x = blk(x)\n\n        x = self.decoder_norm(x)\n\n\n        \n        x = x[:, -self.num_patches:, :]\n        x_reshaped = x\n\n        return x_reshaped\n\n    def forward(self, encoded_latent, ids_restore):\n        decoded_output = self.forward_decoder(encoded_latent, ids_restore)\n        return decoded_output"
  },
  {
    "path": "lib/ops/__init__.py",
    "content": "from .activation import TruncExp\n"
  },
  {
    "path": "lib/ops/activation.py",
    "content": "import math\nimport torch\nimport torch.nn as nn\nfrom torch.autograd import Function\nfrom torch.cuda.amp import custom_bwd, custom_fwd\n\n\nclass _trunc_exp(Function):\n    @staticmethod\n    @custom_fwd(cast_inputs=torch.float32)  # cast to float32\n    def forward(ctx, x):\n        exp_x = torch.exp(x)\n        ctx.save_for_backward(exp_x)\n        return exp_x\n\n    @staticmethod\n    @custom_bwd\n    def backward(ctx, g):\n        exp_x = ctx.saved_tensors[0]\n        return g * exp_x.clamp(min=1e-6, max=1e6)\n\n\ntrunc_exp = _trunc_exp.apply\n\n\nclass TruncExp(nn.Module):\n\n    @staticmethod\n    def forward(x):\n        return _trunc_exp.apply(x)\n"
  },
  {
    "path": "lib/utils/infer_util.py",
    "content": "import os\nimport imageio\nimport rembg\nimport torch\nimport numpy as np\nimport PIL.Image\nfrom PIL import Image\nfrom typing import Any\nimport json\n\nfrom pathlib import Path\nfrom torchvision.transforms import ToTensor\nfrom rembg import remove  # For background removal\nfrom pytorch3d.transforms import axis_angle_to_matrix, matrix_to_axis_angle\nfrom lib.models.deformers.smplx.lbs import batch_rodrigues\nimport cv2\nfrom PIL import Image\nimport numpy as np\n\nimport json\n# import random\nimport math\n# import av\n\n\ndef reset_first_frame_rotation(root_orient, trans):\n    \"\"\"\n    Set the root_orient rotation matrix of the first frame to the identity matrix (no rotation),\n    keep the relative rotation relationships of other frames, and adjust trans accordingly.\n\n    Parameters:\n        root_orient: Tensor of shape (N, 3), representing the axis-angle parameters for N frames.\n        trans: Tensor of shape (N, 3), representing the translation parameters for N frames.\n\n    Returns:\n        new_root_orient: Tensor of shape (N, 3), adjusted axis-angle parameters.\n        new_trans: Tensor of shape (N, 3), adjusted translation parameters.\n    \"\"\"\n    # Convert the root_orient of the first frame to a rotation matrix\n    R_0 = axis_angle_to_matrix(root_orient[0:1])  # Shape: (1, 3, 3)\n\n    # Compute the inverse of the first frame's rotation matrix\n    R_0_inv = torch.inverse(R_0)  # Shape: (1, 3, 3)\n\n    # Initialize lists for new root_orient and trans\n    new_root_orient = []\n    new_trans = []\n\n    for i in range(root_orient.shape[0]):\n        # Rotation matrix of the current frame\n        R_i = axis_angle_to_matrix(root_orient[i:i+1])  # Shape: (1, 3, 3)\n        R_new = torch.matmul(R_0_inv, R_i)  # Shape: (1, 3, 3)\n        \n        # Convert the rotation matrix back to axis-angle representation\n        axis_angle_new = matrix_to_axis_angle(R_new)  # Shape: (1, 3)\n        new_root_orient.append(axis_angle_new)\n        \n        # Adjust the translation for the current frame\n        trans_i = trans[i:i+1]  # Shape: (1, 3)\n        trans_new = torch.matmul(R_0_inv, trans_i.T).T  # Shape: (1, 3)\n        new_trans.append(trans_new)\n    \n    # Stack the results of new_root_orient and new_trans\n    new_root_orient = torch.cat(new_root_orient, dim=0)  # Shape: (N, 3)\n    new_trans = torch.cat(new_trans, dim=0)  # Shape: (N, 3)\n    \n    # Adjust the new translations relative to the first frame\n    new_trans = new_trans - new_trans[[0], :]  \n\n    return new_root_orient, new_trans \n\nfrom scipy.spatial.transform import Rotation\ndef rotation_matrix_to_rodrigues(rotation_matrices):\n    # reshape rotation_matrices to (-1, 3, 3)\n    reshaped_matrices = rotation_matrices.reshape(-1, 3, 3)\n    rotation = Rotation.from_matrix(reshaped_matrices)\n    rodrigues_vectors = rotation.as_rotvec()\n    return rodrigues_vectors\n\n\n\ndef get_hand_pose_mean():\n    import numpy as np\n    hand_pose_mean=  np.array([[ 0.11167871,  0.04289218, -0.41644183,  0.10881133, -0.06598568,\n        -0.75622   , -0.09639297, -0.09091566, -0.18845929, -0.11809504,\n         0.05094385, -0.5295845 , -0.14369841,  0.0552417 , -0.7048571 ,\n        -0.01918292, -0.09233685, -0.3379135 , -0.45703298, -0.19628395,\n        -0.6254575 , -0.21465237, -0.06599829, -0.50689423, -0.36972436,\n        -0.06034463, -0.07949023, -0.1418697 , -0.08585263, -0.63552827,\n        -0.3033416 , -0.05788098, -0.6313892 , -0.17612089, -0.13209307,\n        -0.37335458,  0.8509643 ,  0.27692273, -0.09154807, -0.49983943,\n         0.02655647,  0.05288088,  0.5355592 ,  0.04596104, -0.27735803,\n         0.11167871, -0.04289218,  0.41644183,  0.10881133,  0.06598568,\n         0.75622   , -0.09639297,  0.09091566,  0.18845929, -0.11809504,\n        -0.05094385,  0.5295845 , -0.14369841, -0.0552417 ,  0.7048571 ,\n        -0.01918292,  0.09233685,  0.3379135 , -0.45703298,  0.19628395,\n         0.6254575 , -0.21465237,  0.06599829,  0.50689423, -0.36972436,\n         0.06034463,  0.07949023, -0.1418697 ,  0.08585263,  0.63552827,\n        -0.3033416 ,  0.05788098,  0.6313892 , -0.17612089,  0.13209307,\n         0.37335458,  0.8509643 , -0.27692273,  0.09154807, -0.49983943,\n        -0.02655647, -0.05288088,  0.5355592 , -0.04596104,  0.27735803]])\n    return hand_pose_mean\n\n\ndef load_smplify_json(smplx_smplify_path):\n    with open(smplx_smplify_path) as f:\n        data = json.load(f)\n    \n    # Prepare camera transformation matrix (R | t)\n    RT = torch.concatenate([torch.Tensor(data['camera']['R']), torch.Tensor(data['camera']['t']).reshape(3, 1) * 2], dim=1)\n    RT = torch.cat([RT, torch.Tensor([[0, 0, 0, 1]])], dim=0)\n\n    # Create intrinsic parameters tensor\n    intri = torch.Tensor(data['camera']['focal'] + data['camera']['princpt'])\n    # intri[[3, 2]] = intri[[2, 3]]\n\n    # # Set default focal length and image resolution\n    # default_focal = 1120  # Default focal length\n    # img_res = [640, 896]\n    # default_fxy_cxy = torch.tensor([default_focal, default_focal, img_res[1] // 2, img_res[0] // 2]).reshape(1, 4)\n\n    # # Adjust intrinsic parameters based on default focal and resolution\n    # intri = intri * default_fxy_cxy[0, -2] / intri[-2]  \n    # intri[-2:] = default_fxy_cxy[0, -2:]  # Force consistent image width and height\n\n    # Extract SMPL parameters from data\n    smpl_param_data = data \n    global_orient = np.array(smpl_param_data['root_pose']).reshape(1, -1)\n    body_pose = np.array(smpl_param_data['body_pose']).reshape(1, -1)\n    shape = np.array(smpl_param_data['betas_save']).reshape(1, -1)[:, :10]\n    left_hand_pose = np.array(smpl_param_data['lhand_pose']).reshape(1, -1)\n    right_hand_pose = np.array(smpl_param_data['rhand_pose']).reshape(1, -1)\n\n    # Concatenate all parameters into a single tensor for SMPL model\n    smpl_param_ref = np.concatenate([np.array([[1.]]), np.array(smpl_param_data['trans']).reshape(1, 3),\n        global_orient, body_pose, shape, left_hand_pose, right_hand_pose,\n        np.array(smpl_param_data['jaw_pose']).reshape(1, -1),\n        np.zeros_like(np.array(smpl_param_data['leye_pose']).reshape(1, -1)),\n        np.zeros_like(np.array(smpl_param_data['reye_pose']).reshape(1, -1)),\n        np.zeros_like(np.array(smpl_param_data['expr']).reshape(1, -1)[:, :10])], axis=1)\n\n    return RT, intri, torch.Tensor(smpl_param_ref).reshape(-1)  # Return transformation, intrinsic, and SMPL parameters\n\ndef load_image(input_path, output_folder, image_frame_ratio=None):\n    input_img_path = Path(input_path)\n\n    vids = []\n    save_path = os.path.join(output_folder, f\"{input_img_path.name}\")\n    print(f\"Processing: {save_path}\")\n    image = Image.open(input_img_path)\n\n    if image.mode == \"RGBA\":\n        pass\n    else:\n        # remove bg\n        image = remove(image.convert(\"RGBA\"), alpha_matting=True)\n\n    # resize object in frame\n    image_arr = np.array(image)\n    in_w, in_h = image_arr.shape[:2]\n    ret, mask = cv2.threshold(\n        np.array(image.split()[-1]), 0, 255, cv2.THRESH_BINARY\n    )\n    x, y, w, h = cv2.boundingRect(mask)\n    max_size = max(w, h)\n    side_len = (\n        int(max_size / image_frame_ratio)\n        if image_frame_ratio is not None\n        else int(max_size / 0.85)\n    )\n    padded_image = np.zeros((side_len, side_len, 4), dtype=np.uint8)\n    center = side_len // 2\n    padded_image[\n        center - h // 2 : center - h // 2 + h,\n        center - w // 2 : center - w // 2 + w,\n    ] = image_arr[y : y + h, x : x + w]\n    rgba = Image.fromarray(padded_image).resize((896, 896), Image.LANCZOS)\n    # crop the width into 640 in the center\n    rgba = rgba.crop([128, 0, 640+128, 896])\n    # white bg\n    rgba_arr = np.array(rgba) / 255.0\n    rgb = rgba_arr[..., :3] * rgba_arr[..., -1:] + (1 - rgba_arr[..., -1:])\n    input_image = Image.fromarray((rgb * 255).astype(np.uint8))\n\n    image = ToTensor()(input_image)\n\n    return image\n\n\n\ndef prepare_camera( resolution_x = 640, resolution_y = 640, focal_length = 600,sensor_width = 32,  camera_dist = 20, num_views=1, stides=1):\n        \n    def look_at(camera_position, target_position, up_vector):  # colmap +z forward, +y down\n        forward = -(camera_position - target_position) / np.linalg.norm(camera_position - target_position)\n        right = np.cross(up_vector, forward)\n        up = np.cross(forward, right)\n        return np.column_stack((right, up, forward))\n\n    # set the intrisics\n    focal_length = focal_length * (resolution_y/sensor_width)\n\n    K = np.array(\n        [[focal_length, 0, resolution_x//2],\n        [0, focal_length, resolution_y//2],\n        [0, 0, 1]]\n    )\n\n    # set the extrisics\n    camera_pose_list = []\n    for frame_idx in range(0, num_views, stides):\n\n        phi = math.radians(90)\n        theta = (3 / 4) * math.pi * 2\n        camera_location = np.array(\n            [camera_dist * math.sin(phi) * math.cos(theta),\n            \n            camera_dist * math.cos(phi),\n            -camera_dist * math.sin(phi) * math.sin(theta),]\n            )\n        # print(camera_location)\n        camera_pose = np.eye(4)\n        camera_pose[:3, 3] = camera_location\n\n        # Set camera position and target position\n        camera_position = camera_location\n        target_position = np.array([0.0, 0.0, 0.0])\n\n        # Compute the camera's rotation matrix to look at the target\n        up_vector = np.array([0.0, -1.0, 0.0]) # colmap\n        rotation_matrix = look_at(camera_position, target_position, up_vector)\n\n        # Update camera position and rotation\n        camera_pose[:3, :3] = rotation_matrix\n        camera_pose[:3, 3] = camera_position\n        camera_pose_list.append(camera_pose)\n    return K, camera_pose_list\n\n\ndef construct_camera(K, cam_list, device='cuda'):\n    num_imgs = len(cam_list)\n    front_idx = num_imgs//4*3\n    cam_list = cam_list[front_idx:] + cam_list[:front_idx] \n    cam_raw = np.array(cam_list)\n    cam_raw[:, :3, 3] = cam_raw[:, :3, 3] \n    cam = np.linalg.inv(cam_raw)\n    cam = torch.Tensor(cam)\n    intrics = torch.Tensor([K[0,0],K[1,1], K[0,2], K[1,2]]).reshape(-1)\n    scale = 0.5\n    # diffrent from the synthetic data, the scale is process first\n    # trans from (3,) to (batch_size, 3,1)\n    trans = [0, 0.2, 0] #in the center\n    trans_bt = torch.Tensor(trans).reshape(1, 3, 1).expand(cam.shape[0], 3, 1)\n    cam[:,:3,3] = cam[:,:3,3] + torch.bmm(cam[:,:3,:3], trans_bt).reshape(-1, 3) # T = Rt+T torch.Size([24, 3, 1])\n    cam[:,:3,:3] = cam[:,:3,:3] * scale  # R = sR\n    cam_c2w = torch.inverse(cam)\n    cam_w2c = cam\n    poses = []\n    for i_cam in range(cam.shape[0]):\n        poses.append( torch.concat([\n            (intrics.reshape(-1)).to(torch.float32), #C ! # C ? T 理论上要给C\n            (cam_w2c[i_cam]).to(torch.float32).reshape(-1), # RT  #Rt|C ? RT 理论上要给RT\n        ], dim=0))\n    cameras = torch.stack(poses).to(device) # [N, 19]\n    return cameras\n\ndef get_name_str(name):\n    path_ = os.path.basename(os.path.dirname(name)) + os.path.basename(name)\n    return path_\n\n\n\ndef load_smplx_from_npy(smplx_path, device='cuda'):\n    hand_mean = get_hand_pose_mean().reshape(-1)\n    smplx_pose_param = np.load(smplx_path, allow_pickle=True)\n    # if \"person1\" in smplx_pose_param:\n    #     smplx_pose_param = smplx_pose_param['person1']\n    smplx_pose_param = {\n        'root_orient': smplx_pose_param[:, :3],  # controls the global root orientation\n        'pose_body': smplx_pose_param[:, 3:3+63],  # controls the body\n        'pose_hand': smplx_pose_param[:, 66:66+90],  # controls the finger articulation\n        'pose_jaw': smplx_pose_param[:, 66+90:66+93],  # controls the yaw pose\n        'face_expr': smplx_pose_param[:, 159:159+50],  # controls the face expression\n        'face_shape': smplx_pose_param[:, 209:209+100],  # controls the face shape\n        'trans': smplx_pose_param[:, 309:309+3],  # controls the global body position\n        'betas': smplx_pose_param[:, 312:],  # controls the body shape. Body shape is static\n    }\n\n    smplx_param_list = []\n    for i in range(1, 1800, 1):\n        # for i in k.keys():\n        #     k[i] = np.array(k[i])\n        left_hands = np.array([1.4624, -0.1615,  0.1361,  1.3851, -0.2597,  0.0247, -0.0683, -0.4478,\n            -0.6652, -0.7290,  0.0084, -0.4818])\n        betas = torch.zeros((10))\n        smplx_param = \\\n            np.concatenate([np.array([1]), smplx_pose_param['trans'][i], smplx_pose_param['root_orient'][i], \\\n                            smplx_pose_param['pose_body'][i],betas, \\\n                                smplx_pose_param['pose_hand'][i]-hand_mean, smplx_pose_param['pose_jaw'][i], np.zeros(6), smplx_pose_param['face_expr'][i][:10]], axis=0).reshape(1,-1)\n        smplx_param_list.append(smplx_param)\n    smplx_params = np.concatenate(smplx_param_list, 0)\n    smpl_params = torch.Tensor(smplx_params).to(device)\n    return smpl_params\ndef add_root_rotate_to_smplx(smpl_tmp, frames_num=180, device='cuda'):\n    from cv2 import Rodrigues\n    initial_matrix = batch_rodrigues(smpl_tmp.reshape(1,189)[:, 4:7]).cpu().numpy().copy()\n    # Rotate a rotation matrix by 360 degrees around the y-axis.\n    # frames_num = 180\n    all_smpl = []\n    # Combine the rotations\n    all_smpl = []\n    for idx_f in range(frames_num):\n        new_smpl = smpl_tmp.clone() \n        angle = 360//frames_num * idx_f\n        y_angle = np.radians(angle)\n        y_rotation_matrix = np.array([\n            [ np.cos(y_angle),0,  np.sin(y_angle)],\n            [0,  1, 0],\n            [-np.sin(y_angle), 0, np.cos(y_angle)],\n        ])\n        final_matrix = y_rotation_matrix[None] @ initial_matrix\n        \n        new_smpl[4:7] = torch.Tensor(rotation_matrix_to_rodrigues(torch.Tensor(final_matrix))).to(device)\n        all_smpl.append(new_smpl)\n    all_smpl = torch.stack(all_smpl, 0)\n    smpl_params = all_smpl.to(device)\n    return smpl_params\n\ndef load_smplx_from_json(smplx_path, device='cuda'):\n    # format of motion-x\n    hand_mean = get_hand_pose_mean().reshape(-1)\n    with open(smplx_path, 'r') as f:\n        smplx_pose_param = json.load(f)\n    smplx_param_list = []\n    for par in smplx_pose_param['annotations']:\n        k = par['smplx_params']\n        for i in k.keys():\n            k[i] = np.array(k[i])\n\n        betas = torch.zeros((10))\n        # #########   wrist pose fix ################\n        smplx_param = \\\n            np.concatenate([np.array([1]), k['trans'], \n                            k['root_orient']*np.array([1, 1, 1]), \\\n                            k['pose_body'],betas, \\\n                            k['pose_hand']-hand_mean, k['pose_jaw'], np.zeros(6), np.zeros_like(k['face_expr'][:10])], axis=0).reshape(1,-1)\n        smplx_param_list.append(smplx_param)\n\n\n    smplx_params = np.concatenate(smplx_param_list, 0)\n    print(smplx_params.shape)\n    smpl_params = torch.Tensor(smplx_params).to(device)\n    return smpl_params\n\ndef get_image_dimensions(input_path):\n    with Image.open(input_path) as img:\n        return img.height, img.width  \n\ndef construct_camera_from_motionx(smplx_path, device='cuda'):\n    with open(smplx_path, 'r') as f:\n        smplx_pose_param = json.load(f)\n    cam_exts = []\n    cam_ints = []\n    for par in smplx_pose_param['annotations']:\n        cam = par['cam_params']\n        R = np.array(cam['cam_R'])\n        K = np.array(cam['intrins'])\n        T = np.array(cam['cam_T']) \n        cam['cam_T'][1] = -cam['cam_T'][1]\n        cam['cam_T'][2] = -cam['cam_T'][2]\n        extrix = np.eye(4)\n        extrix[:3, :3] = R\n        extrix[:3,3] = T\n        cam_exts.append(extrix)\n        intrix = K\n        cam_ints.append(intrix)\n\n    # target N,20\n    cam_exts_array = np.array(cam_exts)\n\n    cam_exts_stack = torch.Tensor(cam_exts_array).to(device).reshape(-1, 16)\n    cam_ints_stack = torch.Tensor(cam_ints).to(device).reshape(-1, 4)\n    cameras = torch.cat([cam_ints_stack, cam_exts_stack], dim=-1).reshape(-1,1, 20)\n    return cameras\n\ndef remove_background(image: PIL.Image.Image,\n    rembg_session: Any = None,\n    force: bool = False,\n    **rembg_kwargs,\n) -> PIL.Image.Image:\n    do_remove = True\n    if image.mode == \"RGBA\" and image.getextrema()[3][0] < 255:\n        do_remove = False\n    do_remove = do_remove or force\n    if do_remove:\n        image = rembg.remove(image, session=rembg_session, **rembg_kwargs)\n    return image\n\n\ndef resize_foreground(\n    image: PIL.Image.Image,\n    ratio: float,\n) -> PIL.Image.Image:\n    image = np.array(image)\n    assert image.shape[-1] == 4\n    alpha = np.where(image[..., 3] > 0)\n    y1, y2, x1, x2 = (\n        alpha[0].min(),\n        alpha[0].max(),\n        alpha[1].min(),\n        alpha[1].max(),\n    )\n    # crop the foreground\n    fg = image[y1:y2, x1:x2]\n    # pad to square\n    size = max(fg.shape[0], fg.shape[1])\n    ph0, pw0 = (size - fg.shape[0]) // 2, (size - fg.shape[1]) // 2\n    ph1, pw1 = size - fg.shape[0] - ph0, size - fg.shape[1] - pw0\n    new_image = np.pad(\n        fg,\n        ((ph0, ph1), (pw0, pw1), (0, 0)),\n        mode=\"constant\",\n        constant_values=((0, 0), (0, 0), (0, 0)),\n    )\n\n    # compute padding according to the ratio\n    new_size = int(new_image.shape[0] / ratio)\n    # pad to size, double side\n    ph0, pw0 = (new_size - size) // 2, (new_size - size) // 2\n    ph1, pw1 = new_size - size - ph0, new_size - size - pw0\n    new_image = np.pad(\n        new_image,\n        ((ph0, ph1), (pw0, pw1), (0, 0)),\n        mode=\"constant\",\n        constant_values=((0, 0), (0, 0), (0, 0)),\n    )\n    new_image = PIL.Image.fromarray(new_image)\n    return new_image\n\n\ndef images_to_video(\n    images: torch.Tensor, \n    output_path: str, \n    fps: int = 30,\n) -> None:\n    # images: (N, C, H, W)\n    video_dir = os.path.dirname(output_path)\n    video_name = os.path.basename(output_path)\n    os.makedirs(video_dir, exist_ok=True)\n\n    frames = []\n    for i in range(len(images)):\n        frame = (images[i].permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8)\n        assert frame.shape[0] == images.shape[2] and frame.shape[1] == images.shape[3], \\\n            f\"Frame shape mismatch: {frame.shape} vs {images.shape}\"\n        assert frame.min() >= 0 and frame.max() <= 255, \\\n            f\"Frame value out of range: {frame.min()} ~ {frame.max()}\"\n        frames.append(frame)\n    imageio.mimwrite(output_path, np.stack(frames), fps=fps, quality=10)\n\n\ndef save_video(\n    frames: torch.Tensor,\n    output_path: str,\n    fps: int = 30,\n) -> None:\n    # images: (N, C, H, W)\n    frames = [(frame.permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8) for frame in frames]\n    writer = imageio.get_writer(output_path, fps=fps)\n    for frame in frames:\n        writer.append_data(frame)\n    writer.close()"
  },
  {
    "path": "lib/utils/mesh.py",
    "content": "import os\nimport cv2\nimport torch\nimport trimesh\nimport numpy as np\n\ndef dot(x, y):\n    return torch.sum(x * y, -1, keepdim=True)\n\n\ndef length(x, eps=1e-20):\n    return torch.sqrt(torch.clamp(dot(x, x), min=eps))\n\n\ndef safe_normalize(x, eps=1e-20):\n    return x / length(x, eps)\n\nclass Mesh:\n    def __init__(\n        self,\n        v=None,\n        f=None,\n        vn=None,\n        fn=None,\n        vt=None,\n        ft=None,\n        albedo=None,\n        vc=None, # vertex color\n        device=None,\n    ):\n        self.device = device\n        self.v = v\n        self.vn = vn\n        self.vt = vt\n        self.f = f\n        self.fn = fn\n        self.ft = ft\n        # only support a single albedo\n        self.albedo = albedo\n        # support vertex color is no albedo\n        self.vc = vc\n\n        self.ori_center = 0\n        self.ori_scale = 1\n\n    @classmethod\n    def load(cls, path=None, resize=True, renormal=True, retex=False, front_dir='+z', **kwargs):\n        # assume init with kwargs\n        if path is None:\n            mesh = cls(**kwargs)\n        # obj supports face uv\n        elif path.endswith(\".obj\"):\n            mesh = cls.load_obj(path, **kwargs)\n        # trimesh only supports vertex uv, but can load more formats\n        else:\n            mesh = cls.load_trimesh(path, **kwargs)\n\n        print(f\"[Mesh loading] v: {mesh.v.shape}, f: {mesh.f.shape}\")\n        # auto-normalize\n        if resize:\n            mesh.auto_size()\n        # auto-fix normal\n        if renormal or mesh.vn is None:\n            mesh.auto_normal()\n            print(f\"[Mesh loading] vn: {mesh.vn.shape}, fn: {mesh.fn.shape}\")\n        # auto-fix texcoords\n        if retex or (mesh.albedo is not None and mesh.vt is None):\n            mesh.auto_uv(cache_path=path)\n            print(f\"[Mesh loading] vt: {mesh.vt.shape}, ft: {mesh.ft.shape}\")\n\n        # rotate front dir to +z\n        if front_dir != \"+z\":\n            # axis switch\n            if \"-z\" in front_dir:\n                T = torch.tensor([[1, 0, 0], [0, 1, 0], [0, 0, -1]], device=mesh.device, dtype=torch.float32)\n            elif \"+x\" in front_dir:\n                T = torch.tensor([[0, 0, 1], [0, 1, 0], [1, 0, 0]], device=mesh.device, dtype=torch.float32)\n            elif \"-x\" in front_dir:\n                T = torch.tensor([[0, 0, -1], [0, 1, 0], [1, 0, 0]], device=mesh.device, dtype=torch.float32)\n            elif \"+y\" in front_dir:\n                T = torch.tensor([[1, 0, 0], [0, 0, 1], [0, 1, 0]], device=mesh.device, dtype=torch.float32)\n            elif \"-y\" in front_dir:\n                T = torch.tensor([[1, 0, 0], [0, 0, -1], [0, 1, 0]], device=mesh.device, dtype=torch.float32)\n            else:\n                T = torch.tensor([[1, 0, 0], [0, 1, 0], [0, 0, 1]], device=mesh.device, dtype=torch.float32)\n            # rotation (how many 90 degrees)\n            if '1' in front_dir:\n                T @= torch.tensor([[0, -1, 0], [1, 0, 0], [0, 0, 1]], device=mesh.device, dtype=torch.float32) \n            elif '2' in front_dir:\n                T @= torch.tensor([[1, 0, 0], [0, -1, 0], [0, 0, 1]], device=mesh.device, dtype=torch.float32) \n            elif '3' in front_dir:\n                T @= torch.tensor([[0, 1, 0], [-1, 0, 0], [0, 0, 1]], device=mesh.device, dtype=torch.float32) \n            mesh.v @= T\n            mesh.vn @= T\n\n        return mesh\n\n    # load from obj file\n    @classmethod\n    def load_obj(cls, path, albedo_path=None, device=None):\n        assert os.path.splitext(path)[-1] == \".obj\"\n\n        mesh = cls()\n\n        # device\n        if device is None:\n            device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n\n        mesh.device = device\n\n        # load obj\n        with open(path, \"r\") as f:\n            lines = f.readlines()\n\n        def parse_f_v(fv):\n            xs = [int(x) - 1 if x != \"\" else -1 for x in fv.split(\"/\")]\n            xs.extend([-1] * (3 - len(xs)))\n            return xs[0], xs[1], xs[2]\n\n        # NOTE: we ignore usemtl, and assume the mesh ONLY uses one material (first in mtl)\n        vertices, texcoords, normals = [], [], []\n        faces, tfaces, nfaces = [], [], []\n        mtl_path = None\n\n        for line in lines:\n            split_line = line.split()\n            # empty line\n            if len(split_line) == 0:\n                continue\n            prefix = split_line[0].lower()\n            # mtllib\n            if prefix == \"mtllib\":\n                mtl_path = split_line[1]\n            # usemtl\n            elif prefix == \"usemtl\":\n                pass # ignored\n            # v/vn/vt\n            elif prefix == \"v\":\n                vertices.append([float(v) for v in split_line[1:]])\n            elif prefix == \"vn\":\n                normals.append([float(v) for v in split_line[1:]])\n            elif prefix == \"vt\":\n                val = [float(v) for v in split_line[1:]]\n                texcoords.append([val[0], 1.0 - val[1]])\n            elif prefix == \"f\":\n                vs = split_line[1:]\n                nv = len(vs)\n                v0, t0, n0 = parse_f_v(vs[0])\n                for i in range(nv - 2):  # triangulate (assume vertices are ordered)\n                    v1, t1, n1 = parse_f_v(vs[i + 1])\n                    v2, t2, n2 = parse_f_v(vs[i + 2])\n                    faces.append([v0, v1, v2])\n                    tfaces.append([t0, t1, t2])\n                    nfaces.append([n0, n1, n2])\n\n        mesh.v = torch.tensor(vertices, dtype=torch.float32, device=device)\n        mesh.vt = (\n            torch.tensor(texcoords, dtype=torch.float32, device=device)\n            if len(texcoords) > 0\n            else None\n        )\n        mesh.vn = (\n            torch.tensor(normals, dtype=torch.float32, device=device)\n            if len(normals) > 0\n            else None\n        )\n\n        mesh.f = torch.tensor(faces, dtype=torch.int32, device=device)\n        mesh.ft = (\n            torch.tensor(tfaces, dtype=torch.int32, device=device)\n            if len(texcoords) > 0\n            else None\n        )\n        mesh.fn = (\n            torch.tensor(nfaces, dtype=torch.int32, device=device)\n            if len(normals) > 0\n            else None\n        )\n\n        # see if there is vertex color\n        use_vertex_color = False\n        if mesh.v.shape[1] == 6:\n            use_vertex_color = True\n            mesh.vc = mesh.v[:, 3:]\n            mesh.v = mesh.v[:, :3]\n            print(f\"[load_obj] use vertex color: {mesh.vc.shape}\")\n\n        # try to load texture image\n        if not use_vertex_color:\n            # try to retrieve mtl file\n            mtl_path_candidates = []\n            if mtl_path is not None:\n                mtl_path_candidates.append(mtl_path)\n                mtl_path_candidates.append(os.path.join(os.path.dirname(path), mtl_path))\n            mtl_path_candidates.append(path.replace(\".obj\", \".mtl\"))\n\n            mtl_path = None\n            for candidate in mtl_path_candidates:\n                if os.path.exists(candidate):\n                    mtl_path = candidate\n                    break\n            \n            # if albedo_path is not provided, try retrieve it from mtl\n            if mtl_path is not None and albedo_path is None:\n                with open(mtl_path, \"r\") as f:\n                    lines = f.readlines()\n                for line in lines:\n                    split_line = line.split()\n                    # empty line\n                    if len(split_line) == 0:\n                        continue\n                    prefix = split_line[0]\n                    # NOTE: simply use the first map_Kd as albedo!\n                    if \"map_Kd\" in prefix:\n                        albedo_path = os.path.join(os.path.dirname(path), split_line[1])\n                        print(f\"[load_obj] use texture from: {albedo_path}\")\n                        break\n            \n            # still not found albedo_path, or the path doesn't exist\n            if albedo_path is None or not os.path.exists(albedo_path):\n                # init an empty texture\n                print(f\"[load_obj] init empty albedo!\")\n                # albedo = np.random.rand(1024, 1024, 3).astype(np.float32)\n                albedo = np.ones((1024, 1024, 3), dtype=np.float32) * np.array([0.5, 0.5, 0.5])  # default color\n            else:\n                albedo = cv2.imread(albedo_path, cv2.IMREAD_UNCHANGED)\n                albedo = cv2.cvtColor(albedo, cv2.COLOR_BGR2RGB)\n                albedo = albedo.astype(np.float32) / 255\n                print(f\"[load_obj] load texture: {albedo.shape}\")\n\n\n\n            mesh.albedo = torch.tensor(albedo, dtype=torch.float32, device=device)\n\n        return mesh\n\n    @classmethod\n    def load_trimesh(cls, path, device=None):\n        mesh = cls()\n\n        # device\n        if device is None:\n            device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n\n        mesh.device = device\n\n        # use trimesh to load ply/glb, assume only has one single RootMesh...\n        _data = trimesh.load(path)\n        if isinstance(_data, trimesh.Scene):\n            if len(_data.geometry) == 1:\n                _mesh = list(_data.geometry.values())[0]\n            else:\n                # manual concat, will lose texture\n                _concat = []\n                for g in _data.geometry.values():\n                    if isinstance(g, trimesh.Trimesh):\n                        _concat.append(g)\n                _mesh = trimesh.util.concatenate(_concat)\n        else:\n            _mesh = _data\n        \n        if _mesh.visual.kind == 'vertex':\n            vertex_colors = _mesh.visual.vertex_colors\n            vertex_colors = np.array(vertex_colors[..., :3]).astype(np.float32) / 255\n            mesh.vc = torch.tensor(vertex_colors, dtype=torch.float32, device=device)\n            print(f\"[load_trimesh] use vertex color: {mesh.vc.shape}\")\n        elif _mesh.visual.kind == 'texture':\n            _material = _mesh.visual.material\n            if isinstance(_material, trimesh.visual.material.PBRMaterial):\n                texture = np.array(_material.baseColorTexture).astype(np.float32) / 255\n            elif isinstance(_material, trimesh.visual.material.SimpleMaterial):\n                texture = np.array(_material.to_pbr().baseColorTexture).astype(np.float32) / 255\n            else:\n                raise NotImplementedError(f\"material type {type(_material)} not supported!\")\n            mesh.albedo = torch.tensor(texture, dtype=torch.float32, device=device)\n            print(f\"[load_trimesh] load texture: {texture.shape}\")\n        else:\n            texture = np.ones((1024, 1024, 3), dtype=np.float32) * np.array([0.5, 0.5, 0.5])\n            mesh.albedo = torch.tensor(texture, dtype=torch.float32, device=device)\n            print(f\"[load_trimesh] failed to load texture.\")\n\n        vertices = _mesh.vertices\n\n        try:\n            texcoords = _mesh.visual.uv\n            texcoords[:, 1] = 1 - texcoords[:, 1]\n        except Exception as e:\n            texcoords = None\n\n        try:\n            normals = _mesh.vertex_normals\n        except Exception as e:\n            normals = None\n\n        # trimesh only support vertex uv...\n        faces = tfaces = nfaces = _mesh.faces\n\n        mesh.v = torch.tensor(vertices, dtype=torch.float32, device=device)\n        mesh.vt = (\n            torch.tensor(texcoords, dtype=torch.float32, device=device)\n            if texcoords is not None\n            else None\n        )\n        mesh.vn = (\n            torch.tensor(normals, dtype=torch.float32, device=device)\n            if normals is not None\n            else None\n        )\n\n        mesh.f = torch.tensor(faces, dtype=torch.int32, device=device)\n        mesh.ft = (\n            torch.tensor(tfaces, dtype=torch.int32, device=device)\n            if texcoords is not None\n            else None\n        )\n        mesh.fn = (\n            torch.tensor(nfaces, dtype=torch.int32, device=device)\n            if normals is not None\n            else None\n        )\n\n        return mesh\n\n    # aabb\n    def aabb(self):\n        return torch.min(self.v, dim=0).values, torch.max(self.v, dim=0).values\n\n    # unit size\n    @torch.no_grad()\n    def auto_size(self):\n        vmin, vmax = self.aabb()\n        self.ori_center = (vmax + vmin) / 2\n        self.ori_scale = 1.2 / torch.max(vmax - vmin).item()\n        self.v = (self.v - self.ori_center) * self.ori_scale\n\n    def auto_normal(self):\n        i0, i1, i2 = self.f[:, 0].long(), self.f[:, 1].long(), self.f[:, 2].long()\n        v0, v1, v2 = self.v[i0, :], self.v[i1, :], self.v[i2, :]\n\n        face_normals = torch.cross(v1 - v0, v2 - v0, dim=1)\n\n        # Splat face normals to vertices\n        vn = torch.zeros_like(self.v)\n        vn.scatter_add_(0, i0[:, None].repeat(1, 3), face_normals)\n        vn.scatter_add_(0, i1[:, None].repeat(1, 3), face_normals)\n        vn.scatter_add_(0, i2[:, None].repeat(1, 3), face_normals)\n\n        # Normalize, replace zero (degenerated) normals with some default value\n        vn = torch.where(\n            dot(vn, vn) > 1e-20,\n            vn,\n            torch.tensor([0.0, 0.0, 1.0], dtype=torch.float32, device=vn.device),\n        )\n        vn = safe_normalize(vn)\n\n        self.vn = vn\n        self.fn = self.f\n\n    def auto_uv(self, cache_path=None, vmap=True):\n        # try to load cache\n        if cache_path is not None:\n            cache_path = os.path.splitext(cache_path)[0] + \"_uv.npz\"\n        if cache_path is not None and os.path.exists(cache_path):\n            data = np.load(cache_path)\n            vt_np, ft_np, vmapping = data[\"vt\"], data[\"ft\"], data[\"vmapping\"]\n        else:\n            import xatlas\n\n            v_np = self.v.detach().cpu().numpy()\n            f_np = self.f.detach().int().cpu().numpy()\n            atlas = xatlas.Atlas()\n            atlas.add_mesh(v_np, f_np)\n            chart_options = xatlas.ChartOptions()\n            # chart_options.max_iterations = 4\n            atlas.generate(chart_options=chart_options)\n            vmapping, ft_np, vt_np = atlas[0]  # [N], [M, 3], [N, 2]\n\n            # save to cache\n            if cache_path is not None:\n                np.savez(cache_path, vt=vt_np, ft=ft_np, vmapping=vmapping)\n        \n        vt = torch.from_numpy(vt_np.astype(np.float32)).to(self.device)\n        ft = torch.from_numpy(ft_np.astype(np.int32)).to(self.device)\n        self.vt = vt\n        self.ft = ft\n\n        if vmap:\n            # remap v/f to vt/ft, so each v correspond to a unique vt. (necessary for gltf)\n            vmapping = torch.from_numpy(vmapping.astype(np.int64)).long().to(self.device)\n            self.align_v_to_vt(vmapping)\n    \n    def align_v_to_vt(self, vmapping=None):\n        # remap v/f and vn/vn to vt/ft.\n        if vmapping is None:\n            ft = self.ft.view(-1).long()\n            f = self.f.view(-1).long()\n            vmapping = torch.zeros(self.vt.shape[0], dtype=torch.long, device=self.device)\n            vmapping[ft] = f # scatter, randomly choose one if index is not unique\n\n        self.v = self.v[vmapping]\n        self.f = self.ft\n        # assume fn == f\n        if self.vn is not None:\n            self.vn = self.vn[vmapping]\n            self.fn = self.ft\n\n    def to(self, device):\n        self.device = device\n        for name in [\"v\", \"f\", \"vn\", \"fn\", \"vt\", \"ft\", \"albedo\"]:\n            tensor = getattr(self, name)\n            if tensor is not None:\n                setattr(self, name, tensor.to(device))\n        return self\n    \n    def write(self, path):\n        if path.endswith(\".ply\"):\n            self.write_ply(path)\n        elif path.endswith(\".obj\"):\n            self.write_obj(path)\n        elif path.endswith(\".glb\") or path.endswith(\".gltf\"):\n            self.write_glb(path)\n        else:\n            raise NotImplementedError(f\"format {path} not supported!\")\n    \n    # write to ply file (only geom)\n    def write_ply(self, path):\n\n        v_np = self.v.detach().cpu().numpy()\n        f_np = self.f.detach().cpu().numpy()\n\n        _mesh = trimesh.Trimesh(vertices=v_np, faces=f_np)\n        _mesh.export(path)\n\n    # write to gltf/glb file (geom + texture)\n    def write_glb(self, path):\n\n        assert self.vn is not None and self.vt is not None # should be improved to support export without texture...\n\n        # assert self.v.shape[0] == self.vn.shape[0] and self.v.shape[0] == self.vt.shape[0]\n        if self.v.shape[0] != self.vt.shape[0]:\n            self.align_v_to_vt()\n\n        # assume f == fn == ft\n\n        import pygltflib\n\n        f_np = self.f.detach().cpu().numpy().astype(np.uint32)\n        v_np = self.v.detach().cpu().numpy().astype(np.float32)\n        # vn_np = self.vn.detach().cpu().numpy().astype(np.float32)\n        vt_np = self.vt.detach().cpu().numpy().astype(np.float32)\n\n        albedo = self.albedo.detach().cpu().numpy()\n        albedo = (albedo * 255).astype(np.uint8)\n        albedo = cv2.cvtColor(albedo, cv2.COLOR_RGB2BGR)\n\n        f_np_blob = f_np.flatten().tobytes()\n        v_np_blob = v_np.tobytes()\n        # vn_np_blob = vn_np.tobytes()\n        vt_np_blob = vt_np.tobytes()\n        albedo_blob = cv2.imencode('.png', albedo)[1].tobytes()\n\n        gltf = pygltflib.GLTF2(\n            scene=0,\n            scenes=[pygltflib.Scene(nodes=[0])],\n            nodes=[pygltflib.Node(mesh=0)],\n            meshes=[pygltflib.Mesh(primitives=[\n                pygltflib.Primitive(\n                    # indices to accessors (0 is triangles)\n                    attributes=pygltflib.Attributes(\n                        POSITION=1, TEXCOORD_0=2, \n                    ),\n                    indices=0, material=0,\n                )\n            ])],\n            materials=[\n                pygltflib.Material(\n                    pbrMetallicRoughness=pygltflib.PbrMetallicRoughness(\n                        baseColorTexture=pygltflib.TextureInfo(index=0, texCoord=0),\n                        metallicFactor=0.0,\n                        roughnessFactor=1.0,\n                    ),\n                    alphaCutoff=0,\n                    doubleSided=True,\n                )\n            ],\n            textures=[\n                pygltflib.Texture(sampler=0, source=0),\n            ],\n            samplers=[\n                pygltflib.Sampler(magFilter=pygltflib.LINEAR, minFilter=pygltflib.LINEAR_MIPMAP_LINEAR, wrapS=pygltflib.REPEAT, wrapT=pygltflib.REPEAT),\n            ],\n            images=[\n                # use embedded (buffer) image\n                pygltflib.Image(bufferView=3, mimeType=\"image/png\"),\n            ],\n            buffers=[\n                pygltflib.Buffer(byteLength=len(f_np_blob) + len(v_np_blob) + len(vt_np_blob) + len(albedo_blob))\n            ],\n            # buffer view (based on dtype)\n            bufferViews=[\n                # triangles; as flatten (element) array\n                pygltflib.BufferView(\n                    buffer=0,\n                    byteLength=len(f_np_blob),\n                    target=pygltflib.ELEMENT_ARRAY_BUFFER, # GL_ELEMENT_ARRAY_BUFFER (34963)\n                ),\n                # positions; as vec3 array\n                pygltflib.BufferView(\n                    buffer=0,\n                    byteOffset=len(f_np_blob),\n                    byteLength=len(v_np_blob),\n                    byteStride=12, # vec3\n                    target=pygltflib.ARRAY_BUFFER, # GL_ARRAY_BUFFER (34962)\n                ),\n                # texcoords; as vec2 array\n                pygltflib.BufferView(\n                    buffer=0,\n                    byteOffset=len(f_np_blob) + len(v_np_blob),\n                    byteLength=len(vt_np_blob),\n                    byteStride=8, # vec2\n                    target=pygltflib.ARRAY_BUFFER,\n                ),\n                # texture; as none target\n                pygltflib.BufferView(\n                    buffer=0,\n                    byteOffset=len(f_np_blob) + len(v_np_blob) + len(vt_np_blob),\n                    byteLength=len(albedo_blob),\n                ),\n            ],\n            accessors=[\n                # 0 = triangles\n                pygltflib.Accessor(\n                    bufferView=0,\n                    componentType=pygltflib.UNSIGNED_INT, # GL_UNSIGNED_INT (5125)\n                    count=f_np.size,\n                    type=pygltflib.SCALAR,\n                    max=[int(f_np.max())],\n                    min=[int(f_np.min())],\n                ),\n                # 1 = positions\n                pygltflib.Accessor(\n                    bufferView=1,\n                    componentType=pygltflib.FLOAT, # GL_FLOAT (5126)\n                    count=len(v_np),\n                    type=pygltflib.VEC3,\n                    max=v_np.max(axis=0).tolist(),\n                    min=v_np.min(axis=0).tolist(),\n                ),\n                # 2 = texcoords\n                pygltflib.Accessor(\n                    bufferView=2,\n                    componentType=pygltflib.FLOAT,\n                    count=len(vt_np),\n                    type=pygltflib.VEC2,\n                    max=vt_np.max(axis=0).tolist(),\n                    min=vt_np.min(axis=0).tolist(),\n                ),\n            ],\n        )\n\n        # set actual data\n        gltf.set_binary_blob(f_np_blob + v_np_blob + vt_np_blob + albedo_blob)\n\n        # glb = b\"\".join(gltf.save_to_bytes())\n        gltf.save(path)\n\n    # write to obj file (geom + texture)\n    def write_obj(self, path):\n\n        mtl_path = path.replace(\".obj\", \".mtl\")\n        albedo_path = path.replace(\".obj\", \"_albedo.png\")\n\n        v_np = self.v.detach().cpu().numpy()\n        vt_np = self.vt.detach().cpu().numpy() if self.vt is not None else None\n        vn_np = self.vn.detach().cpu().numpy() if self.vn is not None else None\n        f_np = self.f.detach().cpu().numpy()\n        ft_np = self.ft.detach().cpu().numpy() if self.ft is not None else None\n        fn_np = self.fn.detach().cpu().numpy() if self.fn is not None else None\n\n        with open(path, \"w\") as fp:\n            fp.write(f\"mtllib {os.path.basename(mtl_path)} \\n\")\n\n            for v in v_np:\n                fp.write(f\"v {v[0]} {v[1]} {v[2]} \\n\")\n\n            if vt_np is not None:\n                for v in vt_np:\n                    fp.write(f\"vt {v[0]} {1 - v[1]} \\n\")\n\n            if vn_np is not None:\n                for v in vn_np:\n                    fp.write(f\"vn {v[0]} {v[1]} {v[2]} \\n\")\n\n            fp.write(f\"usemtl defaultMat \\n\")\n            for i in range(len(f_np)):\n                fp.write(\n                    f'f {f_np[i, 0] + 1}/{ft_np[i, 0] + 1 if ft_np is not None else \"\"}/{fn_np[i, 0] + 1 if fn_np is not None else \"\"} \\\n                             {f_np[i, 1] + 1}/{ft_np[i, 1] + 1 if ft_np is not None else \"\"}/{fn_np[i, 1] + 1 if fn_np is not None else \"\"} \\\n                             {f_np[i, 2] + 1}/{ft_np[i, 2] + 1 if ft_np is not None else \"\"}/{fn_np[i, 2] + 1 if fn_np is not None else \"\"} \\n'\n                )\n\n        with open(mtl_path, \"w\") as fp:\n            fp.write(f\"newmtl defaultMat \\n\")\n            fp.write(f\"Ka 1 1 1 \\n\")\n            fp.write(f\"Kd 1 1 1 \\n\")\n            fp.write(f\"Ks 0 0 0 \\n\")\n            fp.write(f\"Tr 1 \\n\")\n            fp.write(f\"illum 1 \\n\")\n            fp.write(f\"Ns 0 \\n\")\n            fp.write(f\"map_Kd {os.path.basename(albedo_path)} \\n\")\n\n        albedo = self.albedo.detach().cpu().numpy()\n        albedo = (albedo * 255).astype(np.uint8)\n        cv2.imwrite(albedo_path, cv2.cvtColor(albedo, cv2.COLOR_RGB2BGR))"
  },
  {
    "path": "lib/utils/mesh_utils.py",
    "content": "import numpy as np\nimport pymeshlab as pml\nimport torch\n\n\ndef gaussian_3d_coeff(xyzs, covs):\n    # xyzs: [N, 3]\n    # covs: [N, 6]\n    x, y, z = xyzs[:, 0], xyzs[:, 1], xyzs[:, 2]\n    a, b, c, d, e, f = covs[:, 0], covs[:, 1], covs[:, 2], covs[:, 3], covs[:, 4], covs[:, 5]\n\n    # eps must be small enough !!!\n    inv_det = 1 / (a * d * f + 2 * e * c * b - e**2 * a - c**2 * d - b**2 * f + 1e-24)\n    inv_a = (d * f - e**2) * inv_det\n    inv_b = (e * c - b * f) * inv_det\n    inv_c = (e * b - c * d) * inv_det\n    inv_d = (a * f - c**2) * inv_det\n    inv_e = (b * c - e * a) * inv_det\n    inv_f = (a * d - b**2) * inv_det\n\n    power = -0.5 * (x**2 * inv_a + y**2 * inv_d + z**2 * inv_f) - x * y * inv_b - x * z * inv_c - y * z * inv_e\n\n    power[power > 0] = -1e10 # abnormal values... make weights 0\n        \n    return torch.exp(power)\n\ndef poisson_mesh_reconstruction(points, normals=None):\n    # points/normals: [N, 3] np.ndarray\n\n    import open3d as o3d\n    \n\n    pcd = o3d.geometry.PointCloud()\n    pcd.points = o3d.utility.Vector3dVector(points)\n\n    # outlier removal\n    pcd, ind = pcd.remove_statistical_outlier(nb_neighbors=20, std_ratio=10)\n\n    # normals\n    if normals is None:\n        pcd.estimate_normals()\n    else:\n        pcd.normals = o3d.utility.Vector3dVector(normals[ind])\n\n    # visualize\n    o3d.visualization.draw_geometries([pcd], point_show_normal=False)\n\n    mesh, densities = o3d.geometry.TriangleMesh.create_from_point_cloud_poisson(\n        pcd, depth=9\n    )\n    vertices_to_remove = densities < np.quantile(densities, 0.1)\n    mesh.remove_vertices_by_mask(vertices_to_remove)\n\n    # visualize\n    o3d.visualization.draw_geometries([mesh])\n\n    vertices = np.asarray(mesh.vertices)\n    triangles = np.asarray(mesh.triangles)\n\n    print(\n        f\"[INFO] poisson mesh reconstruction: {points.shape} --> {vertices.shape} / {triangles.shape}\"\n    )\n\n    return vertices, triangles\n\n\ndef decimate_mesh(\n    verts, faces, target, backend=\"pymeshlab\", remesh=False, optimalplacement=True\n):\n    # optimalplacement: default is True, but for flat mesh must turn False to prevent spike artifect.\n\n    _ori_vert_shape = verts.shape\n    _ori_face_shape = faces.shape\n\n    if backend == \"pyfqmr\":\n        import pyfqmr\n\n        solver = pyfqmr.Simplify()\n        solver.setMesh(verts, faces)\n        solver.simplify_mesh(target_count=target, preserve_border=False, verbose=False)\n        verts, faces, normals = solver.getMesh()\n    else:\n        m = pml.Mesh(verts, faces)\n        ms = pml.MeshSet()\n        ms.add_mesh(m, \"mesh\")  # will copy!\n\n        # filters\n        # ms.meshing_decimation_clustering(threshold=pml.PercentageValue(1))\n        ms.meshing_decimation_quadric_edge_collapse(\n            targetfacenum=int(target), optimalplacement=optimalplacement\n        )\n\n        if remesh:\n            # ms.apply_coord_taubin_smoothing()\n            ms.meshing_isotropic_explicit_remeshing(\n                iterations=3, targetlen=pml.PercentageValue(1)\n            )\n\n        # extract mesh\n        m = ms.current_mesh()\n        verts = m.vertex_matrix()\n        faces = m.face_matrix()\n\n    print(\n        f\"[INFO] mesh decimation: {_ori_vert_shape} --> {verts.shape}, {_ori_face_shape} --> {faces.shape}\"\n    )\n\n    return verts, faces\n\n\ndef clean_mesh(\n    verts,\n    faces,\n    v_pct=1,\n    min_f=64,\n    min_d=20,\n    repair=True,\n    remesh=True,\n    remesh_size=0.01,\n):\n    # verts: [N, 3]\n    # faces: [N, 3]\n    import pymeshlab as pml\n    from importlib.metadata import version\n    PML_VER = version('pymeshlab') \n    import ipdb; ipdb.set_trace()\n    if PML_VER == '2022.2.post3':\n        pml.PercentageValue = pml.Percentage\n        pml.PureValue = pml.AbsoluteValue\n    _ori_vert_shape = verts.shape\n    _ori_face_shape = faces.shape\n\n    m = pml.Mesh(verts, faces)\n    ms = pml.MeshSet()\n    ms.add_mesh(m, \"mesh\")  # will copy!\n\n    # filters\n    ms.meshing_remove_unreferenced_vertices()  # verts not refed by any faces\n\n    if v_pct > 0:\n        ms.meshing_merge_close_vertices(\n            threshold=pml.Percentage(v_pct)\n        )  # 1/10000 of bounding box diagonal\n\n    ms.meshing_remove_duplicate_faces()  # faces defined by the same verts\n    ms.meshing_remove_null_faces()  # faces with area == 0\n\n    if min_d > 0:\n        ms.meshing_remove_connected_component_by_diameter(\n            mincomponentdiag=pml.Percentage(min_d)\n        )\n\n    if min_f > 0:\n        ms.meshing_remove_connected_component_by_face_number(mincomponentsize=min_f)\n\n    if repair:\n        # ms.meshing_remove_t_vertices(method=0, threshold=40, repeat=True)\n        ms.meshing_repair_non_manifold_edges(method=0)\n        ms.meshing_repair_non_manifold_vertices(vertdispratio=0)\n\n    if remesh:\n        # ms.apply_coord_taubin_smoothing()\n        ms.meshing_isotropic_explicit_remeshing(\n            iterations=3, targetlen=pml.PureValue(remesh_size)\n        )\n\n    # extract mesh\n    m = ms.current_mesh()\n    verts = m.vertex_matrix()\n    faces = m.face_matrix()\n\n    print(\n        f\"[INFO] mesh cleaning: {_ori_vert_shape} --> {verts.shape}, {_ori_face_shape} --> {faces.shape}\"\n    )\n\n    return verts, faces"
  },
  {
    "path": "lib/utils/train_util.py",
    "content": "import importlib\n\nimport os\n\nfrom pytorch_lightning.utilities import rank_zero_only\n@rank_zero_only\ndef main_print(*args):\n    print(*args)\n\n\n\ndef count_params(model, verbose=False):\n    total_params = sum(p.numel() for p in model.parameters())\n    if verbose:\n        main_print(f\"{model.__class__.__name__} has {total_params*1.e-6:.2f} M params.\")\n    return total_params\n\n\ndef instantiate_from_config(config):\n    if not \"target\" in config:\n        if config == '__is_first_stage__':\n            return None\n        elif config == \"__is_unconditional__\":\n            return None\n        raise KeyError(\"Expected key `target` to instantiate.\")\n    return get_obj_from_str(config[\"target\"])(**config.get(\"params\", dict()))\n\n\ndef get_obj_from_str(string, reload=False):\n    main_print(string)\n    module, cls = string.rsplit(\".\", 1)\n    if reload:\n        module_imp = importlib.import_module(module)\n        importlib.reload(module_imp)\n    return getattr(importlib.import_module(module, package=None), cls)\n"
  },
  {
    "path": "run_demo.py",
    "content": "import os\nos.environ[\"CUDA_VISIBLE_DEVICES\"] = \"0\"\n# os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"0,1,2,3,4,5,6,7\"\nimport argparse\nimport torch\nfrom tqdm import tqdm\nfrom torchvision.transforms import v2\nfrom pytorch_lightning import seed_everything\nfrom omegaconf import OmegaConf\nfrom tqdm import tqdm\nfrom einops import rearrange\nfrom lib.utils.infer_util import *\nfrom lib.utils.train_util import instantiate_from_config\nimport torchvision\nimport json\n###############################################################################\n# Arguments.\n###############################################################################\n\ndef parse_args():\n    \"\"\"Parse command line arguments\"\"\"\n    parser = argparse.ArgumentParser()\n    parser.add_argument('--config', type=str, help='Path to config file.', required=False)\n    parser.add_argument('--input_path', type=str, help='Path to input image or directory.', required=False)\n    parser.add_argument('--resume_path', type=str, help='Path to saved ckpt.', required=False)\n    parser.add_argument('--output_path', type=str, default='outputs/', help='Output directory.')\n    parser.add_argument('--seed', type=int, default=42, help='Random seed for sampling.')\n    parser.add_argument('--distance', type=float, default=1.5, help='Render distance.')\n    parser.add_argument('--no_rembg', action='store_true', help='Do not remove input background.')\n    parser.add_argument('--render_mode', type=str, default='novel_pose', \n                    choices=['novel_pose', 'reconstruct', 'novel_pose_A'],\n                    help='Rendering mode: novel_pose (animation), reconstruct (reconstruction), or novel_pose_A (360-degree view with A-pose)')\n\n    return parser.parse_args()\n\n###############################################################################\n# Stage 0: Configuration.\n###############################################################################\n\ndevice = torch.device('cuda')\n\n\n\ndef process_data_on_gpu(args, model, gpu_id, img_paths_list, smplx_ref_path_list, smplx_path_driven_list):\n    torch.cuda.set_device(gpu_id)\n    model = model.cuda()  \n    image_plist = []\n\n    \n    render_mode =  args.render_mode\n    \n\n    cam_idx = 0 # 12 # fixed cameras and changes pose for novel poses\n    num_imgs = 60\n    if_load_betas = True\n\n   \n    if_use_video_cam = False  # If the SMPLX sequence provides camera parameters, this can be set to true.\n    if_uniform_coordinates = True # Normalize the SMPL-X sequence for the purpose of driving.\n\n\n    for input_path, smplx_ref_path, smplx_path in tqdm(zip(img_paths_list, smplx_ref_path_list, smplx_path_driven_list), total = len(img_paths_list)):\n        print(f\"Processing: {input_path}\")\n\n        args.input_path = input_path\n        args.input_path_smpl = smplx_ref_path\n\n        # get a name for results\n        name = get_name_str(args.input_path) + get_name_str(smplx_path)\n\n        ###############################################################################\n        # Stage 1: Parameters loading\n        ###############################################################################\n  \n        ''' # Stage 1.1: SMPLX loading (Beta)'''\n        if args.input_path_smpl is not None:\n            # smplx = np.load(args.input_path_smpl, allow_pickle=True).item()\n            smplx = json.load(open(args.input_path_smpl))\n            if \"shapes\" in smplx.keys():\n                smplx['betas'] = smplx['shapes']\n            else:\n                smplx['betas'] = smplx['betas_save']\n        smpl_params = torch.zeros(1, 189).to(device)\n        if if_load_betas:\n            smpl_params[:, 70:80] = torch.Tensor(smplx['betas']).to(device)\n\n        ''' # Stage 1.2: SMPLX loading (Pose)'''\n        # animation\n        if render_mode in ['novel_pose'] : \n\n            if smplx_path.endswith(\".npy\"):\n                smpl_params = load_smplx_from_npy(smplx_path)\n                smpl_params[:, 70:80] = torch.Tensor(smplx['betas']).to(device)\n                    \n                # ========= Note: If the video camera is not used, center everything at the origin ========\n                if if_uniform_coordinates:\n                    print(\"''' Ending --- Adjusting root orientation angles '''\")\n                    \n                    # Extract root orientation and translation from SMPL parameters\n                    root_orient = smpl_params[:, 4:7]  \n                    trans = smpl_params[:, 1:4]  \n                    \n                    # Reset the first frame's rotation and adjust translations\n                    new_root_orient, new_trans = reset_first_frame_rotation(root_orient, trans)\n                    \n                    # Update the root orientation and translation in the SMPL parameters\n                    smpl_params[:, 4:7] = new_root_orient\n                    smpl_params[:, 1:4] = new_trans.squeeze()  # Apply the new translation\n\n                \n            elif smplx_path.endswith(\".json\"):  \n                ''' for motion-x input '''\n                smpl_params = load_smplx_from_json(smplx_path)\n                smpl_params[:, 70:80] = torch.Tensor(smplx['betas']).to(device)\n                if_use_video_cam = True \n                \n\n        elif render_mode in ['reconstruct']:\n            RT_rec, intri_rec, smpl_rec = load_smplify_json(smplx_ref_path)\n            \n            H_rec, W_rec = get_image_dimensions(input_path)\n\n            '''Apply root rotation for a full 360-degree view of the object'''\n            if_add_root_rotate = True\n            if if_add_root_rotate == True:\n                \n                smpl_params = add_root_rotate_to_smplx(smpl_rec, num_imgs)\n                print(\" '''ending ---  invert the root angles'''\")\n            else:\n                smpl_params = smpl_params.to(device)\n                num_imgs = 1\n\n        elif render_mode in ['novel_pose_A']:\n            smpl_params = model.get_default_smplx_params().squeeze()\n            smpl_params = smpl_params.to(device)\n            smpl_params = add_root_rotate_to_smplx(smpl_params.clone(), num_imgs)\n            smpl_params[:, 70:80] = torch.Tensor(smplx['betas']).to(device)\n\n        else:\n            raise NotImplementedError(f\"Render mode '{render_mode}' is not supported.\")\n\n        '''# Stage 1.3: Image loading '''\n        image = load_image(args.input_path, args.output_folders['ref'])\n        H,W = 896,640\n        image_bs = image.unsqueeze(0).to(device)\n        num_imgs = 180\n\n        ''' # Stage 1.4 Camera loading'''\n        if not if_use_video_cam:\n            # prepare cameras\n            K, cam_list = prepare_camera(resolution_x=H, resolution_y=W, num_views=num_imgs, stides=1)\n            cameras = construct_camera(K, cam_list)\n\n            if render_mode == 'novel_pose': # if poses are changed, cameras will be fixed\n                intrics = torch.Tensor([K[0,0],K[1,1], 256, 256]).reshape(-1)\n                model.decoder.renderer.image_size = [512, 512] \n                \n                assert cameras.shape[-1] == 20\n                cameras[:, :4] = intrics\n                cameras = cameras[cam_idx:cam_idx+1]\n                num_imgs = smpl_params.shape[0]\n                cameras = cameras.repeat(num_imgs, 1)\n                cameras = cameras[:, None, :] # length of the pose sequences \n                print(\"modify the render images's resolution into 512x512 \")\n\n            elif render_mode in ['reconstruct']: # using reference smplify's smplx and camera\n                cameras = torch.concat([intri_rec.reshape(-1,4), RT_rec.reshape(-1, 16)], dim=1)\n                # H, W = int(intri_rec[2] * 2), int(intri_rec[3] * 2)\n                model.decoder.renderer.image_size = [W_rec, H_rec]; print(f\"modify the render images's resolution into {H_rec}x{W_rec}\")\n                cameras = cameras.reshape(1,1,20).expand(num_imgs,1,-1)\n                cameras = cameras.cuda()\n                \n            elif render_mode == 'novel_pose_A':\n                model.decoder.renderer.image_size = [W, H] \n                cameras = cameras[0].reshape(1,1,20).expand(num_imgs,1,-1)\n\n        elif if_use_video_cam: # for the animation with motion-x\n            cameras = construct_camera_from_motionx(smplx_path)\n            H, W = 2*cameras[0, 0, [3]].int().item(), 2*cameras[0,0, [2]].int().item()\n            model.decoder.renderer.image_size = [W, H]; print(f\"modify the render images's resolution into {H}x{W}\")\n            # model.decoder.renderer = \n\n        ###############################################################################\n        # Stage 2: Reconstruction.\n        ###############################################################################\n      \n        sample = image_bs[[0]] # N, 3, H, W,\n        # if if_use_dataset:\n        #     sample = rearrange(sample, 'b h w c -> b c h w') # N, 3, H, W,\n\n        image_path_idx = os.path.join(args.output_folders['ref'], f'{name}_ref.jpg')\n        torchvision.utils.save_image( sample[0], image_path_idx)\n\n        with torch.no_grad():\n            # get latents\n            code = model.forward_image_to_uv(sample, is_training=False)\n\n        with torch.no_grad():\n            output_list = []\n            num_imgs_batch = 5\n            total_frames = min(smpl_params.shape[0],300)\n            res_uv = None\n            for i in tqdm(range(0, total_frames, num_imgs_batch)):\n                if i+num_imgs_batch > total_frames:\n                    num_imgs_batch = total_frames - i\n                code_bt = code.expand(num_imgs_batch, -1, -1, -1)\n                # cameras_bt = cameras.expand(num_imgs_batch, -1, -1)\n                cameras_bt = cameras[i:i+num_imgs_batch]\n\n                if render_mode in ['reconstruct', 'novel_pose_A'] and res_uv is not None:\n                    pass \n                else:\n                    res_uv = model.decoder._decode_feature(code_bt) # Decouple UV attributes\n                    res_points = model.decoder._sample_feature(res_uv) # Sampling\n                # Animate\n                res_def_points = model.decoder.deform_pcd(res_points, smpl_params[i:i+num_imgs_batch].to(code_bt.dtype), zeros_hands_off=True, value=0.02) \n                output = model.decoder.forward_render(res_def_points, cameras_bt.to(code_bt.dtype), num_imgs=1)\n                image = output[\"image\"][:, 0].cpu().to(torch.float32)\n\n                print(\"output shape \", output[\"image\"][:, 0].shape)\n                output_list.append(image) # [:, 0] stands to get the all scenes (poses)\n                del output\n\n            output = torch.concatenate(output_list, 0)\n            frames = rearrange(output, \"b h w c -> b c h w\")#.cpu().numpy()\n\n            video_path_idx = os.path.join(args.output_folders['video'], f'{name}.mp4')\n\n            save_video(\n                frames[:,:4,...].to(torch.float32),\n                video_path_idx,\n            )\n            image_plist.append(frames)\n            print(\"saving into \", video_path_idx)\n    return image_plist\n\ndef setup_directories(base_path, config_name):\n    \"\"\"Create output directories for results\"\"\"\n    dirs = {\n        'image': os.path.join(base_path, config_name, 'images'),\n        'video': os.path.join(base_path, config_name),\n        'ref': os.path.join(base_path, config_name)\n    }\n    for path in dirs.values():\n        os.makedirs(path, exist_ok=True)\n    return dirs\n\ndef main():\n    \"\"\"Main execution function\"\"\"\n    # Parse arguments and set random seed\n    args = parse_args()\n\n    args.config = \"configs/idol_v0.yaml\"\n    args.resume_path = \"work_dirs/ckpt/model.ckpt\"\n\n    config = OmegaConf.load(args.config)\n    config_name = os.path.basename(args.config).replace('.yaml', '')\n    model_config = config.model\n\n    resume_path =  args.resume_path\n    # Initialize model\n    model = instantiate_from_config(model_config)\n    model.encoder = model.encoder.to(torch.bfloat16) ; print(\"moving encoder to bf16\")\n    model = model.__class__.load_from_checkpoint(resume_path, **config.model.params)\n    model = model.to(device)\n    model = model.eval()\n\n    # Setup input paths\n    img_paths_list = ['work_dirs/demo_data/4.jpg']\n    smplx_ref_path_list = ['work_dirs/demo_data/4.json']\n    smplx_path_driven_list = ['work_dirs/demo_data/Ways_to_Catch_360_clip1.json']  \n    # smplx_path_driven_list = ['work_dirs/demo_data/finedance-5-144.npy.npy']  \n\n    # Setup output directories\n    # args.output_path = \"./test/\"\n    # args.render_mode = 'reconstruct' # 'novel_pose_A' #'reconstruct' #'novel_pose' \n        \n    # make output directories\n    args.output_folders = setup_directories(args.output_path, config_name)\n\n    # Process data\n    image_plist = process_data_on_gpu(\n        args,\n        model, 0, \n        img_paths_list, \n        smplx_ref_path_list, \n        smplx_path_driven_list\n    )\n\n    return image_plist\n\nif __name__ == \"__main__\":\n    main()\n\n\n"
  },
  {
    "path": "scripts/download_files.sh",
    "content": "#!/bin/bash\n\n# Create necessary directories\nmkdir -p work_dirs/\nmkdir -p work_dirs/ckpt\n\n# Download files from HuggingFace\necho \"Downloading model files...\"\nwget https://huggingface.co/yiyuzhuang/IDOL/resolve/main/model.ckpt -O work_dirs/ckpt/model.ckpt\nwget https://huggingface.co/yiyuzhuang/IDOL/resolve/main/sapiens_1b_epoch_173_torchscript.pt2 -O work_dirs/ckpt/sapiens_1b_epoch_173_torchscript.pt2\nwget https://huggingface.co/yiyuzhuang/IDOL/resolve/main/cache_sub2.zip -O work_dirs/cache_sub2.zip\n\n# Unzip cache file\necho \"Extracting cache files...\"\nunzip -o work_dirs/cache_sub2.zip -d work_dirs/\nrm work_dirs/cache_sub2.zip  # Remove zip file after extraction\n\necho \"Download and extraction completed!\"\n\n\n"
  },
  {
    "path": "scripts/fetch_template.sh",
    "content": "#!/bin/bash\nurle () { [[ \"${1}\" ]] || return 1; local LANG=C i x; for (( i = 0; i < ${#1}; i++ )); do x=\"${1:i:1}\"; [[ \"${x}\" == [a-zA-Z0-9.~-] ]] && echo -n \"${x}\" || printf '%%%02X' \"'${x}\"; done; echo; }\n\nmkdir -p lib/models/deformers/smplx/SMPLX\n\n# username and password input\necho -e \"\\nYou need to register at https://smpl-x.is.tue.mpg.de/, according to Installation Instruction.\"\nread -p \"Username (SMPL-X):\" username\nread -p \"Password (SMPL-X):\" password\nusername=$(urle $username)\npassword=$(urle $password)\n\n# SMPLX \necho -e \"\\nDownloading SMPL-X model...\"\nwget --post-data \"username=$username&password=$password\" 'https://download.is.tue.mpg.de/download.php?domain=smplx&sfile=models_smplx_v1_1.zip' -O 'models_smplx_v1_1.zip' --no-check-certificate --continue\nunzip models_smplx_v1_1.zip -d lib/models/deformers/smplx/SMPLX\nmv lib/models/deformers/smplx/SMPLX/models/smplx/* lib/models/deformers/smplx/SMPLX\nrm -rf lib/models/deformers/smplx/SMPLX/models\nrm -f models_smplx_v1_1.zip\n\nmkdir -p work_dirs/cache/template\n\ncd work_dirs/cache/template\necho -e \"\\nDownloading SMPL-X segmentation info...\"\nwget https://github.com/Meshcapade/wiki/blob/main/assets/SMPL_body_segmentation/smplx/smplx_vert_segmentation.json\n\necho -e \"\\nDownloading SMPL-X UV info...\"\nwget --post-data \"username=$username&password=$password\" 'https://download.is.tue.mpg.de/download.php?domain=smplx&sfile=smplx_uv.zip' -O './smplx_uv.zip' --no-check-certificate --continue\nunzip smplx_uv.zip -d ./smplx_uv\nmv smplx_uv/smplx_uv.obj ./\nrm -f smplx_uv.zip\nrm -rf smplx_uv\n\necho -e \"\\nDownloading SMPL-X FLAME correspondence info...\"\nwget --post-data \"username=$username&password=$password\" 'https://download.is.tue.mpg.de/download.php?domain=smplx&sfile=smplx_mano_flame_correspondences.zip' -O './smplx_mano_flame_correspondences.zip' --no-check-certificate --continue\nunzip smplx_mano_flame_correspondences.zip -d ./smplx_mano_flame_correspondences\nmv smplx_mano_flame_correspondences/SMPL-X__FLAME_vertex_ids.npy ./\nrm -f smplx_mano_flame_correspondences.zip\nrm -rf smplx_mano_flame_correspondences\n\necho -e \"\\nDownloading FLAME template from neural-head-avatars repo...\"\nwget https://raw.githubusercontent.com/philgras/neural-head-avatars/main/assets/flame/head_template_mesh_mouth.obj\n\necho -e \"\\nDownloading FLAME template from DECA repo...\"\nwget https://raw.githubusercontent.com/yfeng95/DECA/master/data/head_template.obj\n\necho -e \"\\nYou need to register at http://flame.is.tue.mpg.de/, according to Installation Instruction.\"\nread -p \"Username (FLAME):\" username\nread -p \"Password (FLAME):\" password\nusername=$(urle $username)\npassword=$(urle $password)\n\necho -e \"\\nDownloading FLAME segmentation info...\"\nwget 'https://files.is.tue.mpg.de/tbolkart/FLAME/FLAME_masks.zip' -O 'FLAME_masks.zip' --no-check-certificate --continue\nunzip FLAME_masks.zip -d ./FLAME_masks\nmv FLAME_masks/FLAME_masks.pkl ./\nrm -f FLAME_masks.zip\nrm -rf FLAME_masks\n\ncd ../../..\n\necho -e \"\\n Finish\"\n\n\n\n"
  },
  {
    "path": "scripts/pip_install.sh",
    "content": "#!/bin/bash\n\n# Complete environment setup process\n\n# Step 0: Ensure you create a Conda environment \n# and Activate the environment\n# conda activate idol\n\n# Step 1: Install Pytorch with CUDA:\npip install torch==2.3.1+cu118 torchvision==0.18.1+cu118 torchaudio==2.3.1+cu118 \\\n--index-url https://download.pytorch.org/whl/cu118\n\n# Step 2: Use pip to install additional dependencies\npip_packages=(\n    \"absl-py==2.1.0\"\n    \"accelerate==0.29.1\"\n    \"addict==2.4.0\"\n    \"albumentations==1.4.17\"\n    \"bitsandbytes\"\n    \"deepspeed==0.15.1\"\n    \"diffusers==0.20.2\"\n    \"einops==0.8.0\"\n    \"fastapi==0.111.0\"\n    \"gradio==3.41.2\"\n    \"matplotlib==3.8.4\"\n    \"numpy==1.26.3\"\n    \"opencv-python==4.9.0.80\"\n    \"pandas==2.2.2\"\n    \"pillow==10.3.0\"\n    \"scikit-image==0.23.2\"\n    \"scipy==1.13.0\"\n    \"timm==0.9.16\"\n    \"transformers==4.40.1\"\n    \"pytorch-lightning==2.3.1\"\n    \"omegaconf==2.3.0\"\n    \"av\"\n    \"webdataset\"\n    \"omegaconf\"\n    \"rembg==2.0.57\"\n    \"tensorboard\"\n)\n\nInstall pip packages in bulk\nfor package in \"${pip_packages[@]}\"\ndo\n    pip install \"$package\"\ndone\n\n\n# Create submodule directory if it doesn't exist\nmkdir -p submodule\ncd submodule\n\n# Step 3: Install PyTorch3D\ngit clone https://github.com/facebookresearch/pytorch3d.git\ncd pytorch3d\ngit checkout v0.7.7  \npip install -e .\ncd ..\n\n# Step 4: Install Simple-KNN\ngit clone https://gitlab.inria.fr/bkerbl/simple-knn.git \ncd simple-knn\npip install -e .\ncd ..\n\n# Step 5: Install Gaussian Splatting\ngit clone https://github.com/graphdeco-inria/gaussian-splatting --recursive\ncd gaussian-splatting/submodules/diff-gaussian-rasterization\npython setup.py develop\ncd ../../..\n\n# Step 6: Install Sapiens\ngit clone https://github.com/facebookresearch/sapiens\ncd sapiens/engine\npip install -e .\ncd ../pretrain\npip install -e .\ncd ../../..\n\n# Step 7: Install deformation module\npython setup.py develop\n\necho \"idol environment setup completed!\"\n"
  },
  {
    "path": "setup.py",
    "content": "from setuptools import setup\r\nfrom torch.utils.cpp_extension import BuildExtension, CUDAExtension\r\nimport os\r\ncuda_dir =  \"lib/models/deformers/fast_snarf/cuda\"\r\n\r\nsetup(\r\n    name='fuse',\r\n    ext_modules=[\r\n        CUDAExtension('fuse_cuda', \r\n        [f'{cuda_dir}/fuse_kernel/fuse_cuda.cpp',\r\n        f'{cuda_dir}/fuse_kernel/fuse_cuda_kernel.cu']),\r\n        CUDAExtension('filter_cuda', \r\n        [f'{cuda_dir}/filter/filter.cpp',\r\n        f'{cuda_dir}/filter/filter_kernel.cu']),\r\n        CUDAExtension('precompute_cuda', \r\n        [f'{cuda_dir}/precompute/precompute.cpp',\r\n        f'{cuda_dir}/precompute/precompute_kernel.cu'])\r\n    ],\r\n    cmdclass={\r\n        'build_ext': BuildExtension\r\n    })\r\n\r\n"
  },
  {
    "path": "train.py",
    "content": "import os, sys\n# os.environ[\"WANDB_MODE\"] = \"dryrun\" # default setting to save locally\nfrom lib.utils.train_util import main_print\nimport torch\n# Check GPU information\nif torch.cuda.is_available():\n    gpu_info = torch.cuda.get_device_name()\n    if \"H20\" in gpu_info or \"H800\" in gpu_info:\n        os.environ[\"NCCL_SOCKET_IFNAME\"] = \"bond1\" # for H20  # If using H20 GPU, set network interface\n        main_print(\"changing the network interface to bond1\")\n    if \"H800\" in gpu_info:\n        # Set precision for matrix multiplication\n        torch.set_float32_matmul_precision('medium')  # or 'high'\n    \n\nimport argparse\nimport shutil\nimport subprocess\nfrom omegaconf import OmegaConf\n\nimport torch\nfrom pytorch_lightning import seed_everything\nfrom pytorch_lightning.trainer import Trainer\nfrom pytorch_lightning.strategies import DDPStrategy\nfrom pytorch_lightning.strategies import DeepSpeedStrategy\nfrom pytorch_lightning.callbacks import Callback\nfrom pytorch_lightning.utilities import rank_zero_only\n\nfrom lib.utils.train_util import instantiate_from_config\n\nfrom pytorch_lightning import loggers as pl_loggers\n\n\ndef get_parser(**parser_kwargs):\n    def str2bool(v):\n        if isinstance(v, bool):\n            return v\n        if v.lower() in (\"yes\", \"true\", \"t\", \"y\", \"1\"):\n            return True\n        elif v.lower() in (\"no\", \"false\", \"f\", \"n\", \"0\"):\n            return False\n        else:\n            raise argparse.ArgumentTypeError(\"Boolean value expected.\")\n\n    parser = argparse.ArgumentParser(**parser_kwargs)\n    parser.add_argument(\n        \"-r\",\n        \"--resume\",\n        type=str,\n        default=None,\n        help=\"resume from checkpoint\",\n    )\n    parser.add_argument(\n        \"--resume_weights_only\",\n        action=\"store_true\",\n        help=\"only resume model weights\",\n    )\n    parser.add_argument(\n        \"--resume_not_loading_decoder\",\n        action=\"store_true\",\n        help=\"only resume model weights excepts decoder\",\n    )\n    # parser.add_argument(\n    #     \"--custom_loading_for_PA\",\n    #     action=\"store_true\",\n    #     help=\"customly loading the PA network\",\n    # )\n    parser.add_argument(\n        \"-b\",\n        \"--base\",\n        type=str,\n        default=\"base_config.yaml\",\n        help=\"path to base configs\",\n    )\n    parser.add_argument(\n        \"-n\",\n        \"--name\",\n        type=str,\n        default=\"\",\n        help=\"experiment name\",\n    )\n    parser.add_argument(\n        \"--num_nodes\",\n        type=int,\n        default=1,\n        help=\"number of nodes to use\",\n    )\n    parser.add_argument(\n        \"--gpus\",\n        type=str,\n        default=\"0,\",\n        help=\"gpu ids to use\",\n    )\n    parser.add_argument(\n        \"-s\",\n        \"--seed\",\n        type=int,\n        default=42,\n        help=\"seed for seed_everything\",\n    )\n    parser.add_argument(\n        \"-l\",\n        \"--logdir\",\n        type=str,\n        default=\"logs\",\n        help=\"directory for logging data\",\n    )\n    parser.add_argument(\n        \"--test_sd\",\n        type=str,\n        default=\"\",\n        help=\"path to state dict for testing\",\n    )\n    parser.add_argument(\n        \"--test_dataset\",\n        type=str,\n        default=\"./configs/test_dataset.yaml\",\n        help=\"path to state dict for testing\",\n    )\n    parser.add_argument(\n        \"--is_debug\",\n        action=\"store_true\",\n        help=\"flag to specify if in debug mode, if true, it will returns more results\",\n    )\n    parser.add_argument(\n        \"--training_mode\",\n        type=str,\n        default=None,\n        help=\"flag to specify the training strategy\",\n    )\n    return parser\n\n\nclass SetupCallback(Callback):\n    def __init__(self, resume, logdir, ckptdir, cfgdir, config):\n        super().__init__()\n        self.resume = resume\n        self.logdir = logdir\n        self.ckptdir = ckptdir\n        self.cfgdir = cfgdir\n        self.config = config\n\n    def on_fit_start(self, trainer, pl_module):\n        if trainer.global_rank == 0:\n            # Create logdirs and save configs\n            os.makedirs(self.logdir, exist_ok=True)\n            os.makedirs(self.ckptdir, exist_ok=True)\n            os.makedirs(self.cfgdir, exist_ok=True)\n\n            main_print(\"Project config\")\n            main_print(OmegaConf.to_yaml(self.config))\n            OmegaConf.save(self.config,\n                           os.path.join(self.cfgdir, \"project.yaml\"))\n\n\nclass CodeSnapshot(Callback):\n    \"\"\"\n    Modified from https://github.com/threestudio-project/threestudio/blob/main/threestudio/utils/callbacks.py#L60\n    \"\"\"\n    def __init__(self, savedir, exclude_patterns=None):\n        self.savedir = savedir\n        # Default excluded files and folders patterns\n        self.exclude_patterns = exclude_patterns or [\n            \"*.mp4\", \"*.npy\", \"work_dirs/*\", \"processed_data/*\", \"logs/*\"\n        ]\n\n    def get_file_list(self):\n        # Get git tracked files, excluding configs directory\n        tracked_files = subprocess.check_output(\n            'git ls-files -- \":!:configs/*\"', shell=True\n        ).splitlines()\n        \n        # Get untracked but not ignored files\n        untracked_files = subprocess.check_output(\n            \"git ls-files --others --exclude-standard\", shell=True\n        ).splitlines()\n        \n        # Merge file lists and decode\n        all_files = [b.decode() for b in set(tracked_files) | set(untracked_files)]\n        \n        # Apply exclusion pattern filtering\n        filtered_files = []\n        for file_path in all_files:\n            should_exclude = False\n            for pattern in self.exclude_patterns:\n                if self._match_pattern(file_path, pattern):\n                    should_exclude = True\n                    break\n            if not should_exclude:\n                filtered_files.append(file_path)\n                \n        return filtered_files\n    \n    def _match_pattern(self, file_path, pattern):\n        \"\"\"Check if file path matches the given pattern\"\"\"\n        # Handle directory wildcard patterns (e.g., work_dirs/*)\n        if pattern.endswith('/*'):\n            dir_prefix = pattern[:-1]  # Remove '*'\n            return file_path.startswith(dir_prefix)\n        \n        # Handle file extension patterns (e.g., *.mp4)\n        if pattern.startswith('*'):\n            ext = pattern[1:]  # Get extension part\n            return file_path.endswith(ext)\n        \n        # Exact matching\n        return file_path == pattern\n\n    @rank_zero_only\n    def save_code_snapshot(self):\n        os.makedirs(self.savedir, exist_ok=True)\n        for f in self.get_file_list():\n            if not os.path.exists(f) or os.path.isdir(f):\n                continue\n            os.makedirs(os.path.join(self.savedir, os.path.dirname(f)), exist_ok=True)\n            shutil.copyfile(f, os.path.join(self.savedir, f))\n\n    def on_fit_start(self, trainer, pl_module):\n        try:\n            self.save_code_snapshot()\n        except:\n            main_print(\n                \"Code snapshot is not saved. Please make sure you have git installed and are in a git repository.\"\n            )\n\n\nif __name__ == \"__main__\":\n    # add cwd for convenience and to make classes in this file available when\n    # running as `python main.py`\n    sys.path.append(os.getcwd())\n\n    parser = get_parser()\n    opt, unknown = parser.parse_known_args()\n\n    cfg_fname = os.path.split(opt.base)[-1]\n    cfg_name = os.path.splitext(cfg_fname)[0]\n    exp_name = \"-\" + opt.name if opt.name != \"\" else \"\"\n    logdir = os.path.join(opt.logdir, cfg_name+exp_name)\n\n\n    # init configs\n    config = OmegaConf.load(opt.base)\n    lightning_config = config.lightning\n    trainer_config = lightning_config.trainer\n\n    # modify some config for debug mode\n    if opt.is_debug:\n\n        lightning_config['trainer']['val_check_interval'] = 1\n        exp_name = 'debug'\n        logdir = os.path.join(opt.logdir, cfg_name+exp_name)\n        config.model.params['is_debug'] = True\n        config.dataset.batch_size = 1 #ss\n        config.dataset.num_workers = 1\n        config.dataset.params.train.params.cache_path = config.dataset.params.debug_cache_path\n\n    ckptdir = os.path.join(logdir, \"checkpoints\")\n    cfgdir = os.path.join(logdir, \"configs\")\n    codedir = os.path.join(logdir, \"code\")\n    seed_everything(opt.seed)\n    \n\n    main_print(f\"Running on GPUs {opt.gpus}\")\n    ngpu = len(opt.gpus.strip(\",\").split(','))\n    trainer_config['devices'] = ngpu\n\n    trainer_opt = argparse.Namespace(**trainer_config)\n    lightning_config.trainer = trainer_config\n\n    # testing setting\n    if len(opt.test_sd) > 0:\n        config_dataset = OmegaConf.load(opt.test_dataset)\n        config.dataset = config_dataset.dataset\n\n\n    precision_config = {'precision':\"bf16\"}\n\n    # model\n    model = instantiate_from_config(config.model)\n    if precision_config['precision'] == \"bf16\":\n        model.encoder = model.encoder.to(torch.bfloat16)\n    if opt.resume and opt.resume_weights_only:\n        if opt.resume_not_loading_decoder:\n            main_print(\"========Loading only model weights excepts decoder ==============\")\n            # Load complete state dictionary\n            state_dict = torch.load(opt.resume, map_location='cpu')['state_dict']\n            # Create a new state dictionary only containing the parts you want to load\n            new_state_dict = {k: v for k, v in state_dict.items() if not (k.startswith('encoder') or k.startswith('decoder') or k.startswith('lpips'))}\n            # Load the remaining state dictionary\n            model.load_state_dict(state_dict, strict=False)\n            del state_dict\n\n\n   \n        with torch.amp.autocast( device_type='cpu'):\n            state_dict = torch.load(opt.resume, map_location='cpu')['state_dict']\n            main_print([k for k in state_dict.keys()  if not k.startswith('lpips') ])\n            new_state_dict = {k: v for k, v in state_dict.items() if not k.startswith('lpips')}\n            model.load_state_dict(new_state_dict, strict=False)\n        model = model.to('cuda')\n\n\n    model.logdir = logdir\n\n    # trainer and callbacks\n    trainer_kwargs = dict()\n\n    # logger\n    param_log = { \n        'save_dir': logdir,\n        'name': cfg_name+exp_name,\n    }\n    trainer_kwargs[\"logger\"] = [ \n        pl_loggers.TensorBoardLogger(**param_log),\n      pl_loggers.CSVLogger(**param_log)\n\n    ]\n\n    # model checkpoint\n    default_modelckpt_cfg = {\n        \"target\": \"pytorch_lightning.callbacks.ModelCheckpoint\",\n        \"params\": {\n            \"dirpath\": ckptdir,\n            \"filename\": \"{step:08}\",\n            \"verbose\": True,\n            \"save_last\": True,\n            \"every_n_train_steps\": 5000,\n            \"save_top_k\": -1,   # save all checkpoints\n        }\n    }\n\n    if \"modelcheckpoint\" in lightning_config:\n        modelckpt_cfg = lightning_config.modelcheckpoint\n    else:\n        modelckpt_cfg = OmegaConf.create()\n    modelckpt_cfg = OmegaConf.merge(default_modelckpt_cfg, modelckpt_cfg)\n\n    # callbacks\n    default_callbacks_cfg = {\n        \"setup_callback\": {\n            \"target\": \"train.SetupCallback\",\n            \"params\": {\n                \"resume\": opt.resume,\n                \"logdir\": logdir,\n                \"ckptdir\": ckptdir,\n                \"cfgdir\": cfgdir,\n                \"config\": config,\n            }\n        },\n        \"learning_rate_logger\": {\n            \"target\": \"pytorch_lightning.callbacks.LearningRateMonitor\",\n            \"params\": {\n                \"logging_interval\": \"step\",\n            }\n        },\n        \"code_snapshot\": {\n            \"target\": \"train.CodeSnapshot\",\n            \"params\": {\n                \"savedir\": codedir,\n            }\n        },\n    }\n    default_callbacks_cfg[\"checkpoint_callback\"] = modelckpt_cfg\n\n    if \"callbacks\" in lightning_config:\n        callbacks_cfg = lightning_config.callbacks\n    else:\n        callbacks_cfg = OmegaConf.create()\n    callbacks_cfg = OmegaConf.merge(default_callbacks_cfg, callbacks_cfg)\n\n    trainer_kwargs[\"callbacks\"] = [\n        instantiate_from_config(callbacks_cfg[k]) for k in callbacks_cfg]\n\n\n\n    training_mode = \"DDP\" if  opt.training_mode is None else  opt.training_mode\n\n    if training_mode == 'DDP':\n        trainer_kwargs[\"strategy\"] = DDPStrategy(find_unused_parameters=False, static_graph=True) # TODO modify to True\n    elif training_mode == 'ZERO':\n        #  DeepSpeed \n        strategy = DeepSpeedStrategy(config='./configs/deepspeed_config.json')\n        trainer_kwargs[\"strategy\"] = strategy# TODO modify to True\n    elif training_mode == 'FSDP':\n        from pytorch_lightning.strategies import FSDPStrategy\n        fsdp_strategy = FSDPStrategy(\n            auto_wrap_policy=None,  \n            activation_checkpointing_policy=None,  \n            cpu_offload=False,  # Whether to offload model parameters to CPU\n            limit_all_gathers=False,  # Whether to limit all gather operations\n            sync_module_states=True,  # Whether to synchronize module states\n            # use_sharded_checkpoint=True,  # Whether to use sharded checkpoints\n            mixed_precision='bf16',  # Mixed precision training, default is 'bf16'\n        )\n        trainer_kwargs[\"strategy\"] = fsdp_strategy\n    main_print(f\" ............ trying training strategy {training_mode} ...........\")\n\n\n\n    trainer = Trainer(**precision_config, **trainer_config, **trainer_kwargs, num_nodes=opt.num_nodes)\n    trainer.logdir = logdir\n\n    # data\n    data = instantiate_from_config(config.dataset)\n    data.prepare_data()\n    data.setup(\"fit\")\n    \n    \n    # configure learning rate\n    base_lr = config.model.params.neck_learning_rate\n    if 'accumulate_grad_batches' in lightning_config.trainer:\n        accumulate_grad_batches = lightning_config.trainer.accumulate_grad_batches\n    else:\n        accumulate_grad_batches = 1\n    main_print(f\"accumulate_grad_batches = {accumulate_grad_batches}\")\n    lightning_config.trainer.accumulate_grad_batches = accumulate_grad_batches\n    model.learning_rate = base_lr\n    main_print(\"++++ NOT USING LR SCALING ++++\")\n    main_print(f\"Setting learning rate to {model.learning_rate:.2e}\")\n\n    # trainer.fit(model, data) # debug \n    \n    if len(opt.test_sd) > 0:\n        sd = torch.load(opt.test_sd, map_location='cpu')\n        model.load_state_dict(sd, strict=False)\n        with torch.amp.autocast(device_type='cpu'):\n            # import ipdb; ipdb.set_trace()\n            def load_folder_ckpt(checkpoint_dir):\n                # For DeepSpeed loading\n                # Get all .pt files\n                pt_files = [os.path.join(checkpoint_dir, f) for f in os.listdir(checkpoint_dir) if f.endswith('.pt')]\n                # Initialize model state dictionary\n                model_state_dict = {}\n                # Load each .pt file and merge into model state dictionary\n                for pt_file in pt_files:\n                    state_dict = torch.load(pt_file, map_location='cpu')\n                    model_state_dict.update(state_dict)\n                return model_state_dict\n            if os.path.isdir(opt.test_sd):\n                state_dict = load_folder_ckpt(opt.test_sd+\"/checkpoint\")\n                # Load checkpoint\n                success = model.load_checkpoint(opt.test_sd, load_optimizer_states=True, load_lr_scheduler_states=True)\n\n            else:\n                state_dict = torch.load(opt.test_sd, map_location='cpu')['state_dict']\n            main_print([k for k in state_dict.keys()  if not k.startswith('lpips') ])\n            new_state_dict = {k: v for k, v in state_dict.items() if not k.startswith('lpips')}\n            new_state_dict = {k: v for k, v in new_state_dict.items() if not k.startswith('encoder')}\n            model.load_state_dict(new_state_dict, strict=False)\n            main_print(f\"========testing =====, loading from {opt.test_sd} ================\")\n        model = model.to('cuda')\n\n        with torch.no_grad():\n            trainer.test(model, data)\n            \n    else:\n        # run training loop\n        try:\n            if opt.resume and not opt.resume_weights_only:\n                trainer.fit(model, data, ckpt_path=opt.resume)\n            else:\n                trainer.fit(model, data)\n        except  Exception as e:\n            main_print(f\"An error occurred: {e}\")\n            torch.cuda.empty_cache()\n"
  }
]