[
  {
    "path": ".gitignore",
    "content": "__pycache__/\ncheckpoint/\ndatasets/\nevaluations/\nfull_models/\npretrained_renderer/\nwandb/\n"
  },
  {
    "path": "LICENSE",
    "content": "Copyright (C) 2022 Roy Or-El, Xuan Luo, Mengyi Shan, Eli Shechtman, Jeong Joon Park, Ira Kemelmacher-Shlizerman.\nAll rights reserved.\nLicensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).\n\nThe software is made available under Creative Commons BY-NC-SA 4.0 license\nby University of Washington. You can use, redistribute, and adapt it\nfor non-commercial purposes, as long as you (a) give appropriate credit\nby citing our paper, (b) indicate any changes that you've made,\nand (c) distribute any derivative works under the same license.\n\nTHE AUTHORS DISCLAIM ALL WARRANTIES WITH REGARD TO THIS SOFTWARE, INCLUDING ALL\nIMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR ANY PARTICULAR PURPOSE.\nIN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY SPECIAL, INDIRECT OR CONSEQUENTIAL\nDAMAGES OR ANY DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS,\nWHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING\nOUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.\n"
  },
  {
    "path": "README.md",
    "content": "# StyleSDF: High-Resolution 3D-Consistent Image and Geometry Generation\n### [Project Page](https://stylesdf.github.io/) | [Paper](https://arxiv.org/pdf/2112.11427.pdf) | [HuggingFace Demo](https://huggingface.co/spaces/SerdarHelli/StyleSDF-3D)\n\n[![Explore in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/royorel/StyleSDF/blob/main/StyleSDF_demo.ipynb)<br>\n\n[Roy Or-El](https://homes.cs.washington.edu/~royorel/)<sup>1</sup> ,\n[Xuan Luo](https://roxanneluo.github.io/)<sup>1</sup>,\n[Mengyi Shan](https://shanmy.github.io/)<sup>1</sup>,\n[Eli Shechtman](https://research.adobe.com/person/eli-shechtman/)<sup>2</sup>,\n[Jeong Joon Park](https://jjparkcv.github.io/)<sup>3</sup>,\n[Ira Kemelmacher-Shlizerman](https://www.irakemelmacher.com/)<sup>1</sup><br>\n<sup>1</sup>University of Washington, <sup>2</sup>Adobe Research, <sup>3</sup>Stanford University\n\n<div align=\"center\">\n<img src=./assets/teaser.png>\n</div>\n\n## Updates\n12/26/2022: A new HuggingFace demo is now available. Special thanks to [@SerdarHelli](https://github.com/SerdarHelli) for the implementation.<br>\n3/27/2022: Fixed a bug in the sphere initialization code (init_forward function was missing, see commit [0bd8741](https://github.com/royorel/StyleSDF/commit/0bd8741f26048d26160a495a9056b5f1da1a60a1)).<br>\n3/22/2022: **Added training files**.<br>\n3/9/2022: Fixed a bug in the calculation of the mean w vector (see commit [d4dd17d](https://github.com/royorel/StyleSDF/commit/d4dd17de09fd58adefc7ed49487476af6018894f)).<br>\n3/4/2022: Testing code and Colab demo were released. **Training files will be released soon.**\n\n## Overview\nStyleSDF is a 3D-aware GAN, aimed at solving two main challenges:\n1. High-resolution, view-consistent generation of the RGB images.\n2. Generating detailed 3D shapes.\n\nStyleSDF is trained only on single-view RGB data. The 3D geometry is learned implicitly with an SDF-based volume renderer.<br>\n\nThis code is the official PyTorch implementation of the paper:\n> **StyleSDF: High-Resolution 3D-Consistent Image and Geometry Generation**<br>\n> Roy Or-El, Xuan Luo, Mengyi Shan, Eli Shechtman, Jeong Joon Park, Ira Kemelmacher-Shlizerman<br>\n> CVPR 2022<br>\n> https://arxiv.org/pdf/2112.11427.pdf\n\n## Abstract\nWe introduce a high resolution, 3D-consistent image and\nshape generation technique which we call StyleSDF. Our\nmethod is trained on single-view RGB data only, and stands\non the shoulders of StyleGAN2 for image generation, while\nsolving two main challenges in 3D-aware GANs: 1) high-resolution,\nview-consistent generation of the RGB images,\nand 2) detailed 3D shape. We achieve this by merging a\nSDF-based 3D representation with a style-based 2D generator.\nOur 3D implicit network renders low-resolution feature\nmaps, from which the style-based network generates\nview-consistent, 1024×1024 images. Notably, our SDFbased\n3D modeling defines detailed 3D surfaces, leading\nto consistent volume rendering. Our method shows higher\nquality results compared to state of the art in terms of visual\nand geometric quality.\n\n## Pre-Requisits\nYou must have a **GPU with CUDA support** in order to run the code.\n\nThis code requires **PyTorch**, **PyTorch3D** and **torchvision** to be installed, please go to [PyTorch.org](https://pytorch.org/) and [PyTorch3d.org](https://pytorch3d.org/) for installation info.<br>\nWe tested our code on Python 3.8.5, PyTorch 1.9.0, PyTorch3D 0.6.1 and torchvision 0.10.0.\n\nThe following packages should also be installed:\n1. lmdb\n2. numpy\n3. ninja\n4. pillow\n5. requests\n6. tqdm\n7. scipy\n8. skimage\n9. skvideo\n10. trimesh[easy]\n11. configargparse\n12. munch\n13. wandb (optional)\n\nIf any of these packages are not installed on your computer, you can install them using the supplied `requirements.txt` file:<br>\n```pip install -r requirements.txt```\n\n## Download Pre-trained Models\nThe pre-trained models can be downloaded by running `python download_models.py`.\n\n## Quick Demo\nYou can explore our method in Google Colab [![Explore in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/royorel/StyleSDF/blob/main/StyleSDF_demo.ipynb).\n\nAlternatively, you can download the pretrained models by running:<br>\n`python download_models.py`\n\nTo generate human faces from the model pre-trained on FFHQ, run:<br>\n`python generate_shapes_and_images.py --expname ffhq1024x1024 --size 1024 --identities NUMBER_OF_FACES`\n\nTo generate animal faces from the model pre-trained on AFHQ, run:<br>\n`python generate_shapes_and_images.py --expname afhq512x512 --size 512 --identities NUMBER_OF_FACES`\n\n## Generating images and meshes\nTo generate images and meshes from a trained model, run:\n`python generate_shapes_and_images.py --expname NAME_OF_TRAINED_MODEL --size MODEL_OUTPUT_SIZE --identities NUMBER_OF_FACES`\n\nThe script will generate an RGB image, a mesh generated from depth map, and the mesh extracted with Marching cubes.\n\n### Optional flags for image and mesh generation\n```\n  --no_surface_renderings          When true, only RGB outputs will be generated. otherwise, both RGB and depth videos/renderings will be generated. (default: false)\n  --fixed_camera_angles            When true, the generator will render indentities from a fixed set of camera angles. (default: false)\n```\n\n## Generating Videos\nTo generate videos from a trained model, run: <br>\n`python render_video.py --expname NAME_OF_TRAINED_MODEL --size MODEL_OUTPUT_SIZE --identities NUMBER_OF_FACES`.\n\nThis script will generate RGB video as well as depth map video for each identity. The average processing time per video is ~5-10 minutes on an RTX2080 Ti GPU.\n\n### Optional flags for video rendering\n```\n  --no_surface_videos          When true, only RGB video will be generated when running render_video.py. otherwise, both RGB and depth videos will be generated. this cuts the processing time per video. (default: false)\n  --azim_video                 When true, the camera trajectory will travel along the azimuth direction. Otherwise, the camera will travel along an ellipsoid trajectory. (default: ellipsoid)\n  --project_noise              When true, use geometry-aware noise projection to reduce flickering effects (see supplementary section C.1 in the paper). Warning: processing time significantly increases with this flag to ~20 minutes per video. (default: false)\n```\n\n## Training\n### Preparing your Dataset\nIf you wish to train a model from scratch, first you need to convert your dataset to an lmdb format. Run:<br>\n`python prepare_data.py --out_path OUTPUT_LMDB_PATH --n_worker N_WORKER --size SIZE1,SIZE2,SIZE3,... INPUT_DATASET_PATH`\n\n### Training the volume renderer\n#### Training scripts\nTo train the volume renderer on FFHQ run: `bash ./scripts/train_ffhq_vol_renderer.sh`. <br>\nTo train the volume renderer on AFHQ run: `bash ./scripts/train_afhq_vol_renderer.sh`.\n\n* The scripts above use distributed training. To train the models on a single GPU (not recommended) remove `-m torch.distributed.launch --nproc_per_node NUM_GPUS` from the script.\n\n#### Training on a new dataset\nTo train the volume renderer on a new dataset, run:<br>\n`python train_volume_renderer.py --batch BATCH_SIZE --chunk CHUNK_SIZE --expname EXPERIMENT_NAME --dataset_path DATASET_PATH`\n\nIdeally, `CHUNK_SIZE` should be the same as `BATCH_SIZE`, but on most GPUs it will likely cause an out of memory error. In such case, reduce `CHUNK_SIZE` to perform gradient accumulation.\n\n**Important note**: The best way to monitor the SDF convergence is to look at \"Beta value\" graph on wandb. Convergence is successful once beta reaches values below approx. 3*10<sup>-3</sup>. If the SDF is not converging, increase the R1 regularization weight. Another helpful option (to a lesser degree) is to decrease the weight of the minimal surface regularization.  \n\n#### Distributed training\nIf you have multiple GPUs you can train your model on multiple instances by running:<br>\n`python -m torch.distributed.launch --nproc_per_node NUM_GPUS train_volume_renderer.py --batch BATCH_SIZE --chunk CHUNK_SIZE --expname EXPERIMENT_NAME --dataset_path DATASET_PATH`\n\n### Training the full pipeline\n#### Training scripts\nTo train the volume renderer on FFHQ run: `bash ./scripts/train_ffhq_full_pipeline_1024x1024.sh`. <br>\nTo train the volume renderer on AFHQ run: `bash ./scripts/train_afhq_full_pipeline_512x512.sh`.\n\n* The scripts above assume that the volume renderer model was already trained. **Do not run them from scratch.**\n* The scripts above use distributed training. To train the models on a single GPU (not recommended) remove `-m torch.distributed.launch --nproc_per_node NUM_GPUS` from the script.\n\n#### Training on a new dataset\nTo train the full pipeline on a new dataset, **first train the volume renderer separately**. <br>\nAfter the volume renderer training is finished, run:<br>\n`python train_full_pipeline.py --batch BATCH_SIZE --chunk CHUNK_SIZE --expname EXPERIMENT_NAME --size OUTPUT_SIZE`\n\nIdeally, `CHUNK_SIZE` should be the same as `BATCH_SIZE`, but on most GPUs it will likely cause an out of memory error. In such case, reduce `CHUNK_SIZE` to perform gradient accumulation.\n\n#### Distributed training\nIf you have multiple GPUs you can train your model on multiple instances by running:<br>\n`python -m torch.distributed.launch --nproc_per_node NUM_GPUS train_full_pipeline.py --batch BATCH_SIZE --chunk CHUNK_SIZE --expname EXPERIMENT_NAME --size OUTPUT_SIZE`\n\nHere, **BATCH_SIZE represents the batch per GPU**, not the overall batch size.\n\n### Optional training flags\n```\nTraining regime options:\n  --iter                 Total number of training iterations. (default: 300,000)\n  --wandb                Use use weights and biases logging. (default: False)\n  --r1                   Weight of the r1 regularization. (default: 10.0)\n  --view_lambda          Weight of the viewpoint regularization. (Equation 6, default: 15)\n  --eikonal_lambda       Weight of the eikonal regularization. (Equation 7, default: 0.1)\n  --min_surf_lambda      Weight of the minimal surface regularization. (Equation 8, default: 0.05)\n\nCamera options:\n  --uniform              When true, the camera position is sampled from uniform distribution. (default: gaussian)\n  --azim                 Camera azimuth angle std (guassian)/range (uniform) in Radians. (default: 0.3 Rad.)\n  --elev                 Camera elevation angle std (guassian)/range (uniform) in Radians. (default: 0.15 Rad.)\n  --fov                  Camera field of view half angle in **Degrees**. (default: 6 Deg.)\n  --dist_radius          Radius of points sampling distance from the origin. Determines the near and far fields. (default: 0.12)\n```\n\n## Citation\nIf you use this code for your research, please cite our paper.\n```\n@InProceedings{orel2022stylesdf,\n  title={Style{SDF}: {H}igh-{R}esolution {3D}-{C}onsistent {I}mage and {G}eometry {G}eneration},\n  author    = {Or-El, Roy and \n               Luo, Xuan and \n               Shan, Mengyi and \n               Shechtman, Eli and \n               Park, Jeong Joon and \n               Kemelmacher-Shlizerman, Ira},\n  booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},\n  month     = {June},\n  year      = {2022},\n  pages     = {13503-13513}\n}\n```\n\n## Acknowledgments\nThis code is inspired by rosinality's [StyleGAN2-PyTorch](https://github.com/rosinality/stylegan2-pytorch) and Yen-Chen Lin's [NeRF-pytorch](https://github.com/yenchenlin/nerf-pytorch).\n\nA special thanks to [@SerdarHelli](https://github.com/SerdarHelli) for implementing the HuggingFace demo.\n"
  },
  {
    "path": "StyleSDF_demo.ipynb",
    "content": "{\n  \"nbformat\": 4,\n  \"nbformat_minor\": 0,\n  \"metadata\": {\n    \"colab\": {\n      \"name\": \"StyleSDF_demo.ipynb\",\n      \"private_outputs\": true,\n      \"provenance\": [],\n      \"collapsed_sections\": [],\n      \"include_colab_link\": true\n    },\n    \"kernelspec\": {\n      \"name\": \"python3\",\n      \"display_name\": \"Python 3\"\n    },\n    \"language_info\": {\n      \"name\": \"python\"\n    },\n    \"accelerator\": \"GPU\"\n  },\n  \"cells\": [\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"view-in-github\",\n        \"colab_type\": \"text\"\n      },\n      \"source\": [\n        \"<a href=\\\"https://colab.research.google.com/github/royorel/StyleSDF/blob/main/StyleSDF_demo.ipynb\\\" target=\\\"_parent\\\"><img src=\\\"https://colab.research.google.com/assets/colab-badge.svg\\\" alt=\\\"Open In Colab\\\"/></a>\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"source\": [\n        \"#StyleSDF Demo\\n\",\n        \"\\n\",\n        \"This Colab notebook demonstrates the capabilities of the StyleSDF 3D-aware GAN architecture proposed in our paper.\\n\",\n        \"\\n\",\n        \"This colab generates images with their correspondinig 3D meshes\"\n      ],\n      \"metadata\": {\n        \"id\": \"b86fxLSo1gqI\"\n      }\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"source\": [\n        \"First, let's download the github repository and install all dependencies.\"\n      ],\n      \"metadata\": {\n        \"id\": \"8wSAkifH2Bk5\"\n      }\n    },\n    {\n      \"cell_type\": \"code\",\n      \"source\": [\n        \"!git clone https://github.com/royorel/StyleSDF.git\\n\",\n        \"%cd StyleSDF\\n\",\n        \"!pip3 install -r requirements.txt\"\n      ],\n      \"metadata\": {\n        \"id\": \"1GTFRig12CuH\"\n      },\n      \"execution_count\": null,\n      \"outputs\": []\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"source\": [\n        \"And install pytorch3D...\"\n      ],\n      \"metadata\": {\n        \"id\": \"GwFJ8oLY4i8d\"\n      }\n    },\n    {\n      \"cell_type\": \"code\",\n      \"source\": [\n        \"!pip install -U fvcore\\n\",\n        \"import sys\\n\",\n        \"import torch\\n\",\n        \"pyt_version_str=torch.__version__.split(\\\"+\\\")[0].replace(\\\".\\\", \\\"\\\")\\n\",\n        \"version_str=\\\"\\\".join([\\n\",\n        \"    f\\\"py3{sys.version_info.minor}_cu\\\",\\n\",\n        \"    torch.version.cuda.replace(\\\".\\\",\\\"\\\"),\\n\",\n        \"    f\\\"_pyt{pyt_version_str}\\\"\\n\",\n        \"])\\n\",\n        \"!pip install --no-index --no-cache-dir pytorch3d -f https://dl.fbaipublicfiles.com/pytorch3d/packaging/wheels/{version_str}/download.html\"\n      ],\n      \"metadata\": {\n        \"id\": \"zl3Vpddz3ols\"\n      },\n      \"execution_count\": null,\n      \"outputs\": []\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"source\": [\n        \"Now let's download the pretrained models for FFHQ and AFHQ.\"\n      ],\n      \"metadata\": {\n        \"id\": \"OgMcslVS5vbC\"\n      }\n    },\n    {\n      \"cell_type\": \"code\",\n      \"source\": [\n        \"!python download_models.py\"\n      ],\n      \"metadata\": {\n        \"id\": \"r1iDkz7r5wnO\"\n      },\n      \"execution_count\": null,\n      \"outputs\": []\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"source\": [\n        \"Here, we import libraries and set options.\\n\",\n        \"\\n\",\n        \"Note: this might take a while (approx. 1-2 minutes) since CUDA kernels need to be compiled.\"\n      ],\n      \"metadata\": {\n        \"id\": \"F2JUyGq4JDa8\"\n      }\n    },\n    {\n      \"cell_type\": \"code\",\n      \"source\": [\n        \"import os\\n\",\n        \"import torch\\n\",\n        \"import trimesh\\n\",\n        \"import numpy as np\\n\",\n        \"from munch import *\\n\",\n        \"from options import BaseOptions\\n\",\n        \"from model import Generator\\n\",\n        \"from generate_shapes_and_images import generate\\n\",\n        \"from render_video import render_video\\n\",\n        \"\\n\",\n        \"\\n\",\n        \"torch.random.manual_seed(321)\\n\",\n        \"\\n\",\n        \"\\n\",\n        \"device = \\\"cuda\\\"\\n\",\n        \"opt = BaseOptions().parse()\\n\",\n        \"opt.camera.uniform = True\\n\",\n        \"opt.model.is_test = True\\n\",\n        \"opt.model.freeze_renderer = False\\n\",\n        \"opt.rendering.offset_sampling = True\\n\",\n        \"opt.rendering.static_viewdirs = True\\n\",\n        \"opt.rendering.force_background = True\\n\",\n        \"opt.rendering.perturb = 0\\n\",\n        \"opt.inference.renderer_output_size = opt.model.renderer_spatial_output_dim\\n\",\n        \"opt.inference.style_dim = opt.model.style_dim\\n\",\n        \"opt.inference.project_noise = opt.model.project_noise\"\n      ],\n      \"metadata\": {\n        \"id\": \"Qfamt8J0JGn5\"\n      },\n      \"execution_count\": null,\n      \"outputs\": []\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"id\": \"TLOiDzgKtYyi\"\n      },\n      \"source\": [\n        \"Don't worry about this message above, \\n\",\n        \"```\\n\",\n        \"usage: ipykernel_launcher.py [-h] [--dataset_path DATASET_PATH]\\n\",\n        \"                             [--config CONFIG] [--expname EXPNAME]\\n\",\n        \"                             [--ckpt CKPT] [--continue_training]\\n\",\n        \"                             ...\\n\",\n        \"                             ...\\n\",\n        \"ipykernel_launcher.py: error: unrecognized arguments: -f /root/.local/share/jupyter/runtime/kernel-c9d47a98-bdba-4a5f-9f0a-e1437c7228b6.json\\n\",\n        \"```\\n\",\n        \"everything is perfectly fine...\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"source\": [\n        \"Here, we define our model.\\n\",\n        \"\\n\",\n        \"Set the options below according to your choosing:\\n\",\n        \"1. If you plan to try the method for the AFHQ dataset (animal faces), change `model_type` to 'afhq'. Default: `ffhq` (human faces).\\n\",\n        \"2. If you wish to turn off depth rendering and marching cubes extraction and generate only RGB images, set `opt.inference.no_surface_renderings = True`. Default: `False`.\\n\",\n        \"3. If you wish to generate the image from a specific set of viewpoints, set `opt.inference.fixed_camera_angles = True`. Default: `False`.\\n\",\n        \"4. Set the number of identities you wish to create in `opt.inference.identities`. Default: `4`.\\n\",\n        \"5. Select the number of views per identity in `opt.inference.num_views_per_id`,<br>\\n\",\n        \"   (Only applicable when `opt.inference.fixed_camera_angles` is false). Default: `1`. \"\n      ],\n      \"metadata\": {\n        \"id\": \"40dk1eEJo9MF\"\n      }\n    },\n    {\n      \"cell_type\": \"code\",\n      \"source\": [\n        \"# User options\\n\",\n        \"model_type = 'ffhq' # Whether to load the FFHQ or AFHQ model\\n\",\n        \"opt.inference.no_surface_renderings = False # When true, only RGB images will be created\\n\",\n        \"opt.inference.fixed_camera_angles = False # When true, each identity will be rendered from a specific set of 13 viewpoints. Otherwise, random views are generated\\n\",\n        \"opt.inference.identities = 4 # Number of identities to generate\\n\",\n        \"opt.inference.num_views_per_id = 1 # Number of viewpoints generated per identity. This option is ignored if opt.inference.fixed_camera_angles is true.\\n\",\n        \"\\n\",\n        \"# Load saved model\\n\",\n        \"if model_type == 'ffhq':\\n\",\n        \"    model_path = 'ffhq1024x1024.pt'\\n\",\n        \"    opt.model.size = 1024\\n\",\n        \"    opt.experiment.expname = 'ffhq1024x1024'\\n\",\n        \"else:\\n\",\n        \"    opt.inference.camera.azim = 0.15\\n\",\n        \"    model_path = 'afhq512x512.pt'\\n\",\n        \"    opt.model.size = 512\\n\",\n        \"    opt.experiment.expname = 'afhq512x512'\\n\",\n        \"\\n\",\n        \"# Create results directory\\n\",\n        \"result_model_dir = 'final_model'\\n\",\n        \"results_dir_basename = os.path.join(opt.inference.results_dir, opt.experiment.expname)\\n\",\n        \"opt.inference.results_dst_dir = os.path.join(results_dir_basename, result_model_dir)\\n\",\n        \"if opt.inference.fixed_camera_angles:\\n\",\n        \"    opt.inference.results_dst_dir = os.path.join(opt.inference.results_dst_dir, 'fixed_angles')\\n\",\n        \"else:\\n\",\n        \"    opt.inference.results_dst_dir = os.path.join(opt.inference.results_dst_dir, 'random_angles')\\n\",\n        \"\\n\",\n        \"os.makedirs(opt.inference.results_dst_dir, exist_ok=True)\\n\",\n        \"os.makedirs(os.path.join(opt.inference.results_dst_dir, 'images'), exist_ok=True)\\n\",\n        \"if not opt.inference.no_surface_renderings:\\n\",\n        \"    os.makedirs(os.path.join(opt.inference.results_dst_dir, 'depth_map_meshes'), exist_ok=True)\\n\",\n        \"    os.makedirs(os.path.join(opt.inference.results_dst_dir, 'marching_cubes_meshes'), exist_ok=True)\\n\",\n        \"\\n\",\n        \"opt.inference.camera = opt.camera\\n\",\n        \"opt.inference.size = opt.model.size\\n\",\n        \"checkpoint_path = os.path.join('full_models', model_path)\\n\",\n        \"checkpoint = torch.load(checkpoint_path)\\n\",\n        \"\\n\",\n        \"# Load image generation model\\n\",\n        \"g_ema = Generator(opt.model, opt.rendering).to(device)\\n\",\n        \"pretrained_weights_dict = checkpoint[\\\"g_ema\\\"]\\n\",\n        \"model_dict = g_ema.state_dict()\\n\",\n        \"for k, v in pretrained_weights_dict.items():\\n\",\n        \"    if v.size() == model_dict[k].size():\\n\",\n        \"        model_dict[k] = v\\n\",\n        \"\\n\",\n        \"g_ema.load_state_dict(model_dict)\\n\",\n        \"\\n\",\n        \"# Load a second volume renderer that extracts surfaces at 128x128x128 (or higher) for better surface resolution\\n\",\n        \"if not opt.inference.no_surface_renderings:\\n\",\n        \"    opt['surf_extraction'] = Munch()\\n\",\n        \"    opt.surf_extraction.rendering = opt.rendering\\n\",\n        \"    opt.surf_extraction.model = opt.model.copy()\\n\",\n        \"    opt.surf_extraction.model.renderer_spatial_output_dim = 128\\n\",\n        \"    opt.surf_extraction.rendering.N_samples = opt.surf_extraction.model.renderer_spatial_output_dim\\n\",\n        \"    opt.surf_extraction.rendering.return_xyz = True\\n\",\n        \"    opt.surf_extraction.rendering.return_sdf = True\\n\",\n        \"    surface_g_ema = Generator(opt.surf_extraction.model, opt.surf_extraction.rendering, full_pipeline=False).to(device)\\n\",\n        \"\\n\",\n        \"\\n\",\n        \"    # Load weights to surface extractor\\n\",\n        \"    surface_extractor_dict = surface_g_ema.state_dict()\\n\",\n        \"    for k, v in pretrained_weights_dict.items():\\n\",\n        \"        if k in surface_extractor_dict.keys() and v.size() == surface_extractor_dict[k].size():\\n\",\n        \"            surface_extractor_dict[k] = v\\n\",\n        \"\\n\",\n        \"    surface_g_ema.load_state_dict(surface_extractor_dict)\\n\",\n        \"else:\\n\",\n        \"    surface_g_ema = None\\n\",\n        \"\\n\",\n        \"# Get the mean latent vector for g_ema\\n\",\n        \"if opt.inference.truncation_ratio < 1:\\n\",\n        \"    with torch.no_grad():\\n\",\n        \"        mean_latent = g_ema.mean_latent(opt.inference.truncation_mean, device)\\n\",\n        \"else:\\n\",\n        \"    surface_mean_latent = None\\n\",\n        \"\\n\",\n        \"# Get the mean latent vector for surface_g_ema\\n\",\n        \"if not opt.inference.no_surface_renderings:\\n\",\n        \"    surface_mean_latent = mean_latent[0]\\n\",\n        \"else:\\n\",\n        \"    surface_mean_latent = None\"\n      ],\n      \"metadata\": {\n        \"id\": \"CUcWipIlpINT\"\n      },\n      \"execution_count\": null,\n      \"outputs\": []\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"source\": [\n        \"## Generating images and meshes\\n\",\n        \"\\n\",\n        \"Finally, we run the network. The results will be saved to `evaluations/[model_name]/final_model/[fixed/random]_angles`, according to the selected setup.\"\n      ],\n      \"metadata\": {\n        \"id\": \"N9pqjPDYCwIJ\"\n      }\n    },\n    {\n      \"cell_type\": \"code\",\n      \"source\": [\n        \"generate(opt.inference, g_ema, surface_g_ema, device, mean_latent, surface_mean_latent)\"\n      ],\n      \"metadata\": {\n        \"id\": \"UG4hZgigDfG7\"\n      },\n      \"execution_count\": null,\n      \"outputs\": []\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"source\": [\n        \"Now let's examine the results\\n\",\n        \"\\n\",\n        \"Tip: for better mesh visualization, we recommend dowwnloading the result meshes and view them with Meshlab.\\n\",\n        \"\\n\",\n        \"Meshes loaction is: `evaluations/[model_name]/final_model/[fixed/random]_angles/[depth_map/marching_cubes]_meshes`.\"\n      ],\n      \"metadata\": {\n        \"id\": \"obfjXuDxNuZl\"\n      }\n    },\n    {\n      \"cell_type\": \"code\",\n      \"source\": [\n        \"from PIL import Image\\n\",\n        \"from trimesh.viewer.notebook import scene_to_html as mesh2html\\n\",\n        \"from IPython.display import HTML as viewer_html\\n\",\n        \"\\n\",\n        \"# First let's look at the images\\n\",\n        \"img_dir = os.path.join(opt.inference.results_dst_dir,'images')\\n\",\n        \"im_list = sorted([entry for entry in os.listdir(img_dir) if 'thumb' not in entry])\\n\",\n        \"img = Image.new('RGB', (256 * len(im_list), 256))\\n\",\n        \"for i, im_file in enumerate(im_list):\\n\",\n        \"    im_path = os.path.join(img_dir, im_file)\\n\",\n        \"    curr_img = Image.open(im_path).resize((256,256)) # the displayed image is scaled to fit to the screen\\n\",\n        \"    img.paste(curr_img, (256 * i, 0))\\n\",\n        \"\\n\",\n        \"display(img)\\n\",\n        \"\\n\",\n        \"# And now, we'll move on to display the marching cubes and depth map meshes\\n\",\n        \"\\n\",\n        \"marching_cubes_meshes_dir = os.path.join(opt.inference.results_dst_dir,'marching_cubes_meshes')\\n\",\n        \"marching_cubes_meshes_list = sorted([os.path.join(marching_cubes_meshes_dir, entry) for entry in os.listdir(marching_cubes_meshes_dir) if 'obj' in entry])\\n\",\n        \"depth_map_meshes_dir = os.path.join(opt.inference.results_dst_dir,'depth_map_meshes')\\n\",\n        \"depth_map_meshes_list = sorted([os.path.join(depth_map_meshes_dir, entry) for entry in os.listdir(depth_map_meshes_dir) if 'obj' in entry])\\n\",\n        \"for i, mesh_files in enumerate(zip(marching_cubes_meshes_list, depth_map_meshes_list)):\\n\",\n        \"    mc_mesh_file, dm_mesh_file = mesh_files[0], mesh_files[1]\\n\",\n        \"    marching_cubes_mesh = trimesh.Scene(trimesh.load_mesh(mc_mesh_file, 'obj'))  \\n\",\n        \"    curr_mc_html = mesh2html(marching_cubes_mesh).replace('\\\"', '&quot;')\\n\",\n        \"    display(viewer_html(' '.join(['<iframe srcdoc=\\\"{srcdoc}\\\"',\\n\",\n        \"                            'width=\\\"{width}px\\\" height=\\\"{height}px\\\"',\\n\",\n        \"                            'style=\\\"border:none;\\\"></iframe>']).format(\\n\",\n        \"                            srcdoc=curr_mc_html, height=256, width=256)))\\n\",\n        \"    depth_map_mesh = trimesh.Scene(trimesh.load_mesh(dm_mesh_file, 'obj'))  \\n\",\n        \"    curr_dm_html = mesh2html(depth_map_mesh).replace('\\\"', '&quot;')\\n\",\n        \"    display(viewer_html(' '.join(['<iframe srcdoc=\\\"{srcdoc}\\\"',\\n\",\n        \"                            'width=\\\"{width}px\\\" height=\\\"{height}px\\\"',\\n\",\n        \"                            'style=\\\"border:none;\\\"></iframe>']).format(\\n\",\n        \"                            srcdoc=curr_dm_html, height=256, width=256)))\"\n      ],\n      \"metadata\": {\n        \"id\": \"k3jXCVT7N2YZ\"\n      },\n      \"execution_count\": null,\n      \"outputs\": []\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"source\": [\n        \"## Generating videos\\n\",\n        \"\\n\",\n        \"Additionally, we can also render videos. The results will be saved to `evaluations/[model_name]/final_model/videos`.\\n\",\n        \"\\n\",\n        \"Set the options below according to your choosing:\\n\",\n        \"1. If you wish to generate only RGB videos, set `opt.inference.no_surface_renderings = True`. Default: `False`.\\n\",\n        \"2. Set the camera trajectory. To travel along the azimuth direction set `opt.inference.azim_video = True`, to travel in an ellipsoid trajectory set `opt.inference.azim_video = False`. Default: `False`.\\n\",\n        \"\\n\",\n        \"###Important Note: \\n\",\n        \" - Processing time for videos when `opt.inference.no_surface_renderings = False` is very lengthy (~ 15-20 minutes per video). Rendering each depth frame for the depth videos is very slow.<br>\\n\",\n        \" - Processing time for videos when `opt.inference.no_surface_renderings = True` is much faster (~ 1-2 minutes per video)\"\n      ],\n      \"metadata\": {\n        \"id\": \"Cxq3c6anN4D0\"\n      }\n    },\n    {\n      \"cell_type\": \"code\",\n      \"source\": [\n        \"# Options\\n\",\n        \"opt.inference.no_surface_renderings = True # When true, only RGB videos will be created\\n\",\n        \"opt.inference.azim_video = True # When true, the camera trajectory will travel along the azimuth direction. Otherwise, the camera will travel along an ellipsoid trajectory.\\n\",\n        \"\\n\",\n        \"opt.inference.results_dst_dir = os.path.join(os.path.split(opt.inference.results_dst_dir)[0], 'videos')\\n\",\n        \"os.makedirs(opt.inference.results_dst_dir, exist_ok=True)\\n\",\n        \"render_video(opt.inference, g_ema, surface_g_ema, device, mean_latent, surface_mean_latent)\"\n      ],\n      \"metadata\": {\n        \"id\": \"nblhnZgcOST8\"\n      },\n      \"execution_count\": null,\n      \"outputs\": []\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"source\": [\n        \"Let's watch the result videos.\\n\",\n        \"\\n\",\n        \"The output video files are relatively large, so it might take a while (about 1-2 minutes) for all of them to be loaded. \"\n      ],\n      \"metadata\": {\n        \"id\": \"jkBYvlaeTGu1\"\n      }\n    },\n    {\n      \"cell_type\": \"code\",\n      \"source\": [\n        \"%%script bash --bg\\n\",\n        \"python3 -m https.server 8000\"\n      ],\n      \"metadata\": {\n        \"id\": \"s-A9oIFXnYdM\"\n      },\n      \"execution_count\": null,\n      \"outputs\": []\n    },\n    {\n      \"cell_type\": \"code\",\n      \"source\": [\n        \"# change ffhq1024x1024 to afhq512x512 if you are working on the AFHQ model\\n\",\n        \"%%html\\n\",\n        \"<div>\\n\",\n        \"  <video width=256 controls><source src=\\\"https://localhost:8000/evaluations/ffhq1024x1024/final_model/videos/sample_video_0_azim.mp4\\\" type=\\\"video/mp4\\\"></video>\\n\",\n        \"  <video width=256 controls><source src=\\\"https://localhost:8000/evaluations/ffhq1024x1024/final_model/videos/sample_video_1_azim.mp4\\\" type=\\\"video/mp4\\\"></video>\\n\",\n        \"  <video width=256 controls><source src=\\\"https://localhost:8000/evaluations/ffhq1024x1024/final_model/videos/sample_video_2_azim.mp4\\\" type=\\\"video/mp4\\\"></video>\\n\",\n        \"  <video width=256 controls><source src=\\\"https://localhost:8000/evaluations/ffhq1024x1024/final_model/videos/sample_video_3_azim.mp4\\\" type=\\\"video/mp4\\\"></video>\\n\",\n        \"</div>\"\n      ],\n      \"metadata\": {\n        \"id\": \"Rz8tq-yYnZMD\"\n      },\n      \"execution_count\": null,\n      \"outputs\": []\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"source\": [\n        \"An alternative way to view the videos with python code. \\n\",\n        \"It loads the videos faster, but very often it crashes the notebook since the video file are too large.\\n\",\n        \"\\n\",\n        \"**It is not recommended to view the files this way**.\\n\",\n        \"\\n\",\n        \"If the notebook does crash, you can also refresh the webpage and manually download the videos.<br>\\n\",\n        \"The videos are located in `evaluations/<model_name>/final_model/videos`\"\n      ],\n      \"metadata\": {\n        \"id\": \"KP_Nrl8yqL65\"\n      }\n    },\n    {\n      \"cell_type\": \"code\",\n      \"source\": [\n        \"# from base64 import b64encode\\n\",\n        \"\\n\",\n        \"# videos_dir = opt.inference.results_dst_dir\\n\",\n        \"# videos_list = sorted([os.path.join(videos_dir, entry) for entry in os.listdir(videos_dir) if 'mp4' in entry])\\n\",\n        \"# for i, video_file in enumerate(videos_list):\\n\",\n        \"#     if i != 1:\\n\",\n        \"#         continue\\n\",\n        \"#     mp4 = open(video_file,'rb').read()\\n\",\n        \"#     data_url = \\\"data:video/mp4;base64,\\\" + b64encode(mp4).decode()\\n\",\n        \"#     display(viewer_html(\\\"\\\"\\\"<video width={0} controls>\\n\",\n        \"#                                 <source src=\\\"{1}\\\" type=\\\"{2}\\\">\\n\",\n        \"#                           </video>\\\"\\\"\\\".format(256, data_url, \\\"video/mp4\\\")))\"\n      ],\n      \"metadata\": {\n        \"id\": \"LzsQTq1hTNSH\"\n      },\n      \"execution_count\": null,\n      \"outputs\": []\n    }\n  ]\n}"
  },
  {
    "path": "dataset.py",
    "content": "import os\nimport csv\nimport lmdb\nimport random\nimport numpy as np\nimport torchvision.transforms.functional as TF\nfrom PIL import Image\nfrom io import BytesIO\nfrom torch.utils.data import Dataset\n\n\nclass MultiResolutionDataset(Dataset):\n    def __init__(self, path, transform, resolution=256, nerf_resolution=64):\n        self.env = lmdb.open(\n            path,\n            max_readers=32,\n            readonly=True,\n            lock=False,\n            readahead=False,\n            meminit=False,\n        )\n\n        if not self.env:\n            raise IOError('Cannot open lmdb dataset', path)\n\n        with self.env.begin(write=False) as txn:\n            self.length = int(txn.get('length'.encode('utf-8')).decode('utf-8'))\n\n        self.resolution = resolution\n        self.nerf_resolution = nerf_resolution\n        self.transform = transform\n\n    def __len__(self):\n        return self.length\n\n    def __getitem__(self, index):\n        with self.env.begin(write=False) as txn:\n            key = f'{self.resolution}-{str(index).zfill(5)}'.encode('utf-8')\n            img_bytes = txn.get(key)\n\n        buffer = BytesIO(img_bytes)\n        img = Image.open(buffer)\n        if random.random() > 0.5:\n            img = TF.hflip(img)\n\n        thumb_img = img.resize((self.nerf_resolution, self.nerf_resolution), Image.HAMMING)\n        img = self.transform(img)\n        thumb_img = self.transform(thumb_img)\n\n        return img, thumb_img\n"
  },
  {
    "path": "distributed.py",
    "content": "import math\nimport pickle\n\nimport torch\nfrom torch import distributed as dist\nfrom torch.utils.data.sampler import Sampler\n\n\ndef get_rank():\n    if not dist.is_available():\n        return 0\n\n    if not dist.is_initialized():\n        return 0\n\n    return dist.get_rank()\n\n\ndef synchronize():\n    if not dist.is_available():\n        return\n\n    if not dist.is_initialized():\n        return\n\n    world_size = dist.get_world_size()\n\n    if world_size == 1:\n        return\n\n    dist.barrier()\n\n\ndef get_world_size():\n    if not dist.is_available():\n        return 1\n\n    if not dist.is_initialized():\n        return 1\n\n    return dist.get_world_size()\n\n\ndef reduce_sum(tensor):\n    if not dist.is_available():\n        return tensor\n\n    if not dist.is_initialized():\n        return tensor\n\n    tensor = tensor.clone()\n    dist.all_reduce(tensor, op=dist.ReduceOp.SUM)\n\n    return tensor\n\n\ndef gather_grad(params):\n    world_size = get_world_size()\n\n    if world_size == 1:\n        return\n\n    for param in params:\n        if param.grad is not None:\n            dist.all_reduce(param.grad.data, op=dist.ReduceOp.SUM)\n            param.grad.data.div_(world_size)\n\n\ndef all_gather(data):\n    world_size = get_world_size()\n\n    if world_size == 1:\n        return [data]\n\n    buffer = pickle.dumps(data)\n    storage = torch.ByteStorage.from_buffer(buffer)\n    tensor = torch.ByteTensor(storage).to('cuda')\n\n    local_size = torch.IntTensor([tensor.numel()]).to('cuda')\n    size_list = [torch.IntTensor([0]).to('cuda') for _ in range(world_size)]\n    dist.all_gather(size_list, local_size)\n    size_list = [int(size.item()) for size in size_list]\n    max_size = max(size_list)\n\n    tensor_list = []\n    for _ in size_list:\n        tensor_list.append(torch.ByteTensor(size=(max_size,)).to('cuda'))\n\n    if local_size != max_size:\n        padding = torch.ByteTensor(size=(max_size - local_size,)).to('cuda')\n        tensor = torch.cat((tensor, padding), 0)\n\n    dist.all_gather(tensor_list, tensor)\n\n    data_list = []\n\n    for size, tensor in zip(size_list, tensor_list):\n        buffer = tensor.cpu().numpy().tobytes()[:size]\n        data_list.append(pickle.loads(buffer))\n\n    return data_list\n\n\ndef reduce_loss_dict(loss_dict):\n    world_size = get_world_size()\n\n    if world_size < 2:\n        return loss_dict\n\n    with torch.no_grad():\n        keys = []\n        losses = []\n\n        for k in sorted(loss_dict.keys()):\n            keys.append(k)\n            losses.append(loss_dict[k])\n\n        losses = torch.stack(losses, 0)\n        dist.reduce(losses, dst=0)\n\n        if dist.get_rank() == 0:\n            losses /= world_size\n\n        reduced_losses = {k: v for k, v in zip(keys, losses)}\n\n    return reduced_losses\n"
  },
  {
    "path": "download_models.py",
    "content": "import os\nimport html\nimport glob\nimport uuid\nimport hashlib\nimport requests\nfrom tqdm import tqdm\nfrom pdb import set_trace as st\n\n\nffhq_full_model_spec = dict(file_url='https://drive.google.com/uc?id=13s_dH768zJ3IHUjySVbD1DqcMolqKlBi',\n                            alt_url='', file_size=202570217, file_md5='1ab9522157e537351fcf4faed9c92abb',\n                            file_path='full_models/ffhq1024x1024.pt',)\nffhq_volume_renderer_spec = dict(file_url='https://drive.google.com/uc?id=1zzB-ACuas7lSAln8pDnIqlWOEg869CCK',\n                                 alt_url='', file_size=63736197, file_md5='fe62f26032ccc8f04e1101b07bcd7462',\n                                 file_path='pretrained_renderer/ffhq_vol_renderer.pt',)\nafhq_full_model_spec = dict(file_url='https://drive.google.com/uc?id=1jZcV5__EPS56JBllRmUfUjMHExql_eSq',\n                            alt_url='', file_size=184908743, file_md5='91eeaab2da5c0d2134c04bb56ac5aeb6',\n                            file_path='full_models/afhq512x512.pt',)\nafhq_volume_renderer_spec = dict(file_url='https://drive.google.com/uc?id=1xhZjuJt_teghAQEoJevAxrU5_8VqqX42',\n                                 alt_url='', file_size=63736197, file_md5='cb3beb6cd3c43d9119356400165e7f26',\n                                 file_path='pretrained_renderer/afhq_vol_renderer.pt',)\nvolume_renderer_sphere_init_spec = dict(file_url='https://drive.google.com/uc?id=1CcYWbHFrJyb4u5rBFx_ww-ceYWd309tk',\n                                        alt_url='', file_size=63736197, file_md5='a6be653435b0e49633e6838f18ff4df6',\n                                        file_path='pretrained_renderer/sphere_init.pt',)\n\n\ndef download_pretrained_models():\n    print('Downloading sphere initialized volume renderer')\n    with requests.Session() as session:\n        try:\n            download_file(session, volume_renderer_sphere_init_spec)\n        except:\n            print('Google Drive download failed.\\n' \\\n                  'Trying do download from alternate server')\n            download_file(session, volume_renderer_sphere_init_spec, use_alt_url=True)\n\n    print('Downloading FFHQ pretrained volume renderer')\n    with requests.Session() as session:\n        try:\n            download_file(session, ffhq_volume_renderer_spec)\n        except:\n            print('Google Drive download failed.\\n' \\\n                  'Trying do download from alternate server')\n            download_file(session, ffhq_volume_renderer_spec, use_alt_url=True)\n\n    print('Downloading FFHQ full model (1024x1024)')\n    with requests.Session() as session:\n        try:\n            download_file(session, ffhq_full_model_spec)\n        except:\n            print('Google Drive download failed.\\n' \\\n                  'Trying do download from alternate server')\n            download_file(session, ffhq_full_model_spec, use_alt_url=True)\n\n    print('Done!')\n\n    print('Downloading AFHQ pretrained volume renderer')\n    with requests.Session() as session:\n        try:\n            download_file(session, afhq_volume_renderer_spec)\n        except:\n            print('Google Drive download failed.\\n' \\\n                  'Trying do download from alternate server')\n            download_file(session, afhq_volume_renderer_spec, use_alt_url=True)\n\n    print('Done!')\n\n    print('Downloading Downloading AFHQ full model (512x512)')\n    with requests.Session() as session:\n        try:\n            download_file(session, afhq_full_model_spec)\n        except:\n            print('Google Drive download failed.\\n' \\\n                  'Trying do download from alternate server')\n            download_file(session, afhq_full_model_spec, use_alt_url=True)\n\n    print('Done!')\n\ndef download_file(session, file_spec, use_alt_url=False, chunk_size=128, num_attempts=10):\n    file_path = file_spec['file_path']\n    if use_alt_url:\n        file_url = file_spec['alt_url']\n    else:\n        file_url = file_spec['file_url']\n\n    file_dir = os.path.dirname(file_path)\n    tmp_path = file_path + '.tmp.' + uuid.uuid4().hex\n    if file_dir:\n        os.makedirs(file_dir, exist_ok=True)\n\n    progress_bar = tqdm(total=file_spec['file_size'], unit='B', unit_scale=True)\n    for attempts_left in reversed(range(num_attempts)):\n        data_size = 0\n        progress_bar.reset()\n        try:\n            # Download.\n            data_md5 = hashlib.md5()\n            with session.get(file_url, stream=True) as res:\n                res.raise_for_status()\n                with open(tmp_path, 'wb') as f:\n                    for chunk in res.iter_content(chunk_size=chunk_size<<10):\n                        progress_bar.update(len(chunk))\n                        f.write(chunk)\n                        data_size += len(chunk)\n                        data_md5.update(chunk)\n\n            # Validate.\n            if 'file_size' in file_spec and data_size != file_spec['file_size']:\n                raise IOError('Incorrect file size', file_path)\n            if 'file_md5' in file_spec and data_md5.hexdigest() != file_spec['file_md5']:\n                raise IOError('Incorrect file MD5', file_path)\n            break\n\n        except:\n            # Last attempt => raise error.\n            if not attempts_left:\n                raise\n\n            # Handle Google Drive virus checker nag.\n            if data_size > 0 and data_size < 8192:\n                with open(tmp_path, 'rb') as f:\n                    data = f.read()\n                links = [html.unescape(link) for link in data.decode('utf-8').split('\"') if 'confirm=t' in link]\n                if len(links) == 1:\n                    file_url = requests.compat.urljoin(file_url, links[0])\n                    continue\n\n    progress_bar.close()\n\n    # Rename temp file to the correct name.\n    os.replace(tmp_path, file_path) # atomic\n\n    # Attempt to clean up any leftover temps.\n    for filename in glob.glob(file_path + '.tmp.*'):\n        try:\n            os.remove(filename)\n        except:\n            pass\n\nif __name__ == \"__main__\":\n    download_pretrained_models()\n"
  },
  {
    "path": "generate_shapes_and_images.py",
    "content": "import os\nimport torch\nimport trimesh\nimport numpy as np\nfrom munch import *\nfrom PIL import Image\nfrom tqdm import tqdm\nfrom torch.nn import functional as F\nfrom torch.utils import data\nfrom torchvision import utils\nfrom torchvision import transforms\nfrom skimage.measure import marching_cubes\nfrom scipy.spatial import Delaunay\nfrom options import BaseOptions\nfrom model import Generator\nfrom utils import (\n    generate_camera_params,\n    align_volume,\n    extract_mesh_with_marching_cubes,\n    xyz2mesh,\n)\n\n\ntorch.random.manual_seed(1234)\n\n\ndef generate(opt, g_ema, surface_g_ema, device, mean_latent, surface_mean_latent):\n    g_ema.eval()\n    if not opt.no_surface_renderings:\n        surface_g_ema.eval()\n\n    # set camera angles\n    if opt.fixed_camera_angles:\n        # These can be changed to any other specific viewpoints.\n        # You can add or remove viewpoints as you wish\n        locations = torch.tensor([[0, 0],\n                                  [-1.5 * opt.camera.azim, 0],\n                                  [-1 * opt.camera.azim, 0],\n                                  [-0.5 * opt.camera.azim, 0],\n                                  [0.5 * opt.camera.azim, 0],\n                                  [1 * opt.camera.azim, 0],\n                                  [1.5 * opt.camera.azim, 0],\n                                  [0, -1.5 * opt.camera.elev],\n                                  [0, -1 * opt.camera.elev],\n                                  [0, -0.5 * opt.camera.elev],\n                                  [0, 0.5 * opt.camera.elev],\n                                  [0, 1 * opt.camera.elev],\n                                  [0, 1.5 * opt.camera.elev]], device=device)\n        # For zooming in/out change the values of fov\n        # (This can be defined for each view separately via a custom tensor\n        # like the locations tensor above. Tensor shape should be [locations.shape[0],1])\n        # reasonable values are [0.75 * opt.camera.fov, 1.25 * opt.camera.fov]\n        fov = opt.camera.fov * torch.ones((locations.shape[0],1), device=device)\n        num_viewdirs = locations.shape[0]\n    else: # draw random camera angles\n        locations = None\n        # fov = None\n        fov = opt.camera.fov\n        num_viewdirs = opt.num_views_per_id\n\n    # generate images\n    for i in tqdm(range(opt.identities)):\n        with torch.no_grad():\n            chunk = 8\n            sample_z = torch.randn(1, opt.style_dim, device=device).repeat(num_viewdirs,1)\n            sample_cam_extrinsics, sample_focals, sample_near, sample_far, sample_locations = \\\n            generate_camera_params(opt.renderer_output_size, device, batch=num_viewdirs,\n                                   locations=locations, #input_fov=fov,\n                                   uniform=opt.camera.uniform, azim_range=opt.camera.azim,\n                                   elev_range=opt.camera.elev, fov_ang=fov,\n                                   dist_radius=opt.camera.dist_radius)\n            rgb_images = torch.Tensor(0, 3, opt.size, opt.size)\n            rgb_images_thumbs = torch.Tensor(0, 3, opt.renderer_output_size, opt.renderer_output_size)\n            for j in range(0, num_viewdirs, chunk):\n                out = g_ema([sample_z[j:j+chunk]],\n                            sample_cam_extrinsics[j:j+chunk],\n                            sample_focals[j:j+chunk],\n                            sample_near[j:j+chunk],\n                            sample_far[j:j+chunk],\n                            truncation=opt.truncation_ratio,\n                            truncation_latent=mean_latent)\n\n                rgb_images = torch.cat([rgb_images, out[0].cpu()], 0)\n                rgb_images_thumbs = torch.cat([rgb_images_thumbs, out[1].cpu()], 0)\n\n            utils.save_image(rgb_images,\n                os.path.join(opt.results_dst_dir, 'images','{}.png'.format(str(i).zfill(7))),\n                nrow=num_viewdirs,\n                normalize=True,\n                padding=0,\n                value_range=(-1, 1),)\n\n            utils.save_image(rgb_images_thumbs,\n                os.path.join(opt.results_dst_dir, 'images','{}_thumb.png'.format(str(i).zfill(7))),\n                nrow=num_viewdirs,\n                normalize=True,\n                padding=0,\n                value_range=(-1, 1),)\n\n            # this is done to fit to RTX2080 RAM size (11GB)\n            del out\n            torch.cuda.empty_cache()\n\n            if not opt.no_surface_renderings:\n                surface_chunk = 1\n                scale = surface_g_ema.renderer.out_im_res / g_ema.renderer.out_im_res\n                surface_sample_focals = sample_focals * scale\n                for j in range(0, num_viewdirs, surface_chunk):\n                    surface_out = surface_g_ema([sample_z[j:j+surface_chunk]],\n                                                sample_cam_extrinsics[j:j+surface_chunk],\n                                                surface_sample_focals[j:j+surface_chunk],\n                                                sample_near[j:j+surface_chunk],\n                                                sample_far[j:j+surface_chunk],\n                                                truncation=opt.truncation_ratio,\n                                                truncation_latent=surface_mean_latent,\n                                                return_sdf=True,\n                                                return_xyz=True)\n\n                    xyz = surface_out[2].cpu()\n                    sdf = surface_out[3].cpu()\n\n                    # this is done to fit to RTX2080 RAM size (11GB)\n                    del surface_out\n                    torch.cuda.empty_cache()\n\n                    # mesh extractions are done one at a time\n                    for k in range(surface_chunk):\n                        curr_locations = sample_locations[j:j+surface_chunk]\n                        loc_str = '_azim{}_elev{}'.format(int(curr_locations[k,0] * 180 / np.pi),\n                                                          int(curr_locations[k,1] * 180 / np.pi))\n\n                        # Save depth outputs as meshes\n                        depth_mesh_filename = os.path.join(opt.results_dst_dir,'depth_map_meshes','sample_{}_depth_mesh{}.obj'.format(i, loc_str))\n                        depth_mesh = xyz2mesh(xyz[k:k+surface_chunk])\n                        if depth_mesh != None:\n                            with open(depth_mesh_filename, 'w') as f:\n                                depth_mesh.export(f,file_type='obj')\n\n                        # extract full geometry with marching cubes\n                        if j == 0:\n                            try:\n                                frostum_aligned_sdf = align_volume(sdf)\n                                marching_cubes_mesh = extract_mesh_with_marching_cubes(frostum_aligned_sdf[k:k+surface_chunk])\n                            except ValueError:\n                                marching_cubes_mesh = None\n                                print('Marching cubes extraction failed.')\n                                print('Please check whether the SDF values are all larger (or all smaller) than 0.')\n\n                            if marching_cubes_mesh != None:\n                                marching_cubes_mesh_filename = os.path.join(opt.results_dst_dir,'marching_cubes_meshes','sample_{}_marching_cubes_mesh{}.obj'.format(i, loc_str))\n                                with open(marching_cubes_mesh_filename, 'w') as f:\n                                    marching_cubes_mesh.export(f,file_type='obj')\n\n\nif __name__ == \"__main__\":\n    device = \"cuda\"\n    opt = BaseOptions().parse()\n    opt.model.is_test = True\n    opt.model.freeze_renderer = False\n    opt.rendering.offset_sampling = True\n    opt.rendering.static_viewdirs = True\n    opt.rendering.force_background = True\n    opt.rendering.perturb = 0\n    opt.inference.size = opt.model.size\n    opt.inference.camera = opt.camera\n    opt.inference.renderer_output_size = opt.model.renderer_spatial_output_dim\n    opt.inference.style_dim = opt.model.style_dim\n    opt.inference.project_noise = opt.model.project_noise\n    opt.inference.return_xyz = opt.rendering.return_xyz\n\n    # find checkpoint directory\n    # check if there's a fully trained model\n    checkpoints_dir = 'full_models'\n    checkpoint_path = os.path.join(checkpoints_dir, opt.experiment.expname + '.pt')\n    if os.path.isfile(checkpoint_path):\n        # define results directory name\n        result_model_dir = 'final_model'\n    else:\n        checkpoints_dir = os.path.join('checkpoint', opt.experiment.expname, 'full_pipeline')\n        checkpoint_path = os.path.join(checkpoints_dir,\n                                       'models_{}.pt'.format(opt.experiment.ckpt.zfill(7)))\n        # define results directory name\n        result_model_dir = 'iter_{}'.format(opt.experiment.ckpt.zfill(7))\n\n    # create results directory\n    results_dir_basename = os.path.join(opt.inference.results_dir, opt.experiment.expname)\n    opt.inference.results_dst_dir = os.path.join(results_dir_basename, result_model_dir)\n    if opt.inference.fixed_camera_angles:\n        opt.inference.results_dst_dir = os.path.join(opt.inference.results_dst_dir, 'fixed_angles')\n    else:\n        opt.inference.results_dst_dir = os.path.join(opt.inference.results_dst_dir, 'random_angles')\n    os.makedirs(opt.inference.results_dst_dir, exist_ok=True)\n    os.makedirs(os.path.join(opt.inference.results_dst_dir, 'images'), exist_ok=True)\n    if not opt.inference.no_surface_renderings:\n        os.makedirs(os.path.join(opt.inference.results_dst_dir, 'depth_map_meshes'), exist_ok=True)\n        os.makedirs(os.path.join(opt.inference.results_dst_dir, 'marching_cubes_meshes'), exist_ok=True)\n\n    # load saved model\n    checkpoint = torch.load(checkpoint_path)\n\n    # load image generation model\n    g_ema = Generator(opt.model, opt.rendering).to(device)\n    pretrained_weights_dict = checkpoint[\"g_ema\"]\n    model_dict = g_ema.state_dict()\n    for k, v in pretrained_weights_dict.items():\n        if v.size() == model_dict[k].size():\n            model_dict[k] = v\n\n    g_ema.load_state_dict(model_dict)\n\n    # load a second volume renderer that extracts surfaces at 128x128x128 (or higher) for better surface resolution\n    if not opt.inference.no_surface_renderings:\n        opt['surf_extraction'] = Munch()\n        opt.surf_extraction.rendering = opt.rendering\n        opt.surf_extraction.model = opt.model.copy()\n        opt.surf_extraction.model.renderer_spatial_output_dim = 128\n        opt.surf_extraction.rendering.N_samples = opt.surf_extraction.model.renderer_spatial_output_dim\n        opt.surf_extraction.rendering.return_xyz = True\n        opt.surf_extraction.rendering.return_sdf = True\n        surface_g_ema = Generator(opt.surf_extraction.model, opt.surf_extraction.rendering, full_pipeline=False).to(device)\n\n        # Load weights to surface extractor\n        surface_extractor_dict = surface_g_ema.state_dict()\n        for k, v in pretrained_weights_dict.items():\n            if k in surface_extractor_dict.keys() and v.size() == surface_extractor_dict[k].size():\n                surface_extractor_dict[k] = v\n\n        surface_g_ema.load_state_dict(surface_extractor_dict)\n    else:\n        surface_g_ema = None\n\n    # get the mean latent vector for g_ema\n    if opt.inference.truncation_ratio < 1:\n        with torch.no_grad():\n            mean_latent = g_ema.mean_latent(opt.inference.truncation_mean, device)\n    else:\n        surface_mean_latent = None\n\n    # get the mean latent vector for surface_g_ema\n    if not opt.inference.no_surface_renderings:\n        surface_mean_latent = mean_latent[0]\n    else:\n        surface_mean_latent = None\n\n    generate(opt.inference, g_ema, surface_g_ema, device, mean_latent, surface_mean_latent)\n"
  },
  {
    "path": "losses.py",
    "content": "import math\nimport torch\nfrom torch import autograd\nfrom torch.nn import functional as F\n\n\ndef viewpoints_loss(viewpoint_pred, viewpoint_target):\n    loss = F.smooth_l1_loss(viewpoint_pred, viewpoint_target)\n\n    return loss\n\n\ndef eikonal_loss(eikonal_term, sdf=None, beta=100):\n    if eikonal_term == None:\n        eikonal_loss = 0\n    else:\n        eikonal_loss = ((eikonal_term.norm(dim=-1) - 1) ** 2).mean()\n\n    if sdf == None:\n        minimal_surface_loss = torch.tensor(0.0, device=eikonal_term.device)\n    else:\n        minimal_surface_loss = torch.exp(-beta * torch.abs(sdf)).mean()\n\n    return eikonal_loss, minimal_surface_loss\n\n\ndef d_logistic_loss(real_pred, fake_pred):\n    real_loss = F.softplus(-real_pred)\n    fake_loss = F.softplus(fake_pred)\n\n    return real_loss.mean() + fake_loss.mean()\n\n\ndef d_r1_loss(real_pred, real_img):\n    grad_real, = autograd.grad(outputs=real_pred.sum(),\n                               inputs=real_img,\n                               create_graph=True)\n\n    grad_penalty = grad_real.pow(2).reshape(grad_real.shape[0], -1).sum(1).mean()\n    return grad_penalty\n\n\ndef g_nonsaturating_loss(fake_pred):\n    loss = F.softplus(-fake_pred).mean()\n\n    return loss\n\n\ndef g_path_regularize(fake_img, latents, mean_path_length, decay=0.01):\n    noise = torch.randn_like(fake_img) / math.sqrt(\n        fake_img.shape[2] * fake_img.shape[3]\n    )\n    grad = autograd.grad(outputs=(fake_img * noise).sum(), inputs=latents,\n                         create_graph=True, only_inputs=True)[0]\n    path_lengths = torch.sqrt(grad.pow(2).sum(2).mean(1))\n\n    path_mean = mean_path_length + decay * (path_lengths.mean() - mean_path_length)\n\n    path_penalty = (path_lengths - path_mean).pow(2).mean()\n\n    return path_penalty, path_mean.detach(), path_lengths\n"
  },
  {
    "path": "model.py",
    "content": "import math\nimport random\nimport trimesh\nimport torch\nimport numpy as np\nfrom torch import nn\nfrom torch.nn import functional as F\nfrom volume_renderer import VolumeFeatureRenderer\nfrom op import FusedLeakyReLU, fused_leaky_relu, upfirdn2d\nfrom pdb import set_trace as st\n\nfrom utils import (\n    create_cameras,\n    create_mesh_renderer,\n    add_textures,\n    create_depth_mesh_renderer,\n)\nfrom pytorch3d.renderer import TexturesUV\nfrom pytorch3d.structures import Meshes\nfrom pytorch3d.renderer import look_at_view_transform, FoVPerspectiveCameras\nfrom pytorch3d.transforms import matrix_to_euler_angles\n\n\nclass PixelNorm(nn.Module):\n    def __init__(self):\n        super().__init__()\n\n    def forward(self, input):\n        return input * torch.rsqrt(torch.mean(input ** 2, dim=1, keepdim=True) + 1e-8)\n\n\nclass MappingLinear(nn.Module):\n    def __init__(self, in_dim, out_dim, bias=True, activation=None, is_last=False):\n        super().__init__()\n        if is_last:\n            weight_std = 0.25\n        else:\n            weight_std = 1\n\n        self.weight = nn.Parameter(weight_std * nn.init.kaiming_normal_(torch.empty(out_dim, in_dim), a=0.2, mode='fan_in', nonlinearity='leaky_relu'))\n\n        if bias:\n            self.bias = nn.Parameter(nn.init.uniform_(torch.empty(out_dim), a=-np.sqrt(1/in_dim), b=np.sqrt(1/in_dim)))\n        else:\n            self.bias = None\n\n        self.activation = activation\n\n    def forward(self, input):\n        if self.activation != None:\n            out = F.linear(input, self.weight)\n            out = fused_leaky_relu(out, self.bias, scale=1)\n        else:\n            out = F.linear(input, self.weight, bias=self.bias)\n\n        return out\n\n    def __repr__(self):\n        return (\n            f\"{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]})\"\n        )\n\n\ndef make_kernel(k):\n    k = torch.tensor(k, dtype=torch.float32)\n\n    if k.ndim == 1:\n        k = k[None, :] * k[:, None]\n\n    k /= k.sum()\n\n    return k\n\n\nclass Upsample(nn.Module):\n    def __init__(self, kernel, factor=2):\n        super().__init__()\n\n        self.factor = factor\n        kernel = make_kernel(kernel) * (factor ** 2)\n        self.register_buffer(\"kernel\", kernel)\n\n        p = kernel.shape[0] - factor\n\n        pad0 = (p + 1) // 2 + factor - 1\n        pad1 = p // 2\n\n        self.pad = (pad0, pad1)\n\n    def forward(self, input):\n        out = upfirdn2d(input, self.kernel, up=self.factor, down=1, pad=self.pad)\n\n        return out\n\n\nclass Downsample(nn.Module):\n    def __init__(self, kernel, factor=2):\n        super().__init__()\n\n        self.factor = factor\n        kernel = make_kernel(kernel)\n        self.register_buffer(\"kernel\", kernel)\n\n        p = kernel.shape[0] - factor\n\n        pad0 = (p + 1) // 2\n        pad1 = p // 2\n\n        self.pad = (pad0, pad1)\n\n    def forward(self, input):\n        out = upfirdn2d(input, self.kernel, up=1, down=self.factor, pad=self.pad)\n\n        return out\n\n\nclass Blur(nn.Module):\n    def __init__(self, kernel, pad, upsample_factor=1):\n        super().__init__()\n\n        kernel = make_kernel(kernel)\n\n        if upsample_factor > 1:\n            kernel = kernel * (upsample_factor ** 2)\n\n        self.register_buffer(\"kernel\", kernel)\n\n        self.pad = pad\n\n    def forward(self, input):\n        out = upfirdn2d(input, self.kernel, pad=self.pad)\n\n        return out\n\n\nclass EqualConv2d(nn.Module):\n    def __init__(\n        self, in_channel, out_channel, kernel_size, stride=1, padding=0, bias=True\n    ):\n        super().__init__()\n\n        self.weight = nn.Parameter(\n            torch.randn(out_channel, in_channel, kernel_size, kernel_size)\n        )\n        self.scale = 1 / math.sqrt(in_channel * kernel_size ** 2)\n\n        self.stride = stride\n        self.padding = padding\n\n        if bias:\n            self.bias = nn.Parameter(torch.zeros(out_channel))\n\n        else:\n            self.bias = None\n\n    def forward(self, input):\n        out = F.conv2d(\n            input,\n            self.weight * self.scale,\n            bias=self.bias,\n            stride=self.stride,\n            padding=self.padding,\n        )\n\n        return out\n\n    def __repr__(self):\n        return (\n            f\"{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]},\"\n            f\" {self.weight.shape[2]}, stride={self.stride}, padding={self.padding})\"\n        )\n\n\nclass EqualLinear(nn.Module):\n    def __init__(self, in_dim, out_dim, bias=True, bias_init=0, lr_mul=1,\n                 activation=None):\n        super().__init__()\n\n        self.weight = nn.Parameter(torch.randn(out_dim, in_dim).div_(lr_mul))\n\n        if bias:\n            self.bias = nn.Parameter(torch.zeros(out_dim).fill_(bias_init))\n\n        else:\n            self.bias = None\n\n        self.activation = activation\n\n        self.scale = (1 / math.sqrt(in_dim)) * lr_mul\n        self.lr_mul = lr_mul\n\n    def forward(self, input):\n        if self.activation:\n            out = F.linear(input, self.weight * self.scale)\n            out = fused_leaky_relu(out, self.bias * self.lr_mul)\n\n        else:\n            out = F.linear(input, self.weight * self.scale,\n                           bias=self.bias * self.lr_mul)\n\n        return out\n\n    def __repr__(self):\n        return (\n            f\"{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]})\"\n        )\n\n\nclass ModulatedConv2d(nn.Module):\n    def __init__(self, in_channel, out_channel, kernel_size, style_dim, demodulate=True,\n                 upsample=False, downsample=False, blur_kernel=[1, 3, 3, 1]):\n        super().__init__()\n\n        self.eps = 1e-8\n        self.kernel_size = kernel_size\n        self.in_channel = in_channel\n        self.out_channel = out_channel\n        self.upsample = upsample\n        self.downsample = downsample\n\n        if upsample:\n            factor = 2\n            p = (len(blur_kernel) - factor) - (kernel_size - 1)\n            pad0 = (p + 1) // 2 + factor - 1\n            pad1 = p // 2 + 1\n\n            self.blur = Blur(blur_kernel, pad=(pad0, pad1), upsample_factor=factor)\n\n        if downsample:\n            factor = 2\n            p = (len(blur_kernel) - factor) + (kernel_size - 1)\n            pad0 = (p + 1) // 2\n            pad1 = p // 2\n\n            self.blur = Blur(blur_kernel, pad=(pad0, pad1))\n\n        fan_in = in_channel * kernel_size ** 2\n        self.scale = 1 / math.sqrt(fan_in)\n        self.padding = kernel_size // 2\n\n        self.weight = nn.Parameter(\n            torch.randn(1, out_channel, in_channel, kernel_size, kernel_size)\n        )\n\n        self.modulation = EqualLinear(style_dim, in_channel, bias_init=1)\n\n        self.demodulate = demodulate\n\n    def __repr__(self):\n        return (\n            f\"{self.__class__.__name__}({self.in_channel}, {self.out_channel}, {self.kernel_size}, \"\n            f\"upsample={self.upsample}, downsample={self.downsample})\"\n        )\n\n    def forward(self, input, style):\n        batch, in_channel, height, width = input.shape\n        style = self.modulation(style).view(batch, 1, in_channel, 1, 1)\n\n        weight = self.scale * self.weight * style\n\n        if self.demodulate:\n            demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + 1e-8)\n            weight = weight * demod.view(batch, self.out_channel, 1, 1, 1)\n\n        weight = weight.view(\n            batch * self.out_channel, in_channel, self.kernel_size, self.kernel_size\n        )\n\n        if self.upsample:\n            input = input.view(1, batch * in_channel, height, width)\n            weight = weight.view(\n                batch, self.out_channel, in_channel, self.kernel_size, self.kernel_size\n            )\n            weight = weight.transpose(1, 2).reshape(\n                batch * in_channel, self.out_channel, self.kernel_size, self.kernel_size\n            )\n            out = F.conv_transpose2d(input, weight, padding=0, stride=2, groups=batch)\n            _, _, height, width = out.shape\n            out = out.view(batch, self.out_channel, height, width)\n            out = self.blur(out)\n\n        elif self.downsample:\n            input = self.blur(input)\n            _, _, height, width = input.shape\n            input = input.view(1, batch * in_channel, height, width)\n            out = F.conv2d(input, weight, padding=0, stride=2, groups=batch)\n            _, _, height, width = out.shape\n            out = out.view(batch, self.out_channel, height, width)\n\n        else:\n            input = input.view(1, batch * in_channel, height, width)\n            out = F.conv2d(input, weight, padding=self.padding, groups=batch)\n            _, _, height, width = out.shape\n            out = out.view(batch, self.out_channel, height, width)\n\n        return out\n\n\nclass NoiseInjection(nn.Module):\n    def __init__(self, project=False):\n        super().__init__()\n        self.project = project\n        self.weight = nn.Parameter(torch.zeros(1))\n        self.prev_noise = None\n        self.mesh_fn = None\n        self.vert_noise = None\n\n    def create_pytorch_mesh(self, trimesh):\n        v=trimesh.vertices; f=trimesh.faces\n        verts = torch.from_numpy(np.asarray(v)).to(torch.float32).cuda()\n        mesh_pytorch = Meshes(\n            verts=[verts],\n            faces = [torch.from_numpy(np.asarray(f)).to(torch.float32).cuda()],\n            textures=None\n        )\n        if self.vert_noise == None or verts.shape[0] != self.vert_noise.shape[1]:\n            self.vert_noise = torch.ones_like(verts)[:,0:1].cpu().normal_().expand(-1,3).unsqueeze(0)\n\n        mesh_pytorch = add_textures(meshes=mesh_pytorch, vertex_colors=self.vert_noise.to(verts.device))\n\n        return mesh_pytorch\n\n    def load_mc_mesh(self, filename, resolution=128, im_res=64):\n        import trimesh\n\n        mc_tri=trimesh.load_mesh(filename)\n        v=mc_tri.vertices; f=mc_tri.faces\n        mesh2=trimesh.base.Trimesh(vertices=v, faces=f)\n        if im_res==64 or im_res==128:\n            pytorch3d_mesh = self.create_pytorch_mesh(mesh2)\n            return pytorch3d_mesh\n        v,f = trimesh.remesh.subdivide(v,f)\n        mesh2_subdiv = trimesh.base.Trimesh(vertices=v, faces=f)\n        if im_res==256:\n            pytorch3d_mesh = self.create_pytorch_mesh(mesh2_subdiv);\n            return pytorch3d_mesh\n        v,f = trimesh.remesh.subdivide(mesh2_subdiv.vertices,mesh2_subdiv.faces)\n        mesh3_subdiv = trimesh.base.Trimesh(vertices=v, faces=f)\n        if im_res==256:\n            pytorch3d_mesh = self.create_pytorch_mesh(mesh3_subdiv);\n            return pytorch3d_mesh\n        v,f = trimesh.remesh.subdivide(mesh3_subdiv.vertices,mesh3_subdiv.faces)\n        mesh4_subdiv = trimesh.base.Trimesh(vertices=v, faces=f)\n\n        pytorch3d_mesh = self.create_pytorch_mesh(mesh4_subdiv)\n\n        return pytorch3d_mesh\n\n    def project_noise(self, noise, transform, mesh_path=None):\n        batch, _, height, width = noise.shape\n        assert(batch == 1)  # assuming during inference batch size is 1\n\n        angles = matrix_to_euler_angles(transform[0:1,:,:3], \"ZYX\")\n        azim = float(angles[0][1])\n        elev = float(-angles[0][2])\n\n        cameras = create_cameras(azim=azim*180/np.pi,elev=elev*180/np.pi,fov=12.,dist=1)\n\n        renderer = create_depth_mesh_renderer(cameras, image_size=height,\n                specular_color=((0,0,0),), ambient_color=((1.,1.,1.),),diffuse_color=((0,0,0),))\n\n\n        if self.mesh_fn is None or self.mesh_fn != mesh_path:\n            self.mesh_fn = mesh_path\n\n        pytorch3d_mesh = self.load_mc_mesh(mesh_path, im_res=height)\n        rgb, depth = renderer(pytorch3d_mesh)\n\n        depth_max = depth.max(-1)[0].view(-1) # (NxN)\n        depth_valid = depth_max > 0.\n        if self.prev_noise is None:\n            self.prev_noise = noise\n        noise_copy = self.prev_noise.clone()\n        noise_copy.view(-1)[depth_valid] = rgb[0,:,:,0].view(-1)[depth_valid]\n        noise_copy = noise_copy.reshape(1,1,height,height)  # 1x1xNxN\n\n        return noise_copy\n\n\n    def forward(self, image, noise=None, transform=None, mesh_path=None):\n        batch, _, height, width = image.shape\n        if noise is None:\n            noise = image.new_empty(batch, 1, height, width).normal_()\n        elif self.project:\n            noise = self.project_noise(noise, transform, mesh_path=mesh_path)\n\n        return image + self.weight * noise\n\n\nclass StyledConv(nn.Module):\n    def __init__(self, in_channel, out_channel, kernel_size, style_dim,\n                 upsample=False, blur_kernel=[1, 3, 3, 1], project_noise=False):\n        super().__init__()\n\n        self.conv = ModulatedConv2d(\n            in_channel,\n            out_channel,\n            kernel_size,\n            style_dim,\n            upsample=upsample,\n            blur_kernel=blur_kernel,\n        )\n\n        self.noise = NoiseInjection(project=project_noise)\n        self.bias = nn.Parameter(torch.zeros(1, out_channel, 1, 1))\n        self.activate = FusedLeakyReLU(out_channel)\n\n    def forward(self, input, style, noise=None, transform=None, mesh_path=None):\n        out = self.conv(input, style)\n        out = self.noise(out, noise=noise, transform=transform, mesh_path=mesh_path)\n        out = self.activate(out)\n\n        return out\n\n\nclass ToRGB(nn.Module):\n    def __init__(self, in_channel, style_dim, upsample=True, blur_kernel=[1, 3, 3, 1]):\n        super().__init__()\n\n        self.upsample = upsample\n        out_channels = 3\n        if upsample:\n            self.upsample = Upsample(blur_kernel)\n\n        self.conv = ModulatedConv2d(in_channel, out_channels, 1, style_dim, demodulate=False)\n        self.bias = nn.Parameter(torch.zeros(1, out_channels, 1, 1))\n\n    def forward(self, input, style, skip=None):\n        out = self.conv(input, style)\n        out = out + self.bias\n\n        if skip is not None:\n            if self.upsample:\n                skip = self.upsample(skip)\n\n            out = out + skip\n\n        return out\n\n\nclass ConvLayer(nn.Sequential):\n    def __init__(self, in_channel, out_channel, kernel_size, downsample=False,\n                 blur_kernel=[1, 3, 3, 1], bias=True, activate=True):\n        layers = []\n\n        if downsample:\n            factor = 2\n            p = (len(blur_kernel) - factor) + (kernel_size - 1)\n            pad0 = (p + 1) // 2\n            pad1 = p // 2\n\n            layers.append(Blur(blur_kernel, pad=(pad0, pad1)))\n\n            stride = 2\n            self.padding = 0\n\n        else:\n            stride = 1\n            self.padding = kernel_size // 2\n\n        layers.append(\n            EqualConv2d(\n                in_channel,\n                out_channel,\n                kernel_size,\n                padding=self.padding,\n                stride=stride,\n                bias=bias and not activate,\n            )\n        )\n\n        if activate:\n            layers.append(FusedLeakyReLU(out_channel, bias=bias))\n\n        super().__init__(*layers)\n\n\nclass Decoder(nn.Module):\n    def __init__(self, model_opt, blur_kernel=[1, 3, 3, 1]):\n        super().__init__()\n        # decoder mapping network\n        self.size = model_opt.size\n        self.style_dim = model_opt.style_dim * 2\n        thumb_im_size = model_opt.renderer_spatial_output_dim\n\n        layers = [PixelNorm(),\n                   EqualLinear(\n                       self.style_dim // 2, self.style_dim, lr_mul=model_opt.lr_mapping, activation=\"fused_lrelu\"\n                   )]\n\n        for i in range(4):\n            layers.append(\n                EqualLinear(\n                    self.style_dim, self.style_dim, lr_mul=model_opt.lr_mapping, activation=\"fused_lrelu\"\n                )\n            )\n\n        self.style = nn.Sequential(*layers)\n\n        # decoder network\n        self.channels = {\n            4: 512,\n            8: 512,\n            16: 512,\n            32: 512,\n            64: 256 * model_opt.channel_multiplier,\n            128: 128 * model_opt.channel_multiplier,\n            256: 64 * model_opt.channel_multiplier,\n            512: 32 * model_opt.channel_multiplier,\n            1024: 16 * model_opt.channel_multiplier,\n        }\n\n        decoder_in_size = model_opt.renderer_spatial_output_dim\n\n        # image decoder\n        self.log_size = int(math.log(self.size, 2))\n        self.log_in_size = int(math.log(decoder_in_size, 2))\n\n        self.conv1 = StyledConv(\n            model_opt.feature_encoder_in_channels,\n            self.channels[decoder_in_size], 3, self.style_dim, blur_kernel=blur_kernel,\n            project_noise=model_opt.project_noise)\n\n        self.to_rgb1 = ToRGB(self.channels[decoder_in_size], self.style_dim, upsample=False)\n\n        self.num_layers = (self.log_size - self.log_in_size) * 2 + 1\n\n        self.convs = nn.ModuleList()\n        self.upsamples = nn.ModuleList()\n        self.to_rgbs = nn.ModuleList()\n        self.noises = nn.Module()\n\n        in_channel = self.channels[decoder_in_size]\n\n        for layer_idx in range(self.num_layers):\n            res = (layer_idx + 2 * self.log_in_size + 1) // 2\n            shape = [1, 1, 2 ** (res), 2 ** (res)]\n            self.noises.register_buffer(f\"noise_{layer_idx}\", torch.randn(*shape))\n\n        for i in range(self.log_in_size+1, self.log_size + 1):\n            out_channel = self.channels[2 ** i]\n\n            self.convs.append(\n                StyledConv(in_channel, out_channel, 3, self.style_dim, upsample=True,\n                           blur_kernel=blur_kernel, project_noise=model_opt.project_noise)\n            )\n\n            self.convs.append(\n                StyledConv(out_channel, out_channel, 3, self.style_dim,\n                           blur_kernel=blur_kernel, project_noise=model_opt.project_noise)\n            )\n\n            self.to_rgbs.append(ToRGB(out_channel, self.style_dim))\n\n            in_channel = out_channel\n\n        self.n_latent = (self.log_size - self.log_in_size) * 2 + 2\n\n    def mean_latent(self, renderer_latent):\n        latent = self.style(renderer_latent).mean(0, keepdim=True)\n\n        return latent\n\n    def get_latent(self, input):\n        return self.style(input)\n\n    def styles_and_noise_forward(self, styles, noise, inject_index=None, truncation=1,\n                                 truncation_latent=None, input_is_latent=False,\n                                 randomize_noise=True):\n        if not input_is_latent:\n            styles = [self.style(s) for s in styles]\n\n        if noise is None:\n            if randomize_noise:\n                noise = [None] * self.num_layers\n            else:\n                noise = [\n                    getattr(self.noises, f\"noise_{i}\") for i in range(self.num_layers)\n                ]\n\n        if (truncation < 1):\n            style_t = []\n\n            for style in styles:\n                style_t.append(\n                    truncation_latent[1] + truncation * (style - truncation_latent[1])\n                )\n\n            styles = style_t\n\n        if len(styles) < 2:\n            inject_index = self.n_latent\n\n            if styles[0].ndim < 3:\n                latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)\n\n            else:\n                latent = styles[0]\n        else:\n            if inject_index is None:\n                inject_index = random.randint(1, self.n_latent - 1)\n\n            latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)\n            latent2 = styles[1].unsqueeze(1).repeat(1, self.n_latent - inject_index, 1)\n\n            latent = torch.cat([latent, latent2], 1)\n\n        return latent, noise\n\n    def forward(self, features, styles, rgbd_in=None, transform=None,\n                return_latents=False, inject_index=None, truncation=1,\n                truncation_latent=None, input_is_latent=False, noise=None,\n                randomize_noise=True, mesh_path=None):\n        latent, noise = self.styles_and_noise_forward(styles, noise, inject_index, truncation,\n                                                      truncation_latent, input_is_latent,\n                                                      randomize_noise)\n\n        out = self.conv1(features, latent[:, 0], noise=noise[0],\n                         transform=transform, mesh_path=mesh_path)\n\n        skip = self.to_rgb1(out, latent[:, 1], skip=rgbd_in)\n\n        i = 1\n        for conv1, conv2, noise1, noise2, to_rgb in zip(\n            self.convs[::2], self.convs[1::2], noise[1::2], noise[2::2], self.to_rgbs\n        ):\n            out = conv1(out, latent[:, i], noise=noise1,\n                           transform=transform, mesh_path=mesh_path)\n            out = conv2(out, latent[:, i + 1], noise=noise2,\n                                   transform=transform, mesh_path=mesh_path)\n            skip = to_rgb(out, latent[:, i + 2], skip=skip)\n\n            i += 2\n\n        out_latent = latent if return_latents else None\n        image = skip\n\n        return image, out_latent\n\n\nclass Generator(nn.Module):\n    def __init__(self, model_opt, renderer_opt, blur_kernel=[1, 3, 3, 1], ema=False, full_pipeline=True):\n        super().__init__()\n        self.size = model_opt.size\n        self.style_dim = model_opt.style_dim\n        self.num_layers = 1\n        self.train_renderer = not model_opt.freeze_renderer\n        self.full_pipeline = full_pipeline\n        model_opt.feature_encoder_in_channels = renderer_opt.width\n\n        if ema or 'is_test' in model_opt.keys():\n            self.is_train = False\n        else:\n            self.is_train = True\n\n        # volume renderer mapping_network\n        layers = []\n        for i in range(3):\n            layers.append(\n                MappingLinear(self.style_dim, self.style_dim, activation=\"fused_lrelu\")\n            )\n\n        self.style = nn.Sequential(*layers)\n\n        # volume renderer\n        thumb_im_size = model_opt.renderer_spatial_output_dim\n        self.renderer = VolumeFeatureRenderer(renderer_opt, style_dim=self.style_dim,\n                                              out_im_res=thumb_im_size)\n\n        if self.full_pipeline:\n            self.decoder = Decoder(model_opt)\n\n    def make_noise(self):\n        device = self.input.input.device\n\n        noises = [torch.randn(1, 1, 2 ** 2, 2 ** 2, device=device)]\n\n        for i in range(3, self.log_size + 1):\n            for _ in range(2):\n                noises.append(torch.randn(1, 1, 2 ** i, 2 ** i, device=device))\n\n        return noises\n\n    def mean_latent(self, n_latent, device):\n        latent_in = torch.randn(n_latent, self.style_dim, device=device)\n        renderer_latent = self.style(latent_in)\n        renderer_latent_mean = renderer_latent.mean(0, keepdim=True)\n        if self.full_pipeline:\n            decoder_latent_mean = self.decoder.mean_latent(renderer_latent)\n        else:\n            decoder_latent_mean = None\n\n        return [renderer_latent_mean, decoder_latent_mean]\n\n    def get_latent(self, input):\n        return self.style(input)\n\n    def styles_and_noise_forward(self, styles, inject_index=None, truncation=1,\n                                 truncation_latent=None, input_is_latent=False):\n        if not input_is_latent:\n            styles = [self.style(s) for s in styles]\n\n        if truncation < 1:\n            style_t = []\n\n            for style in styles:\n                style_t.append(\n                    truncation_latent[0] + truncation * (style - truncation_latent[0])\n                )\n\n            styles = style_t\n\n        return styles\n\n    def init_forward(self, styles, cam_poses, focals, near=0.88, far=1.12):\n        latent = self.styles_and_noise_forward(styles)\n\n        sdf, target_values = self.renderer.mlp_init_pass(cam_poses, focals, near, far, styles=latent[0])\n\n        return sdf, target_values\n\n    def forward(self, styles, cam_poses, focals, near=0.88, far=1.12, return_latents=False,\n                inject_index=None, truncation=1, truncation_latent=None,\n                input_is_latent=False, noise=None, randomize_noise=True,\n                return_sdf=False, return_xyz=False, return_eikonal=False,\n                project_noise=False, mesh_path=None):\n\n        # do not calculate renderer gradients if renderer weights are frozen\n        with torch.set_grad_enabled(self.is_train and self.train_renderer):\n            latent = self.styles_and_noise_forward(styles, inject_index, truncation,\n                                                   truncation_latent, input_is_latent)\n\n            thumb_rgb, features, sdf, mask, xyz, eikonal_term = self.renderer(cam_poses, focals, near, far, styles=latent[0], return_eikonal=return_eikonal)\n\n        if self.full_pipeline:\n            rgb, decoder_latent = self.decoder(features, latent,\n                                               transform=cam_poses if project_noise else None,\n                                               return_latents=return_latents,\n                                               inject_index=inject_index, truncation=truncation,\n                                               truncation_latent=truncation_latent, noise=noise,\n                                               input_is_latent=input_is_latent, randomize_noise=randomize_noise,\n                                               mesh_path=mesh_path)\n\n        else:\n            rgb = None\n\n        if return_latents:\n            return rgb, decoder_latent\n        else:\n            out = (rgb, thumb_rgb)\n            if return_xyz:\n                out += (xyz,)\n            if return_sdf:\n                out += (sdf,)\n            if return_eikonal:\n                out += (eikonal_term,)\n            if return_xyz:\n                out += (mask,)\n\n            return out\n\n############# Volume Renderer Building Blocks & Discriminator ##################\nclass VolumeRenderDiscConv2d(nn.Module):\n    def __init__(self, in_channels, out_channels, kernel_size, stride=1,\n                 padding=0, bias=True, activate=False):\n        super(VolumeRenderDiscConv2d, self).__init__()\n\n        self.conv = nn.Conv2d(in_channels, out_channels,\n                              kernel_size, stride, padding, bias=bias and not activate)\n\n        self.activate = activate\n        if self.activate:\n            self.activation = FusedLeakyReLU(out_channels, bias=bias, scale=1)\n            bias_init_coef = np.sqrt(1 / (in_channels * kernel_size * kernel_size))\n            nn.init.uniform_(self.activation.bias, a=-bias_init_coef, b=bias_init_coef)\n\n\n    def forward(self, input):\n        \"\"\"\n        input_tensor_shape: (N, C_in,H,W)\n        output_tensor_shape: (N,C_out,H_out,W_out）\n        :return: Conv2d + activation Result\n        \"\"\"\n        out = self.conv(input)\n        if self.activate:\n            out = self.activation(out)\n\n        return out\n\n\nclass AddCoords(nn.Module):\n    def __init__(self):\n        super(AddCoords, self).__init__()\n\n    def forward(self, input_tensor):\n        \"\"\"\n        :param input_tensor: shape (N, C_in, H, W)\n        :return:\n        \"\"\"\n        batch_size_shape, channel_in_shape, dim_y, dim_x = input_tensor.shape\n        xx_channel = torch.arange(dim_x, dtype=torch.float32, device=input_tensor.device).repeat(1,1,dim_y,1)\n        yy_channel = torch.arange(dim_y, dtype=torch.float32, device=input_tensor.device).repeat(1,1,dim_x,1).transpose(2,3)\n\n        xx_channel = xx_channel / (dim_x - 1)\n        yy_channel = yy_channel / (dim_y - 1)\n\n        xx_channel = xx_channel * 2 - 1\n        yy_channel = yy_channel * 2 - 1\n\n        xx_channel = xx_channel.repeat(batch_size_shape, 1, 1, 1)\n        yy_channel = yy_channel.repeat(batch_size_shape, 1, 1, 1)\n        out = torch.cat([input_tensor, yy_channel, xx_channel], dim=1)\n\n        return out\n\n\nclass CoordConv2d(nn.Module):\n    def __init__(self, in_channels, out_channels, kernel_size, stride=1,\n                 padding=0, bias=True):\n        super(CoordConv2d, self).__init__()\n\n        self.addcoords = AddCoords()\n        self.conv = nn.Conv2d(in_channels + 2, out_channels,\n                              kernel_size, stride=stride, padding=padding, bias=bias)\n\n    def forward(self, input_tensor):\n        \"\"\"\n        input_tensor_shape: (N, C_in,H,W)\n        output_tensor_shape: N,C_out,H_out,W_out）\n        :return: CoordConv2d Result\n        \"\"\"\n        out = self.addcoords(input_tensor)\n        out = self.conv(out)\n\n        return out\n\n\nclass CoordConvLayer(nn.Module):\n    def __init__(self, in_channel, out_channel, kernel_size, bias=True, activate=True):\n        super(CoordConvLayer, self).__init__()\n        layers = []\n        stride = 1\n        self.activate = activate\n        self.padding = kernel_size // 2 if kernel_size > 2 else 0\n\n        self.conv = CoordConv2d(in_channel, out_channel, kernel_size,\n                                padding=self.padding, stride=stride,\n                                bias=bias and not activate)\n\n        if activate:\n            self.activation = FusedLeakyReLU(out_channel, bias=bias, scale=1)\n\n        bias_init_coef = np.sqrt(1 / (in_channel * kernel_size * kernel_size))\n        nn.init.uniform_(self.activation.bias, a=-bias_init_coef, b=bias_init_coef)\n\n    def forward(self, input):\n        out = self.conv(input)\n        if self.activate:\n            out = self.activation(out)\n\n        return out\n\n\nclass VolumeRenderResBlock(nn.Module):\n    def __init__(self, in_channel, out_channel):\n        super().__init__()\n\n        self.conv1 = CoordConvLayer(in_channel, out_channel, 3)\n        self.conv2 = CoordConvLayer(out_channel, out_channel, 3)\n        self.pooling = nn.AvgPool2d(2)\n        self.downsample = nn.AvgPool2d(2)\n        if out_channel != in_channel:\n            self.skip = VolumeRenderDiscConv2d(in_channel, out_channel, 1)\n        else:\n            self.skip = None\n\n    def forward(self, input):\n        out = self.conv1(input)\n        out = self.conv2(out)\n        out = self.pooling(out)\n\n        downsample_in = self.downsample(input)\n        if self.skip != None:\n            skip_in = self.skip(downsample_in)\n        else:\n            skip_in = downsample_in\n\n        out = (out + skip_in) / math.sqrt(2)\n\n        return out\n\n\nclass VolumeRenderDiscriminator(nn.Module):\n    def __init__(self, opt):\n        super().__init__()\n        init_size = opt.renderer_spatial_output_dim\n        self.viewpoint_loss = not opt.no_viewpoint_loss\n        final_out_channel = 3 if self.viewpoint_loss else 1\n        channels = {\n            2: 400,\n            4: 400,\n            8: 400,\n            16: 400,\n            32: 256,\n            64: 128,\n            128: 64,\n        }\n\n        convs = [VolumeRenderDiscConv2d(3, channels[init_size], 1, activate=True)]\n\n        log_size = int(math.log(init_size, 2))\n\n        in_channel = channels[init_size]\n\n        for i in range(log_size-1, 0, -1):\n            out_channel = channels[2 ** i]\n\n            convs.append(VolumeRenderResBlock(in_channel, out_channel))\n\n            in_channel = out_channel\n\n        self.convs = nn.Sequential(*convs)\n\n        self.final_conv = VolumeRenderDiscConv2d(in_channel, final_out_channel, 2)\n\n    def forward(self, input):\n        out = self.convs(input)\n        out = self.final_conv(out)\n        gan_preds = out[:,0:1]\n        gan_preds = gan_preds.view(-1, 1)\n        if self.viewpoint_loss:\n            viewpoints_preds = out[:,1:]\n            viewpoints_preds = viewpoints_preds.view(-1,2)\n        else:\n            viewpoints_preds = None\n\n        return gan_preds, viewpoints_preds\n\n######################### StyleGAN Discriminator ########################\nclass ResBlock(nn.Module):\n    def __init__(self, in_channel, out_channel, blur_kernel=[1, 3, 3, 1], merge=False):\n        super().__init__()\n\n        self.conv1 = ConvLayer(2 * in_channel if merge else in_channel, in_channel, 3)\n        self.conv2 = ConvLayer(in_channel, out_channel, 3, downsample=True)\n        self.skip = ConvLayer(2 * in_channel if merge else in_channel, out_channel,\n                              1, downsample=True, activate=False, bias=False)\n\n    def forward(self, input):\n        out = self.conv1(input)\n        out = self.conv2(out)\n        out = (out + self.skip(input)) / math.sqrt(2)\n\n        return out\n\n\nclass Discriminator(nn.Module):\n    def __init__(self, opt, blur_kernel=[1, 3, 3, 1]):\n        super().__init__()\n        init_size = opt.size\n\n        channels = {\n            4: 512,\n            8: 512,\n            16: 512,\n            32: 512,\n            64: 256 * opt.channel_multiplier,\n            128: 128 * opt.channel_multiplier,\n            256: 64 * opt.channel_multiplier,\n            512: 32 * opt.channel_multiplier,\n            1024: 16 * opt.channel_multiplier,\n        }\n\n        convs = [ConvLayer(3, channels[init_size], 1)]\n\n        log_size = int(math.log(init_size, 2))\n\n        in_channel = channels[init_size]\n\n        for i in range(log_size, 2, -1):\n            out_channel = channels[2 ** (i - 1)]\n\n            convs.append(ResBlock(in_channel, out_channel, blur_kernel))\n\n            in_channel = out_channel\n\n        self.convs = nn.Sequential(*convs)\n\n        self.stddev_group = 4\n        self.stddev_feat = 1\n\n        # minibatch discrimination\n        in_channel += 1\n\n        self.final_conv = ConvLayer(in_channel, channels[4], 3)\n        self.final_linear = nn.Sequential(\n            EqualLinear(channels[4] * 4 * 4, channels[4], activation=\"fused_lrelu\"),\n            EqualLinear(channels[4], 1),\n        )\n\n    def forward(self, input):\n        out = self.convs(input)\n\n        # minibatch discrimination\n        batch, channel, height, width = out.shape\n        group = min(batch, self.stddev_group)\n        if batch % group != 0:\n            group = 3 if batch % 3 == 0 else 2\n\n        stddev = out.view(\n            group, -1, self.stddev_feat, channel // self.stddev_feat, height, width\n        )\n        stddev = torch.sqrt(stddev.var(0, unbiased=False) + 1e-8)\n        stddev = stddev.mean([2, 3, 4], keepdims=True).squeeze(2)\n        stddev = stddev.repeat(group, 1, height, width)\n        final_out = torch.cat([out, stddev], 1)\n\n        # final layers\n        final_out = self.final_conv(final_out)\n        final_out = final_out.view(batch, -1)\n        final_out = self.final_linear(final_out)\n        gan_preds = final_out[:,:1]\n\n        return gan_preds\n"
  },
  {
    "path": "op/__init__.py",
    "content": "from .fused_act import FusedLeakyReLU, fused_leaky_relu\nfrom .upfirdn2d import upfirdn2d\n"
  },
  {
    "path": "op/fused_act.py",
    "content": "import os\r\n\r\nimport torch\r\nfrom torch import nn\r\nfrom torch.nn import functional as F\r\nfrom torch.autograd import Function\r\nfrom torch.utils.cpp_extension import load\r\n\r\n\r\nmodule_path = os.path.dirname(__file__)\r\nfused = load(\r\n    \"fused\",\r\n    sources=[\r\n        os.path.join(module_path, \"fused_bias_act.cpp\"),\r\n        os.path.join(module_path, \"fused_bias_act_kernel.cu\"),\r\n    ],\r\n)\r\n\r\n\r\nclass FusedLeakyReLUFunctionBackward(Function):\r\n    @staticmethod\r\n    def forward(ctx, grad_output, out, bias, negative_slope, scale):\r\n        ctx.save_for_backward(out)\r\n        ctx.negative_slope = negative_slope\r\n        ctx.scale = scale\r\n\r\n        empty = grad_output.new_empty(0)\r\n\r\n        grad_input = fused.fused_bias_act(\r\n            grad_output, empty, out, 3, 1, negative_slope, scale\r\n        )\r\n\r\n        dim = [0]\r\n\r\n        if grad_input.ndim > 2:\r\n            dim += list(range(2, grad_input.ndim))\r\n\r\n        if bias:\r\n            grad_bias = grad_input.sum(dim).detach()\r\n\r\n        else:\r\n            grad_bias = empty\r\n\r\n        return grad_input, grad_bias\r\n\r\n    @staticmethod\r\n    def backward(ctx, gradgrad_input, gradgrad_bias):\r\n        out, = ctx.saved_tensors\r\n        gradgrad_out = fused.fused_bias_act(\r\n            gradgrad_input, gradgrad_bias, out, 3, 1, ctx.negative_slope, ctx.scale\r\n        )\r\n\r\n        return gradgrad_out, None, None, None, None\r\n\r\n\r\nclass FusedLeakyReLUFunction(Function):\r\n    @staticmethod\r\n    def forward(ctx, input, bias, negative_slope, scale):\r\n        empty = input.new_empty(0)\r\n\r\n        ctx.bias = bias is not None\r\n\r\n        if bias is None:\r\n            bias = empty\r\n\r\n        out = fused.fused_bias_act(input, bias, empty, 3, 0, negative_slope, scale)\r\n        ctx.save_for_backward(out)\r\n        ctx.negative_slope = negative_slope\r\n        ctx.scale = scale\r\n\r\n        return out\r\n\r\n    @staticmethod\r\n    def backward(ctx, grad_output):\r\n        out, = ctx.saved_tensors\r\n\r\n        grad_input, grad_bias = FusedLeakyReLUFunctionBackward.apply(\r\n            grad_output, out, ctx.bias, ctx.negative_slope, ctx.scale\r\n        )\r\n\r\n        if not ctx.bias:\r\n            grad_bias = None\r\n\r\n        return grad_input, grad_bias, None, None\r\n\r\n\r\nclass FusedLeakyReLU(nn.Module):\r\n    def __init__(self, channel, bias=True, negative_slope=0.2, scale=2 ** 0.5):\r\n        super().__init__()\r\n\r\n        if bias:\r\n            self.bias = nn.Parameter(torch.zeros(channel))\r\n\r\n        else:\r\n            self.bias = None\r\n\r\n        self.negative_slope = negative_slope\r\n        self.scale = scale\r\n\r\n    def forward(self, input):\r\n        return fused_leaky_relu(input, self.bias, self.negative_slope, self.scale)\r\n\r\n\r\ndef fused_leaky_relu(input, bias=None, negative_slope=0.2, scale=2 ** 0.5):\r\n    if input.device.type == \"cpu\":\r\n        if bias is not None:\r\n            rest_dim = [1] * (input.ndim - bias.ndim - 1)\r\n            return (\r\n                F.leaky_relu(\r\n                    input + bias.view(1, bias.shape[0], *rest_dim), negative_slope=0.2\r\n                )\r\n                * scale\r\n            )\r\n\r\n        else:\r\n            return F.leaky_relu(input, negative_slope=0.2) * scale\r\n\r\n    else:\r\n        return FusedLeakyReLUFunction.apply(input, bias, negative_slope, scale)\r\n"
  },
  {
    "path": "op/fused_bias_act.cpp",
    "content": "#include <torch/extension.h>\r\n\r\n\r\ntorch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer,\r\n    int act, int grad, float alpha, float scale);\r\n\r\n#define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x \" must be a CUDA tensor\")\r\n#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x \" must be contiguous\")\r\n#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)\r\n\r\ntorch::Tensor fused_bias_act(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer,\r\n    int act, int grad, float alpha, float scale) {\r\n    CHECK_CUDA(input);\r\n    CHECK_CUDA(bias);\r\n\r\n    return fused_bias_act_op(input, bias, refer, act, grad, alpha, scale);\r\n}\r\n\r\nPYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {\r\n    m.def(\"fused_bias_act\", &fused_bias_act, \"fused bias act (CUDA)\");\r\n}"
  },
  {
    "path": "op/fused_bias_act_kernel.cu",
    "content": "// Copyright (c) 2019, NVIDIA Corporation. All rights reserved.\r\n//\r\n// This work is made available under the Nvidia Source Code License-NC.\r\n// To view a copy of this license, visit\r\n// https://nvlabs.github.io/stylegan2/license.html\r\n\r\n#include <torch/types.h>\r\n\r\n#include <ATen/ATen.h>\r\n#include <ATen/AccumulateType.h>\r\n#include <ATen/cuda/CUDAContext.h>\r\n#include <ATen/cuda/CUDAApplyUtils.cuh>\r\n\r\n#include <cuda.h>\r\n#include <cuda_runtime.h>\r\n\r\n\r\ntemplate <typename scalar_t>\r\nstatic __global__ void fused_bias_act_kernel(scalar_t* out, const scalar_t* p_x, const scalar_t* p_b, const scalar_t* p_ref,\r\n    int act, int grad, scalar_t alpha, scalar_t scale, int loop_x, int size_x, int step_b, int size_b, int use_bias, int use_ref) {\r\n    int xi = blockIdx.x * loop_x * blockDim.x + threadIdx.x;\r\n\r\n    scalar_t zero = 0.0;\r\n\r\n    for (int loop_idx = 0; loop_idx < loop_x && xi < size_x; loop_idx++, xi += blockDim.x) {\r\n        scalar_t x = p_x[xi];\r\n\r\n        if (use_bias) {\r\n            x += p_b[(xi / step_b) % size_b];\r\n        }\r\n\r\n        scalar_t ref = use_ref ? p_ref[xi] : zero;\r\n\r\n        scalar_t y;\r\n\r\n        switch (act * 10 + grad) {\r\n            default:\r\n            case 10: y = x; break;\r\n            case 11: y = x; break;\r\n            case 12: y = 0.0; break;\r\n\r\n            case 30: y = (x > 0.0) ? x : x * alpha; break;\r\n            case 31: y = (ref > 0.0) ? x : x * alpha; break;\r\n            case 32: y = 0.0; break;\r\n        }\r\n\r\n        out[xi] = y * scale;\r\n    }\r\n}\r\n\r\n\r\ntorch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer,\r\n    int act, int grad, float alpha, float scale) {\r\n    int curDevice = -1;\r\n    cudaGetDevice(&curDevice);\r\n    cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice);\r\n\r\n    auto x = input.contiguous();\r\n    auto b = bias.contiguous();\r\n    auto ref = refer.contiguous();\r\n\r\n    int use_bias = b.numel() ? 1 : 0;\r\n    int use_ref = ref.numel() ? 1 : 0;\r\n\r\n    int size_x = x.numel();\r\n    int size_b = b.numel();\r\n    int step_b = 1;\r\n\r\n    for (int i = 1 + 1; i < x.dim(); i++) {\r\n        step_b *= x.size(i);\r\n    }\r\n\r\n    int loop_x = 4;\r\n    int block_size = 4 * 32;\r\n    int grid_size = (size_x - 1) / (loop_x * block_size) + 1;\r\n\r\n    auto y = torch::empty_like(x);\r\n\r\n    AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), \"fused_bias_act_kernel\", [&] {\r\n        fused_bias_act_kernel<scalar_t><<<grid_size, block_size, 0, stream>>>(\r\n            y.data_ptr<scalar_t>(),\r\n            x.data_ptr<scalar_t>(),\r\n            b.data_ptr<scalar_t>(),\r\n            ref.data_ptr<scalar_t>(),\r\n            act,\r\n            grad,\r\n            alpha,\r\n            scale,\r\n            loop_x,\r\n            size_x,\r\n            step_b,\r\n            size_b,\r\n            use_bias,\r\n            use_ref\r\n        );\r\n    });\r\n\r\n    return y;\r\n}"
  },
  {
    "path": "op/upfirdn2d.cpp",
    "content": "#include <torch/extension.h>\r\n\r\n\r\ntorch::Tensor upfirdn2d_op(const torch::Tensor& input, const torch::Tensor& kernel,\r\n                            int up_x, int up_y, int down_x, int down_y,\r\n                            int pad_x0, int pad_x1, int pad_y0, int pad_y1);\r\n\r\n#define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x \" must be a CUDA tensor\")\r\n#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x \" must be contiguous\")\r\n#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)\r\n\r\ntorch::Tensor upfirdn2d(const torch::Tensor& input, const torch::Tensor& kernel,\r\n                        int up_x, int up_y, int down_x, int down_y,\r\n                        int pad_x0, int pad_x1, int pad_y0, int pad_y1) {\r\n    CHECK_CUDA(input);\r\n    CHECK_CUDA(kernel);\r\n\r\n    return upfirdn2d_op(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1);\r\n}\r\n\r\nPYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {\r\n    m.def(\"upfirdn2d\", &upfirdn2d, \"upfirdn2d (CUDA)\");\r\n}"
  },
  {
    "path": "op/upfirdn2d.py",
    "content": "import os\r\n\r\nimport torch\r\nfrom torch.nn import functional as F\r\nfrom torch.autograd import Function\r\nfrom torch.utils.cpp_extension import load\r\nfrom pdb import set_trace as st\r\n\r\n\r\nmodule_path = os.path.dirname(__file__)\r\nupfirdn2d_op = load(\r\n    \"upfirdn2d\",\r\n    sources=[\r\n        os.path.join(module_path, \"upfirdn2d.cpp\"),\r\n        os.path.join(module_path, \"upfirdn2d_kernel.cu\"),\r\n    ],\r\n)\r\n\r\n\r\nclass UpFirDn2dBackward(Function):\r\n    @staticmethod\r\n    def forward(\r\n        ctx, grad_output, kernel, grad_kernel, up, down, pad, g_pad, in_size, out_size\r\n    ):\r\n\r\n        up_x, up_y = up\r\n        down_x, down_y = down\r\n        g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1 = g_pad\r\n\r\n        grad_output = grad_output.reshape(-1, out_size[0], out_size[1], 1)\r\n\r\n        grad_input = upfirdn2d_op.upfirdn2d(\r\n            grad_output,\r\n            grad_kernel,\r\n            down_x,\r\n            down_y,\r\n            up_x,\r\n            up_y,\r\n            g_pad_x0,\r\n            g_pad_x1,\r\n            g_pad_y0,\r\n            g_pad_y1,\r\n        )\r\n        grad_input = grad_input.view(in_size[0], in_size[1], in_size[2], in_size[3])\r\n\r\n        ctx.save_for_backward(kernel)\r\n\r\n        pad_x0, pad_x1, pad_y0, pad_y1 = pad\r\n\r\n        ctx.up_x = up_x\r\n        ctx.up_y = up_y\r\n        ctx.down_x = down_x\r\n        ctx.down_y = down_y\r\n        ctx.pad_x0 = pad_x0\r\n        ctx.pad_x1 = pad_x1\r\n        ctx.pad_y0 = pad_y0\r\n        ctx.pad_y1 = pad_y1\r\n        ctx.in_size = in_size\r\n        ctx.out_size = out_size\r\n\r\n        return grad_input\r\n\r\n    @staticmethod\r\n    def backward(ctx, gradgrad_input):\r\n        kernel, = ctx.saved_tensors\r\n\r\n        gradgrad_input = gradgrad_input.reshape(-1, ctx.in_size[2], ctx.in_size[3], 1)\r\n\r\n        gradgrad_out = upfirdn2d_op.upfirdn2d(\r\n            gradgrad_input,\r\n            kernel,\r\n            ctx.up_x,\r\n            ctx.up_y,\r\n            ctx.down_x,\r\n            ctx.down_y,\r\n            ctx.pad_x0,\r\n            ctx.pad_x1,\r\n            ctx.pad_y0,\r\n            ctx.pad_y1,\r\n        )\r\n        # gradgrad_out = gradgrad_out.view(ctx.in_size[0], ctx.out_size[0], ctx.out_size[1], ctx.in_size[3])\r\n        gradgrad_out = gradgrad_out.view(\r\n            ctx.in_size[0], ctx.in_size[1], ctx.out_size[0], ctx.out_size[1]\r\n        )\r\n\r\n        return gradgrad_out, None, None, None, None, None, None, None, None\r\n\r\n\r\nclass UpFirDn2d(Function):\r\n    @staticmethod\r\n    def forward(ctx, input, kernel, up, down, pad):\r\n        up_x, up_y = up\r\n        down_x, down_y = down\r\n        pad_x0, pad_x1, pad_y0, pad_y1 = pad\r\n\r\n        kernel_h, kernel_w = kernel.shape\r\n        batch, channel, in_h, in_w = input.shape\r\n        ctx.in_size = input.shape\r\n\r\n        input = input.reshape(-1, in_h, in_w, 1)\r\n\r\n        ctx.save_for_backward(kernel, torch.flip(kernel, [0, 1]))\r\n\r\n        out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1\r\n        out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1\r\n        ctx.out_size = (out_h, out_w)\r\n\r\n        ctx.up = (up_x, up_y)\r\n        ctx.down = (down_x, down_y)\r\n        ctx.pad = (pad_x0, pad_x1, pad_y0, pad_y1)\r\n\r\n        g_pad_x0 = kernel_w - pad_x0 - 1\r\n        g_pad_y0 = kernel_h - pad_y0 - 1\r\n        g_pad_x1 = in_w * up_x - out_w * down_x + pad_x0 - up_x + 1\r\n        g_pad_y1 = in_h * up_y - out_h * down_y + pad_y0 - up_y + 1\r\n\r\n        ctx.g_pad = (g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1)\r\n\r\n        out = upfirdn2d_op.upfirdn2d(\r\n            input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1\r\n        )\r\n        # out = out.view(major, out_h, out_w, minor)\r\n        out = out.view(-1, channel, out_h, out_w)\r\n\r\n        return out\r\n\r\n    @staticmethod\r\n    def backward(ctx, grad_output):\r\n        kernel, grad_kernel = ctx.saved_tensors\r\n\r\n        grad_input = UpFirDn2dBackward.apply(\r\n            grad_output,\r\n            kernel,\r\n            grad_kernel,\r\n            ctx.up,\r\n            ctx.down,\r\n            ctx.pad,\r\n            ctx.g_pad,\r\n            ctx.in_size,\r\n            ctx.out_size,\r\n        )\r\n\r\n        return grad_input, None, None, None, None\r\n\r\n\r\ndef upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)):\r\n    if input.device.type == \"cpu\":\r\n        out = upfirdn2d_native(\r\n            input, kernel, up, up, down, down, pad[0], pad[1], pad[0], pad[1]\r\n        )\r\n\r\n    else:\r\n        out = UpFirDn2d.apply(\r\n            input, kernel, (up, up), (down, down), (pad[0], pad[1], pad[0], pad[1])\r\n        )\r\n\r\n    return out\r\n\r\n\r\ndef upfirdn2d_native(\r\n    input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1\r\n):\r\n    _, channel, in_h, in_w = input.shape\r\n    input = input.reshape(-1, in_h, in_w, 1)\r\n\r\n    _, in_h, in_w, minor = input.shape\r\n    kernel_h, kernel_w = kernel.shape\r\n\r\n    out = input.view(-1, in_h, 1, in_w, 1, minor)\r\n    out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1])\r\n    out = out.view(-1, in_h * up_y, in_w * up_x, minor)\r\n\r\n    out = F.pad(\r\n        out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)]\r\n    )\r\n    out = out[\r\n        :,\r\n        max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0),\r\n        max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0),\r\n        :,\r\n    ]\r\n\r\n    out = out.permute(0, 3, 1, 2)\r\n    out = out.reshape(\r\n        [-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1]\r\n    )\r\n    w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w)\r\n    out = F.conv2d(out, w)\r\n    out = out.reshape(\r\n        -1,\r\n        minor,\r\n        in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1,\r\n        in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1,\r\n    )\r\n    out = out.permute(0, 2, 3, 1)\r\n    out = out[:, ::down_y, ::down_x, :]\r\n\r\n    out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1\r\n    out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1\r\n\r\n    return out.view(-1, channel, out_h, out_w)\r\n"
  },
  {
    "path": "op/upfirdn2d_kernel.cu",
    "content": "// Copyright (c) 2019, NVIDIA Corporation. All rights reserved.\r\n//\r\n// This work is made available under the Nvidia Source Code License-NC.\r\n// To view a copy of this license, visit\r\n// https://nvlabs.github.io/stylegan2/license.html\r\n\r\n#include <torch/types.h>\r\n\r\n#include <ATen/ATen.h>\r\n#include <ATen/AccumulateType.h>\r\n#include <ATen/cuda/CUDAApplyUtils.cuh>\r\n#include <ATen/cuda/CUDAContext.h>\r\n\r\n#include <cuda.h>\r\n#include <cuda_runtime.h>\r\n\r\nstatic __host__ __device__ __forceinline__ int floor_div(int a, int b) {\r\n  int c = a / b;\r\n\r\n  if (c * b > a) {\r\n    c--;\r\n  }\r\n\r\n  return c;\r\n}\r\n\r\nstruct UpFirDn2DKernelParams {\r\n  int up_x;\r\n  int up_y;\r\n  int down_x;\r\n  int down_y;\r\n  int pad_x0;\r\n  int pad_x1;\r\n  int pad_y0;\r\n  int pad_y1;\r\n\r\n  int major_dim;\r\n  int in_h;\r\n  int in_w;\r\n  int minor_dim;\r\n  int kernel_h;\r\n  int kernel_w;\r\n  int out_h;\r\n  int out_w;\r\n  int loop_major;\r\n  int loop_x;\r\n};\r\n\r\ntemplate <typename scalar_t>\r\n__global__ void upfirdn2d_kernel_large(scalar_t *out, const scalar_t *input,\r\n                                       const scalar_t *kernel,\r\n                                       const UpFirDn2DKernelParams p) {\r\n  int minor_idx = blockIdx.x * blockDim.x + threadIdx.x;\r\n  int out_y = minor_idx / p.minor_dim;\r\n  minor_idx -= out_y * p.minor_dim;\r\n  int out_x_base = blockIdx.y * p.loop_x * blockDim.y + threadIdx.y;\r\n  int major_idx_base = blockIdx.z * p.loop_major;\r\n\r\n  if (out_x_base >= p.out_w || out_y >= p.out_h ||\r\n      major_idx_base >= p.major_dim) {\r\n    return;\r\n  }\r\n\r\n  int mid_y = out_y * p.down_y + p.up_y - 1 - p.pad_y0;\r\n  int in_y = min(max(floor_div(mid_y, p.up_y), 0), p.in_h);\r\n  int h = min(max(floor_div(mid_y + p.kernel_h, p.up_y), 0), p.in_h) - in_y;\r\n  int kernel_y = mid_y + p.kernel_h - (in_y + 1) * p.up_y;\r\n\r\n  for (int loop_major = 0, major_idx = major_idx_base;\r\n       loop_major < p.loop_major && major_idx < p.major_dim;\r\n       loop_major++, major_idx++) {\r\n    for (int loop_x = 0, out_x = out_x_base;\r\n         loop_x < p.loop_x && out_x < p.out_w; loop_x++, out_x += blockDim.y) {\r\n      int mid_x = out_x * p.down_x + p.up_x - 1 - p.pad_x0;\r\n      int in_x = min(max(floor_div(mid_x, p.up_x), 0), p.in_w);\r\n      int w = min(max(floor_div(mid_x + p.kernel_w, p.up_x), 0), p.in_w) - in_x;\r\n      int kernel_x = mid_x + p.kernel_w - (in_x + 1) * p.up_x;\r\n\r\n      const scalar_t *x_p =\r\n          &input[((major_idx * p.in_h + in_y) * p.in_w + in_x) * p.minor_dim +\r\n                 minor_idx];\r\n      const scalar_t *k_p = &kernel[kernel_y * p.kernel_w + kernel_x];\r\n      int x_px = p.minor_dim;\r\n      int k_px = -p.up_x;\r\n      int x_py = p.in_w * p.minor_dim;\r\n      int k_py = -p.up_y * p.kernel_w;\r\n\r\n      scalar_t v = 0.0f;\r\n\r\n      for (int y = 0; y < h; y++) {\r\n        for (int x = 0; x < w; x++) {\r\n          v += static_cast<scalar_t>(*x_p) * static_cast<scalar_t>(*k_p);\r\n          x_p += x_px;\r\n          k_p += k_px;\r\n        }\r\n\r\n        x_p += x_py - w * x_px;\r\n        k_p += k_py - w * k_px;\r\n      }\r\n\r\n      out[((major_idx * p.out_h + out_y) * p.out_w + out_x) * p.minor_dim +\r\n          minor_idx] = v;\r\n    }\r\n  }\r\n}\r\n\r\ntemplate <typename scalar_t, int up_x, int up_y, int down_x, int down_y,\r\n          int kernel_h, int kernel_w, int tile_out_h, int tile_out_w>\r\n__global__ void upfirdn2d_kernel(scalar_t *out, const scalar_t *input,\r\n                                 const scalar_t *kernel,\r\n                                 const UpFirDn2DKernelParams p) {\r\n  const int tile_in_h = ((tile_out_h - 1) * down_y + kernel_h - 1) / up_y + 1;\r\n  const int tile_in_w = ((tile_out_w - 1) * down_x + kernel_w - 1) / up_x + 1;\r\n\r\n  __shared__ volatile float sk[kernel_h][kernel_w];\r\n  __shared__ volatile float sx[tile_in_h][tile_in_w];\r\n\r\n  int minor_idx = blockIdx.x;\r\n  int tile_out_y = minor_idx / p.minor_dim;\r\n  minor_idx -= tile_out_y * p.minor_dim;\r\n  tile_out_y *= tile_out_h;\r\n  int tile_out_x_base = blockIdx.y * p.loop_x * tile_out_w;\r\n  int major_idx_base = blockIdx.z * p.loop_major;\r\n\r\n  if (tile_out_x_base >= p.out_w | tile_out_y >= p.out_h |\r\n      major_idx_base >= p.major_dim) {\r\n    return;\r\n  }\r\n\r\n  for (int tap_idx = threadIdx.x; tap_idx < kernel_h * kernel_w;\r\n       tap_idx += blockDim.x) {\r\n    int ky = tap_idx / kernel_w;\r\n    int kx = tap_idx - ky * kernel_w;\r\n    scalar_t v = 0.0;\r\n\r\n    if (kx < p.kernel_w & ky < p.kernel_h) {\r\n      v = kernel[(p.kernel_h - 1 - ky) * p.kernel_w + (p.kernel_w - 1 - kx)];\r\n    }\r\n\r\n    sk[ky][kx] = v;\r\n  }\r\n\r\n  for (int loop_major = 0, major_idx = major_idx_base;\r\n       loop_major < p.loop_major & major_idx < p.major_dim;\r\n       loop_major++, major_idx++) {\r\n    for (int loop_x = 0, tile_out_x = tile_out_x_base;\r\n         loop_x < p.loop_x & tile_out_x < p.out_w;\r\n         loop_x++, tile_out_x += tile_out_w) {\r\n      int tile_mid_x = tile_out_x * down_x + up_x - 1 - p.pad_x0;\r\n      int tile_mid_y = tile_out_y * down_y + up_y - 1 - p.pad_y0;\r\n      int tile_in_x = floor_div(tile_mid_x, up_x);\r\n      int tile_in_y = floor_div(tile_mid_y, up_y);\r\n\r\n      __syncthreads();\r\n\r\n      for (int in_idx = threadIdx.x; in_idx < tile_in_h * tile_in_w;\r\n           in_idx += blockDim.x) {\r\n        int rel_in_y = in_idx / tile_in_w;\r\n        int rel_in_x = in_idx - rel_in_y * tile_in_w;\r\n        int in_x = rel_in_x + tile_in_x;\r\n        int in_y = rel_in_y + tile_in_y;\r\n\r\n        scalar_t v = 0.0;\r\n\r\n        if (in_x >= 0 & in_y >= 0 & in_x < p.in_w & in_y < p.in_h) {\r\n          v = input[((major_idx * p.in_h + in_y) * p.in_w + in_x) *\r\n                        p.minor_dim +\r\n                    minor_idx];\r\n        }\r\n\r\n        sx[rel_in_y][rel_in_x] = v;\r\n      }\r\n\r\n      __syncthreads();\r\n      for (int out_idx = threadIdx.x; out_idx < tile_out_h * tile_out_w;\r\n           out_idx += blockDim.x) {\r\n        int rel_out_y = out_idx / tile_out_w;\r\n        int rel_out_x = out_idx - rel_out_y * tile_out_w;\r\n        int out_x = rel_out_x + tile_out_x;\r\n        int out_y = rel_out_y + tile_out_y;\r\n\r\n        int mid_x = tile_mid_x + rel_out_x * down_x;\r\n        int mid_y = tile_mid_y + rel_out_y * down_y;\r\n        int in_x = floor_div(mid_x, up_x);\r\n        int in_y = floor_div(mid_y, up_y);\r\n        int rel_in_x = in_x - tile_in_x;\r\n        int rel_in_y = in_y - tile_in_y;\r\n        int kernel_x = (in_x + 1) * up_x - mid_x - 1;\r\n        int kernel_y = (in_y + 1) * up_y - mid_y - 1;\r\n\r\n        scalar_t v = 0.0;\r\n\r\n#pragma unroll\r\n        for (int y = 0; y < kernel_h / up_y; y++)\r\n#pragma unroll\r\n          for (int x = 0; x < kernel_w / up_x; x++)\r\n            v += sx[rel_in_y + y][rel_in_x + x] *\r\n                 sk[kernel_y + y * up_y][kernel_x + x * up_x];\r\n\r\n        if (out_x < p.out_w & out_y < p.out_h) {\r\n          out[((major_idx * p.out_h + out_y) * p.out_w + out_x) * p.minor_dim +\r\n              minor_idx] = v;\r\n        }\r\n      }\r\n    }\r\n  }\r\n}\r\n\r\ntorch::Tensor upfirdn2d_op(const torch::Tensor &input,\r\n                           const torch::Tensor &kernel, int up_x, int up_y,\r\n                           int down_x, int down_y, int pad_x0, int pad_x1,\r\n                           int pad_y0, int pad_y1) {\r\n  int curDevice = -1;\r\n  cudaGetDevice(&curDevice);\r\n  cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice);\r\n\r\n  UpFirDn2DKernelParams p;\r\n\r\n  auto x = input.contiguous();\r\n  auto k = kernel.contiguous();\r\n\r\n  p.major_dim = x.size(0);\r\n  p.in_h = x.size(1);\r\n  p.in_w = x.size(2);\r\n  p.minor_dim = x.size(3);\r\n  p.kernel_h = k.size(0);\r\n  p.kernel_w = k.size(1);\r\n  p.up_x = up_x;\r\n  p.up_y = up_y;\r\n  p.down_x = down_x;\r\n  p.down_y = down_y;\r\n  p.pad_x0 = pad_x0;\r\n  p.pad_x1 = pad_x1;\r\n  p.pad_y0 = pad_y0;\r\n  p.pad_y1 = pad_y1;\r\n\r\n  p.out_h = (p.in_h * p.up_y + p.pad_y0 + p.pad_y1 - p.kernel_h + p.down_y) /\r\n            p.down_y;\r\n  p.out_w = (p.in_w * p.up_x + p.pad_x0 + p.pad_x1 - p.kernel_w + p.down_x) /\r\n            p.down_x;\r\n\r\n  auto out =\r\n      at::empty({p.major_dim, p.out_h, p.out_w, p.minor_dim}, x.options());\r\n\r\n  int mode = -1;\r\n\r\n  int tile_out_h = -1;\r\n  int tile_out_w = -1;\r\n\r\n  if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 &&\r\n      p.kernel_h <= 4 && p.kernel_w <= 4) {\r\n    mode = 1;\r\n    tile_out_h = 16;\r\n    tile_out_w = 64;\r\n  }\r\n\r\n  if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 &&\r\n      p.kernel_h <= 3 && p.kernel_w <= 3) {\r\n    mode = 2;\r\n    tile_out_h = 16;\r\n    tile_out_w = 64;\r\n  }\r\n\r\n  if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 &&\r\n      p.kernel_h <= 4 && p.kernel_w <= 4) {\r\n    mode = 3;\r\n    tile_out_h = 16;\r\n    tile_out_w = 64;\r\n  }\r\n\r\n  if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 &&\r\n      p.kernel_h <= 2 && p.kernel_w <= 2) {\r\n    mode = 4;\r\n    tile_out_h = 16;\r\n    tile_out_w = 64;\r\n  }\r\n\r\n  if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 &&\r\n      p.kernel_h <= 4 && p.kernel_w <= 4) {\r\n    mode = 5;\r\n    tile_out_h = 8;\r\n    tile_out_w = 32;\r\n  }\r\n\r\n  if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 &&\r\n      p.kernel_h <= 2 && p.kernel_w <= 2) {\r\n    mode = 6;\r\n    tile_out_h = 8;\r\n    tile_out_w = 32;\r\n  }\r\n\r\n  dim3 block_size;\r\n  dim3 grid_size;\r\n\r\n  if (tile_out_h > 0 && tile_out_w > 0) {\r\n    p.loop_major = (p.major_dim - 1) / 16384 + 1;\r\n    p.loop_x = 1;\r\n    block_size = dim3(32 * 8, 1, 1);\r\n    grid_size = dim3(((p.out_h - 1) / tile_out_h + 1) * p.minor_dim,\r\n                     (p.out_w - 1) / (p.loop_x * tile_out_w) + 1,\r\n                     (p.major_dim - 1) / p.loop_major + 1);\r\n  } else {\r\n    p.loop_major = (p.major_dim - 1) / 16384 + 1;\r\n    p.loop_x = 4;\r\n    block_size = dim3(4, 32, 1);\r\n    grid_size = dim3((p.out_h * p.minor_dim - 1) / block_size.x + 1,\r\n                     (p.out_w - 1) / (p.loop_x * block_size.y) + 1,\r\n                     (p.major_dim - 1) / p.loop_major + 1);\r\n  }\r\n\r\n  AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), \"upfirdn2d_cuda\", [&] {\r\n    switch (mode) {\r\n    case 1:\r\n      upfirdn2d_kernel<scalar_t, 1, 1, 1, 1, 4, 4, 16, 64>\r\n          <<<grid_size, block_size, 0, stream>>>(out.data_ptr<scalar_t>(),\r\n                                                 x.data_ptr<scalar_t>(),\r\n                                                 k.data_ptr<scalar_t>(), p);\r\n\r\n      break;\r\n\r\n    case 2:\r\n      upfirdn2d_kernel<scalar_t, 1, 1, 1, 1, 3, 3, 16, 64>\r\n          <<<grid_size, block_size, 0, stream>>>(out.data_ptr<scalar_t>(),\r\n                                                 x.data_ptr<scalar_t>(),\r\n                                                 k.data_ptr<scalar_t>(), p);\r\n\r\n      break;\r\n\r\n    case 3:\r\n      upfirdn2d_kernel<scalar_t, 2, 2, 1, 1, 4, 4, 16, 64>\r\n          <<<grid_size, block_size, 0, stream>>>(out.data_ptr<scalar_t>(),\r\n                                                 x.data_ptr<scalar_t>(),\r\n                                                 k.data_ptr<scalar_t>(), p);\r\n\r\n      break;\r\n\r\n    case 4:\r\n      upfirdn2d_kernel<scalar_t, 2, 2, 1, 1, 2, 2, 16, 64>\r\n          <<<grid_size, block_size, 0, stream>>>(out.data_ptr<scalar_t>(),\r\n                                                 x.data_ptr<scalar_t>(),\r\n                                                 k.data_ptr<scalar_t>(), p);\r\n\r\n      break;\r\n\r\n    case 5:\r\n      upfirdn2d_kernel<scalar_t, 1, 1, 2, 2, 4, 4, 8, 32>\r\n          <<<grid_size, block_size, 0, stream>>>(out.data_ptr<scalar_t>(),\r\n                                                 x.data_ptr<scalar_t>(),\r\n                                                 k.data_ptr<scalar_t>(), p);\r\n\r\n      break;\r\n\r\n    case 6:\r\n      upfirdn2d_kernel<scalar_t, 1, 1, 2, 2, 4, 4, 8, 32>\r\n          <<<grid_size, block_size, 0, stream>>>(out.data_ptr<scalar_t>(),\r\n                                                 x.data_ptr<scalar_t>(),\r\n                                                 k.data_ptr<scalar_t>(), p);\r\n\r\n      break;\r\n\r\n    default:\r\n      upfirdn2d_kernel_large<scalar_t><<<grid_size, block_size, 0, stream>>>(\r\n          out.data_ptr<scalar_t>(), x.data_ptr<scalar_t>(),\r\n          k.data_ptr<scalar_t>(), p);\r\n    }\r\n  });\r\n\r\n  return out;\r\n}"
  },
  {
    "path": "options.py",
    "content": "import configargparse\nfrom munch import *\nfrom pdb import set_trace as st\n\nclass BaseOptions():\n    def __init__(self):\n        self.parser = configargparse.ArgumentParser()\n        self.initialized = False\n\n    def initialize(self):\n        # Dataset options\n        dataset = self.parser.add_argument_group('dataset')\n        dataset.add_argument(\"--dataset_path\", type=str, default='./datasets/FFHQ', help=\"path to the lmdb dataset\")\n\n        # Experiment Options\n        experiment = self.parser.add_argument_group('experiment')\n        experiment.add_argument('--config', is_config_file=True, help='config file path')\n        experiment.add_argument(\"--expname\", type=str, default='debug', help='experiment name')\n        experiment.add_argument(\"--ckpt\", type=str, default='300000', help=\"path to the checkpoints to resume training\")\n        experiment.add_argument(\"--continue_training\", action=\"store_true\", help=\"continue training the model\")\n\n        # Training loop options\n        training = self.parser.add_argument_group('training')\n        training.add_argument(\"--checkpoints_dir\", type=str, default='./checkpoint', help='checkpoints directory name')\n        training.add_argument(\"--iter\", type=int, default=300000, help=\"total number of training iterations\")\n        training.add_argument(\"--batch\", type=int, default=4, help=\"batch sizes for each GPU. A single RTX2080 can fit batch=4, chunck=1 into memory.\")\n        training.add_argument(\"--chunk\", type=int, default=4, help='number of samples within a batch to processed in parallel, decrease if running out of memory')\n        training.add_argument(\"--val_n_sample\", type=int, default=8, help=\"number of test samples generated during training\")\n        training.add_argument(\"--d_reg_every\", type=int, default=16, help=\"interval for applying r1 regularization to the StyleGAN generator\")\n        training.add_argument(\"--g_reg_every\", type=int, default=4, help=\"interval for applying path length regularization to the StyleGAN generator\")\n        training.add_argument(\"--local_rank\", type=int, default=0, help=\"local rank for distributed training\")\n        training.add_argument(\"--mixing\", type=float, default=0.9, help=\"probability of latent code mixing\")\n        training.add_argument(\"--lr\", type=float, default=0.002, help=\"learning rate\")\n        training.add_argument(\"--r1\", type=float, default=10, help=\"weight of the r1 regularization\")\n        training.add_argument(\"--view_lambda\", type=float, default=15, help=\"weight of the viewpoint regularization\")\n        training.add_argument(\"--eikonal_lambda\", type=float, default=0.1, help=\"weight of the eikonal regularization\")\n        training.add_argument(\"--min_surf_lambda\", type=float, default=0.05, help=\"weight of the minimal surface regularization\")\n        training.add_argument(\"--min_surf_beta\", type=float, default=100.0, help=\"weight of the minimal surface regularization\")\n        training.add_argument(\"--path_regularize\", type=float, default=2, help=\"weight of the path length regularization\")\n        training.add_argument(\"--path_batch_shrink\", type=int, default=2, help=\"batch size reducing factor for the path length regularization (reduce memory consumption)\")\n        training.add_argument(\"--wandb\", action=\"store_true\", help=\"use weights and biases logging\")\n        training.add_argument(\"--no_sphere_init\", action=\"store_true\", help=\"do not initialize the volume renderer with a sphere SDF\")\n\n        # Inference Options\n        inference = self.parser.add_argument_group('inference')\n        inference.add_argument(\"--results_dir\", type=str, default='./evaluations', help='results/evaluations directory name')\n        inference.add_argument(\"--truncation_ratio\", type=float, default=0.5, help=\"truncation ratio, controls the diversity vs. quality tradeoff. Higher truncation ratio would generate more diverse results\")\n        inference.add_argument(\"--truncation_mean\", type=int, default=10000, help=\"number of vectors to calculate mean for the truncation\")\n        inference.add_argument(\"--identities\", type=int, default=16, help=\"number of identities to be generated\")\n        inference.add_argument(\"--num_views_per_id\", type=int, default=1, help=\"number of viewpoints generated per identity\")\n        inference.add_argument(\"--no_surface_renderings\", action=\"store_true\", help=\"when true, only RGB outputs will be generated. otherwise, both RGB and depth videos/renderings will be generated. this cuts the processing time per video\")\n        inference.add_argument(\"--fixed_camera_angles\", action=\"store_true\", help=\"when true, the generator will render indentities from a fixed set of camera angles.\")\n        inference.add_argument(\"--azim_video\", action=\"store_true\", help=\"when true, the camera trajectory will travel along the azimuth direction. Otherwise, the camera will travel along an ellipsoid trajectory.\")\n\n        # Generator options\n        model = self.parser.add_argument_group('model')\n        model.add_argument(\"--size\", type=int, default=256, help=\"image sizes for the model\")\n        model.add_argument(\"--style_dim\", type=int, default=256, help=\"number of style input dimensions\")\n        model.add_argument(\"--channel_multiplier\", type=int, default=2, help=\"channel multiplier factor for the StyleGAN decoder. config-f = 2, else = 1\")\n        model.add_argument(\"--n_mlp\", type=int, default=8, help=\"number of mlp layers in stylegan's mapping network\")\n        model.add_argument(\"--lr_mapping\", type=float, default=0.01, help='learning rate reduction for mapping network MLP layers')\n        model.add_argument(\"--renderer_spatial_output_dim\", type=int, default=64, help='spatial resolution of the StyleGAN decoder inputs')\n        model.add_argument(\"--project_noise\", action='store_true', help='when true, use geometry-aware noise projection to reduce flickering effects (see supplementary section C.1 in the paper). warning: processing time significantly increases with this flag to ~20 minutes per video.')\n\n        # Camera options\n        camera = self.parser.add_argument_group('camera')\n        camera.add_argument(\"--uniform\", action=\"store_true\", help=\"when true, the camera position is sampled from uniform distribution. Gaussian distribution is the default\")\n        camera.add_argument(\"--azim\", type=float, default=0.3, help=\"camera azimuth angle std/range in Radians\")\n        camera.add_argument(\"--elev\", type=float, default=0.15, help=\"camera elevation angle std/range in Radians\")\n        camera.add_argument(\"--fov\", type=float, default=6, help=\"camera field of view half angle in Degrees\")\n        camera.add_argument(\"--dist_radius\", type=float, default=0.12, help=\"radius of points sampling distance from the origin. determines the near and far fields\")\n\n        # Volume Renderer options\n        rendering = self.parser.add_argument_group('rendering')\n        # MLP model parameters\n        rendering.add_argument(\"--depth\", type=int, default=8, help='layers in network')\n        rendering.add_argument(\"--width\", type=int, default=256, help='channels per layer')\n        # Volume representation options\n        rendering.add_argument(\"--no_sdf\", action='store_true', help='By default, the raw MLP outputs represent an underline signed distance field (SDF). When true, the MLP outputs represent the traditional NeRF density field.')\n        rendering.add_argument(\"--no_z_normalize\", action='store_true', help='By default, the model normalizes input coordinates such that the z coordinate is in [-1,1]. When true that feature is disabled.')\n        rendering.add_argument(\"--static_viewdirs\", action='store_true', help='when true, use static viewing direction input to the MLP')\n        # Ray intergration options\n        rendering.add_argument(\"--N_samples\", type=int, default=24, help='number of samples per ray')\n        rendering.add_argument(\"--no_offset_sampling\", action='store_true', help='when true, use random stratified sampling when rendering the volume, otherwise offset sampling is used. (See Equation (3) in Sec. 3.2 of the paper)')\n        rendering.add_argument(\"--perturb\", type=float, default=1., help='set to 0. for no jitter, 1. for jitter')\n        rendering.add_argument(\"--raw_noise_std\", type=float, default=0., help='std dev of noise added to regularize sigma_a output, 1e0 recommended')\n        rendering.add_argument(\"--force_background\", action='store_true', help='force the last depth sample to act as background in case of a transparent ray')\n        # Set volume renderer outputs\n        rendering.add_argument(\"--return_xyz\", action='store_true', help='when true, the volume renderer also returns the xyz point could of the surface. This point cloud is used to produce depth map renderings')\n        rendering.add_argument(\"--return_sdf\", action='store_true', help='when true, the volume renderer also returns the SDF network outputs for each location in the volume')\n\n        self.initialized = True\n\n    def parse(self):\n        self.opt = Munch()\n        if not self.initialized:\n            self.initialize()\n        try:\n            args = self.parser.parse_args()\n        except: # solves argparse error in google colab\n            args = self.parser.parse_args(args=[])\n\n        for group in self.parser._action_groups[2:]:\n            title = group.title\n            self.opt[title] = Munch()\n            for action in group._group_actions:\n                dest = action.dest\n                self.opt[title][dest] = args.__getattribute__(dest)\n\n        return self.opt\n"
  },
  {
    "path": "prepare_data.py",
    "content": "import argparse\nfrom io import BytesIO\nimport multiprocessing\nfrom functools import partial\n\nfrom PIL import Image\nimport lmdb\nfrom tqdm import tqdm\nfrom torchvision import datasets\nfrom torchvision.transforms import functional as trans_fn\nfrom pdb import set_trace as st\n\n\ndef resize_and_convert(img, size, resample):\n    img = trans_fn.resize(img, size, resample)\n    img = trans_fn.center_crop(img, size)\n    buffer = BytesIO()\n    img.save(buffer, format=\"png\")\n    val = buffer.getvalue()\n\n    return val\n\n\ndef resize_multiple(\n    img, sizes=(128, 256, 512, 1024), resample=Image.LANCZOS):\n    imgs = []\n\n    for size in sizes:\n        imgs.append(resize_and_convert(img, size, resample))\n\n    return imgs\n\n\ndef resize_worker(img_file, sizes, resample):\n    i, file = img_file\n    img = Image.open(file)\n    img = img.convert(\"RGB\")\n    out = resize_multiple(img, sizes=sizes, resample=resample)\n\n    return i, out\n\n\ndef prepare(\n    env, dataset, n_worker, sizes=(128, 256, 512, 1024), resample=Image.LANCZOS\n):\n    resize_fn = partial(resize_worker, sizes=sizes, resample=resample)\n\n    files = sorted(dataset.imgs, key=lambda x: x[0])\n    files = [(i, file) for i, (file, label) in enumerate(files)]\n    total = 0\n\n    with multiprocessing.Pool(n_worker) as pool:\n        for i, imgs in tqdm(pool.imap_unordered(resize_fn, files)):\n            for size, img in zip(sizes, imgs):\n                key = f\"{size}-{str(i).zfill(5)}\".encode(\"utf-8\")\n\n                with env.begin(write=True) as txn:\n                    txn.put(key, img)\n\n            total += 1\n\n        with env.begin(write=True) as txn:\n            txn.put(\"length\".encode(\"utf-8\"), str(total).encode(\"utf-8\"))\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser(description=\"Preprocess images for model training\")\n    parser.add_argument(\"--size\", type=str, default=\"64,512,1024\", help=\"resolutions of images for the dataset\")\n    parser.add_argument(\"--n_worker\", type=int, default=32, help=\"number of workers for preparing dataset\")\n    parser.add_argument(\"--resample\", type=str, default=\"lanczos\", help=\"resampling methods for resizing images\")\n    parser.add_argument(\"--out_path\", type=str, default=\"datasets/FFHQ\", help=\"Target path of the output lmdb dataset\")\n    parser.add_argument(\"in_path\", type=str, help=\"path to the input image dataset\")\n    args = parser.parse_args()\n\n    resample_map = {\"lanczos\": Image.LANCZOS, \"bilinear\": Image.BILINEAR}\n    resample = resample_map[args.resample]\n\n    sizes = [int(s.strip()) for s in args.size.split(\",\")]\n\n    print(f\"Make dataset of image sizes:\", \", \".join(str(s) for s in sizes))\n\n    imgset = datasets.ImageFolder(args.in_path)\n\n    with lmdb.open(args.out_path, map_size=1024 ** 4, readahead=False) as env:\n        prepare(env, imgset, args.n_worker, sizes=sizes, resample=resample)\n"
  },
  {
    "path": "render_video.py",
    "content": "import os\nimport torch\nimport trimesh\nimport numpy as np\nimport skvideo.io\nfrom munch import *\nfrom PIL import Image\nfrom tqdm import tqdm\nfrom torch.nn import functional as F\nfrom torch.utils import data\nfrom torchvision import utils\nfrom torchvision import transforms\nfrom options import BaseOptions\nfrom model import Generator\nfrom utils import (\n    generate_camera_params, align_volume, extract_mesh_with_marching_cubes,\n    xyz2mesh, create_cameras, create_mesh_renderer, add_textures,\n    )\nfrom pytorch3d.structures import Meshes\n\nfrom pdb import set_trace as st\n\ntorch.random.manual_seed(1234)\n\n\ndef render_video(opt, g_ema, surface_g_ema, device, mean_latent, surface_mean_latent):\n    g_ema.eval()\n    if not opt.no_surface_renderings or opt.project_noise:\n        surface_g_ema.eval()\n\n    images = torch.Tensor(0, 3, opt.size, opt.size)\n    num_frames = 250\n    # Generate video trajectory\n    trajectory = np.zeros((num_frames,3), dtype=np.float32)\n\n    # set camera trajectory\n    # sweep azimuth angles (4 seconds)\n    if opt.azim_video:\n        t = np.linspace(0, 1, num_frames)\n        elev = 0\n        fov = opt.camera.fov\n        if opt.camera.uniform:\n            azim = opt.camera.azim * np.cos(t * 2 * np.pi)\n        else:\n            azim = 1.5 * opt.camera.azim * np.cos(t * 2 * np.pi)\n\n        trajectory[:num_frames,0] = azim\n        trajectory[:num_frames,1] = elev\n        trajectory[:num_frames,2] = fov\n\n    # elipsoid sweep (4 seconds)\n    else:\n        t = np.linspace(0, 1, num_frames)\n        fov = opt.camera.fov #+ 1 * np.sin(t * 2 * np.pi)\n        if opt.camera.uniform:\n            elev = opt.camera.elev / 2 + opt.camera.elev / 2  * np.sin(t * 2 * np.pi)\n            azim = opt.camera.azim  * np.cos(t * 2 * np.pi)\n        else:\n            elev = 1.5 * opt.camera.elev * np.sin(t * 2 * np.pi)\n            azim = 1.5 * opt.camera.azim * np.cos(t * 2 * np.pi)\n\n        trajectory[:num_frames,0] = azim\n        trajectory[:num_frames,1] = elev\n        trajectory[:num_frames,2] = fov\n\n    trajectory = torch.from_numpy(trajectory).to(device)\n\n    # generate input parameters for the camera trajectory\n    # sample_cam_poses, sample_focals, sample_near, sample_far = \\\n    # generate_camera_params(trajectory, opt.renderer_output_size, device, dist_radius=opt.camera.dist_radius)\n    sample_cam_extrinsics, sample_focals, sample_near, sample_far, _ = \\\n    generate_camera_params(opt.renderer_output_size, device, locations=trajectory[:,:2],\n                           fov_ang=trajectory[:,2:], dist_radius=opt.camera.dist_radius)\n\n    # In case of noise projection, generate input parameters for the frontal position.\n    # The reference mesh for the noise projection is extracted from the frontal position.\n    # For more details see section C.1 in the supplementary material.\n    if opt.project_noise:\n        frontal_pose = torch.tensor([[0.0,0.0,opt.camera.fov]]).to(device)\n        # frontal_cam_pose, frontal_focals, frontal_near, frontal_far = \\\n        # generate_camera_params(frontal_pose, opt.surf_extraction_output_size, device, dist_radius=opt.camera.dist_radius)\n        frontal_cam_pose, frontal_focals, frontal_near, frontal_far, _ = \\\n        generate_camera_params(opt.surf_extraction_output_size, device, location=frontal_pose[:,:2],\n                               fov_ang=frontal_pose[:,2:], dist_radius=opt.camera.dist_radius)\n\n    # create geometry renderer (renders the depth maps)\n    cameras = create_cameras(azim=np.rad2deg(trajectory[0,0].cpu().numpy()),\n                             elev=np.rad2deg(trajectory[0,1].cpu().numpy()),\n                             dist=1, device=device)\n    renderer = create_mesh_renderer(cameras, image_size=512, specular_color=((0,0,0),),\n                    ambient_color=((0.1,.1,.1),), diffuse_color=((0.75,.75,.75),),\n                    device=device)\n\n    suffix = '_azim' if opt.azim_video else '_elipsoid'\n\n    # generate videos\n    for i in range(opt.identities):\n        print('Processing identity {}/{}...'.format(i+1, opt.identities))\n        chunk = 1\n        sample_z = torch.randn(1, opt.style_dim, device=device).repeat(chunk,1)\n        video_filename = 'sample_video_{}{}.mp4'.format(i,suffix)\n        writer = skvideo.io.FFmpegWriter(os.path.join(opt.results_dst_dir, video_filename),\n                                         outputdict={'-pix_fmt': 'yuv420p', '-crf': '10'})\n        if not opt.no_surface_renderings:\n            depth_video_filename = 'sample_depth_video_{}{}.mp4'.format(i,suffix)\n            depth_writer = skvideo.io.FFmpegWriter(os.path.join(opt.results_dst_dir, depth_video_filename),\n                                             outputdict={'-pix_fmt': 'yuv420p', '-crf': '1'})\n\n\n        ####################### Extract initial surface mesh from the frontal viewpoint #############\n        # For more details see section C.1 in the supplementary material.\n        if opt.project_noise:\n            with torch.no_grad():\n                frontal_surface_out = surface_g_ema([sample_z],\n                                                    frontal_cam_pose,\n                                                    frontal_focals,\n                                                    frontal_near,\n                                                    frontal_far,\n                                                    truncation=opt.truncation_ratio,\n                                                    truncation_latent=surface_mean_latent,\n                                                    return_sdf=True)\n                frontal_sdf = frontal_surface_out[2].cpu()\n\n            print('Extracting Identity {} Frontal view Marching Cubes for consistent video rendering'.format(i))\n\n            frostum_aligned_frontal_sdf = align_volume(frontal_sdf)\n            del frontal_sdf\n\n            try:\n                frontal_marching_cubes_mesh = extract_mesh_with_marching_cubes(frostum_aligned_frontal_sdf)\n            except ValueError:\n                frontal_marching_cubes_mesh = None\n\n            if frontal_marching_cubes_mesh != None:\n                frontal_marching_cubes_mesh_filename = os.path.join(opt.results_dst_dir,'sample_{}_frontal_marching_cubes_mesh{}.obj'.format(i,suffix))\n                with open(frontal_marching_cubes_mesh_filename, 'w') as f:\n                    frontal_marching_cubes_mesh.export(f,file_type='obj')\n\n            del frontal_surface_out\n            torch.cuda.empty_cache()\n        #############################################################################################\n\n        for j in tqdm(range(0, num_frames, chunk)):\n            with torch.no_grad():\n                out = g_ema([sample_z],\n                            sample_cam_extrinsics[j:j+chunk],\n                            sample_focals[j:j+chunk],\n                            sample_near[j:j+chunk],\n                            sample_far[j:j+chunk],\n                            truncation=opt.truncation_ratio,\n                            truncation_latent=mean_latent,\n                            randomize_noise=False,\n                            project_noise=opt.project_noise,\n                            mesh_path=frontal_marching_cubes_mesh_filename if opt.project_noise else None)\n\n                rgb = out[0].cpu()\n\n                # this is done to fit to RTX2080 RAM size (11GB)\n                del out\n                torch.cuda.empty_cache()\n\n                # Convert RGB from [-1, 1] to [0,255]\n                rgb = 127.5 * (rgb.clamp(-1,1).permute(0,2,3,1).cpu().numpy() + 1)\n\n                # Add RGB, frame to video\n                for k in range(chunk):\n                    writer.writeFrame(rgb[k])\n\n                ########## Extract surface ##########\n                if not opt.no_surface_renderings:\n                    scale = surface_g_ema.renderer.out_im_res / g_ema.renderer.out_im_res\n                    surface_sample_focals = sample_focals * scale\n                    surface_out = surface_g_ema([sample_z],\n                                                sample_cam_extrinsics[j:j+chunk],\n                                                surface_sample_focals[j:j+chunk],\n                                                sample_near[j:j+chunk],\n                                                sample_far[j:j+chunk],\n                                                truncation=opt.truncation_ratio,\n                                                truncation_latent=surface_mean_latent,\n                                                return_xyz=True)\n                    xyz = surface_out[2].cpu()\n\n                    # this is done to fit to RTX2080 RAM size (11GB)\n                    del surface_out\n                    torch.cuda.empty_cache()\n\n                    # Render mesh for video\n                    depth_mesh = xyz2mesh(xyz)\n                    mesh = Meshes(\n                        verts=[torch.from_numpy(np.asarray(depth_mesh.vertices)).to(torch.float32).to(device)],\n                        faces = [torch.from_numpy(np.asarray(depth_mesh.faces)).to(torch.float32).to(device)],\n                        textures=None,\n                        verts_normals=[torch.from_numpy(np.copy(np.asarray(depth_mesh.vertex_normals))).to(torch.float32).to(device)],\n                    )\n                    mesh = add_textures(mesh)\n                    cameras = create_cameras(azim=np.rad2deg(trajectory[j,0].cpu().numpy()),\n                                             elev=np.rad2deg(trajectory[j,1].cpu().numpy()),\n                                             fov=2*trajectory[j,2].cpu().numpy(),\n                                             dist=1, device=device)\n                    renderer = create_mesh_renderer(cameras, image_size=512,\n                                                    light_location=((0.0,1.0,5.0),), specular_color=((0.2,0.2,0.2),),\n                                                    ambient_color=((0.1,0.1,0.1),), diffuse_color=((0.65,.65,.65),),\n                                                    device=device)\n\n                    mesh_image = 255 * renderer(mesh).cpu().numpy()\n                    mesh_image = mesh_image[...,:3]\n\n                    # Add depth frame to video\n                    for k in range(chunk):\n                        depth_writer.writeFrame(mesh_image[k])\n\n        # Close video writers\n        writer.close()\n        if not opt.no_surface_renderings:\n            depth_writer.close()\n\n\nif __name__ == \"__main__\":\n    device = \"cuda\"\n    opt = BaseOptions().parse()\n    opt.model.is_test = True\n    opt.model.style_dim = 256\n    opt.model.freeze_renderer = False\n    opt.inference.size = opt.model.size\n    opt.inference.camera = opt.camera\n    opt.inference.renderer_output_size = opt.model.renderer_spatial_output_dim\n    opt.inference.style_dim = opt.model.style_dim\n    opt.inference.project_noise = opt.model.project_noise\n    opt.rendering.perturb = 0\n    opt.rendering.force_background = True\n    opt.rendering.static_viewdirs = True\n    opt.rendering.return_sdf = True\n    opt.rendering.N_samples = 64\n\n    # find checkpoint directory\n    # check if there's a fully trained model\n    checkpoints_dir = 'full_models'\n    checkpoint_path = os.path.join(checkpoints_dir, opt.experiment.expname + '.pt')\n    if os.path.isfile(checkpoint_path):\n        # define results directory name\n        result_model_dir = 'final_model'\n    else:\n        checkpoints_dir = os.path.join('checkpoint', opt.experiment.expname, 'full_pipeline')\n        checkpoint_path = os.path.join(checkpoints_dir,\n                                       'models_{}.pt'.format(opt.experiment.ckpt.zfill(7)))\n        # define results directory name\n        result_model_dir = 'iter_{}'.format(opt.experiment.ckpt.zfill(7))\n\n    results_dir_basename = os.path.join(opt.inference.results_dir, opt.experiment.expname)\n    opt.inference.results_dst_dir = os.path.join(results_dir_basename, result_model_dir, 'videos')\n    if opt.model.project_noise:\n        opt.inference.results_dst_dir = os.path.join(opt.inference.results_dst_dir, 'with_noise_projection')\n\n    os.makedirs(opt.inference.results_dst_dir, exist_ok=True)\n\n    # load saved model\n    checkpoint = torch.load(checkpoint_path)\n\n    # load image generation model\n    g_ema = Generator(opt.model, opt.rendering).to(device)\n\n    # temp fix because of wrong noise sizes\n    pretrained_weights_dict = checkpoint[\"g_ema\"]\n    model_dict = g_ema.state_dict()\n    for k, v in pretrained_weights_dict.items():\n        if v.size() == model_dict[k].size():\n            model_dict[k] = v\n\n    g_ema.load_state_dict(model_dict)\n\n    # load a the volume renderee to a second that extracts surfaces at 128x128x128\n    if not opt.inference.no_surface_renderings or opt.model.project_noise:\n        opt['surf_extraction'] = Munch()\n        opt.surf_extraction.rendering = opt.rendering\n        opt.surf_extraction.model = opt.model.copy()\n        opt.surf_extraction.model.renderer_spatial_output_dim = 128\n        opt.surf_extraction.rendering.N_samples = opt.surf_extraction.model.renderer_spatial_output_dim\n        opt.surf_extraction.rendering.return_xyz = True\n        opt.surf_extraction.rendering.return_sdf = True\n        opt.inference.surf_extraction_output_size = opt.surf_extraction.model.renderer_spatial_output_dim\n        surface_g_ema = Generator(opt.surf_extraction.model, opt.surf_extraction.rendering, full_pipeline=False).to(device)\n\n        # Load weights to surface extractor\n        surface_extractor_dict = surface_g_ema.state_dict()\n        for k, v in pretrained_weights_dict.items():\n            if k in surface_extractor_dict.keys() and v.size() == surface_extractor_dict[k].size():\n                surface_extractor_dict[k] = v\n\n        surface_g_ema.load_state_dict(surface_extractor_dict)\n    else:\n        surface_g_ema = None\n\n    # get the mean latent vector for g_ema\n    if opt.inference.truncation_ratio < 1:\n        with torch.no_grad():\n            mean_latent = g_ema.mean_latent(opt.inference.truncation_mean, device)\n    else:\n        mean_latent = None\n\n    # get the mean latent vector for surface_g_ema\n    if not opt.inference.no_surface_renderings or opt.model.project_noise:\n        surface_mean_latent = mean_latent[0]\n    else:\n        surface_mean_latent = None\n\n    render_video(opt.inference, g_ema, surface_g_ema, device, mean_latent, surface_mean_latent)\n"
  },
  {
    "path": "requirements.txt",
    "content": "lmdb\nnumpy\nninja\npillow\nrequests\ntqdm\nscipy\nscikit-image\nscikit-video\ntrimesh[easy]\nconfigargparse\nmunch\nwandb\n"
  },
  {
    "path": "scripts/train_afhq_full_pipeline_512x512.sh",
    "content": "python -m torch.distributed.launch --nproc_per_node 4 new_train.py --batch 8 --chunk 4 --azim 0.15 --r1 50.0 --expname afhq512x512 --dataset_path ./datasets/AFHQ/train/ --size 512 --wandb\n"
  },
  {
    "path": "scripts/train_afhq_vol_renderer.sh",
    "content": "python -m torch.distributed.launch --nproc_per_node 2 train_volume_renderer.py --batch 12 --chunk 6 --expname afhq_sdf_vol_renderer --dataset_path ./datasets/AFHQ/train --azim 0.15 --wandb\n"
  },
  {
    "path": "scripts/train_ffhq_full_pipeline_1024x1024.sh",
    "content": "python -m torch.distributed.launch --nproc_per_node 4 train_full_pipeline.py --batch 8 --chunk 2 --expname ffhq1024x1024 --size 1024 --wandb\n"
  },
  {
    "path": "scripts/train_ffhq_vol_renderer.sh",
    "content": "python -m torch.distributed.launch --nproc_per_node 2 train_volume_renderer.py --batch 12 --chunk 6 --expname ffhq_sdf_vol_renderer --dataset_path ./datasets/FFHQ/ --wandb\n"
  },
  {
    "path": "train_full_pipeline.py",
    "content": "import argparse\nimport math\nimport random\nimport os\nimport yaml\nimport numpy as np\nimport torch\nfrom torch import nn, autograd, optim\nfrom torch.nn import functional as F\nfrom torch.utils import data\nimport torch.distributed as dist\nfrom torchvision import transforms, utils\nfrom tqdm import tqdm\nfrom PIL import Image\nfrom losses import *\nfrom options import BaseOptions\nfrom model import Generator, Discriminator\nfrom dataset import MultiResolutionDataset\nfrom utils import data_sampler, requires_grad, accumulate, sample_data, make_noise, mixing_noise, generate_camera_params\nfrom distributed import get_rank, synchronize, reduce_loss_dict, reduce_sum, get_world_size\n\ntry:\n    import wandb\nexcept ImportError:\n    wandb = None\n\n\ndef train(opt, experiment_opt, loader, generator, discriminator, g_optim, d_optim, g_ema, device):\n    loader = sample_data(loader)\n\n    pbar = range(opt.iter)\n\n    if get_rank() == 0:\n        pbar = tqdm(pbar, initial=opt.start_iter, dynamic_ncols=True, smoothing=0.01)\n\n    mean_path_length = 0\n\n    d_loss_val = 0\n    g_loss_val = 0\n    r1_loss = torch.tensor(0.0, device=device)\n    g_gan_loss = torch.tensor(0.0, device=device)\n    path_loss = torch.tensor(0.0, device=device)\n    path_lengths = torch.tensor(0.0, device=device)\n    mean_path_length_avg = 0\n    loss_dict = {}\n\n    if opt.distributed:\n        g_module = generator.module\n        d_module = discriminator.module\n    else:\n        g_module = generator\n        d_module = discriminator\n\n    accum = 0.5 ** (32 / (10 * 1000))\n\n    sample_z = [torch.randn(opt.val_n_sample, opt.style_dim, device=device).repeat_interleave(8, dim=0)]\n    sample_cam_extrinsics, sample_focals, sample_near, sample_far, _ = generate_camera_params(opt.renderer_output_size, device, batch=opt.val_n_sample, sweep=True,\n                                                                                         uniform=opt.camera.uniform, azim_range=opt.camera.azim,\n                                                                                         elev_range=opt.camera.elev, fov_ang=opt.camera.fov,\n                                                                                         dist_radius=opt.camera.dist_radius)\n\n    for idx in pbar:\n        i = idx + opt.start_iter\n\n        if i > opt.iter:\n            print(\"Done!\")\n\n            break\n\n        requires_grad(generator, False)\n        requires_grad(discriminator, True)\n        discriminator.zero_grad()\n        d_regularize = i % opt.d_reg_every == 0\n        real_imgs, real_thumb_imgs = next(loader)\n        real_imgs = real_imgs.to(device)\n        real_thumb_imgs = real_thumb_imgs.to(device)\n        noise = mixing_noise(opt.batch, opt.style_dim, opt.mixing, device)\n        cam_extrinsics, focal, near, far, gt_viewpoints = generate_camera_params(opt.renderer_output_size, device, batch=opt.batch,\n                                                                            uniform=opt.camera.uniform, azim_range=opt.camera.azim,\n                                                                            elev_range=opt.camera.elev, fov_ang=opt.camera.fov,\n                                                                            dist_radius=opt.camera.dist_radius)\n\n        for j in range(0, opt.batch, opt.chunk):\n            curr_real_imgs = real_imgs[j:j+opt.chunk]\n            curr_real_thumb_imgs = real_thumb_imgs[j:j+opt.chunk]\n            curr_noise = [n[j:j+opt.chunk] for n in noise]\n            gen_imgs, _ = generator(curr_noise,\n                                    cam_extrinsics[j:j+opt.chunk],\n                                    focal[j:j+opt.chunk],\n                                    near[j:j+opt.chunk],\n                                    far[j:j+opt.chunk])\n\n            fake_pred = discriminator(gen_imgs.detach())\n\n            if d_regularize:\n                curr_real_imgs.requires_grad = True\n                curr_real_thumb_imgs.requires_grad = True\n\n            real_pred = discriminator(curr_real_imgs)\n            d_gan_loss = d_logistic_loss(real_pred, fake_pred)\n\n            if d_regularize:\n                grad_penalty = d_r1_loss(real_pred, curr_real_imgs)\n                r1_loss = opt.r1 * 0.5 * grad_penalty * opt.d_reg_every\n            else:\n                r1_loss = torch.zeros_like(r1_loss)\n\n            d_loss = d_gan_loss + r1_loss\n            d_loss.backward()\n\n        d_optim.step()\n\n        loss_dict[\"d\"] = d_gan_loss\n        loss_dict[\"real_score\"] = real_pred.mean()\n        loss_dict[\"fake_score\"] = fake_pred.mean()\n        if d_regularize or i == opt.start_iter:\n            loss_dict[\"r1\"] = r1_loss.mean()\n\n        requires_grad(generator, True)\n        requires_grad(discriminator, False)\n\n        for j in range(0, opt.batch, opt.chunk):\n            noise = mixing_noise(opt.chunk, opt.style_dim, opt.mixing, device)\n            cam_extrinsics, focal, near, far, gt_viewpoints = generate_camera_params(opt.renderer_output_size, device, batch=opt.chunk,\n                                                                                uniform=opt.camera.uniform, azim_range=opt.camera.azim,\n                                                                                elev_range=opt.camera.elev, fov_ang=opt.camera.fov,\n                                                                                dist_radius=opt.camera.dist_radius)\n\n            fake_img, _ = generator(noise, cam_extrinsics, focal, near, far)\n            fake_pred = discriminator(fake_img)\n            g_gan_loss = g_nonsaturating_loss(fake_pred)\n\n            g_loss = g_gan_loss\n            g_loss.backward()\n\n\n        g_optim.step()\n        generator.zero_grad()\n\n        loss_dict[\"g\"] = g_gan_loss\n\n        # generator path regularization\n        g_regularize = (opt.g_reg_every > 0) and (i % opt.g_reg_every == 0)\n        if g_regularize:\n            path_batch_size = max(1, opt.batch // opt.path_batch_shrink)\n            path_noise = mixing_noise(path_batch_size, opt.style_dim, opt.mixing, device)\n            path_cam_extrinsics, path_focal, path_near, path_far, _ = generate_camera_params(opt.renderer_output_size, device, batch=path_batch_size,\n                                                                                        uniform=opt.camera.uniform, azim_range=opt.camera.azim,\n                                                                                        elev_range=opt.camera.elev, fov_ang=opt.camera.fov,\n                                                                                        dist_radius=opt.camera.dist_radius)\n\n            for j in range(0, path_batch_size, opt.chunk):\n                path_fake_img, path_latents = generator(path_noise, path_cam_extrinsics,\n                                                        path_focal, path_near, path_far,\n                                                        return_latents=True)\n\n                path_loss, mean_path_length, path_lengths = g_path_regularize(\n                    path_fake_img, path_latents, mean_path_length\n                )\n\n                weighted_path_loss = opt.path_regularize * opt.g_reg_every * path_loss# * opt.chunk / path_batch_size\n                if opt.path_batch_shrink:\n                    weighted_path_loss += 0 * path_fake_img[0, 0, 0, 0]\n\n                weighted_path_loss.backward()\n\n            g_optim.step()\n            generator.zero_grad()\n\n            mean_path_length_avg = (reduce_sum(mean_path_length).item() / get_world_size())\n\n        loss_dict[\"path\"] = path_loss\n        loss_dict[\"path_length\"] = path_lengths.mean()\n\n        accumulate(g_ema, g_module, accum)\n        loss_reduced = reduce_loss_dict(loss_dict)\n        d_loss_val = loss_reduced[\"d\"].mean().item()\n        g_loss_val = loss_reduced[\"g\"].mean().item()\n        r1_val = loss_reduced[\"r1\"].mean().item()\n        real_score_val = loss_reduced[\"real_score\"].mean().item()\n        fake_score_val = loss_reduced[\"fake_score\"].mean().item()\n        path_loss_val = loss_reduced[\"path\"].mean().item()\n        path_length_val = loss_reduced[\"path_length\"].mean().item()\n\n        if get_rank() == 0:\n            pbar.set_description(\n                (f\"d: {d_loss_val:.4f}; g: {g_loss_val:.4f}; r1: {r1_val:.4f}; path: {path_loss_val:.4f}\")\n            )\n\n            if i % 1000 == 0 or i == opt.start_iter:\n                with torch.no_grad():\n                    thumbs_samples = torch.Tensor(0, 3, opt.renderer_output_size, opt.renderer_output_size)\n                    samples = torch.Tensor(0, 3, opt.size, opt.size)\n                    step_size = 8\n                    mean_latent = g_module.mean_latent(10000, device)\n                    for k in range(0, opt.val_n_sample * 8, step_size):\n                        curr_samples, curr_thumbs = g_ema([sample_z[0][k:k+step_size]],\n                                                           sample_cam_extrinsics[k:k+step_size],\n                                                           sample_focals[k:k+step_size],\n                                                           sample_near[k:k+step_size],\n                                                           sample_far[k:k+step_size],\n                                                           truncation=0.7,\n                                                           truncation_latent=mean_latent)\n                        samples = torch.cat([samples, curr_samples.cpu()], 0)\n                        thumbs_samples = torch.cat([thumbs_samples, curr_thumbs.cpu()], 0)\n\n                    if i % 10000 == 0:\n                        utils.save_image(samples,\n                            os.path.join(opt.checkpoints_dir, experiment_opt.expname, 'full_pipeline', f\"samples/{str(i).zfill(7)}.png\"),\n                            nrow=int(opt.val_n_sample),\n                            normalize=True,\n                            value_range=(-1, 1),)\n\n                        utils.save_image(thumbs_samples,\n                            os.path.join(opt.checkpoints_dir, experiment_opt.expname, 'full_pipeline', f\"samples/{str(i).zfill(7)}_thumbs.png\"),\n                            nrow=int(opt.val_n_sample),\n                            normalize=True,\n                            value_range=(-1, 1),)\n\n            if wandb and opt.wandb:\n                wandb_log_dict = {\"Generator\": g_loss_val,\n                                  \"Discriminator\": d_loss_val,\n                                  \"R1\": r1_val,\n                                  \"Real Score\": real_score_val,\n                                  \"Fake Score\": fake_score_val,\n                                  \"Path Length Regularization\": path_loss_val,\n                                  \"Path Length\": path_length_val,\n                                  \"Mean Path Length\": mean_path_length,\n                                  }\n                if i % 5000 == 0:\n                    wandb_grid = utils.make_grid(samples, nrow=int(opt.val_n_sample),\n                                                   normalize=True, value_range=(-1, 1))\n                    wandb_ndarr = (255 * wandb_grid.permute(1, 2, 0).numpy()).astype(np.uint8)\n                    wandb_images = Image.fromarray(wandb_ndarr)\n                    wandb_log_dict.update({\"examples\": [wandb.Image(wandb_images,\n                                            caption=\"Generated samples for azimuth angles of: -0.35, -0.25, -0.15, -0.05, 0.05, 0.15, 0.25, 0.35 Radians.\")]})\n\n                    wandb_thumbs_grid = utils.make_grid(thumbs_samples, nrow=int(opt.val_n_sample),\n                                                        normalize=True, value_range=(-1, 1))\n                    wandb_thumbs_ndarr = (255 * wandb_thumbs_grid.permute(1, 2, 0).numpy()).astype(np.uint8)\n                    wandb_thumbs = Image.fromarray(wandb_thumbs_ndarr)\n                    wandb_log_dict.update({\"thumb_examples\": [wandb.Image(wandb_thumbs,\n                                            caption=\"Generated samples for azimuth angles of: -0.35, -0.25, -0.15, -0.05, 0.05, 0.15, 0.25, 0.35 Radians.\")]})\n\n                wandb.log(wandb_log_dict)\n\n            if i % 10000 == 0:\n                torch.save(\n                    {\n                        \"g\": g_module.state_dict(),\n                        \"d\": d_module.state_dict(),\n                        \"g_ema\": g_ema.state_dict(),\n                    },\n                    os.path.join(opt.checkpoints_dir, experiment_opt.expname, 'full_pipeline', f\"models_{str(i).zfill(7)}.pt\")\n                )\n                print('Successfully saved checkpoint for iteration {}.'.format(i))\n\n    if get_rank() == 0:\n        # create final model directory\n        final_model_path = os.path.join('full_models', opt.experiment.expname)\n        os.makedirs(final_model_path, exist_ok=True)\n        torch.save(\n            {\n                \"g\": g_module.state_dict(),\n                \"d\": d_module.state_dict(),\n                \"g_ema\": g_ema.state_dict(),\n            },\n            os.path.join(final_model_path, experiment_opt.expname + '.pt')\n        )\n        print('Successfully saved final model.')\n\n\nif __name__ == \"__main__\":\n    device = \"cuda\"\n    opt = BaseOptions().parse()\n    opt.training.camera = opt.camera\n    opt.training.size = opt.model.size\n    opt.training.renderer_output_size = opt.model.renderer_spatial_output_dim\n    opt.training.style_dim = opt.model.style_dim\n    opt.model.freeze_renderer = True\n\n    n_gpu = int(os.environ[\"WORLD_SIZE\"]) if \"WORLD_SIZE\" in os.environ else 1\n    opt.training.distributed = n_gpu > 1\n\n    if opt.training.distributed:\n        torch.cuda.set_device(opt.training.local_rank)\n        torch.distributed.init_process_group(backend=\"nccl\", init_method=\"env://\")\n        synchronize()\n\n    # create checkpoints directories\n    os.makedirs(os.path.join(opt.training.checkpoints_dir, opt.experiment.expname, 'full_pipeline'), exist_ok=True)\n    os.makedirs(os.path.join(opt.training.checkpoints_dir, opt.experiment.expname, 'full_pipeline', 'samples'), exist_ok=True)\n\n    discriminator = Discriminator(opt.model).to(device)\n    generator = Generator(opt.model, opt.rendering).to(device)\n    g_ema = Generator(opt.model, opt.rendering, ema=True).to(device)\n    g_ema.eval()\n\n    g_reg_ratio = opt.training.g_reg_every / (opt.training.g_reg_every + 1) if opt.training.g_reg_every > 0 else 1\n    d_reg_ratio = opt.training.d_reg_every / (opt.training.d_reg_every + 1)\n\n    params_g = []\n    params_dict_g = dict(generator.named_parameters())\n    for key, value in params_dict_g.items():\n        decoder_cond = ('decoder' in key)\n        if decoder_cond:\n            params_g += [{'params':[value], 'lr':opt.training.lr * g_reg_ratio}]\n\n    g_optim = optim.Adam(params_g, #generator.parameters(),\n                         lr=opt.training.lr * g_reg_ratio,\n                         betas=(0 ** g_reg_ratio, 0.99 ** g_reg_ratio))\n    d_optim = optim.Adam(discriminator.parameters(),\n                         lr=opt.training.lr * d_reg_ratio,# * g_d_ratio,\n                         betas=(0 ** d_reg_ratio, 0.99 ** d_reg_ratio))\n\n    opt.training.start_iter = 0\n    if opt.experiment.continue_training and opt.experiment.ckpt is not None:\n        if get_rank() == 0:\n            print(\"load model:\", opt.experiment.ckpt)\n        ckpt_path = os.path.join(opt.training.checkpoints_dir,\n                                 opt.experiment.expname,\n                                 'models_{}.pt'.format(opt.experiment.ckpt.zfill(7)))\n        ckpt = torch.load(ckpt_path, map_location=lambda storage, loc: storage)\n\n        try:\n            opt.training.start_iter = int(opt.experiment.ckpt) + 1\n\n        except ValueError:\n            pass\n\n        generator.load_state_dict(ckpt[\"g\"])\n        discriminator.load_state_dict(ckpt[\"d\"])\n        g_ema.load_state_dict(ckpt[\"g_ema\"])\n    else:\n        # save configuration\n        opt_path = os.path.join(opt.training.checkpoints_dir, opt.experiment.expname, 'full_pipeline', f\"opt.yaml\")\n        with open(opt_path,'w') as f:\n            yaml.safe_dump(opt, f)\n\n    if not opt.experiment.continue_training:\n        if get_rank() == 0:\n            print(\"loading pretrained renderer weights...\")\n\n        pretrained_renderer_path = os.path.join('./pretrained_renderer',\n                                                opt.experiment.expname + '_vol_renderer.pt')\n        try:\n            ckpt = torch.load(pretrained_renderer_path, map_location=lambda storage, loc: storage)\n        except:\n            print('Pretrained volume renderer experiment name does not match the full pipeline experiment name.')\n            vol_renderer_expname = str(input('Please enter the pretrained volume renderer experiment name:'))\n            pretrained_renderer_path = os.path.join('./pretrained_renderer',\n                                                    vol_renderer_expname + '.pt')\n            ckpt = torch.load(pretrained_renderer_path, map_location=lambda storage, loc: storage)\n\n        pretrained_renderer_dict = ckpt[\"g_ema\"]\n        model_dict = generator.state_dict()\n        for k, v in pretrained_renderer_dict.items():\n            if v.size() == model_dict[k].size():\n                model_dict[k] = v\n\n        generator.load_state_dict(model_dict)\n\n    # initialize g_ema weights to generator weights\n    accumulate(g_ema, generator, 0)\n\n    # set distributed models\n    if opt.training.distributed:\n        generator = nn.parallel.DistributedDataParallel(\n            generator,\n            device_ids=[opt.training.local_rank],\n            output_device=opt.training.local_rank,\n            broadcast_buffers=True,\n            find_unused_parameters=True,\n        )\n\n        discriminator = nn.parallel.DistributedDataParallel(\n            discriminator,\n            device_ids=[opt.training.local_rank],\n            output_device=opt.training.local_rank,\n            broadcast_buffers=False,\n            find_unused_parameters=True\n        )\n\n    transform = transforms.Compose(\n        [transforms.ToTensor(),\n         transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)])\n\n    dataset = MultiResolutionDataset(opt.dataset.dataset_path, transform, opt.model.size,\n                                     opt.model.renderer_spatial_output_dim)\n    loader = data.DataLoader(\n        dataset,\n        batch_size=opt.training.batch,\n        sampler=data_sampler(dataset, shuffle=True, distributed=opt.training.distributed),\n        drop_last=True,\n    )\n\n    if get_rank() == 0 and wandb is not None and opt.training.wandb:\n        wandb.init(project=\"StyleSDF\")\n        wandb.run.name = opt.experiment.expname\n        wandb.config.dataset = os.path.basename(opt.dataset.dataset_path)\n        wandb.config.update(opt.training)\n        wandb.config.update(opt.model)\n        wandb.config.update(opt.rendering)\n\n    train(opt.training, opt.experiment, loader, generator, discriminator, g_optim, d_optim, g_ema, device)\n"
  },
  {
    "path": "train_volume_renderer.py",
    "content": "import argparse\nimport math\nimport random\nimport os\nimport yaml\nimport numpy as np\nimport torch\nimport torch.distributed as dist\nfrom torch import nn, autograd, optim\nfrom torch.nn import functional as F\nfrom torch.utils import data\nfrom torchvision import transforms, utils\nfrom tqdm import tqdm\nfrom PIL import Image\nfrom losses import *\nfrom options import BaseOptions\nfrom model import Generator, VolumeRenderDiscriminator\nfrom dataset import MultiResolutionDataset\nfrom utils import data_sampler, requires_grad, accumulate, sample_data, make_noise, mixing_noise, generate_camera_params\nfrom distributed import get_rank, synchronize, reduce_loss_dict, reduce_sum, get_world_size\n\ntry:\n    import wandb\nexcept ImportError:\n    wandb = None\n\n\ndef train(opt, experiment_opt, loader, generator, discriminator, g_optim, d_optim, g_ema, device):\n    loader = sample_data(loader)\n\n    mean_path_length = 0\n\n    d_loss_val = 0\n    r1_loss = torch.tensor(0.0, device=device)\n    d_view_loss = torch.tensor(0.0, device=device)\n    g_view_loss = torch.tensor(0.0, device=device)\n    g_eikonal = torch.tensor(0.0, device=device)\n    g_minimal_surface = torch.tensor(0.0, device=device)\n\n    g_loss_val = 0\n    loss_dict = {}\n\n    viewpoint_condition = opt.view_lambda > 0\n\n    if opt.distributed:\n        g_module = generator.module\n        d_module = discriminator.module\n    else:\n        g_module = generator\n        d_module = discriminator\n\n    accum = 0.5 ** (32 / (10 * 1000))\n\n    sample_z = [torch.randn(opt.val_n_sample, opt.style_dim, device=device).repeat_interleave(8,dim=0)]\n    sample_cam_extrinsics, sample_focals, sample_near, sample_far, _ = generate_camera_params(opt.renderer_output_size, device, batch=opt.val_n_sample, sweep=True,\n                                                                                         uniform=opt.camera.uniform, azim_range=opt.camera.azim,\n                                                                                         elev_range=opt.camera.elev, fov_ang=opt.camera.fov,\n                                                                                         dist_radius=opt.camera.dist_radius)\n\n    if opt.with_sdf and opt.sphere_init and opt.start_iter == 0:\n        init_pbar = range(10000)\n        if get_rank() == 0:\n            init_pbar = tqdm(init_pbar, initial=0, dynamic_ncols=True, smoothing=0.01)\n\n        generator.zero_grad()\n        for idx in init_pbar:\n            noise = mixing_noise(3, opt.style_dim, opt.mixing, device)\n            cam_extrinsics, focal, near, far, gt_viewpoints = generate_camera_params(opt.renderer_output_size, device, batch=3,\n                                                                                uniform=opt.camera.uniform, azim_range=opt.camera.azim,\n                                                                                elev_range=opt.camera.elev, fov_ang=opt.camera.fov,\n                                                                                dist_radius=opt.camera.dist_radius)\n            sdf, target_values = g_module.init_forward(noise, cam_extrinsics, focal, near, far)\n            loss = F.l1_loss(sdf, target_values)\n            loss.backward()\n            g_optim.step()\n            generator.zero_grad()\n            if get_rank() == 0:\n                init_pbar.set_description((f\"MLP init to sphere procedure - Loss: {loss.item():.4f}\"))\n\n        accumulate(g_ema, g_module, 0)\n        torch.save(\n            {\n                \"g\": g_module.state_dict(),\n                \"d\": d_module.state_dict(),\n                \"g_ema\": g_ema.state_dict(),\n            },\n            os.path.join(opt.checkpoints_dir, experiment_opt.expname, f\"sdf_init_models_{str(0).zfill(7)}.pt\")\n        )\n        print('Successfully saved checkpoint for SDF initialized MLP.')\n\n    pbar = range(opt.iter)\n    if get_rank() == 0:\n        pbar = tqdm(pbar, initial=opt.start_iter, dynamic_ncols=True, smoothing=0.01)\n\n    for idx in pbar:\n        i = idx + opt.start_iter\n\n        if i > opt.iter:\n            print(\"Done!\")\n\n            break\n\n        requires_grad(generator, False)\n        requires_grad(discriminator, True)\n        discriminator.zero_grad()\n        _, real_imgs = next(loader)\n        real_imgs = real_imgs.to(device)\n        noise = mixing_noise(opt.batch, opt.style_dim, opt.mixing, device)\n        cam_extrinsics, focal, near, far, gt_viewpoints = generate_camera_params(opt.renderer_output_size, device, batch=opt.batch,\n                                                                            uniform=opt.camera.uniform, azim_range=opt.camera.azim,\n                                                                            elev_range=opt.camera.elev, fov_ang=opt.camera.fov,\n                                                                            dist_radius=opt.camera.dist_radius)\n        gen_imgs = []\n        for j in range(0, opt.batch, opt.chunk):\n            curr_noise = [n[j:j+opt.chunk] for n in noise]\n            _, fake_img = generator(curr_noise,\n                                    cam_extrinsics[j:j+opt.chunk],\n                                    focal[j:j+opt.chunk],\n                                    near[j:j+opt.chunk],\n                                    far[j:j+opt.chunk])\n\n            gen_imgs += [fake_img]\n\n        gen_imgs = torch.cat(gen_imgs, 0)\n        fake_pred, fake_viewpoint_pred = discriminator(gen_imgs.detach())\n        if viewpoint_condition:\n            d_view_loss = opt.view_lambda * viewpoints_loss(fake_viewpoint_pred, gt_viewpoints)\n\n        real_imgs.requires_grad = True\n        real_pred, _ = discriminator(real_imgs)\n        d_gan_loss = d_logistic_loss(real_pred, fake_pred)\n        grad_penalty = d_r1_loss(real_pred, real_imgs)\n        r1_loss = opt.r1 * 0.5 * grad_penalty\n        d_loss = d_gan_loss + r1_loss + d_view_loss\n        d_loss.backward()\n        d_optim.step()\n\n        loss_dict[\"d\"] = d_gan_loss\n        loss_dict[\"r1\"] = r1_loss\n        loss_dict[\"d_view\"] = d_view_loss\n        loss_dict[\"real_score\"] = real_pred.mean()\n        loss_dict[\"fake_score\"] = fake_pred.mean()\n\n        requires_grad(generator, True)\n        requires_grad(discriminator, False)\n\n        for j in range(0, opt.batch, opt.chunk):\n            noise = mixing_noise(opt.chunk, opt.style_dim, opt.mixing, device)\n            cam_extrinsics, focal, near, far, curr_gt_viewpoints = generate_camera_params(opt.renderer_output_size, device, batch=opt.chunk,\n                                                                                     uniform=opt.camera.uniform, azim_range=opt.camera.azim,\n                                                                                     elev_range=opt.camera.elev, fov_ang=opt.camera.fov,\n                                                                                     dist_radius=opt.camera.dist_radius)\n\n            out = generator(noise, cam_extrinsics, focal, near, far,\n                            return_sdf=opt.min_surf_lambda > 0,\n                            return_eikonal=opt.eikonal_lambda > 0)\n            fake_img  = out[1]\n            if opt.min_surf_lambda > 0:\n                sdf = out[2]\n            if opt.eikonal_lambda > 0:\n                eikonal_term = out[3]\n\n            fake_pred, fake_viewpoint_pred = discriminator(fake_img)\n            if viewpoint_condition:\n                g_view_loss = opt.view_lambda * viewpoints_loss(fake_viewpoint_pred, curr_gt_viewpoints)\n\n            if opt.with_sdf and opt.eikonal_lambda > 0:\n                g_eikonal, g_minimal_surface = eikonal_loss(eikonal_term, sdf=sdf if opt.min_surf_lambda > 0 else None,\n                                                            beta=opt.min_surf_beta)\n                g_eikonal = opt.eikonal_lambda * g_eikonal\n                if opt.min_surf_lambda > 0:\n                    g_minimal_surface = opt.min_surf_lambda * g_minimal_surface\n\n            g_gan_loss = g_nonsaturating_loss(fake_pred)\n            g_loss = g_gan_loss + g_view_loss + g_eikonal + g_minimal_surface\n            g_loss.backward()\n\n        g_optim.step()\n        generator.zero_grad()\n        loss_dict[\"g\"] = g_gan_loss\n        loss_dict[\"g_view\"] = g_view_loss\n        loss_dict[\"g_eikonal\"] = g_eikonal\n        loss_dict[\"g_minimal_surface\"] = g_minimal_surface\n\n        accumulate(g_ema, g_module, accum)\n\n        loss_reduced = reduce_loss_dict(loss_dict)\n        d_loss_val = loss_reduced[\"d\"].mean().item()\n        g_loss_val = loss_reduced[\"g\"].mean().item()\n        r1_val = loss_reduced[\"r1\"].mean().item()\n        real_score_val = loss_reduced[\"real_score\"].mean().item()\n        fake_score_val = loss_reduced[\"fake_score\"].mean().item()\n        d_view_val = loss_reduced[\"d_view\"].mean().item()\n        g_view_val = loss_reduced[\"g_view\"].mean().item()\n        g_eikonal_loss = loss_reduced[\"g_eikonal\"].mean().item()\n        g_minimal_surface_loss = loss_reduced[\"g_minimal_surface\"].mean().item()\n        g_beta_val = g_module.renderer.sigmoid_beta.item() if opt.with_sdf else 0\n\n        if get_rank() == 0:\n            pbar.set_description(\n                (f\"d: {d_loss_val:.4f}; g: {g_loss_val:.4f}; r1: {r1_val:.4f}; viewpoint: {d_view_val+g_view_val:.4f}; eikonal: {g_eikonal_loss:.4f}; surf: {g_minimal_surface_loss:.4f}\")\n            )\n\n            if i % 1000 == 0:\n                with torch.no_grad():\n                    samples = torch.Tensor(0, 3, opt.renderer_output_size, opt.renderer_output_size)\n                    step_size = 4\n                    mean_latent = g_module.mean_latent(10000, device)\n                    for k in range(0, opt.val_n_sample * 8, step_size):\n                        _, curr_samples = g_ema([sample_z[0][k:k+step_size]],\n                                                sample_cam_extrinsics[k:k+step_size],\n                                                sample_focals[k:k+step_size],\n                                                sample_near[k:k+step_size],\n                                                sample_far[k:k+step_size],\n                                                truncation=0.7,\n                                                truncation_latent=mean_latent,)\n                        samples = torch.cat([samples, curr_samples.cpu()], 0)\n\n                    if i % 10000 == 0:\n                        utils.save_image(samples,\n                            os.path.join(opt.checkpoints_dir, experiment_opt.expname, 'volume_renderer', f\"samples/{str(i).zfill(7)}.png\"),\n                            nrow=int(opt.val_n_sample),\n                            normalize=True,\n                            value_range=(-1, 1),)\n\n            if wandb and opt.wandb:\n                wandb_log_dict = {\"Generator\": g_loss_val,\n                                  \"Discriminator\": d_loss_val,\n                                  \"R1\": r1_val,\n                                  \"Real Score\": real_score_val,\n                                  \"Fake Score\": fake_score_val,\n                                  \"D viewpoint\": d_view_val,\n                                  \"G viewpoint\": g_view_val,\n                                  \"G eikonal loss\": g_eikonal_loss,\n                                  \"G minimal surface loss\": g_minimal_surface_loss,\n                                  }\n                if opt.with_sdf:\n                    wandb_log_dict.update({\"Beta value\": g_beta_val})\n\n                if i % 1000 == 0:\n                    wandb_grid = utils.make_grid(samples, nrow=int(opt.val_n_sample),\n                                                   normalize=True, value_range=(-1, 1))\n                    wandb_ndarr = (255 * wandb_grid.permute(1, 2, 0).numpy()).astype(np.uint8)\n                    wandb_images = Image.fromarray(wandb_ndarr)\n                    wandb_log_dict.update({\"examples\": [wandb.Image(wandb_images,\n                                            caption=\"Generated samples for azimuth angles of: -35, -25, -15, -5, 5, 15, 25, 35 degrees.\")]})\n\n                wandb.log(wandb_log_dict)\n\n            if i % 10000 == 0 or (i < 10000 and i % 1000 == 0):\n                torch.save(\n                    {\n                        \"g\": g_module.state_dict(),\n                        \"d\": d_module.state_dict(),\n                        \"g_ema\": g_ema.state_dict(),\n                    },\n                    os.path.join(opt.checkpoints_dir, experiment_opt.expname, 'volume_renderer', f\"models_{str(i).zfill(7)}.pt\")\n                )\n                print('Successfully saved checkpoint for iteration {}.'.format(i))\n\n    if get_rank() == 0:\n        # create final model directory\n        final_model_path = 'pretrained_renderer'\n        os.makedirs(final_model_path, exist_ok=True)\n        torch.save(\n            {\n                \"g\": g_module.state_dict(),\n                \"d\": d_module.state_dict(),\n                \"g_ema\": g_ema.state_dict(),\n            },\n            os.path.join(final_model_path, experiment_opt.expname + '_vol_renderer.pt')\n        )\n        print('Successfully saved final model.')\n\n\nif __name__ == \"__main__\":\n    device = \"cuda\"\n    opt = BaseOptions().parse()\n    opt.model.freeze_renderer = False\n    opt.model.no_viewpoint_loss = opt.training.view_lambda == 0.0\n    opt.training.camera = opt.camera\n    opt.training.renderer_output_size = opt.model.renderer_spatial_output_dim\n    opt.training.style_dim = opt.model.style_dim\n    opt.training.with_sdf = not opt.rendering.no_sdf\n    if opt.training.with_sdf and opt.training.min_surf_lambda > 0:\n        opt.rendering.return_sdf = True\n    opt.training.iter = 200001\n    opt.rendering.no_features_output = True\n\n    n_gpu = int(os.environ[\"WORLD_SIZE\"]) if \"WORLD_SIZE\" in os.environ else 1\n    opt.training.distributed = n_gpu > 1\n\n    if opt.training.distributed:\n        torch.cuda.set_device(int(os.environ[\"LOCAL_RANK\"]))\n        torch.distributed.init_process_group(backend=\"nccl\", init_method=\"env://\")\n        synchronize()\n\n    # create checkpoints directories\n    os.makedirs(os.path.join(opt.training.checkpoints_dir, opt.experiment.expname, 'volume_renderer'), exist_ok=True)\n    os.makedirs(os.path.join(opt.training.checkpoints_dir, opt.experiment.expname, 'volume_renderer', 'samples'), exist_ok=True)\n\n    discriminator = VolumeRenderDiscriminator(opt.model).to(device)\n    generator = Generator(opt.model, opt.rendering, full_pipeline=False).to(device)\n    g_ema = Generator(opt.model, opt.rendering, ema=True, full_pipeline=False).to(device)\n\n    g_ema.eval()\n    accumulate(g_ema, generator, 0)\n    g_optim = optim.Adam(generator.parameters(), lr=2e-5, betas=(0, 0.9))\n    d_optim = optim.Adam(discriminator.parameters(), lr=2e-4, betas=(0, 0.9))\n\n    opt.training.start_iter = 0\n\n    if opt.experiment.continue_training and opt.experiment.ckpt is not None:\n        if get_rank() == 0:\n            print(\"load model:\", opt.experiment.ckpt)\n        ckpt_path = os.path.join(opt.training.checkpoints_dir,\n                                 opt.experiment.expname,\n                                 'models_{}.pt'.format(opt.experiment.ckpt.zfill(7)))\n        ckpt = torch.load(ckpt_path, map_location=lambda storage, loc: storage)\n\n        try:\n            opt.training.start_iter = int(opt.experiment.ckpt) + 1\n\n        except ValueError:\n            pass\n\n        generator.load_state_dict(ckpt[\"g\"])\n        discriminator.load_state_dict(ckpt[\"d\"])\n        g_ema.load_state_dict(ckpt[\"g_ema\"])\n        if \"g_optim\" in ckpt.keys():\n            g_optim.load_state_dict(ckpt[\"g_optim\"])\n            d_optim.load_state_dict(ckpt[\"d_optim\"])\n\n    sphere_init_path = './pretrained_renderer/sphere_init.pt'\n    if opt.training.no_sphere_init:\n        opt.training.sphere_init = False\n    elif not opt.experiment.continue_training and opt.training.with_sdf and os.path.isfile(sphere_init_path):\n        if get_rank() == 0:\n            print(\"loading sphere inititialized model\")\n        ckpt = torch.load(sphere_init_path, map_location=lambda storage, loc: storage)\n        generator.load_state_dict(ckpt[\"g\"])\n        discriminator.load_state_dict(ckpt[\"d\"])\n        g_ema.load_state_dict(ckpt[\"g_ema\"])\n        opt.training.sphere_init = False\n    else:\n        opt.training.sphere_init = True\n\n    if opt.training.distributed:\n        generator = nn.parallel.DistributedDataParallel(\n            generator,\n            device_ids=[opt.training.local_rank],\n            output_device=opt.training.local_rank,\n            broadcast_buffers=True,\n            find_unused_parameters=True,\n        )\n\n        discriminator = nn.parallel.DistributedDataParallel(\n            discriminator,\n            device_ids=[opt.training.local_rank],\n            output_device=opt.training.local_rank,\n            broadcast_buffers=False,\n            find_unused_parameters=True\n        )\n\n    transform = transforms.Compose(\n        [transforms.ToTensor(),\n         transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)])\n\n    dataset = MultiResolutionDataset(opt.dataset.dataset_path, transform, opt.model.size,\n                                     opt.model.renderer_spatial_output_dim)\n    loader = data.DataLoader(\n        dataset,\n        batch_size=opt.training.batch,\n        sampler=data_sampler(dataset, shuffle=True, distributed=opt.training.distributed),\n        drop_last=True,\n    )\n    opt.training.dataset_name = opt.dataset.dataset_path.lower()\n\n    # save options\n    opt_path = os.path.join(opt.training.checkpoints_dir, opt.experiment.expname, 'volume_renderer', f\"opt.yaml\")\n    with open(opt_path,'w') as f:\n        yaml.safe_dump(opt, f)\n\n    # set wandb environment\n    if get_rank() == 0 and wandb is not None and opt.training.wandb:\n        wandb.init(project=\"StyleSDF\")\n        wandb.run.name = opt.experiment.expname\n        wandb.config.dataset = os.path.basename(opt.dataset.dataset_path)\n        wandb.config.update(opt.training)\n        wandb.config.update(opt.model)\n        wandb.config.update(opt.rendering)\n\n    train(opt.training, opt.experiment, loader, generator, discriminator, g_optim, d_optim, g_ema, device)\n"
  },
  {
    "path": "utils.py",
    "content": "import torch\nimport random\nimport trimesh\nimport numpy as np\nfrom torch import nn\nfrom torch.nn import functional as F\nfrom torch.utils import data\nfrom scipy.spatial import Delaunay\nfrom skimage.measure import marching_cubes\nfrom pdb import set_trace as st\n\nimport pytorch3d.io\nfrom pytorch3d.structures import Meshes\nfrom pytorch3d.renderer import (\n    look_at_view_transform,\n    FoVPerspectiveCameras,\n    PointLights,\n    RasterizationSettings,\n    MeshRenderer,\n    MeshRasterizer,\n    SoftPhongShader,\n    TexturesVertex,\n)\n\n######################### Dataset util functions ###########################\n# Get data sampler\ndef data_sampler(dataset, shuffle, distributed):\n    if distributed:\n        return data.distributed.DistributedSampler(dataset, shuffle=shuffle)\n\n    if shuffle:\n        return data.RandomSampler(dataset)\n\n    else:\n        return data.SequentialSampler(dataset)\n\n# Get data minibatch\ndef sample_data(loader):\n    while True:\n        for batch in loader:\n            yield batch\n\n############################## Model weights util functions #################\n# Turn model gradients on/off\ndef requires_grad(model, flag=True):\n    for p in model.parameters():\n        p.requires_grad = flag\n\n# Exponential moving average for generator weights\ndef accumulate(model1, model2, decay=0.999):\n    par1 = dict(model1.named_parameters())\n    par2 = dict(model2.named_parameters())\n\n    for k in par1.keys():\n        par1[k].data.mul_(decay).add_(par2[k].data, alpha=1 - decay)\n\n################### Latent code (Z) sampling util functions ####################\n# Sample Z space latent codes for the generator\ndef make_noise(batch, latent_dim, n_noise, device):\n    if n_noise == 1:\n        return torch.randn(batch, latent_dim, device=device)\n\n    noises = torch.randn(n_noise, batch, latent_dim, device=device).unbind(0)\n\n    return noises\n\n\ndef mixing_noise(batch, latent_dim, prob, device):\n    if prob > 0 and random.random() < prob:\n        return make_noise(batch, latent_dim, 2, device)\n    else:\n        return [make_noise(batch, latent_dim, 1, device)]\n\n################# Camera parameters sampling ####################\ndef generate_camera_params(resolution, device, batch=1, locations=None, sweep=False,\n                           uniform=False, azim_range=0.3, elev_range=0.15,\n                           fov_ang=6, dist_radius=0.12):\n    if locations != None:\n        azim = locations[:,0].view(-1,1)\n        elev = locations[:,1].view(-1,1)\n\n        # generate intrinsic parameters\n        # fix distance to 1\n        dist = torch.ones(azim.shape[0], 1, device=device)\n        near, far = (dist - dist_radius).unsqueeze(-1), (dist + dist_radius).unsqueeze(-1)\n        fov_angle = fov_ang * torch.ones(azim.shape[0], 1, device=device).view(-1,1) * np.pi / 180\n        focal = 0.5 * resolution / torch.tan(fov_angle).unsqueeze(-1)\n    elif sweep:\n        # generate camera locations on the unit sphere\n        azim = (-azim_range + (2 * azim_range / 7) * torch.arange(8, device=device)).view(-1,1).repeat(batch,1)\n        elev = (-elev_range + 2 * elev_range * torch.rand(batch, 1, device=device).repeat(1,8).view(-1,1))\n\n        # generate intrinsic parameters\n        dist = (torch.ones(batch, 1, device=device)).repeat(1,8).view(-1,1)\n        near, far = (dist - dist_radius).unsqueeze(-1), (dist + dist_radius).unsqueeze(-1)\n        fov_angle = fov_ang * torch.ones(batch, 1, device=device).repeat(1,8).view(-1,1) * np.pi / 180\n        focal = 0.5 * resolution / torch.tan(fov_angle).unsqueeze(-1)\n    else:\n        # sample camera locations on the unit sphere\n        if uniform:\n            azim = (-azim_range + 2 * azim_range * torch.rand(batch, 1, device=device))\n            elev = (-elev_range + 2 * elev_range * torch.rand(batch, 1, device=device))\n        else:\n            azim = (azim_range * torch.randn(batch, 1, device=device))\n            elev = (elev_range * torch.randn(batch, 1, device=device))\n\n        # generate intrinsic parameters\n        dist = torch.ones(batch, 1, device=device) # restrict camera position to be on the unit sphere\n        near, far = (dist - dist_radius).unsqueeze(-1), (dist + dist_radius).unsqueeze(-1)\n        fov_angle = fov_ang * torch.ones(batch, 1, device=device) * np.pi / 180 # full fov is 12 degrees\n        focal = 0.5 * resolution / torch.tan(fov_angle).unsqueeze(-1)\n\n    viewpoint = torch.cat([azim, elev], 1)\n\n    #### Generate camera extrinsic matrix ##########\n\n    # convert angles to xyz coordinates\n    x = torch.cos(elev) * torch.sin(azim)\n    y = torch.sin(elev)\n    z = torch.cos(elev) * torch.cos(azim)\n    camera_dir = torch.stack([x, y, z], dim=1).view(-1,3)\n    camera_loc = dist * camera_dir\n\n    # get rotation matrices (assume object is at the world coordinates origin)\n    up = torch.tensor([[0,1,0]]).float().to(device) * torch.ones_like(dist)\n    z_axis = F.normalize(camera_dir, eps=1e-5) # the -z direction points into the screen\n    x_axis = F.normalize(torch.cross(up, z_axis, dim=1), eps=1e-5)\n    y_axis = F.normalize(torch.cross(z_axis, x_axis, dim=1), eps=1e-5)\n    is_close = torch.isclose(x_axis, torch.tensor(0.0), atol=5e-3).all(dim=1, keepdim=True)\n    if is_close.any():\n        replacement = F.normalize(torch.cross(y_axis, z_axis, dim=1), eps=1e-5)\n        x_axis = torch.where(is_close, replacement, x_axis)\n    R = torch.cat((x_axis[:, None, :], y_axis[:, None, :], z_axis[:, None, :]), dim=1)\n    T = camera_loc[:, :, None]\n    extrinsics = torch.cat((R.transpose(1,2),T), -1)\n\n    return extrinsics, focal, near, far, viewpoint\n\n#################### Mesh generation util functions ########################\n# Reshape sampling volume to camera frostum\ndef align_volume(volume, near=0.88, far=1.12):\n    b, h, w, d, c = volume.shape\n    yy, xx, zz = torch.meshgrid(torch.linspace(-1, 1, h),\n                                torch.linspace(-1, 1, w),\n                                torch.linspace(-1, 1, d))\n\n    grid = torch.stack([xx, yy, zz], -1).to(volume.device)\n\n    frostum_adjustment_coeffs = torch.linspace(far / near, 1, d).view(1,1,1,-1,1).to(volume.device)\n    frostum_grid = grid.unsqueeze(0)\n    frostum_grid[...,:2] = frostum_grid[...,:2] * frostum_adjustment_coeffs\n    out_of_boundary = torch.any((frostum_grid.lt(-1).logical_or(frostum_grid.gt(1))), -1, keepdim=True)\n    frostum_grid = frostum_grid.permute(0,3,1,2,4).contiguous()\n    permuted_volume = volume.permute(0,4,3,1,2).contiguous()\n    final_volume = F.grid_sample(permuted_volume, frostum_grid, padding_mode=\"border\", align_corners=True)\n    final_volume = final_volume.permute(0,3,4,2,1).contiguous()\n    # set a non-zero value to grid locations outside of the frostum to avoid marching cubes distortions.\n    # It happens because pytorch grid_sample uses zeros padding.\n    final_volume[out_of_boundary] = 1\n\n    return final_volume\n\n# Extract mesh with marching cubes\ndef extract_mesh_with_marching_cubes(sdf):\n    b, h, w, d, _ = sdf.shape\n\n    # change coordinate order from (y,x,z) to (x,y,z)\n    sdf_vol = sdf[0,...,0].permute(1,0,2).cpu().numpy()\n\n    # scale vertices\n    verts, faces, _, _ = marching_cubes(sdf_vol, 0)\n    verts[:,0] = (verts[:,0]/float(w)-0.5)*0.24\n    verts[:,1] = (verts[:,1]/float(h)-0.5)*0.24\n    verts[:,2] = (verts[:,2]/float(d)-0.5)*0.24\n\n    # fix normal direction\n    verts[:,2] *= -1; verts[:,1] *= -1\n    mesh = trimesh.Trimesh(verts, faces)\n\n    return mesh\n\n# Generate mesh from xyz point cloud\ndef xyz2mesh(xyz):\n    b, _, h, w = xyz.shape\n    x, y = np.meshgrid(np.arange(h), np.arange(w))\n\n    # Extract mesh faces from xyz maps\n    tri = Delaunay(np.concatenate((x.reshape((h*w, 1)), y.reshape((h*w, 1))), 1))\n    faces = tri.simplices\n\n    # invert normals\n    faces[:,[0, 1]] = faces[:,[1, 0]]\n\n    # generate_meshes\n    mesh = trimesh.Trimesh(xyz.squeeze(0).permute(1,2,0).view(h*w,3).cpu().numpy(), faces)\n\n    return mesh\n\n\n################# Mesh rendering util functions #############################\ndef add_textures(meshes:Meshes, vertex_colors=None) -> Meshes:\n    verts = meshes.verts_padded()\n    if vertex_colors is None:\n        vertex_colors = torch.ones_like(verts) # (N, V, 3)\n    textures = TexturesVertex(verts_features=vertex_colors)\n    meshes_t = Meshes(\n        verts=verts,\n        faces=meshes.faces_padded(),\n        textures=textures,\n        verts_normals=meshes.verts_normals_padded(),\n    )\n    return meshes_t\n\n\ndef create_cameras(\n    R=None, T=None,\n    azim=0, elev=0., dist=1.,\n    fov=12., znear=0.01,\n    device=\"cuda\") -> FoVPerspectiveCameras:\n    \"\"\"\n    all the camera parameters can be a single number, a list, or a torch tensor.\n    \"\"\"\n    if R is None or T is None:\n        R, T = look_at_view_transform(dist=dist, azim=azim, elev=elev, device=device)\n    cameras = FoVPerspectiveCameras(device=device, R=R, T=T, znear=znear, fov=fov)\n    return cameras\n\n\ndef create_mesh_renderer(\n        cameras: FoVPerspectiveCameras,\n        image_size: int = 256,\n        blur_radius: float = 1e-6,\n        light_location=((-0.5, 1., 5.0),),\n        device=\"cuda\",\n        **light_kwargs,\n):\n    \"\"\"\n    If don't want to show direct texture color without shading, set the light_kwargs as\n    ambient_color=((1, 1, 1), ), diffuse_color=((0, 0, 0), ), specular_color=((0, 0, 0), )\n    \"\"\"\n    # We will also create a Phong renderer. This is simpler and only needs to render one face per pixel.\n    raster_settings = RasterizationSettings(\n        image_size=image_size,\n        blur_radius=blur_radius,\n        faces_per_pixel=5,\n    )\n    # We can add a point light in front of the object.\n    lights = PointLights(\n        device=device, location=light_location, **light_kwargs\n    )\n    phong_renderer = MeshRenderer(\n        rasterizer=MeshRasterizer(\n            cameras=cameras,\n            raster_settings=raster_settings\n        ),\n        shader=SoftPhongShader(device=device, cameras=cameras, lights=lights)\n    )\n\n    return phong_renderer\n\n\n## custom renderer\nclass MeshRendererWithDepth(nn.Module):\n    def __init__(self, rasterizer, shader):\n        super().__init__()\n        self.rasterizer = rasterizer\n        self.shader = shader\n\n    def forward(self, meshes_world, **kwargs) -> torch.Tensor:\n        fragments = self.rasterizer(meshes_world, **kwargs)\n        images = self.shader(fragments, meshes_world, **kwargs)\n        return images, fragments.zbuf\n\n\ndef create_depth_mesh_renderer(\n        cameras: FoVPerspectiveCameras,\n        image_size: int = 256,\n        blur_radius: float = 1e-6,\n        device=\"cuda\",\n        **light_kwargs,\n):\n    \"\"\"\n    If don't want to show direct texture color without shading, set the light_kwargs as\n    ambient_color=((1, 1, 1), ), diffuse_color=((0, 0, 0), ), specular_color=((0, 0, 0), )\n    \"\"\"\n    # We will also create a Phong renderer. This is simpler and only needs to render one face per pixel.\n    raster_settings = RasterizationSettings(\n        image_size=image_size,\n        blur_radius=blur_radius,\n        faces_per_pixel=17,\n    )\n    # We can add a point light in front of the object.\n    lights = PointLights(\n        device=device, location=((-0.5, 1., 5.0),), **light_kwargs\n    )\n    renderer = MeshRendererWithDepth(\n        rasterizer=MeshRasterizer(\n            cameras=cameras,\n            raster_settings=raster_settings,\n            device=device,\n        ),\n        shader=SoftPhongShader(device=device, cameras=cameras, lights=lights)\n    )\n\n    return renderer\n"
  },
  {
    "path": "volume_renderer.py",
    "content": "import math\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport torch.autograd as autograd\nimport numpy as np\nfrom functools import partial\nfrom pdb import set_trace as st\n\n\n# Basic SIREN fully connected layer\nclass LinearLayer(nn.Module):\n    def __init__(self, in_dim, out_dim, bias=True, bias_init=0, std_init=1, freq_init=False, is_first=False):\n        super().__init__()\n        if is_first:\n            self.weight = nn.Parameter(torch.empty(out_dim, in_dim).uniform_(-1 / in_dim, 1 / in_dim))\n        elif freq_init:\n            self.weight = nn.Parameter(torch.empty(out_dim, in_dim).uniform_(-np.sqrt(6 / in_dim) / 25, np.sqrt(6 / in_dim) / 25))\n        else:\n            self.weight = nn.Parameter(0.25 * nn.init.kaiming_normal_(torch.randn(out_dim, in_dim), a=0.2, mode='fan_in', nonlinearity='leaky_relu'))\n\n        self.bias = nn.Parameter(nn.init.uniform_(torch.empty(out_dim), a=-np.sqrt(1/in_dim), b=np.sqrt(1/in_dim)))\n\n        self.bias_init = bias_init\n        self.std_init = std_init\n\n    def forward(self, input):\n        out = self.std_init * F.linear(input, self.weight, bias=self.bias) + self.bias_init\n\n        return out\n\n# Siren layer with frequency modulation and offset\nclass FiLMSiren(nn.Module):\n    def __init__(self, in_channel, out_channel, style_dim, is_first=False):\n        super().__init__()\n        self.in_channel = in_channel\n        self.out_channel = out_channel\n\n        if is_first:\n            self.weight = nn.Parameter(torch.empty(out_channel, in_channel).uniform_(-1 / 3, 1 / 3))\n        else:\n            self.weight = nn.Parameter(torch.empty(out_channel, in_channel).uniform_(-np.sqrt(6 / in_channel) / 25, np.sqrt(6 / in_channel) / 25))\n\n        self.bias = nn.Parameter(nn.Parameter(nn.init.uniform_(torch.empty(out_channel), a=-np.sqrt(1/in_channel), b=np.sqrt(1/in_channel))))\n        self.activation = torch.sin\n\n        self.gamma = LinearLayer(style_dim, out_channel, bias_init=30, std_init=15)\n        self.beta = LinearLayer(style_dim, out_channel, bias_init=0, std_init=0.25)\n\n    def forward(self, input, style):\n        batch, features = style.shape\n        out = F.linear(input, self.weight, bias=self.bias)\n        gamma = self.gamma(style).view(batch, 1, 1, 1, features)\n        beta = self.beta(style).view(batch, 1, 1, 1, features)\n\n        out = self.activation(gamma * out + beta)\n\n        return out\n\n\n# Siren Generator Model\nclass SirenGenerator(nn.Module):\n    def __init__(self, D=8, W=256, style_dim=256, input_ch=3, input_ch_views=3, output_ch=4,\n                 output_features=True):\n        super(SirenGenerator, self).__init__()\n        self.D = D\n        self.W = W\n        self.input_ch = input_ch\n        self.input_ch_views = input_ch_views\n        self.style_dim = style_dim\n        self.output_features = output_features\n\n        self.pts_linears = nn.ModuleList(\n            [FiLMSiren(3, W, style_dim=style_dim, is_first=True)] + \\\n            [FiLMSiren(W, W, style_dim=style_dim) for i in range(D-1)])\n\n        self.views_linears = FiLMSiren(input_ch_views + W, W,\n                                       style_dim=style_dim)\n        self.rgb_linear = LinearLayer(W, 3, freq_init=True)\n        self.sigma_linear = LinearLayer(W, 1, freq_init=True)\n\n    def forward(self, x, styles):\n        input_pts, input_views = torch.split(x, [self.input_ch, self.input_ch_views], dim=-1)\n        mlp_out = input_pts.contiguous()\n        for i in range(len(self.pts_linears)):\n            mlp_out = self.pts_linears[i](mlp_out, styles)\n\n\n        sdf = self.sigma_linear(mlp_out)\n\n        mlp_out = torch.cat([mlp_out, input_views], -1)\n        out_features = self.views_linears(mlp_out, styles)\n        rgb = self.rgb_linear(out_features)\n\n        outputs = torch.cat([rgb, sdf], -1)\n        if self.output_features:\n            outputs = torch.cat([outputs, out_features], -1)\n\n        return outputs\n\n\n# Full volume renderer\nclass VolumeFeatureRenderer(nn.Module):\n    def __init__(self, opt, style_dim=256, out_im_res=64, mode='train'):\n        super().__init__()\n        self.test = mode != 'train'\n        self.perturb = opt.perturb\n        self.offset_sampling = not opt.no_offset_sampling # Stratified sampling used otherwise\n        self.N_samples = opt.N_samples\n        self.raw_noise_std = opt.raw_noise_std\n        self.return_xyz = opt.return_xyz\n        self.return_sdf = opt.return_sdf\n        self.static_viewdirs = opt.static_viewdirs\n        self.z_normalize = not opt.no_z_normalize\n        self.out_im_res = out_im_res\n        self.force_background = opt.force_background\n        self.with_sdf = not opt.no_sdf\n        if 'no_features_output' in opt.keys():\n            self.output_features = False\n        else:\n            self.output_features = True\n\n        if self.with_sdf:\n            self.sigmoid_beta = nn.Parameter(0.1 * torch.ones(1))\n\n        # create meshgrid to generate rays\n        i, j = torch.meshgrid(torch.linspace(0.5, self.out_im_res - 0.5, self.out_im_res),\n                              torch.linspace(0.5, self.out_im_res - 0.5, self.out_im_res))\n\n        self.register_buffer('i', i.t().unsqueeze(0), persistent=False)\n        self.register_buffer('j', j.t().unsqueeze(0), persistent=False)\n\n        # create integration values\n        if self.offset_sampling:\n            t_vals = torch.linspace(0., 1.-1/self.N_samples, steps=self.N_samples).view(1,1,1,-1)\n        else: # Original NeRF Stratified sampling\n            t_vals = torch.linspace(0., 1., steps=self.N_samples).view(1,1,1,-1)\n\n        self.register_buffer('t_vals', t_vals, persistent=False)\n        self.register_buffer('inf', torch.Tensor([1e10]), persistent=False)\n        self.register_buffer('zero_idx', torch.LongTensor([0]), persistent=False)\n\n        if self.test:\n            self.perturb = False\n            self.raw_noise_std = 0.\n\n        self.channel_dim = -1\n        self.samples_dim = 3\n        self.input_ch = 3\n        self.input_ch_views = 3\n        self.feature_out_size = opt.width\n\n        # set Siren Generator model\n        self.network = SirenGenerator(D=opt.depth, W=opt.width, style_dim=style_dim, input_ch=self.input_ch,\n                                      output_ch=4, input_ch_views=self.input_ch_views,\n                                      output_features=self.output_features)\n\n    def get_rays(self, focal, c2w):\n        dirs = torch.stack([(self.i - self.out_im_res * .5) / focal,\n                            -(self.j - self.out_im_res * .5) / focal,\n                            -torch.ones_like(self.i).expand(focal.shape[0],self.out_im_res, self.out_im_res)], -1)\n\n        # Rotate ray directions from camera frame to the world frame\n        rays_d = torch.sum(dirs[..., None, :] * c2w[:,None,None,:3,:3], -1)  # dot product, equals to: [c2w.dot(dir) for dir in dirs]\n\n        # Translate camera frame's origin to the world frame. It is the origin of all rays.\n        rays_o = c2w[:,None,None,:3,-1].expand(rays_d.shape)\n        if self.static_viewdirs:\n            viewdirs = dirs\n        else:\n            viewdirs = rays_d\n\n        return rays_o, rays_d, viewdirs\n\n    def get_eikonal_term(self, pts, sdf):\n        eikonal_term = autograd.grad(outputs=sdf, inputs=pts,\n                                     grad_outputs=torch.ones_like(sdf),\n                                     create_graph=True)[0]\n\n        return eikonal_term\n\n    def sdf_activation(self, input):\n        sigma = torch.sigmoid(input / self.sigmoid_beta) / self.sigmoid_beta\n\n        return sigma\n\n    def volume_integration(self, raw, z_vals, rays_d, pts, return_eikonal=False):\n        dists = z_vals[...,1:] - z_vals[...,:-1]\n        rays_d_norm = torch.norm(rays_d.unsqueeze(self.samples_dim), dim=self.channel_dim)\n        # dists still has 4 dimensions here instead of 5, hence, in this case samples dim is actually the channel dim\n        dists = torch.cat([dists, self.inf.expand(rays_d_norm.shape)], self.channel_dim)  # [N_rays, N_samples]\n        dists = dists * rays_d_norm\n\n        # If sdf modeling is off, the sdf variable stores the\n        # pre-integration raw sigma MLP outputs.\n        if self.output_features:\n            rgb, sdf, features = torch.split(raw, [3, 1, self.feature_out_size], dim=self.channel_dim)\n        else:\n            rgb, sdf = torch.split(raw, [3, 1], dim=self.channel_dim)\n\n        noise = 0.\n        if self.raw_noise_std > 0.:\n            noise = torch.randn_like(sdf) * self.raw_noise_std\n\n        if self.with_sdf:\n            sigma = self.sdf_activation(-sdf)\n\n            if return_eikonal:\n                eikonal_term = self.get_eikonal_term(pts, sdf)\n            else:\n                eikonal_term = None\n\n            sigma = 1 - torch.exp(-sigma * dists.unsqueeze(self.channel_dim))\n        else:\n            sigma = sdf\n            eikonal_term = None\n\n            sigma = 1 - torch.exp(-F.softplus(sigma + noise) * dists.unsqueeze(self.channel_dim))\n\n        visibility = torch.cumprod(torch.cat([torch.ones_like(torch.index_select(sigma, self.samples_dim, self.zero_idx)),\n                                              1.-sigma + 1e-10], self.samples_dim), self.samples_dim)\n        visibility = visibility[...,:-1,:]\n        weights = sigma * visibility\n\n\n        if self.return_sdf:\n            sdf_out = sdf\n        else:\n            sdf_out = None\n\n        if self.force_background:\n            weights[...,-1,:] = 1 - weights[...,:-1,:].sum(self.samples_dim)\n\n        rgb_map = -1 + 2 * torch.sum(weights * torch.sigmoid(rgb), self.samples_dim)  # switch to [-1,1] value range\n\n        if self.output_features:\n            feature_map = torch.sum(weights * features, self.samples_dim)\n        else:\n            feature_map = None\n\n        # Return surface point cloud in world coordinates.\n        # This is used to generate the depth maps visualizations.\n        # We use world coordinates to avoid transformation errors between\n        # surface renderings from different viewpoints.\n        if self.return_xyz:\n            xyz = torch.sum(weights * pts, self.samples_dim)\n            mask = weights[...,-1,:] # background probability map\n        else:\n            xyz = None\n            mask = None\n\n        return rgb_map, feature_map, sdf_out, mask, xyz, eikonal_term\n\n    def run_network(self, inputs, viewdirs, styles=None):\n        input_dirs = viewdirs.unsqueeze(self.samples_dim).expand(inputs.shape)\n        net_inputs = torch.cat([inputs, input_dirs], self.channel_dim)\n        outputs = self.network(net_inputs, styles=styles)\n\n        return outputs\n\n    def render_rays(self, ray_batch, styles=None, return_eikonal=False):\n        batch, h, w, _ = ray_batch.shape\n        split_pattern = [3, 3, 2]\n        if ray_batch.shape[-1] > 8:\n            split_pattern += [3]\n            rays_o, rays_d, bounds, viewdirs = torch.split(ray_batch, split_pattern, dim=self.channel_dim)\n        else:\n            rays_o, rays_d, bounds = torch.split(ray_batch, split_pattern, dim=self.channel_dim)\n            viewdirs = None\n\n        near, far = torch.split(bounds, [1, 1], dim=self.channel_dim)\n        z_vals = near * (1.-self.t_vals) + far * (self.t_vals)\n\n        if self.perturb > 0.:\n            if self.offset_sampling:\n                # random offset samples\n                upper = torch.cat([z_vals[...,1:], far], -1)\n                lower = z_vals.detach()\n                t_rand = torch.rand(batch, h, w).unsqueeze(self.channel_dim).to(z_vals.device)\n            else:\n                # get intervals between samples\n                mids = .5 * (z_vals[...,1:] + z_vals[...,:-1])\n                upper = torch.cat([mids, z_vals[...,-1:]], -1)\n                lower = torch.cat([z_vals[...,:1], mids], -1)\n                # stratified samples in those intervals\n                t_rand = torch.rand(z_vals.shape).to(z_vals.device)\n\n            z_vals = lower + (upper - lower) * t_rand\n\n        pts = rays_o.unsqueeze(self.samples_dim) + rays_d.unsqueeze(self.samples_dim) * z_vals.unsqueeze(self.channel_dim)\n\n        if return_eikonal:\n            pts.requires_grad = True\n\n        if self.z_normalize:\n            normalized_pts = pts * 2 / ((far - near).unsqueeze(self.samples_dim))\n        else:\n            normalized_pts = pts\n\n        raw = self.run_network(normalized_pts, viewdirs, styles=styles)\n        rgb_map, features, sdf, mask, xyz, eikonal_term = self.volume_integration(raw, z_vals, rays_d, pts, return_eikonal=return_eikonal)\n\n        return rgb_map, features, sdf, mask, xyz, eikonal_term\n\n    def render(self, focal, c2w, near, far, styles, c2w_staticcam=None, return_eikonal=False):\n        rays_o, rays_d, viewdirs = self.get_rays(focal, c2w)\n        viewdirs = viewdirs / torch.norm(viewdirs, dim=-1, keepdim=True)\n\n        # Create ray batch\n        near = near.unsqueeze(-1) * torch.ones_like(rays_d[...,:1])\n        far = far.unsqueeze(-1) * torch.ones_like(rays_d[...,:1])\n        rays = torch.cat([rays_o, rays_d, near, far], -1)\n        rays = torch.cat([rays, viewdirs], -1)\n        rays = rays.float()\n        rgb, features, sdf, mask, xyz, eikonal_term = self.render_rays(rays, styles=styles, return_eikonal=return_eikonal)\n\n        return rgb, features, sdf, mask, xyz, eikonal_term\n\n    def mlp_init_pass(self, cam_poses, focal, near, far, styles=None):\n        rays_o, rays_d, viewdirs = self.get_rays(focal, cam_poses)\n        viewdirs = viewdirs / torch.norm(viewdirs, dim=-1, keepdim=True)\n\n        near = near.unsqueeze(-1) * torch.ones_like(rays_d[...,:1])\n        far = far.unsqueeze(-1) * torch.ones_like(rays_d[...,:1])\n        z_vals = near * (1.-self.t_vals) + far * (self.t_vals)\n\n        # get intervals between samples\n        mids = .5 * (z_vals[...,1:] + z_vals[...,:-1])\n        upper = torch.cat([mids, z_vals[...,-1:]], -1)\n        lower = torch.cat([z_vals[...,:1], mids], -1)\n        # stratified samples in those intervals\n        t_rand = torch.rand(z_vals.shape).to(z_vals.device)\n\n        z_vals = lower + (upper - lower) * t_rand\n        pts = rays_o.unsqueeze(self.samples_dim) + rays_d.unsqueeze(self.samples_dim) * z_vals.unsqueeze(self.channel_dim)\n        if self.z_normalize:\n            normalized_pts = pts * 2 / ((far - near).unsqueeze(self.samples_dim))\n        else:\n            normalized_pts = pts\n\n        raw = self.run_network(normalized_pts, viewdirs, styles=styles)\n        _, sdf = torch.split(raw, [3, 1], dim=self.channel_dim)\n        sdf = sdf.squeeze(self.channel_dim)\n        target_values = pts.detach().norm(dim=-1) - ((far - near) / 4)\n\n        return sdf, target_values\n\n    def forward(self, cam_poses, focal, near, far, styles=None, return_eikonal=False):\n        rgb, features, sdf, mask, xyz, eikonal_term = self.render(focal, c2w=cam_poses, near=near, far=far, styles=styles, return_eikonal=return_eikonal)\n\n        rgb = rgb.permute(0,3,1,2).contiguous()\n        if self.output_features:\n            features = features.permute(0,3,1,2).contiguous()\n\n        if xyz != None:\n            xyz = xyz.permute(0,3,1,2).contiguous()\n            mask = mask.permute(0,3,1,2).contiguous()\n\n        return rgb, features, sdf, mask, xyz, eikonal_term\n"
  }
]