Full Code of royorel/StyleSDF for AI

main 1c995101b80b cached
29 files
202.7 KB
51.7k tokens
171 symbols
1 requests
Download .txt
Showing preview only (212K chars total). Download the full file or copy to clipboard to get everything.
Repository: royorel/StyleSDF
Branch: main
Commit: 1c995101b80b
Files: 29
Total size: 202.7 KB

Directory structure:
gitextract__fgegk92/

├── .gitignore
├── LICENSE
├── README.md
├── StyleSDF_demo.ipynb
├── dataset.py
├── distributed.py
├── download_models.py
├── generate_shapes_and_images.py
├── losses.py
├── model.py
├── op/
│   ├── __init__.py
│   ├── fused_act.py
│   ├── fused_bias_act.cpp
│   ├── fused_bias_act_kernel.cu
│   ├── upfirdn2d.cpp
│   ├── upfirdn2d.py
│   └── upfirdn2d_kernel.cu
├── options.py
├── prepare_data.py
├── render_video.py
├── requirements.txt
├── scripts/
│   ├── train_afhq_full_pipeline_512x512.sh
│   ├── train_afhq_vol_renderer.sh
│   ├── train_ffhq_full_pipeline_1024x1024.sh
│   └── train_ffhq_vol_renderer.sh
├── train_full_pipeline.py
├── train_volume_renderer.py
├── utils.py
└── volume_renderer.py

================================================
FILE CONTENTS
================================================

================================================
FILE: .gitignore
================================================
__pycache__/
checkpoint/
datasets/
evaluations/
full_models/
pretrained_renderer/
wandb/


================================================
FILE: LICENSE
================================================
Copyright (C) 2022 Roy Or-El, Xuan Luo, Mengyi Shan, Eli Shechtman, Jeong Joon Park, Ira Kemelmacher-Shlizerman.
All rights reserved.
Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).

The software is made available under Creative Commons BY-NC-SA 4.0 license
by University of Washington. You can use, redistribute, and adapt it
for non-commercial purposes, as long as you (a) give appropriate credit
by citing our paper, (b) indicate any changes that you've made,
and (c) distribute any derivative works under the same license.

THE AUTHORS DISCLAIM ALL WARRANTIES WITH REGARD TO THIS SOFTWARE, INCLUDING ALL
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR ANY PARTICULAR PURPOSE.
IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY SPECIAL, INDIRECT OR CONSEQUENTIAL
DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS,
WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING
OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.


================================================
FILE: README.md
================================================
# StyleSDF: High-Resolution 3D-Consistent Image and Geometry Generation
### [Project Page](https://stylesdf.github.io/) | [Paper](https://arxiv.org/pdf/2112.11427.pdf) | [HuggingFace Demo](https://huggingface.co/spaces/SerdarHelli/StyleSDF-3D)

[![Explore in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/royorel/StyleSDF/blob/main/StyleSDF_demo.ipynb)<br>

[Roy Or-El](https://homes.cs.washington.edu/~royorel/)<sup>1</sup> ,
[Xuan Luo](https://roxanneluo.github.io/)<sup>1</sup>,
[Mengyi Shan](https://shanmy.github.io/)<sup>1</sup>,
[Eli Shechtman](https://research.adobe.com/person/eli-shechtman/)<sup>2</sup>,
[Jeong Joon Park](https://jjparkcv.github.io/)<sup>3</sup>,
[Ira Kemelmacher-Shlizerman](https://www.irakemelmacher.com/)<sup>1</sup><br>
<sup>1</sup>University of Washington, <sup>2</sup>Adobe Research, <sup>3</sup>Stanford University

<div align="center">
<img src=./assets/teaser.png>
</div>

## Updates
12/26/2022: A new HuggingFace demo is now available. Special thanks to [@SerdarHelli](https://github.com/SerdarHelli) for the implementation.<br>
3/27/2022: Fixed a bug in the sphere initialization code (init_forward function was missing, see commit [0bd8741](https://github.com/royorel/StyleSDF/commit/0bd8741f26048d26160a495a9056b5f1da1a60a1)).<br>
3/22/2022: **Added training files**.<br>
3/9/2022: Fixed a bug in the calculation of the mean w vector (see commit [d4dd17d](https://github.com/royorel/StyleSDF/commit/d4dd17de09fd58adefc7ed49487476af6018894f)).<br>
3/4/2022: Testing code and Colab demo were released. **Training files will be released soon.**

## Overview
StyleSDF is a 3D-aware GAN, aimed at solving two main challenges:
1. High-resolution, view-consistent generation of the RGB images.
2. Generating detailed 3D shapes.

StyleSDF is trained only on single-view RGB data. The 3D geometry is learned implicitly with an SDF-based volume renderer.<br>

This code is the official PyTorch implementation of the paper:
> **StyleSDF: High-Resolution 3D-Consistent Image and Geometry Generation**<br>
> Roy Or-El, Xuan Luo, Mengyi Shan, Eli Shechtman, Jeong Joon Park, Ira Kemelmacher-Shlizerman<br>
> CVPR 2022<br>
> https://arxiv.org/pdf/2112.11427.pdf

## Abstract
We introduce a high resolution, 3D-consistent image and
shape generation technique which we call StyleSDF. Our
method is trained on single-view RGB data only, and stands
on the shoulders of StyleGAN2 for image generation, while
solving two main challenges in 3D-aware GANs: 1) high-resolution,
view-consistent generation of the RGB images,
and 2) detailed 3D shape. We achieve this by merging a
SDF-based 3D representation with a style-based 2D generator.
Our 3D implicit network renders low-resolution feature
maps, from which the style-based network generates
view-consistent, 1024×1024 images. Notably, our SDFbased
3D modeling defines detailed 3D surfaces, leading
to consistent volume rendering. Our method shows higher
quality results compared to state of the art in terms of visual
and geometric quality.

## Pre-Requisits
You must have a **GPU with CUDA support** in order to run the code.

This code requires **PyTorch**, **PyTorch3D** and **torchvision** to be installed, please go to [PyTorch.org](https://pytorch.org/) and [PyTorch3d.org](https://pytorch3d.org/) for installation info.<br>
We tested our code on Python 3.8.5, PyTorch 1.9.0, PyTorch3D 0.6.1 and torchvision 0.10.0.

The following packages should also be installed:
1. lmdb
2. numpy
3. ninja
4. pillow
5. requests
6. tqdm
7. scipy
8. skimage
9. skvideo
10. trimesh[easy]
11. configargparse
12. munch
13. wandb (optional)

If any of these packages are not installed on your computer, you can install them using the supplied `requirements.txt` file:<br>
```pip install -r requirements.txt```

## Download Pre-trained Models
The pre-trained models can be downloaded by running `python download_models.py`.

## Quick Demo
You can explore our method in Google Colab [![Explore in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/royorel/StyleSDF/blob/main/StyleSDF_demo.ipynb).

Alternatively, you can download the pretrained models by running:<br>
`python download_models.py`

To generate human faces from the model pre-trained on FFHQ, run:<br>
`python generate_shapes_and_images.py --expname ffhq1024x1024 --size 1024 --identities NUMBER_OF_FACES`

To generate animal faces from the model pre-trained on AFHQ, run:<br>
`python generate_shapes_and_images.py --expname afhq512x512 --size 512 --identities NUMBER_OF_FACES`

## Generating images and meshes
To generate images and meshes from a trained model, run:
`python generate_shapes_and_images.py --expname NAME_OF_TRAINED_MODEL --size MODEL_OUTPUT_SIZE --identities NUMBER_OF_FACES`

The script will generate an RGB image, a mesh generated from depth map, and the mesh extracted with Marching cubes.

### Optional flags for image and mesh generation
```
  --no_surface_renderings          When true, only RGB outputs will be generated. otherwise, both RGB and depth videos/renderings will be generated. (default: false)
  --fixed_camera_angles            When true, the generator will render indentities from a fixed set of camera angles. (default: false)
```

## Generating Videos
To generate videos from a trained model, run: <br>
`python render_video.py --expname NAME_OF_TRAINED_MODEL --size MODEL_OUTPUT_SIZE --identities NUMBER_OF_FACES`.

This script will generate RGB video as well as depth map video for each identity. The average processing time per video is ~5-10 minutes on an RTX2080 Ti GPU.

### Optional flags for video rendering
```
  --no_surface_videos          When true, only RGB video will be generated when running render_video.py. otherwise, both RGB and depth videos will be generated. this cuts the processing time per video. (default: false)
  --azim_video                 When true, the camera trajectory will travel along the azimuth direction. Otherwise, the camera will travel along an ellipsoid trajectory. (default: ellipsoid)
  --project_noise              When true, use geometry-aware noise projection to reduce flickering effects (see supplementary section C.1 in the paper). Warning: processing time significantly increases with this flag to ~20 minutes per video. (default: false)
```

## Training
### Preparing your Dataset
If you wish to train a model from scratch, first you need to convert your dataset to an lmdb format. Run:<br>
`python prepare_data.py --out_path OUTPUT_LMDB_PATH --n_worker N_WORKER --size SIZE1,SIZE2,SIZE3,... INPUT_DATASET_PATH`

### Training the volume renderer
#### Training scripts
To train the volume renderer on FFHQ run: `bash ./scripts/train_ffhq_vol_renderer.sh`. <br>
To train the volume renderer on AFHQ run: `bash ./scripts/train_afhq_vol_renderer.sh`.

* The scripts above use distributed training. To train the models on a single GPU (not recommended) remove `-m torch.distributed.launch --nproc_per_node NUM_GPUS` from the script.

#### Training on a new dataset
To train the volume renderer on a new dataset, run:<br>
`python train_volume_renderer.py --batch BATCH_SIZE --chunk CHUNK_SIZE --expname EXPERIMENT_NAME --dataset_path DATASET_PATH`

Ideally, `CHUNK_SIZE` should be the same as `BATCH_SIZE`, but on most GPUs it will likely cause an out of memory error. In such case, reduce `CHUNK_SIZE` to perform gradient accumulation.

**Important note**: The best way to monitor the SDF convergence is to look at "Beta value" graph on wandb. Convergence is successful once beta reaches values below approx. 3*10<sup>-3</sup>. If the SDF is not converging, increase the R1 regularization weight. Another helpful option (to a lesser degree) is to decrease the weight of the minimal surface regularization.  

#### Distributed training
If you have multiple GPUs you can train your model on multiple instances by running:<br>
`python -m torch.distributed.launch --nproc_per_node NUM_GPUS train_volume_renderer.py --batch BATCH_SIZE --chunk CHUNK_SIZE --expname EXPERIMENT_NAME --dataset_path DATASET_PATH`

### Training the full pipeline
#### Training scripts
To train the volume renderer on FFHQ run: `bash ./scripts/train_ffhq_full_pipeline_1024x1024.sh`. <br>
To train the volume renderer on AFHQ run: `bash ./scripts/train_afhq_full_pipeline_512x512.sh`.

* The scripts above assume that the volume renderer model was already trained. **Do not run them from scratch.**
* The scripts above use distributed training. To train the models on a single GPU (not recommended) remove `-m torch.distributed.launch --nproc_per_node NUM_GPUS` from the script.

#### Training on a new dataset
To train the full pipeline on a new dataset, **first train the volume renderer separately**. <br>
After the volume renderer training is finished, run:<br>
`python train_full_pipeline.py --batch BATCH_SIZE --chunk CHUNK_SIZE --expname EXPERIMENT_NAME --size OUTPUT_SIZE`

Ideally, `CHUNK_SIZE` should be the same as `BATCH_SIZE`, but on most GPUs it will likely cause an out of memory error. In such case, reduce `CHUNK_SIZE` to perform gradient accumulation.

#### Distributed training
If you have multiple GPUs you can train your model on multiple instances by running:<br>
`python -m torch.distributed.launch --nproc_per_node NUM_GPUS train_full_pipeline.py --batch BATCH_SIZE --chunk CHUNK_SIZE --expname EXPERIMENT_NAME --size OUTPUT_SIZE`

Here, **BATCH_SIZE represents the batch per GPU**, not the overall batch size.

### Optional training flags
```
Training regime options:
  --iter                 Total number of training iterations. (default: 300,000)
  --wandb                Use use weights and biases logging. (default: False)
  --r1                   Weight of the r1 regularization. (default: 10.0)
  --view_lambda          Weight of the viewpoint regularization. (Equation 6, default: 15)
  --eikonal_lambda       Weight of the eikonal regularization. (Equation 7, default: 0.1)
  --min_surf_lambda      Weight of the minimal surface regularization. (Equation 8, default: 0.05)

Camera options:
  --uniform              When true, the camera position is sampled from uniform distribution. (default: gaussian)
  --azim                 Camera azimuth angle std (guassian)/range (uniform) in Radians. (default: 0.3 Rad.)
  --elev                 Camera elevation angle std (guassian)/range (uniform) in Radians. (default: 0.15 Rad.)
  --fov                  Camera field of view half angle in **Degrees**. (default: 6 Deg.)
  --dist_radius          Radius of points sampling distance from the origin. Determines the near and far fields. (default: 0.12)
```

## Citation
If you use this code for your research, please cite our paper.
```
@InProceedings{orel2022stylesdf,
  title={Style{SDF}: {H}igh-{R}esolution {3D}-{C}onsistent {I}mage and {G}eometry {G}eneration},
  author    = {Or-El, Roy and 
               Luo, Xuan and 
               Shan, Mengyi and 
               Shechtman, Eli and 
               Park, Jeong Joon and 
               Kemelmacher-Shlizerman, Ira},
  booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
  month     = {June},
  year      = {2022},
  pages     = {13503-13513}
}
```

## Acknowledgments
This code is inspired by rosinality's [StyleGAN2-PyTorch](https://github.com/rosinality/stylegan2-pytorch) and Yen-Chen Lin's [NeRF-pytorch](https://github.com/yenchenlin/nerf-pytorch).

A special thanks to [@SerdarHelli](https://github.com/SerdarHelli) for implementing the HuggingFace demo.


================================================
FILE: StyleSDF_demo.ipynb
================================================
{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "name": "StyleSDF_demo.ipynb",
      "private_outputs": true,
      "provenance": [],
      "collapsed_sections": [],
      "include_colab_link": true
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "language_info": {
      "name": "python"
    },
    "accelerator": "GPU"
  },
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "view-in-github",
        "colab_type": "text"
      },
      "source": [
        "<a href=\"https://colab.research.google.com/github/royorel/StyleSDF/blob/main/StyleSDF_demo.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "#StyleSDF Demo\n",
        "\n",
        "This Colab notebook demonstrates the capabilities of the StyleSDF 3D-aware GAN architecture proposed in our paper.\n",
        "\n",
        "This colab generates images with their correspondinig 3D meshes"
      ],
      "metadata": {
        "id": "b86fxLSo1gqI"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "First, let's download the github repository and install all dependencies."
      ],
      "metadata": {
        "id": "8wSAkifH2Bk5"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "!git clone https://github.com/royorel/StyleSDF.git\n",
        "%cd StyleSDF\n",
        "!pip3 install -r requirements.txt"
      ],
      "metadata": {
        "id": "1GTFRig12CuH"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "And install pytorch3D..."
      ],
      "metadata": {
        "id": "GwFJ8oLY4i8d"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "!pip install -U fvcore\n",
        "import sys\n",
        "import torch\n",
        "pyt_version_str=torch.__version__.split(\"+\")[0].replace(\".\", \"\")\n",
        "version_str=\"\".join([\n",
        "    f\"py3{sys.version_info.minor}_cu\",\n",
        "    torch.version.cuda.replace(\".\",\"\"),\n",
        "    f\"_pyt{pyt_version_str}\"\n",
        "])\n",
        "!pip install --no-index --no-cache-dir pytorch3d -f https://dl.fbaipublicfiles.com/pytorch3d/packaging/wheels/{version_str}/download.html"
      ],
      "metadata": {
        "id": "zl3Vpddz3ols"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "Now let's download the pretrained models for FFHQ and AFHQ."
      ],
      "metadata": {
        "id": "OgMcslVS5vbC"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "!python download_models.py"
      ],
      "metadata": {
        "id": "r1iDkz7r5wnO"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "Here, we import libraries and set options.\n",
        "\n",
        "Note: this might take a while (approx. 1-2 minutes) since CUDA kernels need to be compiled."
      ],
      "metadata": {
        "id": "F2JUyGq4JDa8"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "import os\n",
        "import torch\n",
        "import trimesh\n",
        "import numpy as np\n",
        "from munch import *\n",
        "from options import BaseOptions\n",
        "from model import Generator\n",
        "from generate_shapes_and_images import generate\n",
        "from render_video import render_video\n",
        "\n",
        "\n",
        "torch.random.manual_seed(321)\n",
        "\n",
        "\n",
        "device = \"cuda\"\n",
        "opt = BaseOptions().parse()\n",
        "opt.camera.uniform = True\n",
        "opt.model.is_test = True\n",
        "opt.model.freeze_renderer = False\n",
        "opt.rendering.offset_sampling = True\n",
        "opt.rendering.static_viewdirs = True\n",
        "opt.rendering.force_background = True\n",
        "opt.rendering.perturb = 0\n",
        "opt.inference.renderer_output_size = opt.model.renderer_spatial_output_dim\n",
        "opt.inference.style_dim = opt.model.style_dim\n",
        "opt.inference.project_noise = opt.model.project_noise"
      ],
      "metadata": {
        "id": "Qfamt8J0JGn5"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "TLOiDzgKtYyi"
      },
      "source": [
        "Don't worry about this message above, \n",
        "```\n",
        "usage: ipykernel_launcher.py [-h] [--dataset_path DATASET_PATH]\n",
        "                             [--config CONFIG] [--expname EXPNAME]\n",
        "                             [--ckpt CKPT] [--continue_training]\n",
        "                             ...\n",
        "                             ...\n",
        "ipykernel_launcher.py: error: unrecognized arguments: -f /root/.local/share/jupyter/runtime/kernel-c9d47a98-bdba-4a5f-9f0a-e1437c7228b6.json\n",
        "```\n",
        "everything is perfectly fine..."
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "Here, we define our model.\n",
        "\n",
        "Set the options below according to your choosing:\n",
        "1. If you plan to try the method for the AFHQ dataset (animal faces), change `model_type` to 'afhq'. Default: `ffhq` (human faces).\n",
        "2. If you wish to turn off depth rendering and marching cubes extraction and generate only RGB images, set `opt.inference.no_surface_renderings = True`. Default: `False`.\n",
        "3. If you wish to generate the image from a specific set of viewpoints, set `opt.inference.fixed_camera_angles = True`. Default: `False`.\n",
        "4. Set the number of identities you wish to create in `opt.inference.identities`. Default: `4`.\n",
        "5. Select the number of views per identity in `opt.inference.num_views_per_id`,<br>\n",
        "   (Only applicable when `opt.inference.fixed_camera_angles` is false). Default: `1`. "
      ],
      "metadata": {
        "id": "40dk1eEJo9MF"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "# User options\n",
        "model_type = 'ffhq' # Whether to load the FFHQ or AFHQ model\n",
        "opt.inference.no_surface_renderings = False # When true, only RGB images will be created\n",
        "opt.inference.fixed_camera_angles = False # When true, each identity will be rendered from a specific set of 13 viewpoints. Otherwise, random views are generated\n",
        "opt.inference.identities = 4 # Number of identities to generate\n",
        "opt.inference.num_views_per_id = 1 # Number of viewpoints generated per identity. This option is ignored if opt.inference.fixed_camera_angles is true.\n",
        "\n",
        "# Load saved model\n",
        "if model_type == 'ffhq':\n",
        "    model_path = 'ffhq1024x1024.pt'\n",
        "    opt.model.size = 1024\n",
        "    opt.experiment.expname = 'ffhq1024x1024'\n",
        "else:\n",
        "    opt.inference.camera.azim = 0.15\n",
        "    model_path = 'afhq512x512.pt'\n",
        "    opt.model.size = 512\n",
        "    opt.experiment.expname = 'afhq512x512'\n",
        "\n",
        "# Create results directory\n",
        "result_model_dir = 'final_model'\n",
        "results_dir_basename = os.path.join(opt.inference.results_dir, opt.experiment.expname)\n",
        "opt.inference.results_dst_dir = os.path.join(results_dir_basename, result_model_dir)\n",
        "if opt.inference.fixed_camera_angles:\n",
        "    opt.inference.results_dst_dir = os.path.join(opt.inference.results_dst_dir, 'fixed_angles')\n",
        "else:\n",
        "    opt.inference.results_dst_dir = os.path.join(opt.inference.results_dst_dir, 'random_angles')\n",
        "\n",
        "os.makedirs(opt.inference.results_dst_dir, exist_ok=True)\n",
        "os.makedirs(os.path.join(opt.inference.results_dst_dir, 'images'), exist_ok=True)\n",
        "if not opt.inference.no_surface_renderings:\n",
        "    os.makedirs(os.path.join(opt.inference.results_dst_dir, 'depth_map_meshes'), exist_ok=True)\n",
        "    os.makedirs(os.path.join(opt.inference.results_dst_dir, 'marching_cubes_meshes'), exist_ok=True)\n",
        "\n",
        "opt.inference.camera = opt.camera\n",
        "opt.inference.size = opt.model.size\n",
        "checkpoint_path = os.path.join('full_models', model_path)\n",
        "checkpoint = torch.load(checkpoint_path)\n",
        "\n",
        "# Load image generation model\n",
        "g_ema = Generator(opt.model, opt.rendering).to(device)\n",
        "pretrained_weights_dict = checkpoint[\"g_ema\"]\n",
        "model_dict = g_ema.state_dict()\n",
        "for k, v in pretrained_weights_dict.items():\n",
        "    if v.size() == model_dict[k].size():\n",
        "        model_dict[k] = v\n",
        "\n",
        "g_ema.load_state_dict(model_dict)\n",
        "\n",
        "# Load a second volume renderer that extracts surfaces at 128x128x128 (or higher) for better surface resolution\n",
        "if not opt.inference.no_surface_renderings:\n",
        "    opt['surf_extraction'] = Munch()\n",
        "    opt.surf_extraction.rendering = opt.rendering\n",
        "    opt.surf_extraction.model = opt.model.copy()\n",
        "    opt.surf_extraction.model.renderer_spatial_output_dim = 128\n",
        "    opt.surf_extraction.rendering.N_samples = opt.surf_extraction.model.renderer_spatial_output_dim\n",
        "    opt.surf_extraction.rendering.return_xyz = True\n",
        "    opt.surf_extraction.rendering.return_sdf = True\n",
        "    surface_g_ema = Generator(opt.surf_extraction.model, opt.surf_extraction.rendering, full_pipeline=False).to(device)\n",
        "\n",
        "\n",
        "    # Load weights to surface extractor\n",
        "    surface_extractor_dict = surface_g_ema.state_dict()\n",
        "    for k, v in pretrained_weights_dict.items():\n",
        "        if k in surface_extractor_dict.keys() and v.size() == surface_extractor_dict[k].size():\n",
        "            surface_extractor_dict[k] = v\n",
        "\n",
        "    surface_g_ema.load_state_dict(surface_extractor_dict)\n",
        "else:\n",
        "    surface_g_ema = None\n",
        "\n",
        "# Get the mean latent vector for g_ema\n",
        "if opt.inference.truncation_ratio < 1:\n",
        "    with torch.no_grad():\n",
        "        mean_latent = g_ema.mean_latent(opt.inference.truncation_mean, device)\n",
        "else:\n",
        "    surface_mean_latent = None\n",
        "\n",
        "# Get the mean latent vector for surface_g_ema\n",
        "if not opt.inference.no_surface_renderings:\n",
        "    surface_mean_latent = mean_latent[0]\n",
        "else:\n",
        "    surface_mean_latent = None"
      ],
      "metadata": {
        "id": "CUcWipIlpINT"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "## Generating images and meshes\n",
        "\n",
        "Finally, we run the network. The results will be saved to `evaluations/[model_name]/final_model/[fixed/random]_angles`, according to the selected setup."
      ],
      "metadata": {
        "id": "N9pqjPDYCwIJ"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "generate(opt.inference, g_ema, surface_g_ema, device, mean_latent, surface_mean_latent)"
      ],
      "metadata": {
        "id": "UG4hZgigDfG7"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "Now let's examine the results\n",
        "\n",
        "Tip: for better mesh visualization, we recommend dowwnloading the result meshes and view them with Meshlab.\n",
        "\n",
        "Meshes loaction is: `evaluations/[model_name]/final_model/[fixed/random]_angles/[depth_map/marching_cubes]_meshes`."
      ],
      "metadata": {
        "id": "obfjXuDxNuZl"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "from PIL import Image\n",
        "from trimesh.viewer.notebook import scene_to_html as mesh2html\n",
        "from IPython.display import HTML as viewer_html\n",
        "\n",
        "# First let's look at the images\n",
        "img_dir = os.path.join(opt.inference.results_dst_dir,'images')\n",
        "im_list = sorted([entry for entry in os.listdir(img_dir) if 'thumb' not in entry])\n",
        "img = Image.new('RGB', (256 * len(im_list), 256))\n",
        "for i, im_file in enumerate(im_list):\n",
        "    im_path = os.path.join(img_dir, im_file)\n",
        "    curr_img = Image.open(im_path).resize((256,256)) # the displayed image is scaled to fit to the screen\n",
        "    img.paste(curr_img, (256 * i, 0))\n",
        "\n",
        "display(img)\n",
        "\n",
        "# And now, we'll move on to display the marching cubes and depth map meshes\n",
        "\n",
        "marching_cubes_meshes_dir = os.path.join(opt.inference.results_dst_dir,'marching_cubes_meshes')\n",
        "marching_cubes_meshes_list = sorted([os.path.join(marching_cubes_meshes_dir, entry) for entry in os.listdir(marching_cubes_meshes_dir) if 'obj' in entry])\n",
        "depth_map_meshes_dir = os.path.join(opt.inference.results_dst_dir,'depth_map_meshes')\n",
        "depth_map_meshes_list = sorted([os.path.join(depth_map_meshes_dir, entry) for entry in os.listdir(depth_map_meshes_dir) if 'obj' in entry])\n",
        "for i, mesh_files in enumerate(zip(marching_cubes_meshes_list, depth_map_meshes_list)):\n",
        "    mc_mesh_file, dm_mesh_file = mesh_files[0], mesh_files[1]\n",
        "    marching_cubes_mesh = trimesh.Scene(trimesh.load_mesh(mc_mesh_file, 'obj'))  \n",
        "    curr_mc_html = mesh2html(marching_cubes_mesh).replace('\"', '&quot;')\n",
        "    display(viewer_html(' '.join(['<iframe srcdoc=\"{srcdoc}\"',\n",
        "                            'width=\"{width}px\" height=\"{height}px\"',\n",
        "                            'style=\"border:none;\"></iframe>']).format(\n",
        "                            srcdoc=curr_mc_html, height=256, width=256)))\n",
        "    depth_map_mesh = trimesh.Scene(trimesh.load_mesh(dm_mesh_file, 'obj'))  \n",
        "    curr_dm_html = mesh2html(depth_map_mesh).replace('\"', '&quot;')\n",
        "    display(viewer_html(' '.join(['<iframe srcdoc=\"{srcdoc}\"',\n",
        "                            'width=\"{width}px\" height=\"{height}px\"',\n",
        "                            'style=\"border:none;\"></iframe>']).format(\n",
        "                            srcdoc=curr_dm_html, height=256, width=256)))"
      ],
      "metadata": {
        "id": "k3jXCVT7N2YZ"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "## Generating videos\n",
        "\n",
        "Additionally, we can also render videos. The results will be saved to `evaluations/[model_name]/final_model/videos`.\n",
        "\n",
        "Set the options below according to your choosing:\n",
        "1. If you wish to generate only RGB videos, set `opt.inference.no_surface_renderings = True`. Default: `False`.\n",
        "2. Set the camera trajectory. To travel along the azimuth direction set `opt.inference.azim_video = True`, to travel in an ellipsoid trajectory set `opt.inference.azim_video = False`. Default: `False`.\n",
        "\n",
        "###Important Note: \n",
        " - Processing time for videos when `opt.inference.no_surface_renderings = False` is very lengthy (~ 15-20 minutes per video). Rendering each depth frame for the depth videos is very slow.<br>\n",
        " - Processing time for videos when `opt.inference.no_surface_renderings = True` is much faster (~ 1-2 minutes per video)"
      ],
      "metadata": {
        "id": "Cxq3c6anN4D0"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "# Options\n",
        "opt.inference.no_surface_renderings = True # When true, only RGB videos will be created\n",
        "opt.inference.azim_video = True # When true, the camera trajectory will travel along the azimuth direction. Otherwise, the camera will travel along an ellipsoid trajectory.\n",
        "\n",
        "opt.inference.results_dst_dir = os.path.join(os.path.split(opt.inference.results_dst_dir)[0], 'videos')\n",
        "os.makedirs(opt.inference.results_dst_dir, exist_ok=True)\n",
        "render_video(opt.inference, g_ema, surface_g_ema, device, mean_latent, surface_mean_latent)"
      ],
      "metadata": {
        "id": "nblhnZgcOST8"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "Let's watch the result videos.\n",
        "\n",
        "The output video files are relatively large, so it might take a while (about 1-2 minutes) for all of them to be loaded. "
      ],
      "metadata": {
        "id": "jkBYvlaeTGu1"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "%%script bash --bg\n",
        "python3 -m https.server 8000"
      ],
      "metadata": {
        "id": "s-A9oIFXnYdM"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# change ffhq1024x1024 to afhq512x512 if you are working on the AFHQ model\n",
        "%%html\n",
        "<div>\n",
        "  <video width=256 controls><source src=\"https://localhost:8000/evaluations/ffhq1024x1024/final_model/videos/sample_video_0_azim.mp4\" type=\"video/mp4\"></video>\n",
        "  <video width=256 controls><source src=\"https://localhost:8000/evaluations/ffhq1024x1024/final_model/videos/sample_video_1_azim.mp4\" type=\"video/mp4\"></video>\n",
        "  <video width=256 controls><source src=\"https://localhost:8000/evaluations/ffhq1024x1024/final_model/videos/sample_video_2_azim.mp4\" type=\"video/mp4\"></video>\n",
        "  <video width=256 controls><source src=\"https://localhost:8000/evaluations/ffhq1024x1024/final_model/videos/sample_video_3_azim.mp4\" type=\"video/mp4\"></video>\n",
        "</div>"
      ],
      "metadata": {
        "id": "Rz8tq-yYnZMD"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "An alternative way to view the videos with python code. \n",
        "It loads the videos faster, but very often it crashes the notebook since the video file are too large.\n",
        "\n",
        "**It is not recommended to view the files this way**.\n",
        "\n",
        "If the notebook does crash, you can also refresh the webpage and manually download the videos.<br>\n",
        "The videos are located in `evaluations/<model_name>/final_model/videos`"
      ],
      "metadata": {
        "id": "KP_Nrl8yqL65"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "# from base64 import b64encode\n",
        "\n",
        "# videos_dir = opt.inference.results_dst_dir\n",
        "# videos_list = sorted([os.path.join(videos_dir, entry) for entry in os.listdir(videos_dir) if 'mp4' in entry])\n",
        "# for i, video_file in enumerate(videos_list):\n",
        "#     if i != 1:\n",
        "#         continue\n",
        "#     mp4 = open(video_file,'rb').read()\n",
        "#     data_url = \"data:video/mp4;base64,\" + b64encode(mp4).decode()\n",
        "#     display(viewer_html(\"\"\"<video width={0} controls>\n",
        "#                                 <source src=\"{1}\" type=\"{2}\">\n",
        "#                           </video>\"\"\".format(256, data_url, \"video/mp4\")))"
      ],
      "metadata": {
        "id": "LzsQTq1hTNSH"
      },
      "execution_count": null,
      "outputs": []
    }
  ]
}

================================================
FILE: dataset.py
================================================
import os
import csv
import lmdb
import random
import numpy as np
import torchvision.transforms.functional as TF
from PIL import Image
from io import BytesIO
from torch.utils.data import Dataset


class MultiResolutionDataset(Dataset):
    def __init__(self, path, transform, resolution=256, nerf_resolution=64):
        self.env = lmdb.open(
            path,
            max_readers=32,
            readonly=True,
            lock=False,
            readahead=False,
            meminit=False,
        )

        if not self.env:
            raise IOError('Cannot open lmdb dataset', path)

        with self.env.begin(write=False) as txn:
            self.length = int(txn.get('length'.encode('utf-8')).decode('utf-8'))

        self.resolution = resolution
        self.nerf_resolution = nerf_resolution
        self.transform = transform

    def __len__(self):
        return self.length

    def __getitem__(self, index):
        with self.env.begin(write=False) as txn:
            key = f'{self.resolution}-{str(index).zfill(5)}'.encode('utf-8')
            img_bytes = txn.get(key)

        buffer = BytesIO(img_bytes)
        img = Image.open(buffer)
        if random.random() > 0.5:
            img = TF.hflip(img)

        thumb_img = img.resize((self.nerf_resolution, self.nerf_resolution), Image.HAMMING)
        img = self.transform(img)
        thumb_img = self.transform(thumb_img)

        return img, thumb_img


================================================
FILE: distributed.py
================================================
import math
import pickle

import torch
from torch import distributed as dist
from torch.utils.data.sampler import Sampler


def get_rank():
    if not dist.is_available():
        return 0

    if not dist.is_initialized():
        return 0

    return dist.get_rank()


def synchronize():
    if not dist.is_available():
        return

    if not dist.is_initialized():
        return

    world_size = dist.get_world_size()

    if world_size == 1:
        return

    dist.barrier()


def get_world_size():
    if not dist.is_available():
        return 1

    if not dist.is_initialized():
        return 1

    return dist.get_world_size()


def reduce_sum(tensor):
    if not dist.is_available():
        return tensor

    if not dist.is_initialized():
        return tensor

    tensor = tensor.clone()
    dist.all_reduce(tensor, op=dist.ReduceOp.SUM)

    return tensor


def gather_grad(params):
    world_size = get_world_size()

    if world_size == 1:
        return

    for param in params:
        if param.grad is not None:
            dist.all_reduce(param.grad.data, op=dist.ReduceOp.SUM)
            param.grad.data.div_(world_size)


def all_gather(data):
    world_size = get_world_size()

    if world_size == 1:
        return [data]

    buffer = pickle.dumps(data)
    storage = torch.ByteStorage.from_buffer(buffer)
    tensor = torch.ByteTensor(storage).to('cuda')

    local_size = torch.IntTensor([tensor.numel()]).to('cuda')
    size_list = [torch.IntTensor([0]).to('cuda') for _ in range(world_size)]
    dist.all_gather(size_list, local_size)
    size_list = [int(size.item()) for size in size_list]
    max_size = max(size_list)

    tensor_list = []
    for _ in size_list:
        tensor_list.append(torch.ByteTensor(size=(max_size,)).to('cuda'))

    if local_size != max_size:
        padding = torch.ByteTensor(size=(max_size - local_size,)).to('cuda')
        tensor = torch.cat((tensor, padding), 0)

    dist.all_gather(tensor_list, tensor)

    data_list = []

    for size, tensor in zip(size_list, tensor_list):
        buffer = tensor.cpu().numpy().tobytes()[:size]
        data_list.append(pickle.loads(buffer))

    return data_list


def reduce_loss_dict(loss_dict):
    world_size = get_world_size()

    if world_size < 2:
        return loss_dict

    with torch.no_grad():
        keys = []
        losses = []

        for k in sorted(loss_dict.keys()):
            keys.append(k)
            losses.append(loss_dict[k])

        losses = torch.stack(losses, 0)
        dist.reduce(losses, dst=0)

        if dist.get_rank() == 0:
            losses /= world_size

        reduced_losses = {k: v for k, v in zip(keys, losses)}

    return reduced_losses


================================================
FILE: download_models.py
================================================
import os
import html
import glob
import uuid
import hashlib
import requests
from tqdm import tqdm
from pdb import set_trace as st


ffhq_full_model_spec = dict(file_url='https://drive.google.com/uc?id=13s_dH768zJ3IHUjySVbD1DqcMolqKlBi',
                            alt_url='', file_size=202570217, file_md5='1ab9522157e537351fcf4faed9c92abb',
                            file_path='full_models/ffhq1024x1024.pt',)
ffhq_volume_renderer_spec = dict(file_url='https://drive.google.com/uc?id=1zzB-ACuas7lSAln8pDnIqlWOEg869CCK',
                                 alt_url='', file_size=63736197, file_md5='fe62f26032ccc8f04e1101b07bcd7462',
                                 file_path='pretrained_renderer/ffhq_vol_renderer.pt',)
afhq_full_model_spec = dict(file_url='https://drive.google.com/uc?id=1jZcV5__EPS56JBllRmUfUjMHExql_eSq',
                            alt_url='', file_size=184908743, file_md5='91eeaab2da5c0d2134c04bb56ac5aeb6',
                            file_path='full_models/afhq512x512.pt',)
afhq_volume_renderer_spec = dict(file_url='https://drive.google.com/uc?id=1xhZjuJt_teghAQEoJevAxrU5_8VqqX42',
                                 alt_url='', file_size=63736197, file_md5='cb3beb6cd3c43d9119356400165e7f26',
                                 file_path='pretrained_renderer/afhq_vol_renderer.pt',)
volume_renderer_sphere_init_spec = dict(file_url='https://drive.google.com/uc?id=1CcYWbHFrJyb4u5rBFx_ww-ceYWd309tk',
                                        alt_url='', file_size=63736197, file_md5='a6be653435b0e49633e6838f18ff4df6',
                                        file_path='pretrained_renderer/sphere_init.pt',)


def download_pretrained_models():
    print('Downloading sphere initialized volume renderer')
    with requests.Session() as session:
        try:
            download_file(session, volume_renderer_sphere_init_spec)
        except:
            print('Google Drive download failed.\n' \
                  'Trying do download from alternate server')
            download_file(session, volume_renderer_sphere_init_spec, use_alt_url=True)

    print('Downloading FFHQ pretrained volume renderer')
    with requests.Session() as session:
        try:
            download_file(session, ffhq_volume_renderer_spec)
        except:
            print('Google Drive download failed.\n' \
                  'Trying do download from alternate server')
            download_file(session, ffhq_volume_renderer_spec, use_alt_url=True)

    print('Downloading FFHQ full model (1024x1024)')
    with requests.Session() as session:
        try:
            download_file(session, ffhq_full_model_spec)
        except:
            print('Google Drive download failed.\n' \
                  'Trying do download from alternate server')
            download_file(session, ffhq_full_model_spec, use_alt_url=True)

    print('Done!')

    print('Downloading AFHQ pretrained volume renderer')
    with requests.Session() as session:
        try:
            download_file(session, afhq_volume_renderer_spec)
        except:
            print('Google Drive download failed.\n' \
                  'Trying do download from alternate server')
            download_file(session, afhq_volume_renderer_spec, use_alt_url=True)

    print('Done!')

    print('Downloading Downloading AFHQ full model (512x512)')
    with requests.Session() as session:
        try:
            download_file(session, afhq_full_model_spec)
        except:
            print('Google Drive download failed.\n' \
                  'Trying do download from alternate server')
            download_file(session, afhq_full_model_spec, use_alt_url=True)

    print('Done!')

def download_file(session, file_spec, use_alt_url=False, chunk_size=128, num_attempts=10):
    file_path = file_spec['file_path']
    if use_alt_url:
        file_url = file_spec['alt_url']
    else:
        file_url = file_spec['file_url']

    file_dir = os.path.dirname(file_path)
    tmp_path = file_path + '.tmp.' + uuid.uuid4().hex
    if file_dir:
        os.makedirs(file_dir, exist_ok=True)

    progress_bar = tqdm(total=file_spec['file_size'], unit='B', unit_scale=True)
    for attempts_left in reversed(range(num_attempts)):
        data_size = 0
        progress_bar.reset()
        try:
            # Download.
            data_md5 = hashlib.md5()
            with session.get(file_url, stream=True) as res:
                res.raise_for_status()
                with open(tmp_path, 'wb') as f:
                    for chunk in res.iter_content(chunk_size=chunk_size<<10):
                        progress_bar.update(len(chunk))
                        f.write(chunk)
                        data_size += len(chunk)
                        data_md5.update(chunk)

            # Validate.
            if 'file_size' in file_spec and data_size != file_spec['file_size']:
                raise IOError('Incorrect file size', file_path)
            if 'file_md5' in file_spec and data_md5.hexdigest() != file_spec['file_md5']:
                raise IOError('Incorrect file MD5', file_path)
            break

        except:
            # Last attempt => raise error.
            if not attempts_left:
                raise

            # Handle Google Drive virus checker nag.
            if data_size > 0 and data_size < 8192:
                with open(tmp_path, 'rb') as f:
                    data = f.read()
                links = [html.unescape(link) for link in data.decode('utf-8').split('"') if 'confirm=t' in link]
                if len(links) == 1:
                    file_url = requests.compat.urljoin(file_url, links[0])
                    continue

    progress_bar.close()

    # Rename temp file to the correct name.
    os.replace(tmp_path, file_path) # atomic

    # Attempt to clean up any leftover temps.
    for filename in glob.glob(file_path + '.tmp.*'):
        try:
            os.remove(filename)
        except:
            pass

if __name__ == "__main__":
    download_pretrained_models()


================================================
FILE: generate_shapes_and_images.py
================================================
import os
import torch
import trimesh
import numpy as np
from munch import *
from PIL import Image
from tqdm import tqdm
from torch.nn import functional as F
from torch.utils import data
from torchvision import utils
from torchvision import transforms
from skimage.measure import marching_cubes
from scipy.spatial import Delaunay
from options import BaseOptions
from model import Generator
from utils import (
    generate_camera_params,
    align_volume,
    extract_mesh_with_marching_cubes,
    xyz2mesh,
)


torch.random.manual_seed(1234)


def generate(opt, g_ema, surface_g_ema, device, mean_latent, surface_mean_latent):
    g_ema.eval()
    if not opt.no_surface_renderings:
        surface_g_ema.eval()

    # set camera angles
    if opt.fixed_camera_angles:
        # These can be changed to any other specific viewpoints.
        # You can add or remove viewpoints as you wish
        locations = torch.tensor([[0, 0],
                                  [-1.5 * opt.camera.azim, 0],
                                  [-1 * opt.camera.azim, 0],
                                  [-0.5 * opt.camera.azim, 0],
                                  [0.5 * opt.camera.azim, 0],
                                  [1 * opt.camera.azim, 0],
                                  [1.5 * opt.camera.azim, 0],
                                  [0, -1.5 * opt.camera.elev],
                                  [0, -1 * opt.camera.elev],
                                  [0, -0.5 * opt.camera.elev],
                                  [0, 0.5 * opt.camera.elev],
                                  [0, 1 * opt.camera.elev],
                                  [0, 1.5 * opt.camera.elev]], device=device)
        # For zooming in/out change the values of fov
        # (This can be defined for each view separately via a custom tensor
        # like the locations tensor above. Tensor shape should be [locations.shape[0],1])
        # reasonable values are [0.75 * opt.camera.fov, 1.25 * opt.camera.fov]
        fov = opt.camera.fov * torch.ones((locations.shape[0],1), device=device)
        num_viewdirs = locations.shape[0]
    else: # draw random camera angles
        locations = None
        # fov = None
        fov = opt.camera.fov
        num_viewdirs = opt.num_views_per_id

    # generate images
    for i in tqdm(range(opt.identities)):
        with torch.no_grad():
            chunk = 8
            sample_z = torch.randn(1, opt.style_dim, device=device).repeat(num_viewdirs,1)
            sample_cam_extrinsics, sample_focals, sample_near, sample_far, sample_locations = \
            generate_camera_params(opt.renderer_output_size, device, batch=num_viewdirs,
                                   locations=locations, #input_fov=fov,
                                   uniform=opt.camera.uniform, azim_range=opt.camera.azim,
                                   elev_range=opt.camera.elev, fov_ang=fov,
                                   dist_radius=opt.camera.dist_radius)
            rgb_images = torch.Tensor(0, 3, opt.size, opt.size)
            rgb_images_thumbs = torch.Tensor(0, 3, opt.renderer_output_size, opt.renderer_output_size)
            for j in range(0, num_viewdirs, chunk):
                out = g_ema([sample_z[j:j+chunk]],
                            sample_cam_extrinsics[j:j+chunk],
                            sample_focals[j:j+chunk],
                            sample_near[j:j+chunk],
                            sample_far[j:j+chunk],
                            truncation=opt.truncation_ratio,
                            truncation_latent=mean_latent)

                rgb_images = torch.cat([rgb_images, out[0].cpu()], 0)
                rgb_images_thumbs = torch.cat([rgb_images_thumbs, out[1].cpu()], 0)

            utils.save_image(rgb_images,
                os.path.join(opt.results_dst_dir, 'images','{}.png'.format(str(i).zfill(7))),
                nrow=num_viewdirs,
                normalize=True,
                padding=0,
                value_range=(-1, 1),)

            utils.save_image(rgb_images_thumbs,
                os.path.join(opt.results_dst_dir, 'images','{}_thumb.png'.format(str(i).zfill(7))),
                nrow=num_viewdirs,
                normalize=True,
                padding=0,
                value_range=(-1, 1),)

            # this is done to fit to RTX2080 RAM size (11GB)
            del out
            torch.cuda.empty_cache()

            if not opt.no_surface_renderings:
                surface_chunk = 1
                scale = surface_g_ema.renderer.out_im_res / g_ema.renderer.out_im_res
                surface_sample_focals = sample_focals * scale
                for j in range(0, num_viewdirs, surface_chunk):
                    surface_out = surface_g_ema([sample_z[j:j+surface_chunk]],
                                                sample_cam_extrinsics[j:j+surface_chunk],
                                                surface_sample_focals[j:j+surface_chunk],
                                                sample_near[j:j+surface_chunk],
                                                sample_far[j:j+surface_chunk],
                                                truncation=opt.truncation_ratio,
                                                truncation_latent=surface_mean_latent,
                                                return_sdf=True,
                                                return_xyz=True)

                    xyz = surface_out[2].cpu()
                    sdf = surface_out[3].cpu()

                    # this is done to fit to RTX2080 RAM size (11GB)
                    del surface_out
                    torch.cuda.empty_cache()

                    # mesh extractions are done one at a time
                    for k in range(surface_chunk):
                        curr_locations = sample_locations[j:j+surface_chunk]
                        loc_str = '_azim{}_elev{}'.format(int(curr_locations[k,0] * 180 / np.pi),
                                                          int(curr_locations[k,1] * 180 / np.pi))

                        # Save depth outputs as meshes
                        depth_mesh_filename = os.path.join(opt.results_dst_dir,'depth_map_meshes','sample_{}_depth_mesh{}.obj'.format(i, loc_str))
                        depth_mesh = xyz2mesh(xyz[k:k+surface_chunk])
                        if depth_mesh != None:
                            with open(depth_mesh_filename, 'w') as f:
                                depth_mesh.export(f,file_type='obj')

                        # extract full geometry with marching cubes
                        if j == 0:
                            try:
                                frostum_aligned_sdf = align_volume(sdf)
                                marching_cubes_mesh = extract_mesh_with_marching_cubes(frostum_aligned_sdf[k:k+surface_chunk])
                            except ValueError:
                                marching_cubes_mesh = None
                                print('Marching cubes extraction failed.')
                                print('Please check whether the SDF values are all larger (or all smaller) than 0.')

                            if marching_cubes_mesh != None:
                                marching_cubes_mesh_filename = os.path.join(opt.results_dst_dir,'marching_cubes_meshes','sample_{}_marching_cubes_mesh{}.obj'.format(i, loc_str))
                                with open(marching_cubes_mesh_filename, 'w') as f:
                                    marching_cubes_mesh.export(f,file_type='obj')


if __name__ == "__main__":
    device = "cuda"
    opt = BaseOptions().parse()
    opt.model.is_test = True
    opt.model.freeze_renderer = False
    opt.rendering.offset_sampling = True
    opt.rendering.static_viewdirs = True
    opt.rendering.force_background = True
    opt.rendering.perturb = 0
    opt.inference.size = opt.model.size
    opt.inference.camera = opt.camera
    opt.inference.renderer_output_size = opt.model.renderer_spatial_output_dim
    opt.inference.style_dim = opt.model.style_dim
    opt.inference.project_noise = opt.model.project_noise
    opt.inference.return_xyz = opt.rendering.return_xyz

    # find checkpoint directory
    # check if there's a fully trained model
    checkpoints_dir = 'full_models'
    checkpoint_path = os.path.join(checkpoints_dir, opt.experiment.expname + '.pt')
    if os.path.isfile(checkpoint_path):
        # define results directory name
        result_model_dir = 'final_model'
    else:
        checkpoints_dir = os.path.join('checkpoint', opt.experiment.expname, 'full_pipeline')
        checkpoint_path = os.path.join(checkpoints_dir,
                                       'models_{}.pt'.format(opt.experiment.ckpt.zfill(7)))
        # define results directory name
        result_model_dir = 'iter_{}'.format(opt.experiment.ckpt.zfill(7))

    # create results directory
    results_dir_basename = os.path.join(opt.inference.results_dir, opt.experiment.expname)
    opt.inference.results_dst_dir = os.path.join(results_dir_basename, result_model_dir)
    if opt.inference.fixed_camera_angles:
        opt.inference.results_dst_dir = os.path.join(opt.inference.results_dst_dir, 'fixed_angles')
    else:
        opt.inference.results_dst_dir = os.path.join(opt.inference.results_dst_dir, 'random_angles')
    os.makedirs(opt.inference.results_dst_dir, exist_ok=True)
    os.makedirs(os.path.join(opt.inference.results_dst_dir, 'images'), exist_ok=True)
    if not opt.inference.no_surface_renderings:
        os.makedirs(os.path.join(opt.inference.results_dst_dir, 'depth_map_meshes'), exist_ok=True)
        os.makedirs(os.path.join(opt.inference.results_dst_dir, 'marching_cubes_meshes'), exist_ok=True)

    # load saved model
    checkpoint = torch.load(checkpoint_path)

    # load image generation model
    g_ema = Generator(opt.model, opt.rendering).to(device)
    pretrained_weights_dict = checkpoint["g_ema"]
    model_dict = g_ema.state_dict()
    for k, v in pretrained_weights_dict.items():
        if v.size() == model_dict[k].size():
            model_dict[k] = v

    g_ema.load_state_dict(model_dict)

    # load a second volume renderer that extracts surfaces at 128x128x128 (or higher) for better surface resolution
    if not opt.inference.no_surface_renderings:
        opt['surf_extraction'] = Munch()
        opt.surf_extraction.rendering = opt.rendering
        opt.surf_extraction.model = opt.model.copy()
        opt.surf_extraction.model.renderer_spatial_output_dim = 128
        opt.surf_extraction.rendering.N_samples = opt.surf_extraction.model.renderer_spatial_output_dim
        opt.surf_extraction.rendering.return_xyz = True
        opt.surf_extraction.rendering.return_sdf = True
        surface_g_ema = Generator(opt.surf_extraction.model, opt.surf_extraction.rendering, full_pipeline=False).to(device)

        # Load weights to surface extractor
        surface_extractor_dict = surface_g_ema.state_dict()
        for k, v in pretrained_weights_dict.items():
            if k in surface_extractor_dict.keys() and v.size() == surface_extractor_dict[k].size():
                surface_extractor_dict[k] = v

        surface_g_ema.load_state_dict(surface_extractor_dict)
    else:
        surface_g_ema = None

    # get the mean latent vector for g_ema
    if opt.inference.truncation_ratio < 1:
        with torch.no_grad():
            mean_latent = g_ema.mean_latent(opt.inference.truncation_mean, device)
    else:
        surface_mean_latent = None

    # get the mean latent vector for surface_g_ema
    if not opt.inference.no_surface_renderings:
        surface_mean_latent = mean_latent[0]
    else:
        surface_mean_latent = None

    generate(opt.inference, g_ema, surface_g_ema, device, mean_latent, surface_mean_latent)


================================================
FILE: losses.py
================================================
import math
import torch
from torch import autograd
from torch.nn import functional as F


def viewpoints_loss(viewpoint_pred, viewpoint_target):
    loss = F.smooth_l1_loss(viewpoint_pred, viewpoint_target)

    return loss


def eikonal_loss(eikonal_term, sdf=None, beta=100):
    if eikonal_term == None:
        eikonal_loss = 0
    else:
        eikonal_loss = ((eikonal_term.norm(dim=-1) - 1) ** 2).mean()

    if sdf == None:
        minimal_surface_loss = torch.tensor(0.0, device=eikonal_term.device)
    else:
        minimal_surface_loss = torch.exp(-beta * torch.abs(sdf)).mean()

    return eikonal_loss, minimal_surface_loss


def d_logistic_loss(real_pred, fake_pred):
    real_loss = F.softplus(-real_pred)
    fake_loss = F.softplus(fake_pred)

    return real_loss.mean() + fake_loss.mean()


def d_r1_loss(real_pred, real_img):
    grad_real, = autograd.grad(outputs=real_pred.sum(),
                               inputs=real_img,
                               create_graph=True)

    grad_penalty = grad_real.pow(2).reshape(grad_real.shape[0], -1).sum(1).mean()
    return grad_penalty


def g_nonsaturating_loss(fake_pred):
    loss = F.softplus(-fake_pred).mean()

    return loss


def g_path_regularize(fake_img, latents, mean_path_length, decay=0.01):
    noise = torch.randn_like(fake_img) / math.sqrt(
        fake_img.shape[2] * fake_img.shape[3]
    )
    grad = autograd.grad(outputs=(fake_img * noise).sum(), inputs=latents,
                         create_graph=True, only_inputs=True)[0]
    path_lengths = torch.sqrt(grad.pow(2).sum(2).mean(1))

    path_mean = mean_path_length + decay * (path_lengths.mean() - mean_path_length)

    path_penalty = (path_lengths - path_mean).pow(2).mean()

    return path_penalty, path_mean.detach(), path_lengths


================================================
FILE: model.py
================================================
import math
import random
import trimesh
import torch
import numpy as np
from torch import nn
from torch.nn import functional as F
from volume_renderer import VolumeFeatureRenderer
from op import FusedLeakyReLU, fused_leaky_relu, upfirdn2d
from pdb import set_trace as st

from utils import (
    create_cameras,
    create_mesh_renderer,
    add_textures,
    create_depth_mesh_renderer,
)
from pytorch3d.renderer import TexturesUV
from pytorch3d.structures import Meshes
from pytorch3d.renderer import look_at_view_transform, FoVPerspectiveCameras
from pytorch3d.transforms import matrix_to_euler_angles


class PixelNorm(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, input):
        return input * torch.rsqrt(torch.mean(input ** 2, dim=1, keepdim=True) + 1e-8)


class MappingLinear(nn.Module):
    def __init__(self, in_dim, out_dim, bias=True, activation=None, is_last=False):
        super().__init__()
        if is_last:
            weight_std = 0.25
        else:
            weight_std = 1

        self.weight = nn.Parameter(weight_std * nn.init.kaiming_normal_(torch.empty(out_dim, in_dim), a=0.2, mode='fan_in', nonlinearity='leaky_relu'))

        if bias:
            self.bias = nn.Parameter(nn.init.uniform_(torch.empty(out_dim), a=-np.sqrt(1/in_dim), b=np.sqrt(1/in_dim)))
        else:
            self.bias = None

        self.activation = activation

    def forward(self, input):
        if self.activation != None:
            out = F.linear(input, self.weight)
            out = fused_leaky_relu(out, self.bias, scale=1)
        else:
            out = F.linear(input, self.weight, bias=self.bias)

        return out

    def __repr__(self):
        return (
            f"{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]})"
        )


def make_kernel(k):
    k = torch.tensor(k, dtype=torch.float32)

    if k.ndim == 1:
        k = k[None, :] * k[:, None]

    k /= k.sum()

    return k


class Upsample(nn.Module):
    def __init__(self, kernel, factor=2):
        super().__init__()

        self.factor = factor
        kernel = make_kernel(kernel) * (factor ** 2)
        self.register_buffer("kernel", kernel)

        p = kernel.shape[0] - factor

        pad0 = (p + 1) // 2 + factor - 1
        pad1 = p // 2

        self.pad = (pad0, pad1)

    def forward(self, input):
        out = upfirdn2d(input, self.kernel, up=self.factor, down=1, pad=self.pad)

        return out


class Downsample(nn.Module):
    def __init__(self, kernel, factor=2):
        super().__init__()

        self.factor = factor
        kernel = make_kernel(kernel)
        self.register_buffer("kernel", kernel)

        p = kernel.shape[0] - factor

        pad0 = (p + 1) // 2
        pad1 = p // 2

        self.pad = (pad0, pad1)

    def forward(self, input):
        out = upfirdn2d(input, self.kernel, up=1, down=self.factor, pad=self.pad)

        return out


class Blur(nn.Module):
    def __init__(self, kernel, pad, upsample_factor=1):
        super().__init__()

        kernel = make_kernel(kernel)

        if upsample_factor > 1:
            kernel = kernel * (upsample_factor ** 2)

        self.register_buffer("kernel", kernel)

        self.pad = pad

    def forward(self, input):
        out = upfirdn2d(input, self.kernel, pad=self.pad)

        return out


class EqualConv2d(nn.Module):
    def __init__(
        self, in_channel, out_channel, kernel_size, stride=1, padding=0, bias=True
    ):
        super().__init__()

        self.weight = nn.Parameter(
            torch.randn(out_channel, in_channel, kernel_size, kernel_size)
        )
        self.scale = 1 / math.sqrt(in_channel * kernel_size ** 2)

        self.stride = stride
        self.padding = padding

        if bias:
            self.bias = nn.Parameter(torch.zeros(out_channel))

        else:
            self.bias = None

    def forward(self, input):
        out = F.conv2d(
            input,
            self.weight * self.scale,
            bias=self.bias,
            stride=self.stride,
            padding=self.padding,
        )

        return out

    def __repr__(self):
        return (
            f"{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]},"
            f" {self.weight.shape[2]}, stride={self.stride}, padding={self.padding})"
        )


class EqualLinear(nn.Module):
    def __init__(self, in_dim, out_dim, bias=True, bias_init=0, lr_mul=1,
                 activation=None):
        super().__init__()

        self.weight = nn.Parameter(torch.randn(out_dim, in_dim).div_(lr_mul))

        if bias:
            self.bias = nn.Parameter(torch.zeros(out_dim).fill_(bias_init))

        else:
            self.bias = None

        self.activation = activation

        self.scale = (1 / math.sqrt(in_dim)) * lr_mul
        self.lr_mul = lr_mul

    def forward(self, input):
        if self.activation:
            out = F.linear(input, self.weight * self.scale)
            out = fused_leaky_relu(out, self.bias * self.lr_mul)

        else:
            out = F.linear(input, self.weight * self.scale,
                           bias=self.bias * self.lr_mul)

        return out

    def __repr__(self):
        return (
            f"{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]})"
        )


class ModulatedConv2d(nn.Module):
    def __init__(self, in_channel, out_channel, kernel_size, style_dim, demodulate=True,
                 upsample=False, downsample=False, blur_kernel=[1, 3, 3, 1]):
        super().__init__()

        self.eps = 1e-8
        self.kernel_size = kernel_size
        self.in_channel = in_channel
        self.out_channel = out_channel
        self.upsample = upsample
        self.downsample = downsample

        if upsample:
            factor = 2
            p = (len(blur_kernel) - factor) - (kernel_size - 1)
            pad0 = (p + 1) // 2 + factor - 1
            pad1 = p // 2 + 1

            self.blur = Blur(blur_kernel, pad=(pad0, pad1), upsample_factor=factor)

        if downsample:
            factor = 2
            p = (len(blur_kernel) - factor) + (kernel_size - 1)
            pad0 = (p + 1) // 2
            pad1 = p // 2

            self.blur = Blur(blur_kernel, pad=(pad0, pad1))

        fan_in = in_channel * kernel_size ** 2
        self.scale = 1 / math.sqrt(fan_in)
        self.padding = kernel_size // 2

        self.weight = nn.Parameter(
            torch.randn(1, out_channel, in_channel, kernel_size, kernel_size)
        )

        self.modulation = EqualLinear(style_dim, in_channel, bias_init=1)

        self.demodulate = demodulate

    def __repr__(self):
        return (
            f"{self.__class__.__name__}({self.in_channel}, {self.out_channel}, {self.kernel_size}, "
            f"upsample={self.upsample}, downsample={self.downsample})"
        )

    def forward(self, input, style):
        batch, in_channel, height, width = input.shape
        style = self.modulation(style).view(batch, 1, in_channel, 1, 1)

        weight = self.scale * self.weight * style

        if self.demodulate:
            demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + 1e-8)
            weight = weight * demod.view(batch, self.out_channel, 1, 1, 1)

        weight = weight.view(
            batch * self.out_channel, in_channel, self.kernel_size, self.kernel_size
        )

        if self.upsample:
            input = input.view(1, batch * in_channel, height, width)
            weight = weight.view(
                batch, self.out_channel, in_channel, self.kernel_size, self.kernel_size
            )
            weight = weight.transpose(1, 2).reshape(
                batch * in_channel, self.out_channel, self.kernel_size, self.kernel_size
            )
            out = F.conv_transpose2d(input, weight, padding=0, stride=2, groups=batch)
            _, _, height, width = out.shape
            out = out.view(batch, self.out_channel, height, width)
            out = self.blur(out)

        elif self.downsample:
            input = self.blur(input)
            _, _, height, width = input.shape
            input = input.view(1, batch * in_channel, height, width)
            out = F.conv2d(input, weight, padding=0, stride=2, groups=batch)
            _, _, height, width = out.shape
            out = out.view(batch, self.out_channel, height, width)

        else:
            input = input.view(1, batch * in_channel, height, width)
            out = F.conv2d(input, weight, padding=self.padding, groups=batch)
            _, _, height, width = out.shape
            out = out.view(batch, self.out_channel, height, width)

        return out


class NoiseInjection(nn.Module):
    def __init__(self, project=False):
        super().__init__()
        self.project = project
        self.weight = nn.Parameter(torch.zeros(1))
        self.prev_noise = None
        self.mesh_fn = None
        self.vert_noise = None

    def create_pytorch_mesh(self, trimesh):
        v=trimesh.vertices; f=trimesh.faces
        verts = torch.from_numpy(np.asarray(v)).to(torch.float32).cuda()
        mesh_pytorch = Meshes(
            verts=[verts],
            faces = [torch.from_numpy(np.asarray(f)).to(torch.float32).cuda()],
            textures=None
        )
        if self.vert_noise == None or verts.shape[0] != self.vert_noise.shape[1]:
            self.vert_noise = torch.ones_like(verts)[:,0:1].cpu().normal_().expand(-1,3).unsqueeze(0)

        mesh_pytorch = add_textures(meshes=mesh_pytorch, vertex_colors=self.vert_noise.to(verts.device))

        return mesh_pytorch

    def load_mc_mesh(self, filename, resolution=128, im_res=64):
        import trimesh

        mc_tri=trimesh.load_mesh(filename)
        v=mc_tri.vertices; f=mc_tri.faces
        mesh2=trimesh.base.Trimesh(vertices=v, faces=f)
        if im_res==64 or im_res==128:
            pytorch3d_mesh = self.create_pytorch_mesh(mesh2)
            return pytorch3d_mesh
        v,f = trimesh.remesh.subdivide(v,f)
        mesh2_subdiv = trimesh.base.Trimesh(vertices=v, faces=f)
        if im_res==256:
            pytorch3d_mesh = self.create_pytorch_mesh(mesh2_subdiv);
            return pytorch3d_mesh
        v,f = trimesh.remesh.subdivide(mesh2_subdiv.vertices,mesh2_subdiv.faces)
        mesh3_subdiv = trimesh.base.Trimesh(vertices=v, faces=f)
        if im_res==256:
            pytorch3d_mesh = self.create_pytorch_mesh(mesh3_subdiv);
            return pytorch3d_mesh
        v,f = trimesh.remesh.subdivide(mesh3_subdiv.vertices,mesh3_subdiv.faces)
        mesh4_subdiv = trimesh.base.Trimesh(vertices=v, faces=f)

        pytorch3d_mesh = self.create_pytorch_mesh(mesh4_subdiv)

        return pytorch3d_mesh

    def project_noise(self, noise, transform, mesh_path=None):
        batch, _, height, width = noise.shape
        assert(batch == 1)  # assuming during inference batch size is 1

        angles = matrix_to_euler_angles(transform[0:1,:,:3], "ZYX")
        azim = float(angles[0][1])
        elev = float(-angles[0][2])

        cameras = create_cameras(azim=azim*180/np.pi,elev=elev*180/np.pi,fov=12.,dist=1)

        renderer = create_depth_mesh_renderer(cameras, image_size=height,
                specular_color=((0,0,0),), ambient_color=((1.,1.,1.),),diffuse_color=((0,0,0),))


        if self.mesh_fn is None or self.mesh_fn != mesh_path:
            self.mesh_fn = mesh_path

        pytorch3d_mesh = self.load_mc_mesh(mesh_path, im_res=height)
        rgb, depth = renderer(pytorch3d_mesh)

        depth_max = depth.max(-1)[0].view(-1) # (NxN)
        depth_valid = depth_max > 0.
        if self.prev_noise is None:
            self.prev_noise = noise
        noise_copy = self.prev_noise.clone()
        noise_copy.view(-1)[depth_valid] = rgb[0,:,:,0].view(-1)[depth_valid]
        noise_copy = noise_copy.reshape(1,1,height,height)  # 1x1xNxN

        return noise_copy


    def forward(self, image, noise=None, transform=None, mesh_path=None):
        batch, _, height, width = image.shape
        if noise is None:
            noise = image.new_empty(batch, 1, height, width).normal_()
        elif self.project:
            noise = self.project_noise(noise, transform, mesh_path=mesh_path)

        return image + self.weight * noise


class StyledConv(nn.Module):
    def __init__(self, in_channel, out_channel, kernel_size, style_dim,
                 upsample=False, blur_kernel=[1, 3, 3, 1], project_noise=False):
        super().__init__()

        self.conv = ModulatedConv2d(
            in_channel,
            out_channel,
            kernel_size,
            style_dim,
            upsample=upsample,
            blur_kernel=blur_kernel,
        )

        self.noise = NoiseInjection(project=project_noise)
        self.bias = nn.Parameter(torch.zeros(1, out_channel, 1, 1))
        self.activate = FusedLeakyReLU(out_channel)

    def forward(self, input, style, noise=None, transform=None, mesh_path=None):
        out = self.conv(input, style)
        out = self.noise(out, noise=noise, transform=transform, mesh_path=mesh_path)
        out = self.activate(out)

        return out


class ToRGB(nn.Module):
    def __init__(self, in_channel, style_dim, upsample=True, blur_kernel=[1, 3, 3, 1]):
        super().__init__()

        self.upsample = upsample
        out_channels = 3
        if upsample:
            self.upsample = Upsample(blur_kernel)

        self.conv = ModulatedConv2d(in_channel, out_channels, 1, style_dim, demodulate=False)
        self.bias = nn.Parameter(torch.zeros(1, out_channels, 1, 1))

    def forward(self, input, style, skip=None):
        out = self.conv(input, style)
        out = out + self.bias

        if skip is not None:
            if self.upsample:
                skip = self.upsample(skip)

            out = out + skip

        return out


class ConvLayer(nn.Sequential):
    def __init__(self, in_channel, out_channel, kernel_size, downsample=False,
                 blur_kernel=[1, 3, 3, 1], bias=True, activate=True):
        layers = []

        if downsample:
            factor = 2
            p = (len(blur_kernel) - factor) + (kernel_size - 1)
            pad0 = (p + 1) // 2
            pad1 = p // 2

            layers.append(Blur(blur_kernel, pad=(pad0, pad1)))

            stride = 2
            self.padding = 0

        else:
            stride = 1
            self.padding = kernel_size // 2

        layers.append(
            EqualConv2d(
                in_channel,
                out_channel,
                kernel_size,
                padding=self.padding,
                stride=stride,
                bias=bias and not activate,
            )
        )

        if activate:
            layers.append(FusedLeakyReLU(out_channel, bias=bias))

        super().__init__(*layers)


class Decoder(nn.Module):
    def __init__(self, model_opt, blur_kernel=[1, 3, 3, 1]):
        super().__init__()
        # decoder mapping network
        self.size = model_opt.size
        self.style_dim = model_opt.style_dim * 2
        thumb_im_size = model_opt.renderer_spatial_output_dim

        layers = [PixelNorm(),
                   EqualLinear(
                       self.style_dim // 2, self.style_dim, lr_mul=model_opt.lr_mapping, activation="fused_lrelu"
                   )]

        for i in range(4):
            layers.append(
                EqualLinear(
                    self.style_dim, self.style_dim, lr_mul=model_opt.lr_mapping, activation="fused_lrelu"
                )
            )

        self.style = nn.Sequential(*layers)

        # decoder network
        self.channels = {
            4: 512,
            8: 512,
            16: 512,
            32: 512,
            64: 256 * model_opt.channel_multiplier,
            128: 128 * model_opt.channel_multiplier,
            256: 64 * model_opt.channel_multiplier,
            512: 32 * model_opt.channel_multiplier,
            1024: 16 * model_opt.channel_multiplier,
        }

        decoder_in_size = model_opt.renderer_spatial_output_dim

        # image decoder
        self.log_size = int(math.log(self.size, 2))
        self.log_in_size = int(math.log(decoder_in_size, 2))

        self.conv1 = StyledConv(
            model_opt.feature_encoder_in_channels,
            self.channels[decoder_in_size], 3, self.style_dim, blur_kernel=blur_kernel,
            project_noise=model_opt.project_noise)

        self.to_rgb1 = ToRGB(self.channels[decoder_in_size], self.style_dim, upsample=False)

        self.num_layers = (self.log_size - self.log_in_size) * 2 + 1

        self.convs = nn.ModuleList()
        self.upsamples = nn.ModuleList()
        self.to_rgbs = nn.ModuleList()
        self.noises = nn.Module()

        in_channel = self.channels[decoder_in_size]

        for layer_idx in range(self.num_layers):
            res = (layer_idx + 2 * self.log_in_size + 1) // 2
            shape = [1, 1, 2 ** (res), 2 ** (res)]
            self.noises.register_buffer(f"noise_{layer_idx}", torch.randn(*shape))

        for i in range(self.log_in_size+1, self.log_size + 1):
            out_channel = self.channels[2 ** i]

            self.convs.append(
                StyledConv(in_channel, out_channel, 3, self.style_dim, upsample=True,
                           blur_kernel=blur_kernel, project_noise=model_opt.project_noise)
            )

            self.convs.append(
                StyledConv(out_channel, out_channel, 3, self.style_dim,
                           blur_kernel=blur_kernel, project_noise=model_opt.project_noise)
            )

            self.to_rgbs.append(ToRGB(out_channel, self.style_dim))

            in_channel = out_channel

        self.n_latent = (self.log_size - self.log_in_size) * 2 + 2

    def mean_latent(self, renderer_latent):
        latent = self.style(renderer_latent).mean(0, keepdim=True)

        return latent

    def get_latent(self, input):
        return self.style(input)

    def styles_and_noise_forward(self, styles, noise, inject_index=None, truncation=1,
                                 truncation_latent=None, input_is_latent=False,
                                 randomize_noise=True):
        if not input_is_latent:
            styles = [self.style(s) for s in styles]

        if noise is None:
            if randomize_noise:
                noise = [None] * self.num_layers
            else:
                noise = [
                    getattr(self.noises, f"noise_{i}") for i in range(self.num_layers)
                ]

        if (truncation < 1):
            style_t = []

            for style in styles:
                style_t.append(
                    truncation_latent[1] + truncation * (style - truncation_latent[1])
                )

            styles = style_t

        if len(styles) < 2:
            inject_index = self.n_latent

            if styles[0].ndim < 3:
                latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)

            else:
                latent = styles[0]
        else:
            if inject_index is None:
                inject_index = random.randint(1, self.n_latent - 1)

            latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
            latent2 = styles[1].unsqueeze(1).repeat(1, self.n_latent - inject_index, 1)

            latent = torch.cat([latent, latent2], 1)

        return latent, noise

    def forward(self, features, styles, rgbd_in=None, transform=None,
                return_latents=False, inject_index=None, truncation=1,
                truncation_latent=None, input_is_latent=False, noise=None,
                randomize_noise=True, mesh_path=None):
        latent, noise = self.styles_and_noise_forward(styles, noise, inject_index, truncation,
                                                      truncation_latent, input_is_latent,
                                                      randomize_noise)

        out = self.conv1(features, latent[:, 0], noise=noise[0],
                         transform=transform, mesh_path=mesh_path)

        skip = self.to_rgb1(out, latent[:, 1], skip=rgbd_in)

        i = 1
        for conv1, conv2, noise1, noise2, to_rgb in zip(
            self.convs[::2], self.convs[1::2], noise[1::2], noise[2::2], self.to_rgbs
        ):
            out = conv1(out, latent[:, i], noise=noise1,
                           transform=transform, mesh_path=mesh_path)
            out = conv2(out, latent[:, i + 1], noise=noise2,
                                   transform=transform, mesh_path=mesh_path)
            skip = to_rgb(out, latent[:, i + 2], skip=skip)

            i += 2

        out_latent = latent if return_latents else None
        image = skip

        return image, out_latent


class Generator(nn.Module):
    def __init__(self, model_opt, renderer_opt, blur_kernel=[1, 3, 3, 1], ema=False, full_pipeline=True):
        super().__init__()
        self.size = model_opt.size
        self.style_dim = model_opt.style_dim
        self.num_layers = 1
        self.train_renderer = not model_opt.freeze_renderer
        self.full_pipeline = full_pipeline
        model_opt.feature_encoder_in_channels = renderer_opt.width

        if ema or 'is_test' in model_opt.keys():
            self.is_train = False
        else:
            self.is_train = True

        # volume renderer mapping_network
        layers = []
        for i in range(3):
            layers.append(
                MappingLinear(self.style_dim, self.style_dim, activation="fused_lrelu")
            )

        self.style = nn.Sequential(*layers)

        # volume renderer
        thumb_im_size = model_opt.renderer_spatial_output_dim
        self.renderer = VolumeFeatureRenderer(renderer_opt, style_dim=self.style_dim,
                                              out_im_res=thumb_im_size)

        if self.full_pipeline:
            self.decoder = Decoder(model_opt)

    def make_noise(self):
        device = self.input.input.device

        noises = [torch.randn(1, 1, 2 ** 2, 2 ** 2, device=device)]

        for i in range(3, self.log_size + 1):
            for _ in range(2):
                noises.append(torch.randn(1, 1, 2 ** i, 2 ** i, device=device))

        return noises

    def mean_latent(self, n_latent, device):
        latent_in = torch.randn(n_latent, self.style_dim, device=device)
        renderer_latent = self.style(latent_in)
        renderer_latent_mean = renderer_latent.mean(0, keepdim=True)
        if self.full_pipeline:
            decoder_latent_mean = self.decoder.mean_latent(renderer_latent)
        else:
            decoder_latent_mean = None

        return [renderer_latent_mean, decoder_latent_mean]

    def get_latent(self, input):
        return self.style(input)

    def styles_and_noise_forward(self, styles, inject_index=None, truncation=1,
                                 truncation_latent=None, input_is_latent=False):
        if not input_is_latent:
            styles = [self.style(s) for s in styles]

        if truncation < 1:
            style_t = []

            for style in styles:
                style_t.append(
                    truncation_latent[0] + truncation * (style - truncation_latent[0])
                )

            styles = style_t

        return styles

    def init_forward(self, styles, cam_poses, focals, near=0.88, far=1.12):
        latent = self.styles_and_noise_forward(styles)

        sdf, target_values = self.renderer.mlp_init_pass(cam_poses, focals, near, far, styles=latent[0])

        return sdf, target_values

    def forward(self, styles, cam_poses, focals, near=0.88, far=1.12, return_latents=False,
                inject_index=None, truncation=1, truncation_latent=None,
                input_is_latent=False, noise=None, randomize_noise=True,
                return_sdf=False, return_xyz=False, return_eikonal=False,
                project_noise=False, mesh_path=None):

        # do not calculate renderer gradients if renderer weights are frozen
        with torch.set_grad_enabled(self.is_train and self.train_renderer):
            latent = self.styles_and_noise_forward(styles, inject_index, truncation,
                                                   truncation_latent, input_is_latent)

            thumb_rgb, features, sdf, mask, xyz, eikonal_term = self.renderer(cam_poses, focals, near, far, styles=latent[0], return_eikonal=return_eikonal)

        if self.full_pipeline:
            rgb, decoder_latent = self.decoder(features, latent,
                                               transform=cam_poses if project_noise else None,
                                               return_latents=return_latents,
                                               inject_index=inject_index, truncation=truncation,
                                               truncation_latent=truncation_latent, noise=noise,
                                               input_is_latent=input_is_latent, randomize_noise=randomize_noise,
                                               mesh_path=mesh_path)

        else:
            rgb = None

        if return_latents:
            return rgb, decoder_latent
        else:
            out = (rgb, thumb_rgb)
            if return_xyz:
                out += (xyz,)
            if return_sdf:
                out += (sdf,)
            if return_eikonal:
                out += (eikonal_term,)
            if return_xyz:
                out += (mask,)

            return out

############# Volume Renderer Building Blocks & Discriminator ##################
class VolumeRenderDiscConv2d(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1,
                 padding=0, bias=True, activate=False):
        super(VolumeRenderDiscConv2d, self).__init__()

        self.conv = nn.Conv2d(in_channels, out_channels,
                              kernel_size, stride, padding, bias=bias and not activate)

        self.activate = activate
        if self.activate:
            self.activation = FusedLeakyReLU(out_channels, bias=bias, scale=1)
            bias_init_coef = np.sqrt(1 / (in_channels * kernel_size * kernel_size))
            nn.init.uniform_(self.activation.bias, a=-bias_init_coef, b=bias_init_coef)


    def forward(self, input):
        """
        input_tensor_shape: (N, C_in,H,W)
        output_tensor_shape: (N,C_out,H_out,W_out)
        :return: Conv2d + activation Result
        """
        out = self.conv(input)
        if self.activate:
            out = self.activation(out)

        return out


class AddCoords(nn.Module):
    def __init__(self):
        super(AddCoords, self).__init__()

    def forward(self, input_tensor):
        """
        :param input_tensor: shape (N, C_in, H, W)
        :return:
        """
        batch_size_shape, channel_in_shape, dim_y, dim_x = input_tensor.shape
        xx_channel = torch.arange(dim_x, dtype=torch.float32, device=input_tensor.device).repeat(1,1,dim_y,1)
        yy_channel = torch.arange(dim_y, dtype=torch.float32, device=input_tensor.device).repeat(1,1,dim_x,1).transpose(2,3)

        xx_channel = xx_channel / (dim_x - 1)
        yy_channel = yy_channel / (dim_y - 1)

        xx_channel = xx_channel * 2 - 1
        yy_channel = yy_channel * 2 - 1

        xx_channel = xx_channel.repeat(batch_size_shape, 1, 1, 1)
        yy_channel = yy_channel.repeat(batch_size_shape, 1, 1, 1)
        out = torch.cat([input_tensor, yy_channel, xx_channel], dim=1)

        return out


class CoordConv2d(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1,
                 padding=0, bias=True):
        super(CoordConv2d, self).__init__()

        self.addcoords = AddCoords()
        self.conv = nn.Conv2d(in_channels + 2, out_channels,
                              kernel_size, stride=stride, padding=padding, bias=bias)

    def forward(self, input_tensor):
        """
        input_tensor_shape: (N, C_in,H,W)
        output_tensor_shape: N,C_out,H_out,W_out)
        :return: CoordConv2d Result
        """
        out = self.addcoords(input_tensor)
        out = self.conv(out)

        return out


class CoordConvLayer(nn.Module):
    def __init__(self, in_channel, out_channel, kernel_size, bias=True, activate=True):
        super(CoordConvLayer, self).__init__()
        layers = []
        stride = 1
        self.activate = activate
        self.padding = kernel_size // 2 if kernel_size > 2 else 0

        self.conv = CoordConv2d(in_channel, out_channel, kernel_size,
                                padding=self.padding, stride=stride,
                                bias=bias and not activate)

        if activate:
            self.activation = FusedLeakyReLU(out_channel, bias=bias, scale=1)

        bias_init_coef = np.sqrt(1 / (in_channel * kernel_size * kernel_size))
        nn.init.uniform_(self.activation.bias, a=-bias_init_coef, b=bias_init_coef)

    def forward(self, input):
        out = self.conv(input)
        if self.activate:
            out = self.activation(out)

        return out


class VolumeRenderResBlock(nn.Module):
    def __init__(self, in_channel, out_channel):
        super().__init__()

        self.conv1 = CoordConvLayer(in_channel, out_channel, 3)
        self.conv2 = CoordConvLayer(out_channel, out_channel, 3)
        self.pooling = nn.AvgPool2d(2)
        self.downsample = nn.AvgPool2d(2)
        if out_channel != in_channel:
            self.skip = VolumeRenderDiscConv2d(in_channel, out_channel, 1)
        else:
            self.skip = None

    def forward(self, input):
        out = self.conv1(input)
        out = self.conv2(out)
        out = self.pooling(out)

        downsample_in = self.downsample(input)
        if self.skip != None:
            skip_in = self.skip(downsample_in)
        else:
            skip_in = downsample_in

        out = (out + skip_in) / math.sqrt(2)

        return out


class VolumeRenderDiscriminator(nn.Module):
    def __init__(self, opt):
        super().__init__()
        init_size = opt.renderer_spatial_output_dim
        self.viewpoint_loss = not opt.no_viewpoint_loss
        final_out_channel = 3 if self.viewpoint_loss else 1
        channels = {
            2: 400,
            4: 400,
            8: 400,
            16: 400,
            32: 256,
            64: 128,
            128: 64,
        }

        convs = [VolumeRenderDiscConv2d(3, channels[init_size], 1, activate=True)]

        log_size = int(math.log(init_size, 2))

        in_channel = channels[init_size]

        for i in range(log_size-1, 0, -1):
            out_channel = channels[2 ** i]

            convs.append(VolumeRenderResBlock(in_channel, out_channel))

            in_channel = out_channel

        self.convs = nn.Sequential(*convs)

        self.final_conv = VolumeRenderDiscConv2d(in_channel, final_out_channel, 2)

    def forward(self, input):
        out = self.convs(input)
        out = self.final_conv(out)
        gan_preds = out[:,0:1]
        gan_preds = gan_preds.view(-1, 1)
        if self.viewpoint_loss:
            viewpoints_preds = out[:,1:]
            viewpoints_preds = viewpoints_preds.view(-1,2)
        else:
            viewpoints_preds = None

        return gan_preds, viewpoints_preds

######################### StyleGAN Discriminator ########################
class ResBlock(nn.Module):
    def __init__(self, in_channel, out_channel, blur_kernel=[1, 3, 3, 1], merge=False):
        super().__init__()

        self.conv1 = ConvLayer(2 * in_channel if merge else in_channel, in_channel, 3)
        self.conv2 = ConvLayer(in_channel, out_channel, 3, downsample=True)
        self.skip = ConvLayer(2 * in_channel if merge else in_channel, out_channel,
                              1, downsample=True, activate=False, bias=False)

    def forward(self, input):
        out = self.conv1(input)
        out = self.conv2(out)
        out = (out + self.skip(input)) / math.sqrt(2)

        return out


class Discriminator(nn.Module):
    def __init__(self, opt, blur_kernel=[1, 3, 3, 1]):
        super().__init__()
        init_size = opt.size

        channels = {
            4: 512,
            8: 512,
            16: 512,
            32: 512,
            64: 256 * opt.channel_multiplier,
            128: 128 * opt.channel_multiplier,
            256: 64 * opt.channel_multiplier,
            512: 32 * opt.channel_multiplier,
            1024: 16 * opt.channel_multiplier,
        }

        convs = [ConvLayer(3, channels[init_size], 1)]

        log_size = int(math.log(init_size, 2))

        in_channel = channels[init_size]

        for i in range(log_size, 2, -1):
            out_channel = channels[2 ** (i - 1)]

            convs.append(ResBlock(in_channel, out_channel, blur_kernel))

            in_channel = out_channel

        self.convs = nn.Sequential(*convs)

        self.stddev_group = 4
        self.stddev_feat = 1

        # minibatch discrimination
        in_channel += 1

        self.final_conv = ConvLayer(in_channel, channels[4], 3)
        self.final_linear = nn.Sequential(
            EqualLinear(channels[4] * 4 * 4, channels[4], activation="fused_lrelu"),
            EqualLinear(channels[4], 1),
        )

    def forward(self, input):
        out = self.convs(input)

        # minibatch discrimination
        batch, channel, height, width = out.shape
        group = min(batch, self.stddev_group)
        if batch % group != 0:
            group = 3 if batch % 3 == 0 else 2

        stddev = out.view(
            group, -1, self.stddev_feat, channel // self.stddev_feat, height, width
        )
        stddev = torch.sqrt(stddev.var(0, unbiased=False) + 1e-8)
        stddev = stddev.mean([2, 3, 4], keepdims=True).squeeze(2)
        stddev = stddev.repeat(group, 1, height, width)
        final_out = torch.cat([out, stddev], 1)

        # final layers
        final_out = self.final_conv(final_out)
        final_out = final_out.view(batch, -1)
        final_out = self.final_linear(final_out)
        gan_preds = final_out[:,:1]

        return gan_preds


================================================
FILE: op/__init__.py
================================================
from .fused_act import FusedLeakyReLU, fused_leaky_relu
from .upfirdn2d import upfirdn2d


================================================
FILE: op/fused_act.py
================================================
import os

import torch
from torch import nn
from torch.nn import functional as F
from torch.autograd import Function
from torch.utils.cpp_extension import load


module_path = os.path.dirname(__file__)
fused = load(
    "fused",
    sources=[
        os.path.join(module_path, "fused_bias_act.cpp"),
        os.path.join(module_path, "fused_bias_act_kernel.cu"),
    ],
)


class FusedLeakyReLUFunctionBackward(Function):
    @staticmethod
    def forward(ctx, grad_output, out, bias, negative_slope, scale):
        ctx.save_for_backward(out)
        ctx.negative_slope = negative_slope
        ctx.scale = scale

        empty = grad_output.new_empty(0)

        grad_input = fused.fused_bias_act(
            grad_output, empty, out, 3, 1, negative_slope, scale
        )

        dim = [0]

        if grad_input.ndim > 2:
            dim += list(range(2, grad_input.ndim))

        if bias:
            grad_bias = grad_input.sum(dim).detach()

        else:
            grad_bias = empty

        return grad_input, grad_bias

    @staticmethod
    def backward(ctx, gradgrad_input, gradgrad_bias):
        out, = ctx.saved_tensors
        gradgrad_out = fused.fused_bias_act(
            gradgrad_input, gradgrad_bias, out, 3, 1, ctx.negative_slope, ctx.scale
        )

        return gradgrad_out, None, None, None, None


class FusedLeakyReLUFunction(Function):
    @staticmethod
    def forward(ctx, input, bias, negative_slope, scale):
        empty = input.new_empty(0)

        ctx.bias = bias is not None

        if bias is None:
            bias = empty

        out = fused.fused_bias_act(input, bias, empty, 3, 0, negative_slope, scale)
        ctx.save_for_backward(out)
        ctx.negative_slope = negative_slope
        ctx.scale = scale

        return out

    @staticmethod
    def backward(ctx, grad_output):
        out, = ctx.saved_tensors

        grad_input, grad_bias = FusedLeakyReLUFunctionBackward.apply(
            grad_output, out, ctx.bias, ctx.negative_slope, ctx.scale
        )

        if not ctx.bias:
            grad_bias = None

        return grad_input, grad_bias, None, None


class FusedLeakyReLU(nn.Module):
    def __init__(self, channel, bias=True, negative_slope=0.2, scale=2 ** 0.5):
        super().__init__()

        if bias:
            self.bias = nn.Parameter(torch.zeros(channel))

        else:
            self.bias = None

        self.negative_slope = negative_slope
        self.scale = scale

    def forward(self, input):
        return fused_leaky_relu(input, self.bias, self.negative_slope, self.scale)


def fused_leaky_relu(input, bias=None, negative_slope=0.2, scale=2 ** 0.5):
    if input.device.type == "cpu":
        if bias is not None:
            rest_dim = [1] * (input.ndim - bias.ndim - 1)
            return (
                F.leaky_relu(
                    input + bias.view(1, bias.shape[0], *rest_dim), negative_slope=0.2
                )
                * scale
            )

        else:
            return F.leaky_relu(input, negative_slope=0.2) * scale

    else:
        return FusedLeakyReLUFunction.apply(input, bias, negative_slope, scale)


================================================
FILE: op/fused_bias_act.cpp
================================================
#include <torch/extension.h>


torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer,
    int act, int grad, float alpha, float scale);

#define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)

torch::Tensor fused_bias_act(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer,
    int act, int grad, float alpha, float scale) {
    CHECK_CUDA(input);
    CHECK_CUDA(bias);

    return fused_bias_act_op(input, bias, refer, act, grad, alpha, scale);
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
    m.def("fused_bias_act", &fused_bias_act, "fused bias act (CUDA)");
}

================================================
FILE: op/fused_bias_act_kernel.cu
================================================
// Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
//
// This work is made available under the Nvidia Source Code License-NC.
// To view a copy of this license, visit
// https://nvlabs.github.io/stylegan2/license.html

#include <torch/types.h>

#include <ATen/ATen.h>
#include <ATen/AccumulateType.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/CUDAApplyUtils.cuh>

#include <cuda.h>
#include <cuda_runtime.h>


template <typename scalar_t>
static __global__ void fused_bias_act_kernel(scalar_t* out, const scalar_t* p_x, const scalar_t* p_b, const scalar_t* p_ref,
    int act, int grad, scalar_t alpha, scalar_t scale, int loop_x, int size_x, int step_b, int size_b, int use_bias, int use_ref) {
    int xi = blockIdx.x * loop_x * blockDim.x + threadIdx.x;

    scalar_t zero = 0.0;

    for (int loop_idx = 0; loop_idx < loop_x && xi < size_x; loop_idx++, xi += blockDim.x) {
        scalar_t x = p_x[xi];

        if (use_bias) {
            x += p_b[(xi / step_b) % size_b];
        }

        scalar_t ref = use_ref ? p_ref[xi] : zero;

        scalar_t y;

        switch (act * 10 + grad) {
            default:
            case 10: y = x; break;
            case 11: y = x; break;
            case 12: y = 0.0; break;

            case 30: y = (x > 0.0) ? x : x * alpha; break;
            case 31: y = (ref > 0.0) ? x : x * alpha; break;
            case 32: y = 0.0; break;
        }

        out[xi] = y * scale;
    }
}


torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer,
    int act, int grad, float alpha, float scale) {
    int curDevice = -1;
    cudaGetDevice(&curDevice);
    cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice);

    auto x = input.contiguous();
    auto b = bias.contiguous();
    auto ref = refer.contiguous();

    int use_bias = b.numel() ? 1 : 0;
    int use_ref = ref.numel() ? 1 : 0;

    int size_x = x.numel();
    int size_b = b.numel();
    int step_b = 1;

    for (int i = 1 + 1; i < x.dim(); i++) {
        step_b *= x.size(i);
    }

    int loop_x = 4;
    int block_size = 4 * 32;
    int grid_size = (size_x - 1) / (loop_x * block_size) + 1;

    auto y = torch::empty_like(x);

    AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "fused_bias_act_kernel", [&] {
        fused_bias_act_kernel<scalar_t><<<grid_size, block_size, 0, stream>>>(
            y.data_ptr<scalar_t>(),
            x.data_ptr<scalar_t>(),
            b.data_ptr<scalar_t>(),
            ref.data_ptr<scalar_t>(),
            act,
            grad,
            alpha,
            scale,
            loop_x,
            size_x,
            step_b,
            size_b,
            use_bias,
            use_ref
        );
    });

    return y;
}

================================================
FILE: op/upfirdn2d.cpp
================================================
#include <torch/extension.h>


torch::Tensor upfirdn2d_op(const torch::Tensor& input, const torch::Tensor& kernel,
                            int up_x, int up_y, int down_x, int down_y,
                            int pad_x0, int pad_x1, int pad_y0, int pad_y1);

#define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)

torch::Tensor upfirdn2d(const torch::Tensor& input, const torch::Tensor& kernel,
                        int up_x, int up_y, int down_x, int down_y,
                        int pad_x0, int pad_x1, int pad_y0, int pad_y1) {
    CHECK_CUDA(input);
    CHECK_CUDA(kernel);

    return upfirdn2d_op(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1);
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
    m.def("upfirdn2d", &upfirdn2d, "upfirdn2d (CUDA)");
}

================================================
FILE: op/upfirdn2d.py
================================================
import os

import torch
from torch.nn import functional as F
from torch.autograd import Function
from torch.utils.cpp_extension import load
from pdb import set_trace as st


module_path = os.path.dirname(__file__)
upfirdn2d_op = load(
    "upfirdn2d",
    sources=[
        os.path.join(module_path, "upfirdn2d.cpp"),
        os.path.join(module_path, "upfirdn2d_kernel.cu"),
    ],
)


class UpFirDn2dBackward(Function):
    @staticmethod
    def forward(
        ctx, grad_output, kernel, grad_kernel, up, down, pad, g_pad, in_size, out_size
    ):

        up_x, up_y = up
        down_x, down_y = down
        g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1 = g_pad

        grad_output = grad_output.reshape(-1, out_size[0], out_size[1], 1)

        grad_input = upfirdn2d_op.upfirdn2d(
            grad_output,
            grad_kernel,
            down_x,
            down_y,
            up_x,
            up_y,
            g_pad_x0,
            g_pad_x1,
            g_pad_y0,
            g_pad_y1,
        )
        grad_input = grad_input.view(in_size[0], in_size[1], in_size[2], in_size[3])

        ctx.save_for_backward(kernel)

        pad_x0, pad_x1, pad_y0, pad_y1 = pad

        ctx.up_x = up_x
        ctx.up_y = up_y
        ctx.down_x = down_x
        ctx.down_y = down_y
        ctx.pad_x0 = pad_x0
        ctx.pad_x1 = pad_x1
        ctx.pad_y0 = pad_y0
        ctx.pad_y1 = pad_y1
        ctx.in_size = in_size
        ctx.out_size = out_size

        return grad_input

    @staticmethod
    def backward(ctx, gradgrad_input):
        kernel, = ctx.saved_tensors

        gradgrad_input = gradgrad_input.reshape(-1, ctx.in_size[2], ctx.in_size[3], 1)

        gradgrad_out = upfirdn2d_op.upfirdn2d(
            gradgrad_input,
            kernel,
            ctx.up_x,
            ctx.up_y,
            ctx.down_x,
            ctx.down_y,
            ctx.pad_x0,
            ctx.pad_x1,
            ctx.pad_y0,
            ctx.pad_y1,
        )
        # gradgrad_out = gradgrad_out.view(ctx.in_size[0], ctx.out_size[0], ctx.out_size[1], ctx.in_size[3])
        gradgrad_out = gradgrad_out.view(
            ctx.in_size[0], ctx.in_size[1], ctx.out_size[0], ctx.out_size[1]
        )

        return gradgrad_out, None, None, None, None, None, None, None, None


class UpFirDn2d(Function):
    @staticmethod
    def forward(ctx, input, kernel, up, down, pad):
        up_x, up_y = up
        down_x, down_y = down
        pad_x0, pad_x1, pad_y0, pad_y1 = pad

        kernel_h, kernel_w = kernel.shape
        batch, channel, in_h, in_w = input.shape
        ctx.in_size = input.shape

        input = input.reshape(-1, in_h, in_w, 1)

        ctx.save_for_backward(kernel, torch.flip(kernel, [0, 1]))

        out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1
        out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1
        ctx.out_size = (out_h, out_w)

        ctx.up = (up_x, up_y)
        ctx.down = (down_x, down_y)
        ctx.pad = (pad_x0, pad_x1, pad_y0, pad_y1)

        g_pad_x0 = kernel_w - pad_x0 - 1
        g_pad_y0 = kernel_h - pad_y0 - 1
        g_pad_x1 = in_w * up_x - out_w * down_x + pad_x0 - up_x + 1
        g_pad_y1 = in_h * up_y - out_h * down_y + pad_y0 - up_y + 1

        ctx.g_pad = (g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1)

        out = upfirdn2d_op.upfirdn2d(
            input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1
        )
        # out = out.view(major, out_h, out_w, minor)
        out = out.view(-1, channel, out_h, out_w)

        return out

    @staticmethod
    def backward(ctx, grad_output):
        kernel, grad_kernel = ctx.saved_tensors

        grad_input = UpFirDn2dBackward.apply(
            grad_output,
            kernel,
            grad_kernel,
            ctx.up,
            ctx.down,
            ctx.pad,
            ctx.g_pad,
            ctx.in_size,
            ctx.out_size,
        )

        return grad_input, None, None, None, None


def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)):
    if input.device.type == "cpu":
        out = upfirdn2d_native(
            input, kernel, up, up, down, down, pad[0], pad[1], pad[0], pad[1]
        )

    else:
        out = UpFirDn2d.apply(
            input, kernel, (up, up), (down, down), (pad[0], pad[1], pad[0], pad[1])
        )

    return out


def upfirdn2d_native(
    input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1
):
    _, channel, in_h, in_w = input.shape
    input = input.reshape(-1, in_h, in_w, 1)

    _, in_h, in_w, minor = input.shape
    kernel_h, kernel_w = kernel.shape

    out = input.view(-1, in_h, 1, in_w, 1, minor)
    out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1])
    out = out.view(-1, in_h * up_y, in_w * up_x, minor)

    out = F.pad(
        out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)]
    )
    out = out[
        :,
        max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0),
        max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0),
        :,
    ]

    out = out.permute(0, 3, 1, 2)
    out = out.reshape(
        [-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1]
    )
    w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w)
    out = F.conv2d(out, w)
    out = out.reshape(
        -1,
        minor,
        in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1,
        in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1,
    )
    out = out.permute(0, 2, 3, 1)
    out = out[:, ::down_y, ::down_x, :]

    out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1
    out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1

    return out.view(-1, channel, out_h, out_w)


================================================
FILE: op/upfirdn2d_kernel.cu
================================================
// Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
//
// This work is made available under the Nvidia Source Code License-NC.
// To view a copy of this license, visit
// https://nvlabs.github.io/stylegan2/license.html

#include <torch/types.h>

#include <ATen/ATen.h>
#include <ATen/AccumulateType.h>
#include <ATen/cuda/CUDAApplyUtils.cuh>
#include <ATen/cuda/CUDAContext.h>

#include <cuda.h>
#include <cuda_runtime.h>

static __host__ __device__ __forceinline__ int floor_div(int a, int b) {
  int c = a / b;

  if (c * b > a) {
    c--;
  }

  return c;
}

struct UpFirDn2DKernelParams {
  int up_x;
  int up_y;
  int down_x;
  int down_y;
  int pad_x0;
  int pad_x1;
  int pad_y0;
  int pad_y1;

  int major_dim;
  int in_h;
  int in_w;
  int minor_dim;
  int kernel_h;
  int kernel_w;
  int out_h;
  int out_w;
  int loop_major;
  int loop_x;
};

template <typename scalar_t>
__global__ void upfirdn2d_kernel_large(scalar_t *out, const scalar_t *input,
                                       const scalar_t *kernel,
                                       const UpFirDn2DKernelParams p) {
  int minor_idx = blockIdx.x * blockDim.x + threadIdx.x;
  int out_y = minor_idx / p.minor_dim;
  minor_idx -= out_y * p.minor_dim;
  int out_x_base = blockIdx.y * p.loop_x * blockDim.y + threadIdx.y;
  int major_idx_base = blockIdx.z * p.loop_major;

  if (out_x_base >= p.out_w || out_y >= p.out_h ||
      major_idx_base >= p.major_dim) {
    return;
  }

  int mid_y = out_y * p.down_y + p.up_y - 1 - p.pad_y0;
  int in_y = min(max(floor_div(mid_y, p.up_y), 0), p.in_h);
  int h = min(max(floor_div(mid_y + p.kernel_h, p.up_y), 0), p.in_h) - in_y;
  int kernel_y = mid_y + p.kernel_h - (in_y + 1) * p.up_y;

  for (int loop_major = 0, major_idx = major_idx_base;
       loop_major < p.loop_major && major_idx < p.major_dim;
       loop_major++, major_idx++) {
    for (int loop_x = 0, out_x = out_x_base;
         loop_x < p.loop_x && out_x < p.out_w; loop_x++, out_x += blockDim.y) {
      int mid_x = out_x * p.down_x + p.up_x - 1 - p.pad_x0;
      int in_x = min(max(floor_div(mid_x, p.up_x), 0), p.in_w);
      int w = min(max(floor_div(mid_x + p.kernel_w, p.up_x), 0), p.in_w) - in_x;
      int kernel_x = mid_x + p.kernel_w - (in_x + 1) * p.up_x;

      const scalar_t *x_p =
          &input[((major_idx * p.in_h + in_y) * p.in_w + in_x) * p.minor_dim +
                 minor_idx];
      const scalar_t *k_p = &kernel[kernel_y * p.kernel_w + kernel_x];
      int x_px = p.minor_dim;
      int k_px = -p.up_x;
      int x_py = p.in_w * p.minor_dim;
      int k_py = -p.up_y * p.kernel_w;

      scalar_t v = 0.0f;

      for (int y = 0; y < h; y++) {
        for (int x = 0; x < w; x++) {
          v += static_cast<scalar_t>(*x_p) * static_cast<scalar_t>(*k_p);
          x_p += x_px;
          k_p += k_px;
        }

        x_p += x_py - w * x_px;
        k_p += k_py - w * k_px;
      }

      out[((major_idx * p.out_h + out_y) * p.out_w + out_x) * p.minor_dim +
          minor_idx] = v;
    }
  }
}

template <typename scalar_t, int up_x, int up_y, int down_x, int down_y,
          int kernel_h, int kernel_w, int tile_out_h, int tile_out_w>
__global__ void upfirdn2d_kernel(scalar_t *out, const scalar_t *input,
                                 const scalar_t *kernel,
                                 const UpFirDn2DKernelParams p) {
  const int tile_in_h = ((tile_out_h - 1) * down_y + kernel_h - 1) / up_y + 1;
  const int tile_in_w = ((tile_out_w - 1) * down_x + kernel_w - 1) / up_x + 1;

  __shared__ volatile float sk[kernel_h][kernel_w];
  __shared__ volatile float sx[tile_in_h][tile_in_w];

  int minor_idx = blockIdx.x;
  int tile_out_y = minor_idx / p.minor_dim;
  minor_idx -= tile_out_y * p.minor_dim;
  tile_out_y *= tile_out_h;
  int tile_out_x_base = blockIdx.y * p.loop_x * tile_out_w;
  int major_idx_base = blockIdx.z * p.loop_major;

  if (tile_out_x_base >= p.out_w | tile_out_y >= p.out_h |
      major_idx_base >= p.major_dim) {
    return;
  }

  for (int tap_idx = threadIdx.x; tap_idx < kernel_h * kernel_w;
       tap_idx += blockDim.x) {
    int ky = tap_idx / kernel_w;
    int kx = tap_idx - ky * kernel_w;
    scalar_t v = 0.0;

    if (kx < p.kernel_w & ky < p.kernel_h) {
      v = kernel[(p.kernel_h - 1 - ky) * p.kernel_w + (p.kernel_w - 1 - kx)];
    }

    sk[ky][kx] = v;
  }

  for (int loop_major = 0, major_idx = major_idx_base;
       loop_major < p.loop_major & major_idx < p.major_dim;
       loop_major++, major_idx++) {
    for (int loop_x = 0, tile_out_x = tile_out_x_base;
         loop_x < p.loop_x & tile_out_x < p.out_w;
         loop_x++, tile_out_x += tile_out_w) {
      int tile_mid_x = tile_out_x * down_x + up_x - 1 - p.pad_x0;
      int tile_mid_y = tile_out_y * down_y + up_y - 1 - p.pad_y0;
      int tile_in_x = floor_div(tile_mid_x, up_x);
      int tile_in_y = floor_div(tile_mid_y, up_y);

      __syncthreads();

      for (int in_idx = threadIdx.x; in_idx < tile_in_h * tile_in_w;
           in_idx += blockDim.x) {
        int rel_in_y = in_idx / tile_in_w;
        int rel_in_x = in_idx - rel_in_y * tile_in_w;
        int in_x = rel_in_x + tile_in_x;
        int in_y = rel_in_y + tile_in_y;

        scalar_t v = 0.0;

        if (in_x >= 0 & in_y >= 0 & in_x < p.in_w & in_y < p.in_h) {
          v = input[((major_idx * p.in_h + in_y) * p.in_w + in_x) *
                        p.minor_dim +
                    minor_idx];
        }

        sx[rel_in_y][rel_in_x] = v;
      }

      __syncthreads();
      for (int out_idx = threadIdx.x; out_idx < tile_out_h * tile_out_w;
           out_idx += blockDim.x) {
        int rel_out_y = out_idx / tile_out_w;
        int rel_out_x = out_idx - rel_out_y * tile_out_w;
        int out_x = rel_out_x + tile_out_x;
        int out_y = rel_out_y + tile_out_y;

        int mid_x = tile_mid_x + rel_out_x * down_x;
        int mid_y = tile_mid_y + rel_out_y * down_y;
        int in_x = floor_div(mid_x, up_x);
        int in_y = floor_div(mid_y, up_y);
        int rel_in_x = in_x - tile_in_x;
        int rel_in_y = in_y - tile_in_y;
        int kernel_x = (in_x + 1) * up_x - mid_x - 1;
        int kernel_y = (in_y + 1) * up_y - mid_y - 1;

        scalar_t v = 0.0;

#pragma unroll
        for (int y = 0; y < kernel_h / up_y; y++)
#pragma unroll
          for (int x = 0; x < kernel_w / up_x; x++)
            v += sx[rel_in_y + y][rel_in_x + x] *
                 sk[kernel_y + y * up_y][kernel_x + x * up_x];

        if (out_x < p.out_w & out_y < p.out_h) {
          out[((major_idx * p.out_h + out_y) * p.out_w + out_x) * p.minor_dim +
              minor_idx] = v;
        }
      }
    }
  }
}

torch::Tensor upfirdn2d_op(const torch::Tensor &input,
                           const torch::Tensor &kernel, int up_x, int up_y,
                           int down_x, int down_y, int pad_x0, int pad_x1,
                           int pad_y0, int pad_y1) {
  int curDevice = -1;
  cudaGetDevice(&curDevice);
  cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice);

  UpFirDn2DKernelParams p;

  auto x = input.contiguous();
  auto k = kernel.contiguous();

  p.major_dim = x.size(0);
  p.in_h = x.size(1);
  p.in_w = x.size(2);
  p.minor_dim = x.size(3);
  p.kernel_h = k.size(0);
  p.kernel_w = k.size(1);
  p.up_x = up_x;
  p.up_y = up_y;
  p.down_x = down_x;
  p.down_y = down_y;
  p.pad_x0 = pad_x0;
  p.pad_x1 = pad_x1;
  p.pad_y0 = pad_y0;
  p.pad_y1 = pad_y1;

  p.out_h = (p.in_h * p.up_y + p.pad_y0 + p.pad_y1 - p.kernel_h + p.down_y) /
            p.down_y;
  p.out_w = (p.in_w * p.up_x + p.pad_x0 + p.pad_x1 - p.kernel_w + p.down_x) /
            p.down_x;

  auto out =
      at::empty({p.major_dim, p.out_h, p.out_w, p.minor_dim}, x.options());

  int mode = -1;

  int tile_out_h = -1;
  int tile_out_w = -1;

  if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 &&
      p.kernel_h <= 4 && p.kernel_w <= 4) {
    mode = 1;
    tile_out_h = 16;
    tile_out_w = 64;
  }

  if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 &&
      p.kernel_h <= 3 && p.kernel_w <= 3) {
    mode = 2;
    tile_out_h = 16;
    tile_out_w = 64;
  }

  if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 &&
      p.kernel_h <= 4 && p.kernel_w <= 4) {
    mode = 3;
    tile_out_h = 16;
    tile_out_w = 64;
  }

  if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 &&
      p.kernel_h <= 2 && p.kernel_w <= 2) {
    mode = 4;
    tile_out_h = 16;
    tile_out_w = 64;
  }

  if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 &&
      p.kernel_h <= 4 && p.kernel_w <= 4) {
    mode = 5;
    tile_out_h = 8;
    tile_out_w = 32;
  }

  if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 &&
      p.kernel_h <= 2 && p.kernel_w <= 2) {
    mode = 6;
    tile_out_h = 8;
    tile_out_w = 32;
  }

  dim3 block_size;
  dim3 grid_size;

  if (tile_out_h > 0 && tile_out_w > 0) {
    p.loop_major = (p.major_dim - 1) / 16384 + 1;
    p.loop_x = 1;
    block_size = dim3(32 * 8, 1, 1);
    grid_size = dim3(((p.out_h - 1) / tile_out_h + 1) * p.minor_dim,
                     (p.out_w - 1) / (p.loop_x * tile_out_w) + 1,
                     (p.major_dim - 1) / p.loop_major + 1);
  } else {
    p.loop_major = (p.major_dim - 1) / 16384 + 1;
    p.loop_x = 4;
    block_size = dim3(4, 32, 1);
    grid_size = dim3((p.out_h * p.minor_dim - 1) / block_size.x + 1,
                     (p.out_w - 1) / (p.loop_x * block_size.y) + 1,
                     (p.major_dim - 1) / p.loop_major + 1);
  }

  AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&] {
    switch (mode) {
    case 1:
      upfirdn2d_kernel<scalar_t, 1, 1, 1, 1, 4, 4, 16, 64>
          <<<grid_size, block_size, 0, stream>>>(out.data_ptr<scalar_t>(),
                                                 x.data_ptr<scalar_t>(),
                                                 k.data_ptr<scalar_t>(), p);

      break;

    case 2:
      upfirdn2d_kernel<scalar_t, 1, 1, 1, 1, 3, 3, 16, 64>
          <<<grid_size, block_size, 0, stream>>>(out.data_ptr<scalar_t>(),
                                                 x.data_ptr<scalar_t>(),
                                                 k.data_ptr<scalar_t>(), p);

      break;

    case 3:
      upfirdn2d_kernel<scalar_t, 2, 2, 1, 1, 4, 4, 16, 64>
          <<<grid_size, block_size, 0, stream>>>(out.data_ptr<scalar_t>(),
                                                 x.data_ptr<scalar_t>(),
                                                 k.data_ptr<scalar_t>(), p);

      break;

    case 4:
      upfirdn2d_kernel<scalar_t, 2, 2, 1, 1, 2, 2, 16, 64>
          <<<grid_size, block_size, 0, stream>>>(out.data_ptr<scalar_t>(),
                                                 x.data_ptr<scalar_t>(),
                                                 k.data_ptr<scalar_t>(), p);

      break;

    case 5:
      upfirdn2d_kernel<scalar_t, 1, 1, 2, 2, 4, 4, 8, 32>
          <<<grid_size, block_size, 0, stream>>>(out.data_ptr<scalar_t>(),
                                                 x.data_ptr<scalar_t>(),
                                                 k.data_ptr<scalar_t>(), p);

      break;

    case 6:
      upfirdn2d_kernel<scalar_t, 1, 1, 2, 2, 4, 4, 8, 32>
          <<<grid_size, block_size, 0, stream>>>(out.data_ptr<scalar_t>(),
                                                 x.data_ptr<scalar_t>(),
                                                 k.data_ptr<scalar_t>(), p);

      break;

    default:
      upfirdn2d_kernel_large<scalar_t><<<grid_size, block_size, 0, stream>>>(
          out.data_ptr<scalar_t>(), x.data_ptr<scalar_t>(),
          k.data_ptr<scalar_t>(), p);
    }
  });

  return out;
}

================================================
FILE: options.py
================================================
import configargparse
from munch import *
from pdb import set_trace as st

class BaseOptions():
    def __init__(self):
        self.parser = configargparse.ArgumentParser()
        self.initialized = False

    def initialize(self):
        # Dataset options
        dataset = self.parser.add_argument_group('dataset')
        dataset.add_argument("--dataset_path", type=str, default='./datasets/FFHQ', help="path to the lmdb dataset")

        # Experiment Options
        experiment = self.parser.add_argument_group('experiment')
        experiment.add_argument('--config', is_config_file=True, help='config file path')
        experiment.add_argument("--expname", type=str, default='debug', help='experiment name')
        experiment.add_argument("--ckpt", type=str, default='300000', help="path to the checkpoints to resume training")
        experiment.add_argument("--continue_training", action="store_true", help="continue training the model")

        # Training loop options
        training = self.parser.add_argument_group('training')
        training.add_argument("--checkpoints_dir", type=str, default='./checkpoint', help='checkpoints directory name')
        training.add_argument("--iter", type=int, default=300000, help="total number of training iterations")
        training.add_argument("--batch", type=int, default=4, help="batch sizes for each GPU. A single RTX2080 can fit batch=4, chunck=1 into memory.")
        training.add_argument("--chunk", type=int, default=4, help='number of samples within a batch to processed in parallel, decrease if running out of memory')
        training.add_argument("--val_n_sample", type=int, default=8, help="number of test samples generated during training")
        training.add_argument("--d_reg_every", type=int, default=16, help="interval for applying r1 regularization to the StyleGAN generator")
        training.add_argument("--g_reg_every", type=int, default=4, help="interval for applying path length regularization to the StyleGAN generator")
        training.add_argument("--local_rank", type=int, default=0, help="local rank for distributed training")
        training.add_argument("--mixing", type=float, default=0.9, help="probability of latent code mixing")
        training.add_argument("--lr", type=float, default=0.002, help="learning rate")
        training.add_argument("--r1", type=float, default=10, help="weight of the r1 regularization")
        training.add_argument("--view_lambda", type=float, default=15, help="weight of the viewpoint regularization")
        training.add_argument("--eikonal_lambda", type=float, default=0.1, help="weight of the eikonal regularization")
        training.add_argument("--min_surf_lambda", type=float, default=0.05, help="weight of the minimal surface regularization")
        training.add_argument("--min_surf_beta", type=float, default=100.0, help="weight of the minimal surface regularization")
        training.add_argument("--path_regularize", type=float, default=2, help="weight of the path length regularization")
        training.add_argument("--path_batch_shrink", type=int, default=2, help="batch size reducing factor for the path length regularization (reduce memory consumption)")
        training.add_argument("--wandb", action="store_true", help="use weights and biases logging")
        training.add_argument("--no_sphere_init", action="store_true", help="do not initialize the volume renderer with a sphere SDF")

        # Inference Options
        inference = self.parser.add_argument_group('inference')
        inference.add_argument("--results_dir", type=str, default='./evaluations', help='results/evaluations directory name')
        inference.add_argument("--truncation_ratio", type=float, default=0.5, help="truncation ratio, controls the diversity vs. quality tradeoff. Higher truncation ratio would generate more diverse results")
        inference.add_argument("--truncation_mean", type=int, default=10000, help="number of vectors to calculate mean for the truncation")
        inference.add_argument("--identities", type=int, default=16, help="number of identities to be generated")
        inference.add_argument("--num_views_per_id", type=int, default=1, help="number of viewpoints generated per identity")
        inference.add_argument("--no_surface_renderings", action="store_true", help="when true, only RGB outputs will be generated. otherwise, both RGB and depth videos/renderings will be generated. this cuts the processing time per video")
        inference.add_argument("--fixed_camera_angles", action="store_true", help="when true, the generator will render indentities from a fixed set of camera angles.")
        inference.add_argument("--azim_video", action="store_true", help="when true, the camera trajectory will travel along the azimuth direction. Otherwise, the camera will travel along an ellipsoid trajectory.")

        # Generator options
        model = self.parser.add_argument_group('model')
        model.add_argument("--size", type=int, default=256, help="image sizes for the model")
        model.add_argument("--style_dim", type=int, default=256, help="number of style input dimensions")
        model.add_argument("--channel_multiplier", type=int, default=2, help="channel multiplier factor for the StyleGAN decoder. config-f = 2, else = 1")
        model.add_argument("--n_mlp", type=int, default=8, help="number of mlp layers in stylegan's mapping network")
        model.add_argument("--lr_mapping", type=float, default=0.01, help='learning rate reduction for mapping network MLP layers')
        model.add_argument("--renderer_spatial_output_dim", type=int, default=64, help='spatial resolution of the StyleGAN decoder inputs')
        model.add_argument("--project_noise", action='store_true', help='when true, use geometry-aware noise projection to reduce flickering effects (see supplementary section C.1 in the paper). warning: processing time significantly increases with this flag to ~20 minutes per video.')

        # Camera options
        camera = self.parser.add_argument_group('camera')
        camera.add_argument("--uniform", action="store_true", help="when true, the camera position is sampled from uniform distribution. Gaussian distribution is the default")
        camera.add_argument("--azim", type=float, default=0.3, help="camera azimuth angle std/range in Radians")
        camera.add_argument("--elev", type=float, default=0.15, help="camera elevation angle std/range in Radians")
        camera.add_argument("--fov", type=float, default=6, help="camera field of view half angle in Degrees")
        camera.add_argument("--dist_radius", type=float, default=0.12, help="radius of points sampling distance from the origin. determines the near and far fields")

        # Volume Renderer options
        rendering = self.parser.add_argument_group('rendering')
        # MLP model parameters
        rendering.add_argument("--depth", type=int, default=8, help='layers in network')
        rendering.add_argument("--width", type=int, default=256, help='channels per layer')
        # Volume representation options
        rendering.add_argument("--no_sdf", action='store_true', help='By default, the raw MLP outputs represent an underline signed distance field (SDF). When true, the MLP outputs represent the traditional NeRF density field.')
        rendering.add_argument("--no_z_normalize", action='store_true', help='By default, the model normalizes input coordinates such that the z coordinate is in [-1,1]. When true that feature is disabled.')
        rendering.add_argument("--static_viewdirs", action='store_true', help='when true, use static viewing direction input to the MLP')
        # Ray intergration options
        rendering.add_argument("--N_samples", type=int, default=24, help='number of samples per ray')
        rendering.add_argument("--no_offset_sampling", action='store_true', help='when true, use random stratified sampling when rendering the volume, otherwise offset sampling is used. (See Equation (3) in Sec. 3.2 of the paper)')
        rendering.add_argument("--perturb", type=float, default=1., help='set to 0. for no jitter, 1. for jitter')
        rendering.add_argument("--raw_noise_std", type=float, default=0., help='std dev of noise added to regularize sigma_a output, 1e0 recommended')
        rendering.add_argument("--force_background", action='store_true', help='force the last depth sample to act as background in case of a transparent ray')
        # Set volume renderer outputs
        rendering.add_argument("--return_xyz", action='store_true', help='when true, the volume renderer also returns the xyz point could of the surface. This point cloud is used to produce depth map renderings')
        rendering.add_argument("--return_sdf", action='store_true', help='when true, the volume renderer also returns the SDF network outputs for each location in the volume')

        self.initialized = True

    def parse(self):
        self.opt = Munch()
        if not self.initialized:
            self.initialize()
        try:
            args = self.parser.parse_args()
        except: # solves argparse error in google colab
            args = self.parser.parse_args(args=[])

        for group in self.parser._action_groups[2:]:
            title = group.title
            self.opt[title] = Munch()
            for action in group._group_actions:
                dest = action.dest
                self.opt[title][dest] = args.__getattribute__(dest)

        return self.opt


================================================
FILE: prepare_data.py
================================================
import argparse
from io import BytesIO
import multiprocessing
from functools import partial

from PIL import Image
import lmdb
from tqdm import tqdm
from torchvision import datasets
from torchvision.transforms import functional as trans_fn
from pdb import set_trace as st


def resize_and_convert(img, size, resample):
    img = trans_fn.resize(img, size, resample)
    img = trans_fn.center_crop(img, size)
    buffer = BytesIO()
    img.save(buffer, format="png")
    val = buffer.getvalue()

    return val


def resize_multiple(
    img, sizes=(128, 256, 512, 1024), resample=Image.LANCZOS):
    imgs = []

    for size in sizes:
        imgs.append(resize_and_convert(img, size, resample))

    return imgs


def resize_worker(img_file, sizes, resample):
    i, file = img_file
    img = Image.open(file)
    img = img.convert("RGB")
    out = resize_multiple(img, sizes=sizes, resample=resample)

    return i, out


def prepare(
    env, dataset, n_worker, sizes=(128, 256, 512, 1024), resample=Image.LANCZOS
):
    resize_fn = partial(resize_worker, sizes=sizes, resample=resample)

    files = sorted(dataset.imgs, key=lambda x: x[0])
    files = [(i, file) for i, (file, label) in enumerate(files)]
    total = 0

    with multiprocessing.Pool(n_worker) as pool:
        for i, imgs in tqdm(pool.imap_unordered(resize_fn, files)):
            for size, img in zip(sizes, imgs):
                key = f"{size}-{str(i).zfill(5)}".encode("utf-8")

                with env.begin(write=True) as txn:
                    txn.put(key, img)

            total += 1

        with env.begin(write=True) as txn:
            txn.put("length".encode("utf-8"), str(total).encode("utf-8"))


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Preprocess images for model training")
    parser.add_argument("--size", type=str, default="64,512,1024", help="resolutions of images for the dataset")
    parser.add_argument("--n_worker", type=int, default=32, help="number of workers for preparing dataset")
    parser.add_argument("--resample", type=str, default="lanczos", help="resampling methods for resizing images")
    parser.add_argument("--out_path", type=str, default="datasets/FFHQ", help="Target path of the output lmdb dataset")
    parser.add_argument("in_path", type=str, help="path to the input image dataset")
    args = parser.parse_args()

    resample_map = {"lanczos": Image.LANCZOS, "bilinear": Image.BILINEAR}
    resample = resample_map[args.resample]

    sizes = [int(s.strip()) for s in args.size.split(",")]

    print(f"Make dataset of image sizes:", ", ".join(str(s) for s in sizes))

    imgset = datasets.ImageFolder(args.in_path)

    with lmdb.open(args.out_path, map_size=1024 ** 4, readahead=False) as env:
        prepare(env, imgset, args.n_worker, sizes=sizes, resample=resample)


================================================
FILE: render_video.py
================================================
import os
import torch
import trimesh
import numpy as np
import skvideo.io
from munch import *
from PIL import Image
from tqdm import tqdm
from torch.nn import functional as F
from torch.utils import data
from torchvision import utils
from torchvision import transforms
from options import BaseOptions
from model import Generator
from utils import (
    generate_camera_params, align_volume, extract_mesh_with_marching_cubes,
    xyz2mesh, create_cameras, create_mesh_renderer, add_textures,
    )
from pytorch3d.structures import Meshes

from pdb import set_trace as st

torch.random.manual_seed(1234)


def render_video(opt, g_ema, surface_g_ema, device, mean_latent, surface_mean_latent):
    g_ema.eval()
    if not opt.no_surface_renderings or opt.project_noise:
        surface_g_ema.eval()

    images = torch.Tensor(0, 3, opt.size, opt.size)
    num_frames = 250
    # Generate video trajectory
    trajectory = np.zeros((num_frames,3), dtype=np.float32)

    # set camera trajectory
    # sweep azimuth angles (4 seconds)
    if opt.azim_video:
        t = np.linspace(0, 1, num_frames)
        elev = 0
        fov = opt.camera.fov
        if opt.camera.uniform:
            azim = opt.camera.azim * np.cos(t * 2 * np.pi)
        else:
            azim = 1.5 * opt.camera.azim * np.cos(t * 2 * np.pi)

        trajectory[:num_frames,0] = azim
        trajectory[:num_frames,1] = elev
        trajectory[:num_frames,2] = fov

    # elipsoid sweep (4 seconds)
    else:
        t = np.linspace(0, 1, num_frames)
        fov = opt.camera.fov #+ 1 * np.sin(t * 2 * np.pi)
        if opt.camera.uniform:
            elev = opt.camera.elev / 2 + opt.camera.elev / 2  * np.sin(t * 2 * np.pi)
            azim = opt.camera.azim  * np.cos(t * 2 * np.pi)
        else:
            elev = 1.5 * opt.camera.elev * np.sin(t * 2 * np.pi)
            azim = 1.5 * opt.camera.azim * np.cos(t * 2 * np.pi)

        trajectory[:num_frames,0] = azim
        trajectory[:num_frames,1] = elev
        trajectory[:num_frames,2] = fov

    trajectory = torch.from_numpy(trajectory).to(device)

    # generate input parameters for the camera trajectory
    # sample_cam_poses, sample_focals, sample_near, sample_far = \
    # generate_camera_params(trajectory, opt.renderer_output_size, device, dist_radius=opt.camera.dist_radius)
    sample_cam_extrinsics, sample_focals, sample_near, sample_far, _ = \
    generate_camera_params(opt.renderer_output_size, device, locations=trajectory[:,:2],
                           fov_ang=trajectory[:,2:], dist_radius=opt.camera.dist_radius)

    # In case of noise projection, generate input parameters for the frontal position.
    # The reference mesh for the noise projection is extracted from the frontal position.
    # For more details see section C.1 in the supplementary material.
    if opt.project_noise:
        frontal_pose = torch.tensor([[0.0,0.0,opt.camera.fov]]).to(device)
        # frontal_cam_pose, frontal_focals, frontal_near, frontal_far = \
        # generate_camera_params(frontal_pose, opt.surf_extraction_output_size, device, dist_radius=opt.camera.dist_radius)
        frontal_cam_pose, frontal_focals, frontal_near, frontal_far, _ = \
        generate_camera_params(opt.surf_extraction_output_size, device, location=frontal_pose[:,:2],
                               fov_ang=frontal_pose[:,2:], dist_radius=opt.camera.dist_radius)

    # create geometry renderer (renders the depth maps)
    cameras = create_cameras(azim=np.rad2deg(trajectory[0,0].cpu().numpy()),
                             elev=np.rad2deg(trajectory[0,1].cpu().numpy()),
                             dist=1, device=device)
    renderer = create_mesh_renderer(cameras, image_size=512, specular_color=((0,0,0),),
                    ambient_color=((0.1,.1,.1),), diffuse_color=((0.75,.75,.75),),
                    device=device)

    suffix = '_azim' if opt.azim_video else '_elipsoid'

    # generate videos
    for i in range(opt.identities):
        print('Processing identity {}/{}...'.format(i+1, opt.identities))
        chunk = 1
        sample_z = torch.randn(1, opt.style_dim, device=device).repeat(chunk,1)
        video_filename = 'sample_video_{}{}.mp4'.format(i,suffix)
        writer = skvideo.io.FFmpegWriter(os.path.join(opt.results_dst_dir, video_filename),
                                         outputdict={'-pix_fmt': 'yuv420p', '-crf': '10'})
        if not opt.no_surface_renderings:
            depth_video_filename = 'sample_depth_video_{}{}.mp4'.format(i,suffix)
            depth_writer = skvideo.io.FFmpegWriter(os.path.join(opt.results_dst_dir, depth_video_filename),
                                             outputdict={'-pix_fmt': 'yuv420p', '-crf': '1'})


        ####################### Extract initial surface mesh from the frontal viewpoint #############
        # For more details see section C.1 in the supplementary material.
        if opt.project_noise:
            with torch.no_grad():
                frontal_surface_out = surface_g_ema([sample_z],
                                                    frontal_cam_pose,
                                                    frontal_focals,
                                                    frontal_near,
                                                    frontal_far,
                                                    truncation=opt.truncation_ratio,
                                                    truncation_latent=surface_mean_latent,
                                                    return_sdf=True)
                frontal_sdf = frontal_surface_out[2].cpu()

            print('Extracting Identity {} Frontal view Marching Cubes for consistent video rendering'.format(i))

            frostum_aligned_frontal_sdf = align_volume(frontal_sdf)
            del frontal_sdf

            try:
                frontal_marching_cubes_mesh = extract_mesh_with_marching_cubes(frostum_aligned_frontal_sdf)
            except ValueError:
                frontal_marching_cubes_mesh = None

            if frontal_marching_cubes_mesh != None:
                frontal_marching_cubes_mesh_filename = os.path.join(opt.results_dst_dir,'sample_{}_frontal_marching_cubes_mesh{}.obj'.format(i,suffix))
                with open(frontal_marching_cubes_mesh_filename, 'w') as f:
                    frontal_marching_cubes_mesh.export(f,file_type='obj')

            del frontal_surface_out
            torch.cuda.empty_cache()
        #############################################################################################

        for j in tqdm(range(0, num_frames, chunk)):
            with torch.no_grad():
                out = g_ema([sample_z],
                            sample_cam_extrinsics[j:j+chunk],
                            sample_focals[j:j+chunk],
                            sample_near[j:j+chunk],
                            sample_far[j:j+chunk],
                            truncation=opt.truncation_ratio,
                            truncation_latent=mean_latent,
                            randomize_noise=False,
                            project_noise=opt.project_noise,
                            mesh_path=frontal_marching_cubes_mesh_filename if opt.project_noise else None)

                rgb = out[0].cpu()

                # this is done to fit to RTX2080 RAM size (11GB)
                del out
                torch.cuda.empty_cache()

                # Convert RGB from [-1, 1] to [0,255]
                rgb = 127.5 * (rgb.clamp(-1,1).permute(0,2,3,1).cpu().numpy() + 1)

                # Add RGB, frame to video
                for k in range(chunk):
                    writer.writeFrame(rgb[k])

                ########## Extract surface ##########
                if not opt.no_surface_renderings:
                    scale = surface_g_ema.renderer.out_im_res / g_ema.renderer.out_im_res
                    surface_sample_focals = sample_focals * scale
                    surface_out = surface_g_ema([sample_z],
                                                sample_cam_extrinsics[j:j+chunk],
                                                surface_sample_focals[j:j+chunk],
                                                sample_near[j:j+chunk],
                                                sample_far[j:j+chunk],
                                                truncation=opt.truncation_ratio,
                                                truncation_latent=surface_mean_latent,
                                                return_xyz=True)
                    xyz = surface_out[2].cpu()

                    # this is done to fit to RTX2080 RAM size (11GB)
                    del surface_out
                    torch.cuda.empty_cache()

                    # Render mesh for video
                    depth_mesh = xyz2mesh(xyz)
                    mesh = Meshes(
                        verts=[torch.from_numpy(np.asarray(depth_mesh.vertices)).to(torch.float32).to(device)],
                        faces = [torch.from_numpy(np.asarray(depth_mesh.faces)).to(torch.float32).to(device)],
                        textures=None,
                        verts_normals=[torch.from_numpy(np.copy(np.asarray(depth_mesh.vertex_normals))).to(torch.float32).to(device)],
                    )
                    mesh = add_textures(mesh)
                    cameras = create_cameras(azim=np.rad2deg(trajectory[j,0].cpu().numpy()),
                                             elev=np.rad2deg(trajectory[j,1].cpu().numpy()),
                                             fov=2*trajectory[j,2].cpu().numpy(),
                                             dist=1, device=device)
                    renderer = create_mesh_renderer(cameras, image_size=512,
                                                    light_location=((0.0,1.0,5.0),), specular_color=((0.2,0.2,0.2),),
                                                    ambient_color=((0.1,0.1,0.1),), diffuse_color=((0.65,.65,.65),),
                                                    device=device)

                    mesh_image = 255 * renderer(mesh).cpu().numpy()
                    mesh_image = mesh_image[...,:3]

                    # Add depth frame to video
                    for k in range(chunk):
                        depth_writer.writeFrame(mesh_image[k])

        # Close video writers
        writer.close()
        if not opt.no_surface_renderings:
            depth_writer.close()


if __name__ == "__main__":
    device = "cuda"
    opt = BaseOptions().parse()
    opt.model.is_test = True
    opt.model.style_dim = 256
    opt.model.freeze_renderer = False
    opt.inference.size = opt.model.size
    opt.inference.camera = opt.camera
    opt.inference.renderer_output_size = opt.model.renderer_spatial_output_dim
    opt.inference.style_dim = opt.model.style_dim
    opt.inference.project_noise = opt.model.project_noise
    opt.rendering.perturb = 0
    opt.rendering.force_background = True
    opt.rendering.static_viewdirs = True
    opt.rendering.return_sdf = True
    opt.rendering.N_samples = 64

    # find checkpoint directory
    # check if there's a fully trained model
    checkpoints_dir = 'full_models'
    checkpoint_path = os.path.join(checkpoints_dir, opt.experiment.expname + '.pt')
    if os.path.isfile(checkpoint_path):
        # define results directory name
        result_model_dir = 'final_model'
    else:
        checkpoints_dir = os.path.join('checkpoint', opt.experiment.expname, 'full_pipeline')
        checkpoint_path = os.path.join(checkpoints_dir,
                                       'models_{}.pt'.format(opt.experiment.ckpt.zfill(7)))
        # define results directory name
        result_model_dir = 'iter_{}'.format(opt.experiment.ckpt.zfill(7))

    results_dir_basename = os.path.join(opt.inference.results_dir, opt.experiment.expname)
    opt.inference.results_dst_dir = os.path.join(results_dir_basename, result_model_dir, 'videos')
    if opt.model.project_noise:
        opt.inference.results_dst_dir = os.path.join(opt.inference.results_dst_dir, 'with_noise_projection')

    os.makedirs(opt.inference.results_dst_dir, exist_ok=True)

    # load saved model
    checkpoint = torch.load(checkpoint_path)

    # load image generation model
    g_ema = Generator(opt.model, opt.rendering).to(device)

    # temp fix because of wrong noise sizes
    pretrained_weights_dict = checkpoint["g_ema"]
    model_dict = g_ema.state_dict()
    for k, v in pretrained_weights_dict.items():
        if v.size() == model_dict[k].size():
            model_dict[k] = v

    g_ema.load_state_dict(model_dict)

    # load a the volume renderee to a second that extracts surfaces at 128x128x128
    if not opt.inference.no_surface_renderings or opt.model.project_noise:
        opt['surf_extraction'] = Munch()
        opt.surf_extraction.rendering = opt.rendering
        opt.surf_extraction.model = opt.model.copy()
        opt.surf_extraction.model.renderer_spatial_output_dim = 128
        opt.surf_extraction.rendering.N_samples = opt.surf_extraction.model.renderer_spatial_output_dim
        opt.surf_extraction.rendering.return_xyz = True
        opt.surf_extraction.rendering.return_sdf = True
        opt.inference.surf_extraction_output_size = opt.surf_extraction.model.renderer_spatial_output_dim
        surface_g_ema = Generator(opt.surf_extraction.model, opt.surf_extraction.rendering, full_pipeline=False).to(device)

        # Load weights to surface extractor
        surface_extractor_dict = surface_g_ema.state_dict()
        for k, v in pretrained_weights_dict.items():
            if k in surface_extractor_dict.keys() and v.size() == surface_extractor_dict[k].size():
                surface_extractor_dict[k] = v

        surface_g_ema.load_state_dict(surface_extractor_dict)
    else:
        surface_g_ema = None

    # get the mean latent vector for g_ema
    if opt.inference.truncation_ratio < 1:
        with torch.no_grad():
            mean_latent = g_ema.mean_latent(opt.inference.truncation_mean, device)
    else:
        mean_latent = None

    # get the mean latent vector for surface_g_ema
    if not opt.inference.no_surface_renderings or opt.model.project_noise:
        surface_mean_latent = mean_latent[0]
    else:
        surface_mean_latent = None

    render_video(opt.inference, g_ema, surface_g_ema, device, mean_latent, surface_mean_latent)


================================================
FILE: requirements.txt
================================================
lmdb
numpy
ninja
pillow
requests
tqdm
scipy
scikit-image
scikit-video
trimesh[easy]
configargparse
munch
wandb


================================================
FILE: scripts/train_afhq_full_pipeline_512x512.sh
================================================
python -m torch.distributed.launch --nproc_per_node 4 new_train.py --batch 8 --chunk 4 --azim 0.15 --r1 50.0 --expname afhq512x512 --dataset_path ./datasets/AFHQ/train/ --size 512 --wandb


================================================
FILE: scripts/train_afhq_vol_renderer.sh
================================================
python -m torch.distributed.launch --nproc_per_node 2 train_volume_renderer.py --batch 12 --chunk 6 --expname afhq_sdf_vol_renderer --dataset_path ./datasets/AFHQ/train --azim 0.15 --wandb


================================================
FILE: scripts/train_ffhq_full_pipeline_1024x1024.sh
================================================
python -m torch.distributed.launch --nproc_per_node 4 train_full_pipeline.py --batch 8 --chunk 2 --expname ffhq1024x1024 --size 1024 --wandb


================================================
FILE: scripts/train_ffhq_vol_renderer.sh
================================================
python -m torch.distributed.launch --nproc_per_node 2 train_volume_renderer.py --batch 12 --chunk 6 --expname ffhq_sdf_vol_renderer --dataset_path ./datasets/FFHQ/ --wandb


================================================
FILE: train_full_pipeline.py
================================================
import argparse
import math
import random
import os
import yaml
import numpy as np
import torch
from torch import nn, autograd, optim
from torch.nn import functional as F
from torch.utils import data
import torch.distributed as dist
from torchvision import transforms, utils
from tqdm import tqdm
from PIL import Image
from losses import *
from options import BaseOptions
from model import Generator, Discriminator
from dataset import MultiResolutionDataset
from utils import data_sampler, requires_grad, accumulate, sample_data, make_noise, mixing_noise, generate_camera_params
from distributed import get_rank, synchronize, reduce_loss_dict, reduce_sum, get_world_size

try:
    import wandb
except ImportError:
    wandb = None


def train(opt, experiment_opt, loader, generator, discriminator, g_optim, d_optim, g_ema, device):
    loader = sample_data(loader)

    pbar = range(opt.iter)

    if get_rank() == 0:
        pbar = tqdm(pbar, initial=opt.start_iter, dynamic_ncols=True, smoothing=0.01)

    mean_path_length = 0

    d_loss_val = 0
    g_loss_val = 0
    r1_loss = torch.tensor(0.0, device=device)
    g_gan_loss = torch.tensor(0.0, device=device)
    path_loss = torch.tensor(0.0, device=device)
    path_lengths = torch.tensor(0.0, device=device)
    mean_path_length_avg = 0
    loss_dict = {}

    if opt.distributed:
        g_module = generator.module
        d_module = discriminator.module
    else:
        g_module = generator
        d_module = discriminator

    accum = 0.5 ** (32 / (10 * 1000))

    sample_z = [torch.randn(opt.val_n_sample, opt.style_dim, device=device).repeat_interleave(8, dim=0)]
    sample_cam_extrinsics, sample_focals, sample_near, sample_far, _ = generate_camera_params(opt.renderer_output_size, device, batch=opt.val_n_sample, sweep=True,
                                                                                         uniform=opt.camera.uniform, azim_range=opt.camera.azim,
                                                                                         elev_range=opt.camera.elev, fov_ang=opt.camera.fov,
                                                                                         dist_radius=opt.camera.dist_radius)

    for idx in pbar:
        i = idx + opt.start_iter

        if i > opt.iter:
            print("Done!")

            break

        requires_grad(generator, False)
        requires_grad(discriminator, True)
        discriminator.zero_grad()
        d_regularize = i % opt.d_reg_every == 0
        real_imgs, real_thumb_imgs = next(loader)
        real_imgs = real_imgs.to(device)
        real_thumb_imgs = real_thumb_imgs.to(device)
        noise = mixing_noise(opt.batch, opt.style_dim, opt.mixing, device)
        cam_extrinsics, focal, near, far, gt_viewpoints = generate_camera_params(opt.renderer_output_size, device, batch=opt.batch,
                                                                            uniform=opt.camera.uniform, azim_range=opt.camera.azim,
                                                                            elev_range=opt.camera.elev, fov_ang=opt.camera.fov,
                                                                            dist_radius=opt.camera.dist_radius)

        for j in range(0, opt.batch, opt.chunk):
            curr_real_imgs = real_imgs[j:j+opt.chunk]
            curr_real_thumb_imgs = real_thumb_imgs[j:j+opt.chunk]
            curr_noise = [n[j:j+opt.chunk] for n in noise]
            gen_imgs, _ = generator(curr_noise,
                                    cam_extrinsics[j:j+opt.chunk],
                                    focal[j:j+opt.chunk],
                                    near[j:j+opt.chunk],
                                    far[j:j+opt.chunk])

            fake_pred = discriminator(gen_imgs.detach())

            if d_regularize:
                curr_real_imgs.requires_grad = True
                curr_real_thumb_imgs.requires_grad = True

            real_pred = discriminator(curr_real_imgs)
            d_gan_loss = d_logistic_loss(real_pred, fake_pred)

            if d_regularize:
                grad_penalty = d_r1_loss(real_pred, curr_real_imgs)
                r1_loss = opt.r1 * 0.5 * grad_penalty * opt.d_reg_every
            else:
                r1_loss = torch.zeros_like(r1_loss)

            d_loss = d_gan_loss + r1_loss
            d_loss.backward()

        d_optim.step()

        loss_dict["d"] = d_gan_loss
        loss_dict["real_score"] = real_pred.mean()
        loss_dict["fake_score"] = fake_pred.mean()
        if d_regularize or i == opt.start_iter:
            loss_dict["r1"] = r1_loss.mean()

        requires_grad(generator, True)
        requires_grad(discriminator, False)

        for j in range(0, opt.batch, opt.chunk):
            noise = mixing_noise(opt.chunk, opt.style_dim, opt.mixing, device)
            cam_extrinsics, focal, near, far, gt_viewpoints = generate_camera_params(opt.renderer_output_size, device, batch=opt.chunk,
                                                                                uniform=opt.camera.uniform, azim_range=opt.camera.azim,
                                                                                elev_range=opt.camera.elev, fov_ang=opt.camera.fov,
                                                                                dist_radius=opt.camera.dist_radius)

            fake_img, _ = generator(noise, cam_extrinsics, focal, near, far)
            fake_pred = discriminator(fake_img)
            g_gan_loss = g_nonsaturating_loss(fake_pred)

            g_loss = g_gan_loss
            g_loss.backward()


        g_optim.step()
        generator.zero_grad()

        loss_dict["g"] = g_gan_loss

        # generator path regularization
        g_regularize = (opt.g_reg_every > 0) and (i % opt.g_reg_every == 0)
        if g_regularize:
            path_batch_size = max(1, opt.batch // opt.path_batch_shrink)
            path_noise = mixing_noise(path_batch_size, opt.style_dim, opt.mixing, device)
            path_cam_extrinsics, path_focal, path_near, path_far, _ = generate_camera_params(opt.renderer_output_size, device, batch=path_batch_size,
                                                                                        uniform=opt.camera.uniform, azim_range=opt.camera.azim,
                                                                                        elev_range=opt.camera.elev, fov_ang=opt.camera.fov,
                                                                                        dist_radius=opt.camera.dist_radius)

            for j in range(0, path_batch_size, opt.chunk):
                path_fake_img, path_latents = generator(path_noise, path_cam_extrinsics,
                                                        path_focal, path_near, path_far,
                                                        return_latents=True)

                path_loss, mean_path_length, path_lengths = g_path_regularize(
                    path_fake_img, path_latents, mean_path_length
                )

                weighted_path_loss = opt.path_regularize * opt.g_reg_every * path_loss# * opt.chunk / path_batch_size
                if opt.path_batch_shrink:
                    weighted_path_loss += 0 * path_fake_img[0, 0, 0, 0]

                weighted_path_loss.backward()

            g_optim.step()
            generator.zero_grad()

            mean_path_length_avg = (reduce_sum(mean_path_length).item() / get_world_size())

        loss_dict["path"] = path_loss
        loss_dict["path_length"] = path_lengths.mean()

        accumulate(g_ema, g_module, accum)
        loss_reduced = reduce_loss_dict(loss_dict)
        d_loss_val = loss_reduced["d"].mean().item()
        g_loss_val = loss_reduced["g"].mean().item()
        r1_val = loss_reduced["r1"].mean().item()
        real_score_val = loss_reduced["real_score"].mean().item()
        fake_score_val = loss_reduced["fake_score"].mean().item()
        path_loss_val = loss_reduced["path"].mean().item()
        path_length_val = loss_reduced["path_length"].mean().item()

        if get_rank() == 0:
            pbar.set_description(
                (f"d: {d_loss_val:.4f}; g: {g_loss_val:.4f}; r1: {r1_val:.4f}; path: {path_loss_val:.4f}")
            )

            if i % 1000 == 0 or i == opt.start_iter:
                with torch.no_grad():
                    thumbs_samples = torch.Tensor(0, 3, opt.renderer_output_size, opt.renderer_output_size)
                    samples = torch.Tensor(0, 3, opt.size, opt.size)
                    step_size = 8
                    mean_latent = g_module.mean_latent(10000, device)
                    for k in range(0, opt.val_n_sample * 8, step_size):
                        curr_samples, curr_thumbs = g_ema([sample_z[0][k:k+step_size]],
                                                           sample_cam_extrinsics[k:k+step_size],
                                                           sample_focals[k:k+step_size],
                                                           sample_near[k:k+step_size],
                                                           sample_far[k:k+step_size],
                                                           truncation=0.7,
                                                           truncation_latent=mean_latent)
                        samples = torch.cat([samples, curr_samples.cpu()], 0)
                        thumbs_samples = torch.cat([thumbs_samples, curr_thumbs.cpu()], 0)

                    if i % 10000 == 0:
                        utils.save_image(samples,
                            os.path.join(opt.checkpoints_dir, experiment_opt.expname, 'full_pipeline', f"samples/{str(i).zfill(7)}.png"),
                            nrow=int(opt.val_n_sample),
                            normalize=True,
                            value_range=(-1, 1),)

                        utils.save_image(thumbs_samples,
                            os.path.join(opt.checkpoints_dir, experiment_opt.expname, 'full_pipeline', f"samples/{str(i).zfill(7)}_thumbs.png"),
                            nrow=int(opt.val_n_sample),
                            normalize=True,
                            value_range=(-1, 1),)

            if wandb and opt.wandb:
                wandb_log_dict = {"Generator": g_loss_val,
                                  "Discriminator": d_loss_val,
                                  "R1": r1_val,
                                  "Real Score": real_score_val,
                                  "Fake Score": fake_score_val,
                                  "Path Length Regularization": path_loss_val,
                                  "Path Length": path_length_val,
                                  "Mean Path Length": mean_path_length,
                                  }
                if i % 5000 == 0:
                    wandb_grid = utils.make_grid(samples, nrow=int(opt.val_n_sample),
                                                   normalize=True, value_range=(-1, 1))
                    wandb_ndarr = (255 * wandb_grid.permute(1, 2, 0).numpy()).astype(np.uint8)
                    wandb_images = Image.fromarray(wandb_ndarr)
                    wandb_log_dict.update({"examples": [wandb.Image(wandb_images,
                                            caption="Generated samples for azimuth angles of: -0.35, -0.25, -0.15, -0.05, 0.05, 0.15, 0.25, 0.35 Radians.")]})

                    wandb_thumbs_grid = utils.make_grid(thumbs_samples, nrow=int(opt.val_n_sample),
                                                        normalize=True, value_range=(-1, 1))
                    wandb_thumbs_ndarr = (255 * wandb_thumbs_grid.permute(1, 2, 0).numpy()).astype(np.uint8)
                    wandb_thumbs = Image.fromarray(wandb_thumbs_ndarr)
                    wandb_log_dict.update({"thumb_examples": [wandb.Image(wandb_thumbs,
                                            caption="Generated samples for azimuth angles of: -0.35, -0.25, -0.15, -0.05, 0.05, 0.15, 0.25, 0.35 Radians.")]})

                wandb.log(wandb_log_dict)

            if i % 10000 == 0:
                torch.save(
                    {
                        "g": g_module.state_dict(),
                        "d": d_module.state_dict(),
                        "g_ema": g_ema.state_dict(),
                    },
                    os.path.join(opt.checkpoints_dir, experiment_opt.expname, 'full_pipeline', f"models_{str(i).zfill(7)}.pt")
                )
                print('Successfully saved checkpoint for iteration {}.'.format(i))

    if get_rank() == 0:
        # create final model directory
        final_model_path = os.path.join('full_models', opt.experiment.expname)
        os.makedirs(final_model_path, exist_ok=True)
        torch.save(
            {
                "g": g_module.state_dict(),
                "d": d_module.state_dict(),
                "g_ema": g_ema.state_dict(),
            },
            os.path.join(final_model_path, experiment_opt.expname + '.pt')
        )
        print('Successfully saved final model.')


if __name__ == "__main__":
    device = "cuda"
    opt = BaseOptions().parse()
    opt.training.camera = opt.camera
    opt.training.size = opt.model.size
    opt.training.renderer_output_size = opt.model.renderer_spatial_output_dim
    opt.training.style_dim = opt.model.style_dim
    opt.model.freeze_renderer = True

    n_gpu = int(os.environ["WORLD_SIZE"]) if "WORLD_SIZE" in os.environ else 1
    opt.training.distributed = n_gpu > 1

    if opt.training.distributed:
        torch.cuda.set_device(opt.training.local_rank)
        torch.distributed.init_process_group(backend="nccl", init_method="env://")
        synchronize()

    # create checkpoints directories
    os.makedirs(os.path.join(opt.training.checkpoints_dir, opt.experiment.expname, 'full_pipeline'), exist_ok=True)
    os.makedirs(os.path.join(opt.training.checkpoints_dir, opt.experiment.expname, 'full_pipeline', 'samples'), exist_ok=True)

    discriminator = Discriminator(opt.model).to(device)
    generator = Generator(opt.model, opt.rendering).to(device)
    g_ema = Generator(opt.model, opt.rendering, ema=True).to(device)
    g_ema.eval()

    g_reg_ratio = opt.training.g_reg_every / (opt.training.g_reg_every + 1) if opt.training.g_reg_every > 0 else 1
    d_reg_ratio = opt.training.d_reg_every / (opt.training.d_reg_every + 1)

    params_g = []
    params_dict_g = dict(generator.named_parameters())
    for key, value in params_dict_g.items():
        decoder_cond = ('decoder' in key)
        if decoder_cond:
            params_g += [{'params':[value], 'lr':opt.training.lr * g_reg_ratio}]

    g_optim = optim.Adam(params_g, #generator.parameters(),
                         lr=opt.training.lr * g_reg_ratio,
                         betas=(0 ** g_reg_ratio, 0.99 ** g_reg_ratio))
    d_optim = optim.Adam(discriminator.parameters(),
                         lr=opt.training.lr * d_reg_ratio,# * g_d_ratio,
                         betas=(0 ** d_reg_ratio, 0.99 ** d_reg_ratio))

    opt.training.start_iter = 0
    if opt.experiment.continue_training and opt.experiment.ckpt is not None:
        if get_rank() == 0:
            print("load model:", opt.experiment.ckpt)
        ckpt_path = os.path.join(opt.training.checkpoints_dir,
                                 opt.experiment.expname,
                                 'models_{}.pt'.format(opt.experiment.ckpt.zfill(7)))
        ckpt = torch.load(ckpt_path, map_location=lambda storage, loc: storage)

        try:
            opt.training.start_iter = int(opt.experiment.ckpt) + 1

        except ValueError:
            pass

        generator.load_state_dict(ckpt["g"])
        discriminator.load_state_dict(ckpt["d"])
        g_ema.load_state_dict(ckpt["g_ema"])
    else:
        # save configuration
        opt_path = os.path.join(opt.training.checkpoints_dir, opt.experiment.expname, 'full_pipeline', f"opt.yaml")
        with open(opt_path,'w') as f:
            yaml.safe_dump(opt, f)

    if not opt.experiment.continue_training:
        if get_rank() == 0:
            print("loading pretrained renderer weights...")

        pretrained_renderer_path = os.path.join('./pretrained_renderer',
                                                opt.experiment.expname + '_vol_renderer.pt')
        try:
            ckpt = torch.load(pretrained_renderer_path, map_location=lambda storage, loc: storage)
        except:
            print('Pretrained volume renderer experiment name does not match the full pipeline experiment name.')
            vol_renderer_expname = str(input('Please enter the pretrained volume renderer experiment name:'))
            pretrained_renderer_path = os.path.join('./pretrained_renderer',
                                                    vol_renderer_expname + '.pt')
            ckpt = torch.load(pretrained_renderer_path, map_location=lambda storage, loc: storage)

        pretrained_renderer_dict = ckpt["g_ema"]
        model_dict = generator.state_dict()
        for k, v in pretrained_renderer_dict.items():
            if v.size() == model_dict[k].size():
                model_dict[k] = v

        generator.load_state_dict(model_dict)

    # initialize g_ema weights to generator weights
    accumulate(g_ema, generator, 0)

    # set distributed models
    if opt.training.distributed:
        generator = nn.parallel.DistributedDataParallel(
            generator,
            device_ids=[opt.training.local_rank],
            output_device=opt.training.local_rank,
            broadcast_buffers=True,
            find_unused_parameters=True,
        )

        discriminator = nn.parallel.DistributedDataParallel(
            discriminator,
            device_ids=[opt.training.local_rank],
            output_device=opt.training.local_rank,
            broadcast_buffers=False,
            find_unused_parameters=True
        )

    transform = transforms.Compose(
        [transforms.ToTensor(),
         transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)])

    dataset = MultiResolutionDataset(opt.dataset.dataset_path, transform, opt.model.size,
                                     opt.model.renderer_spatial_output_dim)
    loader = data.DataLoader(
        dataset,
        batch_size=opt.training.batch,
        sampler=data_sampler(dataset, shuffle=True, distributed=opt.training.distributed),
        drop_last=True,
    )

    if get_rank() == 0 and wandb is not None and opt.training.wandb:
        wandb.init(project="StyleSDF")
        wandb.run.name = opt.experiment.expname
        wandb.config.dataset = os.path.basename(opt.dataset.dataset_path)
        wandb.config.update(opt.training)
        wandb.config.update(opt.model)
        wandb.config.update(opt.rendering)

    train(opt.training, opt.experiment, loader, generator, discriminator, g_optim, d_optim, g_ema, device)


================================================
FILE: train_volume_renderer.py
================================================
import argparse
import math
import random
import os
import yaml
import numpy as np
import torch
import torch.distributed as dist
from torch import nn, autograd, optim
from torch.nn import functional as F
from torch.utils import data
from torchvision import transforms, utils
from tqdm import tqdm
from PIL import Image
from losses import *
from options import BaseOptions
from model import Generator, VolumeRenderDiscriminator
from dataset import MultiResolutionDataset
from utils import data_sampler, requires_grad, accumulate, sample_data, make_noise, mixing_noise, generate_camera_params
from distributed import get_rank, synchronize, reduce_loss_dict, reduce_sum, get_world_size

try:
    import wandb
except ImportError:
    wandb = None


def train(opt, experiment_opt, loader, generator, discriminator, g_optim, d_optim, g_ema, device):
    loader = sample_data(loader)

    mean_path_length = 0

    d_loss_val = 0
    r1_loss = torch.tensor(0.0, device=device)
    d_view_loss = torch.tensor(0.0, device=device)
    g_view_loss = torch.tensor(0.0, device=device)
    g_eikonal = torch.tensor(0.0, device=device)
    g_minimal_surface = torch.tensor(0.0, device=device)

    g_loss_val = 0
    loss_dict = {}

    viewpoint_condition = opt.view_lambda > 0

    if opt.distributed:
        g_module = generator.module
        d_module = discriminator.module
    else:
        g_module = generator
        d_module = discriminator

    accum = 0.5 ** (32 / (10 * 1000))

    sample_z = [torch.randn(opt.val_n_sample, opt.style_dim, device=device).repeat_interleave(8,dim=0)]
    sample_cam_extrinsics, sample_focals, sample_near, sample_far, _ = generate_camera_params(opt.renderer_output_size, device, batch=opt.val_n_sample, sweep=True,
                                                                                         uniform=opt.camera.uniform, azim_range=opt.camera.azim,
                                                                                         elev_range=opt.camera.elev, fov_ang=opt.camera.fov,
                                                                                         dist_radius=opt.camera.dist_radius)

    if opt.with_sdf and opt.sphere_init and opt.start_iter == 0:
        init_pbar = range(10000)
        if get_rank() == 0:
            init_pbar = tqdm(init_pbar, initial=0, dynamic_ncols=True, smoothing=0.01)

        generator.zero_grad()
        for idx in init_pbar:
            noise = mixing_noise(3, opt.style_dim, opt.mixing, device)
            cam_extrinsics, focal, near, far, gt_viewpoints = generate_camera_params(opt.renderer_output_size, device, batch=3,
                                                                                uniform=opt.camera.uniform, azim_range=opt.camera.azim,
                                                                                elev_range=opt.camera.elev, fov_ang=opt.camera.fov,
                                                                                dist_radius=opt.camera.dist_radius)
            sdf, target_values = g_module.init_forward(noise, cam_extrinsics, focal, near, far)
            loss = F.l1_loss(sdf, target_values)
            loss.backward()
            g_optim.step()
            generator.zero_grad()
            if get_rank() == 0:
                init_pbar.set_description((f"MLP init to sphere procedure - Loss: {loss.item():.4f}"))

        accumulate(g_ema, g_module, 0)
        torch.save(
            {
                "g": g_module.state_dict(),
                "d": d_module.state_dict(),
                "g_ema": g_ema.state_dict(),
            },
            os.path.join(opt.checkpoints_dir, experiment_opt.expname, f"sdf_init_models_{str(0).zfill(7)}.pt")
        )
        print('Successfully saved checkpoint for SDF initialized MLP.')

    pbar = range(opt.iter)
    if get_rank() == 0:
        pbar = tqdm(pbar, initial=opt.start_iter, dynamic_ncols=True, smoothing=0.01)

    for idx in pbar:
        i = idx + opt.start_iter

        if i > opt.iter:
            print("Done!")

            break

        requires_grad(generator, False)
        requires_grad(discriminator, True)
        discriminator.zero_grad()
        _, real_imgs = next(loader)
        real_imgs = real_imgs.to(device)
        noise = mixing_noise(opt.batch, opt.style_dim, opt.mixing, device)
        cam_extrinsics, focal, near, far, gt_viewpoints = generate_camera_params(opt.renderer_output_size, device, batch=opt.batch,
                                                                            uniform=opt.camera.uniform, azim_range=opt.camera.azim,
                                                                            elev_range=opt.camera.elev, fov_ang=opt.camera.fov,
                                                                            dist_radius=opt.camera.dist_radius)
        gen_imgs = []
        for j in range(0, opt.batch, opt.chunk):
            curr_noise = [n[j:j+opt.chunk] for n in noise]
            _, fake_img = generator(curr_noise,
                                    cam_extrinsics[j:j+opt.chunk],
                                    focal[j:j+opt.chunk],
                                    near[j:j+opt.chunk],
                                    far[j:j+opt.chunk])

            gen_imgs += [fake_img]

        gen_imgs = torch.cat(gen_imgs, 0)
        fake_pred, fake_viewpoint_pred = discriminator(gen_imgs.detach())
        if viewpoint_condition:
            d_view_loss = opt.view_lambda * viewpoints_loss(fake_viewpoint_pred, gt_viewpoints)

        real_imgs.requires_grad = True
        real_pred, _ = discriminator(real_imgs)
        d_gan_loss = d_logistic_loss(real_pred, fake_pred)
        grad_penalty = d_r1_loss(real_pred, real_imgs)
        r1_loss = opt.r1 * 0.5 * grad_penalty
        d_loss = d_gan_loss + r1_loss + d_view_loss
        d_loss.backward()
        d_optim.step()

        loss_dict["d"] = d_gan_loss
        loss_dict["r1"] = r1_loss
        loss_dict["d_view"] = d_view_loss
        loss_dict["real_score"] = real_pred.mean()
        loss_dict["fake_score"] = fake_pred.mean()

        requires_grad(generator, True)
        requires_grad(discriminator, False)

        for j in range(0, opt.batch, opt.chunk):
            noise = mixing_noise(opt.chunk, opt.style_dim, opt.mixing, device)
            cam_extrinsics, focal, near, far, curr_gt_viewpoints = generate_camera_params(opt.renderer_output_size, device, batch=opt.chunk,
                                                                                     uniform=opt.camera.uniform, azim_range=opt.camera.azim,
                                                                                     elev_range=opt.camera.elev, fov_ang=opt.camera.fov,
                                                                                     dist_radius=opt.camera.dist_radius)

            out = generator(noise, cam_extrinsics, focal, near, far,
                            return_sdf=opt.min_surf_lambda > 0,
                            return_eikonal=opt.eikonal_lambda > 0)
            fake_img  = out[1]
            if opt.min_surf_lambda > 0:
                sdf = out[2]
            if opt.eikonal_lambda > 0:
                eikonal_term = out[3]

            fake_pred, fake_viewpoint_pred = discriminator(fake_img)
            if viewpoint_condition:
                g_view_loss = opt.view_lambda * viewpoints_loss(fake_viewpoint_pred, curr_gt_viewpoints)

            if opt.with_sdf and opt.eikonal_lambda > 0:
                g_eikonal, g_minimal_surface = eikonal_loss(eikonal_term, sdf=sdf if opt.min_surf_lambda > 0 else None,
                                                            beta=opt.min_surf_beta)
                g_eikonal = opt.eikonal_lambda * g_eikonal
                if opt.min_surf_lambda > 0:
                    g_minimal_surface = opt.min_surf_lambda * g_minimal_surface

            g_gan_loss = g_nonsaturating_loss(fake_pred)
            g_loss = g_gan_loss + g_view_loss + g_eikonal + g_minimal_surface
            g_loss.backward()

        g_optim.step()
        generator.zero_grad()
        loss_dict["g"] = g_gan_loss
        loss_dict["g_view"] = g_view_loss
        loss_dict["g_eikonal"] = g_eikonal
        loss_dict["g_minimal_surface"] = g_minimal_surface

        accumulate(g_ema, g_module, accum)

        loss_reduced = reduce_loss_dict(loss_dict)
        d_loss_val = loss_reduced["d"].mean().item()
        g_loss_val = loss_reduced["g"].mean().item()
        r1_val = loss_reduced["r1"].mean().item()
        real_score_val = loss_reduced["real_score"].mean().item()
        fake_score_val = loss_reduced["fake_score"].mean().item()
        d_view_val = loss_reduced["d_view"].mean().item()
        g_view_val = loss_reduced["g_view"].mean().item()
        g_eikonal_loss = loss_reduced["g_eikonal"].mean().item()
        g_minimal_surface_loss = loss_reduced["g_minimal_surface"].mean().item()
        g_beta_val = g_module.renderer.sigmoid_beta.item() if opt.with_sdf else 0

        if get_rank() == 0:
            pbar.set_description(
                (f"d: {d_loss_val:.4f}; g: {g_loss_val:.4f}; r1: {r1_val:.4f}; viewpoint: {d_view_val+g_view_val:.4f}; eikonal: {g_eikonal_loss:.4f}; surf: {g_minimal_surface_loss:.4f}")
            )

            if i % 1000 == 0:
                with torch.no_grad():
                    samples = torch.Tensor(0, 3, opt.renderer_output_size, opt.renderer_output_size)
                    step_size = 4
                    mean_latent = g_module.mean_latent(10000, device)
                    for k in range(0, opt.val_n_sample * 8, step_size):
                        _, curr_samples = g_ema([sample_z[0][k:k+step_size]],
                                                sample_cam_extrinsics[k:k+step_size],
                                                sample_focals[k:k+step_size],
                                                sample_near[k:k+step_size],
                                                sample_far[k:k+step_size],
                                                truncation=0.7,
                                                truncation_latent=mean_latent,)
                        samples = torch.cat([samples, curr_samples.cpu()], 0)

                    if i % 10000 == 0:
                        utils.save_image(samples,
                            os.path.join(opt.checkpoints_dir, experiment_opt.expname, 'volume_renderer', f"samples/{str(i).zfill(7)}.png"),
                            nrow=int(opt.val_n_sample),
                            normalize=True,
                            value_range=(-1, 1),)

            if wandb and opt.wandb:
                wandb_log_dict = {"Generator": g_loss_val,
                                  "Discriminator": d_loss_val,
                                  "R1": r1_val,
                                  "Real Score": real_score_val,
                                  "Fake Score": fake_score_val,
                                  "D viewpoint": d_view_val,
                                  "G viewpoint": g_view_val,
                                  "G eikonal loss": g_eikonal_loss,
                                  "G minimal surface loss": g_minimal_surface_loss,
                                  }
                if opt.with_sdf:
                    wandb_log_dict.update({"Beta value": g_beta_val})

                if i % 1000 == 0:
                    wandb_grid = utils.make_grid(samples, nrow=int(opt.val_n_sample),
                                                   normalize=True, value_range=(-1, 1))
                    wandb_ndarr = (255 * wandb_grid.permute(1, 2, 0).numpy()).astype(np.uint8)
                    wandb_images = Image.fromarray(wandb_ndarr)
                    wandb_log_dict.update({"examples": [wandb.Image(wandb_images,
                                            caption="Generated samples for azimuth angles of: -35, -25, -15, -5, 5, 15, 25, 35 degrees.")]})

                wandb.log(wandb_log_dict)

            if i % 10000 == 0 or (i < 10000 and i % 1000 == 0):
                torch.save(
                    {
                        "g": g_module.state_dict(),
                        "d": d_module.state_dict(),
                        "g_ema": g_ema.state_dict(),
                    },
                    os.path.join(opt.checkpoints_dir, experiment_opt.expname, 'volume_renderer', f"models_{str(i).zfill(7)}.pt")
                )
                print('Successfully saved checkpoint for iteration {}.'.format(i))

    if get_rank() == 0:
        # create final model directory
        final_model_path = 'pretrained_renderer'
        os.makedirs(final_model_path, exist_ok=True)
        torch.save(
            {
                "g": g_module.state_dict(),
                "d": d_module.state_dict(),
                "g_ema": g_ema.state_dict(),
            },
            os.path.join(final_model_path, experiment_opt.expname + '_vol_renderer.pt')
        )
        print('Successfully saved final model.')


if __name__ == "__main__":
    device = "cuda"
    opt = BaseOptions().parse()
    opt.model.freeze_renderer = False
    opt.model.no_viewpoint_loss = opt.training.view_lambda == 0.0
    opt.training.camera = opt.camera
    opt.training.renderer_output_size = opt.model.renderer_spatial_output_dim
    opt.training.style_dim = opt.model.style_dim
    opt.training.with_sdf = not opt.rendering.no_sdf
    if opt.training.with_sdf and opt.training.min_surf_lambda > 0:
        opt.rendering.return_sdf = True
    opt.training.iter = 200001
    opt.rendering.no_features_output = True

    n_gpu = int(os.environ["WORLD_SIZE"]) if "WORLD_SIZE" in os.environ else 1
    opt.training.distributed = n_gpu > 1

    if opt.training.distributed:
        torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))
        torch.distributed.init_process_group(backend="nccl", init_method="env://")
        synchronize()

    # create checkpoints directories
    os.makedirs(os.path.join(opt.training.checkpoints_dir, opt.experiment.expname, 'volume_renderer'), exist_ok=True)
    os.makedirs(os.path.join(opt.training.checkpoints_dir, opt.experiment.expname, 'volume_renderer', 'samples'), exist_ok=True)

    discriminator = VolumeRenderDiscriminator(opt.model).to(device)
    generator = Generator(opt.model, opt.rendering, full_pipeline=False).to(device)
    g_ema = Generator(opt.model, opt.rendering, ema=True, full_pipeline=False).to(device)

    g_ema.eval()
    accumulate(g_ema, generator, 0)
    g_optim = optim.Adam(generator.parameters(), lr=2e-5, betas=(0, 0.9))
    d_optim = optim.Adam(discriminator.parameters(), lr=2e-4, betas=(0, 0.9))

    opt.training.start_iter = 0

    if opt.experiment.continue_training and opt.experiment.ckpt is not None:
        if get_rank() == 0:
            print("load model:", opt.experiment.ckpt)
        ckpt_path = os.path.join(opt.training.checkpoints_dir,
                                 opt.experiment.expname,
                                 'models_{}.pt'.format(opt.experiment.ckpt.zfill(7)))
        ckpt = torch.load(ckpt_path, map_location=lambda storage, loc: storage)

        try:
            opt.training.start_iter = int(opt.experiment.ckpt) + 1

        except ValueError:
            pass

        generator.load_state_dict(ckpt["g"])
        discriminator.load_state_dict(ckpt["d"])
        g_ema.load_state_dict(ckpt["g_ema"])
        if "g_optim" in ckpt.keys():
            g_optim.load_state_dict(ckpt["g_optim"])
            d_optim.load_state_dict(ckpt["d_optim"])

    sphere_init_path = './pretrained_renderer/sphere_init.pt'
    if opt.training.no_sphere_init:
        opt.training.sphere_init = False
    elif not opt.experiment.continue_training and opt.training.with_sdf and os.path.isfile(sphere_init_path):
        if get_rank() == 0:
            print("loading sphere inititialized model")
        ckpt = torch.load(sphere_init_path, map_location=lambda storage, loc: storage)
        generator.load_state_dict(ckpt["g"])
        discriminator.load_state_dict(ckpt["d"])
        g_ema.load_state_dict(ckpt["g_ema"])
        opt.training.sphere_init = False
    else:
        opt.training.sphere_init = True

    if opt.training.distributed:
        generator = nn.parallel.DistributedDataParallel(
            generator,
            device_ids=[opt.training.local_rank],
            output_device=opt.training.local_rank,
            broadcast_buffers=True,
            find_unused_parameters=True,
        )

        discriminator = nn.parallel.DistributedDataParallel(
            discriminator,
            device_ids=[opt.training.local_rank],
            output_device=opt.training.local_rank,
            broadcast_buffers=False,
            find_unused_parameters=True
        )

    transform = transforms.Compose(
        [transforms.ToTensor(),
         transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)])

    dataset = MultiResolutionDataset(opt.dataset.dataset_path, transform, opt.model.size,
                                     opt.model.renderer_spatial_output_dim)
    loader = data.DataLoader(
        dataset,
        batch_size=opt.training.batch,
        sampler=data_sampler(dataset, shuffle=True, distributed=opt.training.distributed),
        drop_last=True,
    )
    opt.training.dataset_name = opt.dataset.dataset_path.lower()

    # save options
    opt_path = os.path.join(opt.training.checkpoints_dir, opt.experiment.expname, 'volume_renderer', f"opt.yaml")
    with open(opt_path,'w') as f:
        yaml.safe_dump(opt, f)

    # set wandb environment
    if get_rank() == 0 and wandb is not None and opt.training.wandb:
        wandb.init(project="StyleSDF")
        wandb.run.name = opt.experiment.expname
        wandb.config.dataset = os.path.basename(opt.dataset.dataset_path)
        wandb.config.update(opt.training)
        wandb.config.update(opt.model)
        wandb.config.update(opt.rendering)

    train(opt.training, opt.experiment, loader, generator, discriminator, g_optim, d_optim, g_ema, device)


================================================
FILE: utils.py
================================================
import torch
import random
import trimesh
import numpy as np
from torch import nn
from torch.nn import functional as F
from torch.utils import data
from scipy.spatial import Delaunay
from skimage.measure import marching_cubes
from pdb import set_trace as st

import pytorch3d.io
from pytorch3d.structures import Meshes
from pytorch3d.renderer import (
    look_at_view_transform,
    FoVPerspectiveCameras,
    PointLights,
    RasterizationSettings,
    MeshRenderer,
    MeshRasterizer,
    SoftPhongShader,
    TexturesVertex,
)

######################### Dataset util functions ###########################
# Get data sampler
def data_sampler(dataset, shuffle, distributed):
    if distributed:
        return data.distributed.DistributedSampler(dataset, shuffle=shuffle)

    if shuffle:
        return data.RandomSampler(dataset)

    else:
        return data.SequentialSampler(dataset)

# Get data minibatch
def sample_data(loader):
    while True:
        for batch in loader:
            yield batch

############################## Model weights util functions #################
# Turn model gradients on/off
def requires_grad(model, flag=True):
    for p in model.parameters():
        p.requires_grad = flag

# Exponential moving average for generator weights
def accumulate(model1, model2, decay=0.999):
    par1 = dict(model1.named_parameters())
    par2 = dict(model2.named_parameters())

    for k in par1.keys():
        par1[k].data.mul_(decay).add_(par2[k].data, alpha=1 - decay)

################### Latent code (Z) sampling util functions ####################
# Sample Z space latent codes for the generator
def make_noise(batch, latent_dim, n_noise, device):
    if n_noise == 1:
        return torch.randn(batch, latent_dim, device=device)

    noises = torch.randn(n_noise, batch, latent_dim, device=device).unbind(0)

    return noises


def mixing_noise(batch, latent_dim, prob, device):
    if prob > 0 and random.random() < prob:
        return make_noise(batch, latent_dim, 2, device)
    else:
        return [make_noise(batch, latent_dim, 1, device)]

################# Camera parameters sampling ####################
def generate_camera_params(resolution, device, batch=1, locations=None, sweep=False,
                           uniform=False, azim_range=0.3, elev_range=0.15,
                           fov_ang=6, dist_radius=0.12):
    if locations != None:
        azim = locations[:,0].view(-1,1)
        elev = locations[:,1].view(-1,1)

        # generate intrinsic parameters
        # fix distance to 1
        dist = torch.ones(azim.shape[0], 1, device=device)
        near, far = (dist - dist_radius).unsqueeze(-1), (dist + dist_radius).unsqueeze(-1)
        fov_angle = fov_ang * torch.ones(azim.shape[0], 1, device=device).view(-1,1) * np.pi / 180
        focal = 0.5 * resolution / torch.tan(fov_angle).unsqueeze(-1)
    elif sweep:
        # generate camera locations on the unit sphere
        azim = (-azim_range + (2 * azim_range / 7) * torch.arange(8, device=device)).view(-1,1).repeat(batch,1)
        elev = (-elev_range + 2 * elev_range * torch.rand(batch, 1, device=device).repeat(1,8).view(-1,1))

        # generate intrinsic parameters
        dist = (torch.ones(batch, 1, device=device)).repeat(1,8).view(-1,1)
        near, far = (dist - dist_radius).unsqueeze(-1), (dist + dist_radius).unsqueeze(-1)
        fov_angle = fov_ang * torch.ones(batch, 1, device=device).repeat(1,8).view(-1,1) * np.pi / 180
        focal = 0.5 * resolution / torch.tan(fov_angle).unsqueeze(-1)
    else:
        # sample camera locations on the unit sphere
        if uniform:
            azim = (-azim_range + 2 * azim_range * torch.rand(batch, 1, device=device))
            elev = (-elev_range + 2 * elev_range * torch.rand(batch, 1, device=device))
        else:
            azim = (azim_range * torch.randn(batch, 1, device=device))
            elev = (elev_range * torch.randn(batch, 1, device=device))

        # generate intrinsic parameters
        dist = torch.ones(batch, 1, device=device) # restrict camera position to be on the unit sphere
        near, far = (dist - dist_radius).unsqueeze(-1), (dist + dist_radius).unsqueeze(-1)
        fov_angle = fov_ang * torch.ones(batch, 1, device=device) * np.pi / 180 # full fov is 12 degrees
        focal = 0.5 * resolution / torch.tan(fov_angle).unsqueeze(-1)

    viewpoint = torch.cat([azim, elev], 1)

    #### Generate camera extrinsic matrix ##########

    # convert angles to xyz coordinates
    x = torch.cos(elev) * torch.sin(azim)
    y = torch.sin(elev)
    z = torch.cos(elev) * torch.cos(azim)
    camera_dir = torch.stack([x, y, z], dim=1).view(-1,3)
    camera_loc = dist * camera_dir

    # get rotation matrices (assume object is at the world coordinates origin)
    up = torch.tensor([[0,1,0]]).float().to(device) * torch.ones_like(dist)
    z_axis = F.normalize(camera_dir, eps=1e-5) # the -z direction points into the screen
    x_axis = F.normalize(torch.cross(up, z_axis, dim=1), eps=1e-5)
    y_axis = F.normalize(torch.cross(z_axis, x_axis, dim=1), eps=1e-5)
    is_close = torch.isclose(x_axis, torch.tensor(0.0), atol=5e-3).all(dim=1, keepdim=True)
    if is_close.any():
        replacement = F.normalize(torch.cross(y_axis, z_axis, dim=1), eps=1e-5)
        x_axis = torch.where(is_close, replacement, x_axis)
    R = torch.cat((x_axis[:, None, :], y_axis[:, None, :], z_axis[:, None, :]), dim=1)
    T = camera_loc[:, :, None]
    extrinsics = torch.cat((R.transpose(1,2),T), -1)

    return extrinsics, focal, near, far, viewpoint

#################### Mesh generation util functions ########################
# Reshape sampling volume to camera frostum
def align_volume(volume, near=0.88, far=1.12):
    b, h, w, d, c = volume.shape
    yy, xx, zz = torch.meshgrid(torch.linspace(-1, 1, h),
                                torch.linspace(-1, 1, w),
                                torch.linspace(-1, 1, d))

    grid = torch.stack([xx, yy, zz], -1).to(volume.device)

    frostum_adjustment_coeffs = torch.linspace(far / near, 1, d).view(1,1,1,-1,1).to(volume.device)
    frostum_grid = grid.unsqueeze(0)
    frostum_grid[...,:2] = frostum_grid[...,:2] * frostum_adjustment_coeffs
    out_of_boundary = torch.any((frostum_grid.lt(-1).logical_or(frostum_grid.gt(1))), -1, keepdim=True)
    frostum_grid = frostum_grid.permute(0,3,1,2,4).contiguous()
    permuted_volume = volume.permute(0,4,3,1,2).contiguous()
    final_volume = F.grid_sample(permuted_volume, frostum_grid, padding_mode="border", align_corners=True)
    final_volume = final_volume.permute(0,3,4,2,1).contiguous()
    # set a non-zero value to grid locations outside of the frostum to avoid marching cubes distortions.
    # It happens because pytorch grid_sample uses zeros padding.
    final_volume[out_of_boundary] = 1

    return final_volume

# Extract mesh with marching cubes
def extract_mesh_with_marching_cubes(sdf):
    b, h, w, d, _ = sdf.shape

    # change coordinate order from (y,x,z) to (x,y,z)
    sdf_vol = sdf[0,...,0].permute(1,0,2).cpu().numpy()

    # scale vertices
    verts, faces, _, _ = marching_cubes(sdf_vol, 0)
    verts[:,0] = (verts[:,0]/float(w)-0.5)*0.24
    verts[:,1] = (verts[:,1]/float(h)-0.5)*0.24
    verts[:,2] = (verts[:,2]/float(d)-0.5)*0.24

    # fix normal direction
    verts[:,2] *= -1; verts[:,1] *= -1
    mesh = trimesh.Trimesh(verts, faces)

    return mesh

# Generate mesh from xyz point cloud
def xyz2mesh(xyz):
    b, _, h, w = xyz.shape
    x, y = np.meshgrid(np.arange(h), np.arange(w))

    # Extract mesh faces from xyz maps
    tri = Delaunay(np.concatenate((x.reshape((h*w, 1)), y.reshape((h*w, 1))), 1))
    faces = tri.simplices

    # invert normals
    faces[:,[0, 1]] = faces[:,[1, 0]]

    # generate_meshes
    mesh = trimesh.Trimesh(xyz.squeeze(0).permute(1,2,0).view(h*w,3).cpu().numpy(), faces)

    return mesh


################# Mesh rendering util functions #############################
def add_textures(meshes:Meshes, vertex_colors=None) -> Meshes:
    verts = meshes.verts_padded()
    if vertex_colors is None:
        vertex_colors = torch.ones_like(verts) # (N, V, 3)
    textures = TexturesVertex(verts_features=vertex_colors)
    meshes_t = Meshes(
        verts=verts,
        faces=meshes.faces_padded(),
        textures=textures,
        verts_normals=meshes.verts_normals_padded(),
    )
    return meshes_t


def create_cameras(
    R=None, T=None,
    azim=0, elev=0., dist=1.,
    fov=12., znear=0.01,
    device="cuda") -> FoVPerspectiveCameras:
    """
    all the camera parameters can be a single number, a list, or a torch tensor.
    """
    if R is None or T is None:
        R, T = look_at_view_transform(dist=dist, azim=azim, elev=elev, device=device)
    cameras = FoVPerspectiveCameras(device=device, R=R, T=T, znear=znear, fov=fov)
    return cameras


def create_mesh_renderer(
        cameras: FoVPerspectiveCameras,
        image_size: int = 256,
        blur_radius: float = 1e-6,
        light_location=((-0.5, 1., 5.0),),
        device="cuda",
        **light_kwargs,
):
    """
    If don't want to show direct texture color without shading, set the light_kwargs as
    ambient_color=((1, 1, 1), ), diffuse_color=((0, 0, 0), ), specular_color=((0, 0, 0), )
    """
    # We will also create a Phong renderer. This is simpler and only needs to render one face per pixel.
    raster_settings = RasterizationSettings(
        image_size=image_size,
        blur_radius=blur_radius,
        faces_per_pixel=5,
    )
    # We can add a point light in front of the object.
    lights = PointLights(
        device=device, location=light_location, **light_kwargs
    )
    phong_renderer = MeshRenderer(
        rasterizer=MeshRasterizer(
            cameras=cameras,
            raster_settings=raster_settings
        ),
        shader=SoftPhongShader(device=device, cameras=cameras, lights=lights)
    )

    return phong_renderer


## custom renderer
class MeshRendererWithDepth(nn.Module):
    def __init__(self, rasterizer, shader):
        super().__init__()
        self.rasterizer = rasterizer
        self.shader = shader

    def forward(self, meshes_world, **kwargs) -> torch.Tensor:
        fragments = self.rasterizer(meshes_world, **kwargs)
        images = self.shader(fragments, meshes_world, **kwargs)
        return images, fragments.zbuf


def create_depth_mesh_renderer(
        cameras: FoVPerspectiveCameras,
        image_size: int = 256,
        blur_radius: float = 1e-6,
        device="cuda",
        **light_kwargs,
):
    """
    If don't want to show direct texture color without shading, set the light_kwargs as
    ambient_color=((1, 1, 1), ), diffuse_color=((0, 0, 0), ), specular_color=((0, 0, 0), )
    """
    # We will also create a Phong renderer. This is simpler and only needs to render one face per pixel.
    raster_settings = RasterizationSettings(
        image_size=image_size,
        blur_radius=blur_radius,
        faces_per_pixel=17,
    )
    # We can add a point light in front of the object.
    lights = PointLights(
        device=device, location=((-0.5, 1., 5.0),), **light_kwargs
    )
    renderer = MeshRendererWithDepth(
        rasterizer=MeshRasterizer(
            cameras=cameras,
            raster_settings=raster_settings,
            device=device,
        ),
        shader=SoftPhongShader(device=device, cameras=cameras, lights=lights)
    )

    return renderer


================================================
FILE: volume_renderer.py
================================================
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.autograd as autograd
import numpy as np
from functools import partial
from pdb import set_trace as st


# Basic SIREN fully connected layer
class LinearLayer(nn.Module):
    def __init__(self, in_dim, out_dim, bias=True, bias_init=0, std_init=1, freq_init=False, is_first=False):
        super().__init__()
        if is_first:
            self.weight = nn.Parameter(torch.empty(out_dim, in_dim).uniform_(-1 / in_dim, 1 / in_dim))
        elif freq_init:
            self.weight = nn.Parameter(torch.empty(out_dim, in_dim).uniform_(-np.sqrt(6 / in_dim) / 25, np.sqrt(6 / in_dim) / 25))
        else:
            self.weight = nn.Parameter(0.25 * nn.init.kaiming_normal_(torch.randn(out_dim, in_dim), a=0.2, mode='fan_in', nonlinearity='leaky_relu'))

        self.bias = nn.Parameter(nn.init.uniform_(torch.empty(out_dim), a=-np.sqrt(1/in_dim), b=np.sqrt(1/in_dim)))

        self.bias_init = bias_init
        self.std_init = std_init

    def forward(self, input):
        out = self.std_init * F.linear(input, self.weight, bias=self.bias) + self.bias_init

        return out

# Siren layer with frequency modulation and offset
class FiLMSiren(nn.Module):
    def __init__(self, in_channel, out_channel, style_dim, is_first=False):
        super().__init__()
        self.in_channel = in_channel
        self.out_channel = out_channel

        if is_first:
            self.weight = nn.Parameter(torch.empty(out_channel, in_channel).uniform_(-1 / 3, 1 / 3))
        else:
            self.weight = nn.Parameter(torch.empty(out_channel, in_channel).uniform_(-np.sqrt(6 / in_channel) / 25, np.sqrt(6 / in_channel) / 25))

        self.bias = nn.Parameter(nn.Parameter(nn.init.uniform_(torch.empty(out_channel), a=-np.sqrt(1/in_channel), b=np.sqrt(1/in_channel))))
        self.activation = torch.sin

        self.gamma = LinearLayer(style_dim, out_channel, bias_init=30, std_init=15)
        self.beta = LinearLayer(style_dim, out_channel, bias_init=0, std_init=0.25)

    def forward(self, input, style):
        batch, features = style.shape
        out = F.linear(input, self.weight, bias=self.bias)
        gamma = self.gamma(style).view(batch, 1, 1, 1, features)
        beta = self.beta(style).view(batch, 1, 1, 1, features)

        out = self.activation(gamma * out + beta)

        return out


# Siren Generator Model
class SirenGenerator(nn.Module):
    def __init__(self, D=8, W=256, style_dim=256, input_ch=3, input_ch_views=3, output_ch=4,
                 output_features=True):
        super(SirenGenerator, self).__init__()
        self.D = D
        self.W = W
        self.input_ch = input_ch
        self.input_ch_views = input_ch_views
        self.style_dim = style_dim
        self.output_features = output_features

        self.pts_linears = nn.ModuleList(
            [FiLMSiren(3, W, style_dim=style_dim, is_first=True)] + \
            [FiLMSiren(W, W, style_dim=style_dim) for i in range(D-1)])

        self.views_linears = FiLMSiren(input_ch_views + W, W,
                                       style_dim=style_dim)
        self.rgb_linear = LinearLayer(W, 3, freq_init=True)
        self.sigma_linear = LinearLayer(W, 1, freq_init=True)

    def forward(self, x, styles):
        input_pts, input_views = torch.split(x, [self.input_ch, self.input_ch_views], dim=-1)
        mlp_out = input_pts.contiguous()
        for i in range(len(self.pts_linears)):
            mlp_out = self.pts_linears[i](mlp_out, styles)


        sdf = self.sigma_linear(mlp_out)

        mlp_out = torch.cat([mlp_out, input_views], -1)
        out_features = self.views_linears(mlp_out, styles)
        rgb = self.rgb_linear(out_features)

        outputs = torch.cat([rgb, sdf], -1)
        if self.output_features:
            outputs = torch.cat([outputs, out_features], -1)

        return outputs


# Full vol
Download .txt
gitextract__fgegk92/

├── .gitignore
├── LICENSE
├── README.md
├── StyleSDF_demo.ipynb
├── dataset.py
├── distributed.py
├── download_models.py
├── generate_shapes_and_images.py
├── losses.py
├── model.py
├── op/
│   ├── __init__.py
│   ├── fused_act.py
│   ├── fused_bias_act.cpp
│   ├── fused_bias_act_kernel.cu
│   ├── upfirdn2d.cpp
│   ├── upfirdn2d.py
│   └── upfirdn2d_kernel.cu
├── options.py
├── prepare_data.py
├── render_video.py
├── requirements.txt
├── scripts/
│   ├── train_afhq_full_pipeline_512x512.sh
│   ├── train_afhq_vol_renderer.sh
│   ├── train_ffhq_full_pipeline_1024x1024.sh
│   └── train_ffhq_vol_renderer.sh
├── train_full_pipeline.py
├── train_volume_renderer.py
├── utils.py
└── volume_renderer.py
Download .txt
SYMBOL INDEX (171 symbols across 17 files)

FILE: dataset.py
  class MultiResolutionDataset (line 12) | class MultiResolutionDataset(Dataset):
    method __init__ (line 13) | def __init__(self, path, transform, resolution=256, nerf_resolution=64):
    method __len__ (line 33) | def __len__(self):
    method __getitem__ (line 36) | def __getitem__(self, index):

FILE: distributed.py
  function get_rank (line 9) | def get_rank():
  function synchronize (line 19) | def synchronize():
  function get_world_size (line 34) | def get_world_size():
  function reduce_sum (line 44) | def reduce_sum(tensor):
  function gather_grad (line 57) | def gather_grad(params):
  function all_gather (line 69) | def all_gather(data):
  function reduce_loss_dict (line 104) | def reduce_loss_dict(loss_dict):

FILE: download_models.py
  function download_pretrained_models (line 28) | def download_pretrained_models():
  function download_file (line 80) | def download_file(session, file_spec, use_alt_url=False, chunk_size=128,...

FILE: generate_shapes_and_images.py
  function generate (line 27) | def generate(opt, g_ema, surface_g_ema, device, mean_latent, surface_mea...

FILE: losses.py
  function viewpoints_loss (line 7) | def viewpoints_loss(viewpoint_pred, viewpoint_target):
  function eikonal_loss (line 13) | def eikonal_loss(eikonal_term, sdf=None, beta=100):
  function d_logistic_loss (line 27) | def d_logistic_loss(real_pred, fake_pred):
  function d_r1_loss (line 34) | def d_r1_loss(real_pred, real_img):
  function g_nonsaturating_loss (line 43) | def g_nonsaturating_loss(fake_pred):
  function g_path_regularize (line 49) | def g_path_regularize(fake_img, latents, mean_path_length, decay=0.01):

FILE: model.py
  class PixelNorm (line 24) | class PixelNorm(nn.Module):
    method __init__ (line 25) | def __init__(self):
    method forward (line 28) | def forward(self, input):
  class MappingLinear (line 32) | class MappingLinear(nn.Module):
    method __init__ (line 33) | def __init__(self, in_dim, out_dim, bias=True, activation=None, is_las...
    method forward (line 49) | def forward(self, input):
    method __repr__ (line 58) | def __repr__(self):
  function make_kernel (line 64) | def make_kernel(k):
  class Upsample (line 75) | class Upsample(nn.Module):
    method __init__ (line 76) | def __init__(self, kernel, factor=2):
    method forward (line 90) | def forward(self, input):
  class Downsample (line 96) | class Downsample(nn.Module):
    method __init__ (line 97) | def __init__(self, kernel, factor=2):
    method forward (line 111) | def forward(self, input):
  class Blur (line 117) | class Blur(nn.Module):
    method __init__ (line 118) | def __init__(self, kernel, pad, upsample_factor=1):
    method forward (line 130) | def forward(self, input):
  class EqualConv2d (line 136) | class EqualConv2d(nn.Module):
    method __init__ (line 137) | def __init__(
    method forward (line 156) | def forward(self, input):
    method __repr__ (line 167) | def __repr__(self):
  class EqualLinear (line 174) | class EqualLinear(nn.Module):
    method __init__ (line 175) | def __init__(self, in_dim, out_dim, bias=True, bias_init=0, lr_mul=1,
    method forward (line 192) | def forward(self, input):
    method __repr__ (line 203) | def __repr__(self):
  class ModulatedConv2d (line 209) | class ModulatedConv2d(nn.Module):
    method __init__ (line 210) | def __init__(self, in_channel, out_channel, kernel_size, style_dim, de...
    method __repr__ (line 249) | def __repr__(self):
    method forward (line 255) | def forward(self, input, style):
  class NoiseInjection (line 299) | class NoiseInjection(nn.Module):
    method __init__ (line 300) | def __init__(self, project=False):
    method create_pytorch_mesh (line 308) | def create_pytorch_mesh(self, trimesh):
    method load_mc_mesh (line 323) | def load_mc_mesh(self, filename, resolution=128, im_res=64):
    method project_noise (line 349) | def project_noise(self, noise, transform, mesh_path=None):
    method forward (line 380) | def forward(self, image, noise=None, transform=None, mesh_path=None):
  class StyledConv (line 390) | class StyledConv(nn.Module):
    method __init__ (line 391) | def __init__(self, in_channel, out_channel, kernel_size, style_dim,
    method forward (line 408) | def forward(self, input, style, noise=None, transform=None, mesh_path=...
  class ToRGB (line 416) | class ToRGB(nn.Module):
    method __init__ (line 417) | def __init__(self, in_channel, style_dim, upsample=True, blur_kernel=[...
    method forward (line 428) | def forward(self, input, style, skip=None):
  class ConvLayer (line 441) | class ConvLayer(nn.Sequential):
    method __init__ (line 442) | def __init__(self, in_channel, out_channel, kernel_size, downsample=Fa...
  class Decoder (line 478) | class Decoder(nn.Module):
    method __init__ (line 479) | def __init__(self, model_opt, blur_kernel=[1, 3, 3, 1]):
    method mean_latent (line 559) | def mean_latent(self, renderer_latent):
    method get_latent (line 564) | def get_latent(self, input):
    method styles_and_noise_forward (line 567) | def styles_and_noise_forward(self, styles, noise, inject_index=None, t...
    method forward (line 610) | def forward(self, features, styles, rgbd_in=None, transform=None,
  class Generator (line 641) | class Generator(nn.Module):
    method __init__ (line 642) | def __init__(self, model_opt, renderer_opt, blur_kernel=[1, 3, 3, 1], ...
    method make_noise (line 673) | def make_noise(self):
    method mean_latent (line 684) | def mean_latent(self, n_latent, device):
    method get_latent (line 695) | def get_latent(self, input):
    method styles_and_noise_forward (line 698) | def styles_and_noise_forward(self, styles, inject_index=None, truncati...
    method init_forward (line 715) | def init_forward(self, styles, cam_poses, focals, near=0.88, far=1.12):
    method forward (line 722) | def forward(self, styles, cam_poses, focals, near=0.88, far=1.12, retu...
  class VolumeRenderDiscConv2d (line 763) | class VolumeRenderDiscConv2d(nn.Module):
    method __init__ (line 764) | def __init__(self, in_channels, out_channels, kernel_size, stride=1,
    method forward (line 778) | def forward(self, input):
  class AddCoords (line 791) | class AddCoords(nn.Module):
    method __init__ (line 792) | def __init__(self):
    method forward (line 795) | def forward(self, input_tensor):
  class CoordConv2d (line 817) | class CoordConv2d(nn.Module):
    method __init__ (line 818) | def __init__(self, in_channels, out_channels, kernel_size, stride=1,
    method forward (line 826) | def forward(self, input_tensor):
  class CoordConvLayer (line 838) | class CoordConvLayer(nn.Module):
    method __init__ (line 839) | def __init__(self, in_channel, out_channel, kernel_size, bias=True, ac...
    method forward (line 856) | def forward(self, input):
  class VolumeRenderResBlock (line 864) | class VolumeRenderResBlock(nn.Module):
    method __init__ (line 865) | def __init__(self, in_channel, out_channel):
    method forward (line 877) | def forward(self, input):
  class VolumeRenderDiscriminator (line 893) | class VolumeRenderDiscriminator(nn.Module):
    method __init__ (line 894) | def __init__(self, opt):
    method forward (line 926) | def forward(self, input):
  class ResBlock (line 940) | class ResBlock(nn.Module):
    method __init__ (line 941) | def __init__(self, in_channel, out_channel, blur_kernel=[1, 3, 3, 1], ...
    method forward (line 949) | def forward(self, input):
  class Discriminator (line 957) | class Discriminator(nn.Module):
    method __init__ (line 958) | def __init__(self, opt, blur_kernel=[1, 3, 3, 1]):
    method forward (line 1001) | def forward(self, input):

FILE: op/fused_act.py
  class FusedLeakyReLUFunctionBackward (line 20) | class FusedLeakyReLUFunctionBackward(Function):
    method forward (line 22) | def forward(ctx, grad_output, out, bias, negative_slope, scale):
    method backward (line 47) | def backward(ctx, gradgrad_input, gradgrad_bias):
  class FusedLeakyReLUFunction (line 56) | class FusedLeakyReLUFunction(Function):
    method forward (line 58) | def forward(ctx, input, bias, negative_slope, scale):
    method backward (line 74) | def backward(ctx, grad_output):
  class FusedLeakyReLU (line 87) | class FusedLeakyReLU(nn.Module):
    method __init__ (line 88) | def __init__(self, channel, bias=True, negative_slope=0.2, scale=2 ** ...
    method forward (line 100) | def forward(self, input):
  function fused_leaky_relu (line 104) | def fused_leaky_relu(input, bias=None, negative_slope=0.2, scale=2 ** 0.5):

FILE: op/fused_bias_act.cpp
  function fused_bias_act (line 11) | torch::Tensor fused_bias_act(const torch::Tensor& input, const torch::Te...
  function PYBIND11_MODULE (line 19) | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {

FILE: op/upfirdn2d.cpp
  function upfirdn2d (line 12) | torch::Tensor upfirdn2d(const torch::Tensor& input, const torch::Tensor&...
  function PYBIND11_MODULE (line 21) | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {

FILE: op/upfirdn2d.py
  class UpFirDn2dBackward (line 20) | class UpFirDn2dBackward(Function):
    method forward (line 22) | def forward(
    method backward (line 64) | def backward(ctx, gradgrad_input):
  class UpFirDn2d (line 89) | class UpFirDn2d(Function):
    method forward (line 91) | def forward(ctx, input, kernel, up, down, pad):
    method backward (line 128) | def backward(ctx, grad_output):
  function upfirdn2d (line 146) | def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)):
  function upfirdn2d_native (line 160) | def upfirdn2d_native(

FILE: options.py
  class BaseOptions (line 5) | class BaseOptions():
    method __init__ (line 6) | def __init__(self):
    method initialize (line 10) | def initialize(self):
    method parse (line 94) | def parse(self):

FILE: prepare_data.py
  function resize_and_convert (line 14) | def resize_and_convert(img, size, resample):
  function resize_multiple (line 24) | def resize_multiple(
  function resize_worker (line 34) | def resize_worker(img_file, sizes, resample):
  function prepare (line 43) | def prepare(

FILE: render_video.py
  function render_video (line 26) | def render_video(opt, g_ema, surface_g_ema, device, mean_latent, surface...

FILE: train_full_pipeline.py
  function train (line 28) | def train(opt, experiment_opt, loader, generator, discriminator, g_optim...

FILE: train_volume_renderer.py
  function train (line 28) | def train(opt, experiment_opt, loader, generator, discriminator, g_optim...

FILE: utils.py
  function data_sampler (line 27) | def data_sampler(dataset, shuffle, distributed):
  function sample_data (line 38) | def sample_data(loader):
  function requires_grad (line 45) | def requires_grad(model, flag=True):
  function accumulate (line 50) | def accumulate(model1, model2, decay=0.999):
  function make_noise (line 59) | def make_noise(batch, latent_dim, n_noise, device):
  function mixing_noise (line 68) | def mixing_noise(batch, latent_dim, prob, device):
  function generate_camera_params (line 75) | def generate_camera_params(resolution, device, batch=1, locations=None, ...
  function align_volume (line 141) | def align_volume(volume, near=0.88, far=1.12):
  function extract_mesh_with_marching_cubes (line 164) | def extract_mesh_with_marching_cubes(sdf):
  function xyz2mesh (line 183) | def xyz2mesh(xyz):
  function add_textures (line 201) | def add_textures(meshes:Meshes, vertex_colors=None) -> Meshes:
  function create_cameras (line 215) | def create_cameras(
  function create_mesh_renderer (line 229) | def create_mesh_renderer(
  class MeshRendererWithDepth (line 263) | class MeshRendererWithDepth(nn.Module):
    method __init__ (line 264) | def __init__(self, rasterizer, shader):
    method forward (line 269) | def forward(self, meshes_world, **kwargs) -> torch.Tensor:
  function create_depth_mesh_renderer (line 275) | def create_depth_mesh_renderer(

FILE: volume_renderer.py
  class LinearLayer (line 12) | class LinearLayer(nn.Module):
    method __init__ (line 13) | def __init__(self, in_dim, out_dim, bias=True, bias_init=0, std_init=1...
    method forward (line 27) | def forward(self, input):
  class FiLMSiren (line 33) | class FiLMSiren(nn.Module):
    method __init__ (line 34) | def __init__(self, in_channel, out_channel, style_dim, is_first=False):
    method forward (line 50) | def forward(self, input, style):
  class SirenGenerator (line 62) | class SirenGenerator(nn.Module):
    method __init__ (line 63) | def __init__(self, D=8, W=256, style_dim=256, input_ch=3, input_ch_vie...
    method forward (line 82) | def forward(self, x, styles):
  class VolumeFeatureRenderer (line 103) | class VolumeFeatureRenderer(nn.Module):
    method __init__ (line 104) | def __init__(self, opt, style_dim=256, out_im_res=64, mode='train'):
    method get_rays (line 158) | def get_rays(self, focal, c2w):
    method get_eikonal_term (line 175) | def get_eikonal_term(self, pts, sdf):
    method sdf_activation (line 182) | def sdf_activation(self, input):
    method volume_integration (line 187) | def volume_integration(self, raw, z_vals, rays_d, pts, return_eikonal=...
    method run_network (line 254) | def run_network(self, inputs, viewdirs, styles=None):
    method render_rays (line 261) | def render_rays(self, ray_batch, styles=None, return_eikonal=False):
    method render (line 305) | def render(self, focal, c2w, near, far, styles, c2w_staticcam=None, re...
    method mlp_init_pass (line 319) | def mlp_init_pass(self, cam_poses, focal, near, far, styles=None):
    method forward (line 348) | def forward(self, cam_poses, focal, near, far, styles=None, return_eik...
Condensed preview — 29 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (217K chars).
[
  {
    "path": ".gitignore",
    "chars": 89,
    "preview": "__pycache__/\ncheckpoint/\ndatasets/\nevaluations/\nfull_models/\npretrained_renderer/\nwandb/\n"
  },
  {
    "path": "LICENSE",
    "chars": 1054,
    "preview": "Copyright (C) 2022 Roy Or-El, Xuan Luo, Mengyi Shan, Eli Shechtman, Jeong Joon Park, Ira Kemelmacher-Shlizerman.\nAll rig"
  },
  {
    "path": "README.md",
    "chars": 11597,
    "preview": "# StyleSDF: High-Resolution 3D-Consistent Image and Geometry Generation\n### [Project Page](https://stylesdf.github.io/) "
  },
  {
    "path": "StyleSDF_demo.ipynb",
    "chars": 20021,
    "preview": "{\n  \"nbformat\": 4,\n  \"nbformat_minor\": 0,\n  \"metadata\": {\n    \"colab\": {\n      \"name\": \"StyleSDF_demo.ipynb\",\n      \"pri"
  },
  {
    "path": "dataset.py",
    "chars": 1432,
    "preview": "import os\nimport csv\nimport lmdb\nimport random\nimport numpy as np\nimport torchvision.transforms.functional as TF\nfrom PI"
  },
  {
    "path": "distributed.py",
    "chars": 2711,
    "preview": "import math\nimport pickle\n\nimport torch\nfrom torch import distributed as dist\nfrom torch.utils.data.sampler import Sampl"
  },
  {
    "path": "download_models.py",
    "chars": 5986,
    "preview": "import os\nimport html\nimport glob\nimport uuid\nimport hashlib\nimport requests\nfrom tqdm import tqdm\nfrom pdb import set_t"
  },
  {
    "path": "generate_shapes_and_images.py",
    "chars": 11783,
    "preview": "import os\nimport torch\nimport trimesh\nimport numpy as np\nfrom munch import *\nfrom PIL import Image\nfrom tqdm import tqdm"
  },
  {
    "path": "losses.py",
    "chars": 1786,
    "preview": "import math\nimport torch\nfrom torch import autograd\nfrom torch.nn import functional as F\n\n\ndef viewpoints_loss(viewpoint"
  },
  {
    "path": "model.py",
    "chars": 34016,
    "preview": "import math\nimport random\nimport trimesh\nimport torch\nimport numpy as np\nfrom torch import nn\nfrom torch.nn import funct"
  },
  {
    "path": "op/__init__.py",
    "chars": 89,
    "preview": "from .fused_act import FusedLeakyReLU, fused_leaky_relu\nfrom .upfirdn2d import upfirdn2d\n"
  },
  {
    "path": "op/fused_act.py",
    "chars": 3262,
    "preview": "import os\r\n\r\nimport torch\r\nfrom torch import nn\r\nfrom torch.nn import functional as F\r\nfrom torch.autograd import Functi"
  },
  {
    "path": "op/fused_bias_act.cpp",
    "chars": 846,
    "preview": "#include <torch/extension.h>\r\n\r\n\r\ntorch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias,"
  },
  {
    "path": "op/fused_bias_act_kernel.cu",
    "chars": 2875,
    "preview": "// Copyright (c) 2019, NVIDIA Corporation. All rights reserved.\r\n//\r\n// This work is made available under the Nvidia Sou"
  },
  {
    "path": "op/upfirdn2d.cpp",
    "chars": 988,
    "preview": "#include <torch/extension.h>\r\n\r\n\r\ntorch::Tensor upfirdn2d_op(const torch::Tensor& input, const torch::Tensor& kernel,\r\n "
  },
  {
    "path": "op/upfirdn2d.py",
    "chars": 5905,
    "preview": "import os\r\n\r\nimport torch\r\nfrom torch.nn import functional as F\r\nfrom torch.autograd import Function\r\nfrom torch.utils.c"
  },
  {
    "path": "op/upfirdn2d_kernel.cu",
    "chars": 12079,
    "preview": "// Copyright (c) 2019, NVIDIA Corporation. All rights reserved.\r\n//\r\n// This work is made available under the Nvidia Sou"
  },
  {
    "path": "options.py",
    "chars": 9522,
    "preview": "import configargparse\nfrom munch import *\nfrom pdb import set_trace as st\n\nclass BaseOptions():\n    def __init__(self):\n"
  },
  {
    "path": "prepare_data.py",
    "chars": 2836,
    "preview": "import argparse\nfrom io import BytesIO\nimport multiprocessing\nfrom functools import partial\n\nfrom PIL import Image\nimpor"
  },
  {
    "path": "render_video.py",
    "chars": 14374,
    "preview": "import os\nimport torch\nimport trimesh\nimport numpy as np\nimport skvideo.io\nfrom munch import *\nfrom PIL import Image\nfro"
  },
  {
    "path": "requirements.txt",
    "chars": 111,
    "preview": "lmdb\nnumpy\nninja\npillow\nrequests\ntqdm\nscipy\nscikit-image\nscikit-video\ntrimesh[easy]\nconfigargparse\nmunch\nwandb\n"
  },
  {
    "path": "scripts/train_afhq_full_pipeline_512x512.sh",
    "chars": 188,
    "preview": "python -m torch.distributed.launch --nproc_per_node 4 new_train.py --batch 8 --chunk 4 --azim 0.15 --r1 50.0 --expname a"
  },
  {
    "path": "scripts/train_afhq_vol_renderer.sh",
    "chars": 189,
    "preview": "python -m torch.distributed.launch --nproc_per_node 2 train_volume_renderer.py --batch 12 --chunk 6 --expname afhq_sdf_v"
  },
  {
    "path": "scripts/train_ffhq_full_pipeline_1024x1024.sh",
    "chars": 141,
    "preview": "python -m torch.distributed.launch --nproc_per_node 4 train_full_pipeline.py --batch 8 --chunk 2 --expname ffhq1024x1024"
  },
  {
    "path": "scripts/train_ffhq_vol_renderer.sh",
    "chars": 172,
    "preview": "python -m torch.distributed.launch --nproc_per_node 2 train_volume_renderer.py --batch 12 --chunk 6 --expname ffhq_sdf_v"
  },
  {
    "path": "train_full_pipeline.py",
    "chars": 18895,
    "preview": "import argparse\nimport math\nimport random\nimport os\nimport yaml\nimport numpy as np\nimport torch\nfrom torch import nn, au"
  },
  {
    "path": "train_volume_renderer.py",
    "chars": 18054,
    "preview": "import argparse\nimport math\nimport random\nimport os\nimport yaml\nimport numpy as np\nimport torch\nimport torch.distributed"
  },
  {
    "path": "utils.py",
    "chars": 11434,
    "preview": "import torch\nimport random\nimport trimesh\nimport numpy as np\nfrom torch import nn\nfrom torch.nn import functional as F\nf"
  },
  {
    "path": "volume_renderer.py",
    "chars": 15137,
    "preview": "import math\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport torch.autograd as autograd\nimport "
  }
]

About this extraction

This page contains the full source code of the royorel/StyleSDF GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 29 files (202.7 KB), approximately 51.7k tokens, and a symbol index with 171 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.

Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.

Copied to clipboard!