Showing preview only (212K chars total). Download the full file or copy to clipboard to get everything.
Repository: royorel/StyleSDF
Branch: main
Commit: 1c995101b80b
Files: 29
Total size: 202.7 KB
Directory structure:
gitextract__fgegk92/
├── .gitignore
├── LICENSE
├── README.md
├── StyleSDF_demo.ipynb
├── dataset.py
├── distributed.py
├── download_models.py
├── generate_shapes_and_images.py
├── losses.py
├── model.py
├── op/
│ ├── __init__.py
│ ├── fused_act.py
│ ├── fused_bias_act.cpp
│ ├── fused_bias_act_kernel.cu
│ ├── upfirdn2d.cpp
│ ├── upfirdn2d.py
│ └── upfirdn2d_kernel.cu
├── options.py
├── prepare_data.py
├── render_video.py
├── requirements.txt
├── scripts/
│ ├── train_afhq_full_pipeline_512x512.sh
│ ├── train_afhq_vol_renderer.sh
│ ├── train_ffhq_full_pipeline_1024x1024.sh
│ └── train_ffhq_vol_renderer.sh
├── train_full_pipeline.py
├── train_volume_renderer.py
├── utils.py
└── volume_renderer.py
================================================
FILE CONTENTS
================================================
================================================
FILE: .gitignore
================================================
__pycache__/
checkpoint/
datasets/
evaluations/
full_models/
pretrained_renderer/
wandb/
================================================
FILE: LICENSE
================================================
Copyright (C) 2022 Roy Or-El, Xuan Luo, Mengyi Shan, Eli Shechtman, Jeong Joon Park, Ira Kemelmacher-Shlizerman.
All rights reserved.
Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
The software is made available under Creative Commons BY-NC-SA 4.0 license
by University of Washington. You can use, redistribute, and adapt it
for non-commercial purposes, as long as you (a) give appropriate credit
by citing our paper, (b) indicate any changes that you've made,
and (c) distribute any derivative works under the same license.
THE AUTHORS DISCLAIM ALL WARRANTIES WITH REGARD TO THIS SOFTWARE, INCLUDING ALL
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR ANY PARTICULAR PURPOSE.
IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY SPECIAL, INDIRECT OR CONSEQUENTIAL
DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS,
WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING
OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
================================================
FILE: README.md
================================================
# StyleSDF: High-Resolution 3D-Consistent Image and Geometry Generation
### [Project Page](https://stylesdf.github.io/) | [Paper](https://arxiv.org/pdf/2112.11427.pdf) | [HuggingFace Demo](https://huggingface.co/spaces/SerdarHelli/StyleSDF-3D)
[](https://colab.research.google.com/github/royorel/StyleSDF/blob/main/StyleSDF_demo.ipynb)<br>
[Roy Or-El](https://homes.cs.washington.edu/~royorel/)<sup>1</sup> ,
[Xuan Luo](https://roxanneluo.github.io/)<sup>1</sup>,
[Mengyi Shan](https://shanmy.github.io/)<sup>1</sup>,
[Eli Shechtman](https://research.adobe.com/person/eli-shechtman/)<sup>2</sup>,
[Jeong Joon Park](https://jjparkcv.github.io/)<sup>3</sup>,
[Ira Kemelmacher-Shlizerman](https://www.irakemelmacher.com/)<sup>1</sup><br>
<sup>1</sup>University of Washington, <sup>2</sup>Adobe Research, <sup>3</sup>Stanford University
<div align="center">
<img src=./assets/teaser.png>
</div>
## Updates
12/26/2022: A new HuggingFace demo is now available. Special thanks to [@SerdarHelli](https://github.com/SerdarHelli) for the implementation.<br>
3/27/2022: Fixed a bug in the sphere initialization code (init_forward function was missing, see commit [0bd8741](https://github.com/royorel/StyleSDF/commit/0bd8741f26048d26160a495a9056b5f1da1a60a1)).<br>
3/22/2022: **Added training files**.<br>
3/9/2022: Fixed a bug in the calculation of the mean w vector (see commit [d4dd17d](https://github.com/royorel/StyleSDF/commit/d4dd17de09fd58adefc7ed49487476af6018894f)).<br>
3/4/2022: Testing code and Colab demo were released. **Training files will be released soon.**
## Overview
StyleSDF is a 3D-aware GAN, aimed at solving two main challenges:
1. High-resolution, view-consistent generation of the RGB images.
2. Generating detailed 3D shapes.
StyleSDF is trained only on single-view RGB data. The 3D geometry is learned implicitly with an SDF-based volume renderer.<br>
This code is the official PyTorch implementation of the paper:
> **StyleSDF: High-Resolution 3D-Consistent Image and Geometry Generation**<br>
> Roy Or-El, Xuan Luo, Mengyi Shan, Eli Shechtman, Jeong Joon Park, Ira Kemelmacher-Shlizerman<br>
> CVPR 2022<br>
> https://arxiv.org/pdf/2112.11427.pdf
## Abstract
We introduce a high resolution, 3D-consistent image and
shape generation technique which we call StyleSDF. Our
method is trained on single-view RGB data only, and stands
on the shoulders of StyleGAN2 for image generation, while
solving two main challenges in 3D-aware GANs: 1) high-resolution,
view-consistent generation of the RGB images,
and 2) detailed 3D shape. We achieve this by merging a
SDF-based 3D representation with a style-based 2D generator.
Our 3D implicit network renders low-resolution feature
maps, from which the style-based network generates
view-consistent, 1024×1024 images. Notably, our SDFbased
3D modeling defines detailed 3D surfaces, leading
to consistent volume rendering. Our method shows higher
quality results compared to state of the art in terms of visual
and geometric quality.
## Pre-Requisits
You must have a **GPU with CUDA support** in order to run the code.
This code requires **PyTorch**, **PyTorch3D** and **torchvision** to be installed, please go to [PyTorch.org](https://pytorch.org/) and [PyTorch3d.org](https://pytorch3d.org/) for installation info.<br>
We tested our code on Python 3.8.5, PyTorch 1.9.0, PyTorch3D 0.6.1 and torchvision 0.10.0.
The following packages should also be installed:
1. lmdb
2. numpy
3. ninja
4. pillow
5. requests
6. tqdm
7. scipy
8. skimage
9. skvideo
10. trimesh[easy]
11. configargparse
12. munch
13. wandb (optional)
If any of these packages are not installed on your computer, you can install them using the supplied `requirements.txt` file:<br>
```pip install -r requirements.txt```
## Download Pre-trained Models
The pre-trained models can be downloaded by running `python download_models.py`.
## Quick Demo
You can explore our method in Google Colab [](https://colab.research.google.com/github/royorel/StyleSDF/blob/main/StyleSDF_demo.ipynb).
Alternatively, you can download the pretrained models by running:<br>
`python download_models.py`
To generate human faces from the model pre-trained on FFHQ, run:<br>
`python generate_shapes_and_images.py --expname ffhq1024x1024 --size 1024 --identities NUMBER_OF_FACES`
To generate animal faces from the model pre-trained on AFHQ, run:<br>
`python generate_shapes_and_images.py --expname afhq512x512 --size 512 --identities NUMBER_OF_FACES`
## Generating images and meshes
To generate images and meshes from a trained model, run:
`python generate_shapes_and_images.py --expname NAME_OF_TRAINED_MODEL --size MODEL_OUTPUT_SIZE --identities NUMBER_OF_FACES`
The script will generate an RGB image, a mesh generated from depth map, and the mesh extracted with Marching cubes.
### Optional flags for image and mesh generation
```
--no_surface_renderings When true, only RGB outputs will be generated. otherwise, both RGB and depth videos/renderings will be generated. (default: false)
--fixed_camera_angles When true, the generator will render indentities from a fixed set of camera angles. (default: false)
```
## Generating Videos
To generate videos from a trained model, run: <br>
`python render_video.py --expname NAME_OF_TRAINED_MODEL --size MODEL_OUTPUT_SIZE --identities NUMBER_OF_FACES`.
This script will generate RGB video as well as depth map video for each identity. The average processing time per video is ~5-10 minutes on an RTX2080 Ti GPU.
### Optional flags for video rendering
```
--no_surface_videos When true, only RGB video will be generated when running render_video.py. otherwise, both RGB and depth videos will be generated. this cuts the processing time per video. (default: false)
--azim_video When true, the camera trajectory will travel along the azimuth direction. Otherwise, the camera will travel along an ellipsoid trajectory. (default: ellipsoid)
--project_noise When true, use geometry-aware noise projection to reduce flickering effects (see supplementary section C.1 in the paper). Warning: processing time significantly increases with this flag to ~20 minutes per video. (default: false)
```
## Training
### Preparing your Dataset
If you wish to train a model from scratch, first you need to convert your dataset to an lmdb format. Run:<br>
`python prepare_data.py --out_path OUTPUT_LMDB_PATH --n_worker N_WORKER --size SIZE1,SIZE2,SIZE3,... INPUT_DATASET_PATH`
### Training the volume renderer
#### Training scripts
To train the volume renderer on FFHQ run: `bash ./scripts/train_ffhq_vol_renderer.sh`. <br>
To train the volume renderer on AFHQ run: `bash ./scripts/train_afhq_vol_renderer.sh`.
* The scripts above use distributed training. To train the models on a single GPU (not recommended) remove `-m torch.distributed.launch --nproc_per_node NUM_GPUS` from the script.
#### Training on a new dataset
To train the volume renderer on a new dataset, run:<br>
`python train_volume_renderer.py --batch BATCH_SIZE --chunk CHUNK_SIZE --expname EXPERIMENT_NAME --dataset_path DATASET_PATH`
Ideally, `CHUNK_SIZE` should be the same as `BATCH_SIZE`, but on most GPUs it will likely cause an out of memory error. In such case, reduce `CHUNK_SIZE` to perform gradient accumulation.
**Important note**: The best way to monitor the SDF convergence is to look at "Beta value" graph on wandb. Convergence is successful once beta reaches values below approx. 3*10<sup>-3</sup>. If the SDF is not converging, increase the R1 regularization weight. Another helpful option (to a lesser degree) is to decrease the weight of the minimal surface regularization.
#### Distributed training
If you have multiple GPUs you can train your model on multiple instances by running:<br>
`python -m torch.distributed.launch --nproc_per_node NUM_GPUS train_volume_renderer.py --batch BATCH_SIZE --chunk CHUNK_SIZE --expname EXPERIMENT_NAME --dataset_path DATASET_PATH`
### Training the full pipeline
#### Training scripts
To train the volume renderer on FFHQ run: `bash ./scripts/train_ffhq_full_pipeline_1024x1024.sh`. <br>
To train the volume renderer on AFHQ run: `bash ./scripts/train_afhq_full_pipeline_512x512.sh`.
* The scripts above assume that the volume renderer model was already trained. **Do not run them from scratch.**
* The scripts above use distributed training. To train the models on a single GPU (not recommended) remove `-m torch.distributed.launch --nproc_per_node NUM_GPUS` from the script.
#### Training on a new dataset
To train the full pipeline on a new dataset, **first train the volume renderer separately**. <br>
After the volume renderer training is finished, run:<br>
`python train_full_pipeline.py --batch BATCH_SIZE --chunk CHUNK_SIZE --expname EXPERIMENT_NAME --size OUTPUT_SIZE`
Ideally, `CHUNK_SIZE` should be the same as `BATCH_SIZE`, but on most GPUs it will likely cause an out of memory error. In such case, reduce `CHUNK_SIZE` to perform gradient accumulation.
#### Distributed training
If you have multiple GPUs you can train your model on multiple instances by running:<br>
`python -m torch.distributed.launch --nproc_per_node NUM_GPUS train_full_pipeline.py --batch BATCH_SIZE --chunk CHUNK_SIZE --expname EXPERIMENT_NAME --size OUTPUT_SIZE`
Here, **BATCH_SIZE represents the batch per GPU**, not the overall batch size.
### Optional training flags
```
Training regime options:
--iter Total number of training iterations. (default: 300,000)
--wandb Use use weights and biases logging. (default: False)
--r1 Weight of the r1 regularization. (default: 10.0)
--view_lambda Weight of the viewpoint regularization. (Equation 6, default: 15)
--eikonal_lambda Weight of the eikonal regularization. (Equation 7, default: 0.1)
--min_surf_lambda Weight of the minimal surface regularization. (Equation 8, default: 0.05)
Camera options:
--uniform When true, the camera position is sampled from uniform distribution. (default: gaussian)
--azim Camera azimuth angle std (guassian)/range (uniform) in Radians. (default: 0.3 Rad.)
--elev Camera elevation angle std (guassian)/range (uniform) in Radians. (default: 0.15 Rad.)
--fov Camera field of view half angle in **Degrees**. (default: 6 Deg.)
--dist_radius Radius of points sampling distance from the origin. Determines the near and far fields. (default: 0.12)
```
## Citation
If you use this code for your research, please cite our paper.
```
@InProceedings{orel2022stylesdf,
title={Style{SDF}: {H}igh-{R}esolution {3D}-{C}onsistent {I}mage and {G}eometry {G}eneration},
author = {Or-El, Roy and
Luo, Xuan and
Shan, Mengyi and
Shechtman, Eli and
Park, Jeong Joon and
Kemelmacher-Shlizerman, Ira},
booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
month = {June},
year = {2022},
pages = {13503-13513}
}
```
## Acknowledgments
This code is inspired by rosinality's [StyleGAN2-PyTorch](https://github.com/rosinality/stylegan2-pytorch) and Yen-Chen Lin's [NeRF-pytorch](https://github.com/yenchenlin/nerf-pytorch).
A special thanks to [@SerdarHelli](https://github.com/SerdarHelli) for implementing the HuggingFace demo.
================================================
FILE: StyleSDF_demo.ipynb
================================================
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"name": "StyleSDF_demo.ipynb",
"private_outputs": true,
"provenance": [],
"collapsed_sections": [],
"include_colab_link": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"language_info": {
"name": "python"
},
"accelerator": "GPU"
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/github/royorel/StyleSDF/blob/main/StyleSDF_demo.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "markdown",
"source": [
"#StyleSDF Demo\n",
"\n",
"This Colab notebook demonstrates the capabilities of the StyleSDF 3D-aware GAN architecture proposed in our paper.\n",
"\n",
"This colab generates images with their correspondinig 3D meshes"
],
"metadata": {
"id": "b86fxLSo1gqI"
}
},
{
"cell_type": "markdown",
"source": [
"First, let's download the github repository and install all dependencies."
],
"metadata": {
"id": "8wSAkifH2Bk5"
}
},
{
"cell_type": "code",
"source": [
"!git clone https://github.com/royorel/StyleSDF.git\n",
"%cd StyleSDF\n",
"!pip3 install -r requirements.txt"
],
"metadata": {
"id": "1GTFRig12CuH"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"And install pytorch3D..."
],
"metadata": {
"id": "GwFJ8oLY4i8d"
}
},
{
"cell_type": "code",
"source": [
"!pip install -U fvcore\n",
"import sys\n",
"import torch\n",
"pyt_version_str=torch.__version__.split(\"+\")[0].replace(\".\", \"\")\n",
"version_str=\"\".join([\n",
" f\"py3{sys.version_info.minor}_cu\",\n",
" torch.version.cuda.replace(\".\",\"\"),\n",
" f\"_pyt{pyt_version_str}\"\n",
"])\n",
"!pip install --no-index --no-cache-dir pytorch3d -f https://dl.fbaipublicfiles.com/pytorch3d/packaging/wheels/{version_str}/download.html"
],
"metadata": {
"id": "zl3Vpddz3ols"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"Now let's download the pretrained models for FFHQ and AFHQ."
],
"metadata": {
"id": "OgMcslVS5vbC"
}
},
{
"cell_type": "code",
"source": [
"!python download_models.py"
],
"metadata": {
"id": "r1iDkz7r5wnO"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"Here, we import libraries and set options.\n",
"\n",
"Note: this might take a while (approx. 1-2 minutes) since CUDA kernels need to be compiled."
],
"metadata": {
"id": "F2JUyGq4JDa8"
}
},
{
"cell_type": "code",
"source": [
"import os\n",
"import torch\n",
"import trimesh\n",
"import numpy as np\n",
"from munch import *\n",
"from options import BaseOptions\n",
"from model import Generator\n",
"from generate_shapes_and_images import generate\n",
"from render_video import render_video\n",
"\n",
"\n",
"torch.random.manual_seed(321)\n",
"\n",
"\n",
"device = \"cuda\"\n",
"opt = BaseOptions().parse()\n",
"opt.camera.uniform = True\n",
"opt.model.is_test = True\n",
"opt.model.freeze_renderer = False\n",
"opt.rendering.offset_sampling = True\n",
"opt.rendering.static_viewdirs = True\n",
"opt.rendering.force_background = True\n",
"opt.rendering.perturb = 0\n",
"opt.inference.renderer_output_size = opt.model.renderer_spatial_output_dim\n",
"opt.inference.style_dim = opt.model.style_dim\n",
"opt.inference.project_noise = opt.model.project_noise"
],
"metadata": {
"id": "Qfamt8J0JGn5"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "TLOiDzgKtYyi"
},
"source": [
"Don't worry about this message above, \n",
"```\n",
"usage: ipykernel_launcher.py [-h] [--dataset_path DATASET_PATH]\n",
" [--config CONFIG] [--expname EXPNAME]\n",
" [--ckpt CKPT] [--continue_training]\n",
" ...\n",
" ...\n",
"ipykernel_launcher.py: error: unrecognized arguments: -f /root/.local/share/jupyter/runtime/kernel-c9d47a98-bdba-4a5f-9f0a-e1437c7228b6.json\n",
"```\n",
"everything is perfectly fine..."
]
},
{
"cell_type": "markdown",
"source": [
"Here, we define our model.\n",
"\n",
"Set the options below according to your choosing:\n",
"1. If you plan to try the method for the AFHQ dataset (animal faces), change `model_type` to 'afhq'. Default: `ffhq` (human faces).\n",
"2. If you wish to turn off depth rendering and marching cubes extraction and generate only RGB images, set `opt.inference.no_surface_renderings = True`. Default: `False`.\n",
"3. If you wish to generate the image from a specific set of viewpoints, set `opt.inference.fixed_camera_angles = True`. Default: `False`.\n",
"4. Set the number of identities you wish to create in `opt.inference.identities`. Default: `4`.\n",
"5. Select the number of views per identity in `opt.inference.num_views_per_id`,<br>\n",
" (Only applicable when `opt.inference.fixed_camera_angles` is false). Default: `1`. "
],
"metadata": {
"id": "40dk1eEJo9MF"
}
},
{
"cell_type": "code",
"source": [
"# User options\n",
"model_type = 'ffhq' # Whether to load the FFHQ or AFHQ model\n",
"opt.inference.no_surface_renderings = False # When true, only RGB images will be created\n",
"opt.inference.fixed_camera_angles = False # When true, each identity will be rendered from a specific set of 13 viewpoints. Otherwise, random views are generated\n",
"opt.inference.identities = 4 # Number of identities to generate\n",
"opt.inference.num_views_per_id = 1 # Number of viewpoints generated per identity. This option is ignored if opt.inference.fixed_camera_angles is true.\n",
"\n",
"# Load saved model\n",
"if model_type == 'ffhq':\n",
" model_path = 'ffhq1024x1024.pt'\n",
" opt.model.size = 1024\n",
" opt.experiment.expname = 'ffhq1024x1024'\n",
"else:\n",
" opt.inference.camera.azim = 0.15\n",
" model_path = 'afhq512x512.pt'\n",
" opt.model.size = 512\n",
" opt.experiment.expname = 'afhq512x512'\n",
"\n",
"# Create results directory\n",
"result_model_dir = 'final_model'\n",
"results_dir_basename = os.path.join(opt.inference.results_dir, opt.experiment.expname)\n",
"opt.inference.results_dst_dir = os.path.join(results_dir_basename, result_model_dir)\n",
"if opt.inference.fixed_camera_angles:\n",
" opt.inference.results_dst_dir = os.path.join(opt.inference.results_dst_dir, 'fixed_angles')\n",
"else:\n",
" opt.inference.results_dst_dir = os.path.join(opt.inference.results_dst_dir, 'random_angles')\n",
"\n",
"os.makedirs(opt.inference.results_dst_dir, exist_ok=True)\n",
"os.makedirs(os.path.join(opt.inference.results_dst_dir, 'images'), exist_ok=True)\n",
"if not opt.inference.no_surface_renderings:\n",
" os.makedirs(os.path.join(opt.inference.results_dst_dir, 'depth_map_meshes'), exist_ok=True)\n",
" os.makedirs(os.path.join(opt.inference.results_dst_dir, 'marching_cubes_meshes'), exist_ok=True)\n",
"\n",
"opt.inference.camera = opt.camera\n",
"opt.inference.size = opt.model.size\n",
"checkpoint_path = os.path.join('full_models', model_path)\n",
"checkpoint = torch.load(checkpoint_path)\n",
"\n",
"# Load image generation model\n",
"g_ema = Generator(opt.model, opt.rendering).to(device)\n",
"pretrained_weights_dict = checkpoint[\"g_ema\"]\n",
"model_dict = g_ema.state_dict()\n",
"for k, v in pretrained_weights_dict.items():\n",
" if v.size() == model_dict[k].size():\n",
" model_dict[k] = v\n",
"\n",
"g_ema.load_state_dict(model_dict)\n",
"\n",
"# Load a second volume renderer that extracts surfaces at 128x128x128 (or higher) for better surface resolution\n",
"if not opt.inference.no_surface_renderings:\n",
" opt['surf_extraction'] = Munch()\n",
" opt.surf_extraction.rendering = opt.rendering\n",
" opt.surf_extraction.model = opt.model.copy()\n",
" opt.surf_extraction.model.renderer_spatial_output_dim = 128\n",
" opt.surf_extraction.rendering.N_samples = opt.surf_extraction.model.renderer_spatial_output_dim\n",
" opt.surf_extraction.rendering.return_xyz = True\n",
" opt.surf_extraction.rendering.return_sdf = True\n",
" surface_g_ema = Generator(opt.surf_extraction.model, opt.surf_extraction.rendering, full_pipeline=False).to(device)\n",
"\n",
"\n",
" # Load weights to surface extractor\n",
" surface_extractor_dict = surface_g_ema.state_dict()\n",
" for k, v in pretrained_weights_dict.items():\n",
" if k in surface_extractor_dict.keys() and v.size() == surface_extractor_dict[k].size():\n",
" surface_extractor_dict[k] = v\n",
"\n",
" surface_g_ema.load_state_dict(surface_extractor_dict)\n",
"else:\n",
" surface_g_ema = None\n",
"\n",
"# Get the mean latent vector for g_ema\n",
"if opt.inference.truncation_ratio < 1:\n",
" with torch.no_grad():\n",
" mean_latent = g_ema.mean_latent(opt.inference.truncation_mean, device)\n",
"else:\n",
" surface_mean_latent = None\n",
"\n",
"# Get the mean latent vector for surface_g_ema\n",
"if not opt.inference.no_surface_renderings:\n",
" surface_mean_latent = mean_latent[0]\n",
"else:\n",
" surface_mean_latent = None"
],
"metadata": {
"id": "CUcWipIlpINT"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"## Generating images and meshes\n",
"\n",
"Finally, we run the network. The results will be saved to `evaluations/[model_name]/final_model/[fixed/random]_angles`, according to the selected setup."
],
"metadata": {
"id": "N9pqjPDYCwIJ"
}
},
{
"cell_type": "code",
"source": [
"generate(opt.inference, g_ema, surface_g_ema, device, mean_latent, surface_mean_latent)"
],
"metadata": {
"id": "UG4hZgigDfG7"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"Now let's examine the results\n",
"\n",
"Tip: for better mesh visualization, we recommend dowwnloading the result meshes and view them with Meshlab.\n",
"\n",
"Meshes loaction is: `evaluations/[model_name]/final_model/[fixed/random]_angles/[depth_map/marching_cubes]_meshes`."
],
"metadata": {
"id": "obfjXuDxNuZl"
}
},
{
"cell_type": "code",
"source": [
"from PIL import Image\n",
"from trimesh.viewer.notebook import scene_to_html as mesh2html\n",
"from IPython.display import HTML as viewer_html\n",
"\n",
"# First let's look at the images\n",
"img_dir = os.path.join(opt.inference.results_dst_dir,'images')\n",
"im_list = sorted([entry for entry in os.listdir(img_dir) if 'thumb' not in entry])\n",
"img = Image.new('RGB', (256 * len(im_list), 256))\n",
"for i, im_file in enumerate(im_list):\n",
" im_path = os.path.join(img_dir, im_file)\n",
" curr_img = Image.open(im_path).resize((256,256)) # the displayed image is scaled to fit to the screen\n",
" img.paste(curr_img, (256 * i, 0))\n",
"\n",
"display(img)\n",
"\n",
"# And now, we'll move on to display the marching cubes and depth map meshes\n",
"\n",
"marching_cubes_meshes_dir = os.path.join(opt.inference.results_dst_dir,'marching_cubes_meshes')\n",
"marching_cubes_meshes_list = sorted([os.path.join(marching_cubes_meshes_dir, entry) for entry in os.listdir(marching_cubes_meshes_dir) if 'obj' in entry])\n",
"depth_map_meshes_dir = os.path.join(opt.inference.results_dst_dir,'depth_map_meshes')\n",
"depth_map_meshes_list = sorted([os.path.join(depth_map_meshes_dir, entry) for entry in os.listdir(depth_map_meshes_dir) if 'obj' in entry])\n",
"for i, mesh_files in enumerate(zip(marching_cubes_meshes_list, depth_map_meshes_list)):\n",
" mc_mesh_file, dm_mesh_file = mesh_files[0], mesh_files[1]\n",
" marching_cubes_mesh = trimesh.Scene(trimesh.load_mesh(mc_mesh_file, 'obj')) \n",
" curr_mc_html = mesh2html(marching_cubes_mesh).replace('\"', '"')\n",
" display(viewer_html(' '.join(['<iframe srcdoc=\"{srcdoc}\"',\n",
" 'width=\"{width}px\" height=\"{height}px\"',\n",
" 'style=\"border:none;\"></iframe>']).format(\n",
" srcdoc=curr_mc_html, height=256, width=256)))\n",
" depth_map_mesh = trimesh.Scene(trimesh.load_mesh(dm_mesh_file, 'obj')) \n",
" curr_dm_html = mesh2html(depth_map_mesh).replace('\"', '"')\n",
" display(viewer_html(' '.join(['<iframe srcdoc=\"{srcdoc}\"',\n",
" 'width=\"{width}px\" height=\"{height}px\"',\n",
" 'style=\"border:none;\"></iframe>']).format(\n",
" srcdoc=curr_dm_html, height=256, width=256)))"
],
"metadata": {
"id": "k3jXCVT7N2YZ"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"## Generating videos\n",
"\n",
"Additionally, we can also render videos. The results will be saved to `evaluations/[model_name]/final_model/videos`.\n",
"\n",
"Set the options below according to your choosing:\n",
"1. If you wish to generate only RGB videos, set `opt.inference.no_surface_renderings = True`. Default: `False`.\n",
"2. Set the camera trajectory. To travel along the azimuth direction set `opt.inference.azim_video = True`, to travel in an ellipsoid trajectory set `opt.inference.azim_video = False`. Default: `False`.\n",
"\n",
"###Important Note: \n",
" - Processing time for videos when `opt.inference.no_surface_renderings = False` is very lengthy (~ 15-20 minutes per video). Rendering each depth frame for the depth videos is very slow.<br>\n",
" - Processing time for videos when `opt.inference.no_surface_renderings = True` is much faster (~ 1-2 minutes per video)"
],
"metadata": {
"id": "Cxq3c6anN4D0"
}
},
{
"cell_type": "code",
"source": [
"# Options\n",
"opt.inference.no_surface_renderings = True # When true, only RGB videos will be created\n",
"opt.inference.azim_video = True # When true, the camera trajectory will travel along the azimuth direction. Otherwise, the camera will travel along an ellipsoid trajectory.\n",
"\n",
"opt.inference.results_dst_dir = os.path.join(os.path.split(opt.inference.results_dst_dir)[0], 'videos')\n",
"os.makedirs(opt.inference.results_dst_dir, exist_ok=True)\n",
"render_video(opt.inference, g_ema, surface_g_ema, device, mean_latent, surface_mean_latent)"
],
"metadata": {
"id": "nblhnZgcOST8"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"Let's watch the result videos.\n",
"\n",
"The output video files are relatively large, so it might take a while (about 1-2 minutes) for all of them to be loaded. "
],
"metadata": {
"id": "jkBYvlaeTGu1"
}
},
{
"cell_type": "code",
"source": [
"%%script bash --bg\n",
"python3 -m https.server 8000"
],
"metadata": {
"id": "s-A9oIFXnYdM"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"# change ffhq1024x1024 to afhq512x512 if you are working on the AFHQ model\n",
"%%html\n",
"<div>\n",
" <video width=256 controls><source src=\"https://localhost:8000/evaluations/ffhq1024x1024/final_model/videos/sample_video_0_azim.mp4\" type=\"video/mp4\"></video>\n",
" <video width=256 controls><source src=\"https://localhost:8000/evaluations/ffhq1024x1024/final_model/videos/sample_video_1_azim.mp4\" type=\"video/mp4\"></video>\n",
" <video width=256 controls><source src=\"https://localhost:8000/evaluations/ffhq1024x1024/final_model/videos/sample_video_2_azim.mp4\" type=\"video/mp4\"></video>\n",
" <video width=256 controls><source src=\"https://localhost:8000/evaluations/ffhq1024x1024/final_model/videos/sample_video_3_azim.mp4\" type=\"video/mp4\"></video>\n",
"</div>"
],
"metadata": {
"id": "Rz8tq-yYnZMD"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"An alternative way to view the videos with python code. \n",
"It loads the videos faster, but very often it crashes the notebook since the video file are too large.\n",
"\n",
"**It is not recommended to view the files this way**.\n",
"\n",
"If the notebook does crash, you can also refresh the webpage and manually download the videos.<br>\n",
"The videos are located in `evaluations/<model_name>/final_model/videos`"
],
"metadata": {
"id": "KP_Nrl8yqL65"
}
},
{
"cell_type": "code",
"source": [
"# from base64 import b64encode\n",
"\n",
"# videos_dir = opt.inference.results_dst_dir\n",
"# videos_list = sorted([os.path.join(videos_dir, entry) for entry in os.listdir(videos_dir) if 'mp4' in entry])\n",
"# for i, video_file in enumerate(videos_list):\n",
"# if i != 1:\n",
"# continue\n",
"# mp4 = open(video_file,'rb').read()\n",
"# data_url = \"data:video/mp4;base64,\" + b64encode(mp4).decode()\n",
"# display(viewer_html(\"\"\"<video width={0} controls>\n",
"# <source src=\"{1}\" type=\"{2}\">\n",
"# </video>\"\"\".format(256, data_url, \"video/mp4\")))"
],
"metadata": {
"id": "LzsQTq1hTNSH"
},
"execution_count": null,
"outputs": []
}
]
}
================================================
FILE: dataset.py
================================================
import os
import csv
import lmdb
import random
import numpy as np
import torchvision.transforms.functional as TF
from PIL import Image
from io import BytesIO
from torch.utils.data import Dataset
class MultiResolutionDataset(Dataset):
def __init__(self, path, transform, resolution=256, nerf_resolution=64):
self.env = lmdb.open(
path,
max_readers=32,
readonly=True,
lock=False,
readahead=False,
meminit=False,
)
if not self.env:
raise IOError('Cannot open lmdb dataset', path)
with self.env.begin(write=False) as txn:
self.length = int(txn.get('length'.encode('utf-8')).decode('utf-8'))
self.resolution = resolution
self.nerf_resolution = nerf_resolution
self.transform = transform
def __len__(self):
return self.length
def __getitem__(self, index):
with self.env.begin(write=False) as txn:
key = f'{self.resolution}-{str(index).zfill(5)}'.encode('utf-8')
img_bytes = txn.get(key)
buffer = BytesIO(img_bytes)
img = Image.open(buffer)
if random.random() > 0.5:
img = TF.hflip(img)
thumb_img = img.resize((self.nerf_resolution, self.nerf_resolution), Image.HAMMING)
img = self.transform(img)
thumb_img = self.transform(thumb_img)
return img, thumb_img
================================================
FILE: distributed.py
================================================
import math
import pickle
import torch
from torch import distributed as dist
from torch.utils.data.sampler import Sampler
def get_rank():
if not dist.is_available():
return 0
if not dist.is_initialized():
return 0
return dist.get_rank()
def synchronize():
if not dist.is_available():
return
if not dist.is_initialized():
return
world_size = dist.get_world_size()
if world_size == 1:
return
dist.barrier()
def get_world_size():
if not dist.is_available():
return 1
if not dist.is_initialized():
return 1
return dist.get_world_size()
def reduce_sum(tensor):
if not dist.is_available():
return tensor
if not dist.is_initialized():
return tensor
tensor = tensor.clone()
dist.all_reduce(tensor, op=dist.ReduceOp.SUM)
return tensor
def gather_grad(params):
world_size = get_world_size()
if world_size == 1:
return
for param in params:
if param.grad is not None:
dist.all_reduce(param.grad.data, op=dist.ReduceOp.SUM)
param.grad.data.div_(world_size)
def all_gather(data):
world_size = get_world_size()
if world_size == 1:
return [data]
buffer = pickle.dumps(data)
storage = torch.ByteStorage.from_buffer(buffer)
tensor = torch.ByteTensor(storage).to('cuda')
local_size = torch.IntTensor([tensor.numel()]).to('cuda')
size_list = [torch.IntTensor([0]).to('cuda') for _ in range(world_size)]
dist.all_gather(size_list, local_size)
size_list = [int(size.item()) for size in size_list]
max_size = max(size_list)
tensor_list = []
for _ in size_list:
tensor_list.append(torch.ByteTensor(size=(max_size,)).to('cuda'))
if local_size != max_size:
padding = torch.ByteTensor(size=(max_size - local_size,)).to('cuda')
tensor = torch.cat((tensor, padding), 0)
dist.all_gather(tensor_list, tensor)
data_list = []
for size, tensor in zip(size_list, tensor_list):
buffer = tensor.cpu().numpy().tobytes()[:size]
data_list.append(pickle.loads(buffer))
return data_list
def reduce_loss_dict(loss_dict):
world_size = get_world_size()
if world_size < 2:
return loss_dict
with torch.no_grad():
keys = []
losses = []
for k in sorted(loss_dict.keys()):
keys.append(k)
losses.append(loss_dict[k])
losses = torch.stack(losses, 0)
dist.reduce(losses, dst=0)
if dist.get_rank() == 0:
losses /= world_size
reduced_losses = {k: v for k, v in zip(keys, losses)}
return reduced_losses
================================================
FILE: download_models.py
================================================
import os
import html
import glob
import uuid
import hashlib
import requests
from tqdm import tqdm
from pdb import set_trace as st
ffhq_full_model_spec = dict(file_url='https://drive.google.com/uc?id=13s_dH768zJ3IHUjySVbD1DqcMolqKlBi',
alt_url='', file_size=202570217, file_md5='1ab9522157e537351fcf4faed9c92abb',
file_path='full_models/ffhq1024x1024.pt',)
ffhq_volume_renderer_spec = dict(file_url='https://drive.google.com/uc?id=1zzB-ACuas7lSAln8pDnIqlWOEg869CCK',
alt_url='', file_size=63736197, file_md5='fe62f26032ccc8f04e1101b07bcd7462',
file_path='pretrained_renderer/ffhq_vol_renderer.pt',)
afhq_full_model_spec = dict(file_url='https://drive.google.com/uc?id=1jZcV5__EPS56JBllRmUfUjMHExql_eSq',
alt_url='', file_size=184908743, file_md5='91eeaab2da5c0d2134c04bb56ac5aeb6',
file_path='full_models/afhq512x512.pt',)
afhq_volume_renderer_spec = dict(file_url='https://drive.google.com/uc?id=1xhZjuJt_teghAQEoJevAxrU5_8VqqX42',
alt_url='', file_size=63736197, file_md5='cb3beb6cd3c43d9119356400165e7f26',
file_path='pretrained_renderer/afhq_vol_renderer.pt',)
volume_renderer_sphere_init_spec = dict(file_url='https://drive.google.com/uc?id=1CcYWbHFrJyb4u5rBFx_ww-ceYWd309tk',
alt_url='', file_size=63736197, file_md5='a6be653435b0e49633e6838f18ff4df6',
file_path='pretrained_renderer/sphere_init.pt',)
def download_pretrained_models():
print('Downloading sphere initialized volume renderer')
with requests.Session() as session:
try:
download_file(session, volume_renderer_sphere_init_spec)
except:
print('Google Drive download failed.\n' \
'Trying do download from alternate server')
download_file(session, volume_renderer_sphere_init_spec, use_alt_url=True)
print('Downloading FFHQ pretrained volume renderer')
with requests.Session() as session:
try:
download_file(session, ffhq_volume_renderer_spec)
except:
print('Google Drive download failed.\n' \
'Trying do download from alternate server')
download_file(session, ffhq_volume_renderer_spec, use_alt_url=True)
print('Downloading FFHQ full model (1024x1024)')
with requests.Session() as session:
try:
download_file(session, ffhq_full_model_spec)
except:
print('Google Drive download failed.\n' \
'Trying do download from alternate server')
download_file(session, ffhq_full_model_spec, use_alt_url=True)
print('Done!')
print('Downloading AFHQ pretrained volume renderer')
with requests.Session() as session:
try:
download_file(session, afhq_volume_renderer_spec)
except:
print('Google Drive download failed.\n' \
'Trying do download from alternate server')
download_file(session, afhq_volume_renderer_spec, use_alt_url=True)
print('Done!')
print('Downloading Downloading AFHQ full model (512x512)')
with requests.Session() as session:
try:
download_file(session, afhq_full_model_spec)
except:
print('Google Drive download failed.\n' \
'Trying do download from alternate server')
download_file(session, afhq_full_model_spec, use_alt_url=True)
print('Done!')
def download_file(session, file_spec, use_alt_url=False, chunk_size=128, num_attempts=10):
file_path = file_spec['file_path']
if use_alt_url:
file_url = file_spec['alt_url']
else:
file_url = file_spec['file_url']
file_dir = os.path.dirname(file_path)
tmp_path = file_path + '.tmp.' + uuid.uuid4().hex
if file_dir:
os.makedirs(file_dir, exist_ok=True)
progress_bar = tqdm(total=file_spec['file_size'], unit='B', unit_scale=True)
for attempts_left in reversed(range(num_attempts)):
data_size = 0
progress_bar.reset()
try:
# Download.
data_md5 = hashlib.md5()
with session.get(file_url, stream=True) as res:
res.raise_for_status()
with open(tmp_path, 'wb') as f:
for chunk in res.iter_content(chunk_size=chunk_size<<10):
progress_bar.update(len(chunk))
f.write(chunk)
data_size += len(chunk)
data_md5.update(chunk)
# Validate.
if 'file_size' in file_spec and data_size != file_spec['file_size']:
raise IOError('Incorrect file size', file_path)
if 'file_md5' in file_spec and data_md5.hexdigest() != file_spec['file_md5']:
raise IOError('Incorrect file MD5', file_path)
break
except:
# Last attempt => raise error.
if not attempts_left:
raise
# Handle Google Drive virus checker nag.
if data_size > 0 and data_size < 8192:
with open(tmp_path, 'rb') as f:
data = f.read()
links = [html.unescape(link) for link in data.decode('utf-8').split('"') if 'confirm=t' in link]
if len(links) == 1:
file_url = requests.compat.urljoin(file_url, links[0])
continue
progress_bar.close()
# Rename temp file to the correct name.
os.replace(tmp_path, file_path) # atomic
# Attempt to clean up any leftover temps.
for filename in glob.glob(file_path + '.tmp.*'):
try:
os.remove(filename)
except:
pass
if __name__ == "__main__":
download_pretrained_models()
================================================
FILE: generate_shapes_and_images.py
================================================
import os
import torch
import trimesh
import numpy as np
from munch import *
from PIL import Image
from tqdm import tqdm
from torch.nn import functional as F
from torch.utils import data
from torchvision import utils
from torchvision import transforms
from skimage.measure import marching_cubes
from scipy.spatial import Delaunay
from options import BaseOptions
from model import Generator
from utils import (
generate_camera_params,
align_volume,
extract_mesh_with_marching_cubes,
xyz2mesh,
)
torch.random.manual_seed(1234)
def generate(opt, g_ema, surface_g_ema, device, mean_latent, surface_mean_latent):
g_ema.eval()
if not opt.no_surface_renderings:
surface_g_ema.eval()
# set camera angles
if opt.fixed_camera_angles:
# These can be changed to any other specific viewpoints.
# You can add or remove viewpoints as you wish
locations = torch.tensor([[0, 0],
[-1.5 * opt.camera.azim, 0],
[-1 * opt.camera.azim, 0],
[-0.5 * opt.camera.azim, 0],
[0.5 * opt.camera.azim, 0],
[1 * opt.camera.azim, 0],
[1.5 * opt.camera.azim, 0],
[0, -1.5 * opt.camera.elev],
[0, -1 * opt.camera.elev],
[0, -0.5 * opt.camera.elev],
[0, 0.5 * opt.camera.elev],
[0, 1 * opt.camera.elev],
[0, 1.5 * opt.camera.elev]], device=device)
# For zooming in/out change the values of fov
# (This can be defined for each view separately via a custom tensor
# like the locations tensor above. Tensor shape should be [locations.shape[0],1])
# reasonable values are [0.75 * opt.camera.fov, 1.25 * opt.camera.fov]
fov = opt.camera.fov * torch.ones((locations.shape[0],1), device=device)
num_viewdirs = locations.shape[0]
else: # draw random camera angles
locations = None
# fov = None
fov = opt.camera.fov
num_viewdirs = opt.num_views_per_id
# generate images
for i in tqdm(range(opt.identities)):
with torch.no_grad():
chunk = 8
sample_z = torch.randn(1, opt.style_dim, device=device).repeat(num_viewdirs,1)
sample_cam_extrinsics, sample_focals, sample_near, sample_far, sample_locations = \
generate_camera_params(opt.renderer_output_size, device, batch=num_viewdirs,
locations=locations, #input_fov=fov,
uniform=opt.camera.uniform, azim_range=opt.camera.azim,
elev_range=opt.camera.elev, fov_ang=fov,
dist_radius=opt.camera.dist_radius)
rgb_images = torch.Tensor(0, 3, opt.size, opt.size)
rgb_images_thumbs = torch.Tensor(0, 3, opt.renderer_output_size, opt.renderer_output_size)
for j in range(0, num_viewdirs, chunk):
out = g_ema([sample_z[j:j+chunk]],
sample_cam_extrinsics[j:j+chunk],
sample_focals[j:j+chunk],
sample_near[j:j+chunk],
sample_far[j:j+chunk],
truncation=opt.truncation_ratio,
truncation_latent=mean_latent)
rgb_images = torch.cat([rgb_images, out[0].cpu()], 0)
rgb_images_thumbs = torch.cat([rgb_images_thumbs, out[1].cpu()], 0)
utils.save_image(rgb_images,
os.path.join(opt.results_dst_dir, 'images','{}.png'.format(str(i).zfill(7))),
nrow=num_viewdirs,
normalize=True,
padding=0,
value_range=(-1, 1),)
utils.save_image(rgb_images_thumbs,
os.path.join(opt.results_dst_dir, 'images','{}_thumb.png'.format(str(i).zfill(7))),
nrow=num_viewdirs,
normalize=True,
padding=0,
value_range=(-1, 1),)
# this is done to fit to RTX2080 RAM size (11GB)
del out
torch.cuda.empty_cache()
if not opt.no_surface_renderings:
surface_chunk = 1
scale = surface_g_ema.renderer.out_im_res / g_ema.renderer.out_im_res
surface_sample_focals = sample_focals * scale
for j in range(0, num_viewdirs, surface_chunk):
surface_out = surface_g_ema([sample_z[j:j+surface_chunk]],
sample_cam_extrinsics[j:j+surface_chunk],
surface_sample_focals[j:j+surface_chunk],
sample_near[j:j+surface_chunk],
sample_far[j:j+surface_chunk],
truncation=opt.truncation_ratio,
truncation_latent=surface_mean_latent,
return_sdf=True,
return_xyz=True)
xyz = surface_out[2].cpu()
sdf = surface_out[3].cpu()
# this is done to fit to RTX2080 RAM size (11GB)
del surface_out
torch.cuda.empty_cache()
# mesh extractions are done one at a time
for k in range(surface_chunk):
curr_locations = sample_locations[j:j+surface_chunk]
loc_str = '_azim{}_elev{}'.format(int(curr_locations[k,0] * 180 / np.pi),
int(curr_locations[k,1] * 180 / np.pi))
# Save depth outputs as meshes
depth_mesh_filename = os.path.join(opt.results_dst_dir,'depth_map_meshes','sample_{}_depth_mesh{}.obj'.format(i, loc_str))
depth_mesh = xyz2mesh(xyz[k:k+surface_chunk])
if depth_mesh != None:
with open(depth_mesh_filename, 'w') as f:
depth_mesh.export(f,file_type='obj')
# extract full geometry with marching cubes
if j == 0:
try:
frostum_aligned_sdf = align_volume(sdf)
marching_cubes_mesh = extract_mesh_with_marching_cubes(frostum_aligned_sdf[k:k+surface_chunk])
except ValueError:
marching_cubes_mesh = None
print('Marching cubes extraction failed.')
print('Please check whether the SDF values are all larger (or all smaller) than 0.')
if marching_cubes_mesh != None:
marching_cubes_mesh_filename = os.path.join(opt.results_dst_dir,'marching_cubes_meshes','sample_{}_marching_cubes_mesh{}.obj'.format(i, loc_str))
with open(marching_cubes_mesh_filename, 'w') as f:
marching_cubes_mesh.export(f,file_type='obj')
if __name__ == "__main__":
device = "cuda"
opt = BaseOptions().parse()
opt.model.is_test = True
opt.model.freeze_renderer = False
opt.rendering.offset_sampling = True
opt.rendering.static_viewdirs = True
opt.rendering.force_background = True
opt.rendering.perturb = 0
opt.inference.size = opt.model.size
opt.inference.camera = opt.camera
opt.inference.renderer_output_size = opt.model.renderer_spatial_output_dim
opt.inference.style_dim = opt.model.style_dim
opt.inference.project_noise = opt.model.project_noise
opt.inference.return_xyz = opt.rendering.return_xyz
# find checkpoint directory
# check if there's a fully trained model
checkpoints_dir = 'full_models'
checkpoint_path = os.path.join(checkpoints_dir, opt.experiment.expname + '.pt')
if os.path.isfile(checkpoint_path):
# define results directory name
result_model_dir = 'final_model'
else:
checkpoints_dir = os.path.join('checkpoint', opt.experiment.expname, 'full_pipeline')
checkpoint_path = os.path.join(checkpoints_dir,
'models_{}.pt'.format(opt.experiment.ckpt.zfill(7)))
# define results directory name
result_model_dir = 'iter_{}'.format(opt.experiment.ckpt.zfill(7))
# create results directory
results_dir_basename = os.path.join(opt.inference.results_dir, opt.experiment.expname)
opt.inference.results_dst_dir = os.path.join(results_dir_basename, result_model_dir)
if opt.inference.fixed_camera_angles:
opt.inference.results_dst_dir = os.path.join(opt.inference.results_dst_dir, 'fixed_angles')
else:
opt.inference.results_dst_dir = os.path.join(opt.inference.results_dst_dir, 'random_angles')
os.makedirs(opt.inference.results_dst_dir, exist_ok=True)
os.makedirs(os.path.join(opt.inference.results_dst_dir, 'images'), exist_ok=True)
if not opt.inference.no_surface_renderings:
os.makedirs(os.path.join(opt.inference.results_dst_dir, 'depth_map_meshes'), exist_ok=True)
os.makedirs(os.path.join(opt.inference.results_dst_dir, 'marching_cubes_meshes'), exist_ok=True)
# load saved model
checkpoint = torch.load(checkpoint_path)
# load image generation model
g_ema = Generator(opt.model, opt.rendering).to(device)
pretrained_weights_dict = checkpoint["g_ema"]
model_dict = g_ema.state_dict()
for k, v in pretrained_weights_dict.items():
if v.size() == model_dict[k].size():
model_dict[k] = v
g_ema.load_state_dict(model_dict)
# load a second volume renderer that extracts surfaces at 128x128x128 (or higher) for better surface resolution
if not opt.inference.no_surface_renderings:
opt['surf_extraction'] = Munch()
opt.surf_extraction.rendering = opt.rendering
opt.surf_extraction.model = opt.model.copy()
opt.surf_extraction.model.renderer_spatial_output_dim = 128
opt.surf_extraction.rendering.N_samples = opt.surf_extraction.model.renderer_spatial_output_dim
opt.surf_extraction.rendering.return_xyz = True
opt.surf_extraction.rendering.return_sdf = True
surface_g_ema = Generator(opt.surf_extraction.model, opt.surf_extraction.rendering, full_pipeline=False).to(device)
# Load weights to surface extractor
surface_extractor_dict = surface_g_ema.state_dict()
for k, v in pretrained_weights_dict.items():
if k in surface_extractor_dict.keys() and v.size() == surface_extractor_dict[k].size():
surface_extractor_dict[k] = v
surface_g_ema.load_state_dict(surface_extractor_dict)
else:
surface_g_ema = None
# get the mean latent vector for g_ema
if opt.inference.truncation_ratio < 1:
with torch.no_grad():
mean_latent = g_ema.mean_latent(opt.inference.truncation_mean, device)
else:
surface_mean_latent = None
# get the mean latent vector for surface_g_ema
if not opt.inference.no_surface_renderings:
surface_mean_latent = mean_latent[0]
else:
surface_mean_latent = None
generate(opt.inference, g_ema, surface_g_ema, device, mean_latent, surface_mean_latent)
================================================
FILE: losses.py
================================================
import math
import torch
from torch import autograd
from torch.nn import functional as F
def viewpoints_loss(viewpoint_pred, viewpoint_target):
loss = F.smooth_l1_loss(viewpoint_pred, viewpoint_target)
return loss
def eikonal_loss(eikonal_term, sdf=None, beta=100):
if eikonal_term == None:
eikonal_loss = 0
else:
eikonal_loss = ((eikonal_term.norm(dim=-1) - 1) ** 2).mean()
if sdf == None:
minimal_surface_loss = torch.tensor(0.0, device=eikonal_term.device)
else:
minimal_surface_loss = torch.exp(-beta * torch.abs(sdf)).mean()
return eikonal_loss, minimal_surface_loss
def d_logistic_loss(real_pred, fake_pred):
real_loss = F.softplus(-real_pred)
fake_loss = F.softplus(fake_pred)
return real_loss.mean() + fake_loss.mean()
def d_r1_loss(real_pred, real_img):
grad_real, = autograd.grad(outputs=real_pred.sum(),
inputs=real_img,
create_graph=True)
grad_penalty = grad_real.pow(2).reshape(grad_real.shape[0], -1).sum(1).mean()
return grad_penalty
def g_nonsaturating_loss(fake_pred):
loss = F.softplus(-fake_pred).mean()
return loss
def g_path_regularize(fake_img, latents, mean_path_length, decay=0.01):
noise = torch.randn_like(fake_img) / math.sqrt(
fake_img.shape[2] * fake_img.shape[3]
)
grad = autograd.grad(outputs=(fake_img * noise).sum(), inputs=latents,
create_graph=True, only_inputs=True)[0]
path_lengths = torch.sqrt(grad.pow(2).sum(2).mean(1))
path_mean = mean_path_length + decay * (path_lengths.mean() - mean_path_length)
path_penalty = (path_lengths - path_mean).pow(2).mean()
return path_penalty, path_mean.detach(), path_lengths
================================================
FILE: model.py
================================================
import math
import random
import trimesh
import torch
import numpy as np
from torch import nn
from torch.nn import functional as F
from volume_renderer import VolumeFeatureRenderer
from op import FusedLeakyReLU, fused_leaky_relu, upfirdn2d
from pdb import set_trace as st
from utils import (
create_cameras,
create_mesh_renderer,
add_textures,
create_depth_mesh_renderer,
)
from pytorch3d.renderer import TexturesUV
from pytorch3d.structures import Meshes
from pytorch3d.renderer import look_at_view_transform, FoVPerspectiveCameras
from pytorch3d.transforms import matrix_to_euler_angles
class PixelNorm(nn.Module):
def __init__(self):
super().__init__()
def forward(self, input):
return input * torch.rsqrt(torch.mean(input ** 2, dim=1, keepdim=True) + 1e-8)
class MappingLinear(nn.Module):
def __init__(self, in_dim, out_dim, bias=True, activation=None, is_last=False):
super().__init__()
if is_last:
weight_std = 0.25
else:
weight_std = 1
self.weight = nn.Parameter(weight_std * nn.init.kaiming_normal_(torch.empty(out_dim, in_dim), a=0.2, mode='fan_in', nonlinearity='leaky_relu'))
if bias:
self.bias = nn.Parameter(nn.init.uniform_(torch.empty(out_dim), a=-np.sqrt(1/in_dim), b=np.sqrt(1/in_dim)))
else:
self.bias = None
self.activation = activation
def forward(self, input):
if self.activation != None:
out = F.linear(input, self.weight)
out = fused_leaky_relu(out, self.bias, scale=1)
else:
out = F.linear(input, self.weight, bias=self.bias)
return out
def __repr__(self):
return (
f"{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]})"
)
def make_kernel(k):
k = torch.tensor(k, dtype=torch.float32)
if k.ndim == 1:
k = k[None, :] * k[:, None]
k /= k.sum()
return k
class Upsample(nn.Module):
def __init__(self, kernel, factor=2):
super().__init__()
self.factor = factor
kernel = make_kernel(kernel) * (factor ** 2)
self.register_buffer("kernel", kernel)
p = kernel.shape[0] - factor
pad0 = (p + 1) // 2 + factor - 1
pad1 = p // 2
self.pad = (pad0, pad1)
def forward(self, input):
out = upfirdn2d(input, self.kernel, up=self.factor, down=1, pad=self.pad)
return out
class Downsample(nn.Module):
def __init__(self, kernel, factor=2):
super().__init__()
self.factor = factor
kernel = make_kernel(kernel)
self.register_buffer("kernel", kernel)
p = kernel.shape[0] - factor
pad0 = (p + 1) // 2
pad1 = p // 2
self.pad = (pad0, pad1)
def forward(self, input):
out = upfirdn2d(input, self.kernel, up=1, down=self.factor, pad=self.pad)
return out
class Blur(nn.Module):
def __init__(self, kernel, pad, upsample_factor=1):
super().__init__()
kernel = make_kernel(kernel)
if upsample_factor > 1:
kernel = kernel * (upsample_factor ** 2)
self.register_buffer("kernel", kernel)
self.pad = pad
def forward(self, input):
out = upfirdn2d(input, self.kernel, pad=self.pad)
return out
class EqualConv2d(nn.Module):
def __init__(
self, in_channel, out_channel, kernel_size, stride=1, padding=0, bias=True
):
super().__init__()
self.weight = nn.Parameter(
torch.randn(out_channel, in_channel, kernel_size, kernel_size)
)
self.scale = 1 / math.sqrt(in_channel * kernel_size ** 2)
self.stride = stride
self.padding = padding
if bias:
self.bias = nn.Parameter(torch.zeros(out_channel))
else:
self.bias = None
def forward(self, input):
out = F.conv2d(
input,
self.weight * self.scale,
bias=self.bias,
stride=self.stride,
padding=self.padding,
)
return out
def __repr__(self):
return (
f"{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]},"
f" {self.weight.shape[2]}, stride={self.stride}, padding={self.padding})"
)
class EqualLinear(nn.Module):
def __init__(self, in_dim, out_dim, bias=True, bias_init=0, lr_mul=1,
activation=None):
super().__init__()
self.weight = nn.Parameter(torch.randn(out_dim, in_dim).div_(lr_mul))
if bias:
self.bias = nn.Parameter(torch.zeros(out_dim).fill_(bias_init))
else:
self.bias = None
self.activation = activation
self.scale = (1 / math.sqrt(in_dim)) * lr_mul
self.lr_mul = lr_mul
def forward(self, input):
if self.activation:
out = F.linear(input, self.weight * self.scale)
out = fused_leaky_relu(out, self.bias * self.lr_mul)
else:
out = F.linear(input, self.weight * self.scale,
bias=self.bias * self.lr_mul)
return out
def __repr__(self):
return (
f"{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]})"
)
class ModulatedConv2d(nn.Module):
def __init__(self, in_channel, out_channel, kernel_size, style_dim, demodulate=True,
upsample=False, downsample=False, blur_kernel=[1, 3, 3, 1]):
super().__init__()
self.eps = 1e-8
self.kernel_size = kernel_size
self.in_channel = in_channel
self.out_channel = out_channel
self.upsample = upsample
self.downsample = downsample
if upsample:
factor = 2
p = (len(blur_kernel) - factor) - (kernel_size - 1)
pad0 = (p + 1) // 2 + factor - 1
pad1 = p // 2 + 1
self.blur = Blur(blur_kernel, pad=(pad0, pad1), upsample_factor=factor)
if downsample:
factor = 2
p = (len(blur_kernel) - factor) + (kernel_size - 1)
pad0 = (p + 1) // 2
pad1 = p // 2
self.blur = Blur(blur_kernel, pad=(pad0, pad1))
fan_in = in_channel * kernel_size ** 2
self.scale = 1 / math.sqrt(fan_in)
self.padding = kernel_size // 2
self.weight = nn.Parameter(
torch.randn(1, out_channel, in_channel, kernel_size, kernel_size)
)
self.modulation = EqualLinear(style_dim, in_channel, bias_init=1)
self.demodulate = demodulate
def __repr__(self):
return (
f"{self.__class__.__name__}({self.in_channel}, {self.out_channel}, {self.kernel_size}, "
f"upsample={self.upsample}, downsample={self.downsample})"
)
def forward(self, input, style):
batch, in_channel, height, width = input.shape
style = self.modulation(style).view(batch, 1, in_channel, 1, 1)
weight = self.scale * self.weight * style
if self.demodulate:
demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + 1e-8)
weight = weight * demod.view(batch, self.out_channel, 1, 1, 1)
weight = weight.view(
batch * self.out_channel, in_channel, self.kernel_size, self.kernel_size
)
if self.upsample:
input = input.view(1, batch * in_channel, height, width)
weight = weight.view(
batch, self.out_channel, in_channel, self.kernel_size, self.kernel_size
)
weight = weight.transpose(1, 2).reshape(
batch * in_channel, self.out_channel, self.kernel_size, self.kernel_size
)
out = F.conv_transpose2d(input, weight, padding=0, stride=2, groups=batch)
_, _, height, width = out.shape
out = out.view(batch, self.out_channel, height, width)
out = self.blur(out)
elif self.downsample:
input = self.blur(input)
_, _, height, width = input.shape
input = input.view(1, batch * in_channel, height, width)
out = F.conv2d(input, weight, padding=0, stride=2, groups=batch)
_, _, height, width = out.shape
out = out.view(batch, self.out_channel, height, width)
else:
input = input.view(1, batch * in_channel, height, width)
out = F.conv2d(input, weight, padding=self.padding, groups=batch)
_, _, height, width = out.shape
out = out.view(batch, self.out_channel, height, width)
return out
class NoiseInjection(nn.Module):
def __init__(self, project=False):
super().__init__()
self.project = project
self.weight = nn.Parameter(torch.zeros(1))
self.prev_noise = None
self.mesh_fn = None
self.vert_noise = None
def create_pytorch_mesh(self, trimesh):
v=trimesh.vertices; f=trimesh.faces
verts = torch.from_numpy(np.asarray(v)).to(torch.float32).cuda()
mesh_pytorch = Meshes(
verts=[verts],
faces = [torch.from_numpy(np.asarray(f)).to(torch.float32).cuda()],
textures=None
)
if self.vert_noise == None or verts.shape[0] != self.vert_noise.shape[1]:
self.vert_noise = torch.ones_like(verts)[:,0:1].cpu().normal_().expand(-1,3).unsqueeze(0)
mesh_pytorch = add_textures(meshes=mesh_pytorch, vertex_colors=self.vert_noise.to(verts.device))
return mesh_pytorch
def load_mc_mesh(self, filename, resolution=128, im_res=64):
import trimesh
mc_tri=trimesh.load_mesh(filename)
v=mc_tri.vertices; f=mc_tri.faces
mesh2=trimesh.base.Trimesh(vertices=v, faces=f)
if im_res==64 or im_res==128:
pytorch3d_mesh = self.create_pytorch_mesh(mesh2)
return pytorch3d_mesh
v,f = trimesh.remesh.subdivide(v,f)
mesh2_subdiv = trimesh.base.Trimesh(vertices=v, faces=f)
if im_res==256:
pytorch3d_mesh = self.create_pytorch_mesh(mesh2_subdiv);
return pytorch3d_mesh
v,f = trimesh.remesh.subdivide(mesh2_subdiv.vertices,mesh2_subdiv.faces)
mesh3_subdiv = trimesh.base.Trimesh(vertices=v, faces=f)
if im_res==256:
pytorch3d_mesh = self.create_pytorch_mesh(mesh3_subdiv);
return pytorch3d_mesh
v,f = trimesh.remesh.subdivide(mesh3_subdiv.vertices,mesh3_subdiv.faces)
mesh4_subdiv = trimesh.base.Trimesh(vertices=v, faces=f)
pytorch3d_mesh = self.create_pytorch_mesh(mesh4_subdiv)
return pytorch3d_mesh
def project_noise(self, noise, transform, mesh_path=None):
batch, _, height, width = noise.shape
assert(batch == 1) # assuming during inference batch size is 1
angles = matrix_to_euler_angles(transform[0:1,:,:3], "ZYX")
azim = float(angles[0][1])
elev = float(-angles[0][2])
cameras = create_cameras(azim=azim*180/np.pi,elev=elev*180/np.pi,fov=12.,dist=1)
renderer = create_depth_mesh_renderer(cameras, image_size=height,
specular_color=((0,0,0),), ambient_color=((1.,1.,1.),),diffuse_color=((0,0,0),))
if self.mesh_fn is None or self.mesh_fn != mesh_path:
self.mesh_fn = mesh_path
pytorch3d_mesh = self.load_mc_mesh(mesh_path, im_res=height)
rgb, depth = renderer(pytorch3d_mesh)
depth_max = depth.max(-1)[0].view(-1) # (NxN)
depth_valid = depth_max > 0.
if self.prev_noise is None:
self.prev_noise = noise
noise_copy = self.prev_noise.clone()
noise_copy.view(-1)[depth_valid] = rgb[0,:,:,0].view(-1)[depth_valid]
noise_copy = noise_copy.reshape(1,1,height,height) # 1x1xNxN
return noise_copy
def forward(self, image, noise=None, transform=None, mesh_path=None):
batch, _, height, width = image.shape
if noise is None:
noise = image.new_empty(batch, 1, height, width).normal_()
elif self.project:
noise = self.project_noise(noise, transform, mesh_path=mesh_path)
return image + self.weight * noise
class StyledConv(nn.Module):
def __init__(self, in_channel, out_channel, kernel_size, style_dim,
upsample=False, blur_kernel=[1, 3, 3, 1], project_noise=False):
super().__init__()
self.conv = ModulatedConv2d(
in_channel,
out_channel,
kernel_size,
style_dim,
upsample=upsample,
blur_kernel=blur_kernel,
)
self.noise = NoiseInjection(project=project_noise)
self.bias = nn.Parameter(torch.zeros(1, out_channel, 1, 1))
self.activate = FusedLeakyReLU(out_channel)
def forward(self, input, style, noise=None, transform=None, mesh_path=None):
out = self.conv(input, style)
out = self.noise(out, noise=noise, transform=transform, mesh_path=mesh_path)
out = self.activate(out)
return out
class ToRGB(nn.Module):
def __init__(self, in_channel, style_dim, upsample=True, blur_kernel=[1, 3, 3, 1]):
super().__init__()
self.upsample = upsample
out_channels = 3
if upsample:
self.upsample = Upsample(blur_kernel)
self.conv = ModulatedConv2d(in_channel, out_channels, 1, style_dim, demodulate=False)
self.bias = nn.Parameter(torch.zeros(1, out_channels, 1, 1))
def forward(self, input, style, skip=None):
out = self.conv(input, style)
out = out + self.bias
if skip is not None:
if self.upsample:
skip = self.upsample(skip)
out = out + skip
return out
class ConvLayer(nn.Sequential):
def __init__(self, in_channel, out_channel, kernel_size, downsample=False,
blur_kernel=[1, 3, 3, 1], bias=True, activate=True):
layers = []
if downsample:
factor = 2
p = (len(blur_kernel) - factor) + (kernel_size - 1)
pad0 = (p + 1) // 2
pad1 = p // 2
layers.append(Blur(blur_kernel, pad=(pad0, pad1)))
stride = 2
self.padding = 0
else:
stride = 1
self.padding = kernel_size // 2
layers.append(
EqualConv2d(
in_channel,
out_channel,
kernel_size,
padding=self.padding,
stride=stride,
bias=bias and not activate,
)
)
if activate:
layers.append(FusedLeakyReLU(out_channel, bias=bias))
super().__init__(*layers)
class Decoder(nn.Module):
def __init__(self, model_opt, blur_kernel=[1, 3, 3, 1]):
super().__init__()
# decoder mapping network
self.size = model_opt.size
self.style_dim = model_opt.style_dim * 2
thumb_im_size = model_opt.renderer_spatial_output_dim
layers = [PixelNorm(),
EqualLinear(
self.style_dim // 2, self.style_dim, lr_mul=model_opt.lr_mapping, activation="fused_lrelu"
)]
for i in range(4):
layers.append(
EqualLinear(
self.style_dim, self.style_dim, lr_mul=model_opt.lr_mapping, activation="fused_lrelu"
)
)
self.style = nn.Sequential(*layers)
# decoder network
self.channels = {
4: 512,
8: 512,
16: 512,
32: 512,
64: 256 * model_opt.channel_multiplier,
128: 128 * model_opt.channel_multiplier,
256: 64 * model_opt.channel_multiplier,
512: 32 * model_opt.channel_multiplier,
1024: 16 * model_opt.channel_multiplier,
}
decoder_in_size = model_opt.renderer_spatial_output_dim
# image decoder
self.log_size = int(math.log(self.size, 2))
self.log_in_size = int(math.log(decoder_in_size, 2))
self.conv1 = StyledConv(
model_opt.feature_encoder_in_channels,
self.channels[decoder_in_size], 3, self.style_dim, blur_kernel=blur_kernel,
project_noise=model_opt.project_noise)
self.to_rgb1 = ToRGB(self.channels[decoder_in_size], self.style_dim, upsample=False)
self.num_layers = (self.log_size - self.log_in_size) * 2 + 1
self.convs = nn.ModuleList()
self.upsamples = nn.ModuleList()
self.to_rgbs = nn.ModuleList()
self.noises = nn.Module()
in_channel = self.channels[decoder_in_size]
for layer_idx in range(self.num_layers):
res = (layer_idx + 2 * self.log_in_size + 1) // 2
shape = [1, 1, 2 ** (res), 2 ** (res)]
self.noises.register_buffer(f"noise_{layer_idx}", torch.randn(*shape))
for i in range(self.log_in_size+1, self.log_size + 1):
out_channel = self.channels[2 ** i]
self.convs.append(
StyledConv(in_channel, out_channel, 3, self.style_dim, upsample=True,
blur_kernel=blur_kernel, project_noise=model_opt.project_noise)
)
self.convs.append(
StyledConv(out_channel, out_channel, 3, self.style_dim,
blur_kernel=blur_kernel, project_noise=model_opt.project_noise)
)
self.to_rgbs.append(ToRGB(out_channel, self.style_dim))
in_channel = out_channel
self.n_latent = (self.log_size - self.log_in_size) * 2 + 2
def mean_latent(self, renderer_latent):
latent = self.style(renderer_latent).mean(0, keepdim=True)
return latent
def get_latent(self, input):
return self.style(input)
def styles_and_noise_forward(self, styles, noise, inject_index=None, truncation=1,
truncation_latent=None, input_is_latent=False,
randomize_noise=True):
if not input_is_latent:
styles = [self.style(s) for s in styles]
if noise is None:
if randomize_noise:
noise = [None] * self.num_layers
else:
noise = [
getattr(self.noises, f"noise_{i}") for i in range(self.num_layers)
]
if (truncation < 1):
style_t = []
for style in styles:
style_t.append(
truncation_latent[1] + truncation * (style - truncation_latent[1])
)
styles = style_t
if len(styles) < 2:
inject_index = self.n_latent
if styles[0].ndim < 3:
latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
else:
latent = styles[0]
else:
if inject_index is None:
inject_index = random.randint(1, self.n_latent - 1)
latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
latent2 = styles[1].unsqueeze(1).repeat(1, self.n_latent - inject_index, 1)
latent = torch.cat([latent, latent2], 1)
return latent, noise
def forward(self, features, styles, rgbd_in=None, transform=None,
return_latents=False, inject_index=None, truncation=1,
truncation_latent=None, input_is_latent=False, noise=None,
randomize_noise=True, mesh_path=None):
latent, noise = self.styles_and_noise_forward(styles, noise, inject_index, truncation,
truncation_latent, input_is_latent,
randomize_noise)
out = self.conv1(features, latent[:, 0], noise=noise[0],
transform=transform, mesh_path=mesh_path)
skip = self.to_rgb1(out, latent[:, 1], skip=rgbd_in)
i = 1
for conv1, conv2, noise1, noise2, to_rgb in zip(
self.convs[::2], self.convs[1::2], noise[1::2], noise[2::2], self.to_rgbs
):
out = conv1(out, latent[:, i], noise=noise1,
transform=transform, mesh_path=mesh_path)
out = conv2(out, latent[:, i + 1], noise=noise2,
transform=transform, mesh_path=mesh_path)
skip = to_rgb(out, latent[:, i + 2], skip=skip)
i += 2
out_latent = latent if return_latents else None
image = skip
return image, out_latent
class Generator(nn.Module):
def __init__(self, model_opt, renderer_opt, blur_kernel=[1, 3, 3, 1], ema=False, full_pipeline=True):
super().__init__()
self.size = model_opt.size
self.style_dim = model_opt.style_dim
self.num_layers = 1
self.train_renderer = not model_opt.freeze_renderer
self.full_pipeline = full_pipeline
model_opt.feature_encoder_in_channels = renderer_opt.width
if ema or 'is_test' in model_opt.keys():
self.is_train = False
else:
self.is_train = True
# volume renderer mapping_network
layers = []
for i in range(3):
layers.append(
MappingLinear(self.style_dim, self.style_dim, activation="fused_lrelu")
)
self.style = nn.Sequential(*layers)
# volume renderer
thumb_im_size = model_opt.renderer_spatial_output_dim
self.renderer = VolumeFeatureRenderer(renderer_opt, style_dim=self.style_dim,
out_im_res=thumb_im_size)
if self.full_pipeline:
self.decoder = Decoder(model_opt)
def make_noise(self):
device = self.input.input.device
noises = [torch.randn(1, 1, 2 ** 2, 2 ** 2, device=device)]
for i in range(3, self.log_size + 1):
for _ in range(2):
noises.append(torch.randn(1, 1, 2 ** i, 2 ** i, device=device))
return noises
def mean_latent(self, n_latent, device):
latent_in = torch.randn(n_latent, self.style_dim, device=device)
renderer_latent = self.style(latent_in)
renderer_latent_mean = renderer_latent.mean(0, keepdim=True)
if self.full_pipeline:
decoder_latent_mean = self.decoder.mean_latent(renderer_latent)
else:
decoder_latent_mean = None
return [renderer_latent_mean, decoder_latent_mean]
def get_latent(self, input):
return self.style(input)
def styles_and_noise_forward(self, styles, inject_index=None, truncation=1,
truncation_latent=None, input_is_latent=False):
if not input_is_latent:
styles = [self.style(s) for s in styles]
if truncation < 1:
style_t = []
for style in styles:
style_t.append(
truncation_latent[0] + truncation * (style - truncation_latent[0])
)
styles = style_t
return styles
def init_forward(self, styles, cam_poses, focals, near=0.88, far=1.12):
latent = self.styles_and_noise_forward(styles)
sdf, target_values = self.renderer.mlp_init_pass(cam_poses, focals, near, far, styles=latent[0])
return sdf, target_values
def forward(self, styles, cam_poses, focals, near=0.88, far=1.12, return_latents=False,
inject_index=None, truncation=1, truncation_latent=None,
input_is_latent=False, noise=None, randomize_noise=True,
return_sdf=False, return_xyz=False, return_eikonal=False,
project_noise=False, mesh_path=None):
# do not calculate renderer gradients if renderer weights are frozen
with torch.set_grad_enabled(self.is_train and self.train_renderer):
latent = self.styles_and_noise_forward(styles, inject_index, truncation,
truncation_latent, input_is_latent)
thumb_rgb, features, sdf, mask, xyz, eikonal_term = self.renderer(cam_poses, focals, near, far, styles=latent[0], return_eikonal=return_eikonal)
if self.full_pipeline:
rgb, decoder_latent = self.decoder(features, latent,
transform=cam_poses if project_noise else None,
return_latents=return_latents,
inject_index=inject_index, truncation=truncation,
truncation_latent=truncation_latent, noise=noise,
input_is_latent=input_is_latent, randomize_noise=randomize_noise,
mesh_path=mesh_path)
else:
rgb = None
if return_latents:
return rgb, decoder_latent
else:
out = (rgb, thumb_rgb)
if return_xyz:
out += (xyz,)
if return_sdf:
out += (sdf,)
if return_eikonal:
out += (eikonal_term,)
if return_xyz:
out += (mask,)
return out
############# Volume Renderer Building Blocks & Discriminator ##################
class VolumeRenderDiscConv2d(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, stride=1,
padding=0, bias=True, activate=False):
super(VolumeRenderDiscConv2d, self).__init__()
self.conv = nn.Conv2d(in_channels, out_channels,
kernel_size, stride, padding, bias=bias and not activate)
self.activate = activate
if self.activate:
self.activation = FusedLeakyReLU(out_channels, bias=bias, scale=1)
bias_init_coef = np.sqrt(1 / (in_channels * kernel_size * kernel_size))
nn.init.uniform_(self.activation.bias, a=-bias_init_coef, b=bias_init_coef)
def forward(self, input):
"""
input_tensor_shape: (N, C_in,H,W)
output_tensor_shape: (N,C_out,H_out,W_out)
:return: Conv2d + activation Result
"""
out = self.conv(input)
if self.activate:
out = self.activation(out)
return out
class AddCoords(nn.Module):
def __init__(self):
super(AddCoords, self).__init__()
def forward(self, input_tensor):
"""
:param input_tensor: shape (N, C_in, H, W)
:return:
"""
batch_size_shape, channel_in_shape, dim_y, dim_x = input_tensor.shape
xx_channel = torch.arange(dim_x, dtype=torch.float32, device=input_tensor.device).repeat(1,1,dim_y,1)
yy_channel = torch.arange(dim_y, dtype=torch.float32, device=input_tensor.device).repeat(1,1,dim_x,1).transpose(2,3)
xx_channel = xx_channel / (dim_x - 1)
yy_channel = yy_channel / (dim_y - 1)
xx_channel = xx_channel * 2 - 1
yy_channel = yy_channel * 2 - 1
xx_channel = xx_channel.repeat(batch_size_shape, 1, 1, 1)
yy_channel = yy_channel.repeat(batch_size_shape, 1, 1, 1)
out = torch.cat([input_tensor, yy_channel, xx_channel], dim=1)
return out
class CoordConv2d(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, stride=1,
padding=0, bias=True):
super(CoordConv2d, self).__init__()
self.addcoords = AddCoords()
self.conv = nn.Conv2d(in_channels + 2, out_channels,
kernel_size, stride=stride, padding=padding, bias=bias)
def forward(self, input_tensor):
"""
input_tensor_shape: (N, C_in,H,W)
output_tensor_shape: N,C_out,H_out,W_out)
:return: CoordConv2d Result
"""
out = self.addcoords(input_tensor)
out = self.conv(out)
return out
class CoordConvLayer(nn.Module):
def __init__(self, in_channel, out_channel, kernel_size, bias=True, activate=True):
super(CoordConvLayer, self).__init__()
layers = []
stride = 1
self.activate = activate
self.padding = kernel_size // 2 if kernel_size > 2 else 0
self.conv = CoordConv2d(in_channel, out_channel, kernel_size,
padding=self.padding, stride=stride,
bias=bias and not activate)
if activate:
self.activation = FusedLeakyReLU(out_channel, bias=bias, scale=1)
bias_init_coef = np.sqrt(1 / (in_channel * kernel_size * kernel_size))
nn.init.uniform_(self.activation.bias, a=-bias_init_coef, b=bias_init_coef)
def forward(self, input):
out = self.conv(input)
if self.activate:
out = self.activation(out)
return out
class VolumeRenderResBlock(nn.Module):
def __init__(self, in_channel, out_channel):
super().__init__()
self.conv1 = CoordConvLayer(in_channel, out_channel, 3)
self.conv2 = CoordConvLayer(out_channel, out_channel, 3)
self.pooling = nn.AvgPool2d(2)
self.downsample = nn.AvgPool2d(2)
if out_channel != in_channel:
self.skip = VolumeRenderDiscConv2d(in_channel, out_channel, 1)
else:
self.skip = None
def forward(self, input):
out = self.conv1(input)
out = self.conv2(out)
out = self.pooling(out)
downsample_in = self.downsample(input)
if self.skip != None:
skip_in = self.skip(downsample_in)
else:
skip_in = downsample_in
out = (out + skip_in) / math.sqrt(2)
return out
class VolumeRenderDiscriminator(nn.Module):
def __init__(self, opt):
super().__init__()
init_size = opt.renderer_spatial_output_dim
self.viewpoint_loss = not opt.no_viewpoint_loss
final_out_channel = 3 if self.viewpoint_loss else 1
channels = {
2: 400,
4: 400,
8: 400,
16: 400,
32: 256,
64: 128,
128: 64,
}
convs = [VolumeRenderDiscConv2d(3, channels[init_size], 1, activate=True)]
log_size = int(math.log(init_size, 2))
in_channel = channels[init_size]
for i in range(log_size-1, 0, -1):
out_channel = channels[2 ** i]
convs.append(VolumeRenderResBlock(in_channel, out_channel))
in_channel = out_channel
self.convs = nn.Sequential(*convs)
self.final_conv = VolumeRenderDiscConv2d(in_channel, final_out_channel, 2)
def forward(self, input):
out = self.convs(input)
out = self.final_conv(out)
gan_preds = out[:,0:1]
gan_preds = gan_preds.view(-1, 1)
if self.viewpoint_loss:
viewpoints_preds = out[:,1:]
viewpoints_preds = viewpoints_preds.view(-1,2)
else:
viewpoints_preds = None
return gan_preds, viewpoints_preds
######################### StyleGAN Discriminator ########################
class ResBlock(nn.Module):
def __init__(self, in_channel, out_channel, blur_kernel=[1, 3, 3, 1], merge=False):
super().__init__()
self.conv1 = ConvLayer(2 * in_channel if merge else in_channel, in_channel, 3)
self.conv2 = ConvLayer(in_channel, out_channel, 3, downsample=True)
self.skip = ConvLayer(2 * in_channel if merge else in_channel, out_channel,
1, downsample=True, activate=False, bias=False)
def forward(self, input):
out = self.conv1(input)
out = self.conv2(out)
out = (out + self.skip(input)) / math.sqrt(2)
return out
class Discriminator(nn.Module):
def __init__(self, opt, blur_kernel=[1, 3, 3, 1]):
super().__init__()
init_size = opt.size
channels = {
4: 512,
8: 512,
16: 512,
32: 512,
64: 256 * opt.channel_multiplier,
128: 128 * opt.channel_multiplier,
256: 64 * opt.channel_multiplier,
512: 32 * opt.channel_multiplier,
1024: 16 * opt.channel_multiplier,
}
convs = [ConvLayer(3, channels[init_size], 1)]
log_size = int(math.log(init_size, 2))
in_channel = channels[init_size]
for i in range(log_size, 2, -1):
out_channel = channels[2 ** (i - 1)]
convs.append(ResBlock(in_channel, out_channel, blur_kernel))
in_channel = out_channel
self.convs = nn.Sequential(*convs)
self.stddev_group = 4
self.stddev_feat = 1
# minibatch discrimination
in_channel += 1
self.final_conv = ConvLayer(in_channel, channels[4], 3)
self.final_linear = nn.Sequential(
EqualLinear(channels[4] * 4 * 4, channels[4], activation="fused_lrelu"),
EqualLinear(channels[4], 1),
)
def forward(self, input):
out = self.convs(input)
# minibatch discrimination
batch, channel, height, width = out.shape
group = min(batch, self.stddev_group)
if batch % group != 0:
group = 3 if batch % 3 == 0 else 2
stddev = out.view(
group, -1, self.stddev_feat, channel // self.stddev_feat, height, width
)
stddev = torch.sqrt(stddev.var(0, unbiased=False) + 1e-8)
stddev = stddev.mean([2, 3, 4], keepdims=True).squeeze(2)
stddev = stddev.repeat(group, 1, height, width)
final_out = torch.cat([out, stddev], 1)
# final layers
final_out = self.final_conv(final_out)
final_out = final_out.view(batch, -1)
final_out = self.final_linear(final_out)
gan_preds = final_out[:,:1]
return gan_preds
================================================
FILE: op/__init__.py
================================================
from .fused_act import FusedLeakyReLU, fused_leaky_relu
from .upfirdn2d import upfirdn2d
================================================
FILE: op/fused_act.py
================================================
import os
import torch
from torch import nn
from torch.nn import functional as F
from torch.autograd import Function
from torch.utils.cpp_extension import load
module_path = os.path.dirname(__file__)
fused = load(
"fused",
sources=[
os.path.join(module_path, "fused_bias_act.cpp"),
os.path.join(module_path, "fused_bias_act_kernel.cu"),
],
)
class FusedLeakyReLUFunctionBackward(Function):
@staticmethod
def forward(ctx, grad_output, out, bias, negative_slope, scale):
ctx.save_for_backward(out)
ctx.negative_slope = negative_slope
ctx.scale = scale
empty = grad_output.new_empty(0)
grad_input = fused.fused_bias_act(
grad_output, empty, out, 3, 1, negative_slope, scale
)
dim = [0]
if grad_input.ndim > 2:
dim += list(range(2, grad_input.ndim))
if bias:
grad_bias = grad_input.sum(dim).detach()
else:
grad_bias = empty
return grad_input, grad_bias
@staticmethod
def backward(ctx, gradgrad_input, gradgrad_bias):
out, = ctx.saved_tensors
gradgrad_out = fused.fused_bias_act(
gradgrad_input, gradgrad_bias, out, 3, 1, ctx.negative_slope, ctx.scale
)
return gradgrad_out, None, None, None, None
class FusedLeakyReLUFunction(Function):
@staticmethod
def forward(ctx, input, bias, negative_slope, scale):
empty = input.new_empty(0)
ctx.bias = bias is not None
if bias is None:
bias = empty
out = fused.fused_bias_act(input, bias, empty, 3, 0, negative_slope, scale)
ctx.save_for_backward(out)
ctx.negative_slope = negative_slope
ctx.scale = scale
return out
@staticmethod
def backward(ctx, grad_output):
out, = ctx.saved_tensors
grad_input, grad_bias = FusedLeakyReLUFunctionBackward.apply(
grad_output, out, ctx.bias, ctx.negative_slope, ctx.scale
)
if not ctx.bias:
grad_bias = None
return grad_input, grad_bias, None, None
class FusedLeakyReLU(nn.Module):
def __init__(self, channel, bias=True, negative_slope=0.2, scale=2 ** 0.5):
super().__init__()
if bias:
self.bias = nn.Parameter(torch.zeros(channel))
else:
self.bias = None
self.negative_slope = negative_slope
self.scale = scale
def forward(self, input):
return fused_leaky_relu(input, self.bias, self.negative_slope, self.scale)
def fused_leaky_relu(input, bias=None, negative_slope=0.2, scale=2 ** 0.5):
if input.device.type == "cpu":
if bias is not None:
rest_dim = [1] * (input.ndim - bias.ndim - 1)
return (
F.leaky_relu(
input + bias.view(1, bias.shape[0], *rest_dim), negative_slope=0.2
)
* scale
)
else:
return F.leaky_relu(input, negative_slope=0.2) * scale
else:
return FusedLeakyReLUFunction.apply(input, bias, negative_slope, scale)
================================================
FILE: op/fused_bias_act.cpp
================================================
#include <torch/extension.h>
torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer,
int act, int grad, float alpha, float scale);
#define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
torch::Tensor fused_bias_act(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer,
int act, int grad, float alpha, float scale) {
CHECK_CUDA(input);
CHECK_CUDA(bias);
return fused_bias_act_op(input, bias, refer, act, grad, alpha, scale);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("fused_bias_act", &fused_bias_act, "fused bias act (CUDA)");
}
================================================
FILE: op/fused_bias_act_kernel.cu
================================================
// Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
//
// This work is made available under the Nvidia Source Code License-NC.
// To view a copy of this license, visit
// https://nvlabs.github.io/stylegan2/license.html
#include <torch/types.h>
#include <ATen/ATen.h>
#include <ATen/AccumulateType.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/CUDAApplyUtils.cuh>
#include <cuda.h>
#include <cuda_runtime.h>
template <typename scalar_t>
static __global__ void fused_bias_act_kernel(scalar_t* out, const scalar_t* p_x, const scalar_t* p_b, const scalar_t* p_ref,
int act, int grad, scalar_t alpha, scalar_t scale, int loop_x, int size_x, int step_b, int size_b, int use_bias, int use_ref) {
int xi = blockIdx.x * loop_x * blockDim.x + threadIdx.x;
scalar_t zero = 0.0;
for (int loop_idx = 0; loop_idx < loop_x && xi < size_x; loop_idx++, xi += blockDim.x) {
scalar_t x = p_x[xi];
if (use_bias) {
x += p_b[(xi / step_b) % size_b];
}
scalar_t ref = use_ref ? p_ref[xi] : zero;
scalar_t y;
switch (act * 10 + grad) {
default:
case 10: y = x; break;
case 11: y = x; break;
case 12: y = 0.0; break;
case 30: y = (x > 0.0) ? x : x * alpha; break;
case 31: y = (ref > 0.0) ? x : x * alpha; break;
case 32: y = 0.0; break;
}
out[xi] = y * scale;
}
}
torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer,
int act, int grad, float alpha, float scale) {
int curDevice = -1;
cudaGetDevice(&curDevice);
cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice);
auto x = input.contiguous();
auto b = bias.contiguous();
auto ref = refer.contiguous();
int use_bias = b.numel() ? 1 : 0;
int use_ref = ref.numel() ? 1 : 0;
int size_x = x.numel();
int size_b = b.numel();
int step_b = 1;
for (int i = 1 + 1; i < x.dim(); i++) {
step_b *= x.size(i);
}
int loop_x = 4;
int block_size = 4 * 32;
int grid_size = (size_x - 1) / (loop_x * block_size) + 1;
auto y = torch::empty_like(x);
AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "fused_bias_act_kernel", [&] {
fused_bias_act_kernel<scalar_t><<<grid_size, block_size, 0, stream>>>(
y.data_ptr<scalar_t>(),
x.data_ptr<scalar_t>(),
b.data_ptr<scalar_t>(),
ref.data_ptr<scalar_t>(),
act,
grad,
alpha,
scale,
loop_x,
size_x,
step_b,
size_b,
use_bias,
use_ref
);
});
return y;
}
================================================
FILE: op/upfirdn2d.cpp
================================================
#include <torch/extension.h>
torch::Tensor upfirdn2d_op(const torch::Tensor& input, const torch::Tensor& kernel,
int up_x, int up_y, int down_x, int down_y,
int pad_x0, int pad_x1, int pad_y0, int pad_y1);
#define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
torch::Tensor upfirdn2d(const torch::Tensor& input, const torch::Tensor& kernel,
int up_x, int up_y, int down_x, int down_y,
int pad_x0, int pad_x1, int pad_y0, int pad_y1) {
CHECK_CUDA(input);
CHECK_CUDA(kernel);
return upfirdn2d_op(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("upfirdn2d", &upfirdn2d, "upfirdn2d (CUDA)");
}
================================================
FILE: op/upfirdn2d.py
================================================
import os
import torch
from torch.nn import functional as F
from torch.autograd import Function
from torch.utils.cpp_extension import load
from pdb import set_trace as st
module_path = os.path.dirname(__file__)
upfirdn2d_op = load(
"upfirdn2d",
sources=[
os.path.join(module_path, "upfirdn2d.cpp"),
os.path.join(module_path, "upfirdn2d_kernel.cu"),
],
)
class UpFirDn2dBackward(Function):
@staticmethod
def forward(
ctx, grad_output, kernel, grad_kernel, up, down, pad, g_pad, in_size, out_size
):
up_x, up_y = up
down_x, down_y = down
g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1 = g_pad
grad_output = grad_output.reshape(-1, out_size[0], out_size[1], 1)
grad_input = upfirdn2d_op.upfirdn2d(
grad_output,
grad_kernel,
down_x,
down_y,
up_x,
up_y,
g_pad_x0,
g_pad_x1,
g_pad_y0,
g_pad_y1,
)
grad_input = grad_input.view(in_size[0], in_size[1], in_size[2], in_size[3])
ctx.save_for_backward(kernel)
pad_x0, pad_x1, pad_y0, pad_y1 = pad
ctx.up_x = up_x
ctx.up_y = up_y
ctx.down_x = down_x
ctx.down_y = down_y
ctx.pad_x0 = pad_x0
ctx.pad_x1 = pad_x1
ctx.pad_y0 = pad_y0
ctx.pad_y1 = pad_y1
ctx.in_size = in_size
ctx.out_size = out_size
return grad_input
@staticmethod
def backward(ctx, gradgrad_input):
kernel, = ctx.saved_tensors
gradgrad_input = gradgrad_input.reshape(-1, ctx.in_size[2], ctx.in_size[3], 1)
gradgrad_out = upfirdn2d_op.upfirdn2d(
gradgrad_input,
kernel,
ctx.up_x,
ctx.up_y,
ctx.down_x,
ctx.down_y,
ctx.pad_x0,
ctx.pad_x1,
ctx.pad_y0,
ctx.pad_y1,
)
# gradgrad_out = gradgrad_out.view(ctx.in_size[0], ctx.out_size[0], ctx.out_size[1], ctx.in_size[3])
gradgrad_out = gradgrad_out.view(
ctx.in_size[0], ctx.in_size[1], ctx.out_size[0], ctx.out_size[1]
)
return gradgrad_out, None, None, None, None, None, None, None, None
class UpFirDn2d(Function):
@staticmethod
def forward(ctx, input, kernel, up, down, pad):
up_x, up_y = up
down_x, down_y = down
pad_x0, pad_x1, pad_y0, pad_y1 = pad
kernel_h, kernel_w = kernel.shape
batch, channel, in_h, in_w = input.shape
ctx.in_size = input.shape
input = input.reshape(-1, in_h, in_w, 1)
ctx.save_for_backward(kernel, torch.flip(kernel, [0, 1]))
out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1
out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1
ctx.out_size = (out_h, out_w)
ctx.up = (up_x, up_y)
ctx.down = (down_x, down_y)
ctx.pad = (pad_x0, pad_x1, pad_y0, pad_y1)
g_pad_x0 = kernel_w - pad_x0 - 1
g_pad_y0 = kernel_h - pad_y0 - 1
g_pad_x1 = in_w * up_x - out_w * down_x + pad_x0 - up_x + 1
g_pad_y1 = in_h * up_y - out_h * down_y + pad_y0 - up_y + 1
ctx.g_pad = (g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1)
out = upfirdn2d_op.upfirdn2d(
input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1
)
# out = out.view(major, out_h, out_w, minor)
out = out.view(-1, channel, out_h, out_w)
return out
@staticmethod
def backward(ctx, grad_output):
kernel, grad_kernel = ctx.saved_tensors
grad_input = UpFirDn2dBackward.apply(
grad_output,
kernel,
grad_kernel,
ctx.up,
ctx.down,
ctx.pad,
ctx.g_pad,
ctx.in_size,
ctx.out_size,
)
return grad_input, None, None, None, None
def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)):
if input.device.type == "cpu":
out = upfirdn2d_native(
input, kernel, up, up, down, down, pad[0], pad[1], pad[0], pad[1]
)
else:
out = UpFirDn2d.apply(
input, kernel, (up, up), (down, down), (pad[0], pad[1], pad[0], pad[1])
)
return out
def upfirdn2d_native(
input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1
):
_, channel, in_h, in_w = input.shape
input = input.reshape(-1, in_h, in_w, 1)
_, in_h, in_w, minor = input.shape
kernel_h, kernel_w = kernel.shape
out = input.view(-1, in_h, 1, in_w, 1, minor)
out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1])
out = out.view(-1, in_h * up_y, in_w * up_x, minor)
out = F.pad(
out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)]
)
out = out[
:,
max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0),
max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0),
:,
]
out = out.permute(0, 3, 1, 2)
out = out.reshape(
[-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1]
)
w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w)
out = F.conv2d(out, w)
out = out.reshape(
-1,
minor,
in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1,
in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1,
)
out = out.permute(0, 2, 3, 1)
out = out[:, ::down_y, ::down_x, :]
out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1
out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1
return out.view(-1, channel, out_h, out_w)
================================================
FILE: op/upfirdn2d_kernel.cu
================================================
// Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
//
// This work is made available under the Nvidia Source Code License-NC.
// To view a copy of this license, visit
// https://nvlabs.github.io/stylegan2/license.html
#include <torch/types.h>
#include <ATen/ATen.h>
#include <ATen/AccumulateType.h>
#include <ATen/cuda/CUDAApplyUtils.cuh>
#include <ATen/cuda/CUDAContext.h>
#include <cuda.h>
#include <cuda_runtime.h>
static __host__ __device__ __forceinline__ int floor_div(int a, int b) {
int c = a / b;
if (c * b > a) {
c--;
}
return c;
}
struct UpFirDn2DKernelParams {
int up_x;
int up_y;
int down_x;
int down_y;
int pad_x0;
int pad_x1;
int pad_y0;
int pad_y1;
int major_dim;
int in_h;
int in_w;
int minor_dim;
int kernel_h;
int kernel_w;
int out_h;
int out_w;
int loop_major;
int loop_x;
};
template <typename scalar_t>
__global__ void upfirdn2d_kernel_large(scalar_t *out, const scalar_t *input,
const scalar_t *kernel,
const UpFirDn2DKernelParams p) {
int minor_idx = blockIdx.x * blockDim.x + threadIdx.x;
int out_y = minor_idx / p.minor_dim;
minor_idx -= out_y * p.minor_dim;
int out_x_base = blockIdx.y * p.loop_x * blockDim.y + threadIdx.y;
int major_idx_base = blockIdx.z * p.loop_major;
if (out_x_base >= p.out_w || out_y >= p.out_h ||
major_idx_base >= p.major_dim) {
return;
}
int mid_y = out_y * p.down_y + p.up_y - 1 - p.pad_y0;
int in_y = min(max(floor_div(mid_y, p.up_y), 0), p.in_h);
int h = min(max(floor_div(mid_y + p.kernel_h, p.up_y), 0), p.in_h) - in_y;
int kernel_y = mid_y + p.kernel_h - (in_y + 1) * p.up_y;
for (int loop_major = 0, major_idx = major_idx_base;
loop_major < p.loop_major && major_idx < p.major_dim;
loop_major++, major_idx++) {
for (int loop_x = 0, out_x = out_x_base;
loop_x < p.loop_x && out_x < p.out_w; loop_x++, out_x += blockDim.y) {
int mid_x = out_x * p.down_x + p.up_x - 1 - p.pad_x0;
int in_x = min(max(floor_div(mid_x, p.up_x), 0), p.in_w);
int w = min(max(floor_div(mid_x + p.kernel_w, p.up_x), 0), p.in_w) - in_x;
int kernel_x = mid_x + p.kernel_w - (in_x + 1) * p.up_x;
const scalar_t *x_p =
&input[((major_idx * p.in_h + in_y) * p.in_w + in_x) * p.minor_dim +
minor_idx];
const scalar_t *k_p = &kernel[kernel_y * p.kernel_w + kernel_x];
int x_px = p.minor_dim;
int k_px = -p.up_x;
int x_py = p.in_w * p.minor_dim;
int k_py = -p.up_y * p.kernel_w;
scalar_t v = 0.0f;
for (int y = 0; y < h; y++) {
for (int x = 0; x < w; x++) {
v += static_cast<scalar_t>(*x_p) * static_cast<scalar_t>(*k_p);
x_p += x_px;
k_p += k_px;
}
x_p += x_py - w * x_px;
k_p += k_py - w * k_px;
}
out[((major_idx * p.out_h + out_y) * p.out_w + out_x) * p.minor_dim +
minor_idx] = v;
}
}
}
template <typename scalar_t, int up_x, int up_y, int down_x, int down_y,
int kernel_h, int kernel_w, int tile_out_h, int tile_out_w>
__global__ void upfirdn2d_kernel(scalar_t *out, const scalar_t *input,
const scalar_t *kernel,
const UpFirDn2DKernelParams p) {
const int tile_in_h = ((tile_out_h - 1) * down_y + kernel_h - 1) / up_y + 1;
const int tile_in_w = ((tile_out_w - 1) * down_x + kernel_w - 1) / up_x + 1;
__shared__ volatile float sk[kernel_h][kernel_w];
__shared__ volatile float sx[tile_in_h][tile_in_w];
int minor_idx = blockIdx.x;
int tile_out_y = minor_idx / p.minor_dim;
minor_idx -= tile_out_y * p.minor_dim;
tile_out_y *= tile_out_h;
int tile_out_x_base = blockIdx.y * p.loop_x * tile_out_w;
int major_idx_base = blockIdx.z * p.loop_major;
if (tile_out_x_base >= p.out_w | tile_out_y >= p.out_h |
major_idx_base >= p.major_dim) {
return;
}
for (int tap_idx = threadIdx.x; tap_idx < kernel_h * kernel_w;
tap_idx += blockDim.x) {
int ky = tap_idx / kernel_w;
int kx = tap_idx - ky * kernel_w;
scalar_t v = 0.0;
if (kx < p.kernel_w & ky < p.kernel_h) {
v = kernel[(p.kernel_h - 1 - ky) * p.kernel_w + (p.kernel_w - 1 - kx)];
}
sk[ky][kx] = v;
}
for (int loop_major = 0, major_idx = major_idx_base;
loop_major < p.loop_major & major_idx < p.major_dim;
loop_major++, major_idx++) {
for (int loop_x = 0, tile_out_x = tile_out_x_base;
loop_x < p.loop_x & tile_out_x < p.out_w;
loop_x++, tile_out_x += tile_out_w) {
int tile_mid_x = tile_out_x * down_x + up_x - 1 - p.pad_x0;
int tile_mid_y = tile_out_y * down_y + up_y - 1 - p.pad_y0;
int tile_in_x = floor_div(tile_mid_x, up_x);
int tile_in_y = floor_div(tile_mid_y, up_y);
__syncthreads();
for (int in_idx = threadIdx.x; in_idx < tile_in_h * tile_in_w;
in_idx += blockDim.x) {
int rel_in_y = in_idx / tile_in_w;
int rel_in_x = in_idx - rel_in_y * tile_in_w;
int in_x = rel_in_x + tile_in_x;
int in_y = rel_in_y + tile_in_y;
scalar_t v = 0.0;
if (in_x >= 0 & in_y >= 0 & in_x < p.in_w & in_y < p.in_h) {
v = input[((major_idx * p.in_h + in_y) * p.in_w + in_x) *
p.minor_dim +
minor_idx];
}
sx[rel_in_y][rel_in_x] = v;
}
__syncthreads();
for (int out_idx = threadIdx.x; out_idx < tile_out_h * tile_out_w;
out_idx += blockDim.x) {
int rel_out_y = out_idx / tile_out_w;
int rel_out_x = out_idx - rel_out_y * tile_out_w;
int out_x = rel_out_x + tile_out_x;
int out_y = rel_out_y + tile_out_y;
int mid_x = tile_mid_x + rel_out_x * down_x;
int mid_y = tile_mid_y + rel_out_y * down_y;
int in_x = floor_div(mid_x, up_x);
int in_y = floor_div(mid_y, up_y);
int rel_in_x = in_x - tile_in_x;
int rel_in_y = in_y - tile_in_y;
int kernel_x = (in_x + 1) * up_x - mid_x - 1;
int kernel_y = (in_y + 1) * up_y - mid_y - 1;
scalar_t v = 0.0;
#pragma unroll
for (int y = 0; y < kernel_h / up_y; y++)
#pragma unroll
for (int x = 0; x < kernel_w / up_x; x++)
v += sx[rel_in_y + y][rel_in_x + x] *
sk[kernel_y + y * up_y][kernel_x + x * up_x];
if (out_x < p.out_w & out_y < p.out_h) {
out[((major_idx * p.out_h + out_y) * p.out_w + out_x) * p.minor_dim +
minor_idx] = v;
}
}
}
}
}
torch::Tensor upfirdn2d_op(const torch::Tensor &input,
const torch::Tensor &kernel, int up_x, int up_y,
int down_x, int down_y, int pad_x0, int pad_x1,
int pad_y0, int pad_y1) {
int curDevice = -1;
cudaGetDevice(&curDevice);
cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice);
UpFirDn2DKernelParams p;
auto x = input.contiguous();
auto k = kernel.contiguous();
p.major_dim = x.size(0);
p.in_h = x.size(1);
p.in_w = x.size(2);
p.minor_dim = x.size(3);
p.kernel_h = k.size(0);
p.kernel_w = k.size(1);
p.up_x = up_x;
p.up_y = up_y;
p.down_x = down_x;
p.down_y = down_y;
p.pad_x0 = pad_x0;
p.pad_x1 = pad_x1;
p.pad_y0 = pad_y0;
p.pad_y1 = pad_y1;
p.out_h = (p.in_h * p.up_y + p.pad_y0 + p.pad_y1 - p.kernel_h + p.down_y) /
p.down_y;
p.out_w = (p.in_w * p.up_x + p.pad_x0 + p.pad_x1 - p.kernel_w + p.down_x) /
p.down_x;
auto out =
at::empty({p.major_dim, p.out_h, p.out_w, p.minor_dim}, x.options());
int mode = -1;
int tile_out_h = -1;
int tile_out_w = -1;
if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 &&
p.kernel_h <= 4 && p.kernel_w <= 4) {
mode = 1;
tile_out_h = 16;
tile_out_w = 64;
}
if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 &&
p.kernel_h <= 3 && p.kernel_w <= 3) {
mode = 2;
tile_out_h = 16;
tile_out_w = 64;
}
if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 &&
p.kernel_h <= 4 && p.kernel_w <= 4) {
mode = 3;
tile_out_h = 16;
tile_out_w = 64;
}
if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 &&
p.kernel_h <= 2 && p.kernel_w <= 2) {
mode = 4;
tile_out_h = 16;
tile_out_w = 64;
}
if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 &&
p.kernel_h <= 4 && p.kernel_w <= 4) {
mode = 5;
tile_out_h = 8;
tile_out_w = 32;
}
if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 &&
p.kernel_h <= 2 && p.kernel_w <= 2) {
mode = 6;
tile_out_h = 8;
tile_out_w = 32;
}
dim3 block_size;
dim3 grid_size;
if (tile_out_h > 0 && tile_out_w > 0) {
p.loop_major = (p.major_dim - 1) / 16384 + 1;
p.loop_x = 1;
block_size = dim3(32 * 8, 1, 1);
grid_size = dim3(((p.out_h - 1) / tile_out_h + 1) * p.minor_dim,
(p.out_w - 1) / (p.loop_x * tile_out_w) + 1,
(p.major_dim - 1) / p.loop_major + 1);
} else {
p.loop_major = (p.major_dim - 1) / 16384 + 1;
p.loop_x = 4;
block_size = dim3(4, 32, 1);
grid_size = dim3((p.out_h * p.minor_dim - 1) / block_size.x + 1,
(p.out_w - 1) / (p.loop_x * block_size.y) + 1,
(p.major_dim - 1) / p.loop_major + 1);
}
AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&] {
switch (mode) {
case 1:
upfirdn2d_kernel<scalar_t, 1, 1, 1, 1, 4, 4, 16, 64>
<<<grid_size, block_size, 0, stream>>>(out.data_ptr<scalar_t>(),
x.data_ptr<scalar_t>(),
k.data_ptr<scalar_t>(), p);
break;
case 2:
upfirdn2d_kernel<scalar_t, 1, 1, 1, 1, 3, 3, 16, 64>
<<<grid_size, block_size, 0, stream>>>(out.data_ptr<scalar_t>(),
x.data_ptr<scalar_t>(),
k.data_ptr<scalar_t>(), p);
break;
case 3:
upfirdn2d_kernel<scalar_t, 2, 2, 1, 1, 4, 4, 16, 64>
<<<grid_size, block_size, 0, stream>>>(out.data_ptr<scalar_t>(),
x.data_ptr<scalar_t>(),
k.data_ptr<scalar_t>(), p);
break;
case 4:
upfirdn2d_kernel<scalar_t, 2, 2, 1, 1, 2, 2, 16, 64>
<<<grid_size, block_size, 0, stream>>>(out.data_ptr<scalar_t>(),
x.data_ptr<scalar_t>(),
k.data_ptr<scalar_t>(), p);
break;
case 5:
upfirdn2d_kernel<scalar_t, 1, 1, 2, 2, 4, 4, 8, 32>
<<<grid_size, block_size, 0, stream>>>(out.data_ptr<scalar_t>(),
x.data_ptr<scalar_t>(),
k.data_ptr<scalar_t>(), p);
break;
case 6:
upfirdn2d_kernel<scalar_t, 1, 1, 2, 2, 4, 4, 8, 32>
<<<grid_size, block_size, 0, stream>>>(out.data_ptr<scalar_t>(),
x.data_ptr<scalar_t>(),
k.data_ptr<scalar_t>(), p);
break;
default:
upfirdn2d_kernel_large<scalar_t><<<grid_size, block_size, 0, stream>>>(
out.data_ptr<scalar_t>(), x.data_ptr<scalar_t>(),
k.data_ptr<scalar_t>(), p);
}
});
return out;
}
================================================
FILE: options.py
================================================
import configargparse
from munch import *
from pdb import set_trace as st
class BaseOptions():
def __init__(self):
self.parser = configargparse.ArgumentParser()
self.initialized = False
def initialize(self):
# Dataset options
dataset = self.parser.add_argument_group('dataset')
dataset.add_argument("--dataset_path", type=str, default='./datasets/FFHQ', help="path to the lmdb dataset")
# Experiment Options
experiment = self.parser.add_argument_group('experiment')
experiment.add_argument('--config', is_config_file=True, help='config file path')
experiment.add_argument("--expname", type=str, default='debug', help='experiment name')
experiment.add_argument("--ckpt", type=str, default='300000', help="path to the checkpoints to resume training")
experiment.add_argument("--continue_training", action="store_true", help="continue training the model")
# Training loop options
training = self.parser.add_argument_group('training')
training.add_argument("--checkpoints_dir", type=str, default='./checkpoint', help='checkpoints directory name')
training.add_argument("--iter", type=int, default=300000, help="total number of training iterations")
training.add_argument("--batch", type=int, default=4, help="batch sizes for each GPU. A single RTX2080 can fit batch=4, chunck=1 into memory.")
training.add_argument("--chunk", type=int, default=4, help='number of samples within a batch to processed in parallel, decrease if running out of memory')
training.add_argument("--val_n_sample", type=int, default=8, help="number of test samples generated during training")
training.add_argument("--d_reg_every", type=int, default=16, help="interval for applying r1 regularization to the StyleGAN generator")
training.add_argument("--g_reg_every", type=int, default=4, help="interval for applying path length regularization to the StyleGAN generator")
training.add_argument("--local_rank", type=int, default=0, help="local rank for distributed training")
training.add_argument("--mixing", type=float, default=0.9, help="probability of latent code mixing")
training.add_argument("--lr", type=float, default=0.002, help="learning rate")
training.add_argument("--r1", type=float, default=10, help="weight of the r1 regularization")
training.add_argument("--view_lambda", type=float, default=15, help="weight of the viewpoint regularization")
training.add_argument("--eikonal_lambda", type=float, default=0.1, help="weight of the eikonal regularization")
training.add_argument("--min_surf_lambda", type=float, default=0.05, help="weight of the minimal surface regularization")
training.add_argument("--min_surf_beta", type=float, default=100.0, help="weight of the minimal surface regularization")
training.add_argument("--path_regularize", type=float, default=2, help="weight of the path length regularization")
training.add_argument("--path_batch_shrink", type=int, default=2, help="batch size reducing factor for the path length regularization (reduce memory consumption)")
training.add_argument("--wandb", action="store_true", help="use weights and biases logging")
training.add_argument("--no_sphere_init", action="store_true", help="do not initialize the volume renderer with a sphere SDF")
# Inference Options
inference = self.parser.add_argument_group('inference')
inference.add_argument("--results_dir", type=str, default='./evaluations', help='results/evaluations directory name')
inference.add_argument("--truncation_ratio", type=float, default=0.5, help="truncation ratio, controls the diversity vs. quality tradeoff. Higher truncation ratio would generate more diverse results")
inference.add_argument("--truncation_mean", type=int, default=10000, help="number of vectors to calculate mean for the truncation")
inference.add_argument("--identities", type=int, default=16, help="number of identities to be generated")
inference.add_argument("--num_views_per_id", type=int, default=1, help="number of viewpoints generated per identity")
inference.add_argument("--no_surface_renderings", action="store_true", help="when true, only RGB outputs will be generated. otherwise, both RGB and depth videos/renderings will be generated. this cuts the processing time per video")
inference.add_argument("--fixed_camera_angles", action="store_true", help="when true, the generator will render indentities from a fixed set of camera angles.")
inference.add_argument("--azim_video", action="store_true", help="when true, the camera trajectory will travel along the azimuth direction. Otherwise, the camera will travel along an ellipsoid trajectory.")
# Generator options
model = self.parser.add_argument_group('model')
model.add_argument("--size", type=int, default=256, help="image sizes for the model")
model.add_argument("--style_dim", type=int, default=256, help="number of style input dimensions")
model.add_argument("--channel_multiplier", type=int, default=2, help="channel multiplier factor for the StyleGAN decoder. config-f = 2, else = 1")
model.add_argument("--n_mlp", type=int, default=8, help="number of mlp layers in stylegan's mapping network")
model.add_argument("--lr_mapping", type=float, default=0.01, help='learning rate reduction for mapping network MLP layers')
model.add_argument("--renderer_spatial_output_dim", type=int, default=64, help='spatial resolution of the StyleGAN decoder inputs')
model.add_argument("--project_noise", action='store_true', help='when true, use geometry-aware noise projection to reduce flickering effects (see supplementary section C.1 in the paper). warning: processing time significantly increases with this flag to ~20 minutes per video.')
# Camera options
camera = self.parser.add_argument_group('camera')
camera.add_argument("--uniform", action="store_true", help="when true, the camera position is sampled from uniform distribution. Gaussian distribution is the default")
camera.add_argument("--azim", type=float, default=0.3, help="camera azimuth angle std/range in Radians")
camera.add_argument("--elev", type=float, default=0.15, help="camera elevation angle std/range in Radians")
camera.add_argument("--fov", type=float, default=6, help="camera field of view half angle in Degrees")
camera.add_argument("--dist_radius", type=float, default=0.12, help="radius of points sampling distance from the origin. determines the near and far fields")
# Volume Renderer options
rendering = self.parser.add_argument_group('rendering')
# MLP model parameters
rendering.add_argument("--depth", type=int, default=8, help='layers in network')
rendering.add_argument("--width", type=int, default=256, help='channels per layer')
# Volume representation options
rendering.add_argument("--no_sdf", action='store_true', help='By default, the raw MLP outputs represent an underline signed distance field (SDF). When true, the MLP outputs represent the traditional NeRF density field.')
rendering.add_argument("--no_z_normalize", action='store_true', help='By default, the model normalizes input coordinates such that the z coordinate is in [-1,1]. When true that feature is disabled.')
rendering.add_argument("--static_viewdirs", action='store_true', help='when true, use static viewing direction input to the MLP')
# Ray intergration options
rendering.add_argument("--N_samples", type=int, default=24, help='number of samples per ray')
rendering.add_argument("--no_offset_sampling", action='store_true', help='when true, use random stratified sampling when rendering the volume, otherwise offset sampling is used. (See Equation (3) in Sec. 3.2 of the paper)')
rendering.add_argument("--perturb", type=float, default=1., help='set to 0. for no jitter, 1. for jitter')
rendering.add_argument("--raw_noise_std", type=float, default=0., help='std dev of noise added to regularize sigma_a output, 1e0 recommended')
rendering.add_argument("--force_background", action='store_true', help='force the last depth sample to act as background in case of a transparent ray')
# Set volume renderer outputs
rendering.add_argument("--return_xyz", action='store_true', help='when true, the volume renderer also returns the xyz point could of the surface. This point cloud is used to produce depth map renderings')
rendering.add_argument("--return_sdf", action='store_true', help='when true, the volume renderer also returns the SDF network outputs for each location in the volume')
self.initialized = True
def parse(self):
self.opt = Munch()
if not self.initialized:
self.initialize()
try:
args = self.parser.parse_args()
except: # solves argparse error in google colab
args = self.parser.parse_args(args=[])
for group in self.parser._action_groups[2:]:
title = group.title
self.opt[title] = Munch()
for action in group._group_actions:
dest = action.dest
self.opt[title][dest] = args.__getattribute__(dest)
return self.opt
================================================
FILE: prepare_data.py
================================================
import argparse
from io import BytesIO
import multiprocessing
from functools import partial
from PIL import Image
import lmdb
from tqdm import tqdm
from torchvision import datasets
from torchvision.transforms import functional as trans_fn
from pdb import set_trace as st
def resize_and_convert(img, size, resample):
img = trans_fn.resize(img, size, resample)
img = trans_fn.center_crop(img, size)
buffer = BytesIO()
img.save(buffer, format="png")
val = buffer.getvalue()
return val
def resize_multiple(
img, sizes=(128, 256, 512, 1024), resample=Image.LANCZOS):
imgs = []
for size in sizes:
imgs.append(resize_and_convert(img, size, resample))
return imgs
def resize_worker(img_file, sizes, resample):
i, file = img_file
img = Image.open(file)
img = img.convert("RGB")
out = resize_multiple(img, sizes=sizes, resample=resample)
return i, out
def prepare(
env, dataset, n_worker, sizes=(128, 256, 512, 1024), resample=Image.LANCZOS
):
resize_fn = partial(resize_worker, sizes=sizes, resample=resample)
files = sorted(dataset.imgs, key=lambda x: x[0])
files = [(i, file) for i, (file, label) in enumerate(files)]
total = 0
with multiprocessing.Pool(n_worker) as pool:
for i, imgs in tqdm(pool.imap_unordered(resize_fn, files)):
for size, img in zip(sizes, imgs):
key = f"{size}-{str(i).zfill(5)}".encode("utf-8")
with env.begin(write=True) as txn:
txn.put(key, img)
total += 1
with env.begin(write=True) as txn:
txn.put("length".encode("utf-8"), str(total).encode("utf-8"))
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Preprocess images for model training")
parser.add_argument("--size", type=str, default="64,512,1024", help="resolutions of images for the dataset")
parser.add_argument("--n_worker", type=int, default=32, help="number of workers for preparing dataset")
parser.add_argument("--resample", type=str, default="lanczos", help="resampling methods for resizing images")
parser.add_argument("--out_path", type=str, default="datasets/FFHQ", help="Target path of the output lmdb dataset")
parser.add_argument("in_path", type=str, help="path to the input image dataset")
args = parser.parse_args()
resample_map = {"lanczos": Image.LANCZOS, "bilinear": Image.BILINEAR}
resample = resample_map[args.resample]
sizes = [int(s.strip()) for s in args.size.split(",")]
print(f"Make dataset of image sizes:", ", ".join(str(s) for s in sizes))
imgset = datasets.ImageFolder(args.in_path)
with lmdb.open(args.out_path, map_size=1024 ** 4, readahead=False) as env:
prepare(env, imgset, args.n_worker, sizes=sizes, resample=resample)
================================================
FILE: render_video.py
================================================
import os
import torch
import trimesh
import numpy as np
import skvideo.io
from munch import *
from PIL import Image
from tqdm import tqdm
from torch.nn import functional as F
from torch.utils import data
from torchvision import utils
from torchvision import transforms
from options import BaseOptions
from model import Generator
from utils import (
generate_camera_params, align_volume, extract_mesh_with_marching_cubes,
xyz2mesh, create_cameras, create_mesh_renderer, add_textures,
)
from pytorch3d.structures import Meshes
from pdb import set_trace as st
torch.random.manual_seed(1234)
def render_video(opt, g_ema, surface_g_ema, device, mean_latent, surface_mean_latent):
g_ema.eval()
if not opt.no_surface_renderings or opt.project_noise:
surface_g_ema.eval()
images = torch.Tensor(0, 3, opt.size, opt.size)
num_frames = 250
# Generate video trajectory
trajectory = np.zeros((num_frames,3), dtype=np.float32)
# set camera trajectory
# sweep azimuth angles (4 seconds)
if opt.azim_video:
t = np.linspace(0, 1, num_frames)
elev = 0
fov = opt.camera.fov
if opt.camera.uniform:
azim = opt.camera.azim * np.cos(t * 2 * np.pi)
else:
azim = 1.5 * opt.camera.azim * np.cos(t * 2 * np.pi)
trajectory[:num_frames,0] = azim
trajectory[:num_frames,1] = elev
trajectory[:num_frames,2] = fov
# elipsoid sweep (4 seconds)
else:
t = np.linspace(0, 1, num_frames)
fov = opt.camera.fov #+ 1 * np.sin(t * 2 * np.pi)
if opt.camera.uniform:
elev = opt.camera.elev / 2 + opt.camera.elev / 2 * np.sin(t * 2 * np.pi)
azim = opt.camera.azim * np.cos(t * 2 * np.pi)
else:
elev = 1.5 * opt.camera.elev * np.sin(t * 2 * np.pi)
azim = 1.5 * opt.camera.azim * np.cos(t * 2 * np.pi)
trajectory[:num_frames,0] = azim
trajectory[:num_frames,1] = elev
trajectory[:num_frames,2] = fov
trajectory = torch.from_numpy(trajectory).to(device)
# generate input parameters for the camera trajectory
# sample_cam_poses, sample_focals, sample_near, sample_far = \
# generate_camera_params(trajectory, opt.renderer_output_size, device, dist_radius=opt.camera.dist_radius)
sample_cam_extrinsics, sample_focals, sample_near, sample_far, _ = \
generate_camera_params(opt.renderer_output_size, device, locations=trajectory[:,:2],
fov_ang=trajectory[:,2:], dist_radius=opt.camera.dist_radius)
# In case of noise projection, generate input parameters for the frontal position.
# The reference mesh for the noise projection is extracted from the frontal position.
# For more details see section C.1 in the supplementary material.
if opt.project_noise:
frontal_pose = torch.tensor([[0.0,0.0,opt.camera.fov]]).to(device)
# frontal_cam_pose, frontal_focals, frontal_near, frontal_far = \
# generate_camera_params(frontal_pose, opt.surf_extraction_output_size, device, dist_radius=opt.camera.dist_radius)
frontal_cam_pose, frontal_focals, frontal_near, frontal_far, _ = \
generate_camera_params(opt.surf_extraction_output_size, device, location=frontal_pose[:,:2],
fov_ang=frontal_pose[:,2:], dist_radius=opt.camera.dist_radius)
# create geometry renderer (renders the depth maps)
cameras = create_cameras(azim=np.rad2deg(trajectory[0,0].cpu().numpy()),
elev=np.rad2deg(trajectory[0,1].cpu().numpy()),
dist=1, device=device)
renderer = create_mesh_renderer(cameras, image_size=512, specular_color=((0,0,0),),
ambient_color=((0.1,.1,.1),), diffuse_color=((0.75,.75,.75),),
device=device)
suffix = '_azim' if opt.azim_video else '_elipsoid'
# generate videos
for i in range(opt.identities):
print('Processing identity {}/{}...'.format(i+1, opt.identities))
chunk = 1
sample_z = torch.randn(1, opt.style_dim, device=device).repeat(chunk,1)
video_filename = 'sample_video_{}{}.mp4'.format(i,suffix)
writer = skvideo.io.FFmpegWriter(os.path.join(opt.results_dst_dir, video_filename),
outputdict={'-pix_fmt': 'yuv420p', '-crf': '10'})
if not opt.no_surface_renderings:
depth_video_filename = 'sample_depth_video_{}{}.mp4'.format(i,suffix)
depth_writer = skvideo.io.FFmpegWriter(os.path.join(opt.results_dst_dir, depth_video_filename),
outputdict={'-pix_fmt': 'yuv420p', '-crf': '1'})
####################### Extract initial surface mesh from the frontal viewpoint #############
# For more details see section C.1 in the supplementary material.
if opt.project_noise:
with torch.no_grad():
frontal_surface_out = surface_g_ema([sample_z],
frontal_cam_pose,
frontal_focals,
frontal_near,
frontal_far,
truncation=opt.truncation_ratio,
truncation_latent=surface_mean_latent,
return_sdf=True)
frontal_sdf = frontal_surface_out[2].cpu()
print('Extracting Identity {} Frontal view Marching Cubes for consistent video rendering'.format(i))
frostum_aligned_frontal_sdf = align_volume(frontal_sdf)
del frontal_sdf
try:
frontal_marching_cubes_mesh = extract_mesh_with_marching_cubes(frostum_aligned_frontal_sdf)
except ValueError:
frontal_marching_cubes_mesh = None
if frontal_marching_cubes_mesh != None:
frontal_marching_cubes_mesh_filename = os.path.join(opt.results_dst_dir,'sample_{}_frontal_marching_cubes_mesh{}.obj'.format(i,suffix))
with open(frontal_marching_cubes_mesh_filename, 'w') as f:
frontal_marching_cubes_mesh.export(f,file_type='obj')
del frontal_surface_out
torch.cuda.empty_cache()
#############################################################################################
for j in tqdm(range(0, num_frames, chunk)):
with torch.no_grad():
out = g_ema([sample_z],
sample_cam_extrinsics[j:j+chunk],
sample_focals[j:j+chunk],
sample_near[j:j+chunk],
sample_far[j:j+chunk],
truncation=opt.truncation_ratio,
truncation_latent=mean_latent,
randomize_noise=False,
project_noise=opt.project_noise,
mesh_path=frontal_marching_cubes_mesh_filename if opt.project_noise else None)
rgb = out[0].cpu()
# this is done to fit to RTX2080 RAM size (11GB)
del out
torch.cuda.empty_cache()
# Convert RGB from [-1, 1] to [0,255]
rgb = 127.5 * (rgb.clamp(-1,1).permute(0,2,3,1).cpu().numpy() + 1)
# Add RGB, frame to video
for k in range(chunk):
writer.writeFrame(rgb[k])
########## Extract surface ##########
if not opt.no_surface_renderings:
scale = surface_g_ema.renderer.out_im_res / g_ema.renderer.out_im_res
surface_sample_focals = sample_focals * scale
surface_out = surface_g_ema([sample_z],
sample_cam_extrinsics[j:j+chunk],
surface_sample_focals[j:j+chunk],
sample_near[j:j+chunk],
sample_far[j:j+chunk],
truncation=opt.truncation_ratio,
truncation_latent=surface_mean_latent,
return_xyz=True)
xyz = surface_out[2].cpu()
# this is done to fit to RTX2080 RAM size (11GB)
del surface_out
torch.cuda.empty_cache()
# Render mesh for video
depth_mesh = xyz2mesh(xyz)
mesh = Meshes(
verts=[torch.from_numpy(np.asarray(depth_mesh.vertices)).to(torch.float32).to(device)],
faces = [torch.from_numpy(np.asarray(depth_mesh.faces)).to(torch.float32).to(device)],
textures=None,
verts_normals=[torch.from_numpy(np.copy(np.asarray(depth_mesh.vertex_normals))).to(torch.float32).to(device)],
)
mesh = add_textures(mesh)
cameras = create_cameras(azim=np.rad2deg(trajectory[j,0].cpu().numpy()),
elev=np.rad2deg(trajectory[j,1].cpu().numpy()),
fov=2*trajectory[j,2].cpu().numpy(),
dist=1, device=device)
renderer = create_mesh_renderer(cameras, image_size=512,
light_location=((0.0,1.0,5.0),), specular_color=((0.2,0.2,0.2),),
ambient_color=((0.1,0.1,0.1),), diffuse_color=((0.65,.65,.65),),
device=device)
mesh_image = 255 * renderer(mesh).cpu().numpy()
mesh_image = mesh_image[...,:3]
# Add depth frame to video
for k in range(chunk):
depth_writer.writeFrame(mesh_image[k])
# Close video writers
writer.close()
if not opt.no_surface_renderings:
depth_writer.close()
if __name__ == "__main__":
device = "cuda"
opt = BaseOptions().parse()
opt.model.is_test = True
opt.model.style_dim = 256
opt.model.freeze_renderer = False
opt.inference.size = opt.model.size
opt.inference.camera = opt.camera
opt.inference.renderer_output_size = opt.model.renderer_spatial_output_dim
opt.inference.style_dim = opt.model.style_dim
opt.inference.project_noise = opt.model.project_noise
opt.rendering.perturb = 0
opt.rendering.force_background = True
opt.rendering.static_viewdirs = True
opt.rendering.return_sdf = True
opt.rendering.N_samples = 64
# find checkpoint directory
# check if there's a fully trained model
checkpoints_dir = 'full_models'
checkpoint_path = os.path.join(checkpoints_dir, opt.experiment.expname + '.pt')
if os.path.isfile(checkpoint_path):
# define results directory name
result_model_dir = 'final_model'
else:
checkpoints_dir = os.path.join('checkpoint', opt.experiment.expname, 'full_pipeline')
checkpoint_path = os.path.join(checkpoints_dir,
'models_{}.pt'.format(opt.experiment.ckpt.zfill(7)))
# define results directory name
result_model_dir = 'iter_{}'.format(opt.experiment.ckpt.zfill(7))
results_dir_basename = os.path.join(opt.inference.results_dir, opt.experiment.expname)
opt.inference.results_dst_dir = os.path.join(results_dir_basename, result_model_dir, 'videos')
if opt.model.project_noise:
opt.inference.results_dst_dir = os.path.join(opt.inference.results_dst_dir, 'with_noise_projection')
os.makedirs(opt.inference.results_dst_dir, exist_ok=True)
# load saved model
checkpoint = torch.load(checkpoint_path)
# load image generation model
g_ema = Generator(opt.model, opt.rendering).to(device)
# temp fix because of wrong noise sizes
pretrained_weights_dict = checkpoint["g_ema"]
model_dict = g_ema.state_dict()
for k, v in pretrained_weights_dict.items():
if v.size() == model_dict[k].size():
model_dict[k] = v
g_ema.load_state_dict(model_dict)
# load a the volume renderee to a second that extracts surfaces at 128x128x128
if not opt.inference.no_surface_renderings or opt.model.project_noise:
opt['surf_extraction'] = Munch()
opt.surf_extraction.rendering = opt.rendering
opt.surf_extraction.model = opt.model.copy()
opt.surf_extraction.model.renderer_spatial_output_dim = 128
opt.surf_extraction.rendering.N_samples = opt.surf_extraction.model.renderer_spatial_output_dim
opt.surf_extraction.rendering.return_xyz = True
opt.surf_extraction.rendering.return_sdf = True
opt.inference.surf_extraction_output_size = opt.surf_extraction.model.renderer_spatial_output_dim
surface_g_ema = Generator(opt.surf_extraction.model, opt.surf_extraction.rendering, full_pipeline=False).to(device)
# Load weights to surface extractor
surface_extractor_dict = surface_g_ema.state_dict()
for k, v in pretrained_weights_dict.items():
if k in surface_extractor_dict.keys() and v.size() == surface_extractor_dict[k].size():
surface_extractor_dict[k] = v
surface_g_ema.load_state_dict(surface_extractor_dict)
else:
surface_g_ema = None
# get the mean latent vector for g_ema
if opt.inference.truncation_ratio < 1:
with torch.no_grad():
mean_latent = g_ema.mean_latent(opt.inference.truncation_mean, device)
else:
mean_latent = None
# get the mean latent vector for surface_g_ema
if not opt.inference.no_surface_renderings or opt.model.project_noise:
surface_mean_latent = mean_latent[0]
else:
surface_mean_latent = None
render_video(opt.inference, g_ema, surface_g_ema, device, mean_latent, surface_mean_latent)
================================================
FILE: requirements.txt
================================================
lmdb
numpy
ninja
pillow
requests
tqdm
scipy
scikit-image
scikit-video
trimesh[easy]
configargparse
munch
wandb
================================================
FILE: scripts/train_afhq_full_pipeline_512x512.sh
================================================
python -m torch.distributed.launch --nproc_per_node 4 new_train.py --batch 8 --chunk 4 --azim 0.15 --r1 50.0 --expname afhq512x512 --dataset_path ./datasets/AFHQ/train/ --size 512 --wandb
================================================
FILE: scripts/train_afhq_vol_renderer.sh
================================================
python -m torch.distributed.launch --nproc_per_node 2 train_volume_renderer.py --batch 12 --chunk 6 --expname afhq_sdf_vol_renderer --dataset_path ./datasets/AFHQ/train --azim 0.15 --wandb
================================================
FILE: scripts/train_ffhq_full_pipeline_1024x1024.sh
================================================
python -m torch.distributed.launch --nproc_per_node 4 train_full_pipeline.py --batch 8 --chunk 2 --expname ffhq1024x1024 --size 1024 --wandb
================================================
FILE: scripts/train_ffhq_vol_renderer.sh
================================================
python -m torch.distributed.launch --nproc_per_node 2 train_volume_renderer.py --batch 12 --chunk 6 --expname ffhq_sdf_vol_renderer --dataset_path ./datasets/FFHQ/ --wandb
================================================
FILE: train_full_pipeline.py
================================================
import argparse
import math
import random
import os
import yaml
import numpy as np
import torch
from torch import nn, autograd, optim
from torch.nn import functional as F
from torch.utils import data
import torch.distributed as dist
from torchvision import transforms, utils
from tqdm import tqdm
from PIL import Image
from losses import *
from options import BaseOptions
from model import Generator, Discriminator
from dataset import MultiResolutionDataset
from utils import data_sampler, requires_grad, accumulate, sample_data, make_noise, mixing_noise, generate_camera_params
from distributed import get_rank, synchronize, reduce_loss_dict, reduce_sum, get_world_size
try:
import wandb
except ImportError:
wandb = None
def train(opt, experiment_opt, loader, generator, discriminator, g_optim, d_optim, g_ema, device):
loader = sample_data(loader)
pbar = range(opt.iter)
if get_rank() == 0:
pbar = tqdm(pbar, initial=opt.start_iter, dynamic_ncols=True, smoothing=0.01)
mean_path_length = 0
d_loss_val = 0
g_loss_val = 0
r1_loss = torch.tensor(0.0, device=device)
g_gan_loss = torch.tensor(0.0, device=device)
path_loss = torch.tensor(0.0, device=device)
path_lengths = torch.tensor(0.0, device=device)
mean_path_length_avg = 0
loss_dict = {}
if opt.distributed:
g_module = generator.module
d_module = discriminator.module
else:
g_module = generator
d_module = discriminator
accum = 0.5 ** (32 / (10 * 1000))
sample_z = [torch.randn(opt.val_n_sample, opt.style_dim, device=device).repeat_interleave(8, dim=0)]
sample_cam_extrinsics, sample_focals, sample_near, sample_far, _ = generate_camera_params(opt.renderer_output_size, device, batch=opt.val_n_sample, sweep=True,
uniform=opt.camera.uniform, azim_range=opt.camera.azim,
elev_range=opt.camera.elev, fov_ang=opt.camera.fov,
dist_radius=opt.camera.dist_radius)
for idx in pbar:
i = idx + opt.start_iter
if i > opt.iter:
print("Done!")
break
requires_grad(generator, False)
requires_grad(discriminator, True)
discriminator.zero_grad()
d_regularize = i % opt.d_reg_every == 0
real_imgs, real_thumb_imgs = next(loader)
real_imgs = real_imgs.to(device)
real_thumb_imgs = real_thumb_imgs.to(device)
noise = mixing_noise(opt.batch, opt.style_dim, opt.mixing, device)
cam_extrinsics, focal, near, far, gt_viewpoints = generate_camera_params(opt.renderer_output_size, device, batch=opt.batch,
uniform=opt.camera.uniform, azim_range=opt.camera.azim,
elev_range=opt.camera.elev, fov_ang=opt.camera.fov,
dist_radius=opt.camera.dist_radius)
for j in range(0, opt.batch, opt.chunk):
curr_real_imgs = real_imgs[j:j+opt.chunk]
curr_real_thumb_imgs = real_thumb_imgs[j:j+opt.chunk]
curr_noise = [n[j:j+opt.chunk] for n in noise]
gen_imgs, _ = generator(curr_noise,
cam_extrinsics[j:j+opt.chunk],
focal[j:j+opt.chunk],
near[j:j+opt.chunk],
far[j:j+opt.chunk])
fake_pred = discriminator(gen_imgs.detach())
if d_regularize:
curr_real_imgs.requires_grad = True
curr_real_thumb_imgs.requires_grad = True
real_pred = discriminator(curr_real_imgs)
d_gan_loss = d_logistic_loss(real_pred, fake_pred)
if d_regularize:
grad_penalty = d_r1_loss(real_pred, curr_real_imgs)
r1_loss = opt.r1 * 0.5 * grad_penalty * opt.d_reg_every
else:
r1_loss = torch.zeros_like(r1_loss)
d_loss = d_gan_loss + r1_loss
d_loss.backward()
d_optim.step()
loss_dict["d"] = d_gan_loss
loss_dict["real_score"] = real_pred.mean()
loss_dict["fake_score"] = fake_pred.mean()
if d_regularize or i == opt.start_iter:
loss_dict["r1"] = r1_loss.mean()
requires_grad(generator, True)
requires_grad(discriminator, False)
for j in range(0, opt.batch, opt.chunk):
noise = mixing_noise(opt.chunk, opt.style_dim, opt.mixing, device)
cam_extrinsics, focal, near, far, gt_viewpoints = generate_camera_params(opt.renderer_output_size, device, batch=opt.chunk,
uniform=opt.camera.uniform, azim_range=opt.camera.azim,
elev_range=opt.camera.elev, fov_ang=opt.camera.fov,
dist_radius=opt.camera.dist_radius)
fake_img, _ = generator(noise, cam_extrinsics, focal, near, far)
fake_pred = discriminator(fake_img)
g_gan_loss = g_nonsaturating_loss(fake_pred)
g_loss = g_gan_loss
g_loss.backward()
g_optim.step()
generator.zero_grad()
loss_dict["g"] = g_gan_loss
# generator path regularization
g_regularize = (opt.g_reg_every > 0) and (i % opt.g_reg_every == 0)
if g_regularize:
path_batch_size = max(1, opt.batch // opt.path_batch_shrink)
path_noise = mixing_noise(path_batch_size, opt.style_dim, opt.mixing, device)
path_cam_extrinsics, path_focal, path_near, path_far, _ = generate_camera_params(opt.renderer_output_size, device, batch=path_batch_size,
uniform=opt.camera.uniform, azim_range=opt.camera.azim,
elev_range=opt.camera.elev, fov_ang=opt.camera.fov,
dist_radius=opt.camera.dist_radius)
for j in range(0, path_batch_size, opt.chunk):
path_fake_img, path_latents = generator(path_noise, path_cam_extrinsics,
path_focal, path_near, path_far,
return_latents=True)
path_loss, mean_path_length, path_lengths = g_path_regularize(
path_fake_img, path_latents, mean_path_length
)
weighted_path_loss = opt.path_regularize * opt.g_reg_every * path_loss# * opt.chunk / path_batch_size
if opt.path_batch_shrink:
weighted_path_loss += 0 * path_fake_img[0, 0, 0, 0]
weighted_path_loss.backward()
g_optim.step()
generator.zero_grad()
mean_path_length_avg = (reduce_sum(mean_path_length).item() / get_world_size())
loss_dict["path"] = path_loss
loss_dict["path_length"] = path_lengths.mean()
accumulate(g_ema, g_module, accum)
loss_reduced = reduce_loss_dict(loss_dict)
d_loss_val = loss_reduced["d"].mean().item()
g_loss_val = loss_reduced["g"].mean().item()
r1_val = loss_reduced["r1"].mean().item()
real_score_val = loss_reduced["real_score"].mean().item()
fake_score_val = loss_reduced["fake_score"].mean().item()
path_loss_val = loss_reduced["path"].mean().item()
path_length_val = loss_reduced["path_length"].mean().item()
if get_rank() == 0:
pbar.set_description(
(f"d: {d_loss_val:.4f}; g: {g_loss_val:.4f}; r1: {r1_val:.4f}; path: {path_loss_val:.4f}")
)
if i % 1000 == 0 or i == opt.start_iter:
with torch.no_grad():
thumbs_samples = torch.Tensor(0, 3, opt.renderer_output_size, opt.renderer_output_size)
samples = torch.Tensor(0, 3, opt.size, opt.size)
step_size = 8
mean_latent = g_module.mean_latent(10000, device)
for k in range(0, opt.val_n_sample * 8, step_size):
curr_samples, curr_thumbs = g_ema([sample_z[0][k:k+step_size]],
sample_cam_extrinsics[k:k+step_size],
sample_focals[k:k+step_size],
sample_near[k:k+step_size],
sample_far[k:k+step_size],
truncation=0.7,
truncation_latent=mean_latent)
samples = torch.cat([samples, curr_samples.cpu()], 0)
thumbs_samples = torch.cat([thumbs_samples, curr_thumbs.cpu()], 0)
if i % 10000 == 0:
utils.save_image(samples,
os.path.join(opt.checkpoints_dir, experiment_opt.expname, 'full_pipeline', f"samples/{str(i).zfill(7)}.png"),
nrow=int(opt.val_n_sample),
normalize=True,
value_range=(-1, 1),)
utils.save_image(thumbs_samples,
os.path.join(opt.checkpoints_dir, experiment_opt.expname, 'full_pipeline', f"samples/{str(i).zfill(7)}_thumbs.png"),
nrow=int(opt.val_n_sample),
normalize=True,
value_range=(-1, 1),)
if wandb and opt.wandb:
wandb_log_dict = {"Generator": g_loss_val,
"Discriminator": d_loss_val,
"R1": r1_val,
"Real Score": real_score_val,
"Fake Score": fake_score_val,
"Path Length Regularization": path_loss_val,
"Path Length": path_length_val,
"Mean Path Length": mean_path_length,
}
if i % 5000 == 0:
wandb_grid = utils.make_grid(samples, nrow=int(opt.val_n_sample),
normalize=True, value_range=(-1, 1))
wandb_ndarr = (255 * wandb_grid.permute(1, 2, 0).numpy()).astype(np.uint8)
wandb_images = Image.fromarray(wandb_ndarr)
wandb_log_dict.update({"examples": [wandb.Image(wandb_images,
caption="Generated samples for azimuth angles of: -0.35, -0.25, -0.15, -0.05, 0.05, 0.15, 0.25, 0.35 Radians.")]})
wandb_thumbs_grid = utils.make_grid(thumbs_samples, nrow=int(opt.val_n_sample),
normalize=True, value_range=(-1, 1))
wandb_thumbs_ndarr = (255 * wandb_thumbs_grid.permute(1, 2, 0).numpy()).astype(np.uint8)
wandb_thumbs = Image.fromarray(wandb_thumbs_ndarr)
wandb_log_dict.update({"thumb_examples": [wandb.Image(wandb_thumbs,
caption="Generated samples for azimuth angles of: -0.35, -0.25, -0.15, -0.05, 0.05, 0.15, 0.25, 0.35 Radians.")]})
wandb.log(wandb_log_dict)
if i % 10000 == 0:
torch.save(
{
"g": g_module.state_dict(),
"d": d_module.state_dict(),
"g_ema": g_ema.state_dict(),
},
os.path.join(opt.checkpoints_dir, experiment_opt.expname, 'full_pipeline', f"models_{str(i).zfill(7)}.pt")
)
print('Successfully saved checkpoint for iteration {}.'.format(i))
if get_rank() == 0:
# create final model directory
final_model_path = os.path.join('full_models', opt.experiment.expname)
os.makedirs(final_model_path, exist_ok=True)
torch.save(
{
"g": g_module.state_dict(),
"d": d_module.state_dict(),
"g_ema": g_ema.state_dict(),
},
os.path.join(final_model_path, experiment_opt.expname + '.pt')
)
print('Successfully saved final model.')
if __name__ == "__main__":
device = "cuda"
opt = BaseOptions().parse()
opt.training.camera = opt.camera
opt.training.size = opt.model.size
opt.training.renderer_output_size = opt.model.renderer_spatial_output_dim
opt.training.style_dim = opt.model.style_dim
opt.model.freeze_renderer = True
n_gpu = int(os.environ["WORLD_SIZE"]) if "WORLD_SIZE" in os.environ else 1
opt.training.distributed = n_gpu > 1
if opt.training.distributed:
torch.cuda.set_device(opt.training.local_rank)
torch.distributed.init_process_group(backend="nccl", init_method="env://")
synchronize()
# create checkpoints directories
os.makedirs(os.path.join(opt.training.checkpoints_dir, opt.experiment.expname, 'full_pipeline'), exist_ok=True)
os.makedirs(os.path.join(opt.training.checkpoints_dir, opt.experiment.expname, 'full_pipeline', 'samples'), exist_ok=True)
discriminator = Discriminator(opt.model).to(device)
generator = Generator(opt.model, opt.rendering).to(device)
g_ema = Generator(opt.model, opt.rendering, ema=True).to(device)
g_ema.eval()
g_reg_ratio = opt.training.g_reg_every / (opt.training.g_reg_every + 1) if opt.training.g_reg_every > 0 else 1
d_reg_ratio = opt.training.d_reg_every / (opt.training.d_reg_every + 1)
params_g = []
params_dict_g = dict(generator.named_parameters())
for key, value in params_dict_g.items():
decoder_cond = ('decoder' in key)
if decoder_cond:
params_g += [{'params':[value], 'lr':opt.training.lr * g_reg_ratio}]
g_optim = optim.Adam(params_g, #generator.parameters(),
lr=opt.training.lr * g_reg_ratio,
betas=(0 ** g_reg_ratio, 0.99 ** g_reg_ratio))
d_optim = optim.Adam(discriminator.parameters(),
lr=opt.training.lr * d_reg_ratio,# * g_d_ratio,
betas=(0 ** d_reg_ratio, 0.99 ** d_reg_ratio))
opt.training.start_iter = 0
if opt.experiment.continue_training and opt.experiment.ckpt is not None:
if get_rank() == 0:
print("load model:", opt.experiment.ckpt)
ckpt_path = os.path.join(opt.training.checkpoints_dir,
opt.experiment.expname,
'models_{}.pt'.format(opt.experiment.ckpt.zfill(7)))
ckpt = torch.load(ckpt_path, map_location=lambda storage, loc: storage)
try:
opt.training.start_iter = int(opt.experiment.ckpt) + 1
except ValueError:
pass
generator.load_state_dict(ckpt["g"])
discriminator.load_state_dict(ckpt["d"])
g_ema.load_state_dict(ckpt["g_ema"])
else:
# save configuration
opt_path = os.path.join(opt.training.checkpoints_dir, opt.experiment.expname, 'full_pipeline', f"opt.yaml")
with open(opt_path,'w') as f:
yaml.safe_dump(opt, f)
if not opt.experiment.continue_training:
if get_rank() == 0:
print("loading pretrained renderer weights...")
pretrained_renderer_path = os.path.join('./pretrained_renderer',
opt.experiment.expname + '_vol_renderer.pt')
try:
ckpt = torch.load(pretrained_renderer_path, map_location=lambda storage, loc: storage)
except:
print('Pretrained volume renderer experiment name does not match the full pipeline experiment name.')
vol_renderer_expname = str(input('Please enter the pretrained volume renderer experiment name:'))
pretrained_renderer_path = os.path.join('./pretrained_renderer',
vol_renderer_expname + '.pt')
ckpt = torch.load(pretrained_renderer_path, map_location=lambda storage, loc: storage)
pretrained_renderer_dict = ckpt["g_ema"]
model_dict = generator.state_dict()
for k, v in pretrained_renderer_dict.items():
if v.size() == model_dict[k].size():
model_dict[k] = v
generator.load_state_dict(model_dict)
# initialize g_ema weights to generator weights
accumulate(g_ema, generator, 0)
# set distributed models
if opt.training.distributed:
generator = nn.parallel.DistributedDataParallel(
generator,
device_ids=[opt.training.local_rank],
output_device=opt.training.local_rank,
broadcast_buffers=True,
find_unused_parameters=True,
)
discriminator = nn.parallel.DistributedDataParallel(
discriminator,
device_ids=[opt.training.local_rank],
output_device=opt.training.local_rank,
broadcast_buffers=False,
find_unused_parameters=True
)
transform = transforms.Compose(
[transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)])
dataset = MultiResolutionDataset(opt.dataset.dataset_path, transform, opt.model.size,
opt.model.renderer_spatial_output_dim)
loader = data.DataLoader(
dataset,
batch_size=opt.training.batch,
sampler=data_sampler(dataset, shuffle=True, distributed=opt.training.distributed),
drop_last=True,
)
if get_rank() == 0 and wandb is not None and opt.training.wandb:
wandb.init(project="StyleSDF")
wandb.run.name = opt.experiment.expname
wandb.config.dataset = os.path.basename(opt.dataset.dataset_path)
wandb.config.update(opt.training)
wandb.config.update(opt.model)
wandb.config.update(opt.rendering)
train(opt.training, opt.experiment, loader, generator, discriminator, g_optim, d_optim, g_ema, device)
================================================
FILE: train_volume_renderer.py
================================================
import argparse
import math
import random
import os
import yaml
import numpy as np
import torch
import torch.distributed as dist
from torch import nn, autograd, optim
from torch.nn import functional as F
from torch.utils import data
from torchvision import transforms, utils
from tqdm import tqdm
from PIL import Image
from losses import *
from options import BaseOptions
from model import Generator, VolumeRenderDiscriminator
from dataset import MultiResolutionDataset
from utils import data_sampler, requires_grad, accumulate, sample_data, make_noise, mixing_noise, generate_camera_params
from distributed import get_rank, synchronize, reduce_loss_dict, reduce_sum, get_world_size
try:
import wandb
except ImportError:
wandb = None
def train(opt, experiment_opt, loader, generator, discriminator, g_optim, d_optim, g_ema, device):
loader = sample_data(loader)
mean_path_length = 0
d_loss_val = 0
r1_loss = torch.tensor(0.0, device=device)
d_view_loss = torch.tensor(0.0, device=device)
g_view_loss = torch.tensor(0.0, device=device)
g_eikonal = torch.tensor(0.0, device=device)
g_minimal_surface = torch.tensor(0.0, device=device)
g_loss_val = 0
loss_dict = {}
viewpoint_condition = opt.view_lambda > 0
if opt.distributed:
g_module = generator.module
d_module = discriminator.module
else:
g_module = generator
d_module = discriminator
accum = 0.5 ** (32 / (10 * 1000))
sample_z = [torch.randn(opt.val_n_sample, opt.style_dim, device=device).repeat_interleave(8,dim=0)]
sample_cam_extrinsics, sample_focals, sample_near, sample_far, _ = generate_camera_params(opt.renderer_output_size, device, batch=opt.val_n_sample, sweep=True,
uniform=opt.camera.uniform, azim_range=opt.camera.azim,
elev_range=opt.camera.elev, fov_ang=opt.camera.fov,
dist_radius=opt.camera.dist_radius)
if opt.with_sdf and opt.sphere_init and opt.start_iter == 0:
init_pbar = range(10000)
if get_rank() == 0:
init_pbar = tqdm(init_pbar, initial=0, dynamic_ncols=True, smoothing=0.01)
generator.zero_grad()
for idx in init_pbar:
noise = mixing_noise(3, opt.style_dim, opt.mixing, device)
cam_extrinsics, focal, near, far, gt_viewpoints = generate_camera_params(opt.renderer_output_size, device, batch=3,
uniform=opt.camera.uniform, azim_range=opt.camera.azim,
elev_range=opt.camera.elev, fov_ang=opt.camera.fov,
dist_radius=opt.camera.dist_radius)
sdf, target_values = g_module.init_forward(noise, cam_extrinsics, focal, near, far)
loss = F.l1_loss(sdf, target_values)
loss.backward()
g_optim.step()
generator.zero_grad()
if get_rank() == 0:
init_pbar.set_description((f"MLP init to sphere procedure - Loss: {loss.item():.4f}"))
accumulate(g_ema, g_module, 0)
torch.save(
{
"g": g_module.state_dict(),
"d": d_module.state_dict(),
"g_ema": g_ema.state_dict(),
},
os.path.join(opt.checkpoints_dir, experiment_opt.expname, f"sdf_init_models_{str(0).zfill(7)}.pt")
)
print('Successfully saved checkpoint for SDF initialized MLP.')
pbar = range(opt.iter)
if get_rank() == 0:
pbar = tqdm(pbar, initial=opt.start_iter, dynamic_ncols=True, smoothing=0.01)
for idx in pbar:
i = idx + opt.start_iter
if i > opt.iter:
print("Done!")
break
requires_grad(generator, False)
requires_grad(discriminator, True)
discriminator.zero_grad()
_, real_imgs = next(loader)
real_imgs = real_imgs.to(device)
noise = mixing_noise(opt.batch, opt.style_dim, opt.mixing, device)
cam_extrinsics, focal, near, far, gt_viewpoints = generate_camera_params(opt.renderer_output_size, device, batch=opt.batch,
uniform=opt.camera.uniform, azim_range=opt.camera.azim,
elev_range=opt.camera.elev, fov_ang=opt.camera.fov,
dist_radius=opt.camera.dist_radius)
gen_imgs = []
for j in range(0, opt.batch, opt.chunk):
curr_noise = [n[j:j+opt.chunk] for n in noise]
_, fake_img = generator(curr_noise,
cam_extrinsics[j:j+opt.chunk],
focal[j:j+opt.chunk],
near[j:j+opt.chunk],
far[j:j+opt.chunk])
gen_imgs += [fake_img]
gen_imgs = torch.cat(gen_imgs, 0)
fake_pred, fake_viewpoint_pred = discriminator(gen_imgs.detach())
if viewpoint_condition:
d_view_loss = opt.view_lambda * viewpoints_loss(fake_viewpoint_pred, gt_viewpoints)
real_imgs.requires_grad = True
real_pred, _ = discriminator(real_imgs)
d_gan_loss = d_logistic_loss(real_pred, fake_pred)
grad_penalty = d_r1_loss(real_pred, real_imgs)
r1_loss = opt.r1 * 0.5 * grad_penalty
d_loss = d_gan_loss + r1_loss + d_view_loss
d_loss.backward()
d_optim.step()
loss_dict["d"] = d_gan_loss
loss_dict["r1"] = r1_loss
loss_dict["d_view"] = d_view_loss
loss_dict["real_score"] = real_pred.mean()
loss_dict["fake_score"] = fake_pred.mean()
requires_grad(generator, True)
requires_grad(discriminator, False)
for j in range(0, opt.batch, opt.chunk):
noise = mixing_noise(opt.chunk, opt.style_dim, opt.mixing, device)
cam_extrinsics, focal, near, far, curr_gt_viewpoints = generate_camera_params(opt.renderer_output_size, device, batch=opt.chunk,
uniform=opt.camera.uniform, azim_range=opt.camera.azim,
elev_range=opt.camera.elev, fov_ang=opt.camera.fov,
dist_radius=opt.camera.dist_radius)
out = generator(noise, cam_extrinsics, focal, near, far,
return_sdf=opt.min_surf_lambda > 0,
return_eikonal=opt.eikonal_lambda > 0)
fake_img = out[1]
if opt.min_surf_lambda > 0:
sdf = out[2]
if opt.eikonal_lambda > 0:
eikonal_term = out[3]
fake_pred, fake_viewpoint_pred = discriminator(fake_img)
if viewpoint_condition:
g_view_loss = opt.view_lambda * viewpoints_loss(fake_viewpoint_pred, curr_gt_viewpoints)
if opt.with_sdf and opt.eikonal_lambda > 0:
g_eikonal, g_minimal_surface = eikonal_loss(eikonal_term, sdf=sdf if opt.min_surf_lambda > 0 else None,
beta=opt.min_surf_beta)
g_eikonal = opt.eikonal_lambda * g_eikonal
if opt.min_surf_lambda > 0:
g_minimal_surface = opt.min_surf_lambda * g_minimal_surface
g_gan_loss = g_nonsaturating_loss(fake_pred)
g_loss = g_gan_loss + g_view_loss + g_eikonal + g_minimal_surface
g_loss.backward()
g_optim.step()
generator.zero_grad()
loss_dict["g"] = g_gan_loss
loss_dict["g_view"] = g_view_loss
loss_dict["g_eikonal"] = g_eikonal
loss_dict["g_minimal_surface"] = g_minimal_surface
accumulate(g_ema, g_module, accum)
loss_reduced = reduce_loss_dict(loss_dict)
d_loss_val = loss_reduced["d"].mean().item()
g_loss_val = loss_reduced["g"].mean().item()
r1_val = loss_reduced["r1"].mean().item()
real_score_val = loss_reduced["real_score"].mean().item()
fake_score_val = loss_reduced["fake_score"].mean().item()
d_view_val = loss_reduced["d_view"].mean().item()
g_view_val = loss_reduced["g_view"].mean().item()
g_eikonal_loss = loss_reduced["g_eikonal"].mean().item()
g_minimal_surface_loss = loss_reduced["g_minimal_surface"].mean().item()
g_beta_val = g_module.renderer.sigmoid_beta.item() if opt.with_sdf else 0
if get_rank() == 0:
pbar.set_description(
(f"d: {d_loss_val:.4f}; g: {g_loss_val:.4f}; r1: {r1_val:.4f}; viewpoint: {d_view_val+g_view_val:.4f}; eikonal: {g_eikonal_loss:.4f}; surf: {g_minimal_surface_loss:.4f}")
)
if i % 1000 == 0:
with torch.no_grad():
samples = torch.Tensor(0, 3, opt.renderer_output_size, opt.renderer_output_size)
step_size = 4
mean_latent = g_module.mean_latent(10000, device)
for k in range(0, opt.val_n_sample * 8, step_size):
_, curr_samples = g_ema([sample_z[0][k:k+step_size]],
sample_cam_extrinsics[k:k+step_size],
sample_focals[k:k+step_size],
sample_near[k:k+step_size],
sample_far[k:k+step_size],
truncation=0.7,
truncation_latent=mean_latent,)
samples = torch.cat([samples, curr_samples.cpu()], 0)
if i % 10000 == 0:
utils.save_image(samples,
os.path.join(opt.checkpoints_dir, experiment_opt.expname, 'volume_renderer', f"samples/{str(i).zfill(7)}.png"),
nrow=int(opt.val_n_sample),
normalize=True,
value_range=(-1, 1),)
if wandb and opt.wandb:
wandb_log_dict = {"Generator": g_loss_val,
"Discriminator": d_loss_val,
"R1": r1_val,
"Real Score": real_score_val,
"Fake Score": fake_score_val,
"D viewpoint": d_view_val,
"G viewpoint": g_view_val,
"G eikonal loss": g_eikonal_loss,
"G minimal surface loss": g_minimal_surface_loss,
}
if opt.with_sdf:
wandb_log_dict.update({"Beta value": g_beta_val})
if i % 1000 == 0:
wandb_grid = utils.make_grid(samples, nrow=int(opt.val_n_sample),
normalize=True, value_range=(-1, 1))
wandb_ndarr = (255 * wandb_grid.permute(1, 2, 0).numpy()).astype(np.uint8)
wandb_images = Image.fromarray(wandb_ndarr)
wandb_log_dict.update({"examples": [wandb.Image(wandb_images,
caption="Generated samples for azimuth angles of: -35, -25, -15, -5, 5, 15, 25, 35 degrees.")]})
wandb.log(wandb_log_dict)
if i % 10000 == 0 or (i < 10000 and i % 1000 == 0):
torch.save(
{
"g": g_module.state_dict(),
"d": d_module.state_dict(),
"g_ema": g_ema.state_dict(),
},
os.path.join(opt.checkpoints_dir, experiment_opt.expname, 'volume_renderer', f"models_{str(i).zfill(7)}.pt")
)
print('Successfully saved checkpoint for iteration {}.'.format(i))
if get_rank() == 0:
# create final model directory
final_model_path = 'pretrained_renderer'
os.makedirs(final_model_path, exist_ok=True)
torch.save(
{
"g": g_module.state_dict(),
"d": d_module.state_dict(),
"g_ema": g_ema.state_dict(),
},
os.path.join(final_model_path, experiment_opt.expname + '_vol_renderer.pt')
)
print('Successfully saved final model.')
if __name__ == "__main__":
device = "cuda"
opt = BaseOptions().parse()
opt.model.freeze_renderer = False
opt.model.no_viewpoint_loss = opt.training.view_lambda == 0.0
opt.training.camera = opt.camera
opt.training.renderer_output_size = opt.model.renderer_spatial_output_dim
opt.training.style_dim = opt.model.style_dim
opt.training.with_sdf = not opt.rendering.no_sdf
if opt.training.with_sdf and opt.training.min_surf_lambda > 0:
opt.rendering.return_sdf = True
opt.training.iter = 200001
opt.rendering.no_features_output = True
n_gpu = int(os.environ["WORLD_SIZE"]) if "WORLD_SIZE" in os.environ else 1
opt.training.distributed = n_gpu > 1
if opt.training.distributed:
torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))
torch.distributed.init_process_group(backend="nccl", init_method="env://")
synchronize()
# create checkpoints directories
os.makedirs(os.path.join(opt.training.checkpoints_dir, opt.experiment.expname, 'volume_renderer'), exist_ok=True)
os.makedirs(os.path.join(opt.training.checkpoints_dir, opt.experiment.expname, 'volume_renderer', 'samples'), exist_ok=True)
discriminator = VolumeRenderDiscriminator(opt.model).to(device)
generator = Generator(opt.model, opt.rendering, full_pipeline=False).to(device)
g_ema = Generator(opt.model, opt.rendering, ema=True, full_pipeline=False).to(device)
g_ema.eval()
accumulate(g_ema, generator, 0)
g_optim = optim.Adam(generator.parameters(), lr=2e-5, betas=(0, 0.9))
d_optim = optim.Adam(discriminator.parameters(), lr=2e-4, betas=(0, 0.9))
opt.training.start_iter = 0
if opt.experiment.continue_training and opt.experiment.ckpt is not None:
if get_rank() == 0:
print("load model:", opt.experiment.ckpt)
ckpt_path = os.path.join(opt.training.checkpoints_dir,
opt.experiment.expname,
'models_{}.pt'.format(opt.experiment.ckpt.zfill(7)))
ckpt = torch.load(ckpt_path, map_location=lambda storage, loc: storage)
try:
opt.training.start_iter = int(opt.experiment.ckpt) + 1
except ValueError:
pass
generator.load_state_dict(ckpt["g"])
discriminator.load_state_dict(ckpt["d"])
g_ema.load_state_dict(ckpt["g_ema"])
if "g_optim" in ckpt.keys():
g_optim.load_state_dict(ckpt["g_optim"])
d_optim.load_state_dict(ckpt["d_optim"])
sphere_init_path = './pretrained_renderer/sphere_init.pt'
if opt.training.no_sphere_init:
opt.training.sphere_init = False
elif not opt.experiment.continue_training and opt.training.with_sdf and os.path.isfile(sphere_init_path):
if get_rank() == 0:
print("loading sphere inititialized model")
ckpt = torch.load(sphere_init_path, map_location=lambda storage, loc: storage)
generator.load_state_dict(ckpt["g"])
discriminator.load_state_dict(ckpt["d"])
g_ema.load_state_dict(ckpt["g_ema"])
opt.training.sphere_init = False
else:
opt.training.sphere_init = True
if opt.training.distributed:
generator = nn.parallel.DistributedDataParallel(
generator,
device_ids=[opt.training.local_rank],
output_device=opt.training.local_rank,
broadcast_buffers=True,
find_unused_parameters=True,
)
discriminator = nn.parallel.DistributedDataParallel(
discriminator,
device_ids=[opt.training.local_rank],
output_device=opt.training.local_rank,
broadcast_buffers=False,
find_unused_parameters=True
)
transform = transforms.Compose(
[transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)])
dataset = MultiResolutionDataset(opt.dataset.dataset_path, transform, opt.model.size,
opt.model.renderer_spatial_output_dim)
loader = data.DataLoader(
dataset,
batch_size=opt.training.batch,
sampler=data_sampler(dataset, shuffle=True, distributed=opt.training.distributed),
drop_last=True,
)
opt.training.dataset_name = opt.dataset.dataset_path.lower()
# save options
opt_path = os.path.join(opt.training.checkpoints_dir, opt.experiment.expname, 'volume_renderer', f"opt.yaml")
with open(opt_path,'w') as f:
yaml.safe_dump(opt, f)
# set wandb environment
if get_rank() == 0 and wandb is not None and opt.training.wandb:
wandb.init(project="StyleSDF")
wandb.run.name = opt.experiment.expname
wandb.config.dataset = os.path.basename(opt.dataset.dataset_path)
wandb.config.update(opt.training)
wandb.config.update(opt.model)
wandb.config.update(opt.rendering)
train(opt.training, opt.experiment, loader, generator, discriminator, g_optim, d_optim, g_ema, device)
================================================
FILE: utils.py
================================================
import torch
import random
import trimesh
import numpy as np
from torch import nn
from torch.nn import functional as F
from torch.utils import data
from scipy.spatial import Delaunay
from skimage.measure import marching_cubes
from pdb import set_trace as st
import pytorch3d.io
from pytorch3d.structures import Meshes
from pytorch3d.renderer import (
look_at_view_transform,
FoVPerspectiveCameras,
PointLights,
RasterizationSettings,
MeshRenderer,
MeshRasterizer,
SoftPhongShader,
TexturesVertex,
)
######################### Dataset util functions ###########################
# Get data sampler
def data_sampler(dataset, shuffle, distributed):
if distributed:
return data.distributed.DistributedSampler(dataset, shuffle=shuffle)
if shuffle:
return data.RandomSampler(dataset)
else:
return data.SequentialSampler(dataset)
# Get data minibatch
def sample_data(loader):
while True:
for batch in loader:
yield batch
############################## Model weights util functions #################
# Turn model gradients on/off
def requires_grad(model, flag=True):
for p in model.parameters():
p.requires_grad = flag
# Exponential moving average for generator weights
def accumulate(model1, model2, decay=0.999):
par1 = dict(model1.named_parameters())
par2 = dict(model2.named_parameters())
for k in par1.keys():
par1[k].data.mul_(decay).add_(par2[k].data, alpha=1 - decay)
################### Latent code (Z) sampling util functions ####################
# Sample Z space latent codes for the generator
def make_noise(batch, latent_dim, n_noise, device):
if n_noise == 1:
return torch.randn(batch, latent_dim, device=device)
noises = torch.randn(n_noise, batch, latent_dim, device=device).unbind(0)
return noises
def mixing_noise(batch, latent_dim, prob, device):
if prob > 0 and random.random() < prob:
return make_noise(batch, latent_dim, 2, device)
else:
return [make_noise(batch, latent_dim, 1, device)]
################# Camera parameters sampling ####################
def generate_camera_params(resolution, device, batch=1, locations=None, sweep=False,
uniform=False, azim_range=0.3, elev_range=0.15,
fov_ang=6, dist_radius=0.12):
if locations != None:
azim = locations[:,0].view(-1,1)
elev = locations[:,1].view(-1,1)
# generate intrinsic parameters
# fix distance to 1
dist = torch.ones(azim.shape[0], 1, device=device)
near, far = (dist - dist_radius).unsqueeze(-1), (dist + dist_radius).unsqueeze(-1)
fov_angle = fov_ang * torch.ones(azim.shape[0], 1, device=device).view(-1,1) * np.pi / 180
focal = 0.5 * resolution / torch.tan(fov_angle).unsqueeze(-1)
elif sweep:
# generate camera locations on the unit sphere
azim = (-azim_range + (2 * azim_range / 7) * torch.arange(8, device=device)).view(-1,1).repeat(batch,1)
elev = (-elev_range + 2 * elev_range * torch.rand(batch, 1, device=device).repeat(1,8).view(-1,1))
# generate intrinsic parameters
dist = (torch.ones(batch, 1, device=device)).repeat(1,8).view(-1,1)
near, far = (dist - dist_radius).unsqueeze(-1), (dist + dist_radius).unsqueeze(-1)
fov_angle = fov_ang * torch.ones(batch, 1, device=device).repeat(1,8).view(-1,1) * np.pi / 180
focal = 0.5 * resolution / torch.tan(fov_angle).unsqueeze(-1)
else:
# sample camera locations on the unit sphere
if uniform:
azim = (-azim_range + 2 * azim_range * torch.rand(batch, 1, device=device))
elev = (-elev_range + 2 * elev_range * torch.rand(batch, 1, device=device))
else:
azim = (azim_range * torch.randn(batch, 1, device=device))
elev = (elev_range * torch.randn(batch, 1, device=device))
# generate intrinsic parameters
dist = torch.ones(batch, 1, device=device) # restrict camera position to be on the unit sphere
near, far = (dist - dist_radius).unsqueeze(-1), (dist + dist_radius).unsqueeze(-1)
fov_angle = fov_ang * torch.ones(batch, 1, device=device) * np.pi / 180 # full fov is 12 degrees
focal = 0.5 * resolution / torch.tan(fov_angle).unsqueeze(-1)
viewpoint = torch.cat([azim, elev], 1)
#### Generate camera extrinsic matrix ##########
# convert angles to xyz coordinates
x = torch.cos(elev) * torch.sin(azim)
y = torch.sin(elev)
z = torch.cos(elev) * torch.cos(azim)
camera_dir = torch.stack([x, y, z], dim=1).view(-1,3)
camera_loc = dist * camera_dir
# get rotation matrices (assume object is at the world coordinates origin)
up = torch.tensor([[0,1,0]]).float().to(device) * torch.ones_like(dist)
z_axis = F.normalize(camera_dir, eps=1e-5) # the -z direction points into the screen
x_axis = F.normalize(torch.cross(up, z_axis, dim=1), eps=1e-5)
y_axis = F.normalize(torch.cross(z_axis, x_axis, dim=1), eps=1e-5)
is_close = torch.isclose(x_axis, torch.tensor(0.0), atol=5e-3).all(dim=1, keepdim=True)
if is_close.any():
replacement = F.normalize(torch.cross(y_axis, z_axis, dim=1), eps=1e-5)
x_axis = torch.where(is_close, replacement, x_axis)
R = torch.cat((x_axis[:, None, :], y_axis[:, None, :], z_axis[:, None, :]), dim=1)
T = camera_loc[:, :, None]
extrinsics = torch.cat((R.transpose(1,2),T), -1)
return extrinsics, focal, near, far, viewpoint
#################### Mesh generation util functions ########################
# Reshape sampling volume to camera frostum
def align_volume(volume, near=0.88, far=1.12):
b, h, w, d, c = volume.shape
yy, xx, zz = torch.meshgrid(torch.linspace(-1, 1, h),
torch.linspace(-1, 1, w),
torch.linspace(-1, 1, d))
grid = torch.stack([xx, yy, zz], -1).to(volume.device)
frostum_adjustment_coeffs = torch.linspace(far / near, 1, d).view(1,1,1,-1,1).to(volume.device)
frostum_grid = grid.unsqueeze(0)
frostum_grid[...,:2] = frostum_grid[...,:2] * frostum_adjustment_coeffs
out_of_boundary = torch.any((frostum_grid.lt(-1).logical_or(frostum_grid.gt(1))), -1, keepdim=True)
frostum_grid = frostum_grid.permute(0,3,1,2,4).contiguous()
permuted_volume = volume.permute(0,4,3,1,2).contiguous()
final_volume = F.grid_sample(permuted_volume, frostum_grid, padding_mode="border", align_corners=True)
final_volume = final_volume.permute(0,3,4,2,1).contiguous()
# set a non-zero value to grid locations outside of the frostum to avoid marching cubes distortions.
# It happens because pytorch grid_sample uses zeros padding.
final_volume[out_of_boundary] = 1
return final_volume
# Extract mesh with marching cubes
def extract_mesh_with_marching_cubes(sdf):
b, h, w, d, _ = sdf.shape
# change coordinate order from (y,x,z) to (x,y,z)
sdf_vol = sdf[0,...,0].permute(1,0,2).cpu().numpy()
# scale vertices
verts, faces, _, _ = marching_cubes(sdf_vol, 0)
verts[:,0] = (verts[:,0]/float(w)-0.5)*0.24
verts[:,1] = (verts[:,1]/float(h)-0.5)*0.24
verts[:,2] = (verts[:,2]/float(d)-0.5)*0.24
# fix normal direction
verts[:,2] *= -1; verts[:,1] *= -1
mesh = trimesh.Trimesh(verts, faces)
return mesh
# Generate mesh from xyz point cloud
def xyz2mesh(xyz):
b, _, h, w = xyz.shape
x, y = np.meshgrid(np.arange(h), np.arange(w))
# Extract mesh faces from xyz maps
tri = Delaunay(np.concatenate((x.reshape((h*w, 1)), y.reshape((h*w, 1))), 1))
faces = tri.simplices
# invert normals
faces[:,[0, 1]] = faces[:,[1, 0]]
# generate_meshes
mesh = trimesh.Trimesh(xyz.squeeze(0).permute(1,2,0).view(h*w,3).cpu().numpy(), faces)
return mesh
################# Mesh rendering util functions #############################
def add_textures(meshes:Meshes, vertex_colors=None) -> Meshes:
verts = meshes.verts_padded()
if vertex_colors is None:
vertex_colors = torch.ones_like(verts) # (N, V, 3)
textures = TexturesVertex(verts_features=vertex_colors)
meshes_t = Meshes(
verts=verts,
faces=meshes.faces_padded(),
textures=textures,
verts_normals=meshes.verts_normals_padded(),
)
return meshes_t
def create_cameras(
R=None, T=None,
azim=0, elev=0., dist=1.,
fov=12., znear=0.01,
device="cuda") -> FoVPerspectiveCameras:
"""
all the camera parameters can be a single number, a list, or a torch tensor.
"""
if R is None or T is None:
R, T = look_at_view_transform(dist=dist, azim=azim, elev=elev, device=device)
cameras = FoVPerspectiveCameras(device=device, R=R, T=T, znear=znear, fov=fov)
return cameras
def create_mesh_renderer(
cameras: FoVPerspectiveCameras,
image_size: int = 256,
blur_radius: float = 1e-6,
light_location=((-0.5, 1., 5.0),),
device="cuda",
**light_kwargs,
):
"""
If don't want to show direct texture color without shading, set the light_kwargs as
ambient_color=((1, 1, 1), ), diffuse_color=((0, 0, 0), ), specular_color=((0, 0, 0), )
"""
# We will also create a Phong renderer. This is simpler and only needs to render one face per pixel.
raster_settings = RasterizationSettings(
image_size=image_size,
blur_radius=blur_radius,
faces_per_pixel=5,
)
# We can add a point light in front of the object.
lights = PointLights(
device=device, location=light_location, **light_kwargs
)
phong_renderer = MeshRenderer(
rasterizer=MeshRasterizer(
cameras=cameras,
raster_settings=raster_settings
),
shader=SoftPhongShader(device=device, cameras=cameras, lights=lights)
)
return phong_renderer
## custom renderer
class MeshRendererWithDepth(nn.Module):
def __init__(self, rasterizer, shader):
super().__init__()
self.rasterizer = rasterizer
self.shader = shader
def forward(self, meshes_world, **kwargs) -> torch.Tensor:
fragments = self.rasterizer(meshes_world, **kwargs)
images = self.shader(fragments, meshes_world, **kwargs)
return images, fragments.zbuf
def create_depth_mesh_renderer(
cameras: FoVPerspectiveCameras,
image_size: int = 256,
blur_radius: float = 1e-6,
device="cuda",
**light_kwargs,
):
"""
If don't want to show direct texture color without shading, set the light_kwargs as
ambient_color=((1, 1, 1), ), diffuse_color=((0, 0, 0), ), specular_color=((0, 0, 0), )
"""
# We will also create a Phong renderer. This is simpler and only needs to render one face per pixel.
raster_settings = RasterizationSettings(
image_size=image_size,
blur_radius=blur_radius,
faces_per_pixel=17,
)
# We can add a point light in front of the object.
lights = PointLights(
device=device, location=((-0.5, 1., 5.0),), **light_kwargs
)
renderer = MeshRendererWithDepth(
rasterizer=MeshRasterizer(
cameras=cameras,
raster_settings=raster_settings,
device=device,
),
shader=SoftPhongShader(device=device, cameras=cameras, lights=lights)
)
return renderer
================================================
FILE: volume_renderer.py
================================================
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.autograd as autograd
import numpy as np
from functools import partial
from pdb import set_trace as st
# Basic SIREN fully connected layer
class LinearLayer(nn.Module):
def __init__(self, in_dim, out_dim, bias=True, bias_init=0, std_init=1, freq_init=False, is_first=False):
super().__init__()
if is_first:
self.weight = nn.Parameter(torch.empty(out_dim, in_dim).uniform_(-1 / in_dim, 1 / in_dim))
elif freq_init:
self.weight = nn.Parameter(torch.empty(out_dim, in_dim).uniform_(-np.sqrt(6 / in_dim) / 25, np.sqrt(6 / in_dim) / 25))
else:
self.weight = nn.Parameter(0.25 * nn.init.kaiming_normal_(torch.randn(out_dim, in_dim), a=0.2, mode='fan_in', nonlinearity='leaky_relu'))
self.bias = nn.Parameter(nn.init.uniform_(torch.empty(out_dim), a=-np.sqrt(1/in_dim), b=np.sqrt(1/in_dim)))
self.bias_init = bias_init
self.std_init = std_init
def forward(self, input):
out = self.std_init * F.linear(input, self.weight, bias=self.bias) + self.bias_init
return out
# Siren layer with frequency modulation and offset
class FiLMSiren(nn.Module):
def __init__(self, in_channel, out_channel, style_dim, is_first=False):
super().__init__()
self.in_channel = in_channel
self.out_channel = out_channel
if is_first:
self.weight = nn.Parameter(torch.empty(out_channel, in_channel).uniform_(-1 / 3, 1 / 3))
else:
self.weight = nn.Parameter(torch.empty(out_channel, in_channel).uniform_(-np.sqrt(6 / in_channel) / 25, np.sqrt(6 / in_channel) / 25))
self.bias = nn.Parameter(nn.Parameter(nn.init.uniform_(torch.empty(out_channel), a=-np.sqrt(1/in_channel), b=np.sqrt(1/in_channel))))
self.activation = torch.sin
self.gamma = LinearLayer(style_dim, out_channel, bias_init=30, std_init=15)
self.beta = LinearLayer(style_dim, out_channel, bias_init=0, std_init=0.25)
def forward(self, input, style):
batch, features = style.shape
out = F.linear(input, self.weight, bias=self.bias)
gamma = self.gamma(style).view(batch, 1, 1, 1, features)
beta = self.beta(style).view(batch, 1, 1, 1, features)
out = self.activation(gamma * out + beta)
return out
# Siren Generator Model
class SirenGenerator(nn.Module):
def __init__(self, D=8, W=256, style_dim=256, input_ch=3, input_ch_views=3, output_ch=4,
output_features=True):
super(SirenGenerator, self).__init__()
self.D = D
self.W = W
self.input_ch = input_ch
self.input_ch_views = input_ch_views
self.style_dim = style_dim
self.output_features = output_features
self.pts_linears = nn.ModuleList(
[FiLMSiren(3, W, style_dim=style_dim, is_first=True)] + \
[FiLMSiren(W, W, style_dim=style_dim) for i in range(D-1)])
self.views_linears = FiLMSiren(input_ch_views + W, W,
style_dim=style_dim)
self.rgb_linear = LinearLayer(W, 3, freq_init=True)
self.sigma_linear = LinearLayer(W, 1, freq_init=True)
def forward(self, x, styles):
input_pts, input_views = torch.split(x, [self.input_ch, self.input_ch_views], dim=-1)
mlp_out = input_pts.contiguous()
for i in range(len(self.pts_linears)):
mlp_out = self.pts_linears[i](mlp_out, styles)
sdf = self.sigma_linear(mlp_out)
mlp_out = torch.cat([mlp_out, input_views], -1)
out_features = self.views_linears(mlp_out, styles)
rgb = self.rgb_linear(out_features)
outputs = torch.cat([rgb, sdf], -1)
if self.output_features:
outputs = torch.cat([outputs, out_features], -1)
return outputs
# Full vol
gitextract__fgegk92/ ├── .gitignore ├── LICENSE ├── README.md ├── StyleSDF_demo.ipynb ├── dataset.py ├── distributed.py ├── download_models.py ├── generate_shapes_and_images.py ├── losses.py ├── model.py ├── op/ │ ├── __init__.py │ ├── fused_act.py │ ├── fused_bias_act.cpp │ ├── fused_bias_act_kernel.cu │ ├── upfirdn2d.cpp │ ├── upfirdn2d.py │ └── upfirdn2d_kernel.cu ├── options.py ├── prepare_data.py ├── render_video.py ├── requirements.txt ├── scripts/ │ ├── train_afhq_full_pipeline_512x512.sh │ ├── train_afhq_vol_renderer.sh │ ├── train_ffhq_full_pipeline_1024x1024.sh │ └── train_ffhq_vol_renderer.sh ├── train_full_pipeline.py ├── train_volume_renderer.py ├── utils.py └── volume_renderer.py
SYMBOL INDEX (171 symbols across 17 files)
FILE: dataset.py
class MultiResolutionDataset (line 12) | class MultiResolutionDataset(Dataset):
method __init__ (line 13) | def __init__(self, path, transform, resolution=256, nerf_resolution=64):
method __len__ (line 33) | def __len__(self):
method __getitem__ (line 36) | def __getitem__(self, index):
FILE: distributed.py
function get_rank (line 9) | def get_rank():
function synchronize (line 19) | def synchronize():
function get_world_size (line 34) | def get_world_size():
function reduce_sum (line 44) | def reduce_sum(tensor):
function gather_grad (line 57) | def gather_grad(params):
function all_gather (line 69) | def all_gather(data):
function reduce_loss_dict (line 104) | def reduce_loss_dict(loss_dict):
FILE: download_models.py
function download_pretrained_models (line 28) | def download_pretrained_models():
function download_file (line 80) | def download_file(session, file_spec, use_alt_url=False, chunk_size=128,...
FILE: generate_shapes_and_images.py
function generate (line 27) | def generate(opt, g_ema, surface_g_ema, device, mean_latent, surface_mea...
FILE: losses.py
function viewpoints_loss (line 7) | def viewpoints_loss(viewpoint_pred, viewpoint_target):
function eikonal_loss (line 13) | def eikonal_loss(eikonal_term, sdf=None, beta=100):
function d_logistic_loss (line 27) | def d_logistic_loss(real_pred, fake_pred):
function d_r1_loss (line 34) | def d_r1_loss(real_pred, real_img):
function g_nonsaturating_loss (line 43) | def g_nonsaturating_loss(fake_pred):
function g_path_regularize (line 49) | def g_path_regularize(fake_img, latents, mean_path_length, decay=0.01):
FILE: model.py
class PixelNorm (line 24) | class PixelNorm(nn.Module):
method __init__ (line 25) | def __init__(self):
method forward (line 28) | def forward(self, input):
class MappingLinear (line 32) | class MappingLinear(nn.Module):
method __init__ (line 33) | def __init__(self, in_dim, out_dim, bias=True, activation=None, is_las...
method forward (line 49) | def forward(self, input):
method __repr__ (line 58) | def __repr__(self):
function make_kernel (line 64) | def make_kernel(k):
class Upsample (line 75) | class Upsample(nn.Module):
method __init__ (line 76) | def __init__(self, kernel, factor=2):
method forward (line 90) | def forward(self, input):
class Downsample (line 96) | class Downsample(nn.Module):
method __init__ (line 97) | def __init__(self, kernel, factor=2):
method forward (line 111) | def forward(self, input):
class Blur (line 117) | class Blur(nn.Module):
method __init__ (line 118) | def __init__(self, kernel, pad, upsample_factor=1):
method forward (line 130) | def forward(self, input):
class EqualConv2d (line 136) | class EqualConv2d(nn.Module):
method __init__ (line 137) | def __init__(
method forward (line 156) | def forward(self, input):
method __repr__ (line 167) | def __repr__(self):
class EqualLinear (line 174) | class EqualLinear(nn.Module):
method __init__ (line 175) | def __init__(self, in_dim, out_dim, bias=True, bias_init=0, lr_mul=1,
method forward (line 192) | def forward(self, input):
method __repr__ (line 203) | def __repr__(self):
class ModulatedConv2d (line 209) | class ModulatedConv2d(nn.Module):
method __init__ (line 210) | def __init__(self, in_channel, out_channel, kernel_size, style_dim, de...
method __repr__ (line 249) | def __repr__(self):
method forward (line 255) | def forward(self, input, style):
class NoiseInjection (line 299) | class NoiseInjection(nn.Module):
method __init__ (line 300) | def __init__(self, project=False):
method create_pytorch_mesh (line 308) | def create_pytorch_mesh(self, trimesh):
method load_mc_mesh (line 323) | def load_mc_mesh(self, filename, resolution=128, im_res=64):
method project_noise (line 349) | def project_noise(self, noise, transform, mesh_path=None):
method forward (line 380) | def forward(self, image, noise=None, transform=None, mesh_path=None):
class StyledConv (line 390) | class StyledConv(nn.Module):
method __init__ (line 391) | def __init__(self, in_channel, out_channel, kernel_size, style_dim,
method forward (line 408) | def forward(self, input, style, noise=None, transform=None, mesh_path=...
class ToRGB (line 416) | class ToRGB(nn.Module):
method __init__ (line 417) | def __init__(self, in_channel, style_dim, upsample=True, blur_kernel=[...
method forward (line 428) | def forward(self, input, style, skip=None):
class ConvLayer (line 441) | class ConvLayer(nn.Sequential):
method __init__ (line 442) | def __init__(self, in_channel, out_channel, kernel_size, downsample=Fa...
class Decoder (line 478) | class Decoder(nn.Module):
method __init__ (line 479) | def __init__(self, model_opt, blur_kernel=[1, 3, 3, 1]):
method mean_latent (line 559) | def mean_latent(self, renderer_latent):
method get_latent (line 564) | def get_latent(self, input):
method styles_and_noise_forward (line 567) | def styles_and_noise_forward(self, styles, noise, inject_index=None, t...
method forward (line 610) | def forward(self, features, styles, rgbd_in=None, transform=None,
class Generator (line 641) | class Generator(nn.Module):
method __init__ (line 642) | def __init__(self, model_opt, renderer_opt, blur_kernel=[1, 3, 3, 1], ...
method make_noise (line 673) | def make_noise(self):
method mean_latent (line 684) | def mean_latent(self, n_latent, device):
method get_latent (line 695) | def get_latent(self, input):
method styles_and_noise_forward (line 698) | def styles_and_noise_forward(self, styles, inject_index=None, truncati...
method init_forward (line 715) | def init_forward(self, styles, cam_poses, focals, near=0.88, far=1.12):
method forward (line 722) | def forward(self, styles, cam_poses, focals, near=0.88, far=1.12, retu...
class VolumeRenderDiscConv2d (line 763) | class VolumeRenderDiscConv2d(nn.Module):
method __init__ (line 764) | def __init__(self, in_channels, out_channels, kernel_size, stride=1,
method forward (line 778) | def forward(self, input):
class AddCoords (line 791) | class AddCoords(nn.Module):
method __init__ (line 792) | def __init__(self):
method forward (line 795) | def forward(self, input_tensor):
class CoordConv2d (line 817) | class CoordConv2d(nn.Module):
method __init__ (line 818) | def __init__(self, in_channels, out_channels, kernel_size, stride=1,
method forward (line 826) | def forward(self, input_tensor):
class CoordConvLayer (line 838) | class CoordConvLayer(nn.Module):
method __init__ (line 839) | def __init__(self, in_channel, out_channel, kernel_size, bias=True, ac...
method forward (line 856) | def forward(self, input):
class VolumeRenderResBlock (line 864) | class VolumeRenderResBlock(nn.Module):
method __init__ (line 865) | def __init__(self, in_channel, out_channel):
method forward (line 877) | def forward(self, input):
class VolumeRenderDiscriminator (line 893) | class VolumeRenderDiscriminator(nn.Module):
method __init__ (line 894) | def __init__(self, opt):
method forward (line 926) | def forward(self, input):
class ResBlock (line 940) | class ResBlock(nn.Module):
method __init__ (line 941) | def __init__(self, in_channel, out_channel, blur_kernel=[1, 3, 3, 1], ...
method forward (line 949) | def forward(self, input):
class Discriminator (line 957) | class Discriminator(nn.Module):
method __init__ (line 958) | def __init__(self, opt, blur_kernel=[1, 3, 3, 1]):
method forward (line 1001) | def forward(self, input):
FILE: op/fused_act.py
class FusedLeakyReLUFunctionBackward (line 20) | class FusedLeakyReLUFunctionBackward(Function):
method forward (line 22) | def forward(ctx, grad_output, out, bias, negative_slope, scale):
method backward (line 47) | def backward(ctx, gradgrad_input, gradgrad_bias):
class FusedLeakyReLUFunction (line 56) | class FusedLeakyReLUFunction(Function):
method forward (line 58) | def forward(ctx, input, bias, negative_slope, scale):
method backward (line 74) | def backward(ctx, grad_output):
class FusedLeakyReLU (line 87) | class FusedLeakyReLU(nn.Module):
method __init__ (line 88) | def __init__(self, channel, bias=True, negative_slope=0.2, scale=2 ** ...
method forward (line 100) | def forward(self, input):
function fused_leaky_relu (line 104) | def fused_leaky_relu(input, bias=None, negative_slope=0.2, scale=2 ** 0.5):
FILE: op/fused_bias_act.cpp
function fused_bias_act (line 11) | torch::Tensor fused_bias_act(const torch::Tensor& input, const torch::Te...
function PYBIND11_MODULE (line 19) | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
FILE: op/upfirdn2d.cpp
function upfirdn2d (line 12) | torch::Tensor upfirdn2d(const torch::Tensor& input, const torch::Tensor&...
function PYBIND11_MODULE (line 21) | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
FILE: op/upfirdn2d.py
class UpFirDn2dBackward (line 20) | class UpFirDn2dBackward(Function):
method forward (line 22) | def forward(
method backward (line 64) | def backward(ctx, gradgrad_input):
class UpFirDn2d (line 89) | class UpFirDn2d(Function):
method forward (line 91) | def forward(ctx, input, kernel, up, down, pad):
method backward (line 128) | def backward(ctx, grad_output):
function upfirdn2d (line 146) | def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)):
function upfirdn2d_native (line 160) | def upfirdn2d_native(
FILE: options.py
class BaseOptions (line 5) | class BaseOptions():
method __init__ (line 6) | def __init__(self):
method initialize (line 10) | def initialize(self):
method parse (line 94) | def parse(self):
FILE: prepare_data.py
function resize_and_convert (line 14) | def resize_and_convert(img, size, resample):
function resize_multiple (line 24) | def resize_multiple(
function resize_worker (line 34) | def resize_worker(img_file, sizes, resample):
function prepare (line 43) | def prepare(
FILE: render_video.py
function render_video (line 26) | def render_video(opt, g_ema, surface_g_ema, device, mean_latent, surface...
FILE: train_full_pipeline.py
function train (line 28) | def train(opt, experiment_opt, loader, generator, discriminator, g_optim...
FILE: train_volume_renderer.py
function train (line 28) | def train(opt, experiment_opt, loader, generator, discriminator, g_optim...
FILE: utils.py
function data_sampler (line 27) | def data_sampler(dataset, shuffle, distributed):
function sample_data (line 38) | def sample_data(loader):
function requires_grad (line 45) | def requires_grad(model, flag=True):
function accumulate (line 50) | def accumulate(model1, model2, decay=0.999):
function make_noise (line 59) | def make_noise(batch, latent_dim, n_noise, device):
function mixing_noise (line 68) | def mixing_noise(batch, latent_dim, prob, device):
function generate_camera_params (line 75) | def generate_camera_params(resolution, device, batch=1, locations=None, ...
function align_volume (line 141) | def align_volume(volume, near=0.88, far=1.12):
function extract_mesh_with_marching_cubes (line 164) | def extract_mesh_with_marching_cubes(sdf):
function xyz2mesh (line 183) | def xyz2mesh(xyz):
function add_textures (line 201) | def add_textures(meshes:Meshes, vertex_colors=None) -> Meshes:
function create_cameras (line 215) | def create_cameras(
function create_mesh_renderer (line 229) | def create_mesh_renderer(
class MeshRendererWithDepth (line 263) | class MeshRendererWithDepth(nn.Module):
method __init__ (line 264) | def __init__(self, rasterizer, shader):
method forward (line 269) | def forward(self, meshes_world, **kwargs) -> torch.Tensor:
function create_depth_mesh_renderer (line 275) | def create_depth_mesh_renderer(
FILE: volume_renderer.py
class LinearLayer (line 12) | class LinearLayer(nn.Module):
method __init__ (line 13) | def __init__(self, in_dim, out_dim, bias=True, bias_init=0, std_init=1...
method forward (line 27) | def forward(self, input):
class FiLMSiren (line 33) | class FiLMSiren(nn.Module):
method __init__ (line 34) | def __init__(self, in_channel, out_channel, style_dim, is_first=False):
method forward (line 50) | def forward(self, input, style):
class SirenGenerator (line 62) | class SirenGenerator(nn.Module):
method __init__ (line 63) | def __init__(self, D=8, W=256, style_dim=256, input_ch=3, input_ch_vie...
method forward (line 82) | def forward(self, x, styles):
class VolumeFeatureRenderer (line 103) | class VolumeFeatureRenderer(nn.Module):
method __init__ (line 104) | def __init__(self, opt, style_dim=256, out_im_res=64, mode='train'):
method get_rays (line 158) | def get_rays(self, focal, c2w):
method get_eikonal_term (line 175) | def get_eikonal_term(self, pts, sdf):
method sdf_activation (line 182) | def sdf_activation(self, input):
method volume_integration (line 187) | def volume_integration(self, raw, z_vals, rays_d, pts, return_eikonal=...
method run_network (line 254) | def run_network(self, inputs, viewdirs, styles=None):
method render_rays (line 261) | def render_rays(self, ray_batch, styles=None, return_eikonal=False):
method render (line 305) | def render(self, focal, c2w, near, far, styles, c2w_staticcam=None, re...
method mlp_init_pass (line 319) | def mlp_init_pass(self, cam_poses, focal, near, far, styles=None):
method forward (line 348) | def forward(self, cam_poses, focal, near, far, styles=None, return_eik...
Condensed preview — 29 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (217K chars).
[
{
"path": ".gitignore",
"chars": 89,
"preview": "__pycache__/\ncheckpoint/\ndatasets/\nevaluations/\nfull_models/\npretrained_renderer/\nwandb/\n"
},
{
"path": "LICENSE",
"chars": 1054,
"preview": "Copyright (C) 2022 Roy Or-El, Xuan Luo, Mengyi Shan, Eli Shechtman, Jeong Joon Park, Ira Kemelmacher-Shlizerman.\nAll rig"
},
{
"path": "README.md",
"chars": 11597,
"preview": "# StyleSDF: High-Resolution 3D-Consistent Image and Geometry Generation\n### [Project Page](https://stylesdf.github.io/) "
},
{
"path": "StyleSDF_demo.ipynb",
"chars": 20021,
"preview": "{\n \"nbformat\": 4,\n \"nbformat_minor\": 0,\n \"metadata\": {\n \"colab\": {\n \"name\": \"StyleSDF_demo.ipynb\",\n \"pri"
},
{
"path": "dataset.py",
"chars": 1432,
"preview": "import os\nimport csv\nimport lmdb\nimport random\nimport numpy as np\nimport torchvision.transforms.functional as TF\nfrom PI"
},
{
"path": "distributed.py",
"chars": 2711,
"preview": "import math\nimport pickle\n\nimport torch\nfrom torch import distributed as dist\nfrom torch.utils.data.sampler import Sampl"
},
{
"path": "download_models.py",
"chars": 5986,
"preview": "import os\nimport html\nimport glob\nimport uuid\nimport hashlib\nimport requests\nfrom tqdm import tqdm\nfrom pdb import set_t"
},
{
"path": "generate_shapes_and_images.py",
"chars": 11783,
"preview": "import os\nimport torch\nimport trimesh\nimport numpy as np\nfrom munch import *\nfrom PIL import Image\nfrom tqdm import tqdm"
},
{
"path": "losses.py",
"chars": 1786,
"preview": "import math\nimport torch\nfrom torch import autograd\nfrom torch.nn import functional as F\n\n\ndef viewpoints_loss(viewpoint"
},
{
"path": "model.py",
"chars": 34016,
"preview": "import math\nimport random\nimport trimesh\nimport torch\nimport numpy as np\nfrom torch import nn\nfrom torch.nn import funct"
},
{
"path": "op/__init__.py",
"chars": 89,
"preview": "from .fused_act import FusedLeakyReLU, fused_leaky_relu\nfrom .upfirdn2d import upfirdn2d\n"
},
{
"path": "op/fused_act.py",
"chars": 3262,
"preview": "import os\r\n\r\nimport torch\r\nfrom torch import nn\r\nfrom torch.nn import functional as F\r\nfrom torch.autograd import Functi"
},
{
"path": "op/fused_bias_act.cpp",
"chars": 846,
"preview": "#include <torch/extension.h>\r\n\r\n\r\ntorch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias,"
},
{
"path": "op/fused_bias_act_kernel.cu",
"chars": 2875,
"preview": "// Copyright (c) 2019, NVIDIA Corporation. All rights reserved.\r\n//\r\n// This work is made available under the Nvidia Sou"
},
{
"path": "op/upfirdn2d.cpp",
"chars": 988,
"preview": "#include <torch/extension.h>\r\n\r\n\r\ntorch::Tensor upfirdn2d_op(const torch::Tensor& input, const torch::Tensor& kernel,\r\n "
},
{
"path": "op/upfirdn2d.py",
"chars": 5905,
"preview": "import os\r\n\r\nimport torch\r\nfrom torch.nn import functional as F\r\nfrom torch.autograd import Function\r\nfrom torch.utils.c"
},
{
"path": "op/upfirdn2d_kernel.cu",
"chars": 12079,
"preview": "// Copyright (c) 2019, NVIDIA Corporation. All rights reserved.\r\n//\r\n// This work is made available under the Nvidia Sou"
},
{
"path": "options.py",
"chars": 9522,
"preview": "import configargparse\nfrom munch import *\nfrom pdb import set_trace as st\n\nclass BaseOptions():\n def __init__(self):\n"
},
{
"path": "prepare_data.py",
"chars": 2836,
"preview": "import argparse\nfrom io import BytesIO\nimport multiprocessing\nfrom functools import partial\n\nfrom PIL import Image\nimpor"
},
{
"path": "render_video.py",
"chars": 14374,
"preview": "import os\nimport torch\nimport trimesh\nimport numpy as np\nimport skvideo.io\nfrom munch import *\nfrom PIL import Image\nfro"
},
{
"path": "requirements.txt",
"chars": 111,
"preview": "lmdb\nnumpy\nninja\npillow\nrequests\ntqdm\nscipy\nscikit-image\nscikit-video\ntrimesh[easy]\nconfigargparse\nmunch\nwandb\n"
},
{
"path": "scripts/train_afhq_full_pipeline_512x512.sh",
"chars": 188,
"preview": "python -m torch.distributed.launch --nproc_per_node 4 new_train.py --batch 8 --chunk 4 --azim 0.15 --r1 50.0 --expname a"
},
{
"path": "scripts/train_afhq_vol_renderer.sh",
"chars": 189,
"preview": "python -m torch.distributed.launch --nproc_per_node 2 train_volume_renderer.py --batch 12 --chunk 6 --expname afhq_sdf_v"
},
{
"path": "scripts/train_ffhq_full_pipeline_1024x1024.sh",
"chars": 141,
"preview": "python -m torch.distributed.launch --nproc_per_node 4 train_full_pipeline.py --batch 8 --chunk 2 --expname ffhq1024x1024"
},
{
"path": "scripts/train_ffhq_vol_renderer.sh",
"chars": 172,
"preview": "python -m torch.distributed.launch --nproc_per_node 2 train_volume_renderer.py --batch 12 --chunk 6 --expname ffhq_sdf_v"
},
{
"path": "train_full_pipeline.py",
"chars": 18895,
"preview": "import argparse\nimport math\nimport random\nimport os\nimport yaml\nimport numpy as np\nimport torch\nfrom torch import nn, au"
},
{
"path": "train_volume_renderer.py",
"chars": 18054,
"preview": "import argparse\nimport math\nimport random\nimport os\nimport yaml\nimport numpy as np\nimport torch\nimport torch.distributed"
},
{
"path": "utils.py",
"chars": 11434,
"preview": "import torch\nimport random\nimport trimesh\nimport numpy as np\nfrom torch import nn\nfrom torch.nn import functional as F\nf"
},
{
"path": "volume_renderer.py",
"chars": 15137,
"preview": "import math\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport torch.autograd as autograd\nimport "
}
]
About this extraction
This page contains the full source code of the royorel/StyleSDF GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 29 files (202.7 KB), approximately 51.7k tokens, and a symbol index with 171 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.
Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.