Repository: royorel/StyleSDF Branch: main Commit: 1c995101b80b Files: 29 Total size: 202.7 KB Directory structure: gitextract__fgegk92/ ├── .gitignore ├── LICENSE ├── README.md ├── StyleSDF_demo.ipynb ├── dataset.py ├── distributed.py ├── download_models.py ├── generate_shapes_and_images.py ├── losses.py ├── model.py ├── op/ │ ├── __init__.py │ ├── fused_act.py │ ├── fused_bias_act.cpp │ ├── fused_bias_act_kernel.cu │ ├── upfirdn2d.cpp │ ├── upfirdn2d.py │ └── upfirdn2d_kernel.cu ├── options.py ├── prepare_data.py ├── render_video.py ├── requirements.txt ├── scripts/ │ ├── train_afhq_full_pipeline_512x512.sh │ ├── train_afhq_vol_renderer.sh │ ├── train_ffhq_full_pipeline_1024x1024.sh │ └── train_ffhq_vol_renderer.sh ├── train_full_pipeline.py ├── train_volume_renderer.py ├── utils.py └── volume_renderer.py ================================================ FILE CONTENTS ================================================ ================================================ FILE: .gitignore ================================================ __pycache__/ checkpoint/ datasets/ evaluations/ full_models/ pretrained_renderer/ wandb/ ================================================ FILE: LICENSE ================================================ Copyright (C) 2022 Roy Or-El, Xuan Luo, Mengyi Shan, Eli Shechtman, Jeong Joon Park, Ira Kemelmacher-Shlizerman. All rights reserved. Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). The software is made available under Creative Commons BY-NC-SA 4.0 license by University of Washington. You can use, redistribute, and adapt it for non-commercial purposes, as long as you (a) give appropriate credit by citing our paper, (b) indicate any changes that you've made, and (c) distribute any derivative works under the same license. THE AUTHORS DISCLAIM ALL WARRANTIES WITH REGARD TO THIS SOFTWARE, INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR ANY PARTICULAR PURPOSE. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY SPECIAL, INDIRECT OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. ================================================ FILE: README.md ================================================ # StyleSDF: High-Resolution 3D-Consistent Image and Geometry Generation ### [Project Page](https://stylesdf.github.io/) | [Paper](https://arxiv.org/pdf/2112.11427.pdf) | [HuggingFace Demo](https://huggingface.co/spaces/SerdarHelli/StyleSDF-3D) [![Explore in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/royorel/StyleSDF/blob/main/StyleSDF_demo.ipynb)
[Roy Or-El](https://homes.cs.washington.edu/~royorel/)1 , [Xuan Luo](https://roxanneluo.github.io/)1, [Mengyi Shan](https://shanmy.github.io/)1, [Eli Shechtman](https://research.adobe.com/person/eli-shechtman/)2, [Jeong Joon Park](https://jjparkcv.github.io/)3, [Ira Kemelmacher-Shlizerman](https://www.irakemelmacher.com/)1
1University of Washington, 2Adobe Research, 3Stanford University
## Updates 12/26/2022: A new HuggingFace demo is now available. Special thanks to [@SerdarHelli](https://github.com/SerdarHelli) for the implementation.
3/27/2022: Fixed a bug in the sphere initialization code (init_forward function was missing, see commit [0bd8741](https://github.com/royorel/StyleSDF/commit/0bd8741f26048d26160a495a9056b5f1da1a60a1)).
3/22/2022: **Added training files**.
3/9/2022: Fixed a bug in the calculation of the mean w vector (see commit [d4dd17d](https://github.com/royorel/StyleSDF/commit/d4dd17de09fd58adefc7ed49487476af6018894f)).
3/4/2022: Testing code and Colab demo were released. **Training files will be released soon.** ## Overview StyleSDF is a 3D-aware GAN, aimed at solving two main challenges: 1. High-resolution, view-consistent generation of the RGB images. 2. Generating detailed 3D shapes. StyleSDF is trained only on single-view RGB data. The 3D geometry is learned implicitly with an SDF-based volume renderer.
This code is the official PyTorch implementation of the paper: > **StyleSDF: High-Resolution 3D-Consistent Image and Geometry Generation**
> Roy Or-El, Xuan Luo, Mengyi Shan, Eli Shechtman, Jeong Joon Park, Ira Kemelmacher-Shlizerman
> CVPR 2022
> https://arxiv.org/pdf/2112.11427.pdf ## Abstract We introduce a high resolution, 3D-consistent image and shape generation technique which we call StyleSDF. Our method is trained on single-view RGB data only, and stands on the shoulders of StyleGAN2 for image generation, while solving two main challenges in 3D-aware GANs: 1) high-resolution, view-consistent generation of the RGB images, and 2) detailed 3D shape. We achieve this by merging a SDF-based 3D representation with a style-based 2D generator. Our 3D implicit network renders low-resolution feature maps, from which the style-based network generates view-consistent, 1024×1024 images. Notably, our SDFbased 3D modeling defines detailed 3D surfaces, leading to consistent volume rendering. Our method shows higher quality results compared to state of the art in terms of visual and geometric quality. ## Pre-Requisits You must have a **GPU with CUDA support** in order to run the code. This code requires **PyTorch**, **PyTorch3D** and **torchvision** to be installed, please go to [PyTorch.org](https://pytorch.org/) and [PyTorch3d.org](https://pytorch3d.org/) for installation info.
We tested our code on Python 3.8.5, PyTorch 1.9.0, PyTorch3D 0.6.1 and torchvision 0.10.0. The following packages should also be installed: 1. lmdb 2. numpy 3. ninja 4. pillow 5. requests 6. tqdm 7. scipy 8. skimage 9. skvideo 10. trimesh[easy] 11. configargparse 12. munch 13. wandb (optional) If any of these packages are not installed on your computer, you can install them using the supplied `requirements.txt` file:
```pip install -r requirements.txt``` ## Download Pre-trained Models The pre-trained models can be downloaded by running `python download_models.py`. ## Quick Demo You can explore our method in Google Colab [![Explore in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/royorel/StyleSDF/blob/main/StyleSDF_demo.ipynb). Alternatively, you can download the pretrained models by running:
`python download_models.py` To generate human faces from the model pre-trained on FFHQ, run:
`python generate_shapes_and_images.py --expname ffhq1024x1024 --size 1024 --identities NUMBER_OF_FACES` To generate animal faces from the model pre-trained on AFHQ, run:
`python generate_shapes_and_images.py --expname afhq512x512 --size 512 --identities NUMBER_OF_FACES` ## Generating images and meshes To generate images and meshes from a trained model, run: `python generate_shapes_and_images.py --expname NAME_OF_TRAINED_MODEL --size MODEL_OUTPUT_SIZE --identities NUMBER_OF_FACES` The script will generate an RGB image, a mesh generated from depth map, and the mesh extracted with Marching cubes. ### Optional flags for image and mesh generation ``` --no_surface_renderings When true, only RGB outputs will be generated. otherwise, both RGB and depth videos/renderings will be generated. (default: false) --fixed_camera_angles When true, the generator will render indentities from a fixed set of camera angles. (default: false) ``` ## Generating Videos To generate videos from a trained model, run:
`python render_video.py --expname NAME_OF_TRAINED_MODEL --size MODEL_OUTPUT_SIZE --identities NUMBER_OF_FACES`. This script will generate RGB video as well as depth map video for each identity. The average processing time per video is ~5-10 minutes on an RTX2080 Ti GPU. ### Optional flags for video rendering ``` --no_surface_videos When true, only RGB video will be generated when running render_video.py. otherwise, both RGB and depth videos will be generated. this cuts the processing time per video. (default: false) --azim_video When true, the camera trajectory will travel along the azimuth direction. Otherwise, the camera will travel along an ellipsoid trajectory. (default: ellipsoid) --project_noise When true, use geometry-aware noise projection to reduce flickering effects (see supplementary section C.1 in the paper). Warning: processing time significantly increases with this flag to ~20 minutes per video. (default: false) ``` ## Training ### Preparing your Dataset If you wish to train a model from scratch, first you need to convert your dataset to an lmdb format. Run:
`python prepare_data.py --out_path OUTPUT_LMDB_PATH --n_worker N_WORKER --size SIZE1,SIZE2,SIZE3,... INPUT_DATASET_PATH` ### Training the volume renderer #### Training scripts To train the volume renderer on FFHQ run: `bash ./scripts/train_ffhq_vol_renderer.sh`.
To train the volume renderer on AFHQ run: `bash ./scripts/train_afhq_vol_renderer.sh`. * The scripts above use distributed training. To train the models on a single GPU (not recommended) remove `-m torch.distributed.launch --nproc_per_node NUM_GPUS` from the script. #### Training on a new dataset To train the volume renderer on a new dataset, run:
`python train_volume_renderer.py --batch BATCH_SIZE --chunk CHUNK_SIZE --expname EXPERIMENT_NAME --dataset_path DATASET_PATH` Ideally, `CHUNK_SIZE` should be the same as `BATCH_SIZE`, but on most GPUs it will likely cause an out of memory error. In such case, reduce `CHUNK_SIZE` to perform gradient accumulation. **Important note**: The best way to monitor the SDF convergence is to look at "Beta value" graph on wandb. Convergence is successful once beta reaches values below approx. 3*10-3. If the SDF is not converging, increase the R1 regularization weight. Another helpful option (to a lesser degree) is to decrease the weight of the minimal surface regularization. #### Distributed training If you have multiple GPUs you can train your model on multiple instances by running:
`python -m torch.distributed.launch --nproc_per_node NUM_GPUS train_volume_renderer.py --batch BATCH_SIZE --chunk CHUNK_SIZE --expname EXPERIMENT_NAME --dataset_path DATASET_PATH` ### Training the full pipeline #### Training scripts To train the volume renderer on FFHQ run: `bash ./scripts/train_ffhq_full_pipeline_1024x1024.sh`.
To train the volume renderer on AFHQ run: `bash ./scripts/train_afhq_full_pipeline_512x512.sh`. * The scripts above assume that the volume renderer model was already trained. **Do not run them from scratch.** * The scripts above use distributed training. To train the models on a single GPU (not recommended) remove `-m torch.distributed.launch --nproc_per_node NUM_GPUS` from the script. #### Training on a new dataset To train the full pipeline on a new dataset, **first train the volume renderer separately**.
After the volume renderer training is finished, run:
`python train_full_pipeline.py --batch BATCH_SIZE --chunk CHUNK_SIZE --expname EXPERIMENT_NAME --size OUTPUT_SIZE` Ideally, `CHUNK_SIZE` should be the same as `BATCH_SIZE`, but on most GPUs it will likely cause an out of memory error. In such case, reduce `CHUNK_SIZE` to perform gradient accumulation. #### Distributed training If you have multiple GPUs you can train your model on multiple instances by running:
`python -m torch.distributed.launch --nproc_per_node NUM_GPUS train_full_pipeline.py --batch BATCH_SIZE --chunk CHUNK_SIZE --expname EXPERIMENT_NAME --size OUTPUT_SIZE` Here, **BATCH_SIZE represents the batch per GPU**, not the overall batch size. ### Optional training flags ``` Training regime options: --iter Total number of training iterations. (default: 300,000) --wandb Use use weights and biases logging. (default: False) --r1 Weight of the r1 regularization. (default: 10.0) --view_lambda Weight of the viewpoint regularization. (Equation 6, default: 15) --eikonal_lambda Weight of the eikonal regularization. (Equation 7, default: 0.1) --min_surf_lambda Weight of the minimal surface regularization. (Equation 8, default: 0.05) Camera options: --uniform When true, the camera position is sampled from uniform distribution. (default: gaussian) --azim Camera azimuth angle std (guassian)/range (uniform) in Radians. (default: 0.3 Rad.) --elev Camera elevation angle std (guassian)/range (uniform) in Radians. (default: 0.15 Rad.) --fov Camera field of view half angle in **Degrees**. (default: 6 Deg.) --dist_radius Radius of points sampling distance from the origin. Determines the near and far fields. (default: 0.12) ``` ## Citation If you use this code for your research, please cite our paper. ``` @InProceedings{orel2022stylesdf, title={Style{SDF}: {H}igh-{R}esolution {3D}-{C}onsistent {I}mage and {G}eometry {G}eneration}, author = {Or-El, Roy and Luo, Xuan and Shan, Mengyi and Shechtman, Eli and Park, Jeong Joon and Kemelmacher-Shlizerman, Ira}, booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)}, month = {June}, year = {2022}, pages = {13503-13513} } ``` ## Acknowledgments This code is inspired by rosinality's [StyleGAN2-PyTorch](https://github.com/rosinality/stylegan2-pytorch) and Yen-Chen Lin's [NeRF-pytorch](https://github.com/yenchenlin/nerf-pytorch). A special thanks to [@SerdarHelli](https://github.com/SerdarHelli) for implementing the HuggingFace demo. ================================================ FILE: StyleSDF_demo.ipynb ================================================ { "nbformat": 4, "nbformat_minor": 0, "metadata": { "colab": { "name": "StyleSDF_demo.ipynb", "private_outputs": true, "provenance": [], "collapsed_sections": [], "include_colab_link": true }, "kernelspec": { "name": "python3", "display_name": "Python 3" }, "language_info": { "name": "python" }, "accelerator": "GPU" }, "cells": [ { "cell_type": "markdown", "metadata": { "id": "view-in-github", "colab_type": "text" }, "source": [ "\"Open" ] }, { "cell_type": "markdown", "source": [ "#StyleSDF Demo\n", "\n", "This Colab notebook demonstrates the capabilities of the StyleSDF 3D-aware GAN architecture proposed in our paper.\n", "\n", "This colab generates images with their correspondinig 3D meshes" ], "metadata": { "id": "b86fxLSo1gqI" } }, { "cell_type": "markdown", "source": [ "First, let's download the github repository and install all dependencies." ], "metadata": { "id": "8wSAkifH2Bk5" } }, { "cell_type": "code", "source": [ "!git clone https://github.com/royorel/StyleSDF.git\n", "%cd StyleSDF\n", "!pip3 install -r requirements.txt" ], "metadata": { "id": "1GTFRig12CuH" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "And install pytorch3D..." ], "metadata": { "id": "GwFJ8oLY4i8d" } }, { "cell_type": "code", "source": [ "!pip install -U fvcore\n", "import sys\n", "import torch\n", "pyt_version_str=torch.__version__.split(\"+\")[0].replace(\".\", \"\")\n", "version_str=\"\".join([\n", " f\"py3{sys.version_info.minor}_cu\",\n", " torch.version.cuda.replace(\".\",\"\"),\n", " f\"_pyt{pyt_version_str}\"\n", "])\n", "!pip install --no-index --no-cache-dir pytorch3d -f https://dl.fbaipublicfiles.com/pytorch3d/packaging/wheels/{version_str}/download.html" ], "metadata": { "id": "zl3Vpddz3ols" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "Now let's download the pretrained models for FFHQ and AFHQ." ], "metadata": { "id": "OgMcslVS5vbC" } }, { "cell_type": "code", "source": [ "!python download_models.py" ], "metadata": { "id": "r1iDkz7r5wnO" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "Here, we import libraries and set options.\n", "\n", "Note: this might take a while (approx. 1-2 minutes) since CUDA kernels need to be compiled." ], "metadata": { "id": "F2JUyGq4JDa8" } }, { "cell_type": "code", "source": [ "import os\n", "import torch\n", "import trimesh\n", "import numpy as np\n", "from munch import *\n", "from options import BaseOptions\n", "from model import Generator\n", "from generate_shapes_and_images import generate\n", "from render_video import render_video\n", "\n", "\n", "torch.random.manual_seed(321)\n", "\n", "\n", "device = \"cuda\"\n", "opt = BaseOptions().parse()\n", "opt.camera.uniform = True\n", "opt.model.is_test = True\n", "opt.model.freeze_renderer = False\n", "opt.rendering.offset_sampling = True\n", "opt.rendering.static_viewdirs = True\n", "opt.rendering.force_background = True\n", "opt.rendering.perturb = 0\n", "opt.inference.renderer_output_size = opt.model.renderer_spatial_output_dim\n", "opt.inference.style_dim = opt.model.style_dim\n", "opt.inference.project_noise = opt.model.project_noise" ], "metadata": { "id": "Qfamt8J0JGn5" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "metadata": { "id": "TLOiDzgKtYyi" }, "source": [ "Don't worry about this message above, \n", "```\n", "usage: ipykernel_launcher.py [-h] [--dataset_path DATASET_PATH]\n", " [--config CONFIG] [--expname EXPNAME]\n", " [--ckpt CKPT] [--continue_training]\n", " ...\n", " ...\n", "ipykernel_launcher.py: error: unrecognized arguments: -f /root/.local/share/jupyter/runtime/kernel-c9d47a98-bdba-4a5f-9f0a-e1437c7228b6.json\n", "```\n", "everything is perfectly fine..." ] }, { "cell_type": "markdown", "source": [ "Here, we define our model.\n", "\n", "Set the options below according to your choosing:\n", "1. If you plan to try the method for the AFHQ dataset (animal faces), change `model_type` to 'afhq'. Default: `ffhq` (human faces).\n", "2. If you wish to turn off depth rendering and marching cubes extraction and generate only RGB images, set `opt.inference.no_surface_renderings = True`. Default: `False`.\n", "3. If you wish to generate the image from a specific set of viewpoints, set `opt.inference.fixed_camera_angles = True`. Default: `False`.\n", "4. Set the number of identities you wish to create in `opt.inference.identities`. Default: `4`.\n", "5. Select the number of views per identity in `opt.inference.num_views_per_id`,
\n", " (Only applicable when `opt.inference.fixed_camera_angles` is false). Default: `1`. " ], "metadata": { "id": "40dk1eEJo9MF" } }, { "cell_type": "code", "source": [ "# User options\n", "model_type = 'ffhq' # Whether to load the FFHQ or AFHQ model\n", "opt.inference.no_surface_renderings = False # When true, only RGB images will be created\n", "opt.inference.fixed_camera_angles = False # When true, each identity will be rendered from a specific set of 13 viewpoints. Otherwise, random views are generated\n", "opt.inference.identities = 4 # Number of identities to generate\n", "opt.inference.num_views_per_id = 1 # Number of viewpoints generated per identity. This option is ignored if opt.inference.fixed_camera_angles is true.\n", "\n", "# Load saved model\n", "if model_type == 'ffhq':\n", " model_path = 'ffhq1024x1024.pt'\n", " opt.model.size = 1024\n", " opt.experiment.expname = 'ffhq1024x1024'\n", "else:\n", " opt.inference.camera.azim = 0.15\n", " model_path = 'afhq512x512.pt'\n", " opt.model.size = 512\n", " opt.experiment.expname = 'afhq512x512'\n", "\n", "# Create results directory\n", "result_model_dir = 'final_model'\n", "results_dir_basename = os.path.join(opt.inference.results_dir, opt.experiment.expname)\n", "opt.inference.results_dst_dir = os.path.join(results_dir_basename, result_model_dir)\n", "if opt.inference.fixed_camera_angles:\n", " opt.inference.results_dst_dir = os.path.join(opt.inference.results_dst_dir, 'fixed_angles')\n", "else:\n", " opt.inference.results_dst_dir = os.path.join(opt.inference.results_dst_dir, 'random_angles')\n", "\n", "os.makedirs(opt.inference.results_dst_dir, exist_ok=True)\n", "os.makedirs(os.path.join(opt.inference.results_dst_dir, 'images'), exist_ok=True)\n", "if not opt.inference.no_surface_renderings:\n", " os.makedirs(os.path.join(opt.inference.results_dst_dir, 'depth_map_meshes'), exist_ok=True)\n", " os.makedirs(os.path.join(opt.inference.results_dst_dir, 'marching_cubes_meshes'), exist_ok=True)\n", "\n", "opt.inference.camera = opt.camera\n", "opt.inference.size = opt.model.size\n", "checkpoint_path = os.path.join('full_models', model_path)\n", "checkpoint = torch.load(checkpoint_path)\n", "\n", "# Load image generation model\n", "g_ema = Generator(opt.model, opt.rendering).to(device)\n", "pretrained_weights_dict = checkpoint[\"g_ema\"]\n", "model_dict = g_ema.state_dict()\n", "for k, v in pretrained_weights_dict.items():\n", " if v.size() == model_dict[k].size():\n", " model_dict[k] = v\n", "\n", "g_ema.load_state_dict(model_dict)\n", "\n", "# Load a second volume renderer that extracts surfaces at 128x128x128 (or higher) for better surface resolution\n", "if not opt.inference.no_surface_renderings:\n", " opt['surf_extraction'] = Munch()\n", " opt.surf_extraction.rendering = opt.rendering\n", " opt.surf_extraction.model = opt.model.copy()\n", " opt.surf_extraction.model.renderer_spatial_output_dim = 128\n", " opt.surf_extraction.rendering.N_samples = opt.surf_extraction.model.renderer_spatial_output_dim\n", " opt.surf_extraction.rendering.return_xyz = True\n", " opt.surf_extraction.rendering.return_sdf = True\n", " surface_g_ema = Generator(opt.surf_extraction.model, opt.surf_extraction.rendering, full_pipeline=False).to(device)\n", "\n", "\n", " # Load weights to surface extractor\n", " surface_extractor_dict = surface_g_ema.state_dict()\n", " for k, v in pretrained_weights_dict.items():\n", " if k in surface_extractor_dict.keys() and v.size() == surface_extractor_dict[k].size():\n", " surface_extractor_dict[k] = v\n", "\n", " surface_g_ema.load_state_dict(surface_extractor_dict)\n", "else:\n", " surface_g_ema = None\n", "\n", "# Get the mean latent vector for g_ema\n", "if opt.inference.truncation_ratio < 1:\n", " with torch.no_grad():\n", " mean_latent = g_ema.mean_latent(opt.inference.truncation_mean, device)\n", "else:\n", " surface_mean_latent = None\n", "\n", "# Get the mean latent vector for surface_g_ema\n", "if not opt.inference.no_surface_renderings:\n", " surface_mean_latent = mean_latent[0]\n", "else:\n", " surface_mean_latent = None" ], "metadata": { "id": "CUcWipIlpINT" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "## Generating images and meshes\n", "\n", "Finally, we run the network. The results will be saved to `evaluations/[model_name]/final_model/[fixed/random]_angles`, according to the selected setup." ], "metadata": { "id": "N9pqjPDYCwIJ" } }, { "cell_type": "code", "source": [ "generate(opt.inference, g_ema, surface_g_ema, device, mean_latent, surface_mean_latent)" ], "metadata": { "id": "UG4hZgigDfG7" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "Now let's examine the results\n", "\n", "Tip: for better mesh visualization, we recommend dowwnloading the result meshes and view them with Meshlab.\n", "\n", "Meshes loaction is: `evaluations/[model_name]/final_model/[fixed/random]_angles/[depth_map/marching_cubes]_meshes`." ], "metadata": { "id": "obfjXuDxNuZl" } }, { "cell_type": "code", "source": [ "from PIL import Image\n", "from trimesh.viewer.notebook import scene_to_html as mesh2html\n", "from IPython.display import HTML as viewer_html\n", "\n", "# First let's look at the images\n", "img_dir = os.path.join(opt.inference.results_dst_dir,'images')\n", "im_list = sorted([entry for entry in os.listdir(img_dir) if 'thumb' not in entry])\n", "img = Image.new('RGB', (256 * len(im_list), 256))\n", "for i, im_file in enumerate(im_list):\n", " im_path = os.path.join(img_dir, im_file)\n", " curr_img = Image.open(im_path).resize((256,256)) # the displayed image is scaled to fit to the screen\n", " img.paste(curr_img, (256 * i, 0))\n", "\n", "display(img)\n", "\n", "# And now, we'll move on to display the marching cubes and depth map meshes\n", "\n", "marching_cubes_meshes_dir = os.path.join(opt.inference.results_dst_dir,'marching_cubes_meshes')\n", "marching_cubes_meshes_list = sorted([os.path.join(marching_cubes_meshes_dir, entry) for entry in os.listdir(marching_cubes_meshes_dir) if 'obj' in entry])\n", "depth_map_meshes_dir = os.path.join(opt.inference.results_dst_dir,'depth_map_meshes')\n", "depth_map_meshes_list = sorted([os.path.join(depth_map_meshes_dir, entry) for entry in os.listdir(depth_map_meshes_dir) if 'obj' in entry])\n", "for i, mesh_files in enumerate(zip(marching_cubes_meshes_list, depth_map_meshes_list)):\n", " mc_mesh_file, dm_mesh_file = mesh_files[0], mesh_files[1]\n", " marching_cubes_mesh = trimesh.Scene(trimesh.load_mesh(mc_mesh_file, 'obj')) \n", " curr_mc_html = mesh2html(marching_cubes_mesh).replace('\"', '"')\n", " display(viewer_html(' '.join(['']).format(\n", " srcdoc=curr_mc_html, height=256, width=256)))\n", " depth_map_mesh = trimesh.Scene(trimesh.load_mesh(dm_mesh_file, 'obj')) \n", " curr_dm_html = mesh2html(depth_map_mesh).replace('\"', '"')\n", " display(viewer_html(' '.join(['']).format(\n", " srcdoc=curr_dm_html, height=256, width=256)))" ], "metadata": { "id": "k3jXCVT7N2YZ" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "## Generating videos\n", "\n", "Additionally, we can also render videos. The results will be saved to `evaluations/[model_name]/final_model/videos`.\n", "\n", "Set the options below according to your choosing:\n", "1. If you wish to generate only RGB videos, set `opt.inference.no_surface_renderings = True`. Default: `False`.\n", "2. Set the camera trajectory. To travel along the azimuth direction set `opt.inference.azim_video = True`, to travel in an ellipsoid trajectory set `opt.inference.azim_video = False`. Default: `False`.\n", "\n", "###Important Note: \n", " - Processing time for videos when `opt.inference.no_surface_renderings = False` is very lengthy (~ 15-20 minutes per video). Rendering each depth frame for the depth videos is very slow.
\n", " - Processing time for videos when `opt.inference.no_surface_renderings = True` is much faster (~ 1-2 minutes per video)" ], "metadata": { "id": "Cxq3c6anN4D0" } }, { "cell_type": "code", "source": [ "# Options\n", "opt.inference.no_surface_renderings = True # When true, only RGB videos will be created\n", "opt.inference.azim_video = True # When true, the camera trajectory will travel along the azimuth direction. Otherwise, the camera will travel along an ellipsoid trajectory.\n", "\n", "opt.inference.results_dst_dir = os.path.join(os.path.split(opt.inference.results_dst_dir)[0], 'videos')\n", "os.makedirs(opt.inference.results_dst_dir, exist_ok=True)\n", "render_video(opt.inference, g_ema, surface_g_ema, device, mean_latent, surface_mean_latent)" ], "metadata": { "id": "nblhnZgcOST8" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "Let's watch the result videos.\n", "\n", "The output video files are relatively large, so it might take a while (about 1-2 minutes) for all of them to be loaded. " ], "metadata": { "id": "jkBYvlaeTGu1" } }, { "cell_type": "code", "source": [ "%%script bash --bg\n", "python3 -m https.server 8000" ], "metadata": { "id": "s-A9oIFXnYdM" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "# change ffhq1024x1024 to afhq512x512 if you are working on the AFHQ model\n", "%%html\n", "
\n", " \n", " \n", " \n", " \n", "
" ], "metadata": { "id": "Rz8tq-yYnZMD" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "An alternative way to view the videos with python code. \n", "It loads the videos faster, but very often it crashes the notebook since the video file are too large.\n", "\n", "**It is not recommended to view the files this way**.\n", "\n", "If the notebook does crash, you can also refresh the webpage and manually download the videos.
\n", "The videos are located in `evaluations//final_model/videos`" ], "metadata": { "id": "KP_Nrl8yqL65" } }, { "cell_type": "code", "source": [ "# from base64 import b64encode\n", "\n", "# videos_dir = opt.inference.results_dst_dir\n", "# videos_list = sorted([os.path.join(videos_dir, entry) for entry in os.listdir(videos_dir) if 'mp4' in entry])\n", "# for i, video_file in enumerate(videos_list):\n", "# if i != 1:\n", "# continue\n", "# mp4 = open(video_file,'rb').read()\n", "# data_url = \"data:video/mp4;base64,\" + b64encode(mp4).decode()\n", "# display(viewer_html(\"\"\"\"\"\".format(256, data_url, \"video/mp4\")))" ], "metadata": { "id": "LzsQTq1hTNSH" }, "execution_count": null, "outputs": [] } ] } ================================================ FILE: dataset.py ================================================ import os import csv import lmdb import random import numpy as np import torchvision.transforms.functional as TF from PIL import Image from io import BytesIO from torch.utils.data import Dataset class MultiResolutionDataset(Dataset): def __init__(self, path, transform, resolution=256, nerf_resolution=64): self.env = lmdb.open( path, max_readers=32, readonly=True, lock=False, readahead=False, meminit=False, ) if not self.env: raise IOError('Cannot open lmdb dataset', path) with self.env.begin(write=False) as txn: self.length = int(txn.get('length'.encode('utf-8')).decode('utf-8')) self.resolution = resolution self.nerf_resolution = nerf_resolution self.transform = transform def __len__(self): return self.length def __getitem__(self, index): with self.env.begin(write=False) as txn: key = f'{self.resolution}-{str(index).zfill(5)}'.encode('utf-8') img_bytes = txn.get(key) buffer = BytesIO(img_bytes) img = Image.open(buffer) if random.random() > 0.5: img = TF.hflip(img) thumb_img = img.resize((self.nerf_resolution, self.nerf_resolution), Image.HAMMING) img = self.transform(img) thumb_img = self.transform(thumb_img) return img, thumb_img ================================================ FILE: distributed.py ================================================ import math import pickle import torch from torch import distributed as dist from torch.utils.data.sampler import Sampler def get_rank(): if not dist.is_available(): return 0 if not dist.is_initialized(): return 0 return dist.get_rank() def synchronize(): if not dist.is_available(): return if not dist.is_initialized(): return world_size = dist.get_world_size() if world_size == 1: return dist.barrier() def get_world_size(): if not dist.is_available(): return 1 if not dist.is_initialized(): return 1 return dist.get_world_size() def reduce_sum(tensor): if not dist.is_available(): return tensor if not dist.is_initialized(): return tensor tensor = tensor.clone() dist.all_reduce(tensor, op=dist.ReduceOp.SUM) return tensor def gather_grad(params): world_size = get_world_size() if world_size == 1: return for param in params: if param.grad is not None: dist.all_reduce(param.grad.data, op=dist.ReduceOp.SUM) param.grad.data.div_(world_size) def all_gather(data): world_size = get_world_size() if world_size == 1: return [data] buffer = pickle.dumps(data) storage = torch.ByteStorage.from_buffer(buffer) tensor = torch.ByteTensor(storage).to('cuda') local_size = torch.IntTensor([tensor.numel()]).to('cuda') size_list = [torch.IntTensor([0]).to('cuda') for _ in range(world_size)] dist.all_gather(size_list, local_size) size_list = [int(size.item()) for size in size_list] max_size = max(size_list) tensor_list = [] for _ in size_list: tensor_list.append(torch.ByteTensor(size=(max_size,)).to('cuda')) if local_size != max_size: padding = torch.ByteTensor(size=(max_size - local_size,)).to('cuda') tensor = torch.cat((tensor, padding), 0) dist.all_gather(tensor_list, tensor) data_list = [] for size, tensor in zip(size_list, tensor_list): buffer = tensor.cpu().numpy().tobytes()[:size] data_list.append(pickle.loads(buffer)) return data_list def reduce_loss_dict(loss_dict): world_size = get_world_size() if world_size < 2: return loss_dict with torch.no_grad(): keys = [] losses = [] for k in sorted(loss_dict.keys()): keys.append(k) losses.append(loss_dict[k]) losses = torch.stack(losses, 0) dist.reduce(losses, dst=0) if dist.get_rank() == 0: losses /= world_size reduced_losses = {k: v for k, v in zip(keys, losses)} return reduced_losses ================================================ FILE: download_models.py ================================================ import os import html import glob import uuid import hashlib import requests from tqdm import tqdm from pdb import set_trace as st ffhq_full_model_spec = dict(file_url='https://drive.google.com/uc?id=13s_dH768zJ3IHUjySVbD1DqcMolqKlBi', alt_url='', file_size=202570217, file_md5='1ab9522157e537351fcf4faed9c92abb', file_path='full_models/ffhq1024x1024.pt',) ffhq_volume_renderer_spec = dict(file_url='https://drive.google.com/uc?id=1zzB-ACuas7lSAln8pDnIqlWOEg869CCK', alt_url='', file_size=63736197, file_md5='fe62f26032ccc8f04e1101b07bcd7462', file_path='pretrained_renderer/ffhq_vol_renderer.pt',) afhq_full_model_spec = dict(file_url='https://drive.google.com/uc?id=1jZcV5__EPS56JBllRmUfUjMHExql_eSq', alt_url='', file_size=184908743, file_md5='91eeaab2da5c0d2134c04bb56ac5aeb6', file_path='full_models/afhq512x512.pt',) afhq_volume_renderer_spec = dict(file_url='https://drive.google.com/uc?id=1xhZjuJt_teghAQEoJevAxrU5_8VqqX42', alt_url='', file_size=63736197, file_md5='cb3beb6cd3c43d9119356400165e7f26', file_path='pretrained_renderer/afhq_vol_renderer.pt',) volume_renderer_sphere_init_spec = dict(file_url='https://drive.google.com/uc?id=1CcYWbHFrJyb4u5rBFx_ww-ceYWd309tk', alt_url='', file_size=63736197, file_md5='a6be653435b0e49633e6838f18ff4df6', file_path='pretrained_renderer/sphere_init.pt',) def download_pretrained_models(): print('Downloading sphere initialized volume renderer') with requests.Session() as session: try: download_file(session, volume_renderer_sphere_init_spec) except: print('Google Drive download failed.\n' \ 'Trying do download from alternate server') download_file(session, volume_renderer_sphere_init_spec, use_alt_url=True) print('Downloading FFHQ pretrained volume renderer') with requests.Session() as session: try: download_file(session, ffhq_volume_renderer_spec) except: print('Google Drive download failed.\n' \ 'Trying do download from alternate server') download_file(session, ffhq_volume_renderer_spec, use_alt_url=True) print('Downloading FFHQ full model (1024x1024)') with requests.Session() as session: try: download_file(session, ffhq_full_model_spec) except: print('Google Drive download failed.\n' \ 'Trying do download from alternate server') download_file(session, ffhq_full_model_spec, use_alt_url=True) print('Done!') print('Downloading AFHQ pretrained volume renderer') with requests.Session() as session: try: download_file(session, afhq_volume_renderer_spec) except: print('Google Drive download failed.\n' \ 'Trying do download from alternate server') download_file(session, afhq_volume_renderer_spec, use_alt_url=True) print('Done!') print('Downloading Downloading AFHQ full model (512x512)') with requests.Session() as session: try: download_file(session, afhq_full_model_spec) except: print('Google Drive download failed.\n' \ 'Trying do download from alternate server') download_file(session, afhq_full_model_spec, use_alt_url=True) print('Done!') def download_file(session, file_spec, use_alt_url=False, chunk_size=128, num_attempts=10): file_path = file_spec['file_path'] if use_alt_url: file_url = file_spec['alt_url'] else: file_url = file_spec['file_url'] file_dir = os.path.dirname(file_path) tmp_path = file_path + '.tmp.' + uuid.uuid4().hex if file_dir: os.makedirs(file_dir, exist_ok=True) progress_bar = tqdm(total=file_spec['file_size'], unit='B', unit_scale=True) for attempts_left in reversed(range(num_attempts)): data_size = 0 progress_bar.reset() try: # Download. data_md5 = hashlib.md5() with session.get(file_url, stream=True) as res: res.raise_for_status() with open(tmp_path, 'wb') as f: for chunk in res.iter_content(chunk_size=chunk_size<<10): progress_bar.update(len(chunk)) f.write(chunk) data_size += len(chunk) data_md5.update(chunk) # Validate. if 'file_size' in file_spec and data_size != file_spec['file_size']: raise IOError('Incorrect file size', file_path) if 'file_md5' in file_spec and data_md5.hexdigest() != file_spec['file_md5']: raise IOError('Incorrect file MD5', file_path) break except: # Last attempt => raise error. if not attempts_left: raise # Handle Google Drive virus checker nag. if data_size > 0 and data_size < 8192: with open(tmp_path, 'rb') as f: data = f.read() links = [html.unescape(link) for link in data.decode('utf-8').split('"') if 'confirm=t' in link] if len(links) == 1: file_url = requests.compat.urljoin(file_url, links[0]) continue progress_bar.close() # Rename temp file to the correct name. os.replace(tmp_path, file_path) # atomic # Attempt to clean up any leftover temps. for filename in glob.glob(file_path + '.tmp.*'): try: os.remove(filename) except: pass if __name__ == "__main__": download_pretrained_models() ================================================ FILE: generate_shapes_and_images.py ================================================ import os import torch import trimesh import numpy as np from munch import * from PIL import Image from tqdm import tqdm from torch.nn import functional as F from torch.utils import data from torchvision import utils from torchvision import transforms from skimage.measure import marching_cubes from scipy.spatial import Delaunay from options import BaseOptions from model import Generator from utils import ( generate_camera_params, align_volume, extract_mesh_with_marching_cubes, xyz2mesh, ) torch.random.manual_seed(1234) def generate(opt, g_ema, surface_g_ema, device, mean_latent, surface_mean_latent): g_ema.eval() if not opt.no_surface_renderings: surface_g_ema.eval() # set camera angles if opt.fixed_camera_angles: # These can be changed to any other specific viewpoints. # You can add or remove viewpoints as you wish locations = torch.tensor([[0, 0], [-1.5 * opt.camera.azim, 0], [-1 * opt.camera.azim, 0], [-0.5 * opt.camera.azim, 0], [0.5 * opt.camera.azim, 0], [1 * opt.camera.azim, 0], [1.5 * opt.camera.azim, 0], [0, -1.5 * opt.camera.elev], [0, -1 * opt.camera.elev], [0, -0.5 * opt.camera.elev], [0, 0.5 * opt.camera.elev], [0, 1 * opt.camera.elev], [0, 1.5 * opt.camera.elev]], device=device) # For zooming in/out change the values of fov # (This can be defined for each view separately via a custom tensor # like the locations tensor above. Tensor shape should be [locations.shape[0],1]) # reasonable values are [0.75 * opt.camera.fov, 1.25 * opt.camera.fov] fov = opt.camera.fov * torch.ones((locations.shape[0],1), device=device) num_viewdirs = locations.shape[0] else: # draw random camera angles locations = None # fov = None fov = opt.camera.fov num_viewdirs = opt.num_views_per_id # generate images for i in tqdm(range(opt.identities)): with torch.no_grad(): chunk = 8 sample_z = torch.randn(1, opt.style_dim, device=device).repeat(num_viewdirs,1) sample_cam_extrinsics, sample_focals, sample_near, sample_far, sample_locations = \ generate_camera_params(opt.renderer_output_size, device, batch=num_viewdirs, locations=locations, #input_fov=fov, uniform=opt.camera.uniform, azim_range=opt.camera.azim, elev_range=opt.camera.elev, fov_ang=fov, dist_radius=opt.camera.dist_radius) rgb_images = torch.Tensor(0, 3, opt.size, opt.size) rgb_images_thumbs = torch.Tensor(0, 3, opt.renderer_output_size, opt.renderer_output_size) for j in range(0, num_viewdirs, chunk): out = g_ema([sample_z[j:j+chunk]], sample_cam_extrinsics[j:j+chunk], sample_focals[j:j+chunk], sample_near[j:j+chunk], sample_far[j:j+chunk], truncation=opt.truncation_ratio, truncation_latent=mean_latent) rgb_images = torch.cat([rgb_images, out[0].cpu()], 0) rgb_images_thumbs = torch.cat([rgb_images_thumbs, out[1].cpu()], 0) utils.save_image(rgb_images, os.path.join(opt.results_dst_dir, 'images','{}.png'.format(str(i).zfill(7))), nrow=num_viewdirs, normalize=True, padding=0, value_range=(-1, 1),) utils.save_image(rgb_images_thumbs, os.path.join(opt.results_dst_dir, 'images','{}_thumb.png'.format(str(i).zfill(7))), nrow=num_viewdirs, normalize=True, padding=0, value_range=(-1, 1),) # this is done to fit to RTX2080 RAM size (11GB) del out torch.cuda.empty_cache() if not opt.no_surface_renderings: surface_chunk = 1 scale = surface_g_ema.renderer.out_im_res / g_ema.renderer.out_im_res surface_sample_focals = sample_focals * scale for j in range(0, num_viewdirs, surface_chunk): surface_out = surface_g_ema([sample_z[j:j+surface_chunk]], sample_cam_extrinsics[j:j+surface_chunk], surface_sample_focals[j:j+surface_chunk], sample_near[j:j+surface_chunk], sample_far[j:j+surface_chunk], truncation=opt.truncation_ratio, truncation_latent=surface_mean_latent, return_sdf=True, return_xyz=True) xyz = surface_out[2].cpu() sdf = surface_out[3].cpu() # this is done to fit to RTX2080 RAM size (11GB) del surface_out torch.cuda.empty_cache() # mesh extractions are done one at a time for k in range(surface_chunk): curr_locations = sample_locations[j:j+surface_chunk] loc_str = '_azim{}_elev{}'.format(int(curr_locations[k,0] * 180 / np.pi), int(curr_locations[k,1] * 180 / np.pi)) # Save depth outputs as meshes depth_mesh_filename = os.path.join(opt.results_dst_dir,'depth_map_meshes','sample_{}_depth_mesh{}.obj'.format(i, loc_str)) depth_mesh = xyz2mesh(xyz[k:k+surface_chunk]) if depth_mesh != None: with open(depth_mesh_filename, 'w') as f: depth_mesh.export(f,file_type='obj') # extract full geometry with marching cubes if j == 0: try: frostum_aligned_sdf = align_volume(sdf) marching_cubes_mesh = extract_mesh_with_marching_cubes(frostum_aligned_sdf[k:k+surface_chunk]) except ValueError: marching_cubes_mesh = None print('Marching cubes extraction failed.') print('Please check whether the SDF values are all larger (or all smaller) than 0.') if marching_cubes_mesh != None: marching_cubes_mesh_filename = os.path.join(opt.results_dst_dir,'marching_cubes_meshes','sample_{}_marching_cubes_mesh{}.obj'.format(i, loc_str)) with open(marching_cubes_mesh_filename, 'w') as f: marching_cubes_mesh.export(f,file_type='obj') if __name__ == "__main__": device = "cuda" opt = BaseOptions().parse() opt.model.is_test = True opt.model.freeze_renderer = False opt.rendering.offset_sampling = True opt.rendering.static_viewdirs = True opt.rendering.force_background = True opt.rendering.perturb = 0 opt.inference.size = opt.model.size opt.inference.camera = opt.camera opt.inference.renderer_output_size = opt.model.renderer_spatial_output_dim opt.inference.style_dim = opt.model.style_dim opt.inference.project_noise = opt.model.project_noise opt.inference.return_xyz = opt.rendering.return_xyz # find checkpoint directory # check if there's a fully trained model checkpoints_dir = 'full_models' checkpoint_path = os.path.join(checkpoints_dir, opt.experiment.expname + '.pt') if os.path.isfile(checkpoint_path): # define results directory name result_model_dir = 'final_model' else: checkpoints_dir = os.path.join('checkpoint', opt.experiment.expname, 'full_pipeline') checkpoint_path = os.path.join(checkpoints_dir, 'models_{}.pt'.format(opt.experiment.ckpt.zfill(7))) # define results directory name result_model_dir = 'iter_{}'.format(opt.experiment.ckpt.zfill(7)) # create results directory results_dir_basename = os.path.join(opt.inference.results_dir, opt.experiment.expname) opt.inference.results_dst_dir = os.path.join(results_dir_basename, result_model_dir) if opt.inference.fixed_camera_angles: opt.inference.results_dst_dir = os.path.join(opt.inference.results_dst_dir, 'fixed_angles') else: opt.inference.results_dst_dir = os.path.join(opt.inference.results_dst_dir, 'random_angles') os.makedirs(opt.inference.results_dst_dir, exist_ok=True) os.makedirs(os.path.join(opt.inference.results_dst_dir, 'images'), exist_ok=True) if not opt.inference.no_surface_renderings: os.makedirs(os.path.join(opt.inference.results_dst_dir, 'depth_map_meshes'), exist_ok=True) os.makedirs(os.path.join(opt.inference.results_dst_dir, 'marching_cubes_meshes'), exist_ok=True) # load saved model checkpoint = torch.load(checkpoint_path) # load image generation model g_ema = Generator(opt.model, opt.rendering).to(device) pretrained_weights_dict = checkpoint["g_ema"] model_dict = g_ema.state_dict() for k, v in pretrained_weights_dict.items(): if v.size() == model_dict[k].size(): model_dict[k] = v g_ema.load_state_dict(model_dict) # load a second volume renderer that extracts surfaces at 128x128x128 (or higher) for better surface resolution if not opt.inference.no_surface_renderings: opt['surf_extraction'] = Munch() opt.surf_extraction.rendering = opt.rendering opt.surf_extraction.model = opt.model.copy() opt.surf_extraction.model.renderer_spatial_output_dim = 128 opt.surf_extraction.rendering.N_samples = opt.surf_extraction.model.renderer_spatial_output_dim opt.surf_extraction.rendering.return_xyz = True opt.surf_extraction.rendering.return_sdf = True surface_g_ema = Generator(opt.surf_extraction.model, opt.surf_extraction.rendering, full_pipeline=False).to(device) # Load weights to surface extractor surface_extractor_dict = surface_g_ema.state_dict() for k, v in pretrained_weights_dict.items(): if k in surface_extractor_dict.keys() and v.size() == surface_extractor_dict[k].size(): surface_extractor_dict[k] = v surface_g_ema.load_state_dict(surface_extractor_dict) else: surface_g_ema = None # get the mean latent vector for g_ema if opt.inference.truncation_ratio < 1: with torch.no_grad(): mean_latent = g_ema.mean_latent(opt.inference.truncation_mean, device) else: surface_mean_latent = None # get the mean latent vector for surface_g_ema if not opt.inference.no_surface_renderings: surface_mean_latent = mean_latent[0] else: surface_mean_latent = None generate(opt.inference, g_ema, surface_g_ema, device, mean_latent, surface_mean_latent) ================================================ FILE: losses.py ================================================ import math import torch from torch import autograd from torch.nn import functional as F def viewpoints_loss(viewpoint_pred, viewpoint_target): loss = F.smooth_l1_loss(viewpoint_pred, viewpoint_target) return loss def eikonal_loss(eikonal_term, sdf=None, beta=100): if eikonal_term == None: eikonal_loss = 0 else: eikonal_loss = ((eikonal_term.norm(dim=-1) - 1) ** 2).mean() if sdf == None: minimal_surface_loss = torch.tensor(0.0, device=eikonal_term.device) else: minimal_surface_loss = torch.exp(-beta * torch.abs(sdf)).mean() return eikonal_loss, minimal_surface_loss def d_logistic_loss(real_pred, fake_pred): real_loss = F.softplus(-real_pred) fake_loss = F.softplus(fake_pred) return real_loss.mean() + fake_loss.mean() def d_r1_loss(real_pred, real_img): grad_real, = autograd.grad(outputs=real_pred.sum(), inputs=real_img, create_graph=True) grad_penalty = grad_real.pow(2).reshape(grad_real.shape[0], -1).sum(1).mean() return grad_penalty def g_nonsaturating_loss(fake_pred): loss = F.softplus(-fake_pred).mean() return loss def g_path_regularize(fake_img, latents, mean_path_length, decay=0.01): noise = torch.randn_like(fake_img) / math.sqrt( fake_img.shape[2] * fake_img.shape[3] ) grad = autograd.grad(outputs=(fake_img * noise).sum(), inputs=latents, create_graph=True, only_inputs=True)[0] path_lengths = torch.sqrt(grad.pow(2).sum(2).mean(1)) path_mean = mean_path_length + decay * (path_lengths.mean() - mean_path_length) path_penalty = (path_lengths - path_mean).pow(2).mean() return path_penalty, path_mean.detach(), path_lengths ================================================ FILE: model.py ================================================ import math import random import trimesh import torch import numpy as np from torch import nn from torch.nn import functional as F from volume_renderer import VolumeFeatureRenderer from op import FusedLeakyReLU, fused_leaky_relu, upfirdn2d from pdb import set_trace as st from utils import ( create_cameras, create_mesh_renderer, add_textures, create_depth_mesh_renderer, ) from pytorch3d.renderer import TexturesUV from pytorch3d.structures import Meshes from pytorch3d.renderer import look_at_view_transform, FoVPerspectiveCameras from pytorch3d.transforms import matrix_to_euler_angles class PixelNorm(nn.Module): def __init__(self): super().__init__() def forward(self, input): return input * torch.rsqrt(torch.mean(input ** 2, dim=1, keepdim=True) + 1e-8) class MappingLinear(nn.Module): def __init__(self, in_dim, out_dim, bias=True, activation=None, is_last=False): super().__init__() if is_last: weight_std = 0.25 else: weight_std = 1 self.weight = nn.Parameter(weight_std * nn.init.kaiming_normal_(torch.empty(out_dim, in_dim), a=0.2, mode='fan_in', nonlinearity='leaky_relu')) if bias: self.bias = nn.Parameter(nn.init.uniform_(torch.empty(out_dim), a=-np.sqrt(1/in_dim), b=np.sqrt(1/in_dim))) else: self.bias = None self.activation = activation def forward(self, input): if self.activation != None: out = F.linear(input, self.weight) out = fused_leaky_relu(out, self.bias, scale=1) else: out = F.linear(input, self.weight, bias=self.bias) return out def __repr__(self): return ( f"{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]})" ) def make_kernel(k): k = torch.tensor(k, dtype=torch.float32) if k.ndim == 1: k = k[None, :] * k[:, None] k /= k.sum() return k class Upsample(nn.Module): def __init__(self, kernel, factor=2): super().__init__() self.factor = factor kernel = make_kernel(kernel) * (factor ** 2) self.register_buffer("kernel", kernel) p = kernel.shape[0] - factor pad0 = (p + 1) // 2 + factor - 1 pad1 = p // 2 self.pad = (pad0, pad1) def forward(self, input): out = upfirdn2d(input, self.kernel, up=self.factor, down=1, pad=self.pad) return out class Downsample(nn.Module): def __init__(self, kernel, factor=2): super().__init__() self.factor = factor kernel = make_kernel(kernel) self.register_buffer("kernel", kernel) p = kernel.shape[0] - factor pad0 = (p + 1) // 2 pad1 = p // 2 self.pad = (pad0, pad1) def forward(self, input): out = upfirdn2d(input, self.kernel, up=1, down=self.factor, pad=self.pad) return out class Blur(nn.Module): def __init__(self, kernel, pad, upsample_factor=1): super().__init__() kernel = make_kernel(kernel) if upsample_factor > 1: kernel = kernel * (upsample_factor ** 2) self.register_buffer("kernel", kernel) self.pad = pad def forward(self, input): out = upfirdn2d(input, self.kernel, pad=self.pad) return out class EqualConv2d(nn.Module): def __init__( self, in_channel, out_channel, kernel_size, stride=1, padding=0, bias=True ): super().__init__() self.weight = nn.Parameter( torch.randn(out_channel, in_channel, kernel_size, kernel_size) ) self.scale = 1 / math.sqrt(in_channel * kernel_size ** 2) self.stride = stride self.padding = padding if bias: self.bias = nn.Parameter(torch.zeros(out_channel)) else: self.bias = None def forward(self, input): out = F.conv2d( input, self.weight * self.scale, bias=self.bias, stride=self.stride, padding=self.padding, ) return out def __repr__(self): return ( f"{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]}," f" {self.weight.shape[2]}, stride={self.stride}, padding={self.padding})" ) class EqualLinear(nn.Module): def __init__(self, in_dim, out_dim, bias=True, bias_init=0, lr_mul=1, activation=None): super().__init__() self.weight = nn.Parameter(torch.randn(out_dim, in_dim).div_(lr_mul)) if bias: self.bias = nn.Parameter(torch.zeros(out_dim).fill_(bias_init)) else: self.bias = None self.activation = activation self.scale = (1 / math.sqrt(in_dim)) * lr_mul self.lr_mul = lr_mul def forward(self, input): if self.activation: out = F.linear(input, self.weight * self.scale) out = fused_leaky_relu(out, self.bias * self.lr_mul) else: out = F.linear(input, self.weight * self.scale, bias=self.bias * self.lr_mul) return out def __repr__(self): return ( f"{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]})" ) class ModulatedConv2d(nn.Module): def __init__(self, in_channel, out_channel, kernel_size, style_dim, demodulate=True, upsample=False, downsample=False, blur_kernel=[1, 3, 3, 1]): super().__init__() self.eps = 1e-8 self.kernel_size = kernel_size self.in_channel = in_channel self.out_channel = out_channel self.upsample = upsample self.downsample = downsample if upsample: factor = 2 p = (len(blur_kernel) - factor) - (kernel_size - 1) pad0 = (p + 1) // 2 + factor - 1 pad1 = p // 2 + 1 self.blur = Blur(blur_kernel, pad=(pad0, pad1), upsample_factor=factor) if downsample: factor = 2 p = (len(blur_kernel) - factor) + (kernel_size - 1) pad0 = (p + 1) // 2 pad1 = p // 2 self.blur = Blur(blur_kernel, pad=(pad0, pad1)) fan_in = in_channel * kernel_size ** 2 self.scale = 1 / math.sqrt(fan_in) self.padding = kernel_size // 2 self.weight = nn.Parameter( torch.randn(1, out_channel, in_channel, kernel_size, kernel_size) ) self.modulation = EqualLinear(style_dim, in_channel, bias_init=1) self.demodulate = demodulate def __repr__(self): return ( f"{self.__class__.__name__}({self.in_channel}, {self.out_channel}, {self.kernel_size}, " f"upsample={self.upsample}, downsample={self.downsample})" ) def forward(self, input, style): batch, in_channel, height, width = input.shape style = self.modulation(style).view(batch, 1, in_channel, 1, 1) weight = self.scale * self.weight * style if self.demodulate: demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + 1e-8) weight = weight * demod.view(batch, self.out_channel, 1, 1, 1) weight = weight.view( batch * self.out_channel, in_channel, self.kernel_size, self.kernel_size ) if self.upsample: input = input.view(1, batch * in_channel, height, width) weight = weight.view( batch, self.out_channel, in_channel, self.kernel_size, self.kernel_size ) weight = weight.transpose(1, 2).reshape( batch * in_channel, self.out_channel, self.kernel_size, self.kernel_size ) out = F.conv_transpose2d(input, weight, padding=0, stride=2, groups=batch) _, _, height, width = out.shape out = out.view(batch, self.out_channel, height, width) out = self.blur(out) elif self.downsample: input = self.blur(input) _, _, height, width = input.shape input = input.view(1, batch * in_channel, height, width) out = F.conv2d(input, weight, padding=0, stride=2, groups=batch) _, _, height, width = out.shape out = out.view(batch, self.out_channel, height, width) else: input = input.view(1, batch * in_channel, height, width) out = F.conv2d(input, weight, padding=self.padding, groups=batch) _, _, height, width = out.shape out = out.view(batch, self.out_channel, height, width) return out class NoiseInjection(nn.Module): def __init__(self, project=False): super().__init__() self.project = project self.weight = nn.Parameter(torch.zeros(1)) self.prev_noise = None self.mesh_fn = None self.vert_noise = None def create_pytorch_mesh(self, trimesh): v=trimesh.vertices; f=trimesh.faces verts = torch.from_numpy(np.asarray(v)).to(torch.float32).cuda() mesh_pytorch = Meshes( verts=[verts], faces = [torch.from_numpy(np.asarray(f)).to(torch.float32).cuda()], textures=None ) if self.vert_noise == None or verts.shape[0] != self.vert_noise.shape[1]: self.vert_noise = torch.ones_like(verts)[:,0:1].cpu().normal_().expand(-1,3).unsqueeze(0) mesh_pytorch = add_textures(meshes=mesh_pytorch, vertex_colors=self.vert_noise.to(verts.device)) return mesh_pytorch def load_mc_mesh(self, filename, resolution=128, im_res=64): import trimesh mc_tri=trimesh.load_mesh(filename) v=mc_tri.vertices; f=mc_tri.faces mesh2=trimesh.base.Trimesh(vertices=v, faces=f) if im_res==64 or im_res==128: pytorch3d_mesh = self.create_pytorch_mesh(mesh2) return pytorch3d_mesh v,f = trimesh.remesh.subdivide(v,f) mesh2_subdiv = trimesh.base.Trimesh(vertices=v, faces=f) if im_res==256: pytorch3d_mesh = self.create_pytorch_mesh(mesh2_subdiv); return pytorch3d_mesh v,f = trimesh.remesh.subdivide(mesh2_subdiv.vertices,mesh2_subdiv.faces) mesh3_subdiv = trimesh.base.Trimesh(vertices=v, faces=f) if im_res==256: pytorch3d_mesh = self.create_pytorch_mesh(mesh3_subdiv); return pytorch3d_mesh v,f = trimesh.remesh.subdivide(mesh3_subdiv.vertices,mesh3_subdiv.faces) mesh4_subdiv = trimesh.base.Trimesh(vertices=v, faces=f) pytorch3d_mesh = self.create_pytorch_mesh(mesh4_subdiv) return pytorch3d_mesh def project_noise(self, noise, transform, mesh_path=None): batch, _, height, width = noise.shape assert(batch == 1) # assuming during inference batch size is 1 angles = matrix_to_euler_angles(transform[0:1,:,:3], "ZYX") azim = float(angles[0][1]) elev = float(-angles[0][2]) cameras = create_cameras(azim=azim*180/np.pi,elev=elev*180/np.pi,fov=12.,dist=1) renderer = create_depth_mesh_renderer(cameras, image_size=height, specular_color=((0,0,0),), ambient_color=((1.,1.,1.),),diffuse_color=((0,0,0),)) if self.mesh_fn is None or self.mesh_fn != mesh_path: self.mesh_fn = mesh_path pytorch3d_mesh = self.load_mc_mesh(mesh_path, im_res=height) rgb, depth = renderer(pytorch3d_mesh) depth_max = depth.max(-1)[0].view(-1) # (NxN) depth_valid = depth_max > 0. if self.prev_noise is None: self.prev_noise = noise noise_copy = self.prev_noise.clone() noise_copy.view(-1)[depth_valid] = rgb[0,:,:,0].view(-1)[depth_valid] noise_copy = noise_copy.reshape(1,1,height,height) # 1x1xNxN return noise_copy def forward(self, image, noise=None, transform=None, mesh_path=None): batch, _, height, width = image.shape if noise is None: noise = image.new_empty(batch, 1, height, width).normal_() elif self.project: noise = self.project_noise(noise, transform, mesh_path=mesh_path) return image + self.weight * noise class StyledConv(nn.Module): def __init__(self, in_channel, out_channel, kernel_size, style_dim, upsample=False, blur_kernel=[1, 3, 3, 1], project_noise=False): super().__init__() self.conv = ModulatedConv2d( in_channel, out_channel, kernel_size, style_dim, upsample=upsample, blur_kernel=blur_kernel, ) self.noise = NoiseInjection(project=project_noise) self.bias = nn.Parameter(torch.zeros(1, out_channel, 1, 1)) self.activate = FusedLeakyReLU(out_channel) def forward(self, input, style, noise=None, transform=None, mesh_path=None): out = self.conv(input, style) out = self.noise(out, noise=noise, transform=transform, mesh_path=mesh_path) out = self.activate(out) return out class ToRGB(nn.Module): def __init__(self, in_channel, style_dim, upsample=True, blur_kernel=[1, 3, 3, 1]): super().__init__() self.upsample = upsample out_channels = 3 if upsample: self.upsample = Upsample(blur_kernel) self.conv = ModulatedConv2d(in_channel, out_channels, 1, style_dim, demodulate=False) self.bias = nn.Parameter(torch.zeros(1, out_channels, 1, 1)) def forward(self, input, style, skip=None): out = self.conv(input, style) out = out + self.bias if skip is not None: if self.upsample: skip = self.upsample(skip) out = out + skip return out class ConvLayer(nn.Sequential): def __init__(self, in_channel, out_channel, kernel_size, downsample=False, blur_kernel=[1, 3, 3, 1], bias=True, activate=True): layers = [] if downsample: factor = 2 p = (len(blur_kernel) - factor) + (kernel_size - 1) pad0 = (p + 1) // 2 pad1 = p // 2 layers.append(Blur(blur_kernel, pad=(pad0, pad1))) stride = 2 self.padding = 0 else: stride = 1 self.padding = kernel_size // 2 layers.append( EqualConv2d( in_channel, out_channel, kernel_size, padding=self.padding, stride=stride, bias=bias and not activate, ) ) if activate: layers.append(FusedLeakyReLU(out_channel, bias=bias)) super().__init__(*layers) class Decoder(nn.Module): def __init__(self, model_opt, blur_kernel=[1, 3, 3, 1]): super().__init__() # decoder mapping network self.size = model_opt.size self.style_dim = model_opt.style_dim * 2 thumb_im_size = model_opt.renderer_spatial_output_dim layers = [PixelNorm(), EqualLinear( self.style_dim // 2, self.style_dim, lr_mul=model_opt.lr_mapping, activation="fused_lrelu" )] for i in range(4): layers.append( EqualLinear( self.style_dim, self.style_dim, lr_mul=model_opt.lr_mapping, activation="fused_lrelu" ) ) self.style = nn.Sequential(*layers) # decoder network self.channels = { 4: 512, 8: 512, 16: 512, 32: 512, 64: 256 * model_opt.channel_multiplier, 128: 128 * model_opt.channel_multiplier, 256: 64 * model_opt.channel_multiplier, 512: 32 * model_opt.channel_multiplier, 1024: 16 * model_opt.channel_multiplier, } decoder_in_size = model_opt.renderer_spatial_output_dim # image decoder self.log_size = int(math.log(self.size, 2)) self.log_in_size = int(math.log(decoder_in_size, 2)) self.conv1 = StyledConv( model_opt.feature_encoder_in_channels, self.channels[decoder_in_size], 3, self.style_dim, blur_kernel=blur_kernel, project_noise=model_opt.project_noise) self.to_rgb1 = ToRGB(self.channels[decoder_in_size], self.style_dim, upsample=False) self.num_layers = (self.log_size - self.log_in_size) * 2 + 1 self.convs = nn.ModuleList() self.upsamples = nn.ModuleList() self.to_rgbs = nn.ModuleList() self.noises = nn.Module() in_channel = self.channels[decoder_in_size] for layer_idx in range(self.num_layers): res = (layer_idx + 2 * self.log_in_size + 1) // 2 shape = [1, 1, 2 ** (res), 2 ** (res)] self.noises.register_buffer(f"noise_{layer_idx}", torch.randn(*shape)) for i in range(self.log_in_size+1, self.log_size + 1): out_channel = self.channels[2 ** i] self.convs.append( StyledConv(in_channel, out_channel, 3, self.style_dim, upsample=True, blur_kernel=blur_kernel, project_noise=model_opt.project_noise) ) self.convs.append( StyledConv(out_channel, out_channel, 3, self.style_dim, blur_kernel=blur_kernel, project_noise=model_opt.project_noise) ) self.to_rgbs.append(ToRGB(out_channel, self.style_dim)) in_channel = out_channel self.n_latent = (self.log_size - self.log_in_size) * 2 + 2 def mean_latent(self, renderer_latent): latent = self.style(renderer_latent).mean(0, keepdim=True) return latent def get_latent(self, input): return self.style(input) def styles_and_noise_forward(self, styles, noise, inject_index=None, truncation=1, truncation_latent=None, input_is_latent=False, randomize_noise=True): if not input_is_latent: styles = [self.style(s) for s in styles] if noise is None: if randomize_noise: noise = [None] * self.num_layers else: noise = [ getattr(self.noises, f"noise_{i}") for i in range(self.num_layers) ] if (truncation < 1): style_t = [] for style in styles: style_t.append( truncation_latent[1] + truncation * (style - truncation_latent[1]) ) styles = style_t if len(styles) < 2: inject_index = self.n_latent if styles[0].ndim < 3: latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1) else: latent = styles[0] else: if inject_index is None: inject_index = random.randint(1, self.n_latent - 1) latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1) latent2 = styles[1].unsqueeze(1).repeat(1, self.n_latent - inject_index, 1) latent = torch.cat([latent, latent2], 1) return latent, noise def forward(self, features, styles, rgbd_in=None, transform=None, return_latents=False, inject_index=None, truncation=1, truncation_latent=None, input_is_latent=False, noise=None, randomize_noise=True, mesh_path=None): latent, noise = self.styles_and_noise_forward(styles, noise, inject_index, truncation, truncation_latent, input_is_latent, randomize_noise) out = self.conv1(features, latent[:, 0], noise=noise[0], transform=transform, mesh_path=mesh_path) skip = self.to_rgb1(out, latent[:, 1], skip=rgbd_in) i = 1 for conv1, conv2, noise1, noise2, to_rgb in zip( self.convs[::2], self.convs[1::2], noise[1::2], noise[2::2], self.to_rgbs ): out = conv1(out, latent[:, i], noise=noise1, transform=transform, mesh_path=mesh_path) out = conv2(out, latent[:, i + 1], noise=noise2, transform=transform, mesh_path=mesh_path) skip = to_rgb(out, latent[:, i + 2], skip=skip) i += 2 out_latent = latent if return_latents else None image = skip return image, out_latent class Generator(nn.Module): def __init__(self, model_opt, renderer_opt, blur_kernel=[1, 3, 3, 1], ema=False, full_pipeline=True): super().__init__() self.size = model_opt.size self.style_dim = model_opt.style_dim self.num_layers = 1 self.train_renderer = not model_opt.freeze_renderer self.full_pipeline = full_pipeline model_opt.feature_encoder_in_channels = renderer_opt.width if ema or 'is_test' in model_opt.keys(): self.is_train = False else: self.is_train = True # volume renderer mapping_network layers = [] for i in range(3): layers.append( MappingLinear(self.style_dim, self.style_dim, activation="fused_lrelu") ) self.style = nn.Sequential(*layers) # volume renderer thumb_im_size = model_opt.renderer_spatial_output_dim self.renderer = VolumeFeatureRenderer(renderer_opt, style_dim=self.style_dim, out_im_res=thumb_im_size) if self.full_pipeline: self.decoder = Decoder(model_opt) def make_noise(self): device = self.input.input.device noises = [torch.randn(1, 1, 2 ** 2, 2 ** 2, device=device)] for i in range(3, self.log_size + 1): for _ in range(2): noises.append(torch.randn(1, 1, 2 ** i, 2 ** i, device=device)) return noises def mean_latent(self, n_latent, device): latent_in = torch.randn(n_latent, self.style_dim, device=device) renderer_latent = self.style(latent_in) renderer_latent_mean = renderer_latent.mean(0, keepdim=True) if self.full_pipeline: decoder_latent_mean = self.decoder.mean_latent(renderer_latent) else: decoder_latent_mean = None return [renderer_latent_mean, decoder_latent_mean] def get_latent(self, input): return self.style(input) def styles_and_noise_forward(self, styles, inject_index=None, truncation=1, truncation_latent=None, input_is_latent=False): if not input_is_latent: styles = [self.style(s) for s in styles] if truncation < 1: style_t = [] for style in styles: style_t.append( truncation_latent[0] + truncation * (style - truncation_latent[0]) ) styles = style_t return styles def init_forward(self, styles, cam_poses, focals, near=0.88, far=1.12): latent = self.styles_and_noise_forward(styles) sdf, target_values = self.renderer.mlp_init_pass(cam_poses, focals, near, far, styles=latent[0]) return sdf, target_values def forward(self, styles, cam_poses, focals, near=0.88, far=1.12, return_latents=False, inject_index=None, truncation=1, truncation_latent=None, input_is_latent=False, noise=None, randomize_noise=True, return_sdf=False, return_xyz=False, return_eikonal=False, project_noise=False, mesh_path=None): # do not calculate renderer gradients if renderer weights are frozen with torch.set_grad_enabled(self.is_train and self.train_renderer): latent = self.styles_and_noise_forward(styles, inject_index, truncation, truncation_latent, input_is_latent) thumb_rgb, features, sdf, mask, xyz, eikonal_term = self.renderer(cam_poses, focals, near, far, styles=latent[0], return_eikonal=return_eikonal) if self.full_pipeline: rgb, decoder_latent = self.decoder(features, latent, transform=cam_poses if project_noise else None, return_latents=return_latents, inject_index=inject_index, truncation=truncation, truncation_latent=truncation_latent, noise=noise, input_is_latent=input_is_latent, randomize_noise=randomize_noise, mesh_path=mesh_path) else: rgb = None if return_latents: return rgb, decoder_latent else: out = (rgb, thumb_rgb) if return_xyz: out += (xyz,) if return_sdf: out += (sdf,) if return_eikonal: out += (eikonal_term,) if return_xyz: out += (mask,) return out ############# Volume Renderer Building Blocks & Discriminator ################## class VolumeRenderDiscConv2d(nn.Module): def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, bias=True, activate=False): super(VolumeRenderDiscConv2d, self).__init__() self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=bias and not activate) self.activate = activate if self.activate: self.activation = FusedLeakyReLU(out_channels, bias=bias, scale=1) bias_init_coef = np.sqrt(1 / (in_channels * kernel_size * kernel_size)) nn.init.uniform_(self.activation.bias, a=-bias_init_coef, b=bias_init_coef) def forward(self, input): """ input_tensor_shape: (N, C_in,H,W) output_tensor_shape: (N,C_out,H_out,W_out) :return: Conv2d + activation Result """ out = self.conv(input) if self.activate: out = self.activation(out) return out class AddCoords(nn.Module): def __init__(self): super(AddCoords, self).__init__() def forward(self, input_tensor): """ :param input_tensor: shape (N, C_in, H, W) :return: """ batch_size_shape, channel_in_shape, dim_y, dim_x = input_tensor.shape xx_channel = torch.arange(dim_x, dtype=torch.float32, device=input_tensor.device).repeat(1,1,dim_y,1) yy_channel = torch.arange(dim_y, dtype=torch.float32, device=input_tensor.device).repeat(1,1,dim_x,1).transpose(2,3) xx_channel = xx_channel / (dim_x - 1) yy_channel = yy_channel / (dim_y - 1) xx_channel = xx_channel * 2 - 1 yy_channel = yy_channel * 2 - 1 xx_channel = xx_channel.repeat(batch_size_shape, 1, 1, 1) yy_channel = yy_channel.repeat(batch_size_shape, 1, 1, 1) out = torch.cat([input_tensor, yy_channel, xx_channel], dim=1) return out class CoordConv2d(nn.Module): def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, bias=True): super(CoordConv2d, self).__init__() self.addcoords = AddCoords() self.conv = nn.Conv2d(in_channels + 2, out_channels, kernel_size, stride=stride, padding=padding, bias=bias) def forward(self, input_tensor): """ input_tensor_shape: (N, C_in,H,W) output_tensor_shape: N,C_out,H_out,W_out) :return: CoordConv2d Result """ out = self.addcoords(input_tensor) out = self.conv(out) return out class CoordConvLayer(nn.Module): def __init__(self, in_channel, out_channel, kernel_size, bias=True, activate=True): super(CoordConvLayer, self).__init__() layers = [] stride = 1 self.activate = activate self.padding = kernel_size // 2 if kernel_size > 2 else 0 self.conv = CoordConv2d(in_channel, out_channel, kernel_size, padding=self.padding, stride=stride, bias=bias and not activate) if activate: self.activation = FusedLeakyReLU(out_channel, bias=bias, scale=1) bias_init_coef = np.sqrt(1 / (in_channel * kernel_size * kernel_size)) nn.init.uniform_(self.activation.bias, a=-bias_init_coef, b=bias_init_coef) def forward(self, input): out = self.conv(input) if self.activate: out = self.activation(out) return out class VolumeRenderResBlock(nn.Module): def __init__(self, in_channel, out_channel): super().__init__() self.conv1 = CoordConvLayer(in_channel, out_channel, 3) self.conv2 = CoordConvLayer(out_channel, out_channel, 3) self.pooling = nn.AvgPool2d(2) self.downsample = nn.AvgPool2d(2) if out_channel != in_channel: self.skip = VolumeRenderDiscConv2d(in_channel, out_channel, 1) else: self.skip = None def forward(self, input): out = self.conv1(input) out = self.conv2(out) out = self.pooling(out) downsample_in = self.downsample(input) if self.skip != None: skip_in = self.skip(downsample_in) else: skip_in = downsample_in out = (out + skip_in) / math.sqrt(2) return out class VolumeRenderDiscriminator(nn.Module): def __init__(self, opt): super().__init__() init_size = opt.renderer_spatial_output_dim self.viewpoint_loss = not opt.no_viewpoint_loss final_out_channel = 3 if self.viewpoint_loss else 1 channels = { 2: 400, 4: 400, 8: 400, 16: 400, 32: 256, 64: 128, 128: 64, } convs = [VolumeRenderDiscConv2d(3, channels[init_size], 1, activate=True)] log_size = int(math.log(init_size, 2)) in_channel = channels[init_size] for i in range(log_size-1, 0, -1): out_channel = channels[2 ** i] convs.append(VolumeRenderResBlock(in_channel, out_channel)) in_channel = out_channel self.convs = nn.Sequential(*convs) self.final_conv = VolumeRenderDiscConv2d(in_channel, final_out_channel, 2) def forward(self, input): out = self.convs(input) out = self.final_conv(out) gan_preds = out[:,0:1] gan_preds = gan_preds.view(-1, 1) if self.viewpoint_loss: viewpoints_preds = out[:,1:] viewpoints_preds = viewpoints_preds.view(-1,2) else: viewpoints_preds = None return gan_preds, viewpoints_preds ######################### StyleGAN Discriminator ######################## class ResBlock(nn.Module): def __init__(self, in_channel, out_channel, blur_kernel=[1, 3, 3, 1], merge=False): super().__init__() self.conv1 = ConvLayer(2 * in_channel if merge else in_channel, in_channel, 3) self.conv2 = ConvLayer(in_channel, out_channel, 3, downsample=True) self.skip = ConvLayer(2 * in_channel if merge else in_channel, out_channel, 1, downsample=True, activate=False, bias=False) def forward(self, input): out = self.conv1(input) out = self.conv2(out) out = (out + self.skip(input)) / math.sqrt(2) return out class Discriminator(nn.Module): def __init__(self, opt, blur_kernel=[1, 3, 3, 1]): super().__init__() init_size = opt.size channels = { 4: 512, 8: 512, 16: 512, 32: 512, 64: 256 * opt.channel_multiplier, 128: 128 * opt.channel_multiplier, 256: 64 * opt.channel_multiplier, 512: 32 * opt.channel_multiplier, 1024: 16 * opt.channel_multiplier, } convs = [ConvLayer(3, channels[init_size], 1)] log_size = int(math.log(init_size, 2)) in_channel = channels[init_size] for i in range(log_size, 2, -1): out_channel = channels[2 ** (i - 1)] convs.append(ResBlock(in_channel, out_channel, blur_kernel)) in_channel = out_channel self.convs = nn.Sequential(*convs) self.stddev_group = 4 self.stddev_feat = 1 # minibatch discrimination in_channel += 1 self.final_conv = ConvLayer(in_channel, channels[4], 3) self.final_linear = nn.Sequential( EqualLinear(channels[4] * 4 * 4, channels[4], activation="fused_lrelu"), EqualLinear(channels[4], 1), ) def forward(self, input): out = self.convs(input) # minibatch discrimination batch, channel, height, width = out.shape group = min(batch, self.stddev_group) if batch % group != 0: group = 3 if batch % 3 == 0 else 2 stddev = out.view( group, -1, self.stddev_feat, channel // self.stddev_feat, height, width ) stddev = torch.sqrt(stddev.var(0, unbiased=False) + 1e-8) stddev = stddev.mean([2, 3, 4], keepdims=True).squeeze(2) stddev = stddev.repeat(group, 1, height, width) final_out = torch.cat([out, stddev], 1) # final layers final_out = self.final_conv(final_out) final_out = final_out.view(batch, -1) final_out = self.final_linear(final_out) gan_preds = final_out[:,:1] return gan_preds ================================================ FILE: op/__init__.py ================================================ from .fused_act import FusedLeakyReLU, fused_leaky_relu from .upfirdn2d import upfirdn2d ================================================ FILE: op/fused_act.py ================================================ import os import torch from torch import nn from torch.nn import functional as F from torch.autograd import Function from torch.utils.cpp_extension import load module_path = os.path.dirname(__file__) fused = load( "fused", sources=[ os.path.join(module_path, "fused_bias_act.cpp"), os.path.join(module_path, "fused_bias_act_kernel.cu"), ], ) class FusedLeakyReLUFunctionBackward(Function): @staticmethod def forward(ctx, grad_output, out, bias, negative_slope, scale): ctx.save_for_backward(out) ctx.negative_slope = negative_slope ctx.scale = scale empty = grad_output.new_empty(0) grad_input = fused.fused_bias_act( grad_output, empty, out, 3, 1, negative_slope, scale ) dim = [0] if grad_input.ndim > 2: dim += list(range(2, grad_input.ndim)) if bias: grad_bias = grad_input.sum(dim).detach() else: grad_bias = empty return grad_input, grad_bias @staticmethod def backward(ctx, gradgrad_input, gradgrad_bias): out, = ctx.saved_tensors gradgrad_out = fused.fused_bias_act( gradgrad_input, gradgrad_bias, out, 3, 1, ctx.negative_slope, ctx.scale ) return gradgrad_out, None, None, None, None class FusedLeakyReLUFunction(Function): @staticmethod def forward(ctx, input, bias, negative_slope, scale): empty = input.new_empty(0) ctx.bias = bias is not None if bias is None: bias = empty out = fused.fused_bias_act(input, bias, empty, 3, 0, negative_slope, scale) ctx.save_for_backward(out) ctx.negative_slope = negative_slope ctx.scale = scale return out @staticmethod def backward(ctx, grad_output): out, = ctx.saved_tensors grad_input, grad_bias = FusedLeakyReLUFunctionBackward.apply( grad_output, out, ctx.bias, ctx.negative_slope, ctx.scale ) if not ctx.bias: grad_bias = None return grad_input, grad_bias, None, None class FusedLeakyReLU(nn.Module): def __init__(self, channel, bias=True, negative_slope=0.2, scale=2 ** 0.5): super().__init__() if bias: self.bias = nn.Parameter(torch.zeros(channel)) else: self.bias = None self.negative_slope = negative_slope self.scale = scale def forward(self, input): return fused_leaky_relu(input, self.bias, self.negative_slope, self.scale) def fused_leaky_relu(input, bias=None, negative_slope=0.2, scale=2 ** 0.5): if input.device.type == "cpu": if bias is not None: rest_dim = [1] * (input.ndim - bias.ndim - 1) return ( F.leaky_relu( input + bias.view(1, bias.shape[0], *rest_dim), negative_slope=0.2 ) * scale ) else: return F.leaky_relu(input, negative_slope=0.2) * scale else: return FusedLeakyReLUFunction.apply(input, bias, negative_slope, scale) ================================================ FILE: op/fused_bias_act.cpp ================================================ #include torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer, int act, int grad, float alpha, float scale); #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) torch::Tensor fused_bias_act(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer, int act, int grad, float alpha, float scale) { CHECK_CUDA(input); CHECK_CUDA(bias); return fused_bias_act_op(input, bias, refer, act, grad, alpha, scale); } PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("fused_bias_act", &fused_bias_act, "fused bias act (CUDA)"); } ================================================ FILE: op/fused_bias_act_kernel.cu ================================================ // Copyright (c) 2019, NVIDIA Corporation. All rights reserved. // // This work is made available under the Nvidia Source Code License-NC. // To view a copy of this license, visit // https://nvlabs.github.io/stylegan2/license.html #include #include #include #include #include #include #include template static __global__ void fused_bias_act_kernel(scalar_t* out, const scalar_t* p_x, const scalar_t* p_b, const scalar_t* p_ref, int act, int grad, scalar_t alpha, scalar_t scale, int loop_x, int size_x, int step_b, int size_b, int use_bias, int use_ref) { int xi = blockIdx.x * loop_x * blockDim.x + threadIdx.x; scalar_t zero = 0.0; for (int loop_idx = 0; loop_idx < loop_x && xi < size_x; loop_idx++, xi += blockDim.x) { scalar_t x = p_x[xi]; if (use_bias) { x += p_b[(xi / step_b) % size_b]; } scalar_t ref = use_ref ? p_ref[xi] : zero; scalar_t y; switch (act * 10 + grad) { default: case 10: y = x; break; case 11: y = x; break; case 12: y = 0.0; break; case 30: y = (x > 0.0) ? x : x * alpha; break; case 31: y = (ref > 0.0) ? x : x * alpha; break; case 32: y = 0.0; break; } out[xi] = y * scale; } } torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer, int act, int grad, float alpha, float scale) { int curDevice = -1; cudaGetDevice(&curDevice); cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice); auto x = input.contiguous(); auto b = bias.contiguous(); auto ref = refer.contiguous(); int use_bias = b.numel() ? 1 : 0; int use_ref = ref.numel() ? 1 : 0; int size_x = x.numel(); int size_b = b.numel(); int step_b = 1; for (int i = 1 + 1; i < x.dim(); i++) { step_b *= x.size(i); } int loop_x = 4; int block_size = 4 * 32; int grid_size = (size_x - 1) / (loop_x * block_size) + 1; auto y = torch::empty_like(x); AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "fused_bias_act_kernel", [&] { fused_bias_act_kernel<<>>( y.data_ptr(), x.data_ptr(), b.data_ptr(), ref.data_ptr(), act, grad, alpha, scale, loop_x, size_x, step_b, size_b, use_bias, use_ref ); }); return y; } ================================================ FILE: op/upfirdn2d.cpp ================================================ #include torch::Tensor upfirdn2d_op(const torch::Tensor& input, const torch::Tensor& kernel, int up_x, int up_y, int down_x, int down_y, int pad_x0, int pad_x1, int pad_y0, int pad_y1); #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) torch::Tensor upfirdn2d(const torch::Tensor& input, const torch::Tensor& kernel, int up_x, int up_y, int down_x, int down_y, int pad_x0, int pad_x1, int pad_y0, int pad_y1) { CHECK_CUDA(input); CHECK_CUDA(kernel); return upfirdn2d_op(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1); } PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("upfirdn2d", &upfirdn2d, "upfirdn2d (CUDA)"); } ================================================ FILE: op/upfirdn2d.py ================================================ import os import torch from torch.nn import functional as F from torch.autograd import Function from torch.utils.cpp_extension import load from pdb import set_trace as st module_path = os.path.dirname(__file__) upfirdn2d_op = load( "upfirdn2d", sources=[ os.path.join(module_path, "upfirdn2d.cpp"), os.path.join(module_path, "upfirdn2d_kernel.cu"), ], ) class UpFirDn2dBackward(Function): @staticmethod def forward( ctx, grad_output, kernel, grad_kernel, up, down, pad, g_pad, in_size, out_size ): up_x, up_y = up down_x, down_y = down g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1 = g_pad grad_output = grad_output.reshape(-1, out_size[0], out_size[1], 1) grad_input = upfirdn2d_op.upfirdn2d( grad_output, grad_kernel, down_x, down_y, up_x, up_y, g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1, ) grad_input = grad_input.view(in_size[0], in_size[1], in_size[2], in_size[3]) ctx.save_for_backward(kernel) pad_x0, pad_x1, pad_y0, pad_y1 = pad ctx.up_x = up_x ctx.up_y = up_y ctx.down_x = down_x ctx.down_y = down_y ctx.pad_x0 = pad_x0 ctx.pad_x1 = pad_x1 ctx.pad_y0 = pad_y0 ctx.pad_y1 = pad_y1 ctx.in_size = in_size ctx.out_size = out_size return grad_input @staticmethod def backward(ctx, gradgrad_input): kernel, = ctx.saved_tensors gradgrad_input = gradgrad_input.reshape(-1, ctx.in_size[2], ctx.in_size[3], 1) gradgrad_out = upfirdn2d_op.upfirdn2d( gradgrad_input, kernel, ctx.up_x, ctx.up_y, ctx.down_x, ctx.down_y, ctx.pad_x0, ctx.pad_x1, ctx.pad_y0, ctx.pad_y1, ) # gradgrad_out = gradgrad_out.view(ctx.in_size[0], ctx.out_size[0], ctx.out_size[1], ctx.in_size[3]) gradgrad_out = gradgrad_out.view( ctx.in_size[0], ctx.in_size[1], ctx.out_size[0], ctx.out_size[1] ) return gradgrad_out, None, None, None, None, None, None, None, None class UpFirDn2d(Function): @staticmethod def forward(ctx, input, kernel, up, down, pad): up_x, up_y = up down_x, down_y = down pad_x0, pad_x1, pad_y0, pad_y1 = pad kernel_h, kernel_w = kernel.shape batch, channel, in_h, in_w = input.shape ctx.in_size = input.shape input = input.reshape(-1, in_h, in_w, 1) ctx.save_for_backward(kernel, torch.flip(kernel, [0, 1])) out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1 out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1 ctx.out_size = (out_h, out_w) ctx.up = (up_x, up_y) ctx.down = (down_x, down_y) ctx.pad = (pad_x0, pad_x1, pad_y0, pad_y1) g_pad_x0 = kernel_w - pad_x0 - 1 g_pad_y0 = kernel_h - pad_y0 - 1 g_pad_x1 = in_w * up_x - out_w * down_x + pad_x0 - up_x + 1 g_pad_y1 = in_h * up_y - out_h * down_y + pad_y0 - up_y + 1 ctx.g_pad = (g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1) out = upfirdn2d_op.upfirdn2d( input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1 ) # out = out.view(major, out_h, out_w, minor) out = out.view(-1, channel, out_h, out_w) return out @staticmethod def backward(ctx, grad_output): kernel, grad_kernel = ctx.saved_tensors grad_input = UpFirDn2dBackward.apply( grad_output, kernel, grad_kernel, ctx.up, ctx.down, ctx.pad, ctx.g_pad, ctx.in_size, ctx.out_size, ) return grad_input, None, None, None, None def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)): if input.device.type == "cpu": out = upfirdn2d_native( input, kernel, up, up, down, down, pad[0], pad[1], pad[0], pad[1] ) else: out = UpFirDn2d.apply( input, kernel, (up, up), (down, down), (pad[0], pad[1], pad[0], pad[1]) ) return out def upfirdn2d_native( input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1 ): _, channel, in_h, in_w = input.shape input = input.reshape(-1, in_h, in_w, 1) _, in_h, in_w, minor = input.shape kernel_h, kernel_w = kernel.shape out = input.view(-1, in_h, 1, in_w, 1, minor) out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1]) out = out.view(-1, in_h * up_y, in_w * up_x, minor) out = F.pad( out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)] ) out = out[ :, max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0), max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0), :, ] out = out.permute(0, 3, 1, 2) out = out.reshape( [-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1] ) w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w) out = F.conv2d(out, w) out = out.reshape( -1, minor, in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1, in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1, ) out = out.permute(0, 2, 3, 1) out = out[:, ::down_y, ::down_x, :] out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1 out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1 return out.view(-1, channel, out_h, out_w) ================================================ FILE: op/upfirdn2d_kernel.cu ================================================ // Copyright (c) 2019, NVIDIA Corporation. All rights reserved. // // This work is made available under the Nvidia Source Code License-NC. // To view a copy of this license, visit // https://nvlabs.github.io/stylegan2/license.html #include #include #include #include #include #include #include static __host__ __device__ __forceinline__ int floor_div(int a, int b) { int c = a / b; if (c * b > a) { c--; } return c; } struct UpFirDn2DKernelParams { int up_x; int up_y; int down_x; int down_y; int pad_x0; int pad_x1; int pad_y0; int pad_y1; int major_dim; int in_h; int in_w; int minor_dim; int kernel_h; int kernel_w; int out_h; int out_w; int loop_major; int loop_x; }; template __global__ void upfirdn2d_kernel_large(scalar_t *out, const scalar_t *input, const scalar_t *kernel, const UpFirDn2DKernelParams p) { int minor_idx = blockIdx.x * blockDim.x + threadIdx.x; int out_y = minor_idx / p.minor_dim; minor_idx -= out_y * p.minor_dim; int out_x_base = blockIdx.y * p.loop_x * blockDim.y + threadIdx.y; int major_idx_base = blockIdx.z * p.loop_major; if (out_x_base >= p.out_w || out_y >= p.out_h || major_idx_base >= p.major_dim) { return; } int mid_y = out_y * p.down_y + p.up_y - 1 - p.pad_y0; int in_y = min(max(floor_div(mid_y, p.up_y), 0), p.in_h); int h = min(max(floor_div(mid_y + p.kernel_h, p.up_y), 0), p.in_h) - in_y; int kernel_y = mid_y + p.kernel_h - (in_y + 1) * p.up_y; for (int loop_major = 0, major_idx = major_idx_base; loop_major < p.loop_major && major_idx < p.major_dim; loop_major++, major_idx++) { for (int loop_x = 0, out_x = out_x_base; loop_x < p.loop_x && out_x < p.out_w; loop_x++, out_x += blockDim.y) { int mid_x = out_x * p.down_x + p.up_x - 1 - p.pad_x0; int in_x = min(max(floor_div(mid_x, p.up_x), 0), p.in_w); int w = min(max(floor_div(mid_x + p.kernel_w, p.up_x), 0), p.in_w) - in_x; int kernel_x = mid_x + p.kernel_w - (in_x + 1) * p.up_x; const scalar_t *x_p = &input[((major_idx * p.in_h + in_y) * p.in_w + in_x) * p.minor_dim + minor_idx]; const scalar_t *k_p = &kernel[kernel_y * p.kernel_w + kernel_x]; int x_px = p.minor_dim; int k_px = -p.up_x; int x_py = p.in_w * p.minor_dim; int k_py = -p.up_y * p.kernel_w; scalar_t v = 0.0f; for (int y = 0; y < h; y++) { for (int x = 0; x < w; x++) { v += static_cast(*x_p) * static_cast(*k_p); x_p += x_px; k_p += k_px; } x_p += x_py - w * x_px; k_p += k_py - w * k_px; } out[((major_idx * p.out_h + out_y) * p.out_w + out_x) * p.minor_dim + minor_idx] = v; } } } template __global__ void upfirdn2d_kernel(scalar_t *out, const scalar_t *input, const scalar_t *kernel, const UpFirDn2DKernelParams p) { const int tile_in_h = ((tile_out_h - 1) * down_y + kernel_h - 1) / up_y + 1; const int tile_in_w = ((tile_out_w - 1) * down_x + kernel_w - 1) / up_x + 1; __shared__ volatile float sk[kernel_h][kernel_w]; __shared__ volatile float sx[tile_in_h][tile_in_w]; int minor_idx = blockIdx.x; int tile_out_y = minor_idx / p.minor_dim; minor_idx -= tile_out_y * p.minor_dim; tile_out_y *= tile_out_h; int tile_out_x_base = blockIdx.y * p.loop_x * tile_out_w; int major_idx_base = blockIdx.z * p.loop_major; if (tile_out_x_base >= p.out_w | tile_out_y >= p.out_h | major_idx_base >= p.major_dim) { return; } for (int tap_idx = threadIdx.x; tap_idx < kernel_h * kernel_w; tap_idx += blockDim.x) { int ky = tap_idx / kernel_w; int kx = tap_idx - ky * kernel_w; scalar_t v = 0.0; if (kx < p.kernel_w & ky < p.kernel_h) { v = kernel[(p.kernel_h - 1 - ky) * p.kernel_w + (p.kernel_w - 1 - kx)]; } sk[ky][kx] = v; } for (int loop_major = 0, major_idx = major_idx_base; loop_major < p.loop_major & major_idx < p.major_dim; loop_major++, major_idx++) { for (int loop_x = 0, tile_out_x = tile_out_x_base; loop_x < p.loop_x & tile_out_x < p.out_w; loop_x++, tile_out_x += tile_out_w) { int tile_mid_x = tile_out_x * down_x + up_x - 1 - p.pad_x0; int tile_mid_y = tile_out_y * down_y + up_y - 1 - p.pad_y0; int tile_in_x = floor_div(tile_mid_x, up_x); int tile_in_y = floor_div(tile_mid_y, up_y); __syncthreads(); for (int in_idx = threadIdx.x; in_idx < tile_in_h * tile_in_w; in_idx += blockDim.x) { int rel_in_y = in_idx / tile_in_w; int rel_in_x = in_idx - rel_in_y * tile_in_w; int in_x = rel_in_x + tile_in_x; int in_y = rel_in_y + tile_in_y; scalar_t v = 0.0; if (in_x >= 0 & in_y >= 0 & in_x < p.in_w & in_y < p.in_h) { v = input[((major_idx * p.in_h + in_y) * p.in_w + in_x) * p.minor_dim + minor_idx]; } sx[rel_in_y][rel_in_x] = v; } __syncthreads(); for (int out_idx = threadIdx.x; out_idx < tile_out_h * tile_out_w; out_idx += blockDim.x) { int rel_out_y = out_idx / tile_out_w; int rel_out_x = out_idx - rel_out_y * tile_out_w; int out_x = rel_out_x + tile_out_x; int out_y = rel_out_y + tile_out_y; int mid_x = tile_mid_x + rel_out_x * down_x; int mid_y = tile_mid_y + rel_out_y * down_y; int in_x = floor_div(mid_x, up_x); int in_y = floor_div(mid_y, up_y); int rel_in_x = in_x - tile_in_x; int rel_in_y = in_y - tile_in_y; int kernel_x = (in_x + 1) * up_x - mid_x - 1; int kernel_y = (in_y + 1) * up_y - mid_y - 1; scalar_t v = 0.0; #pragma unroll for (int y = 0; y < kernel_h / up_y; y++) #pragma unroll for (int x = 0; x < kernel_w / up_x; x++) v += sx[rel_in_y + y][rel_in_x + x] * sk[kernel_y + y * up_y][kernel_x + x * up_x]; if (out_x < p.out_w & out_y < p.out_h) { out[((major_idx * p.out_h + out_y) * p.out_w + out_x) * p.minor_dim + minor_idx] = v; } } } } } torch::Tensor upfirdn2d_op(const torch::Tensor &input, const torch::Tensor &kernel, int up_x, int up_y, int down_x, int down_y, int pad_x0, int pad_x1, int pad_y0, int pad_y1) { int curDevice = -1; cudaGetDevice(&curDevice); cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice); UpFirDn2DKernelParams p; auto x = input.contiguous(); auto k = kernel.contiguous(); p.major_dim = x.size(0); p.in_h = x.size(1); p.in_w = x.size(2); p.minor_dim = x.size(3); p.kernel_h = k.size(0); p.kernel_w = k.size(1); p.up_x = up_x; p.up_y = up_y; p.down_x = down_x; p.down_y = down_y; p.pad_x0 = pad_x0; p.pad_x1 = pad_x1; p.pad_y0 = pad_y0; p.pad_y1 = pad_y1; p.out_h = (p.in_h * p.up_y + p.pad_y0 + p.pad_y1 - p.kernel_h + p.down_y) / p.down_y; p.out_w = (p.in_w * p.up_x + p.pad_x0 + p.pad_x1 - p.kernel_w + p.down_x) / p.down_x; auto out = at::empty({p.major_dim, p.out_h, p.out_w, p.minor_dim}, x.options()); int mode = -1; int tile_out_h = -1; int tile_out_w = -1; if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 && p.kernel_h <= 4 && p.kernel_w <= 4) { mode = 1; tile_out_h = 16; tile_out_w = 64; } if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 && p.kernel_h <= 3 && p.kernel_w <= 3) { mode = 2; tile_out_h = 16; tile_out_w = 64; } if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 && p.kernel_h <= 4 && p.kernel_w <= 4) { mode = 3; tile_out_h = 16; tile_out_w = 64; } if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 && p.kernel_h <= 2 && p.kernel_w <= 2) { mode = 4; tile_out_h = 16; tile_out_w = 64; } if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 && p.kernel_h <= 4 && p.kernel_w <= 4) { mode = 5; tile_out_h = 8; tile_out_w = 32; } if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 && p.kernel_h <= 2 && p.kernel_w <= 2) { mode = 6; tile_out_h = 8; tile_out_w = 32; } dim3 block_size; dim3 grid_size; if (tile_out_h > 0 && tile_out_w > 0) { p.loop_major = (p.major_dim - 1) / 16384 + 1; p.loop_x = 1; block_size = dim3(32 * 8, 1, 1); grid_size = dim3(((p.out_h - 1) / tile_out_h + 1) * p.minor_dim, (p.out_w - 1) / (p.loop_x * tile_out_w) + 1, (p.major_dim - 1) / p.loop_major + 1); } else { p.loop_major = (p.major_dim - 1) / 16384 + 1; p.loop_x = 4; block_size = dim3(4, 32, 1); grid_size = dim3((p.out_h * p.minor_dim - 1) / block_size.x + 1, (p.out_w - 1) / (p.loop_x * block_size.y) + 1, (p.major_dim - 1) / p.loop_major + 1); } AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&] { switch (mode) { case 1: upfirdn2d_kernel <<>>(out.data_ptr(), x.data_ptr(), k.data_ptr(), p); break; case 2: upfirdn2d_kernel <<>>(out.data_ptr(), x.data_ptr(), k.data_ptr(), p); break; case 3: upfirdn2d_kernel <<>>(out.data_ptr(), x.data_ptr(), k.data_ptr(), p); break; case 4: upfirdn2d_kernel <<>>(out.data_ptr(), x.data_ptr(), k.data_ptr(), p); break; case 5: upfirdn2d_kernel <<>>(out.data_ptr(), x.data_ptr(), k.data_ptr(), p); break; case 6: upfirdn2d_kernel <<>>(out.data_ptr(), x.data_ptr(), k.data_ptr(), p); break; default: upfirdn2d_kernel_large<<>>( out.data_ptr(), x.data_ptr(), k.data_ptr(), p); } }); return out; } ================================================ FILE: options.py ================================================ import configargparse from munch import * from pdb import set_trace as st class BaseOptions(): def __init__(self): self.parser = configargparse.ArgumentParser() self.initialized = False def initialize(self): # Dataset options dataset = self.parser.add_argument_group('dataset') dataset.add_argument("--dataset_path", type=str, default='./datasets/FFHQ', help="path to the lmdb dataset") # Experiment Options experiment = self.parser.add_argument_group('experiment') experiment.add_argument('--config', is_config_file=True, help='config file path') experiment.add_argument("--expname", type=str, default='debug', help='experiment name') experiment.add_argument("--ckpt", type=str, default='300000', help="path to the checkpoints to resume training") experiment.add_argument("--continue_training", action="store_true", help="continue training the model") # Training loop options training = self.parser.add_argument_group('training') training.add_argument("--checkpoints_dir", type=str, default='./checkpoint', help='checkpoints directory name') training.add_argument("--iter", type=int, default=300000, help="total number of training iterations") training.add_argument("--batch", type=int, default=4, help="batch sizes for each GPU. A single RTX2080 can fit batch=4, chunck=1 into memory.") training.add_argument("--chunk", type=int, default=4, help='number of samples within a batch to processed in parallel, decrease if running out of memory') training.add_argument("--val_n_sample", type=int, default=8, help="number of test samples generated during training") training.add_argument("--d_reg_every", type=int, default=16, help="interval for applying r1 regularization to the StyleGAN generator") training.add_argument("--g_reg_every", type=int, default=4, help="interval for applying path length regularization to the StyleGAN generator") training.add_argument("--local_rank", type=int, default=0, help="local rank for distributed training") training.add_argument("--mixing", type=float, default=0.9, help="probability of latent code mixing") training.add_argument("--lr", type=float, default=0.002, help="learning rate") training.add_argument("--r1", type=float, default=10, help="weight of the r1 regularization") training.add_argument("--view_lambda", type=float, default=15, help="weight of the viewpoint regularization") training.add_argument("--eikonal_lambda", type=float, default=0.1, help="weight of the eikonal regularization") training.add_argument("--min_surf_lambda", type=float, default=0.05, help="weight of the minimal surface regularization") training.add_argument("--min_surf_beta", type=float, default=100.0, help="weight of the minimal surface regularization") training.add_argument("--path_regularize", type=float, default=2, help="weight of the path length regularization") training.add_argument("--path_batch_shrink", type=int, default=2, help="batch size reducing factor for the path length regularization (reduce memory consumption)") training.add_argument("--wandb", action="store_true", help="use weights and biases logging") training.add_argument("--no_sphere_init", action="store_true", help="do not initialize the volume renderer with a sphere SDF") # Inference Options inference = self.parser.add_argument_group('inference') inference.add_argument("--results_dir", type=str, default='./evaluations', help='results/evaluations directory name') inference.add_argument("--truncation_ratio", type=float, default=0.5, help="truncation ratio, controls the diversity vs. quality tradeoff. Higher truncation ratio would generate more diverse results") inference.add_argument("--truncation_mean", type=int, default=10000, help="number of vectors to calculate mean for the truncation") inference.add_argument("--identities", type=int, default=16, help="number of identities to be generated") inference.add_argument("--num_views_per_id", type=int, default=1, help="number of viewpoints generated per identity") inference.add_argument("--no_surface_renderings", action="store_true", help="when true, only RGB outputs will be generated. otherwise, both RGB and depth videos/renderings will be generated. this cuts the processing time per video") inference.add_argument("--fixed_camera_angles", action="store_true", help="when true, the generator will render indentities from a fixed set of camera angles.") inference.add_argument("--azim_video", action="store_true", help="when true, the camera trajectory will travel along the azimuth direction. Otherwise, the camera will travel along an ellipsoid trajectory.") # Generator options model = self.parser.add_argument_group('model') model.add_argument("--size", type=int, default=256, help="image sizes for the model") model.add_argument("--style_dim", type=int, default=256, help="number of style input dimensions") model.add_argument("--channel_multiplier", type=int, default=2, help="channel multiplier factor for the StyleGAN decoder. config-f = 2, else = 1") model.add_argument("--n_mlp", type=int, default=8, help="number of mlp layers in stylegan's mapping network") model.add_argument("--lr_mapping", type=float, default=0.01, help='learning rate reduction for mapping network MLP layers') model.add_argument("--renderer_spatial_output_dim", type=int, default=64, help='spatial resolution of the StyleGAN decoder inputs') model.add_argument("--project_noise", action='store_true', help='when true, use geometry-aware noise projection to reduce flickering effects (see supplementary section C.1 in the paper). warning: processing time significantly increases with this flag to ~20 minutes per video.') # Camera options camera = self.parser.add_argument_group('camera') camera.add_argument("--uniform", action="store_true", help="when true, the camera position is sampled from uniform distribution. Gaussian distribution is the default") camera.add_argument("--azim", type=float, default=0.3, help="camera azimuth angle std/range in Radians") camera.add_argument("--elev", type=float, default=0.15, help="camera elevation angle std/range in Radians") camera.add_argument("--fov", type=float, default=6, help="camera field of view half angle in Degrees") camera.add_argument("--dist_radius", type=float, default=0.12, help="radius of points sampling distance from the origin. determines the near and far fields") # Volume Renderer options rendering = self.parser.add_argument_group('rendering') # MLP model parameters rendering.add_argument("--depth", type=int, default=8, help='layers in network') rendering.add_argument("--width", type=int, default=256, help='channels per layer') # Volume representation options rendering.add_argument("--no_sdf", action='store_true', help='By default, the raw MLP outputs represent an underline signed distance field (SDF). When true, the MLP outputs represent the traditional NeRF density field.') rendering.add_argument("--no_z_normalize", action='store_true', help='By default, the model normalizes input coordinates such that the z coordinate is in [-1,1]. When true that feature is disabled.') rendering.add_argument("--static_viewdirs", action='store_true', help='when true, use static viewing direction input to the MLP') # Ray intergration options rendering.add_argument("--N_samples", type=int, default=24, help='number of samples per ray') rendering.add_argument("--no_offset_sampling", action='store_true', help='when true, use random stratified sampling when rendering the volume, otherwise offset sampling is used. (See Equation (3) in Sec. 3.2 of the paper)') rendering.add_argument("--perturb", type=float, default=1., help='set to 0. for no jitter, 1. for jitter') rendering.add_argument("--raw_noise_std", type=float, default=0., help='std dev of noise added to regularize sigma_a output, 1e0 recommended') rendering.add_argument("--force_background", action='store_true', help='force the last depth sample to act as background in case of a transparent ray') # Set volume renderer outputs rendering.add_argument("--return_xyz", action='store_true', help='when true, the volume renderer also returns the xyz point could of the surface. This point cloud is used to produce depth map renderings') rendering.add_argument("--return_sdf", action='store_true', help='when true, the volume renderer also returns the SDF network outputs for each location in the volume') self.initialized = True def parse(self): self.opt = Munch() if not self.initialized: self.initialize() try: args = self.parser.parse_args() except: # solves argparse error in google colab args = self.parser.parse_args(args=[]) for group in self.parser._action_groups[2:]: title = group.title self.opt[title] = Munch() for action in group._group_actions: dest = action.dest self.opt[title][dest] = args.__getattribute__(dest) return self.opt ================================================ FILE: prepare_data.py ================================================ import argparse from io import BytesIO import multiprocessing from functools import partial from PIL import Image import lmdb from tqdm import tqdm from torchvision import datasets from torchvision.transforms import functional as trans_fn from pdb import set_trace as st def resize_and_convert(img, size, resample): img = trans_fn.resize(img, size, resample) img = trans_fn.center_crop(img, size) buffer = BytesIO() img.save(buffer, format="png") val = buffer.getvalue() return val def resize_multiple( img, sizes=(128, 256, 512, 1024), resample=Image.LANCZOS): imgs = [] for size in sizes: imgs.append(resize_and_convert(img, size, resample)) return imgs def resize_worker(img_file, sizes, resample): i, file = img_file img = Image.open(file) img = img.convert("RGB") out = resize_multiple(img, sizes=sizes, resample=resample) return i, out def prepare( env, dataset, n_worker, sizes=(128, 256, 512, 1024), resample=Image.LANCZOS ): resize_fn = partial(resize_worker, sizes=sizes, resample=resample) files = sorted(dataset.imgs, key=lambda x: x[0]) files = [(i, file) for i, (file, label) in enumerate(files)] total = 0 with multiprocessing.Pool(n_worker) as pool: for i, imgs in tqdm(pool.imap_unordered(resize_fn, files)): for size, img in zip(sizes, imgs): key = f"{size}-{str(i).zfill(5)}".encode("utf-8") with env.begin(write=True) as txn: txn.put(key, img) total += 1 with env.begin(write=True) as txn: txn.put("length".encode("utf-8"), str(total).encode("utf-8")) if __name__ == "__main__": parser = argparse.ArgumentParser(description="Preprocess images for model training") parser.add_argument("--size", type=str, default="64,512,1024", help="resolutions of images for the dataset") parser.add_argument("--n_worker", type=int, default=32, help="number of workers for preparing dataset") parser.add_argument("--resample", type=str, default="lanczos", help="resampling methods for resizing images") parser.add_argument("--out_path", type=str, default="datasets/FFHQ", help="Target path of the output lmdb dataset") parser.add_argument("in_path", type=str, help="path to the input image dataset") args = parser.parse_args() resample_map = {"lanczos": Image.LANCZOS, "bilinear": Image.BILINEAR} resample = resample_map[args.resample] sizes = [int(s.strip()) for s in args.size.split(",")] print(f"Make dataset of image sizes:", ", ".join(str(s) for s in sizes)) imgset = datasets.ImageFolder(args.in_path) with lmdb.open(args.out_path, map_size=1024 ** 4, readahead=False) as env: prepare(env, imgset, args.n_worker, sizes=sizes, resample=resample) ================================================ FILE: render_video.py ================================================ import os import torch import trimesh import numpy as np import skvideo.io from munch import * from PIL import Image from tqdm import tqdm from torch.nn import functional as F from torch.utils import data from torchvision import utils from torchvision import transforms from options import BaseOptions from model import Generator from utils import ( generate_camera_params, align_volume, extract_mesh_with_marching_cubes, xyz2mesh, create_cameras, create_mesh_renderer, add_textures, ) from pytorch3d.structures import Meshes from pdb import set_trace as st torch.random.manual_seed(1234) def render_video(opt, g_ema, surface_g_ema, device, mean_latent, surface_mean_latent): g_ema.eval() if not opt.no_surface_renderings or opt.project_noise: surface_g_ema.eval() images = torch.Tensor(0, 3, opt.size, opt.size) num_frames = 250 # Generate video trajectory trajectory = np.zeros((num_frames,3), dtype=np.float32) # set camera trajectory # sweep azimuth angles (4 seconds) if opt.azim_video: t = np.linspace(0, 1, num_frames) elev = 0 fov = opt.camera.fov if opt.camera.uniform: azim = opt.camera.azim * np.cos(t * 2 * np.pi) else: azim = 1.5 * opt.camera.azim * np.cos(t * 2 * np.pi) trajectory[:num_frames,0] = azim trajectory[:num_frames,1] = elev trajectory[:num_frames,2] = fov # elipsoid sweep (4 seconds) else: t = np.linspace(0, 1, num_frames) fov = opt.camera.fov #+ 1 * np.sin(t * 2 * np.pi) if opt.camera.uniform: elev = opt.camera.elev / 2 + opt.camera.elev / 2 * np.sin(t * 2 * np.pi) azim = opt.camera.azim * np.cos(t * 2 * np.pi) else: elev = 1.5 * opt.camera.elev * np.sin(t * 2 * np.pi) azim = 1.5 * opt.camera.azim * np.cos(t * 2 * np.pi) trajectory[:num_frames,0] = azim trajectory[:num_frames,1] = elev trajectory[:num_frames,2] = fov trajectory = torch.from_numpy(trajectory).to(device) # generate input parameters for the camera trajectory # sample_cam_poses, sample_focals, sample_near, sample_far = \ # generate_camera_params(trajectory, opt.renderer_output_size, device, dist_radius=opt.camera.dist_radius) sample_cam_extrinsics, sample_focals, sample_near, sample_far, _ = \ generate_camera_params(opt.renderer_output_size, device, locations=trajectory[:,:2], fov_ang=trajectory[:,2:], dist_radius=opt.camera.dist_radius) # In case of noise projection, generate input parameters for the frontal position. # The reference mesh for the noise projection is extracted from the frontal position. # For more details see section C.1 in the supplementary material. if opt.project_noise: frontal_pose = torch.tensor([[0.0,0.0,opt.camera.fov]]).to(device) # frontal_cam_pose, frontal_focals, frontal_near, frontal_far = \ # generate_camera_params(frontal_pose, opt.surf_extraction_output_size, device, dist_radius=opt.camera.dist_radius) frontal_cam_pose, frontal_focals, frontal_near, frontal_far, _ = \ generate_camera_params(opt.surf_extraction_output_size, device, location=frontal_pose[:,:2], fov_ang=frontal_pose[:,2:], dist_radius=opt.camera.dist_radius) # create geometry renderer (renders the depth maps) cameras = create_cameras(azim=np.rad2deg(trajectory[0,0].cpu().numpy()), elev=np.rad2deg(trajectory[0,1].cpu().numpy()), dist=1, device=device) renderer = create_mesh_renderer(cameras, image_size=512, specular_color=((0,0,0),), ambient_color=((0.1,.1,.1),), diffuse_color=((0.75,.75,.75),), device=device) suffix = '_azim' if opt.azim_video else '_elipsoid' # generate videos for i in range(opt.identities): print('Processing identity {}/{}...'.format(i+1, opt.identities)) chunk = 1 sample_z = torch.randn(1, opt.style_dim, device=device).repeat(chunk,1) video_filename = 'sample_video_{}{}.mp4'.format(i,suffix) writer = skvideo.io.FFmpegWriter(os.path.join(opt.results_dst_dir, video_filename), outputdict={'-pix_fmt': 'yuv420p', '-crf': '10'}) if not opt.no_surface_renderings: depth_video_filename = 'sample_depth_video_{}{}.mp4'.format(i,suffix) depth_writer = skvideo.io.FFmpegWriter(os.path.join(opt.results_dst_dir, depth_video_filename), outputdict={'-pix_fmt': 'yuv420p', '-crf': '1'}) ####################### Extract initial surface mesh from the frontal viewpoint ############# # For more details see section C.1 in the supplementary material. if opt.project_noise: with torch.no_grad(): frontal_surface_out = surface_g_ema([sample_z], frontal_cam_pose, frontal_focals, frontal_near, frontal_far, truncation=opt.truncation_ratio, truncation_latent=surface_mean_latent, return_sdf=True) frontal_sdf = frontal_surface_out[2].cpu() print('Extracting Identity {} Frontal view Marching Cubes for consistent video rendering'.format(i)) frostum_aligned_frontal_sdf = align_volume(frontal_sdf) del frontal_sdf try: frontal_marching_cubes_mesh = extract_mesh_with_marching_cubes(frostum_aligned_frontal_sdf) except ValueError: frontal_marching_cubes_mesh = None if frontal_marching_cubes_mesh != None: frontal_marching_cubes_mesh_filename = os.path.join(opt.results_dst_dir,'sample_{}_frontal_marching_cubes_mesh{}.obj'.format(i,suffix)) with open(frontal_marching_cubes_mesh_filename, 'w') as f: frontal_marching_cubes_mesh.export(f,file_type='obj') del frontal_surface_out torch.cuda.empty_cache() ############################################################################################# for j in tqdm(range(0, num_frames, chunk)): with torch.no_grad(): out = g_ema([sample_z], sample_cam_extrinsics[j:j+chunk], sample_focals[j:j+chunk], sample_near[j:j+chunk], sample_far[j:j+chunk], truncation=opt.truncation_ratio, truncation_latent=mean_latent, randomize_noise=False, project_noise=opt.project_noise, mesh_path=frontal_marching_cubes_mesh_filename if opt.project_noise else None) rgb = out[0].cpu() # this is done to fit to RTX2080 RAM size (11GB) del out torch.cuda.empty_cache() # Convert RGB from [-1, 1] to [0,255] rgb = 127.5 * (rgb.clamp(-1,1).permute(0,2,3,1).cpu().numpy() + 1) # Add RGB, frame to video for k in range(chunk): writer.writeFrame(rgb[k]) ########## Extract surface ########## if not opt.no_surface_renderings: scale = surface_g_ema.renderer.out_im_res / g_ema.renderer.out_im_res surface_sample_focals = sample_focals * scale surface_out = surface_g_ema([sample_z], sample_cam_extrinsics[j:j+chunk], surface_sample_focals[j:j+chunk], sample_near[j:j+chunk], sample_far[j:j+chunk], truncation=opt.truncation_ratio, truncation_latent=surface_mean_latent, return_xyz=True) xyz = surface_out[2].cpu() # this is done to fit to RTX2080 RAM size (11GB) del surface_out torch.cuda.empty_cache() # Render mesh for video depth_mesh = xyz2mesh(xyz) mesh = Meshes( verts=[torch.from_numpy(np.asarray(depth_mesh.vertices)).to(torch.float32).to(device)], faces = [torch.from_numpy(np.asarray(depth_mesh.faces)).to(torch.float32).to(device)], textures=None, verts_normals=[torch.from_numpy(np.copy(np.asarray(depth_mesh.vertex_normals))).to(torch.float32).to(device)], ) mesh = add_textures(mesh) cameras = create_cameras(azim=np.rad2deg(trajectory[j,0].cpu().numpy()), elev=np.rad2deg(trajectory[j,1].cpu().numpy()), fov=2*trajectory[j,2].cpu().numpy(), dist=1, device=device) renderer = create_mesh_renderer(cameras, image_size=512, light_location=((0.0,1.0,5.0),), specular_color=((0.2,0.2,0.2),), ambient_color=((0.1,0.1,0.1),), diffuse_color=((0.65,.65,.65),), device=device) mesh_image = 255 * renderer(mesh).cpu().numpy() mesh_image = mesh_image[...,:3] # Add depth frame to video for k in range(chunk): depth_writer.writeFrame(mesh_image[k]) # Close video writers writer.close() if not opt.no_surface_renderings: depth_writer.close() if __name__ == "__main__": device = "cuda" opt = BaseOptions().parse() opt.model.is_test = True opt.model.style_dim = 256 opt.model.freeze_renderer = False opt.inference.size = opt.model.size opt.inference.camera = opt.camera opt.inference.renderer_output_size = opt.model.renderer_spatial_output_dim opt.inference.style_dim = opt.model.style_dim opt.inference.project_noise = opt.model.project_noise opt.rendering.perturb = 0 opt.rendering.force_background = True opt.rendering.static_viewdirs = True opt.rendering.return_sdf = True opt.rendering.N_samples = 64 # find checkpoint directory # check if there's a fully trained model checkpoints_dir = 'full_models' checkpoint_path = os.path.join(checkpoints_dir, opt.experiment.expname + '.pt') if os.path.isfile(checkpoint_path): # define results directory name result_model_dir = 'final_model' else: checkpoints_dir = os.path.join('checkpoint', opt.experiment.expname, 'full_pipeline') checkpoint_path = os.path.join(checkpoints_dir, 'models_{}.pt'.format(opt.experiment.ckpt.zfill(7))) # define results directory name result_model_dir = 'iter_{}'.format(opt.experiment.ckpt.zfill(7)) results_dir_basename = os.path.join(opt.inference.results_dir, opt.experiment.expname) opt.inference.results_dst_dir = os.path.join(results_dir_basename, result_model_dir, 'videos') if opt.model.project_noise: opt.inference.results_dst_dir = os.path.join(opt.inference.results_dst_dir, 'with_noise_projection') os.makedirs(opt.inference.results_dst_dir, exist_ok=True) # load saved model checkpoint = torch.load(checkpoint_path) # load image generation model g_ema = Generator(opt.model, opt.rendering).to(device) # temp fix because of wrong noise sizes pretrained_weights_dict = checkpoint["g_ema"] model_dict = g_ema.state_dict() for k, v in pretrained_weights_dict.items(): if v.size() == model_dict[k].size(): model_dict[k] = v g_ema.load_state_dict(model_dict) # load a the volume renderee to a second that extracts surfaces at 128x128x128 if not opt.inference.no_surface_renderings or opt.model.project_noise: opt['surf_extraction'] = Munch() opt.surf_extraction.rendering = opt.rendering opt.surf_extraction.model = opt.model.copy() opt.surf_extraction.model.renderer_spatial_output_dim = 128 opt.surf_extraction.rendering.N_samples = opt.surf_extraction.model.renderer_spatial_output_dim opt.surf_extraction.rendering.return_xyz = True opt.surf_extraction.rendering.return_sdf = True opt.inference.surf_extraction_output_size = opt.surf_extraction.model.renderer_spatial_output_dim surface_g_ema = Generator(opt.surf_extraction.model, opt.surf_extraction.rendering, full_pipeline=False).to(device) # Load weights to surface extractor surface_extractor_dict = surface_g_ema.state_dict() for k, v in pretrained_weights_dict.items(): if k in surface_extractor_dict.keys() and v.size() == surface_extractor_dict[k].size(): surface_extractor_dict[k] = v surface_g_ema.load_state_dict(surface_extractor_dict) else: surface_g_ema = None # get the mean latent vector for g_ema if opt.inference.truncation_ratio < 1: with torch.no_grad(): mean_latent = g_ema.mean_latent(opt.inference.truncation_mean, device) else: mean_latent = None # get the mean latent vector for surface_g_ema if not opt.inference.no_surface_renderings or opt.model.project_noise: surface_mean_latent = mean_latent[0] else: surface_mean_latent = None render_video(opt.inference, g_ema, surface_g_ema, device, mean_latent, surface_mean_latent) ================================================ FILE: requirements.txt ================================================ lmdb numpy ninja pillow requests tqdm scipy scikit-image scikit-video trimesh[easy] configargparse munch wandb ================================================ FILE: scripts/train_afhq_full_pipeline_512x512.sh ================================================ python -m torch.distributed.launch --nproc_per_node 4 new_train.py --batch 8 --chunk 4 --azim 0.15 --r1 50.0 --expname afhq512x512 --dataset_path ./datasets/AFHQ/train/ --size 512 --wandb ================================================ FILE: scripts/train_afhq_vol_renderer.sh ================================================ python -m torch.distributed.launch --nproc_per_node 2 train_volume_renderer.py --batch 12 --chunk 6 --expname afhq_sdf_vol_renderer --dataset_path ./datasets/AFHQ/train --azim 0.15 --wandb ================================================ FILE: scripts/train_ffhq_full_pipeline_1024x1024.sh ================================================ python -m torch.distributed.launch --nproc_per_node 4 train_full_pipeline.py --batch 8 --chunk 2 --expname ffhq1024x1024 --size 1024 --wandb ================================================ FILE: scripts/train_ffhq_vol_renderer.sh ================================================ python -m torch.distributed.launch --nproc_per_node 2 train_volume_renderer.py --batch 12 --chunk 6 --expname ffhq_sdf_vol_renderer --dataset_path ./datasets/FFHQ/ --wandb ================================================ FILE: train_full_pipeline.py ================================================ import argparse import math import random import os import yaml import numpy as np import torch from torch import nn, autograd, optim from torch.nn import functional as F from torch.utils import data import torch.distributed as dist from torchvision import transforms, utils from tqdm import tqdm from PIL import Image from losses import * from options import BaseOptions from model import Generator, Discriminator from dataset import MultiResolutionDataset from utils import data_sampler, requires_grad, accumulate, sample_data, make_noise, mixing_noise, generate_camera_params from distributed import get_rank, synchronize, reduce_loss_dict, reduce_sum, get_world_size try: import wandb except ImportError: wandb = None def train(opt, experiment_opt, loader, generator, discriminator, g_optim, d_optim, g_ema, device): loader = sample_data(loader) pbar = range(opt.iter) if get_rank() == 0: pbar = tqdm(pbar, initial=opt.start_iter, dynamic_ncols=True, smoothing=0.01) mean_path_length = 0 d_loss_val = 0 g_loss_val = 0 r1_loss = torch.tensor(0.0, device=device) g_gan_loss = torch.tensor(0.0, device=device) path_loss = torch.tensor(0.0, device=device) path_lengths = torch.tensor(0.0, device=device) mean_path_length_avg = 0 loss_dict = {} if opt.distributed: g_module = generator.module d_module = discriminator.module else: g_module = generator d_module = discriminator accum = 0.5 ** (32 / (10 * 1000)) sample_z = [torch.randn(opt.val_n_sample, opt.style_dim, device=device).repeat_interleave(8, dim=0)] sample_cam_extrinsics, sample_focals, sample_near, sample_far, _ = generate_camera_params(opt.renderer_output_size, device, batch=opt.val_n_sample, sweep=True, uniform=opt.camera.uniform, azim_range=opt.camera.azim, elev_range=opt.camera.elev, fov_ang=opt.camera.fov, dist_radius=opt.camera.dist_radius) for idx in pbar: i = idx + opt.start_iter if i > opt.iter: print("Done!") break requires_grad(generator, False) requires_grad(discriminator, True) discriminator.zero_grad() d_regularize = i % opt.d_reg_every == 0 real_imgs, real_thumb_imgs = next(loader) real_imgs = real_imgs.to(device) real_thumb_imgs = real_thumb_imgs.to(device) noise = mixing_noise(opt.batch, opt.style_dim, opt.mixing, device) cam_extrinsics, focal, near, far, gt_viewpoints = generate_camera_params(opt.renderer_output_size, device, batch=opt.batch, uniform=opt.camera.uniform, azim_range=opt.camera.azim, elev_range=opt.camera.elev, fov_ang=opt.camera.fov, dist_radius=opt.camera.dist_radius) for j in range(0, opt.batch, opt.chunk): curr_real_imgs = real_imgs[j:j+opt.chunk] curr_real_thumb_imgs = real_thumb_imgs[j:j+opt.chunk] curr_noise = [n[j:j+opt.chunk] for n in noise] gen_imgs, _ = generator(curr_noise, cam_extrinsics[j:j+opt.chunk], focal[j:j+opt.chunk], near[j:j+opt.chunk], far[j:j+opt.chunk]) fake_pred = discriminator(gen_imgs.detach()) if d_regularize: curr_real_imgs.requires_grad = True curr_real_thumb_imgs.requires_grad = True real_pred = discriminator(curr_real_imgs) d_gan_loss = d_logistic_loss(real_pred, fake_pred) if d_regularize: grad_penalty = d_r1_loss(real_pred, curr_real_imgs) r1_loss = opt.r1 * 0.5 * grad_penalty * opt.d_reg_every else: r1_loss = torch.zeros_like(r1_loss) d_loss = d_gan_loss + r1_loss d_loss.backward() d_optim.step() loss_dict["d"] = d_gan_loss loss_dict["real_score"] = real_pred.mean() loss_dict["fake_score"] = fake_pred.mean() if d_regularize or i == opt.start_iter: loss_dict["r1"] = r1_loss.mean() requires_grad(generator, True) requires_grad(discriminator, False) for j in range(0, opt.batch, opt.chunk): noise = mixing_noise(opt.chunk, opt.style_dim, opt.mixing, device) cam_extrinsics, focal, near, far, gt_viewpoints = generate_camera_params(opt.renderer_output_size, device, batch=opt.chunk, uniform=opt.camera.uniform, azim_range=opt.camera.azim, elev_range=opt.camera.elev, fov_ang=opt.camera.fov, dist_radius=opt.camera.dist_radius) fake_img, _ = generator(noise, cam_extrinsics, focal, near, far) fake_pred = discriminator(fake_img) g_gan_loss = g_nonsaturating_loss(fake_pred) g_loss = g_gan_loss g_loss.backward() g_optim.step() generator.zero_grad() loss_dict["g"] = g_gan_loss # generator path regularization g_regularize = (opt.g_reg_every > 0) and (i % opt.g_reg_every == 0) if g_regularize: path_batch_size = max(1, opt.batch // opt.path_batch_shrink) path_noise = mixing_noise(path_batch_size, opt.style_dim, opt.mixing, device) path_cam_extrinsics, path_focal, path_near, path_far, _ = generate_camera_params(opt.renderer_output_size, device, batch=path_batch_size, uniform=opt.camera.uniform, azim_range=opt.camera.azim, elev_range=opt.camera.elev, fov_ang=opt.camera.fov, dist_radius=opt.camera.dist_radius) for j in range(0, path_batch_size, opt.chunk): path_fake_img, path_latents = generator(path_noise, path_cam_extrinsics, path_focal, path_near, path_far, return_latents=True) path_loss, mean_path_length, path_lengths = g_path_regularize( path_fake_img, path_latents, mean_path_length ) weighted_path_loss = opt.path_regularize * opt.g_reg_every * path_loss# * opt.chunk / path_batch_size if opt.path_batch_shrink: weighted_path_loss += 0 * path_fake_img[0, 0, 0, 0] weighted_path_loss.backward() g_optim.step() generator.zero_grad() mean_path_length_avg = (reduce_sum(mean_path_length).item() / get_world_size()) loss_dict["path"] = path_loss loss_dict["path_length"] = path_lengths.mean() accumulate(g_ema, g_module, accum) loss_reduced = reduce_loss_dict(loss_dict) d_loss_val = loss_reduced["d"].mean().item() g_loss_val = loss_reduced["g"].mean().item() r1_val = loss_reduced["r1"].mean().item() real_score_val = loss_reduced["real_score"].mean().item() fake_score_val = loss_reduced["fake_score"].mean().item() path_loss_val = loss_reduced["path"].mean().item() path_length_val = loss_reduced["path_length"].mean().item() if get_rank() == 0: pbar.set_description( (f"d: {d_loss_val:.4f}; g: {g_loss_val:.4f}; r1: {r1_val:.4f}; path: {path_loss_val:.4f}") ) if i % 1000 == 0 or i == opt.start_iter: with torch.no_grad(): thumbs_samples = torch.Tensor(0, 3, opt.renderer_output_size, opt.renderer_output_size) samples = torch.Tensor(0, 3, opt.size, opt.size) step_size = 8 mean_latent = g_module.mean_latent(10000, device) for k in range(0, opt.val_n_sample * 8, step_size): curr_samples, curr_thumbs = g_ema([sample_z[0][k:k+step_size]], sample_cam_extrinsics[k:k+step_size], sample_focals[k:k+step_size], sample_near[k:k+step_size], sample_far[k:k+step_size], truncation=0.7, truncation_latent=mean_latent) samples = torch.cat([samples, curr_samples.cpu()], 0) thumbs_samples = torch.cat([thumbs_samples, curr_thumbs.cpu()], 0) if i % 10000 == 0: utils.save_image(samples, os.path.join(opt.checkpoints_dir, experiment_opt.expname, 'full_pipeline', f"samples/{str(i).zfill(7)}.png"), nrow=int(opt.val_n_sample), normalize=True, value_range=(-1, 1),) utils.save_image(thumbs_samples, os.path.join(opt.checkpoints_dir, experiment_opt.expname, 'full_pipeline', f"samples/{str(i).zfill(7)}_thumbs.png"), nrow=int(opt.val_n_sample), normalize=True, value_range=(-1, 1),) if wandb and opt.wandb: wandb_log_dict = {"Generator": g_loss_val, "Discriminator": d_loss_val, "R1": r1_val, "Real Score": real_score_val, "Fake Score": fake_score_val, "Path Length Regularization": path_loss_val, "Path Length": path_length_val, "Mean Path Length": mean_path_length, } if i % 5000 == 0: wandb_grid = utils.make_grid(samples, nrow=int(opt.val_n_sample), normalize=True, value_range=(-1, 1)) wandb_ndarr = (255 * wandb_grid.permute(1, 2, 0).numpy()).astype(np.uint8) wandb_images = Image.fromarray(wandb_ndarr) wandb_log_dict.update({"examples": [wandb.Image(wandb_images, caption="Generated samples for azimuth angles of: -0.35, -0.25, -0.15, -0.05, 0.05, 0.15, 0.25, 0.35 Radians.")]}) wandb_thumbs_grid = utils.make_grid(thumbs_samples, nrow=int(opt.val_n_sample), normalize=True, value_range=(-1, 1)) wandb_thumbs_ndarr = (255 * wandb_thumbs_grid.permute(1, 2, 0).numpy()).astype(np.uint8) wandb_thumbs = Image.fromarray(wandb_thumbs_ndarr) wandb_log_dict.update({"thumb_examples": [wandb.Image(wandb_thumbs, caption="Generated samples for azimuth angles of: -0.35, -0.25, -0.15, -0.05, 0.05, 0.15, 0.25, 0.35 Radians.")]}) wandb.log(wandb_log_dict) if i % 10000 == 0: torch.save( { "g": g_module.state_dict(), "d": d_module.state_dict(), "g_ema": g_ema.state_dict(), }, os.path.join(opt.checkpoints_dir, experiment_opt.expname, 'full_pipeline', f"models_{str(i).zfill(7)}.pt") ) print('Successfully saved checkpoint for iteration {}.'.format(i)) if get_rank() == 0: # create final model directory final_model_path = os.path.join('full_models', opt.experiment.expname) os.makedirs(final_model_path, exist_ok=True) torch.save( { "g": g_module.state_dict(), "d": d_module.state_dict(), "g_ema": g_ema.state_dict(), }, os.path.join(final_model_path, experiment_opt.expname + '.pt') ) print('Successfully saved final model.') if __name__ == "__main__": device = "cuda" opt = BaseOptions().parse() opt.training.camera = opt.camera opt.training.size = opt.model.size opt.training.renderer_output_size = opt.model.renderer_spatial_output_dim opt.training.style_dim = opt.model.style_dim opt.model.freeze_renderer = True n_gpu = int(os.environ["WORLD_SIZE"]) if "WORLD_SIZE" in os.environ else 1 opt.training.distributed = n_gpu > 1 if opt.training.distributed: torch.cuda.set_device(opt.training.local_rank) torch.distributed.init_process_group(backend="nccl", init_method="env://") synchronize() # create checkpoints directories os.makedirs(os.path.join(opt.training.checkpoints_dir, opt.experiment.expname, 'full_pipeline'), exist_ok=True) os.makedirs(os.path.join(opt.training.checkpoints_dir, opt.experiment.expname, 'full_pipeline', 'samples'), exist_ok=True) discriminator = Discriminator(opt.model).to(device) generator = Generator(opt.model, opt.rendering).to(device) g_ema = Generator(opt.model, opt.rendering, ema=True).to(device) g_ema.eval() g_reg_ratio = opt.training.g_reg_every / (opt.training.g_reg_every + 1) if opt.training.g_reg_every > 0 else 1 d_reg_ratio = opt.training.d_reg_every / (opt.training.d_reg_every + 1) params_g = [] params_dict_g = dict(generator.named_parameters()) for key, value in params_dict_g.items(): decoder_cond = ('decoder' in key) if decoder_cond: params_g += [{'params':[value], 'lr':opt.training.lr * g_reg_ratio}] g_optim = optim.Adam(params_g, #generator.parameters(), lr=opt.training.lr * g_reg_ratio, betas=(0 ** g_reg_ratio, 0.99 ** g_reg_ratio)) d_optim = optim.Adam(discriminator.parameters(), lr=opt.training.lr * d_reg_ratio,# * g_d_ratio, betas=(0 ** d_reg_ratio, 0.99 ** d_reg_ratio)) opt.training.start_iter = 0 if opt.experiment.continue_training and opt.experiment.ckpt is not None: if get_rank() == 0: print("load model:", opt.experiment.ckpt) ckpt_path = os.path.join(opt.training.checkpoints_dir, opt.experiment.expname, 'models_{}.pt'.format(opt.experiment.ckpt.zfill(7))) ckpt = torch.load(ckpt_path, map_location=lambda storage, loc: storage) try: opt.training.start_iter = int(opt.experiment.ckpt) + 1 except ValueError: pass generator.load_state_dict(ckpt["g"]) discriminator.load_state_dict(ckpt["d"]) g_ema.load_state_dict(ckpt["g_ema"]) else: # save configuration opt_path = os.path.join(opt.training.checkpoints_dir, opt.experiment.expname, 'full_pipeline', f"opt.yaml") with open(opt_path,'w') as f: yaml.safe_dump(opt, f) if not opt.experiment.continue_training: if get_rank() == 0: print("loading pretrained renderer weights...") pretrained_renderer_path = os.path.join('./pretrained_renderer', opt.experiment.expname + '_vol_renderer.pt') try: ckpt = torch.load(pretrained_renderer_path, map_location=lambda storage, loc: storage) except: print('Pretrained volume renderer experiment name does not match the full pipeline experiment name.') vol_renderer_expname = str(input('Please enter the pretrained volume renderer experiment name:')) pretrained_renderer_path = os.path.join('./pretrained_renderer', vol_renderer_expname + '.pt') ckpt = torch.load(pretrained_renderer_path, map_location=lambda storage, loc: storage) pretrained_renderer_dict = ckpt["g_ema"] model_dict = generator.state_dict() for k, v in pretrained_renderer_dict.items(): if v.size() == model_dict[k].size(): model_dict[k] = v generator.load_state_dict(model_dict) # initialize g_ema weights to generator weights accumulate(g_ema, generator, 0) # set distributed models if opt.training.distributed: generator = nn.parallel.DistributedDataParallel( generator, device_ids=[opt.training.local_rank], output_device=opt.training.local_rank, broadcast_buffers=True, find_unused_parameters=True, ) discriminator = nn.parallel.DistributedDataParallel( discriminator, device_ids=[opt.training.local_rank], output_device=opt.training.local_rank, broadcast_buffers=False, find_unused_parameters=True ) transform = transforms.Compose( [transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)]) dataset = MultiResolutionDataset(opt.dataset.dataset_path, transform, opt.model.size, opt.model.renderer_spatial_output_dim) loader = data.DataLoader( dataset, batch_size=opt.training.batch, sampler=data_sampler(dataset, shuffle=True, distributed=opt.training.distributed), drop_last=True, ) if get_rank() == 0 and wandb is not None and opt.training.wandb: wandb.init(project="StyleSDF") wandb.run.name = opt.experiment.expname wandb.config.dataset = os.path.basename(opt.dataset.dataset_path) wandb.config.update(opt.training) wandb.config.update(opt.model) wandb.config.update(opt.rendering) train(opt.training, opt.experiment, loader, generator, discriminator, g_optim, d_optim, g_ema, device) ================================================ FILE: train_volume_renderer.py ================================================ import argparse import math import random import os import yaml import numpy as np import torch import torch.distributed as dist from torch import nn, autograd, optim from torch.nn import functional as F from torch.utils import data from torchvision import transforms, utils from tqdm import tqdm from PIL import Image from losses import * from options import BaseOptions from model import Generator, VolumeRenderDiscriminator from dataset import MultiResolutionDataset from utils import data_sampler, requires_grad, accumulate, sample_data, make_noise, mixing_noise, generate_camera_params from distributed import get_rank, synchronize, reduce_loss_dict, reduce_sum, get_world_size try: import wandb except ImportError: wandb = None def train(opt, experiment_opt, loader, generator, discriminator, g_optim, d_optim, g_ema, device): loader = sample_data(loader) mean_path_length = 0 d_loss_val = 0 r1_loss = torch.tensor(0.0, device=device) d_view_loss = torch.tensor(0.0, device=device) g_view_loss = torch.tensor(0.0, device=device) g_eikonal = torch.tensor(0.0, device=device) g_minimal_surface = torch.tensor(0.0, device=device) g_loss_val = 0 loss_dict = {} viewpoint_condition = opt.view_lambda > 0 if opt.distributed: g_module = generator.module d_module = discriminator.module else: g_module = generator d_module = discriminator accum = 0.5 ** (32 / (10 * 1000)) sample_z = [torch.randn(opt.val_n_sample, opt.style_dim, device=device).repeat_interleave(8,dim=0)] sample_cam_extrinsics, sample_focals, sample_near, sample_far, _ = generate_camera_params(opt.renderer_output_size, device, batch=opt.val_n_sample, sweep=True, uniform=opt.camera.uniform, azim_range=opt.camera.azim, elev_range=opt.camera.elev, fov_ang=opt.camera.fov, dist_radius=opt.camera.dist_radius) if opt.with_sdf and opt.sphere_init and opt.start_iter == 0: init_pbar = range(10000) if get_rank() == 0: init_pbar = tqdm(init_pbar, initial=0, dynamic_ncols=True, smoothing=0.01) generator.zero_grad() for idx in init_pbar: noise = mixing_noise(3, opt.style_dim, opt.mixing, device) cam_extrinsics, focal, near, far, gt_viewpoints = generate_camera_params(opt.renderer_output_size, device, batch=3, uniform=opt.camera.uniform, azim_range=opt.camera.azim, elev_range=opt.camera.elev, fov_ang=opt.camera.fov, dist_radius=opt.camera.dist_radius) sdf, target_values = g_module.init_forward(noise, cam_extrinsics, focal, near, far) loss = F.l1_loss(sdf, target_values) loss.backward() g_optim.step() generator.zero_grad() if get_rank() == 0: init_pbar.set_description((f"MLP init to sphere procedure - Loss: {loss.item():.4f}")) accumulate(g_ema, g_module, 0) torch.save( { "g": g_module.state_dict(), "d": d_module.state_dict(), "g_ema": g_ema.state_dict(), }, os.path.join(opt.checkpoints_dir, experiment_opt.expname, f"sdf_init_models_{str(0).zfill(7)}.pt") ) print('Successfully saved checkpoint for SDF initialized MLP.') pbar = range(opt.iter) if get_rank() == 0: pbar = tqdm(pbar, initial=opt.start_iter, dynamic_ncols=True, smoothing=0.01) for idx in pbar: i = idx + opt.start_iter if i > opt.iter: print("Done!") break requires_grad(generator, False) requires_grad(discriminator, True) discriminator.zero_grad() _, real_imgs = next(loader) real_imgs = real_imgs.to(device) noise = mixing_noise(opt.batch, opt.style_dim, opt.mixing, device) cam_extrinsics, focal, near, far, gt_viewpoints = generate_camera_params(opt.renderer_output_size, device, batch=opt.batch, uniform=opt.camera.uniform, azim_range=opt.camera.azim, elev_range=opt.camera.elev, fov_ang=opt.camera.fov, dist_radius=opt.camera.dist_radius) gen_imgs = [] for j in range(0, opt.batch, opt.chunk): curr_noise = [n[j:j+opt.chunk] for n in noise] _, fake_img = generator(curr_noise, cam_extrinsics[j:j+opt.chunk], focal[j:j+opt.chunk], near[j:j+opt.chunk], far[j:j+opt.chunk]) gen_imgs += [fake_img] gen_imgs = torch.cat(gen_imgs, 0) fake_pred, fake_viewpoint_pred = discriminator(gen_imgs.detach()) if viewpoint_condition: d_view_loss = opt.view_lambda * viewpoints_loss(fake_viewpoint_pred, gt_viewpoints) real_imgs.requires_grad = True real_pred, _ = discriminator(real_imgs) d_gan_loss = d_logistic_loss(real_pred, fake_pred) grad_penalty = d_r1_loss(real_pred, real_imgs) r1_loss = opt.r1 * 0.5 * grad_penalty d_loss = d_gan_loss + r1_loss + d_view_loss d_loss.backward() d_optim.step() loss_dict["d"] = d_gan_loss loss_dict["r1"] = r1_loss loss_dict["d_view"] = d_view_loss loss_dict["real_score"] = real_pred.mean() loss_dict["fake_score"] = fake_pred.mean() requires_grad(generator, True) requires_grad(discriminator, False) for j in range(0, opt.batch, opt.chunk): noise = mixing_noise(opt.chunk, opt.style_dim, opt.mixing, device) cam_extrinsics, focal, near, far, curr_gt_viewpoints = generate_camera_params(opt.renderer_output_size, device, batch=opt.chunk, uniform=opt.camera.uniform, azim_range=opt.camera.azim, elev_range=opt.camera.elev, fov_ang=opt.camera.fov, dist_radius=opt.camera.dist_radius) out = generator(noise, cam_extrinsics, focal, near, far, return_sdf=opt.min_surf_lambda > 0, return_eikonal=opt.eikonal_lambda > 0) fake_img = out[1] if opt.min_surf_lambda > 0: sdf = out[2] if opt.eikonal_lambda > 0: eikonal_term = out[3] fake_pred, fake_viewpoint_pred = discriminator(fake_img) if viewpoint_condition: g_view_loss = opt.view_lambda * viewpoints_loss(fake_viewpoint_pred, curr_gt_viewpoints) if opt.with_sdf and opt.eikonal_lambda > 0: g_eikonal, g_minimal_surface = eikonal_loss(eikonal_term, sdf=sdf if opt.min_surf_lambda > 0 else None, beta=opt.min_surf_beta) g_eikonal = opt.eikonal_lambda * g_eikonal if opt.min_surf_lambda > 0: g_minimal_surface = opt.min_surf_lambda * g_minimal_surface g_gan_loss = g_nonsaturating_loss(fake_pred) g_loss = g_gan_loss + g_view_loss + g_eikonal + g_minimal_surface g_loss.backward() g_optim.step() generator.zero_grad() loss_dict["g"] = g_gan_loss loss_dict["g_view"] = g_view_loss loss_dict["g_eikonal"] = g_eikonal loss_dict["g_minimal_surface"] = g_minimal_surface accumulate(g_ema, g_module, accum) loss_reduced = reduce_loss_dict(loss_dict) d_loss_val = loss_reduced["d"].mean().item() g_loss_val = loss_reduced["g"].mean().item() r1_val = loss_reduced["r1"].mean().item() real_score_val = loss_reduced["real_score"].mean().item() fake_score_val = loss_reduced["fake_score"].mean().item() d_view_val = loss_reduced["d_view"].mean().item() g_view_val = loss_reduced["g_view"].mean().item() g_eikonal_loss = loss_reduced["g_eikonal"].mean().item() g_minimal_surface_loss = loss_reduced["g_minimal_surface"].mean().item() g_beta_val = g_module.renderer.sigmoid_beta.item() if opt.with_sdf else 0 if get_rank() == 0: pbar.set_description( (f"d: {d_loss_val:.4f}; g: {g_loss_val:.4f}; r1: {r1_val:.4f}; viewpoint: {d_view_val+g_view_val:.4f}; eikonal: {g_eikonal_loss:.4f}; surf: {g_minimal_surface_loss:.4f}") ) if i % 1000 == 0: with torch.no_grad(): samples = torch.Tensor(0, 3, opt.renderer_output_size, opt.renderer_output_size) step_size = 4 mean_latent = g_module.mean_latent(10000, device) for k in range(0, opt.val_n_sample * 8, step_size): _, curr_samples = g_ema([sample_z[0][k:k+step_size]], sample_cam_extrinsics[k:k+step_size], sample_focals[k:k+step_size], sample_near[k:k+step_size], sample_far[k:k+step_size], truncation=0.7, truncation_latent=mean_latent,) samples = torch.cat([samples, curr_samples.cpu()], 0) if i % 10000 == 0: utils.save_image(samples, os.path.join(opt.checkpoints_dir, experiment_opt.expname, 'volume_renderer', f"samples/{str(i).zfill(7)}.png"), nrow=int(opt.val_n_sample), normalize=True, value_range=(-1, 1),) if wandb and opt.wandb: wandb_log_dict = {"Generator": g_loss_val, "Discriminator": d_loss_val, "R1": r1_val, "Real Score": real_score_val, "Fake Score": fake_score_val, "D viewpoint": d_view_val, "G viewpoint": g_view_val, "G eikonal loss": g_eikonal_loss, "G minimal surface loss": g_minimal_surface_loss, } if opt.with_sdf: wandb_log_dict.update({"Beta value": g_beta_val}) if i % 1000 == 0: wandb_grid = utils.make_grid(samples, nrow=int(opt.val_n_sample), normalize=True, value_range=(-1, 1)) wandb_ndarr = (255 * wandb_grid.permute(1, 2, 0).numpy()).astype(np.uint8) wandb_images = Image.fromarray(wandb_ndarr) wandb_log_dict.update({"examples": [wandb.Image(wandb_images, caption="Generated samples for azimuth angles of: -35, -25, -15, -5, 5, 15, 25, 35 degrees.")]}) wandb.log(wandb_log_dict) if i % 10000 == 0 or (i < 10000 and i % 1000 == 0): torch.save( { "g": g_module.state_dict(), "d": d_module.state_dict(), "g_ema": g_ema.state_dict(), }, os.path.join(opt.checkpoints_dir, experiment_opt.expname, 'volume_renderer', f"models_{str(i).zfill(7)}.pt") ) print('Successfully saved checkpoint for iteration {}.'.format(i)) if get_rank() == 0: # create final model directory final_model_path = 'pretrained_renderer' os.makedirs(final_model_path, exist_ok=True) torch.save( { "g": g_module.state_dict(), "d": d_module.state_dict(), "g_ema": g_ema.state_dict(), }, os.path.join(final_model_path, experiment_opt.expname + '_vol_renderer.pt') ) print('Successfully saved final model.') if __name__ == "__main__": device = "cuda" opt = BaseOptions().parse() opt.model.freeze_renderer = False opt.model.no_viewpoint_loss = opt.training.view_lambda == 0.0 opt.training.camera = opt.camera opt.training.renderer_output_size = opt.model.renderer_spatial_output_dim opt.training.style_dim = opt.model.style_dim opt.training.with_sdf = not opt.rendering.no_sdf if opt.training.with_sdf and opt.training.min_surf_lambda > 0: opt.rendering.return_sdf = True opt.training.iter = 200001 opt.rendering.no_features_output = True n_gpu = int(os.environ["WORLD_SIZE"]) if "WORLD_SIZE" in os.environ else 1 opt.training.distributed = n_gpu > 1 if opt.training.distributed: torch.cuda.set_device(int(os.environ["LOCAL_RANK"])) torch.distributed.init_process_group(backend="nccl", init_method="env://") synchronize() # create checkpoints directories os.makedirs(os.path.join(opt.training.checkpoints_dir, opt.experiment.expname, 'volume_renderer'), exist_ok=True) os.makedirs(os.path.join(opt.training.checkpoints_dir, opt.experiment.expname, 'volume_renderer', 'samples'), exist_ok=True) discriminator = VolumeRenderDiscriminator(opt.model).to(device) generator = Generator(opt.model, opt.rendering, full_pipeline=False).to(device) g_ema = Generator(opt.model, opt.rendering, ema=True, full_pipeline=False).to(device) g_ema.eval() accumulate(g_ema, generator, 0) g_optim = optim.Adam(generator.parameters(), lr=2e-5, betas=(0, 0.9)) d_optim = optim.Adam(discriminator.parameters(), lr=2e-4, betas=(0, 0.9)) opt.training.start_iter = 0 if opt.experiment.continue_training and opt.experiment.ckpt is not None: if get_rank() == 0: print("load model:", opt.experiment.ckpt) ckpt_path = os.path.join(opt.training.checkpoints_dir, opt.experiment.expname, 'models_{}.pt'.format(opt.experiment.ckpt.zfill(7))) ckpt = torch.load(ckpt_path, map_location=lambda storage, loc: storage) try: opt.training.start_iter = int(opt.experiment.ckpt) + 1 except ValueError: pass generator.load_state_dict(ckpt["g"]) discriminator.load_state_dict(ckpt["d"]) g_ema.load_state_dict(ckpt["g_ema"]) if "g_optim" in ckpt.keys(): g_optim.load_state_dict(ckpt["g_optim"]) d_optim.load_state_dict(ckpt["d_optim"]) sphere_init_path = './pretrained_renderer/sphere_init.pt' if opt.training.no_sphere_init: opt.training.sphere_init = False elif not opt.experiment.continue_training and opt.training.with_sdf and os.path.isfile(sphere_init_path): if get_rank() == 0: print("loading sphere inititialized model") ckpt = torch.load(sphere_init_path, map_location=lambda storage, loc: storage) generator.load_state_dict(ckpt["g"]) discriminator.load_state_dict(ckpt["d"]) g_ema.load_state_dict(ckpt["g_ema"]) opt.training.sphere_init = False else: opt.training.sphere_init = True if opt.training.distributed: generator = nn.parallel.DistributedDataParallel( generator, device_ids=[opt.training.local_rank], output_device=opt.training.local_rank, broadcast_buffers=True, find_unused_parameters=True, ) discriminator = nn.parallel.DistributedDataParallel( discriminator, device_ids=[opt.training.local_rank], output_device=opt.training.local_rank, broadcast_buffers=False, find_unused_parameters=True ) transform = transforms.Compose( [transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)]) dataset = MultiResolutionDataset(opt.dataset.dataset_path, transform, opt.model.size, opt.model.renderer_spatial_output_dim) loader = data.DataLoader( dataset, batch_size=opt.training.batch, sampler=data_sampler(dataset, shuffle=True, distributed=opt.training.distributed), drop_last=True, ) opt.training.dataset_name = opt.dataset.dataset_path.lower() # save options opt_path = os.path.join(opt.training.checkpoints_dir, opt.experiment.expname, 'volume_renderer', f"opt.yaml") with open(opt_path,'w') as f: yaml.safe_dump(opt, f) # set wandb environment if get_rank() == 0 and wandb is not None and opt.training.wandb: wandb.init(project="StyleSDF") wandb.run.name = opt.experiment.expname wandb.config.dataset = os.path.basename(opt.dataset.dataset_path) wandb.config.update(opt.training) wandb.config.update(opt.model) wandb.config.update(opt.rendering) train(opt.training, opt.experiment, loader, generator, discriminator, g_optim, d_optim, g_ema, device) ================================================ FILE: utils.py ================================================ import torch import random import trimesh import numpy as np from torch import nn from torch.nn import functional as F from torch.utils import data from scipy.spatial import Delaunay from skimage.measure import marching_cubes from pdb import set_trace as st import pytorch3d.io from pytorch3d.structures import Meshes from pytorch3d.renderer import ( look_at_view_transform, FoVPerspectiveCameras, PointLights, RasterizationSettings, MeshRenderer, MeshRasterizer, SoftPhongShader, TexturesVertex, ) ######################### Dataset util functions ########################### # Get data sampler def data_sampler(dataset, shuffle, distributed): if distributed: return data.distributed.DistributedSampler(dataset, shuffle=shuffle) if shuffle: return data.RandomSampler(dataset) else: return data.SequentialSampler(dataset) # Get data minibatch def sample_data(loader): while True: for batch in loader: yield batch ############################## Model weights util functions ################# # Turn model gradients on/off def requires_grad(model, flag=True): for p in model.parameters(): p.requires_grad = flag # Exponential moving average for generator weights def accumulate(model1, model2, decay=0.999): par1 = dict(model1.named_parameters()) par2 = dict(model2.named_parameters()) for k in par1.keys(): par1[k].data.mul_(decay).add_(par2[k].data, alpha=1 - decay) ################### Latent code (Z) sampling util functions #################### # Sample Z space latent codes for the generator def make_noise(batch, latent_dim, n_noise, device): if n_noise == 1: return torch.randn(batch, latent_dim, device=device) noises = torch.randn(n_noise, batch, latent_dim, device=device).unbind(0) return noises def mixing_noise(batch, latent_dim, prob, device): if prob > 0 and random.random() < prob: return make_noise(batch, latent_dim, 2, device) else: return [make_noise(batch, latent_dim, 1, device)] ################# Camera parameters sampling #################### def generate_camera_params(resolution, device, batch=1, locations=None, sweep=False, uniform=False, azim_range=0.3, elev_range=0.15, fov_ang=6, dist_radius=0.12): if locations != None: azim = locations[:,0].view(-1,1) elev = locations[:,1].view(-1,1) # generate intrinsic parameters # fix distance to 1 dist = torch.ones(azim.shape[0], 1, device=device) near, far = (dist - dist_radius).unsqueeze(-1), (dist + dist_radius).unsqueeze(-1) fov_angle = fov_ang * torch.ones(azim.shape[0], 1, device=device).view(-1,1) * np.pi / 180 focal = 0.5 * resolution / torch.tan(fov_angle).unsqueeze(-1) elif sweep: # generate camera locations on the unit sphere azim = (-azim_range + (2 * azim_range / 7) * torch.arange(8, device=device)).view(-1,1).repeat(batch,1) elev = (-elev_range + 2 * elev_range * torch.rand(batch, 1, device=device).repeat(1,8).view(-1,1)) # generate intrinsic parameters dist = (torch.ones(batch, 1, device=device)).repeat(1,8).view(-1,1) near, far = (dist - dist_radius).unsqueeze(-1), (dist + dist_radius).unsqueeze(-1) fov_angle = fov_ang * torch.ones(batch, 1, device=device).repeat(1,8).view(-1,1) * np.pi / 180 focal = 0.5 * resolution / torch.tan(fov_angle).unsqueeze(-1) else: # sample camera locations on the unit sphere if uniform: azim = (-azim_range + 2 * azim_range * torch.rand(batch, 1, device=device)) elev = (-elev_range + 2 * elev_range * torch.rand(batch, 1, device=device)) else: azim = (azim_range * torch.randn(batch, 1, device=device)) elev = (elev_range * torch.randn(batch, 1, device=device)) # generate intrinsic parameters dist = torch.ones(batch, 1, device=device) # restrict camera position to be on the unit sphere near, far = (dist - dist_radius).unsqueeze(-1), (dist + dist_radius).unsqueeze(-1) fov_angle = fov_ang * torch.ones(batch, 1, device=device) * np.pi / 180 # full fov is 12 degrees focal = 0.5 * resolution / torch.tan(fov_angle).unsqueeze(-1) viewpoint = torch.cat([azim, elev], 1) #### Generate camera extrinsic matrix ########## # convert angles to xyz coordinates x = torch.cos(elev) * torch.sin(azim) y = torch.sin(elev) z = torch.cos(elev) * torch.cos(azim) camera_dir = torch.stack([x, y, z], dim=1).view(-1,3) camera_loc = dist * camera_dir # get rotation matrices (assume object is at the world coordinates origin) up = torch.tensor([[0,1,0]]).float().to(device) * torch.ones_like(dist) z_axis = F.normalize(camera_dir, eps=1e-5) # the -z direction points into the screen x_axis = F.normalize(torch.cross(up, z_axis, dim=1), eps=1e-5) y_axis = F.normalize(torch.cross(z_axis, x_axis, dim=1), eps=1e-5) is_close = torch.isclose(x_axis, torch.tensor(0.0), atol=5e-3).all(dim=1, keepdim=True) if is_close.any(): replacement = F.normalize(torch.cross(y_axis, z_axis, dim=1), eps=1e-5) x_axis = torch.where(is_close, replacement, x_axis) R = torch.cat((x_axis[:, None, :], y_axis[:, None, :], z_axis[:, None, :]), dim=1) T = camera_loc[:, :, None] extrinsics = torch.cat((R.transpose(1,2),T), -1) return extrinsics, focal, near, far, viewpoint #################### Mesh generation util functions ######################## # Reshape sampling volume to camera frostum def align_volume(volume, near=0.88, far=1.12): b, h, w, d, c = volume.shape yy, xx, zz = torch.meshgrid(torch.linspace(-1, 1, h), torch.linspace(-1, 1, w), torch.linspace(-1, 1, d)) grid = torch.stack([xx, yy, zz], -1).to(volume.device) frostum_adjustment_coeffs = torch.linspace(far / near, 1, d).view(1,1,1,-1,1).to(volume.device) frostum_grid = grid.unsqueeze(0) frostum_grid[...,:2] = frostum_grid[...,:2] * frostum_adjustment_coeffs out_of_boundary = torch.any((frostum_grid.lt(-1).logical_or(frostum_grid.gt(1))), -1, keepdim=True) frostum_grid = frostum_grid.permute(0,3,1,2,4).contiguous() permuted_volume = volume.permute(0,4,3,1,2).contiguous() final_volume = F.grid_sample(permuted_volume, frostum_grid, padding_mode="border", align_corners=True) final_volume = final_volume.permute(0,3,4,2,1).contiguous() # set a non-zero value to grid locations outside of the frostum to avoid marching cubes distortions. # It happens because pytorch grid_sample uses zeros padding. final_volume[out_of_boundary] = 1 return final_volume # Extract mesh with marching cubes def extract_mesh_with_marching_cubes(sdf): b, h, w, d, _ = sdf.shape # change coordinate order from (y,x,z) to (x,y,z) sdf_vol = sdf[0,...,0].permute(1,0,2).cpu().numpy() # scale vertices verts, faces, _, _ = marching_cubes(sdf_vol, 0) verts[:,0] = (verts[:,0]/float(w)-0.5)*0.24 verts[:,1] = (verts[:,1]/float(h)-0.5)*0.24 verts[:,2] = (verts[:,2]/float(d)-0.5)*0.24 # fix normal direction verts[:,2] *= -1; verts[:,1] *= -1 mesh = trimesh.Trimesh(verts, faces) return mesh # Generate mesh from xyz point cloud def xyz2mesh(xyz): b, _, h, w = xyz.shape x, y = np.meshgrid(np.arange(h), np.arange(w)) # Extract mesh faces from xyz maps tri = Delaunay(np.concatenate((x.reshape((h*w, 1)), y.reshape((h*w, 1))), 1)) faces = tri.simplices # invert normals faces[:,[0, 1]] = faces[:,[1, 0]] # generate_meshes mesh = trimesh.Trimesh(xyz.squeeze(0).permute(1,2,0).view(h*w,3).cpu().numpy(), faces) return mesh ################# Mesh rendering util functions ############################# def add_textures(meshes:Meshes, vertex_colors=None) -> Meshes: verts = meshes.verts_padded() if vertex_colors is None: vertex_colors = torch.ones_like(verts) # (N, V, 3) textures = TexturesVertex(verts_features=vertex_colors) meshes_t = Meshes( verts=verts, faces=meshes.faces_padded(), textures=textures, verts_normals=meshes.verts_normals_padded(), ) return meshes_t def create_cameras( R=None, T=None, azim=0, elev=0., dist=1., fov=12., znear=0.01, device="cuda") -> FoVPerspectiveCameras: """ all the camera parameters can be a single number, a list, or a torch tensor. """ if R is None or T is None: R, T = look_at_view_transform(dist=dist, azim=azim, elev=elev, device=device) cameras = FoVPerspectiveCameras(device=device, R=R, T=T, znear=znear, fov=fov) return cameras def create_mesh_renderer( cameras: FoVPerspectiveCameras, image_size: int = 256, blur_radius: float = 1e-6, light_location=((-0.5, 1., 5.0),), device="cuda", **light_kwargs, ): """ If don't want to show direct texture color without shading, set the light_kwargs as ambient_color=((1, 1, 1), ), diffuse_color=((0, 0, 0), ), specular_color=((0, 0, 0), ) """ # We will also create a Phong renderer. This is simpler and only needs to render one face per pixel. raster_settings = RasterizationSettings( image_size=image_size, blur_radius=blur_radius, faces_per_pixel=5, ) # We can add a point light in front of the object. lights = PointLights( device=device, location=light_location, **light_kwargs ) phong_renderer = MeshRenderer( rasterizer=MeshRasterizer( cameras=cameras, raster_settings=raster_settings ), shader=SoftPhongShader(device=device, cameras=cameras, lights=lights) ) return phong_renderer ## custom renderer class MeshRendererWithDepth(nn.Module): def __init__(self, rasterizer, shader): super().__init__() self.rasterizer = rasterizer self.shader = shader def forward(self, meshes_world, **kwargs) -> torch.Tensor: fragments = self.rasterizer(meshes_world, **kwargs) images = self.shader(fragments, meshes_world, **kwargs) return images, fragments.zbuf def create_depth_mesh_renderer( cameras: FoVPerspectiveCameras, image_size: int = 256, blur_radius: float = 1e-6, device="cuda", **light_kwargs, ): """ If don't want to show direct texture color without shading, set the light_kwargs as ambient_color=((1, 1, 1), ), diffuse_color=((0, 0, 0), ), specular_color=((0, 0, 0), ) """ # We will also create a Phong renderer. This is simpler and only needs to render one face per pixel. raster_settings = RasterizationSettings( image_size=image_size, blur_radius=blur_radius, faces_per_pixel=17, ) # We can add a point light in front of the object. lights = PointLights( device=device, location=((-0.5, 1., 5.0),), **light_kwargs ) renderer = MeshRendererWithDepth( rasterizer=MeshRasterizer( cameras=cameras, raster_settings=raster_settings, device=device, ), shader=SoftPhongShader(device=device, cameras=cameras, lights=lights) ) return renderer ================================================ FILE: volume_renderer.py ================================================ import math import torch import torch.nn as nn import torch.nn.functional as F import torch.autograd as autograd import numpy as np from functools import partial from pdb import set_trace as st # Basic SIREN fully connected layer class LinearLayer(nn.Module): def __init__(self, in_dim, out_dim, bias=True, bias_init=0, std_init=1, freq_init=False, is_first=False): super().__init__() if is_first: self.weight = nn.Parameter(torch.empty(out_dim, in_dim).uniform_(-1 / in_dim, 1 / in_dim)) elif freq_init: self.weight = nn.Parameter(torch.empty(out_dim, in_dim).uniform_(-np.sqrt(6 / in_dim) / 25, np.sqrt(6 / in_dim) / 25)) else: self.weight = nn.Parameter(0.25 * nn.init.kaiming_normal_(torch.randn(out_dim, in_dim), a=0.2, mode='fan_in', nonlinearity='leaky_relu')) self.bias = nn.Parameter(nn.init.uniform_(torch.empty(out_dim), a=-np.sqrt(1/in_dim), b=np.sqrt(1/in_dim))) self.bias_init = bias_init self.std_init = std_init def forward(self, input): out = self.std_init * F.linear(input, self.weight, bias=self.bias) + self.bias_init return out # Siren layer with frequency modulation and offset class FiLMSiren(nn.Module): def __init__(self, in_channel, out_channel, style_dim, is_first=False): super().__init__() self.in_channel = in_channel self.out_channel = out_channel if is_first: self.weight = nn.Parameter(torch.empty(out_channel, in_channel).uniform_(-1 / 3, 1 / 3)) else: self.weight = nn.Parameter(torch.empty(out_channel, in_channel).uniform_(-np.sqrt(6 / in_channel) / 25, np.sqrt(6 / in_channel) / 25)) self.bias = nn.Parameter(nn.Parameter(nn.init.uniform_(torch.empty(out_channel), a=-np.sqrt(1/in_channel), b=np.sqrt(1/in_channel)))) self.activation = torch.sin self.gamma = LinearLayer(style_dim, out_channel, bias_init=30, std_init=15) self.beta = LinearLayer(style_dim, out_channel, bias_init=0, std_init=0.25) def forward(self, input, style): batch, features = style.shape out = F.linear(input, self.weight, bias=self.bias) gamma = self.gamma(style).view(batch, 1, 1, 1, features) beta = self.beta(style).view(batch, 1, 1, 1, features) out = self.activation(gamma * out + beta) return out # Siren Generator Model class SirenGenerator(nn.Module): def __init__(self, D=8, W=256, style_dim=256, input_ch=3, input_ch_views=3, output_ch=4, output_features=True): super(SirenGenerator, self).__init__() self.D = D self.W = W self.input_ch = input_ch self.input_ch_views = input_ch_views self.style_dim = style_dim self.output_features = output_features self.pts_linears = nn.ModuleList( [FiLMSiren(3, W, style_dim=style_dim, is_first=True)] + \ [FiLMSiren(W, W, style_dim=style_dim) for i in range(D-1)]) self.views_linears = FiLMSiren(input_ch_views + W, W, style_dim=style_dim) self.rgb_linear = LinearLayer(W, 3, freq_init=True) self.sigma_linear = LinearLayer(W, 1, freq_init=True) def forward(self, x, styles): input_pts, input_views = torch.split(x, [self.input_ch, self.input_ch_views], dim=-1) mlp_out = input_pts.contiguous() for i in range(len(self.pts_linears)): mlp_out = self.pts_linears[i](mlp_out, styles) sdf = self.sigma_linear(mlp_out) mlp_out = torch.cat([mlp_out, input_views], -1) out_features = self.views_linears(mlp_out, styles) rgb = self.rgb_linear(out_features) outputs = torch.cat([rgb, sdf], -1) if self.output_features: outputs = torch.cat([outputs, out_features], -1) return outputs # Full volume renderer class VolumeFeatureRenderer(nn.Module): def __init__(self, opt, style_dim=256, out_im_res=64, mode='train'): super().__init__() self.test = mode != 'train' self.perturb = opt.perturb self.offset_sampling = not opt.no_offset_sampling # Stratified sampling used otherwise self.N_samples = opt.N_samples self.raw_noise_std = opt.raw_noise_std self.return_xyz = opt.return_xyz self.return_sdf = opt.return_sdf self.static_viewdirs = opt.static_viewdirs self.z_normalize = not opt.no_z_normalize self.out_im_res = out_im_res self.force_background = opt.force_background self.with_sdf = not opt.no_sdf if 'no_features_output' in opt.keys(): self.output_features = False else: self.output_features = True if self.with_sdf: self.sigmoid_beta = nn.Parameter(0.1 * torch.ones(1)) # create meshgrid to generate rays i, j = torch.meshgrid(torch.linspace(0.5, self.out_im_res - 0.5, self.out_im_res), torch.linspace(0.5, self.out_im_res - 0.5, self.out_im_res)) self.register_buffer('i', i.t().unsqueeze(0), persistent=False) self.register_buffer('j', j.t().unsqueeze(0), persistent=False) # create integration values if self.offset_sampling: t_vals = torch.linspace(0., 1.-1/self.N_samples, steps=self.N_samples).view(1,1,1,-1) else: # Original NeRF Stratified sampling t_vals = torch.linspace(0., 1., steps=self.N_samples).view(1,1,1,-1) self.register_buffer('t_vals', t_vals, persistent=False) self.register_buffer('inf', torch.Tensor([1e10]), persistent=False) self.register_buffer('zero_idx', torch.LongTensor([0]), persistent=False) if self.test: self.perturb = False self.raw_noise_std = 0. self.channel_dim = -1 self.samples_dim = 3 self.input_ch = 3 self.input_ch_views = 3 self.feature_out_size = opt.width # set Siren Generator model self.network = SirenGenerator(D=opt.depth, W=opt.width, style_dim=style_dim, input_ch=self.input_ch, output_ch=4, input_ch_views=self.input_ch_views, output_features=self.output_features) def get_rays(self, focal, c2w): dirs = torch.stack([(self.i - self.out_im_res * .5) / focal, -(self.j - self.out_im_res * .5) / focal, -torch.ones_like(self.i).expand(focal.shape[0],self.out_im_res, self.out_im_res)], -1) # Rotate ray directions from camera frame to the world frame rays_d = torch.sum(dirs[..., None, :] * c2w[:,None,None,:3,:3], -1) # dot product, equals to: [c2w.dot(dir) for dir in dirs] # Translate camera frame's origin to the world frame. It is the origin of all rays. rays_o = c2w[:,None,None,:3,-1].expand(rays_d.shape) if self.static_viewdirs: viewdirs = dirs else: viewdirs = rays_d return rays_o, rays_d, viewdirs def get_eikonal_term(self, pts, sdf): eikonal_term = autograd.grad(outputs=sdf, inputs=pts, grad_outputs=torch.ones_like(sdf), create_graph=True)[0] return eikonal_term def sdf_activation(self, input): sigma = torch.sigmoid(input / self.sigmoid_beta) / self.sigmoid_beta return sigma def volume_integration(self, raw, z_vals, rays_d, pts, return_eikonal=False): dists = z_vals[...,1:] - z_vals[...,:-1] rays_d_norm = torch.norm(rays_d.unsqueeze(self.samples_dim), dim=self.channel_dim) # dists still has 4 dimensions here instead of 5, hence, in this case samples dim is actually the channel dim dists = torch.cat([dists, self.inf.expand(rays_d_norm.shape)], self.channel_dim) # [N_rays, N_samples] dists = dists * rays_d_norm # If sdf modeling is off, the sdf variable stores the # pre-integration raw sigma MLP outputs. if self.output_features: rgb, sdf, features = torch.split(raw, [3, 1, self.feature_out_size], dim=self.channel_dim) else: rgb, sdf = torch.split(raw, [3, 1], dim=self.channel_dim) noise = 0. if self.raw_noise_std > 0.: noise = torch.randn_like(sdf) * self.raw_noise_std if self.with_sdf: sigma = self.sdf_activation(-sdf) if return_eikonal: eikonal_term = self.get_eikonal_term(pts, sdf) else: eikonal_term = None sigma = 1 - torch.exp(-sigma * dists.unsqueeze(self.channel_dim)) else: sigma = sdf eikonal_term = None sigma = 1 - torch.exp(-F.softplus(sigma + noise) * dists.unsqueeze(self.channel_dim)) visibility = torch.cumprod(torch.cat([torch.ones_like(torch.index_select(sigma, self.samples_dim, self.zero_idx)), 1.-sigma + 1e-10], self.samples_dim), self.samples_dim) visibility = visibility[...,:-1,:] weights = sigma * visibility if self.return_sdf: sdf_out = sdf else: sdf_out = None if self.force_background: weights[...,-1,:] = 1 - weights[...,:-1,:].sum(self.samples_dim) rgb_map = -1 + 2 * torch.sum(weights * torch.sigmoid(rgb), self.samples_dim) # switch to [-1,1] value range if self.output_features: feature_map = torch.sum(weights * features, self.samples_dim) else: feature_map = None # Return surface point cloud in world coordinates. # This is used to generate the depth maps visualizations. # We use world coordinates to avoid transformation errors between # surface renderings from different viewpoints. if self.return_xyz: xyz = torch.sum(weights * pts, self.samples_dim) mask = weights[...,-1,:] # background probability map else: xyz = None mask = None return rgb_map, feature_map, sdf_out, mask, xyz, eikonal_term def run_network(self, inputs, viewdirs, styles=None): input_dirs = viewdirs.unsqueeze(self.samples_dim).expand(inputs.shape) net_inputs = torch.cat([inputs, input_dirs], self.channel_dim) outputs = self.network(net_inputs, styles=styles) return outputs def render_rays(self, ray_batch, styles=None, return_eikonal=False): batch, h, w, _ = ray_batch.shape split_pattern = [3, 3, 2] if ray_batch.shape[-1] > 8: split_pattern += [3] rays_o, rays_d, bounds, viewdirs = torch.split(ray_batch, split_pattern, dim=self.channel_dim) else: rays_o, rays_d, bounds = torch.split(ray_batch, split_pattern, dim=self.channel_dim) viewdirs = None near, far = torch.split(bounds, [1, 1], dim=self.channel_dim) z_vals = near * (1.-self.t_vals) + far * (self.t_vals) if self.perturb > 0.: if self.offset_sampling: # random offset samples upper = torch.cat([z_vals[...,1:], far], -1) lower = z_vals.detach() t_rand = torch.rand(batch, h, w).unsqueeze(self.channel_dim).to(z_vals.device) else: # get intervals between samples mids = .5 * (z_vals[...,1:] + z_vals[...,:-1]) upper = torch.cat([mids, z_vals[...,-1:]], -1) lower = torch.cat([z_vals[...,:1], mids], -1) # stratified samples in those intervals t_rand = torch.rand(z_vals.shape).to(z_vals.device) z_vals = lower + (upper - lower) * t_rand pts = rays_o.unsqueeze(self.samples_dim) + rays_d.unsqueeze(self.samples_dim) * z_vals.unsqueeze(self.channel_dim) if return_eikonal: pts.requires_grad = True if self.z_normalize: normalized_pts = pts * 2 / ((far - near).unsqueeze(self.samples_dim)) else: normalized_pts = pts raw = self.run_network(normalized_pts, viewdirs, styles=styles) rgb_map, features, sdf, mask, xyz, eikonal_term = self.volume_integration(raw, z_vals, rays_d, pts, return_eikonal=return_eikonal) return rgb_map, features, sdf, mask, xyz, eikonal_term def render(self, focal, c2w, near, far, styles, c2w_staticcam=None, return_eikonal=False): rays_o, rays_d, viewdirs = self.get_rays(focal, c2w) viewdirs = viewdirs / torch.norm(viewdirs, dim=-1, keepdim=True) # Create ray batch near = near.unsqueeze(-1) * torch.ones_like(rays_d[...,:1]) far = far.unsqueeze(-1) * torch.ones_like(rays_d[...,:1]) rays = torch.cat([rays_o, rays_d, near, far], -1) rays = torch.cat([rays, viewdirs], -1) rays = rays.float() rgb, features, sdf, mask, xyz, eikonal_term = self.render_rays(rays, styles=styles, return_eikonal=return_eikonal) return rgb, features, sdf, mask, xyz, eikonal_term def mlp_init_pass(self, cam_poses, focal, near, far, styles=None): rays_o, rays_d, viewdirs = self.get_rays(focal, cam_poses) viewdirs = viewdirs / torch.norm(viewdirs, dim=-1, keepdim=True) near = near.unsqueeze(-1) * torch.ones_like(rays_d[...,:1]) far = far.unsqueeze(-1) * torch.ones_like(rays_d[...,:1]) z_vals = near * (1.-self.t_vals) + far * (self.t_vals) # get intervals between samples mids = .5 * (z_vals[...,1:] + z_vals[...,:-1]) upper = torch.cat([mids, z_vals[...,-1:]], -1) lower = torch.cat([z_vals[...,:1], mids], -1) # stratified samples in those intervals t_rand = torch.rand(z_vals.shape).to(z_vals.device) z_vals = lower + (upper - lower) * t_rand pts = rays_o.unsqueeze(self.samples_dim) + rays_d.unsqueeze(self.samples_dim) * z_vals.unsqueeze(self.channel_dim) if self.z_normalize: normalized_pts = pts * 2 / ((far - near).unsqueeze(self.samples_dim)) else: normalized_pts = pts raw = self.run_network(normalized_pts, viewdirs, styles=styles) _, sdf = torch.split(raw, [3, 1], dim=self.channel_dim) sdf = sdf.squeeze(self.channel_dim) target_values = pts.detach().norm(dim=-1) - ((far - near) / 4) return sdf, target_values def forward(self, cam_poses, focal, near, far, styles=None, return_eikonal=False): rgb, features, sdf, mask, xyz, eikonal_term = self.render(focal, c2w=cam_poses, near=near, far=far, styles=styles, return_eikonal=return_eikonal) rgb = rgb.permute(0,3,1,2).contiguous() if self.output_features: features = features.permute(0,3,1,2).contiguous() if xyz != None: xyz = xyz.permute(0,3,1,2).contiguous() mask = mask.permute(0,3,1,2).contiguous() return rgb, features, sdf, mask, xyz, eikonal_term