[
  {
    "path": ".gitignore",
    "content": "# Byte-compiled / optimized / DLL files\n__pycache__/\n*.py[cod]\n*$py.class\n\n# C extensions\n*.so\n\n# Distribution / packaging\n.Python\nbuild/\ndevelop-eggs/\ndist/\ndownloads/\neggs/\n.eggs/\n# lib/\nlib64/\nparts/\nsdist/\nvar/\nwheels/\nshare/python-wheels/\n*.egg-info/\n.installed.cfg\n*.egg\nMANIFEST\n\n# PyInstaller\n#  Usually these files are written by a python script from a template\n#  before PyInstaller builds the exe, so as to inject date/other infos into it.\n*.manifest\n*.spec\n\n# Installer logs\npip-log.txt\npip-delete-this-directory.txt\n\n# Unit test / coverage reports\nhtmlcov/\n.tox/\n.nox/\n.coverage\n.coverage.*\n.cache\nnosetests.xml\ncoverage.xml\n*.cover\n*.py,cover\n.hypothesis/\n.pytest_cache/\ncover/\n\n# Translations\n*.mo\n*.pot\n\n# Django stuff:\n*.log\nlocal_settings.py\ndb.sqlite3\ndb.sqlite3-journal\n\n# Flask stuff:\ninstance/\n.webassets-cache\n\n# Scrapy stuff:\n.scrapy\n\n# Sphinx documentation\ndocs/_build/\n\n# PyBuilder\n.pybuilder/\ntarget/\n\n# Jupyter Notebook\n.ipynb_checkpoints\n\n# IPython\nprofile_default/\nipython_config.py\n\n# pyenv\n#   For a library or package, you might want to ignore these files since the code is\n#   intended to run in multiple environments; otherwise, check them in:\n# .python-version\n\n# pipenv\n#   According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.\n#   However, in case of collaboration, if having platform-specific dependencies or dependencies\n#   having no cross-platform support, pipenv may install dependencies that don't work, or not\n#   install all needed dependencies.\n#Pipfile.lock\n\n# poetry\n#   Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.\n#   This is especially recommended for binary packages to ensure reproducibility, and is more\n#   commonly ignored for libraries.\n#   https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control\n#poetry.lock\n\n# pdm\n#   Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.\n#pdm.lock\n#   pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it\n#   in version control.\n#   https://pdm.fming.dev/#use-with-ide\n.pdm.toml\n\n# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm\n__pypackages__/\n\n# Celery stuff\ncelerybeat-schedule\ncelerybeat.pid\n\n# SageMath parsed files\n*.sage.py\n\n# Environments\n.env\n.venv\nenv/\nvenv/\nENV/\nenv.bak/\nvenv.bak/\n\n# Spyder project settings\n.spyderproject\n.spyproject\n\n# Rope project settings\n.ropeproject\n\n# mkdocs documentation\n/site\n\n# mypy\n.mypy_cache/\n.dmypy.json\ndmypy.json\n\n# Pyre type checker\n.pyre/\n\n# pytype static type analyzer\n.pytype/\n\n# Cython debug symbols\ncython_debug/\n\n# PyCharm\n#  JetBrains specific template is maintained in a separate JetBrains.gitignore that can\n#  be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore\n#  and can be added to the global gitignore or merged into this file.  For a more nuclear\n#  option (not recommended) you can uncomment the following to ignore the entire idea folder.\n#.idea/\n\n.DS_STORE\n"
  },
  {
    "path": "GMeshDiffusion/diffusion_configs/config_lower_occgrid_normalized.py",
    "content": "import ml_collections\nimport torch\nimport os\n\n\ndef get_config():\n    config = ml_collections.ConfigDict()\n\n    # data\n    data = config.data = ml_collections.ConfigDict()\n    data.root_dir = 'PLACEHOLDER'\n    # data.dataset_metapath = os.path.join(data.root_dir, 'metadata/lower_res64_train.txt')\n    data.num_workers = 4\n    data.grid_size = 128\n    data.tet_resolution = 64\n    data.num_channels = 4\n    data.use_occ_grid = True\n    data.grid_metafile = os.path.join(data.root_dir, 'metadata/lower_res64_grid_train.txt')\n    data.occgrid_metafile = os.path.join(data.root_dir, 'metadata/lower_res64_occgrid_train.txt')\n\n    data.occ_mask_path = os.path.join(data.root_dir, 'metadata/occ_mask_res64.pt')\n    data.tet_info_path = os.path.join(data.root_dir, 'metadata/tet_info.pt')\n\n    data.filter_meta_path = None\n    data.aug = True\n\n    # training\n    training = config.training = ml_collections.ConfigDict()\n    training.sde = 'vpsde'\n    training.continuous = False\n    training.reduce_mean = True\n    training.batch_size = 1 ### for DDP, global_batch_size = nproc * local_batch_size\n    training.num_grad_acc_steps = 4 \n    training.n_iters = 2400001\n    training.snapshot_freq = 1000\n    training.log_freq = 50\n    ## produce samples at each snapshot.\n    training.snapshot_sampling = True\n    training.likelihood_weighting = False\n    training.loss_type = 'l2'\n    training.train_dir = \"PLACEHOLDER\"\n    training.snapshot_freq_for_preemption = 1000\n    training.gradscaler_growth_interval = 1000\n    training.use_aux_loss = False\n\n\n    training.compile = True # PyTorch 2.0, torch.compile\n    training.enable_xformers_memory_efficient_attention = True\n\n    # sampling\n    sampling = config.sampling = ml_collections.ConfigDict()\n    sampling.method = 'pc'\n    sampling.predictor = 'ancestral_sampling'\n    sampling.corrector = 'none'\n    sampling.n_steps_each = 1\n    sampling.noise_removal = True\n    sampling.probability_flow = False\n    sampling.snr = 0.075\n\n\n    # model\n    model = config.model = ml_collections.ConfigDict()\n    model.name = 'unet3d_occgrid'\n    model.use_occ_grid = True\n    model.num_res_blocks = 2\n    model.num_res_blocks_1st_layer = 2\n    model.base_channels = 128\n    model.ch_mult = (1, 2, 2, 4, 4, 4)\n    model.down_block_types = (\n        \"ResBlock\", \"ResBlock\", \"ResBlock\", \"AttnResBlock\", \"ResBlock\", \"ResBlock\"\n    )\n    model.up_block_types = (\n       \"ResBlock\", \"ResBlock\", \"AttnResBlock\", \"ResBlock\", \"ResBlock\", \"ResBlock\"\n    )\n    model.scale_by_sigma = False\n    model.num_scales = 1000\n    model.ema_rate = 0.9999\n    model.normalization = 'GroupNorm'\n    model.act_fn = 'swish'\n    model.attn_resolutions = (16,)\n    model.resamp_with_conv = True\n    model.dropout = 0.1\n    model.sigma_max = 378\n    model.sigma_min = 0.01\n    model.beta_min = 0.1\n    model.beta_max = 20.\n    model.embedding_type = 'fourier'\n    model.pred_type = 'noise'\n    model.conditional = True\n\n    model.feature_mask_path = os.path.join(data.root_dir, 'metadata/global_mask_res64.pt')\n    model.pixcat_mask_path = os.path.join(data.root_dir, 'metadata/cat_mask_res64.pt')\n\n    # optimization\n    config.optim = optim = ml_collections.ConfigDict()\n    optim.weight_decay = 1e-5\n    optim.optimizer = 'AdamW'\n    optim.lr = 1e-5\n    optim.beta1 = 0.9\n    optim.eps = 1e-8\n    optim.warmup = 5000\n    optim.grad_clip = 1.\n\n    # eval\n    config.eval = eval_config = ml_collections.ConfigDict()\n    eval_config.batch_size = 2\n    eval_config.idx = 0\n    eval_config.bin_size = 30\n    eval_config.eval_dir = \"PLACEHOLDER\"\n    eval_config.ckpt_path = \"PLACEHOLDER\"\n    \n\n    config.seed = 42\n    config.device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')\n\n\n    return config\n"
  },
  {
    "path": "GMeshDiffusion/diffusion_configs/config_upper_occgrid_normalized.py",
    "content": "import ml_collections\nimport torch\nimport os\n\n\ndef get_config():\n    config = ml_collections.ConfigDict()\n\n    # data\n    data = config.data = ml_collections.ConfigDict()\n    data.root_dir = 'PLACEHOLDER'\n    # data.dataset_metapath = os.path.join(data.root_dir, 'metadata/upper_res64_train.txt')\n    data.num_workers = 4\n    data.grid_size = 128\n    data.tet_resolution = 64\n    data.num_channels = 4\n    data.use_occ_grid = True\n    data.grid_metafile = os.path.join(data.root_dir, 'metadata/upper_res64_grid_train.txt')\n    data.occgrid_metafile = os.path.join(data.root_dir, 'metadata/upper_res64_occgrid_train.txt')\n\n    data.occ_mask_path = os.path.join(data.root_dir, 'metadata/occ_mask_res64.pt')\n    data.tet_info_path = os.path.join(data.root_dir, 'metadata/tet_info.pt')\n\n    data.filter_meta_path = None\n    data.aug = True\n\n    # training\n    training = config.training = ml_collections.ConfigDict()\n    training.sde = 'vpsde'\n    training.continuous = False\n    training.reduce_mean = True\n    training.batch_size = 1 ### for DDP, global_batch_size = nproc * local_batch_size\n    training.num_grad_acc_steps = 4 \n    training.n_iters = 2400001\n    training.snapshot_freq = 1000\n    training.log_freq = 50\n    ## produce samples at each snapshot.\n    training.snapshot_sampling = True\n    training.likelihood_weighting = False\n    training.loss_type = 'l2'\n    training.train_dir = \"PLACEHOLDER\"\n    training.snapshot_freq_for_preemption = 1000\n    training.gradscaler_growth_interval = 1000\n    training.use_aux_loss = False\n\n\n    training.compile = True # PyTorch 2.0, torch.compile\n    training.enable_xformers_memory_efficient_attention = True\n\n    # sampling\n    sampling = config.sampling = ml_collections.ConfigDict()\n    sampling.method = 'pc'\n    sampling.predictor = 'ancestral_sampling'\n    sampling.corrector = 'none'\n    sampling.n_steps_each = 1\n    sampling.noise_removal = True\n    sampling.probability_flow = False\n    sampling.snr = 0.075\n\n\n    # model\n    model = config.model = ml_collections.ConfigDict()\n    model.name = 'unet3d_occgrid'\n    model.use_occ_grid = True\n    model.num_res_blocks = 2\n    model.num_res_blocks_1st_layer = 2\n    model.base_channels = 128\n    model.ch_mult = (1, 2, 2, 4, 4, 4)\n    model.down_block_types = (\n        \"ResBlock\", \"ResBlock\", \"ResBlock\", \"AttnResBlock\", \"ResBlock\", \"ResBlock\"\n    )\n    model.up_block_types = (\n       \"ResBlock\", \"ResBlock\", \"AttnResBlock\", \"ResBlock\", \"ResBlock\", \"ResBlock\"\n    )\n    model.scale_by_sigma = False\n    model.num_scales = 1000\n    model.ema_rate = 0.9999\n    model.normalization = 'GroupNorm'\n    model.act_fn = 'swish'\n    model.attn_resolutions = (16,)\n    model.resamp_with_conv = True\n    model.dropout = 0.1\n    model.sigma_max = 378\n    model.sigma_min = 0.01\n    model.beta_min = 0.1\n    model.beta_max = 20.\n    model.embedding_type = 'fourier'\n    model.pred_type = 'noise'\n    model.conditional = True\n\n    model.feature_mask_path = os.path.join(data.root_dir, 'metadata/global_mask_res64_occaug_normalized_v1.pt')\n    model.pixcat_mask_path = os.path.join(data.root_dir, 'metadata/cat_mask_res64_occaug_normalized_v1.pt')\n\n    # optimization\n    config.optim = optim = ml_collections.ConfigDict()\n    optim.weight_decay = 1e-5\n    optim.optimizer = 'AdamW'\n    optim.lr = 1e-5\n    optim.beta1 = 0.9\n    optim.eps = 1e-8\n    optim.warmup = 5000\n    optim.grad_clip = 1.\n\n    # eval\n    config.eval = eval_config = ml_collections.ConfigDict()\n    eval_config.batch_size = 2\n    eval_config.idx = 0\n    eval_config.bin_size = 30\n    eval_config.eval_dir = \"PLACEHOLDER\"\n    eval_config.ckpt_path = \"PLACEHOLDER\"\n    \n\n    config.seed = 42\n    config.device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')\n\n\n    return config\n"
  },
  {
    "path": "GMeshDiffusion/lib/dataset/gshell_dataset.py",
    "content": "import torch\nimport numpy as np\nfrom torch.utils.data import Dataset\n\nclass GShellDataset(Dataset):\n    def __init__(self, filepath_metafile, extension='pt'):\n        super().__init__()\n        with open(filepath_metafile, 'r') as f:\n            self.filepath_list = [fpath.rstrip() for fpath in f]\n\n        self.extension = extension\n        assert self.extension in ['pt', 'npy']\n    \n    def __len__(self):\n        return len(self.filepath_list)\n\n    def __getitem__(self, idx):\n        with torch.no_grad():\n            if self.extension == 'pt':\n                datum = torch.load(self.filepath_list[idx], map_location='cpu')\n            else:\n                datum = torch.tensor(np.load(self.filepath_list[idx]))\n        return datum\n"
  },
  {
    "path": "GMeshDiffusion/lib/dataset/gshell_dataset_aug.py",
    "content": "import torch\nfrom torch.utils.data import Dataset\n\nclass GShellAugDataset(Dataset):\n    def __init__(self, FLAGS, extension='pt'):\n        super().__init__()\n        with open(FLAGS.data.grid_metafile, 'r') as f:\n            self.filepath_list = [fpath.rstrip() for fpath in f]\n        with open(FLAGS.data.occgrid_metafile, 'r') as f:\n            self.occ_filepath_list = [fpath.rstrip() for fpath in f]\n\n        self.extension = extension\n        self.num_channels = FLAGS.data.num_channels\n        print('num_channels: ', self.num_channels)\n        assert self.extension in ['pt', 'npy']\n    \n    def __len__(self):\n        return len(self.filepath_list)\n\n    def __getitem__(self, idx):\n        with torch.no_grad():\n            grid = torch.load(self.filepath_list[idx], map_location='cpu')\n            try:\n                occ_grid = torch.load(self.occ_filepath_list[idx], map_location='cpu')\n            except:\n                print(self.occ_filepath_list[idx])\n                raise\n        return (grid[:self.num_channels], occ_grid)\n    \n    @staticmethod\n    def collate(data):\n        return {\n            'grid': torch.stack([x[0] for x in data]),\n            'occgrid': torch.stack([x[1] for x in data]),\n        }\n"
  },
  {
    "path": "GMeshDiffusion/lib/diffusion/evaler.py",
    "content": "import os\nimport sys\nimport numpy as np\nimport tqdm\n\nimport logging\nfrom . import losses\nfrom .models import utils as mutils\nfrom .models.ema import ExponentialMovingAverage\nfrom . import sde_lib\nimport torch\nfrom .utils import restore_checkpoint\nfrom . import sampling\n\ndef uncond_gen(\n        config\n    ):\n    \"\"\"\n        Unconditional Generation\n    \"\"\"\n    with torch.no_grad():\n        eval_dir, ckpt_path = config.eval.eval_dir, config.eval.ckpt_path\n        idx = config.eval.idx\n        bin_size = config.eval.bin_size\n        print(f\"idx to save: {idx * bin_size} to {idx * bin_size + bin_size - 1}\")\n        # Create directory to eval_folder\n        os.makedirs(eval_dir, exist_ok=True)\n\n        scaler, inverse_scaler = lambda x: x, lambda x: x\n\n        # Initialize model\n        score_model = mutils.create_model(config, use_parallel=False)\n        optimizer = losses.get_optimizer(config, score_model.parameters())\n        ema = ExponentialMovingAverage(score_model.parameters(), decay=config.model.ema_rate)\n        state = dict(optimizer=optimizer, model=score_model, ema=ema, step=0)\n\n        # Setup SDEs\n        sde = sde_lib.VPSDE(beta_min=config.model.beta_min, beta_max=config.model.beta_max, N=config.model.num_scales)\n\n        sampling_eps = 1e-3\n        sampling_shape = (config.eval.batch_size,\n                        config.data.num_channels,\n                        config.data.grid_size, config.data.grid_size, config.data.grid_size)\n        sampling_fn = sampling.get_sampling_fn(config, sde, sampling_shape, inverse_scaler, sampling_eps)\n\n        assert os.path.exists(ckpt_path)\n        print('ckpt path:', ckpt_path)\n        try:\n            state = restore_checkpoint(ckpt_path, state, device=config.device)\n        except:\n            raise\n        ema.copy_to(score_model.parameters())\n\n        print(f\"loaded model is trained till iter {state['step'] // config.training.num_grad_acc_steps}\")\n\n\n        for k in range(bin_size):\n            save_file_path = os.path.join(eval_dir, f\"{idx * bin_size + k}\")\n            print(f'check: {save_file_path}')\n            if os.path.exists(save_file_path + '.pt'):\n                # continue\n                pass\n            print(f'will save to: {save_file_path}')\n            samples, n = sampling_fn(score_model)\n            if type(samples) != tuple:\n                print(samples[:, 0].unique())\n                torch.save(samples, save_file_path + '.pt')\n                samples = samples.cpu().numpy()\n                # np.save(save_file_path, samples)\n            else:\n                print(samples[0][:, 0].unique())\n                torch.save(samples[0], save_file_path + '.pt')\n                torch.save(samples[1], save_file_path + '_occ.pt')\n                # samples, occ = samples[0].cpu().numpy(), samples[1].cpu().numpy()\n            # np.save(save_file_path + '.npy, samples)\n\n\ndef slerp(z1, z2, alpha):\n    '''\n        Spherical Linear Interpolation\n    '''\n    theta = torch.acos(torch.sum(z1 * z2) / (torch.norm(z1) * torch.norm(z2)))\n    return (\n            torch.sin((1 - alpha) * theta) / torch.sin(theta) * z1\n            + torch.sin(alpha * theta) / torch.sin(theta) * z2\n    )\n\ndef uncond_gen_interp(\n        config,\n        idx=0,\n    ):\n    \"\"\"\n        Generation with interpolation between initial noises\n        Used for DDIM\n    \"\"\"\n    with torch.no_grad():\n        eval_dir, ckpt_path = config.eval.eval_dir, config.eval.ckpt_path\n        # Create directory to eval_folder\n        os.makedirs(eval_dir, exist_ok=True)\n\n        scaler, inverse_scaler = lambda x: x, lambda x: x\n\n        # Initialize model\n        score_model = mutils.create_model(config, use_parallel=False)\n        optimizer = losses.get_optimizer(config, score_model.parameters())\n        ema = ExponentialMovingAverage(score_model.parameters(), decay=config.model.ema_rate)\n        state = dict(optimizer=optimizer, model=score_model, ema=ema, step=0)\n\n        # Setup SDEs\n        sde = sde_lib.VPSDE(beta_min=config.model.beta_min, beta_max=config.model.beta_max, N=config.model.num_scales)\n\n        sampling_eps = 1e-3\n        sampling_shape = (config.eval.batch_size,\n                        config.data.num_channels,\n                        config.data.grid_size, config.data.grid_size, config.data.grid_size)\n        sampling_fn = sampling.get_sampling_fn(config, sde, sampling_shape, inverse_scaler, sampling_eps)\n\n        assert os.path.exists(ckpt_path)\n        print('ckpt path:', ckpt_path)\n        try:\n            state = restore_checkpoint(ckpt_path, state, device=config.device)\n        except:\n            raise\n        ema.copy_to(score_model.parameters())\n\n        print(f\"loaded model is trained till iter {state['step'] // config.training.num_grad_acc_steps}\")\n\n\n        idx = config.eval.idx\n        bin_size = config.eval.bin_size\n        config.eval.interp_batch_size = 32\n        print(f\"idx to save: {idx * bin_size} to {idx * bin_size + bin_size - 1}\")\n\n        for k in range(bin_size):\n            save_file_path = os.path.join(eval_dir, f\"{idx * bin_size + k}\")\n\n            noise = sde.prior_sampling(\n                (2, config.data.num_channels, config.data.grid_size, config.data.grid_size, config.data.grid_size)\n            ).to(config.device)\n        \n            interp_sampling_shape = (config.eval.interp_batch_size,\n                            config.data.num_channels,\n                            config.data.grid_size, config.data.grid_size, config.data.grid_size)\n            x0 = torch.zeros(interp_sampling_shape, device=config.device)\n            x0[0] = noise[0]\n            x0[-1] = noise[1]\n            for i in range(1, config.eval.interp_batch_size - 1):\n                x0[i] = slerp(x0[0], x0[-1], i / float(config.eval.interp_batch_size - 1))\n\n            if config.model.use_occ_grid:\n                noise_occ = sde.prior_sampling(\n                    (2, 1, config.data.grid_size * 2, config.data.grid_size * 2, config.data.grid_size * 2)\n                ).to(config.device)\n                interp_sampling_shape = (config.eval.interp_batch_size,\n                                1,\n                                config.data.grid_size * 2, config.data.grid_size * 2, config.data.grid_size * 2)\n                x0_occ = torch.zeros(interp_sampling_shape, device=config.device)\n                x0_occ[0] = noise_occ[0]\n                x0_occ[-1] = noise_occ[1]\n                for i in range(1, config.eval.interp_batch_size - 1):\n                    x0_occ[i] = slerp(x0_occ[0], x0_occ[-1], i / float(config.eval.interp_batch_size - 1))\n            else:\n                x0_occ = None\n\n            sample_list = []\n            sample_occ_list = []\n            for i in tqdm.trange(config.eval.interp_batch_size):\n                samples, n = sampling_fn(score_model, x0=x0[i:i+1], x0_occ=x0_occ[i:i+1])\n                if type(samples) != tuple:\n                    # samples = samples.cpu()\n                    sample_list.append(samples.cpu())\n                else:\n                    # samples = samples.cpu()\n                    sample_list.append(samples[0].cpu())\n                    sample_occ_list.append(samples[1].cpu())\n\n            # np.save(save_file_path, np.concatenate(sample_list, axis=0))\n            torch.save(torch.cat(sample_list, dim=0), save_file_path + '.pt')\n            if config.model.use_occ_grid:\n                torch.save(torch.cat(sample_occ_list, dim=0), save_file_path + '_occ.pt')\n\n\ndef cond_gen(\n        config,\n        save_fname='0',\n    ):\n    \"\"\"\n        Conditional Generation with partially completed dmtet from a 2.5D view (converted into a cubic grid)\n    \"\"\"\n    with torch.no_grad():\n        eval_dir, ckpt_path = config.eval.eval_dir, config.eval.ckpt_path\n        # Create directory to eval_folder\n        os.makedirs(eval_dir, exist_ok=True)\n\n        scaler, inverse_scaler = lambda x: x, lambda x: x\n\n        # Initialize model\n        score_model = mutils.create_model(config)\n        optimizer = losses.get_optimizer(config, score_model.parameters())\n        ema = ExponentialMovingAverage(score_model.parameters(), decay=config.model.ema_rate)\n        state = dict(optimizer=optimizer, model=score_model, ema=ema, step=0)\n\n        # Setup SDEs\n        sde = sde_lib.VPSDE(beta_min=config.model.beta_min, beta_max=config.model.beta_max, N=config.model.num_scales)\n\n        resolution = config.data.image_size\n        grid_mask = torch.load(f'./data/grid_mask_{resolution}.pt').view(1, 1, resolution, resolution, resolution).to(\"cuda\")\n        grid_mask = grid_mask[:, :, :config.data.input_size, :config.data.input_size, :config.data.input_size]\n\n        sampling_eps = 1e-3\n        sampling_shape = (config.eval.batch_size,\n                        config.data.num_channels,\n                        # resolution, resolution, resolution)\n                        config.data.input_size, config.data.input_size, config.data.input_size)\n        sampling_fn = sampling.get_sampling_fn(config, sde, sampling_shape, inverse_scaler, sampling_eps, grid_mask=grid_mask)\n\n        assert os.path.exists(ckpt_path)\n        print('ckpt path:', ckpt_path)\n        try:\n            state = restore_checkpoint(ckpt_path, state, device=config.device)\n        except:\n            raise\n        ema.copy_to(score_model.parameters())\n\n        print(f\"loaded model is trained till iter {state['step'] // config.training.iter_size}\")\n\n        \n        save_file_path = os.path.join(eval_dir, f\"{save_fname}.npy\")\n\n        ### Conditional but free gradients; start from small t\n\n        partial_dict = torch.load(config.eval.partial_dmtet_path)\n        partial_sdf = partial_dict['sdf']\n        partial_mask = partial_dict['vis']\n\n\n        ### compute the mapping from tet indices to 3D cubic grid vertex indices\n        tet_path = config.eval.tet_path\n        tet = np.load(tet_path)\n        vertices = torch.tensor(tet['vertices'])\n        vertices_unique = vertices[:].unique()\n        dx = vertices_unique[1] - vertices_unique[0]\n\n        ind_to_coord = (torch.round(\n            (vertices - vertices.min()) / dx)\n        ).long()\n\n        \n        partial_sdf_grid = torch.zeros((1, 1, resolution, resolution, resolution))\n        partial_sdf_grid[0, 0, ind_to_coord[:, 0], ind_to_coord[:, 1], ind_to_coord[:, 2]] = partial_sdf\n        partial_mask_grid = torch.zeros((1, 1, resolution, resolution, resolution))\n        partial_mask_grid[0, 0, ind_to_coord[:, 0], ind_to_coord[:, 1], ind_to_coord[:, 2]] = partial_mask.float()\n\n        samples, n = sampling_fn(\n            score_model, \n            partial=partial_sdf_grid.cuda(), \n            partial_mask=partial_mask_grid.cuda(), \n            freeze_iters=config.eval.freeze_iters\n        )\n\n        samples = samples.cpu().numpy()\n        np.save(save_file_path, samples)\n\n"
  },
  {
    "path": "GMeshDiffusion/lib/diffusion/likelihood.py",
    "content": "# coding=utf-8\n# Copyright 2020 The Google Research Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n# pylint: skip-file\n# pytype: skip-file\n\"\"\"Various sampling methods.\"\"\"\n\nimport torch\nimport numpy as np\nfrom scipy import integrate\nfrom .models import utils as mutils\n\n\ndef get_div_fn(fn):\n  \"\"\"Create the divergence function of `fn` using the Hutchinson-Skilling trace estimator.\"\"\"\n\n  def div_fn(x, t, eps):\n    with torch.enable_grad():\n      x.requires_grad_(True)\n      fn_eps = torch.sum(fn(x, t) * eps)\n      grad_fn_eps = torch.autograd.grad(fn_eps, x)[0]\n    x.requires_grad_(False)\n    return torch.sum(grad_fn_eps * eps, dim=tuple(range(1, len(x.shape))))\n\n  return div_fn\n\n\ndef get_likelihood_fn(sde, inverse_scaler, hutchinson_type='Rademacher',\n                      rtol=1e-5, atol=1e-5, method='RK45', eps=1e-5):\n  \"\"\"Create a function to compute the unbiased log-likelihood estimate of a given data point.\n\n  Args:\n    sde: A `sde_lib.SDE` object that represents the forward SDE.\n    inverse_scaler: The inverse data normalizer.\n    hutchinson_type: \"Rademacher\" or \"Gaussian\". The type of noise for Hutchinson-Skilling trace estimator.\n    rtol: A `float` number. The relative tolerance level of the black-box ODE solver.\n    atol: A `float` number. The absolute tolerance level of the black-box ODE solver.\n    method: A `str`. The algorithm for the black-box ODE solver.\n      See documentation for `scipy.integrate.solve_ivp`.\n    eps: A `float` number. The probability flow ODE is integrated to `eps` for numerical stability.\n\n  Returns:\n    A function that a batch of data points and returns the log-likelihoods in bits/dim,\n      the latent code, and the number of function evaluations cost by computation.\n  \"\"\"\n\n  def drift_fn(model, x, t):\n    \"\"\"The drift function of the reverse-time SDE.\"\"\"\n    score_fn = mutils.get_score_fn(sde, model, train=False, continuous=True)\n    # Probability flow ODE is a special case of Reverse SDE\n    rsde = sde.reverse(score_fn, probability_flow=True)\n    return rsde.sde(x, t)[0]\n\n  def div_fn(model, x, t, noise):\n    return get_div_fn(lambda xx, tt: drift_fn(model, xx, tt))(x, t, noise)\n\n  def likelihood_fn(model, data):\n    \"\"\"Compute an unbiased estimate to the log-likelihood in bits/dim.\n\n    Args:\n      model: A score model.\n      data: A PyTorch tensor.\n\n    Returns:\n      bpd: A PyTorch tensor of shape [batch size]. The log-likelihoods on `data` in bits/dim.\n      z: A PyTorch tensor of the same shape as `data`. The latent representation of `data` under the\n        probability flow ODE.\n      nfe: An integer. The number of function evaluations used for running the black-box ODE solver.\n    \"\"\"\n    with torch.no_grad():\n      shape = data.shape\n      if hutchinson_type == 'Gaussian':\n        epsilon = torch.randn_like(data)\n      elif hutchinson_type == 'Rademacher':\n        epsilon = torch.randint_like(data, low=0, high=2).float() * 2 - 1.\n      else:\n        raise NotImplementedError(f\"Hutchinson type {hutchinson_type} unknown.\")\n\n      def ode_func(t, x):\n        sample = mutils.from_flattened_numpy(x[:-shape[0]], shape).to(data.device).type(torch.float32)\n        vec_t = torch.ones(sample.shape[0], device=sample.device) * t\n        drift = mutils.to_flattened_numpy(drift_fn(model, sample, vec_t))\n        logp_grad = mutils.to_flattened_numpy(div_fn(model, sample, vec_t, epsilon))\n        return np.concatenate([drift, logp_grad], axis=0)\n\n      init = np.concatenate([mutils.to_flattened_numpy(data), np.zeros((shape[0],))], axis=0)\n      solution = integrate.solve_ivp(ode_func, (eps, sde.T), init, rtol=rtol, atol=atol, method=method)\n      nfe = solution.nfev\n      zp = solution.y[:, -1]\n      z = mutils.from_flattened_numpy(zp[:-shape[0]], shape).to(data.device).type(torch.float32)\n      delta_logp = mutils.from_flattened_numpy(zp[-shape[0]:], (shape[0],)).to(data.device).type(torch.float32)\n      prior_logp = sde.prior_logp(z)\n      bpd = -(prior_logp + delta_logp) / np.log(2)\n      N = np.prod(shape[1:])\n      bpd = bpd / N\n      # A hack to convert log-likelihoods to bits/dim\n      offset = 7. - inverse_scaler(-1.)\n      bpd = bpd + offset\n      return bpd, z, nfe\n\n  return likelihood_fn\n"
  },
  {
    "path": "GMeshDiffusion/lib/diffusion/losses.py",
    "content": "# coding=utf-8\n# Copyright 2020 The Google Research Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"All functions related to loss computation and optimization.\n\"\"\"\n\nimport torch\nimport torch.optim as optim\nimport numpy as np\nfrom .models import utils as mutils\n\n\ndef get_optimizer(config, params):\n  \"\"\"Returns a flax optimizer object based on `config`.\"\"\"\n  if config.optim.optimizer == 'Adam':\n    optimizer = optim.Adam(params, lr=config.optim.lr, betas=(config.optim.beta1, 0.999), eps=config.optim.eps,\n                           weight_decay=config.optim.weight_decay)\n  elif config.optim.optimizer == 'AdamW':\n    optimizer = optim.AdamW(params, lr=config.optim.lr, betas=(config.optim.beta1, 0.999), eps=config.optim.eps,\n                           weight_decay=config.optim.weight_decay)\n  else:\n    raise NotImplementedError(\n      f'Optimizer {config.optim.optimizer} not supported yet!')\n\n  return optimizer\n\n\ndef optimization_manager(config):\n  \"\"\"Returns an optimize_fn based on `config`.\"\"\"\n\n  def optimize_fn(optimizer, params, step, lr=config.optim.lr,\n                  warmup=config.optim.warmup,\n                  grad_clip=config.optim.grad_clip,\n                  gradscaler=None):\n    \"\"\"Optimizes with warmup and gradient clipping (disabled if negative).\"\"\"\n    if warmup > 0:\n      for g in optimizer.param_groups:\n        g['lr'] = lr * np.minimum(step / warmup, 1.0)\n    if grad_clip >= 0:\n      gradscaler.unscale_(optimizer)\n      torch.nn.utils.clip_grad_norm_(params, max_norm=grad_clip)\n    gradscaler.step(optimizer)\n    gradscaler.update()\n    # optimizer.step()\n\n  return optimize_fn\n\ndef get_ddpm_loss_fn(vpsde, train, loss_type='l2', pred_type='noise', use_vis_mask=False, use_occ=False, use_aux=False):\n  \"\"\"Legacy code to reproduce previous results on DDPM. Not recommended for new work.\"\"\"\n\n\n  if use_occ:\n    def loss_fn(model, batch, use_mesh_reg=False, verts_discretiezd=None, midpoints_discretiezd=None, edges=None):\n      batch, batch_occ = batch['grid'], batch['occgrid']\n      model_fn = mutils.get_model_fn(model, train=train)\n      labels = torch.randint(0, vpsde.N, (batch.shape[0],), device=batch.device)\n      sqrt_alphas_cumprod = vpsde.sqrt_alphas_cumprod.to(batch.device)\n      sqrt_1m_alphas_cumprod = vpsde.sqrt_1m_alphas_cumprod.to(batch.device)\n      with torch.no_grad():\n        noise = torch.randn_like(batch, device=batch.device)\n        noise_occ = torch.randn_like(batch_occ, device=batch.device)\n        perturbed_data = sqrt_alphas_cumprod[labels, None, None, None, None] * batch + \\\n                        sqrt_1m_alphas_cumprod[labels, None, None, None, None] * noise\n        perturbed_data = perturbed_data.type(batch.dtype)\n        perturbed_data_occ = sqrt_alphas_cumprod[labels, None, None, None, None] * batch_occ + \\\n                        sqrt_1m_alphas_cumprod[labels, None, None, None, None] * noise_occ\n        perturbed_data_occ = perturbed_data_occ.type(batch_occ.dtype)\n\n\n      with torch.cuda.amp.autocast(dtype=torch.bfloat16):\n        pred, pred_occ = model_fn((perturbed_data, perturbed_data_occ), labels)\n    \n      pred, pred_occ = pred.float(), pred_occ.float()\n      alphas1 = sqrt_alphas_cumprod[labels, None, None, None, None]\n      alphas2 = sqrt_1m_alphas_cumprod[labels, None, None, None, None]\n      if pred_type == 'noise':\n        score = pred\n        score_occ = pred_occ\n        x0 = (perturbed_data - score * alphas2) / alphas1\n        x0_occ = (perturbed_data_occ - score_occ * alphas2) / alphas1\n      elif pred_type == 'x0':\n        x0 = pred\n        x0_occ = pred_occ\n        score = (perturbed_data - x0 * alphas1) / alphas2\n        score_occ = (perturbed_data_occ - pred_occ * alphas1) / alphas2\n      \n      # noise = noise[:, :, :score.size(2), :score.size(3), :score.size(4)] ### to accommodate change of size due to arch\n      if loss_type == 'l2':\n        losses = torch.square(score - noise)\n        losses_occ = torch.square(score_occ - noise_occ)\n        assert losses_occ.size(1) == 1\n      elif loss_type == 'l1':\n        raise NotImplementedError\n        losses = torch.abs(score - noise)\n      else:\n        raise NotImplementedError\n\n      mask = model.module.feature_mask\n      occ_mask = model.module.occ_mask\n      assert len(mask.size()) == 5\n      assert mask.size(1) == losses.size(1)\n      assert occ_mask.size(1) == losses_occ.size(1)\n      assert losses.size(0) == losses_occ.size(0)\n      if mask is not None:\n        losses = losses * mask\n        losses_occ = losses_occ * occ_mask\n        occ_loss_scale = 1.0 if not use_aux else 1.0\n        loss = (torch.sum(losses) + torch.sum(losses_occ)) / (mask.sum() + occ_mask.sum()) / losses.size(0)\n      else:\n        raise NotImplementedError\n        \n      if use_aux:\n        pred_vis = model.module.extract_vis_from_cubicgrid(x0, x0_occ)\n        with torch.no_grad():\n          gt_vis = model.module.extract_vis_from_cubicgrid(batch, batch_occ.view(*x0_occ.size()))\n        reg_loss = (\n          (pred_vis - gt_vis).pow(2).view(x0.size(0), -1).mean(dim=-1) * sqrt_alphas_cumprod[labels]\n        ).mean()\n      else:\n        reg_loss = torch.zeros_like(loss)\n      total_loss = loss + reg_loss\n\n      return total_loss, loss, reg_loss\n  else:\n    def loss_fn(model, batch, use_mesh_reg=False, verts_discretiezd=None, midpoints_discretiezd=None, edges=None):\n      model_fn = mutils.get_model_fn(model, train=train)\n      labels = torch.randint(0, vpsde.N, (batch.shape[0],), device=batch.device)\n      sqrt_alphas_cumprod = vpsde.sqrt_alphas_cumprod.to(batch.device)\n      sqrt_1m_alphas_cumprod = vpsde.sqrt_1m_alphas_cumprod.to(batch.device)\n      noise = torch.randn_like(batch, device=batch.device)\n      perturbed_data = sqrt_alphas_cumprod[labels, None, None, None, None] * batch + \\\n                      sqrt_1m_alphas_cumprod[labels, None, None, None, None] * noise\n      perturbed_data = perturbed_data.type(batch.dtype)\n\n      with torch.cuda.amp.autocast(dtype=torch.bfloat16):\n        pred = model_fn(perturbed_data, labels)\n      pred = pred.float()\n      alphas1 = sqrt_alphas_cumprod[labels, None, None, None, None]\n      alphas2 = sqrt_1m_alphas_cumprod[labels, None, None, None, None]\n      if pred_type == 'noise':\n        score = pred\n        x0 = (perturbed_data - score * alphas2) / alphas1\n      elif pred_type == 'x0':\n        x0 = pred\n        score = (perturbed_data - x0 * alphas1) / alphas2\n      \n      if use_vis_mask:\n        assert x0.size(0) == 1\n        vis_mask = model.extract_vismask_from_cubicgrid(x0)\n        # noise = noise[:, :, :score.size(2), :score.size(3), :score.size(4)] ### to accommodate change of size due to arch\n        if loss_type == 'l2':\n          losses = torch.square((score - noise) * vis_mask)\n        elif loss_type == 'l1':\n          losses = torch.abs((score - noise) * vis_mask)\n        else:\n          raise NotImplementedError\n      else:\n        if loss_type == 'l2':\n          losses = torch.square(score - noise)\n        elif loss_type == 'l1':\n          losses = torch.abs(score - noise)\n        else:\n          raise NotImplementedError\n\n      mask = model.module.feature_mask\n      assert len(mask.size()) == 5\n      assert mask.size(1) == losses.size(1)\n      if mask is not None:\n        losses = losses * mask\n        loss = torch.sum(losses) / mask.sum() / losses.size(0)\n      else:\n        raise NotImplementedError\n\n\n      reg_loss = torch.zeros_like(loss)\n      total_loss = loss\n\n      return total_loss, loss, reg_loss\n\n  return loss_fn\n\ndef get_step_fn(sde, train, optimize_fn=None, loss_type='l2', pred_type='noise', use_vis_mask=False, use_occ=False, use_aux=False):\n  \"\"\"Create a one-step training/evaluation function.\n\n  Args:\n    sde: An `sde_lib.SDE` object that represents the forward SDE.\n    optimize_fn: An optimization function.\n    reduce_mean: If `True`, average the loss across data dimensions. Otherwise sum the loss across data dimensions.\n    continuous: `True` indicates that the model is defined to take continuous time steps.\n    likelihood_weighting: If `True`, weight the mixture of score matching losses according to\n      https://arxiv.org/abs/2101.09258; otherwise use the weighting recommended by our paper.\n\n  Returns:\n    A one-step function for training or evaluation.\n  \"\"\"\n  \n  loss_fn = get_ddpm_loss_fn(sde, train, loss_type=loss_type, pred_type=pred_type, use_vis_mask=use_vis_mask, use_occ=use_occ, use_aux=use_aux)\n\n  def step_fn(state, batch, clear_grad=True, update_param=True, gradscaler=None):\n    \"\"\"Running one step of training or evaluation.\n\n    This function will undergo `jax.lax.scan` so that multiple steps can be pmapped and jit-compiled together\n    for faster execution.\n\n    Args:\n      state: A dictionary of training information, containing the score model, optimizer,\n       EMA status, and number of optimization steps.\n      batch: A mini-batch of training/evaluation data.\n\n    Returns:\n      loss: The average loss value of this state.\n    \"\"\"\n    model = state['model']\n    if train:\n      optimizer = state['optimizer']\n      if clear_grad:\n        optimizer.zero_grad()\n      loss_total, loss_score, loss_reg = loss_fn(model, batch)\n      gradscaler.scale(loss_total).backward()\n      if update_param:\n        optimize_fn(optimizer, model.parameters(), step=state['step'], gradscaler=gradscaler)\n      state['step'] += 1\n      state['ema'].update(model.parameters())\n    else:\n      with torch.no_grad():\n        ema = state['ema']\n        ema.store(model.parameters())\n        ema.copy_to(model.parameters())\n        loss_total, loss_score, loss_reg = loss_fn(model, batch)\n        ema.restore(model.parameters())\n\n    return {\n      'loss_total': loss_total,\n      'loss_score': loss_score,\n      'loss_reg': loss_reg,\n    }\n\n  return step_fn"
  },
  {
    "path": "GMeshDiffusion/lib/diffusion/models/__init__.py",
    "content": "# coding=utf-8\n# Copyright 2020 The Google Research Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n"
  },
  {
    "path": "GMeshDiffusion/lib/diffusion/models/ema.py",
    "content": "# Modified from https://raw.githubusercontent.com/fadel/pytorch_ema/master/torch_ema/ema.py\n\nfrom __future__ import division\nfrom __future__ import unicode_literals\n\nimport torch\n\n\n# Partially based on: https://github.com/tensorflow/tensorflow/blob/r1.13/tensorflow/python/training/moving_averages.py\nclass ExponentialMovingAverage:\n  \"\"\"\n  Maintains (exponential) moving average of a set of parameters.\n  \"\"\"\n\n  def __init__(self, parameters, decay, use_num_updates=True):\n    \"\"\"\n    Args:\n      parameters: Iterable of `torch.nn.Parameter`; usually the result of\n        `model.parameters()`.\n      decay: The exponential decay.\n      use_num_updates: Whether to use number of updates when computing\n        averages.\n    \"\"\"\n    if decay < 0.0 or decay > 1.0:\n      raise ValueError('Decay must be between 0 and 1')\n    self.decay = decay\n    self.num_updates = 0 if use_num_updates else None\n    self.shadow_params = [p.clone().detach()\n                          for p in parameters if p.requires_grad]\n    self.collected_params = []\n\n  def update(self, parameters):\n    \"\"\"\n    Update currently maintained parameters.\n\n    Call this every time the parameters are updated, such as the result of\n    the `optimizer.step()` call.\n\n    Args:\n      parameters: Iterable of `torch.nn.Parameter`; usually the same set of\n        parameters used to initialize this object.\n    \"\"\"\n    decay = self.decay\n    if self.num_updates is not None:\n      self.num_updates += 1\n      decay = min(decay, (1 + self.num_updates) / (10 + self.num_updates))\n    one_minus_decay = 1.0 - decay\n    with torch.no_grad():\n      parameters = [p for p in parameters if p.requires_grad]\n      for s_param, param in zip(self.shadow_params, parameters):\n        # print(s_param.device, s_param.device, param.device)\n        s_param.sub_(one_minus_decay * (s_param - param))\n\n  def copy_to(self, parameters):\n    \"\"\"\n    Copy current parameters into given collection of parameters.\n\n    Args:\n      parameters: Iterable of `torch.nn.Parameter`; the parameters to be\n        updated with the stored moving averages.\n    \"\"\"\n    parameters = [p for p in parameters if p.requires_grad]\n    for s_param, param in zip(self.shadow_params, parameters):\n      if param.requires_grad:\n        param.data.copy_(s_param.data)\n\n  def store(self, parameters):\n    \"\"\"\n    Save the current parameters for restoring later.\n\n    Args:\n      parameters: Iterable of `torch.nn.Parameter`; the parameters to be\n        temporarily stored.\n    \"\"\"\n    self.collected_params = [param.clone() for param in parameters]\n\n  def restore(self, parameters):\n    \"\"\"\n    Restore the parameters stored with the `store` method.\n    Useful to validate the model with EMA parameters without affecting the\n    original optimization process. Store the parameters before the\n    `copy_to` method. After validation (or model saving), use this to\n    restore the former parameters.\n\n    Args:\n      parameters: Iterable of `torch.nn.Parameter`; the parameters to be\n        updated with the stored parameters.\n    \"\"\"\n    for c_param, param in zip(self.collected_params, parameters):\n      param.data.copy_(c_param.data)\n\n  def state_dict(self):\n    return dict(decay=self.decay, num_updates=self.num_updates,\n                shadow_params=self.shadow_params)\n\n  def load_state_dict(self, state_dict, device='cuda'):\n    self.decay = state_dict['decay']\n    self.num_updates = state_dict['num_updates']\n    self.shadow_params = state_dict['shadow_params']\n    for k, _ in enumerate(self.shadow_params):\n      self.shadow_params[k] = self.shadow_params[k].to(device)\n    # for k in self.shadow_params:\n    #   print(k.device)\n    # raise"
  },
  {
    "path": "GMeshDiffusion/lib/diffusion/models/functional.py",
    "content": "#################################################################################################\n# Copyright (c) 2023 Ali Hassani.\n#\n# Permission is hereby granted, free of charge, to any person obtaining a copy\n# of this software and associated documentation files (the \"Software\"), to deal\n# in the Software without restriction, including without limitation the rights\n# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell\n# copies of the Software, and to permit persons to whom the Software is\n# furnished to do so, subject to the following conditions:\n#\n# The above copyright notice and this permission notice shall be included in all\n# copies or substantial portions of the Software.\n#\n# THE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\n# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\n# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\n# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\n# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\n# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE\n# SOFTWARE.\n#\n#################################################################################################\nimport torch\nfrom torch.autograd import Function\nfrom torch.cuda.amp import custom_bwd, custom_fwd\n\ntry:\n    from natten import _C\nexcept ImportError:\n    raise ImportError(\n        f\"Failed to import NATTEN's CPP backend. \"\n        + f\"This could be due to an invalid/incomplete install. \"\n        + f\"Please uninstall NATTEN (pip uninstall natten) and re-install with the\"\n        f\" correct torch build: \"\n        + f\"shi-labs.com/natten\"\n    )\n\n\ndef has_cuda():\n    return _C.has_cuda()\n\n\ndef has_half():\n    return _C.has_half()\n\n\ndef has_bfloat():\n    return _C.has_bfloat()\n\n\ndef has_gemm():\n    return _C.has_gemm()\n\n\ndef enable_tf32():\n    return _C.set_gemm_tf32(True)\n\n\ndef disable_tf32():\n    return _C.set_gemm_tf32(False)\n\n\ndef enable_tiled_na():\n    return _C.set_tiled_na(True)\n\n\ndef disable_tiled_na():\n    return _C.set_tiled_na(False)\n\n\ndef enable_gemm_na():\n    return _C.set_gemm_na(True)\n\n\ndef disable_gemm_na():\n    return _C.set_gemm_na(False)\n\n\nclass NeighborhoodAttention1DQKAutogradFunction(Function):\n    @staticmethod\n    @custom_fwd\n    def forward(ctx, query, key, rpb, kernel_size, dilation):\n        query = query.contiguous()\n        key = key.contiguous()\n        attn = _C.na1d_qk_forward(query, key, rpb, kernel_size, dilation)\n        ctx.save_for_backward(query, key)\n        ctx.kernel_size = kernel_size\n        ctx.dilation = dilation\n        ctx.bias = rpb is not None\n        return attn\n\n    @staticmethod\n    @custom_bwd\n    def backward(ctx, grad_out):\n        outputs = _C.na1d_qk_backward(\n            grad_out.contiguous(),\n            ctx.saved_tensors[0],\n            ctx.saved_tensors[1],\n            ctx.bias,\n            ctx.kernel_size,\n            ctx.dilation,\n        )\n        d_query, d_key, d_rpb = outputs\n        return d_query, d_key, d_rpb, None, None\n\n\nclass NeighborhoodAttention1DAVAutogradFunction(Function):\n    @staticmethod\n    @custom_fwd\n    def forward(ctx, attn, value, kernel_size, dilation):\n        attn = attn.contiguous()\n        value = value.contiguous()\n        out = _C.na1d_av_forward(attn, value, kernel_size, dilation)\n        ctx.save_for_backward(attn, value)\n        ctx.kernel_size = kernel_size\n        ctx.dilation = dilation\n        return out\n\n    @staticmethod\n    @custom_bwd\n    def backward(ctx, grad_out):\n        outputs = _C.na1d_av_backward(\n            grad_out.contiguous(),\n            ctx.saved_tensors[0],\n            ctx.saved_tensors[1],\n            ctx.kernel_size,\n            ctx.dilation,\n        )\n        d_attn, d_value = outputs\n        return d_attn, d_value, None, None\n\n\nclass NeighborhoodAttention2DQKAutogradFunction(Function):\n    @staticmethod\n    @custom_fwd\n    def forward(ctx, query, key, rpb, kernel_size, dilation):\n        query = query.contiguous()\n        key = key.contiguous()\n        if rpb is not None:\n            rpb = rpb.to(key.dtype)\n        attn = _C.na2d_qk_forward(query, key, rpb, kernel_size, dilation)\n        ctx.save_for_backward(query, key)\n        ctx.kernel_size = kernel_size\n        ctx.dilation = dilation\n        ctx.bias = rpb is not None\n        return attn\n\n    @staticmethod\n    @custom_bwd\n    def backward(ctx, grad_out):\n        outputs = _C.na2d_qk_backward(\n            grad_out.contiguous(),\n            ctx.saved_tensors[0],\n            ctx.saved_tensors[1],\n            ctx.bias,\n            ctx.kernel_size,\n            ctx.dilation,\n        )\n        d_query, d_key, d_rpb = outputs\n        return d_query, d_key, d_rpb, None, None\n\n\nclass NeighborhoodAttention2DAVAutogradFunction(Function):\n    @staticmethod\n    @custom_fwd\n    def forward(ctx, attn, value, kernel_size, dilation):\n        attn = attn.contiguous().to(value.dtype)\n        value = value.contiguous()\n        out = _C.na2d_av_forward(attn, value, kernel_size, dilation)\n        ctx.save_for_backward(attn, value)\n        ctx.kernel_size = kernel_size\n        ctx.dilation = dilation\n        return out\n\n    @staticmethod\n    @custom_bwd\n    def backward(ctx, grad_out):\n        outputs = _C.na2d_av_backward(\n            grad_out.contiguous(),\n            ctx.saved_tensors[0],\n            ctx.saved_tensors[1],\n            ctx.kernel_size,\n            ctx.dilation,\n        )\n        d_attn, d_value = outputs\n        return d_attn, d_value, None, None\n\n\nclass NeighborhoodAttention3DQKAutogradFunction(Function):\n    @staticmethod\n    @custom_fwd\n    def forward(ctx, query, key, rpb, kernel_size_d, kernel_size, dilation_d, dilation):\n        query = query.contiguous()\n        key = key.contiguous()\n        attn = _C.na3d_qk_forward(\n            query, key, rpb, kernel_size, dilation, kernel_size_d, dilation_d\n        )\n        ctx.save_for_backward(query, key)\n        ctx.kernel_size_d = kernel_size_d\n        ctx.kernel_size = kernel_size\n        ctx.dilation_d = dilation_d\n        ctx.dilation = dilation\n        ctx.bias = rpb is not None\n        return attn\n\n    @staticmethod\n    @custom_bwd\n    def backward(ctx, grad_out):\n        outputs = _C.na3d_qk_backward(\n            grad_out.contiguous(),\n            ctx.saved_tensors[0],\n            ctx.saved_tensors[1],\n            ctx.bias,\n            ctx.kernel_size,\n            ctx.dilation,\n            ctx.kernel_size_d,\n            ctx.dilation_d,\n        )\n        d_query, d_key, d_rpb = outputs\n        return d_query, d_key, d_rpb, None, None, None, None\n\n\nclass NeighborhoodAttention3DAVAutogradFunction(Function):\n    @staticmethod\n    @custom_fwd\n    def forward(ctx, attn, value, kernel_size_d, kernel_size, dilation_d, dilation):\n        attn = attn.contiguous()\n        value = value.contiguous()\n        out = _C.na3d_av_forward(\n            attn, value, kernel_size, dilation, kernel_size_d, dilation_d\n        )\n        ctx.save_for_backward(attn, value)\n        ctx.kernel_size_d = kernel_size_d\n        ctx.kernel_size = kernel_size\n        ctx.dilation_d = dilation_d\n        ctx.dilation = dilation\n        return out\n\n    @staticmethod\n    @custom_bwd\n    def backward(ctx, grad_out):\n        outputs = _C.na3d_av_backward(\n            grad_out.contiguous(),\n            ctx.saved_tensors[0],\n            ctx.saved_tensors[1],\n            ctx.kernel_size,\n            ctx.dilation,\n            ctx.kernel_size_d,\n            ctx.dilation_d,\n        )\n        d_attn, d_value = outputs\n        return d_attn, d_value, None, None, None, None\n\n\ndef natten1dqkrpb(query, key, rpb, kernel_size, dilation):\n    return NeighborhoodAttention1DQKAutogradFunction.apply(\n        query, key, rpb, kernel_size, dilation\n    )\n\n\ndef natten1dqk(query, key, kernel_size, dilation):\n    return NeighborhoodAttention1DQKAutogradFunction.apply(\n        query, key, None, kernel_size, dilation\n    )\n\n\ndef natten1dav(attn, value, kernel_size, dilation):\n    return NeighborhoodAttention1DAVAutogradFunction.apply(\n        attn, value, kernel_size, dilation\n    )\n\n\ndef natten2dqkrpb(query, key, rpb, kernel_size, dilation):\n    return NeighborhoodAttention2DQKAutogradFunction.apply(\n        query, key, rpb, kernel_size, dilation\n    )\n\n\ndef natten2dqk(query, key, kernel_size, dilation):\n    return NeighborhoodAttention2DQKAutogradFunction.apply(\n        query, key, None, kernel_size, dilation\n    )\n\n\ndef natten2dav(attn, value, kernel_size, dilation):\n    return NeighborhoodAttention2DAVAutogradFunction.apply(\n        attn, value, kernel_size, dilation\n    )\n\n\ndef natten3dqkrpb(query, key, rpb, kernel_size_d, kernel_size, dilation_d, dilation):\n    return NeighborhoodAttention3DQKAutogradFunction.apply(\n        query, key, rpb, kernel_size_d, kernel_size, dilation_d, dilation\n    )\n\n\ndef natten3dqk(query, key, kernel_size_d, kernel_size, dilation_d, dilation):\n    return NeighborhoodAttention3DQKAutogradFunction.apply(\n        query, key, None, kernel_size_d, kernel_size, dilation_d, dilation\n    )\n\n\ndef natten3dav(attn, value, kernel_size_d, kernel_size, dilation_d, dilation):\n    return NeighborhoodAttention3DAVAutogradFunction.apply(\n        attn, value, kernel_size_d, kernel_size, dilation_d, dilation\n    )"
  },
  {
    "path": "GMeshDiffusion/lib/diffusion/models/layers.py",
    "content": "# coding=utf-8\n# Copyright 2020 The Google Research Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n# pylint: skip-file\n\"\"\"Common layers for defining score networks.\n\"\"\"\nimport math\nimport string\nfrom functools import partial\nimport torch.nn as nn\nimport torch\nimport torch.nn.functional as F\nimport numpy as np\nfrom .normalization import ConditionalInstanceNorm3dPlus\n\nclass GroupNormFloat32(nn.GroupNorm):\n    def forward(self, input):\n        with torch.autocast(device_type='cuda', dtype=torch.bfloat16, enabled=False):\n            return F.group_norm(\n                input.float(), self.num_groups, self.weight, self.bias, self.eps)\n\n\ndef get_act_fn(act_name):\n    \"\"\"Get activation functions from the config file.\"\"\"\n\n    if act_name.lower() == 'elu':\n        return nn.ELU()\n    elif act_name.lower() == 'relu':\n        return nn.ReLU()\n    elif act_name.lower() == 'lrelu':\n        return nn.LeakyReLU(negative_slope=0.2)\n    elif act_name.lower() == 'swish' or act_name.lower() == 'silu':\n        return nn.SiLU()\n    else:\n        raise NotImplementedError('activation function does not exist!')\n\ndef variance_scaling(scale, mode, distribution,\n                     in_axis=1, out_axis=0,\n                     dtype=torch.float32,\n                     device='cpu'):\n    \"\"\"Ported from JAX. \"\"\"\n\n    def _compute_fans(shape, in_axis=1, out_axis=0):\n        receptive_field_size = np.prod(shape) / shape[in_axis] / shape[out_axis]\n        fan_in = shape[in_axis] * receptive_field_size\n        fan_out = shape[out_axis] * receptive_field_size\n        return fan_in, fan_out\n\n    def init(shape, dtype=dtype, device=device):\n        fan_in, fan_out = _compute_fans(shape, in_axis, out_axis)\n        if mode == \"fan_in\":\n            denominator = fan_in\n        elif mode == \"fan_out\":\n            denominator = fan_out\n        elif mode == \"fan_avg\":\n            denominator = (fan_in + fan_out) / 2\n        else:\n            raise ValueError(\n                \"invalid mode for variance scaling initializer: {}\".format(mode))\n        variance = scale / denominator\n        if distribution == \"normal\":\n            return torch.randn(*shape, dtype=dtype, device=device) * np.sqrt(variance)\n        elif distribution == \"uniform\":\n            return (torch.rand(*shape, dtype=dtype, device=device) * 2. - 1.) * np.sqrt(3 * variance)\n        else:\n            raise ValueError(\"invalid distribution for variance scaling initializer\")\n\n    return init\n\n\ndef default_init(scale=1.):\n    \"\"\"The same initialization used in DDPM.\"\"\"\n    scale = 1e-10 if scale == 0 else scale\n    return variance_scaling(scale, 'fan_avg', 'uniform')\n\n\nclass Dense(nn.Module):\n    \"\"\"Linear layer with `default_init`.\"\"\"\n    def __init__(self):\n        super().__init__()\n\n\ndef conv1x1(in_planes, out_planes, stride=1, bias=True, init_scale=1., padding=0):\n    \"\"\"1x1 convolution with DDPM initialization.\"\"\"\n    conv = nn.Conv3d(in_planes, out_planes, kernel_size=1, stride=stride, padding=padding, bias=bias)\n    conv.weight.data = default_init(init_scale)(conv.weight.data.shape)\n    nn.init.zeros_(conv.bias)\n    return conv\n\ndef conv3x3(in_planes, out_planes, stride=1, bias=True, dilation=1, init_scale=1., padding=1):\n    \"\"\"3x3 convolution with DDPM initialization.\"\"\"\n    conv = nn.Conv3d(in_planes, out_planes, kernel_size=3, stride=stride, padding=padding,\n                    dilation=dilation, bias=bias)\n    conv.weight.data = default_init(init_scale)(conv.weight.data.shape)\n    nn.init.zeros_(conv.bias)\n    return conv\n\ndef conv5x5(in_planes, out_planes, stride=2, bias=True, dilation=1, init_scale=1., padding=2):\n    \"\"\"3x3 convolution with DDPM initialization.\"\"\"\n    conv = nn.Conv3d(in_planes, out_planes, kernel_size=5, stride=stride, padding=padding,\n                    dilation=dilation, bias=bias)\n    conv.weight.data = default_init(init_scale)(conv.weight.data.shape)\n    nn.init.zeros_(conv.bias)\n    return conv\n\n\ndef conv3x3_transposed(in_planes, out_planes, stride=2, bias=True, dilation=1, init_scale=1., padding=1):\n    \"\"\"3x3 convolution with DDPM initialization.\"\"\"\n    conv = nn.ConvTranspose3d(in_planes, out_planes, kernel_size=3, stride=stride, padding=padding,\n                    dilation=dilation, bias=bias)\n    conv.weight.data = default_init(init_scale)(conv.weight.data.shape)\n    nn.init.zeros_(conv.bias)\n    return conv\n\ndef conv5x5_transposed(in_planes, out_planes, stride=2, bias=True, dilation=1, init_scale=1., padding=2):\n    \"\"\"3x3 convolution with DDPM initialization.\"\"\"\n    conv = nn.ConvTranspose3d(in_planes, out_planes, kernel_size=5, stride=stride, padding=padding,\n                    dilation=dilation, bias=bias)\n    conv.weight.data = default_init(init_scale)(conv.weight.data.shape)\n    nn.init.zeros_(conv.bias)\n    return conv\n\n\n###########################################################################\n# Functions below are ported over from the DDPM codebase:\n#  https://github.com/hojonathanho/diffusion/blob/master/diffusion_tf/nn.py\n###########################################################################\n\ndef get_timestep_embedding(timesteps, embedding_dim, max_positions=10000):\n    with torch.no_grad():\n        assert len(timesteps.shape) == 1  # and timesteps.dtype == tf.int32\n        half_dim = embedding_dim // 2\n        # magic number 10000 is from transformers\n        emb = math.log(max_positions) / (half_dim - 1)\n        # emb = math.log(2.) / (half_dim - 1)\n        emb = torch.exp(torch.arange(half_dim, dtype=torch.float32, device=timesteps.device) * -emb)\n        # emb = tf.range(num_embeddings, dtype=jnp.float32)[:, None] * emb[None, :]\n        # emb = tf.cast(timesteps, dtype=jnp.float32)[:, None] * emb[None, :]\n        emb = timesteps[:, None] * emb[None, :]\n        emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)\n        if embedding_dim % 2 == 1:  # zero pad\n            emb = F.pad(emb, (0, 1), mode='constant')\n        assert emb.shape == (timesteps.shape[0], embedding_dim)\n        return emb\n\nclass AttnBlock(nn.Module):\n    \"\"\"Channel-wise self-attention block.\"\"\"\n    def __init__(self, channels, num_groups=32):\n        super().__init__()\n        self.GroupNorm_0 = GroupNormFloat32(num_groups=num_groups, num_channels=channels, eps=1e-6)\n        self.NIN_0 = conv1x1(channels, channels)\n        self.NIN_1 = conv1x1(channels, channels)\n        self.NIN_2 = conv1x1(channels, channels)\n        self.NIN_3 = conv1x1(channels, channels, init_scale=0.)\n\n    def forward(self, x):\n        B, C, D, H, W = x.shape\n        h = self.GroupNorm_0(x)\n        q = self.NIN_0(h)\n        k = self.NIN_1(h)\n        v = self.NIN_2(h)\n\n        # q = q.view(B, C, -1).permute(0, 2, 1)\n        # k = k.view(B, C, -1).permute(0, 2, 1)\n        # v = v.view(B, C, -1).permute(0, 2, 1)\n        # with torch.autocast(device_type='cuda', dtype=torch.bfloat16, enabled=False):\n        #     h = F.scaled_dot_product_attention(q.float(), k.float(), v.float())\n        # h = h.permute(0, 2, 1).view(B, C, D, H, W)\n\n        with torch.autocast(device_type='cuda', dtype=torch.bfloat16, enabled=False):\n            w = torch.einsum('bcdhw,bckij->bdhwkij', q.float(), k.float()) * (int(C) ** (-0.5))\n            w = torch.reshape(w, (B, D, H, W, D * H * W))\n            w = F.softmax(w, dim=-1)\n            w = torch.reshape(w, (B, D, H, W, D, H, W))\n        h = torch.einsum('bdhwkij,bckij->bcdhw', w, v)\n\n        h = self.NIN_3(h)\n        return x + h\n\nclass Upsample(nn.Module):\n    def __init__(self, channels, with_conv=False):\n        super().__init__()\n        if with_conv:\n            self.Conv_0 = conv3x3(channels, channels)\n        self.with_conv = with_conv\n\n    def forward(self, x, temb=None):\n        B, C, D, H, W = x.shape\n        h = F.interpolate(x.float(), (D * 2, H * 2, W * 2), mode='nearest')\n        if self.with_conv:\n            h = self.Conv_0(h)\n        return h\n\n\nclass Downsample(nn.Module):\n    def __init__(self, channels, with_conv=False):\n        super().__init__()\n        if with_conv:\n            self.Conv_0 = conv3x3(channels, channels, stride=2, padding=0)\n            self.with_conv = with_conv\n\n    def forward(self, x, temb=None):\n        B, C, D, H, W = x.shape\n        # Emulate 'SAME' padding\n        if self.with_conv:\n            x = F.pad(x, (0, 1, 0, 1, 0, 1))\n            x = self.Conv_0(x)\n        else:\n            x = F.avg_pool3d(x, kernel_size=2, stride=2, padding=0)\n\n        assert x.shape == (B, C, D // 2, H // 2, W // 2)\n        return x\n\n\nclass ResBlock(nn.Module):\n    \"\"\"The ResNet Blocks used in DDPM.\"\"\"\n    def __init__(self, act_fn, in_ch, out_ch=None, temb_dim=None, conv_shortcut=False, dropout=0.1, num_groups=32):\n        super().__init__()\n        if out_ch is None:\n            out_ch = in_ch\n        self.GroupNorm_0 = GroupNormFloat32(num_groups=num_groups, num_channels=in_ch, eps=1e-6)\n        self.act = act_fn\n        self.Conv_0 = conv3x3(in_ch, out_ch)\n        if temb_dim is not None:\n            self.Dense_0 = nn.Linear(temb_dim, out_ch)\n            self.Dense_0.weight.data = default_init()(self.Dense_0.weight.data.shape)\n            nn.init.zeros_(self.Dense_0.bias)\n\n        self.GroupNorm_1 = GroupNormFloat32(num_groups=num_groups, num_channels=out_ch, eps=1e-6)\n        self.Dropout_0 = nn.Dropout(dropout)\n        self.Conv_1 = conv3x3(out_ch, out_ch, init_scale=0.)\n        if in_ch != out_ch:\n            if conv_shortcut:\n                self.Conv_2 = conv3x3(in_ch, out_ch)\n            else:\n                self.NIN_0 = conv1x1(in_ch, out_ch)\n        self.out_ch = out_ch\n        self.in_ch = in_ch\n        self.conv_shortcut = conv_shortcut\n\n    def forward(self, x, temb=None):\n        B, C, D, H, W = x.shape\n        assert C == self.in_ch\n        out_ch = self.out_ch if self.out_ch else self.in_ch\n        h = self.act(self.GroupNorm_0(x))\n        h = self.Conv_0(h)\n        # Add bias to each feature map conditioned on the time embedding\n        if temb is not None:\n            h += self.Dense_0(self.act(temb))[:, :, None, None, None]\n        h = self.act(self.GroupNorm_1(h))\n        h = self.Dropout_0(h)\n        h = self.Conv_1(h)\n        if C != out_ch:\n            if self.conv_shortcut:\n                x = self.Conv_2(x)\n            else:\n                x = self.NIN_0(x)\n        return x + h\n\nclass AttnResBlock(ResBlock):\n    \"\"\"The ResNet Blocks used in DDPM.\"\"\"\n    def __init__(self, act_fn, in_ch, out_ch, temb_dim=None, conv_shortcut=False, dropout=0.1, num_groups=32):\n        super().__init__(act_fn, in_ch, out_ch, temb_dim, conv_shortcut, dropout, num_groups=num_groups)\n        self.attn_block = AttnBlock(out_ch, num_groups=num_groups)\n\n    def forward(self, x, temb=None):\n        h = super().forward(x, temb)\n        return self.attn_block(h)\n"
  },
  {
    "path": "GMeshDiffusion/lib/diffusion/models/normalization.py",
    "content": "# coding=utf-8\n# Copyright 2020 The Google Research Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"Normalization layers.\"\"\"\nimport torch.nn as nn\nimport torch\nimport functools\n\n\ndef get_normalization(config, conditional=False):\n  \"\"\"Obtain normalization modules from the config file.\"\"\"\n  norm = config.model.normalization\n  if conditional:\n    if norm == 'InstanceNorm++':\n      return functools.partial(ConditionalInstanceNorm3dPlus, num_classes=config.model.num_classes)\n    else:\n      raise NotImplementedError(f'{norm} not implemented yet.')\n  else:\n    if norm == 'InstanceNorm':\n      return nn.InstanceNorm3d\n    elif norm == 'InstanceNorm++':\n      return InstanceNorm3dPlus\n    elif norm == 'VarianceNorm':\n      return VarianceNorm3d\n    elif norm == 'GroupNorm':\n      return nn.GroupNorm\n    else:\n      raise ValueError('Unknown normalization: %s' % norm)\n\n\nclass ConditionalBatchNorm3d(nn.Module):\n  def __init__(self, num_features, num_classes, bias=True):\n    super().__init__()\n    self.num_features = num_features\n    self.bias = bias\n    self.bn = nn.BatchNorm3d(num_features, affine=False)\n    if self.bias:\n      self.embed = nn.Embedding(num_classes, num_features * 2)\n      self.embed.weight.data[:, :num_features].uniform_()  # Initialise scale at N(1, 0.02)\n      self.embed.weight.data[:, num_features:].zero_()  # Initialise bias at 0\n    else:\n      self.embed = nn.Embedding(num_classes, num_features)\n      self.embed.weight.data.uniform_()\n\n  def forward(self, x, y):\n    out = self.bn(x)\n    if self.bias:\n      gamma, beta = self.embed(y).chunk(2, dim=1)\n      out = gamma.view(-1, self.num_features, 1, 1, 1) * out + beta.view(-1, self.num_features, 1, 1, 1)\n    else:\n      gamma = self.embed(y)\n      out = gamma.view(-1, self.num_features, 1, 1, 1) * out\n    return out\n\n\nclass ConditionalInstanceNorm3d(nn.Module):\n  def __init__(self, num_features, num_classes, bias=True):\n    super().__init__()\n    self.num_features = num_features\n    self.bias = bias\n    self.instance_norm = nn.InstanceNorm3d(num_features, affine=False, track_running_stats=False)\n    if bias:\n      self.embed = nn.Embedding(num_classes, num_features * 2)\n      self.embed.weight.data[:, :num_features].uniform_()  # Initialise scale at N(1, 0.02)\n      self.embed.weight.data[:, num_features:].zero_()  # Initialise bias at 0\n    else:\n      self.embed = nn.Embedding(num_classes, num_features)\n      self.embed.weight.data.uniform_()\n\n  def forward(self, x, y):\n    h = self.instance_norm(x)\n    if self.bias:\n      gamma, beta = self.embed(y).chunk(2, dim=-1)\n      out = gamma.view(-1, self.num_features, 1, 1, 1) * h + beta.view(-1, self.num_features, 1, 1, 1)\n    else:\n      gamma = self.embed(y)\n      out = gamma.view(-1, self.num_features, 1, 1, 1) * h\n    return out\n\n\nclass ConditionalVarianceNorm3d(nn.Module):\n  def __init__(self, num_features, num_classes, bias=False):\n    super().__init__()\n    self.num_features = num_features\n    self.bias = bias\n    self.embed = nn.Embedding(num_classes, num_features)\n    self.embed.weight.data.normal_(1, 0.02)\n\n  def forward(self, x, y):\n    vars = torch.var(x, dim=(2, 3, 4), keepdim=True)\n    h = x / torch.sqrt(vars + 1e-5)\n\n    gamma = self.embed(y)\n    out = gamma.view(-1, self.num_features, 1, 1, 1) * h\n    return out\n\n\nclass VarianceNorm3d(nn.Module):\n  def __init__(self, num_features, bias=False):\n    super().__init__()\n    self.num_features = num_features\n    self.bias = bias\n    self.alpha = nn.Parameter(torch.zeros(num_features))\n    self.alpha.data.normal_(1, 0.02)\n\n  def forward(self, x):\n    vars = torch.var(x, dim=(2, 3, 4), keepdim=True)\n    h = x / torch.sqrt(vars + 1e-5)\n\n    out = self.alpha.view(-1, self.num_features, 1, 1, 1) * h\n    return out\n\n\nclass ConditionalNoneNorm3d(nn.Module):\n  def __init__(self, num_features, num_classes, bias=True):\n    super().__init__()\n    self.num_features = num_features\n    self.bias = bias\n    if bias:\n      self.embed = nn.Embedding(num_classes, num_features * 2)\n      self.embed.weight.data[:, :num_features].uniform_()  # Initialise scale at N(1, 0.02)\n      self.embed.weight.data[:, num_features:].zero_()  # Initialise bias at 0\n    else:\n      self.embed = nn.Embedding(num_classes, num_features)\n      self.embed.weight.data.uniform_()\n\n  def forward(self, x, y):\n    if self.bias:\n      gamma, beta = self.embed(y).chunk(2, dim=-1)\n      out = gamma.view(-1, self.num_features, 1, 1, 1) * x + beta.view(-1, self.num_features, 1, 1, 1)\n    else:\n      gamma = self.embed(y)\n      out = gamma.view(-1, self.num_features, 1, 1, 1) * x\n    return out\n\n\nclass NoneNorm3d(nn.Module):\n  def __init__(self, num_features, bias=True):\n    super().__init__()\n\n  def forward(self, x):\n    return x\n\n\nclass InstanceNorm3dPlus(nn.Module):\n  def __init__(self, num_features, bias=True):\n    super().__init__()\n    self.num_features = num_features\n    self.bias = bias\n    self.instance_norm = nn.InstanceNorm3d(num_features, affine=False, track_running_stats=False)\n    self.alpha = nn.Parameter(torch.zeros(num_features))\n    self.gamma = nn.Parameter(torch.zeros(num_features))\n    self.alpha.data.normal_(1, 0.02)\n    self.gamma.data.normal_(1, 0.02)\n    if bias:\n      self.beta = nn.Parameter(torch.zeros(num_features))\n\n  def forward(self, x):\n    means = torch.mean(x, dim=(2, 3, 4))\n    m = torch.mean(means, dim=-1, keepdim=True)\n    v = torch.var(means, dim=-1, keepdim=True)\n    means = (means - m) / (torch.sqrt(v + 1e-5))\n    h = self.instance_norm(x)\n\n    if self.bias:\n      h = h + means[..., None, None, None] * self.alpha[..., None, None, None]\n      out = self.gamma.view(-1, self.num_features, 1, 1, 1) * h + self.beta.view(-1, self.num_features, 1, 1, 1)\n    else:\n      h = h + means[..., None, None, None] * self.alpha[..., None, None, None]\n      out = self.gamma.view(-1, self.num_features, 1, 1, 1) * h\n    return out\n\n\nclass ConditionalInstanceNorm3dPlus(nn.Module):\n  def __init__(self, num_features, num_classes, bias=True):\n    super().__init__()\n    self.num_features = num_features\n    self.bias = bias\n    self.instance_norm = nn.InstanceNorm3d(num_features, affine=False, track_running_stats=False)\n    if bias:\n      self.embed = nn.Embedding(num_classes, num_features * 3)\n      self.embed.weight.data[:, :2 * num_features].normal_(1, 0.02)  # Initialise scale at N(1, 0.02)\n      self.embed.weight.data[:, 2 * num_features:].zero_()  # Initialise bias at 0\n    else:\n      self.embed = nn.Embedding(num_classes, 2 * num_features)\n      self.embed.weight.data.normal_(1, 0.02)\n\n  def forward(self, x, y):\n    means = torch.mean(x, dim=(2, 3, 4))\n    m = torch.mean(means, dim=-1, keepdim=True)\n    v = torch.var(means, dim=-1, keepdim=True)\n    means = (means - m) / (torch.sqrt(v + 1e-5))\n    h = self.instance_norm(x)\n\n    if self.bias:\n      gamma, alpha, beta = self.embed(y).chunk(3, dim=-1)\n      h = h + means[..., None, None, None] * alpha[..., None, None, None]\n      out = gamma.view(-1, self.num_features, 1, 1, 1) * h + beta.view(-1, self.num_features, 1, 1, 1)\n    else:\n      gamma, alpha = self.embed(y).chunk(2, dim=-1)\n      h = h + means[..., None, None, None] * alpha[..., None, None, None]\n      out = gamma.view(-1, self.num_features, 1, 1, 1) * h\n    return out\n"
  },
  {
    "path": "GMeshDiffusion/lib/diffusion/models/unet3d_occgrid.py",
    "content": "# coding=utf-8\n# Copyright 2020 The Google Research Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n# pylint: skip-file\n\"\"\"DDPM model.\n\nThis code is the pytorch equivalent of:\nhttps://github.com/hojonathanho/diffusion/blob/master/diffusion_tf/models/unet.py\n\"\"\"\nimport torch\nimport torch.nn as nn\nimport functools\nimport numpy as np\n\nfrom . import utils\nfrom .layers import ResBlock, AttnResBlock, Upsample, Downsample, conv1x1, conv3x3, conv5x5, get_act_fn, default_init, get_timestep_embedding, GroupNormFloat32\n\nimport sys\n\n\ndef str_to_class(classname):\n    return getattr(sys.modules[__name__], classname)\n\n\n@utils.register_model(name='unet3d_occgrid')\nclass UNet3D(nn.Module):\n    def __init__(self, config):\n        super().__init__()\n        self.act_fn = get_act_fn(config.model.act_fn)\n        self.nf = nf = config.model.base_channels\n        data_ch = config.data.num_channels\n        ch_mult = config.model.ch_mult\n        feature_mask = torch.load(config.model.feature_mask_path, map_location='cpu').view(1, data_ch, 128, 128, 128)\n        pixcat_mask = torch.load(config.model.pixcat_mask_path, map_location='cpu').view(1, 1, 128, 128, 128)\n\n        occ_mask_path = config.data.occ_mask_path\n        occ_mask = torch.load(occ_mask_path, map_location='cpu').view(1, 1, 256, 256, 256)\n\n\n        tet_info = torch.load(config.data.tet_info_path)\n        self.tet_edge_vpos = tet_info['tet_edge_vpos'].cuda()\n        self.tet_edge_pix_loc = tet_info['tet_edge_pix_loc'].cuda().view(-1, 2, 3)\n        self.tet_edge_pix_loc = self.tet_edge_pix_loc.view(-1, 2, 3)\n        # self.tet_center_loc = tet_info['tet_center_loc'].cuda()\n        self.vis_edges = tet_info['vis_edges'].cuda()\n        self.occ_edge_cano_order = tet_info['occ_edge_cano_order'].cuda()\n        self.tet_edgenode_loc = self.tet_edge_pix_loc.float().mean(dim=1).long()\n        self.occ_edge_loc = self.tet_edgenode_loc.view(-1, 6, 3)[:, self.vis_edges.view(-1)].view(-1, 2, 3)\n        self.occ_node_loc = (self.occ_edge_loc.view(-1, 12, 2, 3).float().mean(dim=-2) * 2.0).long().view(-1, 3)\n        print(self.tet_edgenode_loc.size(), self.vis_edges.size(), self.occ_edge_loc.size(), self.occ_node_loc.size())\n        self.tet_edge_pix_loc = self.tet_edge_pix_loc.view(-1, 3)\n        \n        \n        self.feature_mask = torch.nn.Parameter(feature_mask, requires_grad=False)\n        self.pixcat_mask = torch.nn.Parameter(pixcat_mask, requires_grad=False)\n        self.occ_mask = torch.nn.Parameter(occ_mask, requires_grad=False)\n        self.down_block_types = config.model.down_block_types\n        self.up_block_types = config.model.up_block_types\n        self.num_res_blocks = config.model.num_res_blocks\n        self.num_res_blocks_1st_layer = config.model.num_res_blocks_1st_layer\n        resamp_with_conv = config.model.resamp_with_conv\n        dropout = config.model.dropout\n        assert len(self.down_block_types) == len(self.up_block_types)\n\n\n        module_dict = {\n            module: functools.partial(str_to_class(module), act_fn=self.act_fn, temb_dim=4 * nf, dropout=dropout)\n            for module in [\"ResBlock\", \"AttnResBlock\"]\n        }\n\n\n        # Condition on noise levels.\n        noise_temb_layers = [nn.Linear(nf, nf * 4), nn.SiLU(), nn.Linear(nf * 4, nf * 4)]\n        noise_temb_layers[0].weight.data = default_init()(noise_temb_layers[0].weight.data.shape)\n        nn.init.zeros_(noise_temb_layers[0].bias)\n        noise_temb_layers[2].weight.data = default_init()(noise_temb_layers[2].weight.data.shape)\n        nn.init.zeros_(noise_temb_layers[2].bias)\n        self.noise_temb_layers = nn.Sequential(*noise_temb_layers)\n\n        self.occ_conv = conv3x3(1, nf, stride=2, padding=1)\n        self.occ_mask_conv = conv3x3(1, nf, stride=2, padding=1)\n\n        # Downsampling block\n        self.mask_layer = conv5x5(1, nf, stride=1, padding=2)\n        self.input_layer = conv5x5(data_ch, nf, stride=1, padding=2)\n        hs_c = [nf]\n        in_ch = nf\n        \n        modules = []\n        for i_level, down_block_type in enumerate(self.down_block_types):\n            curr_block = module_dict[down_block_type]\n            # Residual blocks for this resolution\n            num_res_blocks = self.num_res_blocks_1st_layer if i_level == 0 else self.num_res_blocks\n            for i_block in range(num_res_blocks):\n                out_ch = nf * ch_mult[i_level]\n                modules.append(curr_block(in_ch=in_ch, out_ch=out_ch))\n                in_ch = out_ch\n                hs_c.append(in_ch)\n        \n            if i_level != len(self.down_block_types) - 1:\n                modules.append(Downsample(channels=in_ch, with_conv=resamp_with_conv))\n                hs_c.append(in_ch)\n\n        in_ch = hs_c[-1]\n        modules.append(module_dict[\"AttnResBlock\"](in_ch=in_ch, out_ch=in_ch))\n        modules.append(module_dict[\"ResBlock\"](in_ch=in_ch))\n\n        # Upsampling block\n        for i_level, up_block_type in enumerate(self.up_block_types):\n            curr_block = module_dict[up_block_type]\n            num_res_blocks = self.num_res_blocks_1st_layer if i_level == len(self.up_block_types) - 1 else self.num_res_blocks\n            for i_block in range(num_res_blocks + 1):\n                out_ch = nf * ch_mult[len(self.up_block_types) - i_level - 1]\n                modules.append(curr_block(in_ch=in_ch + hs_c.pop(), out_ch=out_ch))\n                in_ch = out_ch\n            if i_level != len(self.up_block_types) - 1:\n                modules.append(Upsample(channels=in_ch, with_conv=resamp_with_conv))\n\n        self.all_modules = nn.ModuleList(modules)\n\n        self.output_norm_layer = nn.Sequential(\n            GroupNormFloat32(num_channels=in_ch, num_groups=32, eps=1e-6),\n            nn.SiLU(),\n        )\n        self.output_layer = conv5x5(in_ch, data_ch, init_scale=0., stride=1, padding=2)\n\n\n        self.occ_output_layer = nn.ConvTranspose3d(in_ch, 1, 4, stride=2, padding=1)\n\n    def sequentially_call_module(self, idx, x, temb=None):\n        return idx + 1, self.all_modules[idx](x, temb)\n\n    def forward(self, x, labels):\n        modules = self.all_modules\n\n        with torch.no_grad():\n            x, occ_grid = x[0], x[1]\n            if True or self.centered:\n                # Input is in [-1, 1]\n                x = x\n            else:\n                # Input is in [0, 1]\n                x = 2 * x - 1.\n\n            # Mask out unused values\n            x = x * self.feature_mask\n\n            occ_grid = occ_grid * self.occ_mask\n        with torch.autocast(device_type='cuda', dtype=torch.bfloat16, enabled=False):\n            # timestep/scale embedding\n            timesteps = labels\n            temb = get_timestep_embedding(timesteps.float(), self.nf)\n            temb = self.noise_temb_layers(temb)\n\n        # Downsampling block\n        hs = [self.input_layer(x) + self.mask_layer(self.pixcat_mask) + self.occ_conv(occ_grid) + self.occ_mask_conv(self.occ_mask)]\n\n        m_idx = 0\n        for i_level in range(len(self.down_block_types)):\n            num_res_blocks = self.num_res_blocks_1st_layer if i_level == 0 else self.num_res_blocks\n            for i_block in range(num_res_blocks):\n                m_idx, h = self.sequentially_call_module(m_idx, hs[-1], temb)\n                hs.append(h)\n            if i_level != len(self.down_block_types) - 1:\n                m_idx, h = self.sequentially_call_module(m_idx, hs[-1])\n                hs.append(h)\n\n        h = hs[-1]\n        m_idx, h = self.sequentially_call_module(m_idx, h, temb)\n        m_idx, h = self.sequentially_call_module(m_idx, h, temb)\n\n        # Upsampling block\n        for i_level in range(len(self.up_block_types)):\n            num_res_blocks = self.num_res_blocks_1st_layer if i_level == len(self.up_block_types) - 1 else self.num_res_blocks\n            for i_block in range(num_res_blocks + 1):\n                hspop = hs.pop()\n                h = torch.cat([h, hspop], dim=1)\n                m_idx, h = self.sequentially_call_module(m_idx, h, temb)\n            if i_level != len(self.up_block_types) - 1:\n                m_idx, h = self.sequentially_call_module(m_idx, h, temb)\n\n        assert not hs\n        h = self.output_norm_layer(h)\n        grid = self.output_layer(h)\n        grid_occ = self.occ_output_layer(h)\n\n        # Mask out unused values\n        grid = grid * self.feature_mask\n        grid_occ = grid_occ * self.occ_mask\n\n        return grid, grid_occ\n"
  },
  {
    "path": "GMeshDiffusion/lib/diffusion/models/utils.py",
    "content": "# coding=utf-8\n# Copyright 2020 The Google Research Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"All functions and modules related to model definition.\n\"\"\"\n\nimport torch\nfrom .. import sde_lib\nimport numpy as np\n\n\n_MODELS = {}\n\n\ndef register_model(cls=None, *, name=None):\n  \"\"\"A decorator for registering model classes.\"\"\"\n\n  def _register(cls):\n    if name is None:\n      local_name = cls.__name__\n    else:\n      local_name = name\n    if local_name in _MODELS:\n      raise ValueError(f'Already registered model with name: {local_name}')\n    _MODELS[local_name] = cls\n    return cls\n\n  if cls is None:\n    return _register\n  else:\n    return _register(cls)\n\n\ndef get_model(name):\n  return _MODELS[name]\n\n\ndef get_sigmas(config):\n  \"\"\"Get sigmas --- the set of noise levels for SMLD from config files.\n  Args:\n    config: A ConfigDict object parsed from the config file\n  Returns:\n    sigmas: a jax numpy arrary of noise levels\n  \"\"\"\n  sigmas = np.exp(\n    np.linspace(np.log(config.model.sigma_max), np.log(config.model.sigma_min), config.model.num_scales))\n\n  return sigmas\n\n\ndef get_ddpm_params(config):\n  \"\"\"Get betas and alphas --- parameters used in the original DDPM paper.\"\"\"\n  num_diffusion_timesteps = 1000\n  # parameters need to be adapted if number of time steps differs from 1000\n  beta_start = config.model.beta_min / config.model.num_scales\n  beta_end = config.model.beta_max / config.model.num_scales\n  betas = np.linspace(beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64)\n\n  alphas = 1. - betas\n  alphas_cumprod = np.cumprod(alphas, axis=0)\n  sqrt_alphas_cumprod = np.sqrt(alphas_cumprod)\n  sqrt_1m_alphas_cumprod = np.sqrt(1. - alphas_cumprod)\n\n  return {\n    'betas': betas,\n    'alphas': alphas,\n    'alphas_cumprod': alphas_cumprod,\n    'sqrt_alphas_cumprod': sqrt_alphas_cumprod,\n    'sqrt_1m_alphas_cumprod': sqrt_1m_alphas_cumprod,\n    'beta_min': beta_start * (num_diffusion_timesteps - 1),\n    'beta_max': beta_end * (num_diffusion_timesteps - 1),\n    'num_diffusion_timesteps': num_diffusion_timesteps\n  }\n\n\ndef create_model(config, use_parallel=True, ddp=False, rank=None):\n  \"\"\"Create the score model.\"\"\"\n  model_name = config.model.name\n  score_model = get_model(model_name)(config)\n  if use_parallel:\n    if ddp:\n      score_model = score_model.to(rank)\n      score_model = torch.nn.parallel.DistributedDataParallel(\n        score_model, \n        find_unused_parameters=False,\n        # find_unused_parameters=True,\n        gradient_as_bucket_view=True,\n        # static_graph=True,\n        device_ids=[rank])\n      # score_model = torch.compile(score_model)\n      # score_model = torch.compile(score_model, fullgraph=True)\n    else:\n      score_model = torch.nn.DataParallel(score_model).to(config.device)\n  else:\n    score_model = score_model.to(config.device)\n  return score_model\n\n\ndef get_model_fn(model, train=False):\n  \"\"\"Create a function to give the output of the score-based model.\n\n  Args:\n    model: The score model.\n    train: `True` for training and `False` for evaluation.\n\n  Returns:\n    A model function.\n  \"\"\"\n\n  def model_fn(x, labels):\n    \"\"\"Compute the output of the score-based model.\n\n    Args:\n      x: A mini-batch of input data.\n      labels: A mini-batch of conditioning variables for time steps. Should be interpreted differently\n        for different models.\n\n    Returns:\n      A tuple of (model output, new mutable states)\n    \"\"\"\n    if not train:\n      model.eval()\n      return model(x, labels)\n    else:\n      model.train()\n      return model(x, labels)\n\n  return model_fn\n\ndef get_reg_fn(model, train=False):\n  \"\"\"Create a function to give the output of the score-based model.\n\n  Args:\n    model: The score model.\n    train: `True` for training and `False` for evaluation.\n\n  Returns:\n    A model function.\n  \"\"\"\n\n  def model_fn(x):\n    \"\"\"Compute the output of the score-based model.\n\n    Args:\n      x: A mini-batch of input data.\n      labels: A mini-batch of conditioning variables for time steps. Should be interpreted differently\n        for different models.\n\n    Returns:\n      A tuple of (model output, new mutable states)\n    \"\"\"\n    if not train:\n      model.eval()\n      try:\n        return model.get_reg(x)\n      except:\n        return torch.zeros_like(x, device=x.device)\n    else:\n      model.train()\n      try:\n        return model.get_reg(x)\n      except:\n        return torch.zeros_like(x, device=x.device)\n\n  return model_fn\n\ndef get_score_fn(sde, model, train=False, continuous=False, std_scale=True, pred_type='noise'):\n  \"\"\"Wraps `score_fn` so that the model output corresponds to a real time-dependent score function.\n\n  Args:\n    sde: An `sde_lib.SDE` object that represents the forward SDE.\n    model: A score model.\n    train: `True` for training and `False` for evaluation.\n    continuous: If `True`, the score-based model is expected to directly take continuous time steps.\n    std_scale: whether to scale the score function by the inverse of std. Used for DDIM sampling\n\n  Returns:\n    A score function.\n  \"\"\"\n  model_fn = get_model_fn(model, train=train)\n  reg_fn = get_reg_fn(model, train=train)\n\n  assert not continuous\n  if isinstance(sde, sde_lib.VPSDE):\n    if not std_scale:\n      def score_fn(x, t):\n        labels = t * (sde.N - 1)\n        pred = model_fn(x, labels)\n\n        if pred_type == 'x0':\n          labels = labels.long()\n          alphas1 = sde.sqrt_alphas_cumprod[labels, None, None, None, None].cuda()\n          alphas2 = sde.sqrt_1m_alphas_cumprod[labels, None, None, None, None].cuda()\n          score = (x - pred * alphas1) / alphas2\n        elif pred_type == 'noise':\n          score = pred\n        return score\n    else:\n      def score_fn(x, t):\n        # For VP-trained models, t=0 corresponds to the lowest noise level\n        labels = t * (sde.N - 1)\n        pred = model_fn(x, labels)\n\n\n        if pred_type == 'x0':\n          labels = labels.long()\n          alphas1 = sde.sqrt_alphas_cumprod[labels, None, None, None, None].cuda()\n          alphas2 = sde.sqrt_1m_alphas_cumprod[labels, None, None, None, None].cuda()\n          score = (x - pred * alphas1) / alphas2\n        elif pred_type == 'noise':\n          score = pred\n\n        std = sde.sqrt_1m_alphas_cumprod.to(labels.device)[labels.long()]\n\n        score = -score / std[:, None, None, None, None]\n        return score\n\n  else:\n    raise NotImplementedError(f\"SDE class {sde.__class__.__name__} not yet supported.\")\n\n  return score_fn\n\n\ndef to_flattened_numpy(x):\n  \"\"\"Flatten a torch tensor `x` and convert it to numpy.\"\"\"\n  return x.detach().cpu().numpy().reshape((-1,))\n\n\ndef from_flattened_numpy(x, shape):\n  \"\"\"Form a torch tensor with the given `shape` from a flattened numpy array `x`.\"\"\"\n  return torch.from_numpy(x.reshape(shape))"
  },
  {
    "path": "GMeshDiffusion/lib/diffusion/sampling.py",
    "content": "# coding=utf-8\n# Copyright 2020 The Google Research Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n# pylint: skip-file\n# pytype: skip-file\n\"\"\"Various sampling methods.\"\"\"\nimport functools\n\nimport torch\nimport numpy as np\nimport abc\n\nfrom .models.utils import from_flattened_numpy, to_flattened_numpy, get_score_fn\nfrom scipy import integrate\nfrom . import sde_lib\nfrom .models import utils as mutils\n\nimport logging\nimport tqdm\n\n_CORRECTORS = {}\n_PREDICTORS = {}\n\n\ndef register_predictor(cls=None, *, name=None):\n    \"\"\"A decorator for registering predictor classes.\"\"\"\n\n    def _register(cls):\n        if name is None:\n            local_name = cls.__name__\n        else:\n            local_name = name\n        if local_name in _PREDICTORS:\n            raise ValueError(f'Already registered model with name: {local_name}')\n        _PREDICTORS[local_name] = cls\n        return cls\n\n    if cls is None:\n        return _register\n    else:\n        return _register(cls)\n\n\ndef register_corrector(cls=None, *, name=None):\n    \"\"\"A decorator for registering corrector classes.\"\"\"\n\n    def _register(cls):\n        if name is None:\n            local_name = cls.__name__\n        else:\n            local_name = name\n        if local_name in _CORRECTORS:\n            raise ValueError(f'Already registered model with name: {local_name}')\n        _CORRECTORS[local_name] = cls\n        return cls\n\n    if cls is None:\n        return _register\n    else:\n        return _register(cls)\n\n\ndef get_predictor(name):\n    return _PREDICTORS[name]\n\n\ndef get_corrector(name):\n    return _CORRECTORS[name]\n\n\ndef get_sampling_fn(config, sde, shape, inverse_scaler, eps, grid_mask=None, return_traj=False, pred_type='noise'):\n    \"\"\"Create a sampling function.\n\n    Args:\n        config: A `ml_collections.ConfigDict` object that contains all configuration information.\n        sde: A `sde_lib.SDE` object that represents the forward SDE.\n        shape: A sequence of integers representing the expected shape of a single sample.\n        inverse_scaler: The inverse data normalizer function.\n        eps: A `float` number. The reverse-time SDE is only integrated to `eps` for numerical stability.\n\n    Returns:\n        A function that takes random states and a replicated training state and outputs samples with the\n            trailing dimensions matching `shape`.\n    \"\"\"\n\n    sampler_name = config.sampling.method\n    # Probability flow ODE sampling with black-box ODE solvers\n    # Predictor-Corrector sampling. Predictor-only and Corrector-only samplers are special cases.\n    if sampler_name.lower() == 'pc':\n        predictor = get_predictor(config.sampling.predictor.lower())\n        corrector = get_corrector(config.sampling.corrector.lower())\n        sampling_fn = get_pc_sampler(sde=sde,\n                                    shape=shape,\n                                    predictor=predictor,\n                                    corrector=corrector,\n                                    inverse_scaler=inverse_scaler,\n                                    snr=config.sampling.snr,\n                                    n_steps=config.sampling.n_steps_each,\n                                    probability_flow=config.sampling.probability_flow,\n                                    continuous=config.training.continuous,\n                                    denoise=config.sampling.noise_removal,\n                                    eps=eps,\n                                    device=config.device,\n                                    grid_mask=grid_mask,\n                                    return_traj=return_traj,\n                                    pred_type=pred_type,\n                                    use_occ=config.model.use_occ_grid)\n    elif sampler_name.lower() == 'ddim':\n        predictor = get_predictor('ddim')\n        sampling_fn = get_ddim_sampler(sde=sde,\n                                    shape=shape,\n                                    predictor=predictor,\n                                    inverse_scaler=inverse_scaler,\n                                    n_steps=config.sampling.n_steps_each,\n                                    denoise=config.sampling.noise_removal,\n                                    eps=eps,\n                                    device=config.device,\n                                    grid_mask=grid_mask,\n                                    pred_type=pred_type,\n                                    use_occ=config.model.use_occ_grid)\n    else:\n        raise ValueError(f\"Sampler name {sampler_name} unknown.\")\n\n    return sampling_fn\n\n\nclass Predictor(abc.ABC):\n    \"\"\"The abstract class for a predictor algorithm.\"\"\"\n\n    def __init__(self, sde, score_fn, probability_flow=False):\n        super().__init__()\n        self.sde = sde\n        # Compute the reverse SDE/ODE\n        self.rsde = sde.reverse(score_fn, probability_flow)\n        self.score_fn = score_fn\n\n    @abc.abstractmethod\n    def update_fn(self, x, t):\n        \"\"\"One update of the predictor.\n\n        Args:\n            x: A PyTorch tensor representing the current state\n            t: A Pytorch tensor representing the current time step.\n\n        Returns:\n            x: A PyTorch tensor of the next state.\n            x_mean: A PyTorch tensor. The next state without random noise. Useful for denoising.\n        \"\"\"\n        pass\n\n\nclass Corrector(abc.ABC):\n    \"\"\"The abstract class for a corrector algorithm.\"\"\"\n\n    def __init__(self, sde, score_fn, snr, n_steps):\n        super().__init__()\n        self.sde = sde\n        self.score_fn = score_fn\n        self.snr = snr\n        self.n_steps = n_steps\n\n    @abc.abstractmethod\n    def update_fn(self, x, t):\n        \"\"\"One update of the corrector.\n\n        Args:\n            x: A PyTorch tensor representing the current state\n            t: A PyTorch tensor representing the current time step.\n\n        Returns:\n            x: A PyTorch tensor of the next state.\n            x_mean: A PyTorch tensor. The next state without random noise. Useful for denoising.\n        \"\"\"\n        pass\n\n\n@register_predictor(name='euler_maruyama')\nclass EulerMaruyamaPredictor(Predictor):\n    def __init__(self, sde, score_fn, probability_flow=False):\n        super().__init__(sde, score_fn, probability_flow)\n\n    def update_fn(self, x, t):\n        dt = -1. / self.rsde.N\n        z = torch.randn_like(x)\n        drift, diffusion = self.rsde.sde(x, t)\n        x_mean = x + drift * dt\n        x = x_mean + diffusion[:, None, None, None, None] * np.sqrt(-dt) * z\n        return x, x_mean\n\n\n@register_predictor(name='reverse_diffusion')\nclass ReverseDiffusionPredictor(Predictor):\n    def __init__(self, sde, score_fn, probability_flow=False):\n        super().__init__(sde, score_fn, probability_flow)\n\n    def update_fn(self, x, t):\n        f, G = self.rsde.discretize(x, t)\n        z = torch.randn_like(x)\n        x_mean = x - f\n        x = x_mean + G[:, None, None, None, None] * z\n        return x, x_mean\n\n\n@register_predictor(name='ancestral_sampling')\nclass AncestralSamplingPredictor(Predictor):\n    \"\"\"The ancestral sampling predictor. Currently only supports VE/VP SDEs.\"\"\"\n\n    def __init__(self, sde, score_fn, probability_flow=False):\n        super().__init__(sde, score_fn, probability_flow)\n        if not isinstance(sde, sde_lib.VPSDE):\n            raise NotImplementedError(f\"SDE class {sde.__class__.__name__} not yet supported.\")\n        assert not probability_flow, \"Probability flow not supported by ancestral sampling\"\n\n    def vpsde_update_fn(self, x, t):\n        sde = self.sde\n        timestep = (t * (sde.N - 1) / sde.T).long()\n        beta = sde.discrete_betas.to(t.device)[timestep]\n        score = self.score_fn(x, t)\n        x_mean = (x + beta[:, None, None, None, None] * score) / torch.sqrt(1. - beta)[:, None, None, None, None]\n        noise = torch.randn_like(x)\n        x = x_mean + torch.sqrt(beta)[:, None, None, None, None] * noise\n        return x, x_mean\n\n    def update_fn(self, x, t):\n        if isinstance(self.sde, sde_lib.VPSDE):\n            return self.vpsde_update_fn(x, t)\n        else:\n            raise NotImplementedError\n\n\n@register_predictor(name='none')\nclass NonePredictor(Predictor):\n    \"\"\"An empty predictor that does nothing.\"\"\"\n\n    def __init__(self, sde, score_fn, probability_flow=False):\n        pass\n\n    def update_fn(self, x, t):\n        return x, x\n\n@register_predictor(name='ddim')\nclass DDIMPredictor(Predictor):\n    def __init__(self, sde, score_fn, probability_flow=False):\n        super().__init__(sde, score_fn, probability_flow)\n\n\n    def update_fn(self, x, t, tprev=None):\n        x, x0_pred = self.rsde.discretize_ddim(x, t, tprev=tprev)\n        return x, x0_pred\n\n@register_corrector(name='langevin')\nclass LangevinCorrector(Corrector):\n    def __init__(self, sde, score_fn, snr, n_steps):\n        super().__init__(sde, score_fn, snr, n_steps)\n        if not isinstance(sde, sde_lib.VPSDE):\n            raise NotImplementedError(f\"SDE class {sde.__class__.__name__} not yet supported.\")\n\n    def update_fn(self, x, t):\n        sde = self.sde\n        score_fn = self.score_fn\n        n_steps = self.n_steps\n        target_snr = self.snr\n        if isinstance(sde, sde_lib.VPSDE):\n            timestep = (t * (sde.N - 1) / sde.T).long()\n            alpha = sde.alphas.to(t.device)[timestep]\n        else:\n            alpha = torch.ones_like(t)\n\n        for i in range(n_steps):\n            grad = score_fn(x, t)\n            noise = torch.randn_like(x)\n            grad_norm = torch.norm(grad.reshape(grad.shape[0], -1), dim=-1).mean()\n            noise_norm = torch.norm(noise.reshape(noise.shape[0], -1), dim=-1).mean()\n            step_size = (target_snr * noise_norm / grad_norm) ** 2 * 2 * alpha\n            x_mean = x + step_size[:, None, None, None, None] * grad\n            x = x_mean + torch.sqrt(step_size * 2)[:, None, None, None, None] * noise\n\n        return x, x_mean\n\n\n@register_corrector(name='ald')\nclass AnnealedLangevinDynamics(Corrector):\n    \"\"\"The original annealed Langevin dynamics predictor in NCSN/NCSNv2.\n\n    We include this corrector only for completeness. It was not directly used in our paper.\n    \"\"\"\n\n    def __init__(self, sde, score_fn, snr, n_steps):\n        super().__init__(sde, score_fn, snr, n_steps)\n        if not isinstance(sde, sde_lib.VPSDE):\n            raise NotImplementedError(f\"SDE class {sde.__class__.__name__} not yet supported.\")\n\n    def update_fn(self, x, t):\n        sde = self.sde\n        score_fn = self.score_fn\n        n_steps = self.n_steps\n        target_snr = self.snr\n        if isinstance(sde, sde_lib.VPSDE):\n            timestep = (t * (sde.N - 1) / sde.T).long()\n            alpha = sde.alphas.to(t.device)[timestep]\n        else:\n            alpha = torch.ones_like(t)\n\n        std = self.sde.marginal_prob(x, t)[1]\n\n        for i in range(n_steps):\n            grad = score_fn(x, t)\n            noise = torch.randn_like(x)\n            step_size = (target_snr * std) ** 2 * 2 * alpha\n            x_mean = x + step_size[:, None, None, None, None] * grad\n            x = x_mean + noise * torch.sqrt(step_size * 2)[:, None, None, None, None]\n\n        return x, x_mean\n\n\n@register_corrector(name='none')\nclass NoneCorrector(Corrector):\n    \"\"\"An empty corrector that does nothing.\"\"\"\n\n    def __init__(self, sde, score_fn, snr, n_steps):\n        pass\n\n    def update_fn(self, x, t):\n        return x, x\n\n\ndef shared_predictor_update_fn(x, t, sde, model, predictor, probability_flow, continuous, pred_type='noise'):\n    \"\"\"A wrapper that configures and returns the update function of predictors.\"\"\"\n    score_fn = mutils.get_score_fn(sde, model, train=False, continuous=continuous, pred_type=pred_type)\n    if predictor is None:\n        # Corrector-only sampler\n        predictor_obj = NonePredictor(sde, score_fn, probability_flow)\n    else:\n        predictor_obj = predictor(sde, score_fn, probability_flow)\n    return predictor_obj.update_fn(x, t)\n\n\ndef shared_corrector_update_fn(x, t, sde, model, corrector, continuous, snr, n_steps, pred_type='noise'):\n    \"\"\"A wrapper tha configures and returns the update function of correctors.\"\"\"\n    score_fn = mutils.get_score_fn(sde, model, train=False, continuous=continuous, pred_type=pred_type)\n    if corrector is None:\n        # Predictor-only sampler\n        corrector_obj = NoneCorrector(sde, score_fn, snr, n_steps)\n    else:\n        corrector_obj = corrector(sde, score_fn, snr, n_steps)\n    return corrector_obj.update_fn(x, t)\n\n\ndef get_pc_sampler(sde, shape, predictor, corrector, inverse_scaler, snr,\n                                     n_steps=1, probability_flow=False, continuous=False,\n                                     denoise=True, eps=1e-3, device='cuda', \n                                     grid_mask=None, return_traj=False, pred_type='noise', use_occ=False):\n    \"\"\"Create a Predictor-Corrector (PC) sampler.\n\n    Args:\n        sde: An `sde_lib.SDE` object representing the forward SDE.\n        shape: A sequence of integers. The expected shape of a single sample.\n        predictor: A subclass of `sampling.Predictor` representing the predictor algorithm.\n        corrector: A subclass of `sampling.Corrector` representing the corrector algorithm.\n        inverse_scaler: The inverse data normalizer.\n        snr: A `float` number. The signal-to-noise ratio for configuring correctors.\n        n_steps: An integer. The number of corrector steps per predictor update.\n        probability_flow: If `True`, solve the reverse-time probability flow ODE when running the predictor.\n        continuous: `True` indicates that the score model was continuously trained.\n        denoise: If `True`, add one-step denoising to the final samples.\n        eps: A `float` number. The reverse-time SDE and ODE are integrated to `epsilon` to avoid numerical issues.\n        device: PyTorch device.\n\n    Returns:\n        A sampling function that returns samples and the number of function evaluations during sampling.\n    \"\"\"\n    # Create predictor & corrector update functions\n    predictor_update_fn = functools.partial(shared_predictor_update_fn,\n                                            sde=sde,\n                                            predictor=predictor,\n                                            probability_flow=probability_flow,\n                                            continuous=continuous,\n                                            pred_type=pred_type)\n    corrector_update_fn = functools.partial(shared_corrector_update_fn,\n                                            sde=sde,\n                                            corrector=corrector,\n                                            continuous=continuous,\n                                            snr=snr,\n                                            n_steps=n_steps,\n                                            pred_type=pred_type)\n\n    def pc_sampler(model, \n            partial=None, partial_grid_mask=None, partial_channel=0, \n            freeze_iters=None):\n        \"\"\" The PC sampler funciton.\n\n        Args:\n            model: A score model.\n        Returns:\n            Samples, number of function evaluations.\n        \"\"\"\n        with torch.no_grad():\n\n            if freeze_iters is None:\n                freeze_iters = sde.N + 10 # just some randomly large number greater than sde.N\n            timesteps = torch.linspace(sde.T, eps, sde.N, device=device)\n\n            mask = model.feature_mask\n\n            if pred_type == 'noise':\n                def compute_xzero(sde, model, x, t, grid_mask_input):\n                    timestep_int = (t * (sde.N - 1) / sde.T).long()\n                    alphas1 = sde.sqrt_alphas_cumprod[timestep_int].cuda()\n                    alphas2 = sde.sqrt_1m_alphas_cumprod[timestep_int].cuda()\n                    alphas1_prev = sde.sqrt_alphas_cumprod[timestep_int - 1].cuda()\n                    alphas2_prev = sde.sqrt_1m_alphas_cumprod[timestep_int - 1].cuda()\n                    score_pred = model(x, t * torch.ones(shape[0], device=x.device))\n                    x0_pred_scaled = (x - alphas2 * score_pred)\n                    x0_pred = x0_pred_scaled / alphas1\n                    x0_pred = x0_pred.clamp(-1, 1)\n                    return x0_pred * grid_mask_input\n            elif pred_type == 'x0':\n                def compute_xzero(sde, model, x, t, grid_mask_input):\n                    timestep_int = (t * (sde.N - 1) / sde.T).long()\n                    alphas1 = sde.sqrt_alphas_cumprod[timestep_int].cuda()\n                    alphas2 = sde.sqrt_1m_alphas_cumprod[timestep_int].cuda()\n                    alphas1_prev = sde.sqrt_alphas_cumprod[timestep_int - 1].cuda()\n                    alphas2_prev = sde.sqrt_1m_alphas_cumprod[timestep_int - 1].cuda()\n                    x0_pred = model(x, t * torch.ones(shape[0], device=x.device))\n                    return x0_pred * grid_mask_input\n\n        \n            # Initial sample\n            x = sde.prior_sampling(shape).to(device)\n            assert len(x.size()) == 5\n\n            traj_buffer = []\n        \n            if partial is not None:\n                assert len(partial.size()) == 5\n                t = timesteps[0]\n                vec_t = torch.ones(shape[0], device=t.device) * t\n                x[:, partial_channel] = partial[:, partial_channel] * grid_mask[:, partial_channel]\n\n                partial_mean, partial_std = sde.marginal_prob(x, vec_t)\n                sampled_update = partial_mean[:, partial_channel] + partial_std[:, None, None, None, None] * torch.randn_like(partial_mean[:, partial_channel], device=partial_std.device)\n                x[:, partial_channel] = (\n                    x[:, partial_channel] * (1 - partial_mask[:, partial_channel]) \n                    + sampled_update[:, partial_channel] * partial_mask[:, partial_channel]\n                ) * grid_mask[:, partial_channel]\n\n\n            if partial is not None:\n                x_mean = x\n                for i in tqdm.trange(sde.N):\n                    t = timesteps[i]\n                    vec_t = torch.ones(shape[0], device=t.device) * t\n\n                    x, x_mean = corrector_update_fn(x, vec_t, model=model)\n                    x, x_mean = x * grid_mask, x_mean * grid_mask\n                    x, x_mean = predictor_update_fn(x, vec_t, model=model)\n                    x, x_mean = x * grid_mask, x_mean * grid_mask\n\n\n                    if i != sde.N - 1 and i < freeze_iters:\n\n                        x[:, partial_channel] = (x[:, partial_channel] * (1 - partial_mask[:, partial_channel]) + partial[:, partial_channel] * partial_mask[:, partial_channel]) * grid_mask[:, partial_channel]\n                        x_mean[:, partial_channel] = (x_mean[:, partial_channel] * (1 - partial_mask[:, partial_channel]) + partial[:, partial_channel] * partial_mask[:, partial_channel]) * grid_mask[:, partial_channel]\n\n                        ### add noise to the condition x0_star\n                        partial_mean, partial_std = sde.marginal_prob(x, timesteps[i] * torch.ones(shape[0], device=t.device))\n                        sampled_update = partial_mean[:, partial_channel] + partial_std[:, None, None, None] * torch.randn_like(partial_mean[:, partial_channel], device=partial_std.device)\n                        x[:, partial_channel] = (\n                            x[:, partial_channel] * (1 - partial_mask[:, partial_channel]) \n                            + sampled_update * partial_mask[:, partial_channel]\n                        ) * grid_mask[:, partial_channel]\n                        x_mean[:, partial_channel] = x[:, partial_channel]\n\n            else:\n\n                for i in tqdm.trange(sde.N - 1):\n                    t = timesteps[i]\n\n                    with torch.autocast(device_type='cuda', dtype=torch.float16, enabled=False):\n                        vec_t = torch.ones(shape[0], device=t.device) * t\n                        x, x_mean = corrector_update_fn(x, vec_t, model=model)\n                        x, x_mean = x * mask, x_mean * mask\n                        x, x_mean = predictor_update_fn(x, vec_t, model=model)\n                        x, x_mean = x * mask, x_mean * mask\n                        print(x.min(), x.max())\n\n                    if return_traj and i >= 700 and i % 10 == 0:\n                        traj_buffer.append(compute_xzero(sde, model, x, t, grid_mask))\n\n            if return_traj:\n                return traj_buffer, sde.N * (n_steps + 1)\n            return inverse_scaler(x_mean * mask if denoise else x * mask), sde.N * (n_steps + 1)\n\n    return pc_sampler\n\ndef ddim_predictor_update_fn(x, t, tprev, sde, model, predictor, probability_flow, continuous, pred_type='noise'):\n    \"\"\"A wrapper that configures and returns the update function of predictors.\"\"\"\n    assert not continuous\n    score_fn = mutils.get_score_fn(sde, model, train=False, continuous=False, std_scale=False, pred_type=pred_type)\n    if predictor is None:\n        # Corrector-only sampler\n        predictor_obj = NonePredictor(sde, score_fn, probability_flow)\n    else:\n        predictor_obj = predictor(sde, score_fn, probability_flow)\n    return predictor_obj.update_fn(x, t, tprev)\n\ndef get_ddim_sampler(sde, shape, predictor, inverse_scaler, n_steps=1,\n                    denoise=False, eps=1e-3, device='cuda', grid_mask=None, pred_type='noise', use_occ=False):\n    \"\"\"Probability flow ODE sampler with the black-box ODE solver.\n\n    Args:\n        sde: An `sde_lib.SDE` object that represents the forward SDE.\n        shape: A sequence of integers. The expected shape of a single sample.\n        inverse_scaler: The inverse data normalizer.\n        denoise: If `True`, add one-step denoising to final samples.\n        eps: A `float` number. The reverse-time SDE/ODE will be integrated to `eps` for numerical stability.\n        device: PyTorch device.\n\n    Returns:\n        A sampling function that returns samples and the number of function evaluations during sampling.\n    \"\"\"\n\n    predictor_update_fn = functools.partial(ddim_predictor_update_fn,\n                                            sde=sde,\n                                            predictor=predictor,\n                                            probability_flow=False,\n                                            continuous=False,\n                                            pred_type=pred_type)\n\n    def ddim_sampler(model, schedule='quad', num_steps=100, x0=None, x0_occ=None,\n            partial=None, partial_grid_mask=None, partial_channel=0):\n        \"\"\" The PC sampler funciton.\n\n        Args:\n            model: A score model.\n        Returns:\n            Samples, number of function evaluations.\n        \"\"\"\n        with torch.no_grad():\n            print(device)\n            if x0 is not None:\n                x = x0.to(device)\n            else:\n                # Initial sample\n                x = sde.prior_sampling(shape).to(device)\n            \n            mask = model.feature_mask\n            if use_occ:\n                occ_mask = model.occ_mask.float()\n                if x0_occ is not None:\n                    x_occ = x0_occ.to(device)\n                else:\n                    # Initial sample\n                    x_occ = sde.prior_sampling((x.size(0), 1, x.size(2)*2, x.size(3)*2, x.size(4)*2)).to(device)\n        \n                x = (x.float() * mask, x_occ.float() * occ_mask)\n\n            if partial is not None:\n                x[:, partial_channel] = x[:, partial_channel] * (1 - partial_mask) + partial * partial_mask\n\n            timesteps = torch.linspace(sde.T, eps, sde.N, device=device)\n\n            if schedule == 'uniform':\n                skip = sde.N // num_steps\n                seq = range(0, sde.N, skip)\n            elif schedule == 'quad':\n                seq = (\n                    np.linspace(\n                        0, np.sqrt(sde.N * 0.8), 100\n                    )\n                    ** 2\n                )\n                seq = [int(s) for s in list(seq)]\n\n            timesteps = torch.tensor(seq, device=device) / sde.N\n\n            for i in tqdm.tqdm(list(reversed(range(1, len(timesteps)))), leave=False):\n                t = timesteps[i]\n                tprev = timesteps[i - 1]\n                vec_t = torch.ones(1, device=t.device) * t\n                vec_tprev = torch.ones(1, device=t.device) * tprev\n                with torch.autocast(device_type='cuda', dtype=torch.bfloat16, enabled=False):\n                    x, x0_pred = predictor_update_fn(x, vec_t, model=model, tprev=vec_tprev)\n                    if use_occ:\n                        x = (x[0] * mask, x[1] * occ_mask)\n                        x0_pred = (x0_pred[0] * mask, x0_pred[1] * occ_mask)\n                        # print(x[0].min(), x[0].max())\n                    else:\n                        x, x0_pred = x * mask, x0_pred * mask\n                        # print(x.min(), x.max())\n                if partial is not None:\n                    x[:, partial_channel] = x[:, partial_channel] * (1 - partial_mask) + partial * partial_mask\n                    x0_pred[:, partial_channel] = x0_pred[:, partial_channel] * (1 - partial_mask) + partial * partial_mask\n\n            if use_occ:\n                encode = False\n                return (\n                    inverse_scaler(x0_pred[0] * mask if (denoise and not encode) else x[0] * mask),\n                    inverse_scaler(x0_pred[1] * occ_mask if (denoise and not encode) else x[1] * occ_mask)\n                ), sde.N * (n_steps + 1)\n            else:\n                encode = False\n                return inverse_scaler(x0_pred * mask if (denoise and not encode) else x * mask), sde.N * (n_steps + 1)\n    return ddim_sampler\n"
  },
  {
    "path": "GMeshDiffusion/lib/diffusion/sde_lib.py",
    "content": "\"\"\"Abstract SDE classes, Reverse SDE, and VE/VP SDEs.\"\"\"\nimport abc\nimport torch\nimport numpy as np\nimport torch.nn.functional as F\nimport time\n\n\nclass SDE(abc.ABC):\n  \"\"\"SDE abstract class. Functions are designed for a mini-batch of inputs.\"\"\"\n\n  def __init__(self, N):\n    \"\"\"Construct an SDE.\n\n    Args:\n      N: number of discretization time steps.\n    \"\"\"\n    super().__init__()\n    self.N = N\n\n  @property\n  @abc.abstractmethod\n  def T(self):\n    \"\"\"End time of the SDE.\"\"\"\n    pass\n\n  @abc.abstractmethod\n  def sde(self, x, t):\n    pass\n\n  @abc.abstractmethod\n  def marginal_prob(self, x, t):\n    \"\"\"Parameters to determine the marginal distribution of the SDE, $p_t(x)$.\"\"\"\n    pass\n\n  @abc.abstractmethod\n  def prior_sampling(self, shape):\n    \"\"\"Generate one sample from the prior distribution, $p_T(x)$.\"\"\"\n    pass\n\n  @abc.abstractmethod\n  def prior_logp(self, z):\n    \"\"\"Compute log-density of the prior distribution.\n\n    Useful for computing the log-likelihood via probability flow ODE.\n\n    Args:\n      z: latent code\n    Returns:\n      log probability density\n    \"\"\"\n    pass\n\n  def discretize(self, x, t):\n    \"\"\"Discretize the SDE in the form: x_{i+1} = x_i + f_i(x_i) + G_i z_i.\n\n    Useful for reverse diffusion sampling and probabiliy flow sampling.\n    Defaults to Euler-Maruyama discretization.\n\n    Args:\n      x: a torch tensor\n      t: a torch float representing the time step (from 0 to `self.T`)\n\n    Returns:\n      f, G\n    \"\"\"\n    dt = 1 / self.N\n    drift, diffusion = self.sde(x, t)\n    f = drift * dt\n    G = diffusion * torch.sqrt(torch.tensor(dt, device=t.device))\n    return f, G\n\n  def reverse(self, score_fn, probability_flow=False):\n    \"\"\"Create the reverse-time SDE/ODE.\n\n    Args:\n      score_fn: A time-dependent score-based model that takes x and t and returns the score.\n      probability_flow: If `True`, create the reverse-time ODE used for probability flow sampling.\n    \"\"\"\n    N = self.N\n    T = self.T\n    sde_fn = self.sde\n    discretize_fn = self.discretize\n    sqrt_alphas_cumprod = self.sqrt_alphas_cumprod\n    sqrt_1m_alphas_cumprod = self.sqrt_1m_alphas_cumprod\n\n    # Build the class for reverse-time SDE.\n    class RSDE(self.__class__):\n      def __init__(self):\n        self.N = N\n        self.probability_flow = probability_flow\n\n      @property\n      def T(self):\n        return T\n\n      def sde(self, x, t):\n        \"\"\"Create the drift and diffusion functions for the reverse SDE/ODE.\"\"\"\n        drift, diffusion = sde_fn(x, t)\n        score = score_fn(x, t)\n        drift = drift - diffusion[:, None, None, None, None] ** 2 * score * (0.5 if self.probability_flow else 1.)\n        # Set the diffusion function to zero for ODEs.\n        diffusion = 0. if self.probability_flow else diffusion\n        return drift, diffusion\n\n      def discretize(self, x, t):\n        \"\"\"Create discretized iteration rules for the reverse diffusion sampler.\"\"\"\n        f, G = discretize_fn(x, t)\n        rev_f = f - G[:, None, None, None, None] ** 2 * score_fn(x, t) * (0.5 if self.probability_flow else 1.)\n        rev_G = torch.zeros_like(G) if self.probability_flow else G\n        return rev_f, rev_G\n\n      def discretize_ddim(self, x, t, tprev=None, encode=False):\n        \"\"\"DDPM discretization.\"\"\"\n        timestep = (t * (N - 1) / T).long()\n        timestep_prev = (tprev * (N - 1) / T).long()\n\n        if type(x) == torch.Tensor:\n          score = score_fn(x.float(), t.float())\n\n          # alphas1prev_div_alphas1 = torch.exp(log_diff)\n          alphas1 = sqrt_alphas_cumprod[timestep].cuda()[:, None, None, None, None]\n          alphas2 = sqrt_1m_alphas_cumprod[timestep].cuda()[:, None, None, None, None]\n          alphas1_prev = sqrt_alphas_cumprod[timestep_prev].cuda()[:, None, None, None, None]\n          alphas2_prev = sqrt_1m_alphas_cumprod[timestep_prev].cuda()[:, None, None, None, None]\n          alphas1prev_div_alphas1 = alphas1_prev.double() / alphas1.double()\n          alphas2prev_div_alphas2 = alphas2_prev.double() / alphas2.double()\n\n\n          x0_pred_scaled = (x.double() - alphas2.double() * score.double())\n          use_clip = False\n          if use_clip:\n            # raise NotImplementedError\n            x0_pred_scaled = x0_pred_scaled.clamp(-alphas1[0].squeeze(), alphas1[0].squeeze())\n          score_scaled_t = x - x0_pred_scaled\n          x0_pred = x0_pred_scaled / alphas1\n\n          x_new = (\n            alphas1prev_div_alphas1.double() * x + \n            (-alphas1prev_div_alphas1 + alphas2prev_div_alphas2.double()) * score_scaled_t.double()\n          )\n          return x_new, x0_pred\n        else:\n          score, score_occ = score_fn(x, t.float())\n          x, x_occ = x\n\n          alphas1 = sqrt_alphas_cumprod[timestep].cuda()[:, None, None, None, None]\n          alphas2 = sqrt_1m_alphas_cumprod[timestep].cuda()[:, None, None, None, None]\n          alphas1_prev = sqrt_alphas_cumprod[timestep_prev].cuda()[:, None, None, None, None]\n          alphas2_prev = sqrt_1m_alphas_cumprod[timestep_prev].cuda()[:, None, None, None, None]\n          alphas1prev_div_alphas1 = alphas1_prev / alphas1\n          alphas2prev_div_alphas2 = alphas2_prev / alphas2\n\n          x0_pred_scaled = (x - alphas2 * score)\n          x0_occ_pred_scaled = (x_occ - alphas2 * score_occ)\n          use_clip = False\n          if use_clip:\n            x0_pred_scaled = x0_pred_scaled.clamp(-alphas1[0].squeeze(), alphas1[0].squeeze())\n            x0_occ_pred_scaled = x0_occ_pred_scaled.clamp(-alphas1[0].squeeze(), alphas1[0].squeeze())\n          score_scaled_t = x - x0_pred_scaled\n          x0_pred = x0_pred_scaled / alphas1\n          score_occ_scaled_t = x_occ - x0_occ_pred_scaled\n          x0_occ_pred = x0_occ_pred_scaled / alphas1\n\n          x_new = (\n            alphas1prev_div_alphas1 * x + \n            (-alphas1prev_div_alphas1 + alphas2prev_div_alphas2) * score_scaled_t\n          )\n          x_occ_new = (\n            alphas1prev_div_alphas1 * x_occ + \n            (-alphas1prev_div_alphas1 + alphas2prev_div_alphas2) * score_occ_scaled_t\n          )\n          return (x_new, x_occ_new), (x0_pred, x0_occ_pred)\n\n\n      def discretize_conditional_ddpm(self, x, t, tprev=None, condition_func=None, condition=False):\n        \"\"\"DDPM discretization.\"\"\"\n        timestep = (t * (N - 1) / T).long()\n        timestep_prev = (tprev * (N - 1) / T).long()\n\n        score = score_fn(x.float(), t.float())\n\n        # alphas1prev_div_alphas1 = torch.exp(log_diff)\n        alphas1 = sqrt_alphas_cumprod[timestep].cuda()[:, None, None, None, None]\n        alphas2 = sqrt_1m_alphas_cumprod[timestep].cuda()[:, None, None, None, None]\n        alphas1_prev = sqrt_alphas_cumprod[timestep_prev].cuda()[:, None, None, None, None]\n        alphas1prev_div_alphas1 = alphas1_prev.double() / alphas1.double()\n\n        x0_pred_scaled = (x.double() - alphas2.double() * score.double())\n        x0_pred_scaled = x0_pred_scaled.clamp(-alphas1[0].squeeze(), alphas1[0].squeeze())\n        x0_pred = x0_pred_scaled / alphas1\n\n        if condition is None:\n          condition_update = 0\n        else:\n          if (t - 0.99).mean() < 1e-3:\n            x = x0_pred\n          condition_update = condition_func(x.float(), condition)\n\n        x_new = (\n          x - alphas1prev_div_alphas1.double() * condition_update\n        )\n        return x_new, x0_pred\n\n\n    return RSDE()\n\n\nclass VPSDE(SDE):\n  def __init__(self, beta_min=0.1, beta_max=20, N=1000):\n    \"\"\"Construct a Variance Preserving SDE.\n\n    Args:\n      beta_min: value of beta(0)\n      beta_max: value of beta(1)\n      N: number of discretization steps\n    \"\"\"\n    super().__init__(N)\n    self.beta_0 = beta_min\n    self.beta_1 = beta_max\n    self.N = N\n    self.discrete_betas = torch.linspace(beta_min / N, beta_max / N, N).cuda()\n    self.alphas = 1. - self.discrete_betas\n    self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)\n    self.alphas_cumprod_ext = torch.cat([torch.tensor([1.0 - 1e-4]).cuda(), torch.cumprod(self.alphas, dim=0)], dim=0)\n\n    self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod)\n    self.sqrt_1m_alphas_cumprod = torch.sqrt(1. - self.alphas_cumprod)\n\n    self.alphas_cumprod = self.alphas_cumprod\n    self.alphas_cumprod_ext = self.alphas_cumprod_ext\n\n  @property\n  def T(self):\n    return 1\n\n  def sde(self, x, t):\n    beta_t = self.beta_0 + t * (self.beta_1 - self.beta_0)\n    drift = -0.5 * beta_t[:, None, None, None, None] * x\n    diffusion = torch.sqrt(beta_t)\n    return drift, diffusion\n\n  def marginal_prob(self, x, t):\n    log_mean_coeff = -0.25 * t ** 2 * (self.beta_1 - self.beta_0) - 0.5 * t * self.beta_0\n    mean = torch.exp(log_mean_coeff[:, None, None, None, None]) * x\n    std = torch.sqrt(1. - torch.exp(2. * log_mean_coeff))\n    return mean, std\n\n  def prior_sampling(self, shape):\n    return torch.randn(*shape)\n\n  def prior_logp(self, z):\n    shape = z.shape\n    N = np.prod(shape[1:])\n    logps = -N / 2. * np.log(2 * np.pi) - torch.sum(z ** 2, dim=(1, 2, 3, 4)) / 2.\n    return logps\n\n  def discretize(self, x, t):\n    \"\"\"DDPM discretization.\"\"\"\n    timestep = (t * (self.N - 1) / self.T).long()\n    beta = self.discrete_betas.to(x.device)[timestep]\n    alpha = self.alphas.to(x.device)[timestep]\n    sqrt_beta = torch.sqrt(beta)\n    f = torch.sqrt(alpha)[:, None, None, None, None] * x - x\n    G = sqrt_beta\n    return f, G"
  },
  {
    "path": "GMeshDiffusion/lib/diffusion/trainer.py",
    "content": "import os\nimport sys\nimport numpy as np\n\nimport logging\n# Keep the import below for registering all model definitions\nfrom .models import unet3d, unet3d_occgrid, unet3d_tet_aware, unet3d_occgrid_v2, unet3d_meshdiffusion\n\nfrom . import losses\nfrom .models import utils as mutils\nfrom .models.ema import ExponentialMovingAverage\nfrom . import sde_lib\nimport torch\nfrom torch.utils import tensorboard\nfrom .utils import save_checkpoint, restore_checkpoint\nfrom ..dataset.gshell_dataset import GShellDataset\nfrom ..dataset.gshell_dataset_aug import GShellAugDataset\n\n\ndef train(config):\n    \"\"\"Runs the training pipeline.\n\n    Args:\n    config: Configuration to use.\n    workdir: Working directory for checkpoints and TF summaries. If this\n        contains checkpoint training will be resumed from the latest checkpoint.\n    \"\"\"\n\n    workdir = config.training.train_dir\n    # Create directories for experimental logs\n    logging.info(\"working dir: {:s}\".format(workdir))\n\n\n    tb_dir = os.path.join(workdir, \"tensorboard\")\n    writer = tensorboard.SummaryWriter(tb_dir)\n\n    # Initialize model.\n    score_model = mutils.create_model(config)\n    ema = ExponentialMovingAverage(score_model.parameters(), decay=config.model.ema_rate)\n    optimizer = losses.get_optimizer(config, score_model.parameters())\n    gradscaler = torch.cuda.amp.GradScaler(enabled=True)\n\n    state = dict(optimizer=optimizer, model=score_model, ema=ema, gradscaler=gradscaler, step=0)\n\n\n    # Create checkpoints directory\n    checkpoint_dir = os.path.join(workdir, \"checkpoints\")\n    # Intermediate checkpoints to resume training after pre-emption in cloud environments\n    checkpoint_meta_dir = os.path.join(workdir, \"checkpoints-meta\", \"checkpoint.pth\")\n    os.makedirs(checkpoint_dir, exist_ok=True)\n    os.makedirs(os.path.dirname(checkpoint_meta_dir), exist_ok=True)\n\n    # Resume training when intermediate checkpoints are detected\n    state = restore_checkpoint(checkpoint_meta_dir, state, config.device)\n    initial_step = int(state['step'])\n\n    print(f\"work dir: {workdir}\")\n\n    \n    try:\n        use_occ_grid = config.data.use_occ_grid\n    except:\n        use_occ_grid = False\n    if use_occ_grid:\n        train_dataset = GShellAugDataset(config)\n    else:\n        train_dataset = GShellDataset(config.data.dataset_metapath)\n\n\n    try:\n        collate_fn = train_dataset.collate\n    except:\n        collate_fn = None\n\n    train_loader = torch.utils.data.DataLoader(\n        train_dataset, \n        batch_size=config.training.batch_size, \n        shuffle=True,\n        num_workers=config.data.num_workers,\n        collate_fn=collate_fn,\n        pin_memory=True\n    )\n\n    data_iter = iter(train_loader)\n\n    print(\"data loader set\")\n\n    # Setup SDEs\n    sde = sde_lib.VPSDE(beta_min=config.model.beta_min, beta_max=config.model.beta_max, N=config.model.num_scales)\n\n    # Build one-step training and evaluation functions\n    optimize_fn = losses.optimization_manager(config)\n    try:\n        use_vis_mask = config.model.use_vis_mask\n    except:\n        use_vis_mask = False\n    print('use_vis_mask', use_vis_mask)\n    train_step_fn = losses.get_step_fn(sde, train=True, optimize_fn=optimize_fn,\n                                        loss_type=config.training.loss_type,\n                                        pred_type=config.model.pred_type,\n                                        use_vis_mask=use_vis_mask,\n                                        use_occ=use_occ_grid,\n                                        use_aux=config.training.use_aux_loss)\n\n    num_train_steps = config.training.n_iters\n\n    # In case there are multiple hosts (e.g., TPU pods), only log to host 0\n    logging.info(\"Starting training loop at step %d.\" % (initial_step // config.training.num_grad_acc_steps,))\n\n\n    iter_size = config.training.num_grad_acc_steps\n    for step in range(initial_step // iter_size, num_train_steps + 1):\n        tmp_loss_dict = {\n            'loss_total': 0.0,\n            'loss_score': 0.0,\n            'loss_reg': 0.0,\n        }\n        for step_inner in range(iter_size):\n            try:\n                # batch, batch_mask = next(data_iter)\n                batch = next(data_iter)\n            except StopIteration:\n                # StopIteration is thrown if dataset ends\n                # reinitialize data loader \n                data_iter = iter(train_loader)\n                batch = next(data_iter)\n\n            \n            if type(batch) == dict:\n                for k in batch:\n                    batch[k] = batch[k].to('cuda', non_blocking=False)\n            else:\n                batch = batch.to('cuda', non_blocking=False)\n\n            # Execute one training step\n            clear_grad_flag = (step_inner == 0)\n            update_param_flag = (step_inner == iter_size - 1)\n            loss_dict = train_step_fn(state, batch, clear_grad=clear_grad_flag, update_param=update_param_flag, gradscaler=gradscaler)\n            for key in loss_dict:\n                tmp_loss_dict[key] += loss_dict[key].item() / iter_size\n\n            # print(torch.cuda.memory_summary())\n\n        if step % config.training.log_freq == 0:\n            # logging.info(\"step: %d, training_loss: %.5e\" % (step, tmp_loss))\n            logging.info(\n                \"step: %d, loss_total: %.5e, loss_score: %.5e, loss_reg: %.5e\" \n                % (step, tmp_loss_dict['loss_total'], tmp_loss_dict['loss_score'], tmp_loss_dict['loss_reg'])\n            )\n            sys.stdout.flush()\n            writer.add_scalar(\"loss_total\", tmp_loss_dict['loss_total'], step)\n            writer.add_scalar(\"loss_score\", tmp_loss_dict['loss_score'], step)\n            writer.add_scalar(\"loss_reg\", tmp_loss_dict['loss_reg'], step)\n\n        # Save a temporary checkpoint to resume training after pre-emption periodically\n        if step != 0 and step % config.training.snapshot_freq_for_preemption == 0:\n            logging.info(f\"save meta at iter {step}\")\n            save_checkpoint(checkpoint_meta_dir, state)\n\n        # Save a checkpoint periodically and generate samples if needed\n        if step != 0 and step % config.training.snapshot_freq == 0 or step == num_train_steps:\n            logging.info(f\"save model: {step}-th\")\n            save_checkpoint(os.path.join(checkpoint_dir, f'checkpoint_{step}.pth'), state)\n"
  },
  {
    "path": "GMeshDiffusion/lib/diffusion/trainer_ddp.py",
    "content": "import os\nimport sys\nimport numpy as np\n\nimport logging\n# Keep the import below for registering all model definitions\nfrom .models import unet3d, unet3d_occgrid, unet3d_tet_aware, unet3d_occgrid_v2, unet3d_meshdiffusion\n\nfrom . import losses\nfrom .models import utils as mutils\nfrom .models.ema import ExponentialMovingAverage\nfrom . import sde_lib\nimport torch\nfrom torch.utils import tensorboard\nfrom .utils import save_checkpoint, restore_checkpoint\nfrom ..dataset.gshell_dataset import GShellDataset\nfrom ..dataset.gshell_dataset_aug import GShellAugDataset\n\nfrom .lion.lion import Lion\nimport torch.distributed as dist\n\ndef train(config):\n    \"\"\"Runs the training pipeline.\n\n    Args:\n    config: Configuration to use.\n    workdir: Working directory for checkpoints and TF summaries. If this\n        contains checkpoint training will be resumed from the latest checkpoint.\n    \"\"\"\n    dist.init_process_group(\"nccl\")\n    rank = dist.get_rank()\n    torch.cuda.set_device(rank)\n    device = torch.device(\"cuda\", rank)\n    print(f\"Start running basic DDP example on rank {rank}.\")\n\n    # create model and move it to GPU with id rank\n    world_size = torch.cuda.device_count()\n    device_id = rank % torch.cuda.device_count()\n\n    workdir = config.training.train_dir\n    # Create directories for experimental logs\n    logging.info(\"working dir: {:s}\".format(workdir))\n\n\n    tb_dir = os.path.join(workdir, \"tensorboard\")\n    writer = tensorboard.SummaryWriter(tb_dir)\n\n    # Initialize model.\n    score_model = mutils.create_model(config, ddp=True, rank=rank)\n    ema = ExponentialMovingAverage(score_model.parameters(), decay=config.model.ema_rate)\n    optimizer = losses.get_optimizer(config, score_model.parameters())\n    gradscaler = torch.cuda.amp.GradScaler(growth_interval=config.training.gradscaler_growth_interval)\n\n    state = dict(optimizer=optimizer, model=score_model, ema=ema, gradscaler=gradscaler, step=0)\n\n\n    # Create checkpoints directory\n    checkpoint_dir = os.path.join(workdir, \"checkpoints\")\n    # Intermediate checkpoints to resume training after pre-emption in cloud environments\n    checkpoint_meta_dir = os.path.join(workdir, \"checkpoints-meta\", \"checkpoint.pth\")\n    os.makedirs(checkpoint_dir, exist_ok=True)\n    os.makedirs(os.path.dirname(checkpoint_meta_dir), exist_ok=True)\n\n    # Resume training when intermediate checkpoints are detected\n    state = restore_checkpoint(checkpoint_meta_dir, state, config.device, rank=rank)\n    initial_step = int(state['step'])\n\n    print(f\"work dir: {workdir}\")\n\n    try:\n        use_occ_grid = config.data.use_occ_grid\n    except:\n        use_occ_grid = False\n    if use_occ_grid:\n        train_dataset = GShellAugDataset(config)\n    else:\n        train_dataset = GShellDataset(config.data.dataset_metapath)\n\n    train_sampler = torch.utils.data.distributed.DistributedSampler(\n    \ttrain_dataset,\n    \tnum_replicas=world_size,\n    \trank=rank\n    )\n\n    try:\n        collate_fn = train_dataset.collate\n    except:\n        collate_fn = None\n    train_loader = torch.utils.data.DataLoader(\n        train_dataset, \n        batch_size=config.training.batch_size, \n        num_workers=config.data.num_workers,\n        # pin_memory=True,\n        sampler=train_sampler,\n        collate_fn=collate_fn\n    )\n\n    data_iter = iter(train_loader)\n\n    print(\"data loader set\")\n\n    # Setup SDEs\n    sde = sde_lib.VPSDE(beta_min=config.model.beta_min, beta_max=config.model.beta_max, N=config.model.num_scales)\n\n    # Build one-step training and evaluation functions\n    optimize_fn = losses.optimization_manager(config)\n    try:\n        use_vis_mask = config.model.use_vis_mask\n    except:\n        use_vis_mask = False\n    print('use_vis_mask', use_vis_mask)\n    train_step_fn = losses.get_step_fn(sde, train=True, optimize_fn=optimize_fn,\n                                        loss_type=config.training.loss_type,\n                                        pred_type=config.model.pred_type,\n                                        use_vis_mask=use_vis_mask,\n                                        use_occ=use_occ_grid,\n                                        use_aux=config.training.use_aux_loss)\n\n    num_train_steps = config.training.n_iters\n\n    # In case there are multiple hosts (e.g., TPU pods), only log to host 0\n    logging.info(\"Starting training loop at step %d.\" % (initial_step // config.training.num_grad_acc_steps,))\n\n    iter_size = config.training.num_grad_acc_steps\n    epoch = 0\n    train_sampler.set_epoch(epoch)\n    for step in range(initial_step // iter_size, num_train_steps + 1):\n        tmp_loss_dict = {\n            'loss_total': 0.0,\n            'loss_score': 0.0,\n            'loss_reg': 0.0,\n        }\n        for step_inner in range(iter_size):\n            try:\n                # batch, batch_mask = next(data_iter)\n                batch = next(data_iter)\n            except StopIteration:\n                # StopIteration is thrown if dataset ends\n                # reinitialize data loader \n                epoch += 1\n                train_sampler.set_epoch(epoch)\n                data_iter = iter(train_loader)\n                batch = next(data_iter)\n\n            if type(batch) == dict:\n                for k in batch:\n                    batch[k] = batch[k].to(rank, non_blocking=False)\n            else:\n                batch = batch.to(rank, non_blocking=False)\n\n            # Execute one training step\n            clear_grad_flag = (step_inner == 0)\n            update_param_flag = (step_inner == iter_size - 1)\n            if not update_param_flag:\n                with score_model.no_sync():\n                    loss_dict = train_step_fn(state, batch, clear_grad=clear_grad_flag, update_param=update_param_flag, gradscaler=gradscaler)\n            else:\n                loss_dict = train_step_fn(state, batch, clear_grad=clear_grad_flag, update_param=update_param_flag, gradscaler=gradscaler)\n            for key in loss_dict:\n                tmp_loss_dict[key] += loss_dict[key].item() / iter_size\n\n            # print(torch.cuda.memory_summary())\n\n        if step % config.training.log_freq == 0:\n            loss = tmp_loss_dict['loss_total']\n            loss = torch.tensor(loss / world_size).to(rank)\n\n            # logging.info(\"step: %d, training_loss: %.5e\" % (step, tmp_loss))\n            dist.reduce(loss, dst=0, op=dist.ReduceOp.SUM)\n            if rank == 0:\n                loss = loss.item()\n                logging.info(\"step: %d, loss: %.5e, scale: %.5e\" % (step, loss, gradscaler.get_scale()))\n                sys.stdout.flush()\n                writer.add_scalar(\"loss\", loss, step)\n\n        if rank == 0:\n            # Save a temporary checkpoint to resume training after pre-emption periodically\n            if step != 0 and step % config.training.snapshot_freq_for_preemption == 0:\n                logging.info(f\"save meta at iter {step}\")\n                save_checkpoint(checkpoint_meta_dir, state)\n\n            # Save a checkpoint periodically and generate samples if needed\n            if step != 0 and step % config.training.snapshot_freq == 0 or step == num_train_steps:\n                logging.info(f\"save model: {step}-th\")\n                save_checkpoint(os.path.join(checkpoint_dir, f'checkpoint_{step}.pth'), state)\n\n    dist.destroy_process_group()"
  },
  {
    "path": "GMeshDiffusion/lib/diffusion/utils.py",
    "content": "import torch\nimport os\nimport logging\n\n\ndef restore_checkpoint(ckpt_dir, state, device, strict=False, rank=None):\n  if not os.path.exists(ckpt_dir):\n    os.makedirs(os.path.dirname(ckpt_dir), exist_ok=True)\n    logging.warning(f\"No checkpoint found at {ckpt_dir}. \"\n                    f\"Returned the same state as input\")\n    if strict:\n      raise\n    return state\n  else:\n    if rank is not None:\n      device = f\"cuda:{rank}\"\n    # loaded_state = torch.load(ckpt_dir, map_location=device)\n    loaded_state = torch.load(ckpt_dir, map_location='cpu')\n    state['optimizer'].load_state_dict(loaded_state['optimizer'])\n    try:\n      state['model'].load_state_dict(loaded_state['model'], strict=False)\n    except:\n      consume_prefix_in_state_dict_if_present(loaded_state['model'])\n      state['model'].load_state_dict(loaded_state['model'], strict=False)\n    state['ema'].load_state_dict(loaded_state['ema'], device=device)\n    state['step'] = loaded_state['step']\n    state['model'].to(device)\n    try:\n      state['gradscaler'].load_state_dict(loaded_state['gradscaler'])\n      # state['gradscaler'].to(device)\n    except:\n      # raise\n      pass\n    torch.cuda.empty_cache()\n    return state\n\n\ndef save_checkpoint(ckpt_dir, state):\n  saved_state = {\n    'optimizer': state['optimizer'].state_dict(),\n    'model': state['model'].state_dict(),\n    'ema': state['ema'].state_dict(),\n    'step': state['step'],\n    'gradscaler': state['gradscaler'].state_dict()\n  }\n  torch.save(saved_state, ckpt_dir)"
  },
  {
    "path": "GMeshDiffusion/main_diffusion.py",
    "content": "\"\"\"Training and evaluation\"\"\"\n\nfrom absl import app\nfrom absl import flags\nfrom ml_collections.config_flags import config_flags\n\nimport lib.diffusion.trainer as trainer\nimport lib.diffusion.evaler as evaler\n\n\nFLAGS = flags.FLAGS\n\nconfig_flags.DEFINE_config_file(\n    \"config\", None, \"diffusion configs\", lock_config=False)\nflags.DEFINE_enum(\"mode\", None, [\"train\", \"uncond_gen\", \"cond_gen\", \"uncond_gen_interp\"], \"Running mode\")\nflags.mark_flags_as_required([\"config\", \"mode\"])\n\n\ndef main(argv):\n    if FLAGS.mode == 'train':\n        trainer.train(FLAGS.config)\n    elif FLAGS.mode == 'uncond_gen':\n        evaler.uncond_gen(FLAGS.config)\n    elif FLAGS.mode == 'uncond_gen_interp':\n        evaler.uncond_gen_interp(FLAGS.config)\n    elif FLAGS.mode == 'cond_gen':\n        evaler.cond_gen(FLAGS.config)\n\nif __name__ == \"__main__\":\n  app.run(main)\n"
  },
  {
    "path": "GMeshDiffusion/main_diffusion_ddp.py",
    "content": "\"\"\"Training and evaluation\"\"\"\n\nfrom absl import app\nfrom absl import flags\nfrom ml_collections.config_flags import config_flags\n\nimport lib.diffusion.trainer_ddp as trainer\nimport lib.diffusion.evaler as evaler\n\n\n\n\nFLAGS = flags.FLAGS\n\nconfig_flags.DEFINE_config_file(\n    \"config\", None, \"diffusion configs\", lock_config=False)\nflags.DEFINE_enum(\"mode\", None, [\"train\", \"uncond_gen\", \"cond_gen\", \"uncond_gen_interp\"], \"Running mode\")\nflags.mark_flags_as_required([\"config\", \"mode\"])\n\ndef main(argv):\n    print(FLAGS.config)\n    if FLAGS.mode == 'train':\n        trainer.train(FLAGS.config)\n\nif __name__ == \"__main__\":\n  app.run(main)\n"
  },
  {
    "path": "GMeshDiffusion/metadata/get_splits_lower.py",
    "content": "import os\nimport random\n\nrandom.seed(42)\n\nsplit_ratio = 0.9\ndata_root = 'PLACEHOLDER'\ngrid_root = os.path.join(data_root, 'grid')\noccgrid_root = os.path.join(data_root, 'grid_aug')\ndata_path_list = sorted([os.path.join(data_root, fpath) for fpath in os.listdir(data_root)])\n\nrandom.shuffle(data_path_list)\n\nn_train = int(len(data_path_list) * split_ratio)\ntrain_list = data_path_list[:n_train]\ntest_list = data_path_list[n_train:]\n\nwith open('lower_res64_grid_train.txt', 'w') as f:\n    f.write('\\n'.join(train_list))\n\nwith open('lower_res64_grid_test.txt', 'w') as f:\n    f.write('\\n'.join(test_list))\n\n\noccgrid_train_list = [os.path.join(occgrid_root, x.split('/')[-1]) for x in train_list]\noccgrid_test_list = [os.path.join(occgrid_root, x.split('/')[-1]) for x in test_list]\n\nwith open('lower_res64_occgrid_train.txt', 'w') as f:\n    f.write('\\n'.join(occgrid_train_list))\n\nwith open('lower_res64_occgrid_test.txt', 'w') as f:\n    f.write('\\n'.join(occgrid_test_list))\n\n"
  },
  {
    "path": "GMeshDiffusion/metadata/get_splits_upper.py",
    "content": "import os\nimport random\n\nrandom.seed(42)\n\nsplit_ratio = 0.9\ndata_root = 'PLACEHOLDER'\ngrid_root = os.path.join(data_root, 'grid')\noccgrid_root = os.path.join(data_root, 'grid_aug')\ndata_path_list = sorted([os.path.join(data_root, fpath) for fpath in os.listdir(data_root)])\n\nrandom.shuffle(data_path_list)\n\nn_train = int(len(data_path_list) * split_ratio)\ntrain_list = data_path_list[:n_train]\ntest_list = data_path_list[n_train:]\n\nwith open('upper_res64_grid_train.txt', 'w') as f:\n    f.write('\\n'.join(train_list))\n\nwith open('upper_res64_grid_test.txt', 'w') as f:\n    f.write('\\n'.join(test_list))\n\n\noccgrid_train_list = [os.path.join(occgrid_root, x.split('/')[-1]) for x in train_list]\noccgrid_test_list = [os.path.join(occgrid_root, x.split('/')[-1]) for x in test_list]\n\nwith open('upper_res64_occgrid_train.txt', 'w') as f:\n    f.write('\\n'.join(occgrid_train_list))\n\nwith open('upper_res64_occgrid_test.txt', 'w') as f:\n    f.write('\\n'.join(occgrid_test_list))\n\n"
  },
  {
    "path": "GMeshDiffusion/metadata/save_tet_info.py",
    "content": "'''\n    Storing tet-grid related meta-info into a single file\n'''\n\nimport numpy as np\nimport torch\nimport os\nimport tqdm\nimport argparse\n\nfrom itertools import combinations\n\n\ndef tet_to_grids(vertices, values_list, grid_size):\n    grid = torch.zeros(12, grid_size, grid_size, grid_size, device=vertices.device)\n    with torch.no_grad():\n        for k, values in enumerate(values_list):\n            if k == 0:\n                grid[k, vertices[:, 0], vertices[:, 1], vertices[:, 2]] = values.squeeze()\n            else:\n                grid[1:4, vertices[:, 0], vertices[:, 1], vertices[:, 2]] = values.transpose(0, 1)\n    return grid\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser(description='nvdiffrec')\n    parser.add_argument('-res', '--resolution', type=int)\n    parser.add_argument('-r', '--root', type=str)\n    parser.add_argument('-s', '--source', type=str)\n    parser.add_argument('-t', '--target', type=str)\n    FLAGS = parser.parse_args()\n\n    tet_path = f'./tets/{FLAGS.resolution}_tets_cropped_reordered.npz'\n    tet = np.load(tet_path)\n    vertices = torch.tensor(tet['vertices']).cuda()\n    indices = torch.tensor(tet['indices']).long().cuda()\n\n    edges = torch.tensor(tet['edges']).long().cuda()\n    tet_edges = torch.tensor(tet['tet_edges']).long().view(-1, 2).cuda()\n\n    vertices_unique = vertices[:].unique()\n    dx = vertices_unique[1] - vertices_unique[0]\n    dx = dx / 2.0 ### denser grid\n    vertices_discretized = (\n        ((vertices - vertices.min()) / dx)\n    ).long()\n\n    midpoints = (vertices_discretized[edges[:, 0]] + vertices_discretized[edges[:, 1]]) / 2.0\n    midpoints_dicretized = midpoints.long()\n\n    tet_verts = vertices_discretized[indices.view(-1)].view(-1, 4, 3)\n    tet_center = tet_verts.float().mean(dim=1)\n    tet_center_discretized = tet_center.long()\n\n\n    edge_ind_list = [[0, 1], [0, 2], [0, 3], [1, 2], [1, 3], [2, 3]]\n    msdf_tetedges = []\n    msdf_from_tetverts = []\n    for i in range(5):\n        for j in range(i+1, 6):\n            if (edge_ind_list[i][0] == edge_ind_list[j][0]\n                or edge_ind_list[i][0] == edge_ind_list[j][1]\n                or edge_ind_list[i][1] == edge_ind_list[j][0]\n                or edge_ind_list[i][1] == edge_ind_list[j][1]\n            ):\n                msdf_tetedges.append(i)\n                msdf_tetedges.append(j)\n                msdf_from_tetverts.extend([edge_ind_list[i][0], edge_ind_list[i][1], edge_ind_list[j][0], edge_ind_list[j][1]])\n    msdf_tetedges = torch.tensor(msdf_tetedges)\n    msdf_from_tetverts = torch.tensor(msdf_from_tetverts)\n    print(msdf_tetedges)\n    print(msdf_tetedges.size())\n\n\n\n    tet_edges = tet_edges.view(-1, 2)\n    msdf_tetedges = msdf_tetedges.view(-1)\n    tet_edgenodes_pos = (vertices_discretized[tet_edges[:, 0]] + vertices_discretized[tet_edges[:, 1]]) / 2.0\n    tet_edgenodes_pos = tet_edgenodes_pos.view(-1, 6, 2)\n    occ_edge_pos = tet_edgenodes_pos[:, msdf_tetedges].view(-1, 12, 2, 3)\n    \n\n    edge_twopoint_order = torch.sign(occ_edge_pos[:, :, 0, :] - occ_edge_pos[:, :, 1, :])\n    edge_twopoint_order_binary_code = (edge_twopoint_order * torch.tensor([16, 4, 1], device=edge_twopoint_order.device).view(1, 1, -1)).sum(dim=-1)\n    edge_twopoint_order_binary_code = torch.stack([edge_twopoint_order_binary_code, -edge_twopoint_order_binary_code], dim=-1)\n    _, edge_twopoint_order = edge_twopoint_order_binary_code.sort(dim=-1)\n\n    occ_edge_cano_order = torch.arange(2).view(1, 1, 2).expand(occ_edge_pos.size(0), 12, 2).cuda()\n    occ_edge_cano_order = torch.gather(\n        input=occ_edge_cano_order,\n        dim=-1,\n        index=edge_twopoint_order\n    )\n\n    tet_edges = tet_edges.view(-1)\n\n    torch.save({\n        'tet_v_pos': vertices,\n        'tet_edge_vpos': vertices[tet_edges].view(-1, 2, 3),\n        'tet_edge_pix_loc': vertices_discretized[tet_edges].view(-1, 2, 3),\n        'tet_center_loc': tet_center_discretized,\n        'msdf_edges': msdf_tetedges.view(12, 2),\n        'occ_edge_cano_order': occ_edge_cano_order\n    }, 'tet_info.pt')\n"
  },
  {
    "path": "GMeshDiffusion/metadata/tet_to_cubic_grid_dataset.py",
    "content": "import numpy as np\nimport torch\nimport os\nimport tqdm\nimport argparse\n\ndef tet_to_grids(vertices, values_list, grid_size):\n    grid = torch.zeros(4, grid_size, grid_size, grid_size, device=vertices.device)\n    with torch.no_grad():\n        for k, values in enumerate(values_list):\n            if k == 0:\n                grid[k, vertices[:, 0], vertices[:, 1], vertices[:, 2]] = values.squeeze()\n            else:\n                grid[1:4, vertices[:, 0], vertices[:, 1], vertices[:, 2]] = values.transpose(0, 1)\n    return grid\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser(description='nvdiffrec')\n    parser.add_argument('-res', '--resolution', type=int)\n    parser.add_argument('-ss', '--split-size', type=int, default=int(1e8))\n    parser.add_argument('-ind', '--index', type=int)\n    parser.add_argument('-r', '--root', type=str)\n    parser.add_argument('-s', '--source', type=str)\n    parser.add_argument('-t', '--target', type=str)\n    FLAGS = parser.parse_args()\n\n    tet_path = f'./tets/{FLAGS.resolution}_tets_cropped_reordered.npz'\n    tet = np.load(tet_path)\n    vertices = torch.tensor(tet['vertices']).cuda()\n    indices = torch.tensor(tet['indices']).long().cuda()\n\n    edges = torch.tensor(tet['edges']).long().cuda()\n    tet_edges = torch.tensor(tet['tet_edges']).long().view(-1, 2).cuda()\n    \n    vertices_unique = vertices[:].unique()\n    dx = vertices_unique[1] - vertices_unique[0]\n    dx = dx / 2.0 ### denser grid\n    vertices_discretized = (\n        ((vertices - vertices.min()) / dx)\n    ).long()\n\n    print(vertices_discretized.size())\n    midpoints = (vertices_discretized[edges[:, 0]] + vertices_discretized[edges[:, 1]]) / 2.0\n    midpoints_dicretized = midpoints.long()\n\n    tet_verts = vertices_discretized[indices.view(-1)].view(-1, 4, 3)\n    tet_center = tet_verts.float().mean(dim=1)\n    tet_center_discretized = tet_center.long()\n\n\n    global_mask = torch.zeros(4, FLAGS.resolution * 2, FLAGS.resolution * 2, FLAGS.resolution * 2).cuda()\n    cat_mask = torch.zeros(FLAGS.resolution * 2, FLAGS.resolution * 2, FLAGS.resolution * 2).cuda()\n    global_mask[:4, vertices_discretized[:, 0], vertices_discretized[:, 1], vertices_discretized[:, 2]] += 1.0\n    cat_mask[vertices_discretized[:, 0], vertices_discretized[:, 1], vertices_discretized[:, 2]] = 1\n    global_mask[0, midpoints_dicretized[:, 0], midpoints_dicretized[:, 1], midpoints_dicretized[:, 2]] += 1.0\n    cat_mask[midpoints_dicretized[:, 0], midpoints_dicretized[:, 1], midpoints_dicretized[:, 2]] = -1\n\n\n    torch.save(global_mask, f'global_mask_res{FLAGS.resolution}.pt')\n    torch.save(cat_mask, f'cat_mask_res{FLAGS.resolution}.pt')\n\n    save_folder = FLAGS.root\n\n    grid_folder_base = os.path.join(save_folder, FLAGS.target)\n    os.makedirs(grid_folder_base, exist_ok=True)\n\n    print(grid_folder_base)\n\n    edge_ind_list = [[0, 1], [0, 2], [0, 3], [1, 2], [1, 3], [2, 3]]\n    msdf_tetedges = []\n    msdf_from_tetverts = []\n    for i in range(5):\n        for j in range(i+1, 6):\n            if (edge_ind_list[i][0] == edge_ind_list[j][0]\n                or edge_ind_list[i][0] == edge_ind_list[j][1]\n                or edge_ind_list[i][1] == edge_ind_list[j][0]\n                or edge_ind_list[i][1] == edge_ind_list[j][1]\n            ):\n                msdf_tetedges.append(i)\n                msdf_tetedges.append(j)\n                msdf_from_tetverts.extend([edge_ind_list[i][0], edge_ind_list[i][1], edge_ind_list[j][0], edge_ind_list[j][1]])\n    msdf_tetedges = torch.tensor(msdf_tetedges)\n    msdf_from_tetverts = torch.tensor(msdf_from_tetverts)\n    print(msdf_tetedges)\n    print(msdf_tetedges.size())\n\n\n\n    occgrid_mask_already_saved = False\n    tets_folder = os.path.join(save_folder, FLAGS.source)\n\n    with torch.no_grad():\n        for k in tqdm.trange(FLAGS.split_size):\n            global_index = k + FLAGS.index * FLAGS.split_size\n            tet_path = os.path.join(tets_folder, 'dmt_dict_{:05d}.pt'.format(global_index))\n            if os.path.exists(os.path.join(grid_folder_base, 'grid_{:05d}.pt'.format(global_index))):\n                # continue\n                pass\n            try:\n                if os.path.exists(tet_path):\n                    tet = torch.load(tet_path, map_location=\"cuda\")\n\n                    sdf = tet['sdf'].view(-1, 1)\n                    msdf = tet['msdf'].view(-1, 1)\n                    deform = tet['deform']\n\n\n                    ### resetting sdfs and offsets of all non-mesh-generating tet vertices\n                    tet_edges = tet_edges.view(-1, 2)\n                    tet_edge_mask = ((torch.sign(sdf[tet_edges[:, 0]]) - torch.sign(sdf[tet_edges[:, 1]])) != 0).bool().squeeze(-1).view(-1, 6)\n                    tet_sdf_coeff = (\n                        torch.abs(sdf[tet_edges[:, 0]]) \n                        / (torch.abs(sdf[tet_edges[:, 0]] - sdf[tet_edges[:, 1]]) + 1e-10)\n                    ).squeeze(-1)\n                    tet_sdf_coeff = tet_sdf_coeff.view(-1, 1)\n                    midpoint_msdf_tet = msdf[tet_edges[:, 0]] * (1 - tet_sdf_coeff) + msdf[tet_edges[:, 1]] * tet_sdf_coeff\n                    midpoint_msdf_tet = midpoint_msdf_tet.view(-1, 6)\n                    tet_mask = ((midpoint_msdf_tet > 0) & tet_edge_mask).sum(dim=-1).bool()\n                    vert_mask = torch.zeros_like(sdf.squeeze())\n                    vert_mask[indices[tet_mask].view(-1)] = 1.0\n                    vert_mask = ~vert_mask.bool()\n                    msdf[vert_mask] = -1.0\n                    deform[vert_mask] = 0.0\n\n                    tet_nonallnegmsdf = (torch.sign(msdf[indices.view(-1)].view(-1, 4)).sum(dim=-1) != -4)\n                    vert_mask_nonallnegmsdf = torch.zeros_like(sdf.squeeze())\n                    vert_mask_nonallnegmsdf[indices[tet_nonallnegmsdf].view(-1)] = 1.0\n                    vert_mask_nonallnegmsdf = ~vert_mask_nonallnegmsdf.bool()\n                    sdf[vert_mask_nonallnegmsdf] = 1.0\n                    \n\n                    \n\n                    mask = (\n                        (torch.sign(sdf[edges[:, 0]]) - torch.sign(sdf[edges[:, 1]]) != 0).bool()\n                    )\n\n                    nan_mask = (\n                            ((torch.sign(sdf[edges[:, 0]]) + torch.sign(sdf[edges[:, 1]])) == 2)\n                            | ((torch.sign(sdf[edges[:, 0]]) + torch.sign(sdf[edges[:, 1]])) == -2) \n                        ).bool().squeeze(-1)\n\n                    original_sdf_coeff = torch.abs(sdf[edges[:, 0]]) / (torch.abs(sdf[edges[:, 0]] - sdf[edges[:, 1]]) + 1e-10)\n\n\n                    original_sdf_coeff[nan_mask] = torch.nan\n\n                    normalized_sdf_coeff = ((original_sdf_coeff - 0.5) * 2.0)\n                    normalized_sdf_coeff = torch.nan_to_num(normalized_sdf_coeff)\n                    assert torch.all(normalized_sdf_coeff.abs() <= 1.0)\n\n\n                    sdf_sign = torch.sign(sdf)\n                    sdf_sign[sdf_sign == 0] = 1\n\n                    midpoint_msdf = msdf[edges[:, 0]] * (1 - original_sdf_coeff.view(-1, 1)) + msdf[edges[:, 1]] * original_sdf_coeff.view(-1, 1)\n                    midpoint_msdf_sign = torch.sign(midpoint_msdf)\n                    midpoint_msdf_sign[midpoint_msdf_sign == 0] = -1\n                    midpoint_msdf_sign = midpoint_msdf_sign * mask - (1.0 - mask.float())\n\n                    ############################ Occ Grid ############################\n\n\n                    tet_edges = tet_edges.view(-1, 2)\n                    tet_edge_mask = ((torch.sign(sdf[tet_edges[:, 0]]) - torch.sign(sdf[tet_edges[:, 1]])) != 0).bool().squeeze(-1).view(-1, 6)\n                    tet_sdf_coeff = (\n                        torch.abs(sdf[tet_edges[:, 0]]) \n                        / (torch.abs(sdf[tet_edges[:, 0]] - sdf[tet_edges[:, 1]]) + 1e-10)\n                    ).squeeze(-1)\n                    tet_sdf_coeff = tet_sdf_coeff * tet_edge_mask.view(-1)\n                    tet_sdf_coeff = tet_sdf_coeff.view(-1, 1)\n                    nan_mask = (\n                            ((torch.sign(sdf[tet_edges[:, 0]]) + torch.sign(sdf[tet_edges[:, 1]])) == 2)\n                            | ((torch.sign(sdf[tet_edges[:, 0]]) + torch.sign(sdf[tet_edges[:, 1]])) == -2) \n                        ).bool().squeeze(-1)\n                    tet_sdf_coeff[nan_mask] = torch.nan\n                    midpoint_msdf_tet = msdf[tet_edges[:, 0]] * (1 - tet_sdf_coeff) + msdf[tet_edges[:, 1]] * tet_sdf_coeff\n                    midpoint_msdf_tet = midpoint_msdf_tet.view(-1, 6)\n                    inscribed_edge_twopoint_msdf = midpoint_msdf_tet[:, msdf_tetedges.view(-1)].view(-1, 12, 2)\n\n                    assert ((\n                        (tet_edges.view(-1, 6, 2)[:, msdf_tetedges.view(-1), :].view(-1, 24, 2).sum(dim=-1)) - indices[:, msdf_from_tetverts].view(-1, 24, 2).sum(dim=-1)\n                    ).sum().item() == 0)\n\n                    assert msdf_tetedges.view(-1).size(0) == 24\n                    inscribed_tet_fourpoint_pos = vertices_discretized[indices[:, msdf_from_tetverts].view(-1)].view(-1, 12, 4, 3).to(torch.float64)\n                    inscribed_edge_twopoint_pos = inscribed_tet_fourpoint_pos.view(-1, 12, 2, 2, 3).mean(dim=-2)\n                    occgrid_loc = inscribed_edge_twopoint_pos.mean(dim=-2)\n                    occgrid_loc = (occgrid_loc * 2).to(torch.int64).view(-1, 3)\n\n\n                    edge_twopoint_order = torch.sign(inscribed_edge_twopoint_pos[:, :, 0, :] - inscribed_edge_twopoint_pos[:, :, 1, :])\n                    edge_twopoint_order_binary_code = (edge_twopoint_order * torch.tensor([16, 4, 1], device=edge_twopoint_order.device).view(1, 1, -1)).sum(dim=-1)\n                    edge_twopoint_order_binary_code = torch.stack([edge_twopoint_order_binary_code, -edge_twopoint_order_binary_code], dim=-1)\n                    _, edge_twopoint_order = edge_twopoint_order_binary_code.sort(dim=-1)\n\n                    inscribed_edge_twopoint_msdf = torch.gather(\n                        input=inscribed_edge_twopoint_msdf,\n                        dim=-1,\n                        index=edge_twopoint_order\n                    )\n\n                    mask_msdf = (\n                        ((inscribed_edge_twopoint_msdf[:, :, 0] > 0) & (inscribed_edge_twopoint_msdf[:, :, 1] <= 0)) |\n                        ((inscribed_edge_twopoint_msdf[:, :, 0] <= 0) & (inscribed_edge_twopoint_msdf[:, :, 1] > 0)) \n                    )\n                    msdf_coeff_12 = (\n                        torch.abs(inscribed_edge_twopoint_msdf[:, :, 0]) \n                        / (\n                            torch.abs(inscribed_edge_twopoint_msdf[:, :, 0] - inscribed_edge_twopoint_msdf[:, :, 1])\n                            + 1e-10\n                        )\n                    )\n\n                    msdf_coeff_12 = (msdf_coeff_12 - 0.5) * 2.0 * mask_msdf\n                    msdf_coeff_12 = torch.nan_to_num(msdf_coeff_12)\n\n                    occ_grid = torch.zeros(256, 256, 256, dtype=torch.float, device=msdf_coeff_12.device)\n                    occ_grid[occgrid_loc[:, 0], occgrid_loc[:, 1], occgrid_loc[:, 2]] = msdf_coeff_12.view(-1).to(torch.float)\n\n                    if not occgrid_mask_already_saved:\n                        occ_grid_mask = torch.zeros(256, 256, 256, dtype=torch.float, device=msdf_coeff_12.device)\n                        occ_grid_mask[occgrid_loc[:, 0], occgrid_loc[:, 1], occgrid_loc[:, 2]] = 1\n                        torch.save(occ_grid_mask, f'occ_mask_res{FLAGS.resolution}.pt')\n                        occgrid_mask_already_saved = True\n\n\n\n                    # #################\n\n                    torch.cuda.empty_cache()\n                    grid = torch.zeros(4, FLAGS.resolution * 2, FLAGS.resolution * 2, FLAGS.resolution * 2).cuda()\n                    grid[0, vertices_discretized[:, 0], vertices_discretized[:, 1], vertices_discretized[:, 2]] = sdf_sign.squeeze()\n                    grid[1:4, vertices_discretized[:, 0], vertices_discretized[:, 1], vertices_discretized[:, 2]] = deform.transpose(0, 1)\n                    grid[0, midpoints_dicretized[:, 0], midpoints_dicretized[:, 1], midpoints_dicretized[:, 2]] = midpoint_msdf_sign.squeeze()\n\n                    assert grid.abs().max() <= 1\n\n                    save_path = os.path.join(grid_folder_base, 'grid_{:05d}.pt'.format(global_index))\n                    torch.save(grid, save_path)\n\n                    save_path = os.path.join(grid_folder_base, 'occgrid_{:05d}.pt'.format(global_index))\n                    torch.save(occ_grid, save_path)\n                \n            except:\n                raise"
  },
  {
    "path": "GMeshDiffusion/scripts/run_eval_lower_occgrid_normalized.sh",
    "content": "python main_diffusion.py --mode uncond_gen --config diffusion_configs/config_lower_occgrid_normalized.py \\\n--config.eval.eval_dir=$EVAL_DIR \\\n--config.data.root_dir=$REPO_ROOT_DIR \\\n--config.sampling.method=ddim \\\n--config.eval.ckpt_path=$CKPT_PATH \\\n--config.eval.bin_size=30 \\\n--config.eval.idx $1"
  },
  {
    "path": "GMeshDiffusion/scripts/run_eval_upper_occgrid_normalized.sh",
    "content": "python main_diffusion.py --mode uncond_gen --config diffusion_configs/config_upper_occgrid_normalized.py \\\n--config.eval.eval_dir=$EVAL_DIR \\\n--config.data.root_dir=$REPO_ROOT_DIR \\\n--config.sampling.method=ddim \\\n--config.eval.ckpt_path=$CKPT_PATH \\\n--config.eval.bin_size=10 \\\n--config.eval.idx $1"
  },
  {
    "path": "GMeshDiffusion/scripts/run_lower_occgrid_normalized_ddp.sh",
    "content": "torchrun --nnodes=1 --nproc_per_node=8 main_diffusion_ddp.py --mode=train --config=diffusion_configs/config_lower_occgrid_normalized.py \\\n--config.training.train_dir=$SAVE_DIR --config.data.root_dir=$REPO_ROOT_DIR"
  },
  {
    "path": "GMeshDiffusion/scripts/run_upper_occgrid_normalized_ddp.sh",
    "content": "torchrun --nnodes=1 --nproc_per_node=8 main_diffusion_ddp.py --mode=train --config=diffusion_configs/config_upper_occgrid_normalized.py \\\n--config.training.train_dir=$SAVE_DIR --config.data.root_dir=$REPO_ROOT_DIR\n"
  },
  {
    "path": "README.md",
    "content": "<div align=\"center\">\n  <img src=\"assets/gshell_logo.png\" width=\"900\"/>\n</div>\n\n# Ghost on the Shell: An Expressive Representation of General 3D Shapes\n\n\n<div align=\"center\">\n  <img src=\"assets/teaser.png\" width=\"900\"/>\n</div>\n\n## Introduction\n\nThis is the official implementation of our paper (ICLR 2024 oral) \"Ghost on the Shell: An Expressive Representation of General 3D Shapes\" (G-Shell).\n\nG-Shell is a generic and differentiable representation for both watertight and non-watertight meshes. It enables 1) efficient and robust rasterization-based multiview reconstruction and 2) template-free generation of non-watertight meshes.\n\nPlease refer to [our project page](https://gshell3d.github.io) and [our paper](https://gshell3d.github.io/static/paper/gshell.pdf) for more details.\n\n\n## Getting Started\n\n### Requirements\n\n\n- Python >= 3.8\n- CUDA 11.8\n- PyTorch == 1.13.1\n\n(Conda installation recommended)\n\n#### Reconstruction\n\nRun the following\n\n```\npip install ninja imageio PyOpenGL glfw xatlas gdown\npip install git+https://github.com/NVlabs/nvdiffrast/\npip install --global-option=\"--no-networks\" git+https://github.com/NVlabs/tiny-cuda-nn#subdirectory=bindings/torch\n```\n\nFollow the instructions [here](https://github.com/NVIDIAGameWorks/kaolin/) to install kaolin.\n\nDownload the tet-grid files ([res128](https://drive.google.com/file/d/1u5FzpuY_BOAg8-g9lRwvah7mbCBOfNVg/view?usp=sharing), [res256](https://drive.google.com/file/d/1JnFoPEGcTLFJ7OHSWrI72h1H9_yOxUP6/view?usp=sharing)) & [res64 for G-MeshDiffusion](https://drive.google.com/file/d/1YQuU4D-0q8kwrzEfla3hGzBg4erBhand/view?usp=drive_link) to `data/tets` folder under the root directory. Alternatively, you may follow https://github.com/crawforddoran/quartet and `data/tets/generate_tets.py` to create the tet-grid files.\n\n#### Generation\n\nInstall the following\n\n- Pytorch3D\n- ml_collections\n\n## To-dos\n\n- [x] Code for reconstruction\n- [ ] DeepFashion3D multiview image dataset for metallic surfaces\n- [x] Code for generative models\n- [ ] Code for DeepFashion3D dataset preparation\n- [ ] Evaluation code for generative models\n\n## Reconstruction\n\n### Datasets\n\n#### DeepFashion3D mesh dataset\n\nWe provide ground-truth images (rendered under realistic environment light with Blender) for 9 instances in [DeepFashion3D-v2 dataset](https://github.com/GAP-LAB-CUHK-SZ/deepFashion3D). The download links for the raw meshes can be found in their repo.\n\nnon-metallic material: [training data](https://drive.google.com/file/d/1LwBqLYzamFLyBIiNpD6kEkvySrq2nruG/view?usp=sharing), [test data](https://drive.google.com/file/d/1-47dH_yJrUzKVdKbJenslpdyHwDI6QVo/view?usp=sharing)\n\n\n#### NeRF synthetic dataset\n\nDownload the [NeRF synthetic dataset archive](https://drive.google.com/uc?export=download&id=18JxhpWD-4ZmuFKLzKlAw-w5PpzZxXOcG) and unzip it into the `data/` folder.\n\n#### Hat dataset\n\nDownload link: https://drive.google.com/file/d/18UmT1NM5wJQ-ZM-rtUXJHXkDc-ba-xVk/view?usp=sharing\n\nRGB images, segmentation masks and the corresponding camera poses are included. Alternatively, you may choose to 1) generate the camera poses with COLMAP and 2) create binary segmentation masks by yourself.\n\n### Training\n\n#### DeepFashion3D-v2 instances\n\nThe mesh instances' IDs are [30, 92, 117, 133, 164, 320, 448, 522, 591]. To reconstruct the `$INDEX`-th mesh (0-8) in the list using tet-based G-Shell, run\n\n```\n  python train_gshelltet_deepfashion.py --config config/deepfashion_mc_256.json --index $INDEX --trainset_path $TRAINSET_PATH --testset_path $TESTSET_PATH --o $OUTPUT_PATH\n```\n\nFor FlexiCubes + G-Shell, run\n\n```\n  python train_gflexicubes_deepfashion.py --config config/deepfashion_mc_80.json --index $INDEX --trainset_path $TRAINSET_PATH --testset_path $TESTSET_PATH --o $OUTPUT_PATH\n```\n\n#### Synthetic data\n\n```\n  python train_gshelltet_synthetic.py --config config/nerf_chair.json --o $OUTPUT_PATH\n```\n\n#### Hat data\n\n```\n  python train_gshelltet_polycam.py --config config/polycam_mc_128.json --trainset_path $TRAINSET_PATH --o $OUTPUT_PATH\n```\n\n```\n  python train_gshelltet_polycam.py --config config/polycam_mc_128.json --trainset_path $TRAINSET_PATH --o $OUTPUT_PATH\n```\n\n#### On config files\n\nYou may consider modify the following, depending on your demand:\n\n- `gshell_grid`: the G-Shell grid size. For tet-based G-Shell, please make sure the corresponding tet-grid file exists under `data/tets` (e.g., `256_tets.npz`). Otherwise, follow https://github.com/crawforddoran/quartet and `data/tets/generate_tets.py` to generate the desired tet-grid file.\n- `n_samples`: the number of MC samples for light rays per rasterized pixel. The higher the better (at a cost of memory and speed).\n- `batch_size`: how many views sampled in each iteration.\n- `iteration`: total number of iterations.\n- `kd_min`, `kd_max`, etc: the min/max of the corresponding PBR material parameter.\n\n\n\n\n\n\n## Generation\n\n### Preparation\n\nDownload info files for the underlying tet grids and binary masks that indicating which locations store useful values in the cubic grids from [tet_info.pt](https://drive.google.com/file/d/19Dw_hOpcVHazpm2_1qA7T7j5xABOUWxv/view?usp=drive_link), [global_mask_res64.pt](https://drive.google.com/file/d/1mlSnu23_u08HH5aO3x5z1V9GzguFzoiT/view?usp=drive_link), [cat_mask_res64.pt](https://drive.google.com/file/d/11Bm4CQX-y1X7R47AfQQz20s7oP6AbbNK/view?usp=drive_link) and [occ_mask_res64.pt](https://drive.google.com/file/d/1qEqqLfZe633GdVkj5MGOCON_kf0l4e4G/view?usp=drive_link). Put these files under `GMeshDiffusion/metadata/`.\n\n#### For inference\n\nDownload the pretrained model for upper-body garments lower-body garments [here](https://huggingface.co/lzzcd001/GMeshDiffusion-Models).\n\n#### For training\n\n1) Download the processed Cloth3D garment dataset (for upper-body & lower-body garments) from [link](https://huggingface.co/datasets/lzzcd001/Cloth3D-GShell-Dataset). Alternatively, you may create a grid dataset for your own objects by a) normalize your datapoints by re-center and re-scaling meshes, b) fitting G-Shell representations and c) turn these representations into cubid grids by running `GMeshDiffusion/metadata/tet_to_cubic_grid_dataset.py`.\n\n2) Run `GMeshDiffusion/metadata/get_splits_lower.py` and/or `GMeshDiffusion/metadata/get_splits_upper.py` to generate lists of training and test datapoints.\n\n\n### Inference\n\n1. Modify the batch size in config files in `GMeshDiffusion/diffusion_config/` and enter the desired directories and values (for model checkpoints, where to store generated samples, etc.) in `GMeshDiffusion/scripts`.\n2. Run the eval scripts in `GMeshDiffusion/scripts`.\n3. Run `eval_gmeshdiffusion_generated_samples.py` to extract triangular meshes.\n\n### Training\n\n1. Enter the desired directories (for model checkpoints and where to store generated samples) in `GMeshDiffusion/scripts`.\n2. Modify the config files if necessary.\n3. Run the training scripts in `GMeshDiffusion/scripts`.\n\n\n## Citation\n\nIf you find our work useful to your research, please consider citing:\n\n```\n@article{liu2024gshell,\n    title={Ghost on the Shell: An Expressive Representation of General 3D Shapes},\n    author={Liu, Zhen and Feng, Yao and Xiu, Yuliang and Liu, Weiyang and Paull, Liam and Black, Michael J. and Sch{\\\"o}lkopf, Bernhard},\n    booktitle={The Twelfth International Conference on Learning Representations},\n    year={2024},\n}\n```\n\n\n## Acknowledgement\n\nWe sincerely thank the authors of [Nvdiffrecmc](https://github.com/NVlabs/nvdiffrecmc), [FlexiCubes](https://github.com/nv-tlabs/FlexiCubes) and https://github.com/yang-song/score_sde_pytorch for sharing their codes. Our repo is adapted from [MeshDiffusion](https://github.com/lzzcd001/MeshDiffusion/).\n"
  },
  {
    "path": "configs/deepfashion_mc.json",
    "content": "{\n    \"ref_mesh\": \"data/spot/spot.obj\",\n    \"random_textures\": true,\n    \"iter\": 5000,\n    \"save_interval\": 100,\n    \"texture_res\": [ 1024, 1024 ],\n    \"train_res\": [1024, 1024],\n    \"batch\": 2,\n    \"learning_rate\": [0.03, 0.005],\n    \"ks_min\" : [0, 0.001, 0.0],\n    \"ks_max\" : [0, 1.0, 1.0],\n    \"envlight\": \"data/irrmaps/aerodynamics_workshop_2k.hdr\",\n    \"lock_pos\" : false,\n    \"display\": [{\"latlong\" : true}],\n    \"background\" : \"white\",\n    \"denoiser\": \"bilateral\",\n    \"n_samples\" : 24,\n    \"env_scale\" : 2.0,\n    \"gshell_grid\" : 128,\n    \"validate\" : true,\n    \"laplace_scale\" : 6000,\n    \"boxscale\": [1, 1, 1],\n    \"aabb\": [-1, -1, -1, 1, 1, 1]\n}"
  },
  {
    "path": "configs/deepfashion_mc_256.json",
    "content": "{\n    \"ref_mesh\": \"data/spot/spot.obj\",\n    \"random_textures\": true,\n    \"iter\": 5000,\n    \"save_interval\": 100,\n    \"texture_res\": [ 1024, 1024 ],\n    \"train_res\": [1024, 1024],\n    \"batch\": 2,\n    \"learning_rate\": [0.03, 0.005],\n    \"ks_min\" : [0, 0.001, 0.0],\n    \"ks_max\" : [0, 1.0, 1.0],\n    \"envlight\": \"data/irrmaps/aerodynamics_workshop_2k.hdr\",\n    \"lock_pos\" : false,\n    \"display\": [{\"latlong\" : true}],\n    \"background\" : \"white\",\n    \"denoiser\": \"bilateral\",\n    \"n_samples\" : 24,\n    \"env_scale\" : 2.0,\n    \"gshell_grid\" : 256,\n    \"validate\" : true,\n    \"laplace_scale\" : 6000,\n    \"boxscale\": [1, 1, 1],\n    \"aabb\": [-1, -1, -1, 1, 1, 1]\n}"
  },
  {
    "path": "configs/deepfashion_mc_512.json",
    "content": "{\n    \"ref_mesh\": \"data/spot/spot.obj\",\n    \"random_textures\": true,\n    \"iter\": 5000,\n    \"save_interval\": 100,\n    \"texture_res\": [ 1024, 1024 ],\n    \"train_res\": [1024, 1024],\n    \"batch\": 2,\n    \"learning_rate\": [0.03, 0.005],\n    \"ks_min\" : [0, 0.001, 0.0],\n    \"ks_max\" : [0, 1.0, 1.0],\n    \"envlight\": \"data/irrmaps/aerodynamics_workshop_2k.hdr\",\n    \"validate\" : false,\n    \"lock_pos\" : false,\n    \"display\": [{\"latlong\" : true}],\n    \"background\" : \"white\",\n    \"denoiser\": \"bilateral\",\n    \"n_samples\" : 12,\n    \"env_scale\" : 2.0,\n    \"gshell_grid\" : 512,\n    \"validate\" : true,\n    \"laplace_scale\" : 6000,\n    \"boxscale\": [1, 1, 1],\n    \"aabb\": [-1, -1, -1, 1, 1, 1]\n}"
  },
  {
    "path": "configs/deepfashion_mc_80.json",
    "content": "{\n    \"ref_mesh\": \"data/spot/spot.obj\",\n    \"random_textures\": true,\n    \"iter\": 5000,\n    \"save_interval\": 100,\n    \"texture_res\": [ 1024, 1024 ],\n    \"train_res\": [1024, 1024],\n    \"batch\": 2,\n    \"learning_rate\": [0.03, 0.005],\n    \"ks_min\" : [0, 0.001, 0.0],\n    \"ks_max\" : [0, 1.0, 1.0],\n    \"envlight\": \"data/irrmaps/aerodynamics_workshop_2k.hdr\",\n    \"lock_pos\" : false,\n    \"display\": [{\"latlong\" : true}],\n    \"background\" : \"white\",\n    \"denoiser\": \"bilateral\",\n    \"n_samples\" : 24,\n    \"env_scale\" : 2.0,\n    \"gshell_grid\" : 80,\n    \"validate\" : true,\n    \"laplace_scale\" : 6000,\n    \"boxscale\": [1, 1, 1],\n    \"aabb\": [-1, -1, -1, 1, 1, 1]\n}"
  },
  {
    "path": "configs/nerf_chair.json",
    "content": "{\n    \"ref_mesh\": \"data/nerf_synthetic/chair\",\n    \"random_textures\": true,\n    \"iter\": 5000,\n    \"save_interval\": 100,\n    \"texture_res\": [ 1024, 1024 ],\n    \"train_res\": [800, 800],\n    \"batch\": 2,\n    \"learning_rate\": [0.03, 0.005],\n    \"gshell_grid\" : 128,\n    \"mesh_scale\" : 2.1,\n    \"validate\" : true,\n    \"n_samples\" : 8,\n    \"denoiser\" : \"bilateral\",\n    \"display\": [{\"latlong\" : true}, {\"bsdf\" : \"kd\"}, {\"bsdf\" : \"ks\"}, {\"bsdf\" : \"normal\"}],\n    \"background\" : \"white\",\n    \"boxscale\": [1, 1, 1],\n    \"aabb\": [-1, -1, -1, 1, 1, 1]\n}"
  },
  {
    "path": "configs/polycam_mc.json",
    "content": "{\n    \"ref_mesh\": \"data/spot/spot.obj\",\n    \"random_textures\": true,\n    \"iter\": 5000,\n    \"save_interval\": 100,\n    \"texture_res\": [ 1024, 1024 ],\n    \"train_res\": [768, 1024],\n    \"batch\": 2,\n    \"learning_rate\": [0.03, 0.005],\n    \"ks_min\" : [0, 0.001, 0.0],\n    \"ks_max\" : [0, 1.0, 1.0],\n    \"envlight\": \"data/irrmaps/aerodynamics_workshop_2k.hdr\",\n    \"lock_pos\" : false,\n    \"display\": [{\"latlong\" : true}],\n    \"background\" : \"white\",\n    \"denoiser\": \"bilateral\",\n    \"n_samples\" : 8,\n    \"env_scale\" : 2.0,\n    \"gshell_grid\" : 256,\n    \"validate\" : true,\n    \"laplace_scale\" : 6000,\n    \"boxscale\": [1, 1, 1],\n    \"aabb\": [-1, -1, -1, 1, 1, 1]\n}"
  },
  {
    "path": "configs/polycam_mc_128.json",
    "content": "{\n    \"ref_mesh\": \"data/spot/spot.obj\",\n    \"random_textures\": true,\n    \"iter\": 5000,\n    \"save_interval\": 100,\n    \"texture_res\": [ 1024, 1024 ],\n    \"train_res\": [768, 1024],\n    \"batch\": 2,\n    \"learning_rate\": [0.03, 0.005],\n    \"ks_min\" : [0, 0.001, 0.0],\n    \"ks_max\" : [0, 1.0, 1.0],\n    \"envlight\": \"data/irrmaps/aerodynamics_workshop_2k.hdr\",\n    \"lock_pos\" : false,\n    \"display\": [{\"latlong\" : true}],\n    \"background\" : \"white\",\n    \"denoiser\": \"bilateral\",\n    \"n_samples\" : 8,\n    \"env_scale\" : 2.0,\n    \"gshell_grid\" : 128,\n    \"validate\" : true,\n    \"laplace_scale\" : 6000,\n    \"boxscale\": [1, 1, 1],\n    \"aabb\": [-1, -1, -1, 1, 1, 1]\n}"
  },
  {
    "path": "configs/polycam_mc_16samples.json",
    "content": "{\n    \"ref_mesh\": \"data/spot/spot.obj\",\n    \"random_textures\": true,\n    \"iter\": 5000,\n    \"save_interval\": 100,\n    \"texture_res\": [ 1024, 1024 ],\n    \"train_res\": [768, 1024],\n    \"batch\": 2,\n    \"learning_rate\": [0.03, 0.005],\n    \"ks_min\" : [0, 0.001, 0.0],\n    \"ks_max\" : [0, 1.0, 1.0],\n    \"envlight\": \"data/irrmaps/aerodynamics_workshop_2k.hdr\",\n    \"lock_pos\" : false,\n    \"display\": [{\"latlong\" : true}],\n    \"background\" : \"white\",\n    \"denoiser\": \"bilateral\",\n    \"n_samples\" : 16,\n    \"env_scale\" : 2.0,\n    \"gshell_grid\" : 256,\n    \"validate\" : true,\n    \"laplace_scale\" : 6000,\n    \"boxscale\": [1, 1, 1],\n    \"aabb\": [-1, -1, -1, 1, 1, 1]\n}"
  },
  {
    "path": "data/tets/generate_tets.py",
    "content": "# Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. \n#\n# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual\n# property and proprietary rights in and to this material, related\n# documentation and any modifications thereto. Any use, reproduction, \n# disclosure or distribution of this material and related documentation \n# without an express license agreement from NVIDIA CORPORATION or \n# its affiliates is strictly prohibited.\n\nimport os\nimport numpy as np\n\n\n'''\nThis code segment shows how to use Quartet: https://github.com/crawforddoran/quartet, \nto generate a tet grid \n1) Download, compile and run Quartet as described in the link above. Example usage `quartet meshes/cube.obj 0.5 cube_5.tet`\n2) Run the function below to generate a file `cube_32_tet.tet`\n'''\n\ndef generate_tetrahedron_grid_file(res=32, root='..'):\n    frac = 1.0 / res\n    command = 'cd %s/quartet; ' % (root) + \\\n                './quartet meshes/cube.obj %f meshes/cube_%f_tet.tet -s meshes/cube_boundary_%f.obj' % (frac, res, res)\n    os.system(command)\n\n\n'''\nThis code segment shows how to convert from a quartet .tet file to compressed npz file\n'''\ndef convert_from_quartet_to_npz(quartetfile = 'cube_32_tet.tet', npzfile = '32_tets'):\n\n    file1 = open(quartetfile, 'r')\n    header = file1.readline()\n    numvertices = int(header.split(\" \")[1])\n    numtets     = int(header.split(\" \")[2])\n    print(numvertices, numtets)\n\n    # load vertices\n    vertices = np.loadtxt(quartetfile, skiprows=1, max_rows=numvertices)\n    print(vertices.shape)\n\n    # load indices\n    indices = np.loadtxt(quartetfile, dtype=int, skiprows=1+numvertices, max_rows=numtets)\n    print(indices.shape)\n\n    np.savez_compressed(npzfile, vertices=vertices, indices=indices)"
  },
  {
    "path": "dataset/__init__.py",
    "content": "# Copyright (c) 2020-2022, NVIDIA CORPORATION. All rights reserved.\n#\n# NVIDIA CORPORATION and its licensors retain all intellectual property\n# and proprietary rights in and to this software, related documentation\n# and any modifications thereto. Any use, reproduction, disclosure or\n# distribution of this software and related documentation without an express\n# license agreement from NVIDIA CORPORATION is strictly prohibited.\n\nfrom .dataset import Dataset\nfrom .dataset_mesh import DatasetMesh\nfrom .dataset_nerf import DatasetNERF\nfrom .dataset_llff import DatasetLLFF"
  },
  {
    "path": "dataset/dataset.py",
    "content": "# Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. \n#\n# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual\n# property and proprietary rights in and to this material, related\n# documentation and any modifications thereto. Any use, reproduction, \n# disclosure or distribution of this material and related documentation \n# without an express license agreement from NVIDIA CORPORATION or \n# its affiliates is strictly prohibited.\n\nimport torch\n\nclass Dataset(torch.utils.data.Dataset):\n    \"\"\"Basic dataset interface\"\"\"\n    def __init__(self): \n        super().__init__()\n\n    def __len__(self):\n        raise NotImplementedError\n\n    def __getitem__(self):\n        raise NotImplementedError\n\n    def collate(self, batch):\n        iter_res, iter_spp = batch[0]['resolution'], batch[0]['spp']\n        return {\n            'mv' : torch.cat(list([item['mv'] for item in batch]), dim=0),\n            'mvp' : torch.cat(list([item['mvp'] for item in batch]), dim=0),\n            'campos' : torch.cat(list([item['campos'] for item in batch]), dim=0),\n            'resolution' : iter_res,\n            'spp' : iter_spp,\n            'img' : torch.cat(list([item['img'] for item in batch]), dim=0) if 'img' in batch[0] else None,\n            'img_second' : torch.cat(list([item['img_second'] for item in batch]), dim=0) if 'img_second' in batch[0] else None,\n            'invdepth' : torch.cat(list([item['invdepth'] for item in batch]), dim=0)if 'invdepth' in batch[0] else None,\n            'invdepth_second' : torch.cat(list([item['invdepth_second'] for item in batch]), dim=0) if 'invdepth_second' in batch[0] else None,\n            'envlight_transform': torch.cat(list([item['envlight_transform'] for item in batch]), dim=0) if 'envlight_transform' in batch and batch[0]['envlight_transform'] is not None else None,\n        }"
  },
  {
    "path": "dataset/dataset_deepfashion.py",
    "content": "# Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. \n#\n# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual\n# property and proprietary rights in and to this material, related\n# documentation and any modifications thereto. Any use, reproduction, \n# disclosure or distribution of this material and related documentation \n# without an express license agreement from NVIDIA CORPORATION or \n# its affiliates is strictly prohibited.\n\nimport os\nimport glob\nimport json\n\nimport torch\nimport numpy as np\n\nfrom render import util\n\nfrom .dataset import Dataset\n\nimport cv2 as cv\n\n# This function is borrowed from IDR: https://github.com/lioryariv/idr\ndef load_K_Rt_from_P(filename, P=None):\n    if P is None:\n        lines = open(filename).read().splitlines()\n        if len(lines) == 4:\n            lines = lines[1:]\n        lines = [[x[0], x[1], x[2], x[3]] for x in (x.split(\" \") for x in lines)]\n        P = np.asarray(lines).astype(np.float32).squeeze()\n\n    out = cv.decomposeProjectionMatrix(P)\n    K = out[0]\n    R = out[1]\n    t = out[2]\n\n    K = K / K[2, 2]\n    intrinsics = np.eye(4)\n    intrinsics[:3, :3] = K\n\n\n    pose = np.eye(4, dtype=np.float32)\n    pose[:3, :3] = R.transpose()\n    pose[:3, 3] = (t[:3] / t[3])[:, 0]\n\n    return intrinsics, pose\n\ndef _load_img(path):\n    img = util.load_image_raw(path)\n    if img.dtype != np.float32: # LDR image\n        img = torch.tensor(img / 255, dtype=torch.float32)\n        img[..., 0:3] = util.srgb_to_rgb(img[..., 0:3])\n    else:\n        img = torch.tensor(img, dtype=torch.float32)\n    return img\n\n\n\nclass DatasetDeepFashion(Dataset):\n    def __init__(self, base_dir, FLAGS, examples=None):\n        self.FLAGS = FLAGS\n        self.examples = examples\n        self.base_dir = base_dir\n\n        # Load config / transforms\n        self.n_images = 72 ### hardcoded\n\n        self.fovy               = np.deg2rad(60)\n        self.proj_mtx = util.perspective(\n            self.fovy, self.FLAGS.display_res[1] / self.FLAGS.display_res[0], self.FLAGS.cam_near_far[0], self.FLAGS.cam_near_far[1]\n        )\n\n\n\n        camera_dict = np.load(os.path.join(self.base_dir, 'cameras_sphere.npz'))\n\n        # world_mat is a projection matrix from world to image\n        self.world_mats_np = [camera_dict['world_mat_%d' % idx].astype(np.float32) for idx in range(self.n_images)]\n        self.scale_mats_np = []\n\n\n        # scale_mat: used for coordinate normalization, we assume the scene to render is inside a unit sphere at origin.\n        self.scale_mats_np = [camera_dict['scale_mat_%d' % idx].astype(np.float32) for idx in range(self.n_images)]\n        self.intrinsics_all = []\n        self.pose_all = []\n\n        for scale_mat, world_mat in zip(self.scale_mats_np, self.world_mats_np):\n            P = world_mat @ scale_mat\n            P = P[:3, :4]\n            intrinsics, pose = load_K_Rt_from_P(None, P)\n            self.intrinsics_all.append(torch.from_numpy(intrinsics).float())\n            self.pose_all.append(torch.from_numpy(pose).float())\n\n        # Determine resolution & aspect ratio\n        self.resolution = _load_img(os.path.join(self.base_dir, '{:03d}.png'.format(0))).shape[0:2]\n        self.aspect = self.resolution[1] / self.resolution[0]\n\n        if self.FLAGS.local_rank == 0:\n            print(\"DatasetNERF: %d images with shape [%d, %d]\" % (self.n_images, self.resolution[0], self.resolution[1]))\n\n    def _parse_frame(self, idx):\n        # Load image data and modelview matrix\n        img    = _load_img(os.path.join(self.base_dir, '{:03d}.png'.format(idx)))\n        img[:,:,:3] = img[:,:,:3] * img[:,:,3:]\n        img[:,:,3] = torch.sign(img[:,:,3])\n        assert img.size(-1) == 4\n\n        flip_mat = torch.tensor([\n            [ 1,  0,  0,  0],\n            [ 0, -1,  0,  0],\n            [ 0,  0, -1,  0],\n            [ 0,  0,  0,  1]\n        ], dtype=torch.float)\n\n        mv = flip_mat @ torch.linalg.inv(self.pose_all[idx])\n        campos = torch.linalg.inv(mv)[:3, 3]\n        mvp = self.proj_mtx @ mv\n\n        return img[None, ...].cuda(), mv[None, ...].cuda(), mvp[None, ...].cuda(), campos[None, ...].cuda() # Add batch dimension\n\n    def __len__(self):\n        return self.n_images if self.examples is None else self.examples\n\n    def __getitem__(self, itr):\n        iter_res = self.FLAGS.train_res\n        \n        img      = []\n\n        img, mv, mvp, campos = self._parse_frame(itr % self.n_images)\n\n        return {\n            'mv' : mv,\n            'mvp' : mvp,\n            'campos' : campos,\n            'resolution' : iter_res,\n            'spp' : self.FLAGS.spp,\n            'img' : img\n        }\n"
  },
  {
    "path": "dataset/dataset_deepfashion_testset.py",
    "content": "# Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. \n#\n# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual\n# property and proprietary rights in and to this material, related\n# documentation and any modifications thereto. Any use, reproduction, \n# disclosure or distribution of this material and related documentation \n# without an express license agreement from NVIDIA CORPORATION or \n# its affiliates is strictly prohibited.\n\nimport os\nimport glob\nimport json\n\nimport torch\nimport numpy as np\n\nfrom render import util\n\nfrom .dataset import Dataset\n\nimport cv2 as cv\n\n# This function is borrowed from IDR: https://github.com/lioryariv/idr\ndef load_K_Rt_from_P(filename, P=None):\n    if P is None:\n        lines = open(filename).read().splitlines()\n        if len(lines) == 4:\n            lines = lines[1:]\n        lines = [[x[0], x[1], x[2], x[3]] for x in (x.split(\" \") for x in lines)]\n        P = np.asarray(lines).astype(np.float32).squeeze()\n\n    out = cv.decomposeProjectionMatrix(P)\n    K = out[0]\n    R = out[1]\n    t = out[2]\n\n    K = K / K[2, 2]\n    intrinsics = np.eye(4)\n    intrinsics[:3, :3] = K\n\n\n    pose = np.eye(4, dtype=np.float32)\n    pose[:3, :3] = R.transpose()\n    pose[:3, 3] = (t[:3] / t[3])[:, 0]\n\n    return intrinsics, pose\n\ndef _load_img(path):\n    img = util.load_image_raw(path)\n    if img.dtype != np.float32: # LDR image\n        img = torch.tensor(img / 255, dtype=torch.float32)\n        img[..., 0:3] = util.srgb_to_rgb(img[..., 0:3])\n    else:\n        img = torch.tensor(img, dtype=torch.float32)\n    return img\n\n\ndef _load_mask(path):\n    img = util.load_image_raw(path)\n    if img.dtype != np.float32: # LDR image\n        img = torch.tensor(img / 255, dtype=torch.float32)\n    else:\n        img = torch.tensor(img, dtype=torch.float32)\n    return img\n\n\nclass DatasetDeepFashionTestset(Dataset):\n    def __init__(self, base_dir, FLAGS, examples=None):\n        self.FLAGS = FLAGS\n        self.examples = examples\n        self.base_dir = base_dir\n\n        # Load config / transforms\n        self.n_images = 200 ### hardcoded\n\n\n        proj_mtx_all = np.load(os.path.join(self.base_dir, 'proj_mtx_all.npy'))\n        self.intrinsics_all = []\n        self.pose_all = []\n\n\n        self.fovy               = np.deg2rad(60)\n        self.proj_mtx = util.perspective(\n            self.fovy, self.FLAGS.display_res[1] / self.FLAGS.display_res[0], self.FLAGS.cam_near_far[0], self.FLAGS.cam_near_far[1]\n        )\n\n        for i in range(proj_mtx_all.shape[0]):\n            P = proj_mtx_all[i]\n            P = P[:3, :4]\n            intrinsics, pose = load_K_Rt_from_P(None, P)\n            self.intrinsics_all.append(torch.from_numpy(intrinsics).float())\n            self.pose_all.append(torch.from_numpy(pose).float())\n\n        # Determine resolution & aspect ratio\n        self.resolution = _load_img(os.path.join(self.base_dir, '{:03d}.png'.format(0))).shape[0:2]\n        self.aspect = self.resolution[1] / self.resolution[0]\n\n        if self.FLAGS.local_rank == 0:\n            print(\"DatasetNERF: %d images with shape [%d, %d]\" % (self.n_images, self.resolution[0], self.resolution[1]))\n\n    def _parse_frame(self, idx):\n        # Load image data and modelview matrix\n        img    = _load_img(os.path.join(self.base_dir, '{:03d}.png'.format(idx)))\n        assert img.size(-1) == 4\n\n        flip_mat = torch.tensor([\n            [ 1,  0,  0,  0],\n            [ 0, -1,  0,  0],\n            [ 0,  0, -1,  0],\n            [ 0,  0,  0,  1]\n        ], dtype=torch.float)\n\n        mv = flip_mat @ torch.linalg.inv(self.pose_all[idx])\n        campos = torch.linalg.inv(mv)[:3, 3]\n        mvp = self.proj_mtx @ mv\n\n        return img[None, ...].cuda(), mv[None, ...].cuda(), mvp[None, ...].cuda(), campos[None, ...].cuda() # Add batch dimension\n\n    def __len__(self):\n        return self.n_images if self.examples is None else self.examples\n\n    def __getitem__(self, itr):\n        iter_res = self.FLAGS.train_res\n        \n        img      = []\n\n        img, mv, mvp, campos = self._parse_frame(itr % self.n_images)\n        \n\n        return {\n            'mv' : mv,\n            'mvp' : mvp,\n            'campos' : campos,\n            'resolution' : iter_res,\n            'spp' : self.FLAGS.spp,\n            'img' : img\n        }\n"
  },
  {
    "path": "dataset/dataset_llff.py",
    "content": "# Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. \n#\n# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual\n# property and proprietary rights in and to this material, related\n# documentation and any modifications thereto. Any use, reproduction, \n# disclosure or distribution of this material and related documentation \n# without an express license agreement from NVIDIA CORPORATION or \n# its affiliates is strictly prohibited.\n\nimport os\nimport glob\n\nimport torch\nimport numpy as np\n\nfrom render import util\n\nfrom .dataset import Dataset\n\ndef _load_mask(fn):\n    img = torch.tensor(util.load_image(fn), dtype=torch.float32)\n    if len(img.shape) == 2:\n        img = img[..., None].repeat(1, 1, 3)\n    return img\n\ndef _load_img(fn):\n    img = util.load_image_raw(fn)\n    if img.dtype != np.float32: # LDR image\n        img = torch.tensor(img / 255, dtype=torch.float32)\n        img[..., 0:3] = util.srgb_to_rgb(img[..., 0:3])\n    else:\n        img = torch.tensor(img, dtype=torch.float32)\n    return img\n\n###############################################################################\n# LLFF datasets (real world camera lightfields)\n###############################################################################\n\nclass DatasetLLFF(Dataset):\n    def __init__(self, base_dir, FLAGS, examples=None):\n        self.FLAGS = FLAGS\n        self.base_dir = base_dir\n        self.examples = examples\n\n        # Enumerate all image files and get resolution\n        all_img = [f for f in sorted(glob.glob(os.path.join(self.base_dir, \"images\", \"*\"))) if f.lower().endswith('png') or f.lower().endswith('jpg') or f.lower().endswith('jpeg')]\n        self.resolution = _load_img(all_img[0]).shape[0:2]\n\n        # Load camera poses\n        poses_bounds = np.load(os.path.join(self.base_dir, 'poses_bounds.npy'))\n        \n        poses        = poses_bounds[:, :-2].reshape([-1, 3, 5]).transpose([1,2,0])\n        poses        = np.concatenate([poses[:, 1:2, :], -poses[:, 0:1, :], poses[:, 2:, :]], 1) # Taken from nerf, swizzles from LLFF to expected coordinate system\n        poses        = np.moveaxis(poses, -1, 0).astype(np.float32)\n        \n        lcol         = np.array([0,0,0,1], dtype=np.float32)[None, None, :].repeat(poses.shape[0], 0)\n        self.imvs    = torch.tensor(np.concatenate((poses[:, :, 0:4], lcol), axis=1), dtype=torch.float32)\n        self.aspect  = self.resolution[1] / self.resolution[0] # width / height\n        self.fovy    = util.focal_length_to_fovy(poses[:, 2, 4], poses[:, 0, 4])\n\n        # Recenter scene so lookat position is origin\n        center                = util.lines_focal(self.imvs[..., :3, 3], -self.imvs[..., :3, 2])\n        self.imvs[..., :3, 3] = self.imvs[..., :3, 3] - center[None, ...]\n\n        if self.FLAGS.local_rank == 0:\n            print(\"DatasetLLFF: %d images with shape [%d, %d]\" % (len(all_img), self.resolution[0], self.resolution[1]))\n            print(\"DatasetLLFF: auto-centering at %s\" % (center.cpu().numpy()))\n\n        # Pre-load from disc to avoid slow png parsing\n        if self.FLAGS.pre_load:\n            self.preloaded_data = []\n            for i in range(self.imvs.shape[0]):\n                self.preloaded_data += [self._parse_frame(i)]\n\n    def _parse_frame(self, idx):\n        all_img  = [f for f in sorted(glob.glob(os.path.join(self.base_dir, \"images\", \"*\"))) if f.lower().endswith('png') or f.lower().endswith('jpg') or f.lower().endswith('jpeg')]\n        all_mask = [f for f in sorted(glob.glob(os.path.join(self.base_dir, \"masks\", \"*\"))) if f.lower().endswith('png') or f.lower().endswith('jpg') or f.lower().endswith('jpeg')]\n        assert len(all_img) == self.imvs.shape[0] and len(all_mask) == self.imvs.shape[0]\n\n        # Load image+mask data\n        img  = _load_img(all_img[idx])\n        mask = _load_mask(all_mask[idx])\n        img  = torch.cat((img, mask[..., 0:1]), dim=-1)\n\n        # Setup transforms\n        proj   = util.perspective(self.fovy[idx, ...], self.aspect, self.FLAGS.cam_near_far[0], self.FLAGS.cam_near_far[1])\n        mv     = torch.linalg.inv(self.imvs[idx, ...])\n        campos = torch.linalg.inv(mv)[:3, 3]\n        mvp    = proj @ mv\n\n        return img[None, ...], mv[None, ...], mvp[None, ...], campos[None, ...] # Add batch dimension\n\n    def __len__(self):\n        return self.imvs.shape[0] if self.examples is None else self.examples\n\n    def __getitem__(self, itr):\n        if self.FLAGS.pre_load:\n            img, mv, mvp, campos = self.preloaded_data[itr % self.imvs.shape[0]]\n        else:\n            img, mv, mvp, campos = self._parse_frame(itr % self.imvs.shape[0])\n\n        return {\n            'mv' : mv,\n            'mvp' : mvp,\n            'campos' : campos,\n            'resolution' : self.resolution,\n            'spp' : self.FLAGS.spp,\n            'img' : img\n        }\n"
  },
  {
    "path": "dataset/dataset_mesh.py",
    "content": "# Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. \n#\n# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual\n# property and proprietary rights in and to this material, related\n# documentation and any modifications thereto. Any use, reproduction, \n# disclosure or distribution of this material and related documentation \n# without an express license agreement from NVIDIA CORPORATION or \n# its affiliates is strictly prohibited.\n\nimport numpy as np\nimport torch\n\nfrom render import util\nfrom render import mesh\nfrom render import render\nfrom render import light\n\nfrom .dataset import Dataset\n\n###############################################################################\n# Reference dataset using mesh & rendering\n###############################################################################\n\nclass DatasetMesh(Dataset):\n\n    def __init__(self, ref_mesh, glctx, cam_radius, FLAGS, validate=False, mesh_center=None):\n        # Init \n        self.glctx              = glctx\n        self.cam_radius         = cam_radius\n        self.FLAGS              = FLAGS\n        self.validate           = validate\n        self.fovy               = np.deg2rad(45)\n        self.aspect             = FLAGS.train_res[1] / FLAGS.train_res[0]\n        self.random_lgt         = FLAGS.random_lgt\n        self.camera_lgt         = False\n\n        self.mesh_center = mesh_center\n\n        if self.FLAGS.local_rank == 0:\n            print(\"DatasetMesh: ref mesh has %d triangles and %d vertices\" % (ref_mesh.t_pos_idx.shape[0], ref_mesh.v_pos.shape[0]))\n\n        # Sanity test training texture resolution\n        ref_texture_res = np.maximum(ref_mesh.material['kd'].getRes(), ref_mesh.material['ks'].getRes())\n        if 'normal' in ref_mesh.material:\n            ref_texture_res = np.maximum(ref_texture_res, ref_mesh.material['normal'].getRes())\n        if self.FLAGS.local_rank == 0 and FLAGS.texture_res[0] < ref_texture_res[0] or FLAGS.texture_res[1] < ref_texture_res[1]:\n            print(\"---> WARNING: Picked a texture resolution lower than the reference mesh [%d, %d] < [%d, %d]\" % (FLAGS.texture_res[0], FLAGS.texture_res[1], ref_texture_res[0], ref_texture_res[1]))\n\n        # Load environment map texture\n        self.envlight = light.load_env(FLAGS.envmap, scale=FLAGS.env_scale)\n        \n        self.ref_mesh = mesh.compute_tangents(ref_mesh)\n\n    def _rotate_scene(self, itr):\n        proj_mtx = util.perspective(self.fovy, self.FLAGS.display_res[1] / self.FLAGS.display_res[0], self.FLAGS.cam_near_far[0], self.FLAGS.cam_near_far[1])\n\n        # Smooth rotation for display.\n        ang    = (itr / 50) * np.pi * 2\n        mv     = util.translate(0, 0, -self.cam_radius) @ (util.rotate_x(-0.4) @ util.rotate_y(ang))\n        mvp    = proj_mtx @ mv\n        campos = torch.linalg.inv(mv)[:3, 3]\n\n        return mv[None, ...].cuda(), mvp[None, ...].cuda(), campos[None, ...].cuda(), self.FLAGS.display_res, self.FLAGS.spp\n\n    def _random_scene(self):\n        # ==============================================================================================\n        #  Setup projection matrix\n        # ==============================================================================================\n        iter_res = self.FLAGS.train_res\n        proj_mtx = util.perspective(self.fovy, iter_res[1] / iter_res[0], self.FLAGS.cam_near_far[0], self.FLAGS.cam_near_far[1])\n\n        # ==============================================================================================\n        #  Random camera & light position\n        # ==============================================================================================\n\n        # Random rotation/translation matrix for optimization.\n        if self.mesh_center is not None:\n            mv     = (\n                util.translate(-self.mesh_center[0], -self.mesh_center[1], -self.mesh_center[2]-self.cam_radius) \n                @ util.random_rotation_translation(0.25)\n            )\n        else:\n            mv     = util.translate(0, 0, -self.cam_radius) @ util.random_rotation_translation(0.25)\n        mvp    = proj_mtx @ mv\n        campos = torch.linalg.inv(mv)[:3, 3]\n\n        return mv[None, ...].cuda(), mvp[None, ...].cuda(), campos[None, ...].cuda(), iter_res, self.FLAGS.spp # Add batch dimension\n\n    def __len__(self):\n        return 50 if self.validate else (self.FLAGS.iter + 1) * self.FLAGS.batch\n\n    def __getitem__(self, itr):\n        # ==============================================================================================\n        #  Randomize scene parameters\n        # ==============================================================================================\n\n        if self.validate:\n            mv, mvp, campos, iter_res, iter_spp = self._rotate_scene(itr)\n            camera_mv = None\n        else:\n            mv, mvp, campos, iter_res, iter_spp = self._random_scene()\n            if self.random_lgt:\n                rnd_rot = util.random_rotation()\n                camera_mv = rnd_rot.unsqueeze(0).clone()\n            elif self.camera_lgt:\n                camera_mv = mv.clone()\n            else:\n                camera_mv = None\n\n        with torch.no_grad():\n            rendered = render.render_mesh(self.glctx, self.ref_mesh, mvp, campos, self.envlight, iter_res, spp=iter_spp, \n                                    num_layers=self.FLAGS.layers, msaa=True, background=None, shade_data=True)\n        return {\n            'mv' : mv,\n            'mvp' : mvp,\n            'campos' : campos,\n            'resolution' : iter_res,\n            'spp' : iter_spp,\n            'img' : rendered['shaded'],\n            'img_second' : rendered['shaded_second'],\n            'invdepth' : rendered['invdepth'],\n            'invdepth_second' : rendered['invdepth_second'],\n            'envlight_transform': camera_mv\n        }\n"
  },
  {
    "path": "dataset/dataset_nerf.py",
    "content": "# Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. \n#\n# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual\n# property and proprietary rights in and to this material, related\n# documentation and any modifications thereto. Any use, reproduction, \n# disclosure or distribution of this material and related documentation \n# without an express license agreement from NVIDIA CORPORATION or \n# its affiliates is strictly prohibited.\n\nimport os\nimport glob\nimport json\n\nimport torch\nimport numpy as np\n\nfrom render import util\n\nfrom .dataset import Dataset\n\n###############################################################################\n# NERF image based dataset (synthetic)\n###############################################################################\n\ndef _load_img(path):\n    files = glob.glob(path + '.*')\n    assert len(files) > 0, \"Tried to find image file for: %s, but found 0 files\" % (path)\n    img = util.load_image_raw(files[0])\n    if img.dtype != np.float32: # LDR image\n        img = torch.tensor(img / 255, dtype=torch.float32)\n        img[..., 0:3] = util.srgb_to_rgb(img[..., 0:3])\n    else:\n        img = torch.tensor(img, dtype=torch.float32)\n    return img\n\nclass DatasetNERF(Dataset):\n    def __init__(self, cfg_path, FLAGS, examples=None):\n        self.FLAGS = FLAGS\n        self.examples = examples\n        self.base_dir = os.path.dirname(cfg_path)\n\n        # Load config / transforms\n        self.cfg = json.load(open(cfg_path, 'r'))\n        self.n_images = len(self.cfg['frames'])\n\n        # Determine resolution & aspect ratio\n        self.resolution = _load_img(os.path.join(self.base_dir, self.cfg['frames'][0]['file_path'])).shape[0:2]\n        self.aspect = self.resolution[1] / self.resolution[0]\n\n        if self.FLAGS.local_rank == 0:\n            print(\"DatasetNERF: %d images with shape [%d, %d]\" % (self.n_images, self.resolution[0], self.resolution[1]))\n\n        # Pre-load from disc to avoid slow png parsing\n        if self.FLAGS.pre_load:\n            self.preloaded_data = []\n            for i in range(self.n_images):\n                self.preloaded_data += [self._parse_frame(self.cfg, i)]\n\n    def _parse_frame(self, cfg, idx):\n        # Config projection matrix (static, so could be precomputed)\n        fovy   = util.fovx_to_fovy(cfg['camera_angle_x'], self.aspect)\n        proj   = util.perspective(fovy, self.aspect, self.FLAGS.cam_near_far[0], self.FLAGS.cam_near_far[1])\n\n        # Load image data and modelview matrix\n        img    = _load_img(os.path.join(self.base_dir, cfg['frames'][idx]['file_path']))\n        mv     = torch.linalg.inv(torch.tensor(cfg['frames'][idx]['transform_matrix'], dtype=torch.float32))\n        mv     = mv @ util.rotate_x(-np.pi / 2)\n        campos = torch.linalg.inv(mv)[:3, 3]\n        mvp    = proj @ mv\n\n        return img[None, ...], mv[None, ...], mvp[None, ...], campos[None, ...] # Add batch dimension\n\n    def __len__(self):\n        return self.n_images if self.examples is None else self.examples\n\n    def __getitem__(self, itr):\n        iter_res = self.FLAGS.train_res\n        \n        img      = []\n        fovy     = util.fovx_to_fovy(self.cfg['camera_angle_x'], self.aspect)\n\n        if self.FLAGS.pre_load:\n            img, mv, mvp, campos = self.preloaded_data[itr % self.n_images]\n        else:\n            img, mv, mvp, campos = self._parse_frame(self.cfg, itr % self.n_images)\n\n        return {\n            'mv' : mv,\n            'mvp' : mvp,\n            'campos' : campos,\n            'resolution' : iter_res,\n            'spp' : self.FLAGS.spp,\n            'img' : img\n        }\n"
  },
  {
    "path": "dataset/dataset_nerf_colmap.py",
    "content": "# Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. \n#\n# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual\n# property and proprietary rights in and to this material, related\n# documentation and any modifications thereto. Any use, reproduction, \n# disclosure or distribution of this material and related documentation \n# without an express license agreement from NVIDIA CORPORATION or \n# its affiliates is strictly prohibited.\n\nimport os\nimport glob\nimport json\n\nimport torch\nimport numpy as np\n\nfrom render import util\n\nfrom .dataset import Dataset\n\n###############################################################################\n# NERF image based dataset (synthetic)\n###############################################################################\n\ndef _load_img(path):\n    img = util.load_image_raw(path)\n    if img.dtype != np.float32: # LDR image\n        img = torch.tensor(img / 255, dtype=torch.float32)\n        img[..., 0:3] = util.srgb_to_rgb(img[..., 0:3])\n    else:\n        img = torch.tensor(img, dtype=torch.float32)\n    return img\n\nclass DatasetNERF(Dataset):\n    def __init__(self, cfg_path, FLAGS, examples=None):\n        self.FLAGS = FLAGS\n        self.examples = examples\n        self.base_dir = os.path.dirname(cfg_path)\n\n        # Load config / transforms\n        self.cfg = json.load(open(cfg_path, 'r'))\n        self.n_images = len(self.cfg['frames'])\n\n        # Determine resolution & aspect ratio\n        self.resolution = _load_img(os.path.join(self.base_dir, self.cfg['frames'][0]['file_path'])).shape[0:2]\n        self.aspect = self.resolution[1] / self.resolution[0]\n\n        if self.FLAGS.local_rank == 0:\n            print(\"DatasetNERF: %d images with shape [%d, %d]\" % (self.n_images, self.resolution[0], self.resolution[1]))\n\n        # Pre-load from disc to avoid slow png parsing\n        if self.FLAGS.pre_load:\n            self.preloaded_data = []\n            for i in range(self.n_images):\n                self.preloaded_data += [self._parse_frame(self.cfg, i)]\n\n    def _parse_frame(self, cfg, idx):\n        # Config projection matrix (static, so could be precomputed)\n        fovy   = util.fovx_to_fovy(cfg['frames'][idx]['camera_angle_x'], self.aspect)\n        proj   = util.perspective(fovy, self.aspect, self.FLAGS.cam_near_far[0], self.FLAGS.cam_near_far[1])\n\n        # Load image data and modelview matrix\n        img    = _load_img(os.path.join(self.base_dir, cfg['frames'][idx]['file_path']))\n        mask   = _load_img(os.path.join(self.base_dir, cfg['frames'][idx]['file_path']).replace('/image/', '/mask/').replace('.jpg', '.png'))\n        img    = torch.cat([img, mask[:,:,:1]], dim=-1)\n        mv     = torch.linalg.inv(torch.tensor(cfg['frames'][idx]['transform_matrix'], dtype=torch.float32))\n        mv     = mv @ util.rotate_x(-np.pi / 2)\n        campos = torch.linalg.inv(mv)[:3, 3]\n        mvp    = proj @ mv\n\n        return img[None, ...], mv[None, ...], mvp[None, ...], campos[None, ...] # Add batch dimension\n\n    def __len__(self):\n        return self.n_images if self.examples is None else self.examples\n\n    def __getitem__(self, itr):\n        iter_res = self.FLAGS.train_res\n        \n        img      = []\n        fovy     = util.fovx_to_fovy(self.cfg['frames'][itr % self.n_images]['camera_angle_x'], self.aspect)\n\n        if self.FLAGS.pre_load:\n            img, mv, mvp, campos = self.preloaded_data[itr % self.n_images]\n        else:\n            img, mv, mvp, campos = self._parse_frame(self.cfg, itr % self.n_images)\n\n        return {\n            'mv' : mv,\n            'mvp' : mvp,\n            'campos' : campos,\n            'resolution' : iter_res,\n            'spp' : self.FLAGS.spp,\n            'img' : img\n        }\n"
  },
  {
    "path": "denoiser/denoiser.py",
    "content": "import os\n\nimport torch\nimport numpy as np\nimport math\n\nfrom render import util\nif \"TWOSIDED_TEXTURE\" not in os.environ or os.environ[\"TWOSIDED_TEXTURE\"] == \"True\":\n\tfrom render import optixutils as ou\nelse:\n\tfrom render import optixutils_single_sided as ou\n\n\n###############################################################################\n# Bilateral denoiser\n#\n# Loosely based on SVGF, but removing temporal components and variance stopping guides.\n# https://research.nvidia.com/publication/2017-07_spatiotemporal-variance-guided-filtering-real-time-reconstruction-path-traced\n###############################################################################\n\nclass BilateralDenoiser(torch.nn.Module):\n\tdef __init__(self, influence=1.0):\n\t\tsuper(BilateralDenoiser, self).__init__()\n\t\tself.set_influence(influence)\n\n\tdef set_influence(self, factor):\n\t\tself.sigma = max(factor * 2, 0.0001)\n\t\tself.variance = self.sigma**2.\n\t\tself.N = 2 * math.ceil(self.sigma * 2.5) + 1\n\n\tdef forward(self, input):\n\t\tcol    = input[..., 0:3]\n\t\tnrm    = util.safe_normalize(input[..., 3:6]) # Bent normals can produce normals of length < 1 here\n\t\tzdz    = input[..., 6:8]\n\t\treturn ou.bilateral_denoiser(col, nrm, zdz, self.sigma)\n"
  },
  {
    "path": "eval_gmeshdiffusion_generated_samples.py",
    "content": "# Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. \n#\n# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual\n# property and proprietary rights in and to this material, related\n# documentation and any modifications thereto. Any use, reproduction, \n# disclosure or distribution of this material and related documentation \n# without an express license agreement from NVIDIA CORPORATION or \n# its affiliates is strictly prohibited.\n\nimport os\nimport argparse\nimport json\n\nimport numpy as np\nimport torch\n\n# Import topology / geometry trainers\nfrom geometry.gshell_tets_geometry import GShellTetsGeometry\n\nfrom render import texture\n\nimport pymeshlab\nfrom pytorch3d.io import save_obj\n\nimport tqdm\n\nRADIUS = 4.0\n# RADIUS = 2.5\n\n# Enable to debug back-prop anomalies\n# torch.autograd.set_detect_anomaly(True)\n\n#----------------------------------------------------------------------------\n# Main function.\n#----------------------------------------------------------------------------\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser(description='nvdiffrec')\n    parser.add_argument('--config', type=str, default=None, help='Config file')\n    parser.add_argument('-i', '--iter', type=int, default=5000)\n    parser.add_argument('-b', '--batch', type=int, default=1)\n    parser.add_argument('-s', '--spp', type=int, default=1)\n    parser.add_argument('-l', '--layers', type=int, default=1)\n    parser.add_argument('-r', '--train-res', nargs=2, type=int, default=[512, 512])\n    parser.add_argument('-dr', '--display-res', type=int, default=None)\n    parser.add_argument('-tr', '--texture-res', nargs=2, type=int, default=[1024, 1024])\n    parser.add_argument('-di', '--display-interval', type=int, default=0)\n    parser.add_argument('-si', '--save-interval', type=int, default=1000)\n    parser.add_argument('-lr', '--learning-rate', type=float, default=0.01)\n    parser.add_argument('-mr', '--min-roughness', type=float, default=0.08)\n    parser.add_argument('-mip', '--custom-mip', action='store_true', default=False)\n    parser.add_argument('-rt', '--random-textures', action='store_true', default=False)\n    parser.add_argument('-bg', '--background', default='checker', choices=['black', 'white', 'checker', 'reference'])\n    parser.add_argument('--loss', default='logl1', choices=['logl1', 'logl2', 'mse', 'smape', 'relmse'])\n    parser.add_argument('-o', '--out-dir', type=str, default=None)\n    parser.add_argument('-rm', '--ref_mesh', type=str)\n    parser.add_argument('-bm', '--base-mesh', type=str, default=None)\n    parser.add_argument('--validate', type=bool, default=True)\n    parser.add_argument('--grid_root', type=str)\n    \n    FLAGS = parser.parse_args()\n\n    FLAGS.mtl_override        = None                     # Override material of model\n    FLAGS.dmtet_grid          = 64                      # Resolution of initial tet grid. We provide 64 and 128 resolution grids. Other resolutions can be generated with https://github.com/crawforddoran/quartet\n    FLAGS.mesh_scale          = 2.3                      # Scale of tet grid box. Adjust to cover the model\n    FLAGS.env_scale           = 1.0                      # Env map intensity multiplier\n    FLAGS.envmap              = None                     # HDR environment probe\n    FLAGS.display             = None                     # Conf validation window/display. E.g. [{\"relight\" : <path to envlight>}]\n    FLAGS.camera_space_light  = False                    # Fixed light in camera space. This is needed for setups like ethiopian head where the scanned object rotates on a stand.\n    FLAGS.lock_light          = False                    # Disable light optimization in the second pass\n    FLAGS.lock_pos            = False                    # Disable vertex position optimization in the second pass\n    FLAGS.sdf_regularizer     = 0.2                      # Weight for sdf regularizer (see paper for details)\n    FLAGS.laplace             = \"relative\"               # Mesh Laplacian [\"absolute\", \"relative\"]\n    FLAGS.laplace_scale       = 10000.0                  # Weight for Laplacian regularizer. Default is relative with large weight\n    FLAGS.pre_load            = True                     # Pre-load entire dataset into memory for faster training\n    FLAGS.kd_min              = [ 0.0,  0.0,  0.0,  0.0] # Limits for kd\n    FLAGS.kd_max              = [ 1.0,  1.0,  1.0,  1.0]\n    FLAGS.ks_min              = [ 0.0, 0.08,  0.0]       # Limits for ks\n    FLAGS.ks_max              = [ 1.0,  1.0,  1.0]\n    FLAGS.nrm_min             = [-1.0, -1.0,  0.0]       # Limits for normal map\n    FLAGS.nrm_max             = [ 1.0,  1.0,  1.0]\n    FLAGS.cam_near_far        = [0.1, 1000.0]\n    FLAGS.use_tanh_deform     = False\n    FLAGS.use_sdf_mlp         = False\n    FLAGS.force_default_mtl   = True\n    FLAGS.twosided_texture    = True\n    FLAGS.random_lgt          = False\n    FLAGS.sphere_init         = False\n    FLAGS.num_smooth_steps    = 3\n    FLAGS.use_msdf_mlp        = False\n\n    if FLAGS.config is not None:\n        data = json.load(open(FLAGS.config, 'r'))\n        for key in data:\n            FLAGS.__dict__[key] = data[key]\n\n\n    os.makedirs(FLAGS.out_dir, exist_ok=True)\n\n    mtl_default_diffuse = {\n        'name' : '_default_mat',\n        'bsdf': 'diffuse',\n        'uniform': True,\n        'kd'   : texture.Texture2D(torch.tensor([0.75, 0.3, 0.6], dtype=torch.float32, device='cuda')),\n        'ks'   : texture.Texture2D(torch.tensor([0.0, 0.0, 0.0], dtype=torch.float32, device='cuda'))\n    }\n\n    if FLAGS.force_default_mtl:\n        mtl_default = mtl_default_diffuse\n    else:\n        mtl_default = None\n\n\n    tet_path = './data/tets/64_tets_cropped_reordered.npz'\n    tet = np.load(tet_path)\n    vertices = torch.tensor(tet['vertices'])\n    edges = torch.tensor(tet['edges']).long()\n    vertices_unique = vertices[:].unique()\n    dx = (vertices_unique[1] - vertices_unique[0]) / 2.0\n\n    vertices_discretized = (torch.round(\n        (vertices - vertices.min()) / dx)\n    ).long()\n\n\n    midpoints = (vertices[edges[:, 0]] + vertices[edges[:, 1]]) / 2.0\n    midpoints_dicretized = (torch.round(\n        (midpoints - vertices.min()) / dx)\n    ).long()\n\n    aabb = torch.tensor(FLAGS.aabb, dtype=torch.float).cuda().view(2, 3)\n    center = aabb.mean(0, keepdim=True) / 2.0\n\n    mesh_scale = 3.8\n    mesh_scale = mesh_scale / torch.max(aabb[1] - aabb[0]).item()\n\n    count = 0\n    grid_root = FLAGS.grid_root\n    geometry = GShellTetsGeometry(FLAGS.dmtet_grid, FLAGS.mesh_scale, FLAGS, tet_init_file=tet_path, extract_from_generative=True)\n    with torch.no_grad():\n        for grid_name in tqdm.tqdm(sorted(list(os.listdir(grid_root)))):\n            if '_occ' in grid_name:\n                continue\n\n\n            grid_all = torch.load(\n                os.path.join(grid_root, grid_name), map_location='cuda'\n            )\n            occgrid_all = torch.load(\n                os.path.join(grid_root, grid_name).replace('.pt', '_occ.pt'), map_location='cuda'\n            )[:, 0]\n            for i in tqdm.trange(grid_all.size(0), leave=False):\n                mesh_path = FLAGS.out_dir\n                os.makedirs(mesh_path, exist_ok=True)\n                mesh_savepath = os.path.join(mesh_path, '{:06d}.obj'.format(count))\n\n                if os.path.exists(mesh_savepath):\n                    count += 1\n                    continue\n                grid = grid_all[i]\n                occgrid = occgrid_all[i]\n\n                sdf_sign = (\n                        grid[0, vertices_discretized[:, 0], vertices_discretized[:, 1], vertices_discretized[:, 2]]\n                    ).cuda().float()\n                geometry.deform.data[:] = (\n                        grid[1:4, vertices_discretized[:, 0], vertices_discretized[:, 1], vertices_discretized[:, 2]]\n                    ).cuda().transpose(0, 1).float().clamp(-1, 1)\n                \n\n                sdf_coeff = torch.ones(128, 128, 128).float().cuda() * 0.5\n\n                msdf_sign = torch.zeros(128, 128, 128).float().cuda()\n                msdf_sign[midpoints_dicretized[:, 0], midpoints_dicretized[:, 1], midpoints_dicretized[:, 2]] = torch.sign(\n                    grid[0, midpoints_dicretized[:, 0], midpoints_dicretized[:, 1], midpoints_dicretized[:, 2]].cuda()\n                ).float()\n                geometry.deform.data[:] = geometry.deform.data[:].clip(-1.0, 1.0)\n                geometry.deform_scale = 2.0\n\n                base_mesh = geometry.getMesh_from_augmented_grid_withocc(mtl_default, torch.sign(sdf_sign), sdf_coeff, msdf_sign, occgrid=occgrid)['imesh']\n\n                ### rescale and translate back to align with the dataset\n                base_mesh.v_pos = (base_mesh.v_pos / mesh_scale) + center\n\n                ### save post-processed mesh\n                save_obj(\n                    verts=base_mesh.v_pos,\n                    faces=base_mesh.t_pos_idx,\n                    f=mesh_savepath\n                )\n\n                ms = pymeshlab.MeshSet()\n                ms.load_new_mesh(mesh_savepath)\n                ms.meshing_remove_unreferenced_vertices()\n                ms.meshing_isotropic_explicit_remeshing()\n                ms.apply_coord_laplacian_smoothing(stepsmoothnum=FLAGS.num_smooth_steps, cotangentweight=True)\n                # ms.apply_coord_hc_laplacian_smoothing()\n                # ms.apply_coord_laplacian_smoothing(stepsmoothnum=3, cotangentweight=True) ## for smoother surface\n                ms.meshing_isotropic_explicit_remeshing()\n                ms.apply_filter_script()\n                ms.save_current_mesh(mesh_savepath)\n\n                count += 1\n"
  },
  {
    "path": "geometry/embedding.py",
    "content": "import torch\nfrom torch import nn\n\nclass Embedding(nn.Module):\n    def __init__(self, in_channels, N_freqs, logscale=True):\n        \"\"\"\n        Defines a function that embeds x to (x, sin(2^k x), cos(2^k x), ...)\n        in_channels: number of input channels (3 for both xyz and direction)\n        \"\"\"\n        super(Embedding, self).__init__()\n        self.N_freqs = N_freqs\n        self.in_channels = in_channels\n        self.funcs = [torch.sin, torch.cos]\n        self.out_channels = in_channels*(len(self.funcs)*N_freqs+1)\n\n        if logscale:\n            self.freq_bands = 2**torch.linspace(0, N_freqs-1, N_freqs)\n        else:\n            self.freq_bands = torch.linspace(1, 2**(N_freqs-1), N_freqs)\n\n    def forward(self, x):\n        \"\"\"\n        Embeds x to (x, sin(2^k x), cos(2^k x), ...) \n        Different from the paper, \"x\" is also in the output\n        See https://github.com/bmild/nerf/issues/12\n\n        Inputs:\n            x: (B, self.in_channels)\n\n        Outputs:\n            out: (B, self.out_channels)\n        \"\"\"\n        out = [x]\n        for freq in self.freq_bands:\n            for func in self.funcs:\n                out += [func(freq*x)]\n\n        return torch.cat(out, -1)\n\n"
  },
  {
    "path": "geometry/flexicubes_table.py",
    "content": "# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES.  All rights reserved.\n#\n# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property\n# and proprietary rights in and to this software, related documentation\n# and any modifications thereto.  Any use, reproduction, disclosure or\n# distribution of this software and related documentation without an express\n# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited.\ndmc_table = [\n[[-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],\n[[0, 3, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],\n[[0, 1, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],\n[[1, 3, 8, 9, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],\n[[4, 7, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],\n[[0, 3, 4, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],\n[[0, 1, 9, -1, -1, -1, -1], [4, 7, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],\n[[1, 3, 4, 7, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],\n[[4, 5, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],\n[[0, 3, 8, -1, -1, -1, -1], [4, 5, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],\n[[0, 1, 4, 5, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],\n[[1, 3, 4, 5, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],\n[[5, 7, 8, 9, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],\n[[0, 3, 5, 7, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],\n[[0, 1, 5, 7, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],\n[[1, 3, 5, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],\n[[2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],\n[[0, 2, 8, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],\n[[0, 1, 9, -1, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],\n[[1, 2, 8, 9, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],\n[[4, 7, 8, -1, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],\n[[0, 2, 4, 7, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],\n[[0, 1, 9, -1, -1, -1, -1], [4, 7, 8, -1, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],\n[[1, 2, 4, 7, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],\n[[4, 5, 9, -1, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],\n[[0, 2, 8, 11, -1, -1, -1], [4, 5, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],\n[[0, 1, 4, 5, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],\n[[1, 2, 4, 5, 8, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],\n[[5, 7, 8, 9, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],\n[[0, 2, 5, 7, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],\n[[0, 1, 5, 7, 8, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],\n[[1, 2, 5, 7, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],\n[[1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],\n[[0, 3, 8, -1, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],\n[[0, 2, 9, 10, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],\n[[2, 3, 8, 9, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],\n[[4, 7, 8, -1, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],\n[[0, 3, 4, 7, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],\n[[0, 2, 9, 10, -1, -1, -1], [4, 7, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],\n[[2, 3, 4, 7, 9, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],\n[[4, 5, 9, -1, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],\n[[0, 3, 8, -1, -1, -1, -1], [4, 5, 9, -1, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],\n[[0, 2, 4, 5, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],\n[[2, 3, 4, 5, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],\n[[5, 7, 8, 9, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],\n[[0, 3, 5, 7, 9, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],\n[[0, 2, 5, 7, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],\n[[2, 3, 5, 7, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],\n[[1, 3, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],\n[[0, 1, 8, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],\n[[0, 3, 9, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],\n[[8, 9, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],\n[[4, 7, 8, -1, -1, -1, -1], [1, 3, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],\n[[0, 1, 4, 7, 10, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],\n[[0, 3, 9, 10, 11, -1, -1], [4, 7, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],\n[[4, 7, 9, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],\n[[4, 5, 9, -1, -1, -1, -1], [1, 3, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],\n[[0, 1, 8, 10, 11, -1, -1], [4, 5, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],\n[[0, 3, 4, 5, 10, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],\n[[4, 5, 8, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],\n[[5, 7, 8, 9, -1, -1, -1], [1, 3, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],\n[[0, 1, 5, 7, 9, 10, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],\n[[0, 3, 5, 7, 8, 10, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],\n[[5, 7, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],\n[[6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],\n[[0, 3, 8, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],\n[[0, 1, 9, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],\n[[1, 3, 8, 9, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],\n[[4, 6, 8, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],\n[[0, 3, 4, 6, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],\n[[0, 1, 9, -1, -1, -1, -1], [4, 6, 8, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],\n[[1, 3, 4, 6, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],\n[[4, 5, 9, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],\n[[0, 3, 8, -1, -1, -1, -1], [4, 5, 9, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],\n[[0, 1, 4, 5, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],\n[[1, 3, 4, 5, 8, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],\n[[5, 6, 8, 9, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],\n[[0, 3, 5, 6, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],\n[[0, 1, 5, 6, 8, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],\n[[1, 3, 5, 6, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],\n[[2, 3, 6, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],\n[[0, 2, 6, 7, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],\n[[0, 1, 9, -1, -1, -1, -1], [2, 3, 6, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],\n[[1, 2, 6, 7, 8, 9, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],\n[[2, 3, 4, 6, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],\n[[0, 2, 4, 6, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],\n[[0, 1, 9, -1, -1, -1, -1], [2, 3, 4, 6, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],\n[[1, 2, 4, 6, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],\n[[4, 5, 9, -1, -1, -1, -1], [2, 3, 6, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],\n[[0, 2, 6, 7, 8, -1, -1], [4, 5, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],\n[[0, 1, 4, 5, -1, -1, -1], [2, 3, 6, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],\n[[1, 2, 4, 5, 6, 7, 8], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],\n[[2, 3, 5, 6, 8, 9, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],\n[[0, 2, 5, 6, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],\n[[0, 1, 2, 3, 5, 6, 8], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],\n[[1, 2, 5, 6, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],\n[[1, 2, 10, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],\n[[0, 3, 8, -1, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],\n[[0, 2, 9, 10, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],\n[[2, 3, 8, 9, 10, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],\n[[4, 6, 8, 11, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],\n[[0, 3, 4, 6, 11, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],\n[[0, 2, 9, 10, -1, -1, -1], [4, 6, 8, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],\n[[2, 3, 4, 6, 9, 10, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],\n[[4, 5, 9, -1, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],\n[[0, 3, 8, -1, -1, -1, -1], [4, 5, 9, -1, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1]],\n[[0, 2, 4, 5, 10, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],\n[[2, 3, 4, 5, 8, 10, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],\n[[5, 6, 8, 9, 11, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],\n[[0, 3, 5, 6, 9, 11, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],\n[[0, 2, 5, 6, 8, 10, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],\n[[2, 3, 5, 6, 10, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],\n[[1, 3, 6, 7, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],\n[[0, 1, 6, 7, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],\n[[0, 3, 6, 7, 9, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],\n[[6, 7, 8, 9, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],\n[[1, 3, 4, 6, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],\n[[0, 1, 4, 6, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],\n[[0, 3, 4, 6, 8, 9, 10], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],\n[[4, 6, 9, 10, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],\n[[4, 5, 9, -1, -1, -1, -1], [1, 3, 6, 7, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],\n[[0, 1, 6, 7, 8, 10, -1], [4, 5, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],\n[[0, 3, 4, 5, 6, 7, 10], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],\n[[4, 5, 6, 7, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],\n[[1, 3, 5, 6, 8, 9, 10], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],\n[[0, 1, 5, 6, 9, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],\n[[0, 3, 8, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],\n[[5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],\n[[5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],\n[[0, 3, 8, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],\n[[0, 1, 9, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],\n[[1, 3, 8, 9, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],\n[[4, 7, 8, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],\n[[0, 3, 4, 7, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],\n[[0, 1, 9, -1, -1, -1, -1], [4, 7, 8, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],\n[[1, 3, 4, 7, 9, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],\n[[4, 6, 9, 10, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],\n[[0, 3, 8, -1, -1, -1, -1], [4, 6, 9, 10, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],\n[[0, 1, 4, 6, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],\n[[1, 3, 4, 6, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],\n[[6, 7, 8, 9, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],\n[[0, 3, 6, 7, 9, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],\n[[0, 1, 6, 7, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],\n[[1, 3, 6, 7, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],\n[[2, 3, 11, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],\n[[0, 2, 8, 11, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],\n[[0, 1, 9, -1, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],\n[[1, 2, 8, 9, 11, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],\n[[4, 7, 8, -1, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],\n[[0, 2, 4, 7, 11, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],\n[[0, 1, 9, -1, -1, -1, -1], [4, 7, 8, -1, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1]],\n[[1, 2, 4, 7, 9, 11, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],\n[[4, 6, 9, 10, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],\n[[0, 2, 8, 11, -1, -1, -1], [4, 6, 9, 10, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],\n[[0, 1, 4, 6, 10, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],\n[[1, 2, 4, 6, 8, 10, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],\n[[6, 7, 8, 9, 10, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],\n[[0, 2, 6, 7, 9, 10, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],\n[[0, 1, 6, 7, 8, 10, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],\n[[1, 2, 6, 7, 10, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],\n[[1, 2, 5, 6, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],\n[[0, 3, 8, -1, -1, -1, -1], [1, 2, 5, 6, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],\n[[0, 2, 5, 6, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],\n[[2, 3, 5, 6, 8, 9, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],\n[[4, 7, 8, -1, -1, -1, -1], [1, 2, 5, 6, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],\n[[0, 3, 4, 7, -1, -1, -1], [1, 2, 5, 6, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],\n[[0, 2, 5, 6, 9, -1, -1], [4, 7, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],\n[[2, 3, 4, 5, 6, 7, 9], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],\n[[1, 2, 4, 6, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],\n[[0, 3, 8, -1, -1, -1, -1], [1, 2, 4, 6, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],\n[[0, 2, 4, 6, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],\n[[2, 3, 4, 6, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],\n[[1, 2, 6, 7, 8, 9, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],\n[[0, 1, 2, 3, 6, 7, 9], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],\n[[0, 2, 6, 7, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],\n[[2, 3, 6, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],\n[[1, 3, 5, 6, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],\n[[0, 1, 5, 6, 8, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],\n[[0, 3, 5, 6, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],\n[[5, 6, 8, 9, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],\n[[4, 7, 8, -1, -1, -1, -1], [1, 3, 5, 6, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],\n[[0, 1, 4, 5, 6, 7, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],\n[[0, 3, 5, 6, 9, 11, -1], [4, 7, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],\n[[4, 5, 6, 7, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],\n[[1, 3, 4, 6, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],\n[[0, 1, 4, 6, 8, 9, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],\n[[0, 3, 4, 6, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],\n[[4, 6, 8, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],\n[[1, 3, 6, 7, 8, 9, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],\n[[0, 1, 9, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],\n[[0, 3, 6, 7, 8, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],\n[[6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],\n[[5, 7, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],\n[[0, 3, 8, -1, -1, -1, -1], [5, 7, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],\n[[0, 1, 9, -1, -1, -1, -1], [5, 7, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],\n[[1, 3, 8, 9, -1, -1, -1], [5, 7, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],\n[[4, 5, 8, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],\n[[0, 3, 4, 5, 10, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],\n[[0, 1, 9, -1, -1, -1, -1], [4, 5, 8, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],\n[[1, 3, 4, 5, 9, 10, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],\n[[4, 7, 9, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],\n[[0, 3, 8, -1, -1, -1, -1], [4, 7, 9, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],\n[[0, 1, 4, 7, 10, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],\n[[1, 3, 4, 7, 8, 10, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],\n[[8, 9, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],\n[[0, 3, 9, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],\n[[0, 1, 8, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],\n[[1, 3, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],\n[[2, 3, 5, 7, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],\n[[0, 2, 5, 7, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],\n[[0, 1, 9, -1, -1, -1, -1], [2, 3, 5, 7, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],\n[[1, 2, 5, 7, 8, 9, 10], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],\n[[2, 3, 4, 5, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],\n[[0, 2, 4, 5, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],\n[[0, 1, 9, -1, -1, -1, -1], [2, 3, 4, 5, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],\n[[1, 2, 4, 5, 9, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],\n[[2, 3, 4, 7, 9, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],\n[[0, 2, 4, 7, 8, 9, 10], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],\n[[0, 1, 2, 3, 4, 7, 10], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],\n[[4, 7, 8, -1, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],\n[[2, 3, 8, 9, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],\n[[0, 2, 9, 10, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],\n[[0, 1, 2, 3, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],\n[[1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],\n[[1, 2, 5, 7, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],\n[[0, 3, 8, -1, -1, -1, -1], [1, 2, 5, 7, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],\n[[0, 2, 5, 7, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],\n[[2, 3, 5, 7, 8, 9, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],\n[[1, 2, 4, 5, 8, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],\n[[0, 1, 2, 3, 4, 5, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],\n[[0, 2, 4, 5, 8, 9, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],\n[[4, 5, 9, -1, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],\n[[1, 2, 4, 7, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],\n[[0, 3, 8, -1, -1, -1, -1], [1, 2, 4, 7, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],\n[[0, 2, 4, 7, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],\n[[2, 3, 4, 7, 8, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],\n[[1, 2, 8, 9, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],\n[[0, 1, 2, 3, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],\n[[0, 2, 8, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],\n[[2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],\n[[1, 3, 5, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],\n[[0, 1, 5, 7, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],\n[[0, 3, 5, 7, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],\n[[5, 7, 8, 9, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],\n[[1, 3, 4, 5, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],\n[[0, 1, 4, 5, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],\n[[0, 3, 4, 5, 8, 9, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],\n[[4, 5, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],\n[[1, 3, 4, 7, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],\n[[0, 1, 4, 7, 8, 9, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],\n[[0, 3, 4, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],\n[[4, 7, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],\n[[1, 3, 8, 9, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],\n[[0, 1, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],\n[[0, 3, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]],\n[[-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]]\n]\nnum_vd_table = [0, 1, 1, 1, 1, 1, 2, 1, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 2, 1, 2, 1, 3, 1, 2, 2,\n2, 1, 2, 1, 2, 1, 1, 2, 1, 1, 2, 2, 2, 1, 2, 3, 1, 1, 2, 2, 1, 1, 1, 1, 1, 1, 2,\n1, 2, 1, 2, 2, 1, 1, 2, 1, 1, 1, 1, 2, 2, 2, 1, 1, 2, 1, 2, 3, 2, 2, 1, 1, 1, 1,\n1, 1, 2, 1, 1, 1, 2, 1, 2, 2, 2, 1, 1, 1, 1, 1, 2, 3, 2, 2, 2, 2, 2, 1, 3, 4, 2,\n2, 2, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 1, 1, 1, 1, 2, 1, 1, 2, 2, 2, 2, 2,\n3, 2, 1, 2, 1, 1, 1, 1, 1, 1, 2, 2, 3, 2, 3, 2, 4, 2, 2, 2, 2, 1, 2, 1, 2, 1, 1,\n2, 1, 1, 2, 2, 2, 1, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 1, 2, 1, 1, 1, 1, 1,\n1, 2, 1, 1, 1, 2, 2, 2, 1, 1, 2, 1, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 2, 1, 1, 1, 2,\n1, 1, 1, 1, 2, 1, 1, 1, 1, 1, 2, 1, 1, 1, 1, 1, 2, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1,\n1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0]\ncheck_table = [\n[0, 0, 0, 0, 0],\n[0, 0, 0, 0, 0],\n[0, 0, 0, 0, 0],\n[0, 0, 0, 0, 0],\n[0, 0, 0, 0, 0],\n[0, 0, 0, 0, 0],\n[0, 0, 0, 0, 0],\n[0, 0, 0, 0, 0],\n[0, 0, 0, 0, 0],\n[0, 0, 0, 0, 0],\n[0, 0, 0, 0, 0],\n[0, 0, 0, 0, 0],\n[0, 0, 0, 0, 0],\n[0, 0, 0, 0, 0],\n[0, 0, 0, 0, 0],\n[0, 0, 0, 0, 0],\n[0, 0, 0, 0, 0],\n[0, 0, 0, 0, 0],\n[0, 0, 0, 0, 0],\n[0, 0, 0, 0, 0],\n[0, 0, 0, 0, 0],\n[0, 0, 0, 0, 0],\n[0, 0, 0, 0, 0],\n[0, 0, 0, 0, 0],\n[0, 0, 0, 0, 0],\n[0, 0, 0, 0, 0],\n[0, 0, 0, 0, 0],\n[0, 0, 0, 0, 0],\n[0, 0, 0, 0, 0],\n[0, 0, 0, 0, 0],\n[0, 0, 0, 0, 0],\n[0, 0, 0, 0, 0],\n[0, 0, 0, 0, 0],\n[0, 0, 0, 0, 0],\n[0, 0, 0, 0, 0],\n[0, 0, 0, 0, 0],\n[0, 0, 0, 0, 0],\n[0, 0, 0, 0, 0],\n[0, 0, 0, 0, 0],\n[0, 0, 0, 0, 0],\n[0, 0, 0, 0, 0],\n[0, 0, 0, 0, 0],\n[0, 0, 0, 0, 0],\n[0, 0, 0, 0, 0],\n[0, 0, 0, 0, 0],\n[0, 0, 0, 0, 0],\n[0, 0, 0, 0, 0],\n[0, 0, 0, 0, 0],\n[0, 0, 0, 0, 0],\n[0, 0, 0, 0, 0],\n[0, 0, 0, 0, 0],\n[0, 0, 0, 0, 0],\n[0, 0, 0, 0, 0],\n[0, 0, 0, 0, 0],\n[0, 0, 0, 0, 0],\n[0, 0, 0, 0, 0],\n[0, 0, 0, 0, 0],\n[0, 0, 0, 0, 0],\n[0, 0, 0, 0, 0],\n[0, 0, 0, 0, 0],\n[0, 0, 0, 0, 0],\n[1, 1, 0, 0, 194],\n[1, -1, 0, 0, 193],\n[0, 0, 0, 0, 0],\n[0, 0, 0, 0, 0],\n[0, 0, 0, 0, 0],\n[0, 0, 0, 0, 0],\n[0, 0, 0, 0, 0],\n[0, 0, 0, 0, 0],\n[0, 0, 0, 0, 0],\n[0, 0, 0, 0, 0],\n[0, 0, 0, 0, 0],\n[0, 0, 0, 0, 0],\n[0, 0, 0, 0, 0],\n[0, 0, 0, 0, 0],\n[0, 0, 0, 0, 0],\n[0, 0, 0, 0, 0],\n[0, 0, 0, 0, 0],\n[0, 0, 0, 0, 0],\n[0, 0, 0, 0, 0],\n[0, 0, 0, 0, 0],\n[0, 0, 0, 0, 0],\n[0, 0, 0, 0, 0],\n[0, 0, 0, 0, 0],\n[0, 0, 0, 0, 0],\n[0, 0, 0, 0, 0],\n[0, 0, 0, 0, 0],\n[0, 0, 0, 0, 0],\n[0, 0, 0, 0, 0],\n[0, 0, 0, 0, 0],\n[0, 0, 0, 0, 0],\n[1, 0, 1, 0, 164],\n[0, 0, 0, 0, 0],\n[0, 0, 0, 0, 0],\n[1, 0, -1, 0, 161],\n[0, 0, 0, 0, 0],\n[0, 0, 0, 0, 0],\n[0, 0, 0, 0, 0],\n[0, 0, 0, 0, 0],\n[0, 0, 0, 0, 0],\n[0, 0, 0, 0, 0],\n[0, 0, 0, 0, 0],\n[0, 0, 0, 0, 0],\n[1, 0, 0, 1, 152],\n[0, 0, 0, 0, 0],\n[0, 0, 0, 0, 0],\n[0, 0, 0, 0, 0],\n[0, 0, 0, 0, 0],\n[0, 0, 0, 0, 0],\n[0, 0, 0, 0, 0],\n[1, 0, 0, 1, 145],\n[1, 0, 0, 1, 144],\n[0, 0, 0, 0, 0],\n[0, 0, 0, 0, 0],\n[0, 0, 0, 0, 0],\n[0, 0, 0, 0, 0],\n[0, 0, 0, 0, 0],\n[0, 0, 0, 0, 0],\n[1, 0, 0, -1, 137],\n[0, 0, 0, 0, 0],\n[0, 0, 0, 0, 0],\n[0, 0, 0, 0, 0],\n[1, 0, 1, 0, 133],\n[1, 0, 1, 0, 132],\n[1, 1, 0, 0, 131],\n[1, 1, 0, 0, 130],\n[0, 0, 0, 0, 0],\n[0, 0, 0, 0, 0],\n[0, 0, 0, 0, 0],\n[0, 0, 0, 0, 0],\n[0, 0, 0, 0, 0],\n[0, 0, 0, 0, 0],\n[0, 0, 0, 0, 0],\n[0, 0, 0, 0, 0],\n[0, 0, 0, 0, 0],\n[0, 0, 0, 0, 0],\n[0, 0, 0, 0, 0],\n[0, 0, 0, 0, 0],\n[0, 0, 0, 0, 0],\n[0, 0, 0, 0, 0],\n[0, 0, 0, 0, 0],\n[0, 0, 0, 0, 0],\n[0, 0, 0, 0, 0],\n[0, 0, 0, 0, 0],\n[0, 0, 0, 0, 0],\n[0, 0, 0, 0, 0],\n[0, 0, 0, 0, 0],\n[0, 0, 0, 0, 0],\n[0, 0, 0, 0, 0],\n[0, 0, 0, 0, 0],\n[0, 0, 0, 0, 0],\n[0, 0, 0, 0, 0],\n[0, 0, 0, 0, 0],\n[0, 0, 0, 0, 0],\n[0, 0, 0, 0, 0],\n[1, 0, 0, 1, 100],\n[0, 0, 0, 0, 0],\n[1, 0, 0, 1, 98],\n[0, 0, 0, 0, 0],\n[1, 0, 0, 1, 96],\n[0, 0, 0, 0, 0],\n[0, 0, 0, 0, 0],\n[0, 0, 0, 0, 0],\n[0, 0, 0, 0, 0],\n[0, 0, 0, 0, 0],\n[0, 0, 0, 0, 0],\n[0, 0, 0, 0, 0],\n[1, 0, 1, 0, 88],\n[0, 0, 0, 0, 0],\n[0, 0, 0, 0, 0],\n[0, 0, 0, 0, 0],\n[0, 0, 0, 0, 0],\n[0, 0, 0, 0, 0],\n[1, 0, -1, 0, 82],\n[0, 0, 0, 0, 0],\n[0, 0, 0, 0, 0],\n[0, 0, 0, 0, 0],\n[0, 0, 0, 0, 0],\n[0, 0, 0, 0, 0],\n[0, 0, 0, 0, 0],\n[0, 0, 0, 0, 0],\n[1, 0, 1, 0, 74],\n[0, 0, 0, 0, 0],\n[1, 0, 1, 0, 72],\n[0, 0, 0, 0, 0],\n[1, 0, 0, -1, 70],\n[0, 0, 0, 0, 0],\n[0, 0, 0, 0, 0],\n[1, -1, 0, 0, 67],\n[0, 0, 0, 0, 0],\n[1, -1, 0, 0, 65],\n[0, 0, 0, 0, 0],\n[0, 0, 0, 0, 0],\n[0, 0, 0, 0, 0],\n[0, 0, 0, 0, 0],\n[0, 0, 0, 0, 0],\n[0, 0, 0, 0, 0],\n[0, 0, 0, 0, 0],\n[0, 0, 0, 0, 0],\n[1, 1, 0, 0, 56],\n[0, 0, 0, 0, 0],\n[0, 0, 0, 0, 0],\n[0, 0, 0, 0, 0],\n[1, -1, 0, 0, 52],\n[0, 0, 0, 0, 0],\n[0, 0, 0, 0, 0],\n[0, 0, 0, 0, 0],\n[0, 0, 0, 0, 0],\n[0, 0, 0, 0, 0],\n[0, 0, 0, 0, 0],\n[0, 0, 0, 0, 0],\n[1, 1, 0, 0, 44],\n[0, 0, 0, 0, 0],\n[0, 0, 0, 0, 0],\n[0, 0, 0, 0, 0],\n[1, 1, 0, 0, 40],\n[0, 0, 0, 0, 0],\n[1, 0, 0, -1, 38],\n[1, 0, -1, 0, 37],\n[0, 0, 0, 0, 0],\n[0, 0, 0, 0, 0],\n[0, 0, 0, 0, 0],\n[1, 0, -1, 0, 33],\n[0, 0, 0, 0, 0],\n[0, 0, 0, 0, 0],\n[0, 0, 0, 0, 0],\n[0, 0, 0, 0, 0],\n[1, -1, 0, 0, 28],\n[0, 0, 0, 0, 0],\n[1, 0, -1, 0, 26],\n[1, 0, 0, -1, 25],\n[0, 0, 0, 0, 0],\n[0, 0, 0, 0, 0],\n[0, 0, 0, 0, 0],\n[0, 0, 0, 0, 0],\n[1, -1, 0, 0, 20],\n[0, 0, 0, 0, 0],\n[1, 0, -1, 0, 18],\n[0, 0, 0, 0, 0],\n[0, 0, 0, 0, 0],\n[0, 0, 0, 0, 0],\n[0, 0, 0, 0, 0],\n[0, 0, 0, 0, 0],\n[0, 0, 0, 0, 0],\n[0, 0, 0, 0, 0],\n[0, 0, 0, 0, 0],\n[1, 0, 0, -1, 9],\n[0, 0, 0, 0, 0],\n[0, 0, 0, 0, 0],\n[1, 0, 0, -1, 6],\n[0, 0, 0, 0, 0],\n[0, 0, 0, 0, 0],\n[0, 0, 0, 0, 0],\n[0, 0, 0, 0, 0],\n[0, 0, 0, 0, 0],\n[0, 0, 0, 0, 0]\n]\ntet_table = [\n[-1, -1, -1, -1, -1, -1],\n[0, 0, 0, 0, 0, 0],\n[0, 0, 0, 0, 0, 0],\n[1, 1, 1, 1, 1, 1],\n[4, 4, 4, 4, 4, 4],\n[0, 0, 0, 0, 0, 0],\n[4, 0, 0, 4, 4, -1],\n[1, 1, 1, 1, 1, 1],\n[4, 4, 4, 4, 4, 4],\n[0, 4, 0, 4, 4, -1],\n[0, 0, 0, 0, 0, 0],\n[1, 1, 1, 1, 1, 1],\n[5, 5, 5, 5, 5, 5],\n[0, 0, 0, 0, 0, 0],\n[0, 0, 0, 0, 0, 0],\n[1, 1, 1, 1, 1, 1],\n[2, 2, 2, 2, 2, 2],\n[0, 0, 0, 0, 0, 0],\n[2, 0, 2, -1, 0, 2],\n[1, 1, 1, 1, 1, 1],\n[2, -1, 2, 4, 4, 2],\n[0, 0, 0, 0, 0, 0],\n[2, 0, 2, 4, 4, 2],\n[1, 1, 1, 1, 1, 1],\n[2, 4, 2, 4, 4, 2],\n[0, 4, 0, 4, 4, 0],\n[2, 0, 2, 0, 0, 2],\n[1, 1, 1, 1, 1, 1],\n[2, 5, 2, 5, 5, 2],\n[0, 0, 0, 0, 0, 0],\n[2, 0, 2, 0, 0, 2],\n[1, 1, 1, 1, 1, 1],\n[1, 1, 1, 1, 1, 1],\n[0, 1, 1, -1, 0, 1],\n[0, 0, 0, 0, 0, 0],\n[2, 2, 2, 2, 2, 2],\n[4, 1, 1, 4, 4, 1],\n[0, 1, 1, 0, 0, 1],\n[4, 0, 0, 4, 4, 0],\n[2, 2, 2, 2, 2, 2],\n[-1, 1, 1, 4, 4, 1],\n[0, 1, 1, 4, 4, 1],\n[0, 0, 0, 0, 0, 0],\n[2, 2, 2, 2, 2, 2],\n[5, 1, 1, 5, 5, 1],\n[0, 1, 1, 0, 0, 1],\n[0, 0, 0, 0, 0, 0],\n[2, 2, 2, 2, 2, 2],\n[1, 1, 1, 1, 1, 1],\n[0, 0, 0, 0, 0, 0],\n[0, 0, 0, 0, 0, 0],\n[8, 8, 8, 8, 8, 8],\n[1, 1, 1, 4, 4, 1],\n[0, 0, 0, 0, 0, 0],\n[4, 0, 0, 4, 4, 0],\n[4, 4, 4, 4, 4, 4],\n[1, 1, 1, 4, 4, 1],\n[0, 4, 0, 4, 4, 0],\n[0, 0, 0, 0, 0, 0],\n[4, 4, 4, 4, 4, 4],\n[1, 1, 1, 5, 5, 1],\n[0, 0, 0, 0, 0, 0],\n[0, 0, 0, 0, 0, 0],\n[5, 5, 5, 5, 5, 5],\n[6, 6, 6, 6, 6, 6],\n[6, -1, 0, 6, 0, 6],\n[6, 0, 0, 6, 0, 6],\n[6, 1, 1, 6, 1, 6],\n[4, 4, 4, 4, 4, 4],\n[0, 0, 0, 0, 0, 0],\n[4, 0, 0, 4, 4, 4],\n[1, 1, 1, 1, 1, 1],\n[6, 4, -1, 6, 4, 6],\n[6, 4, 0, 6, 4, 6],\n[6, 0, 0, 6, 0, 6],\n[6, 1, 1, 6, 1, 6],\n[5, 5, 5, 5, 5, 5],\n[0, 0, 0, 0, 0, 0],\n[0, 0, 0, 0, 0, 0],\n[1, 1, 1, 1, 1, 1],\n[2, 2, 2, 2, 2, 2],\n[0, 0, 0, 0, 0, 0],\n[2, 0, 2, 2, 0, 2],\n[1, 1, 1, 1, 1, 1],\n[2, 2, 2, 2, 2, 2],\n[0, 0, 0, 0, 0, 0],\n[2, 0, 2, 2, 2, 2],\n[1, 1, 1, 1, 1, 1],\n[2, 4, 2, 2, 4, 2],\n[0, 4, 0, 4, 4, 0],\n[2, 0, 2, 2, 0, 2],\n[1, 1, 1, 1, 1, 1],\n[2, 2, 2, 2, 2, 2],\n[0, 0, 0, 0, 0, 0],\n[0, 0, 0, 0, 0, 0],\n[1, 1, 1, 1, 1, 1],\n[6, 1, 1, 6, -1, 6],\n[6, 1, 1, 6, 0, 6],\n[6, 0, 0, 6, 0, 6],\n[6, 2, 2, 6, 2, 6],\n[4, 1, 1, 4, 4, 1],\n[0, 1, 1, 0, 0, 1],\n[4, 0, 0, 4, 4, 4],\n[2, 2, 2, 2, 2, 2],\n[6, 1, 1, 6, 4, 6],\n[6, 1, 1, 6, 4, 6],\n[6, 0, 0, 6, 0, 6],\n[6, 2, 2, 6, 2, 6],\n[5, 1, 1, 5, 5, 1],\n[0, 1, 1, 0, 0, 1],\n[0, 0, 0, 0, 0, 0],\n[2, 2, 2, 2, 2, 2],\n[1, 1, 1, 1, 1, 1],\n[0, 0, 0, 0, 0, 0],\n[0, 0, 0, 0, 0, 0],\n[6, 6, 6, 6, 6, 6],\n[1, 1, 1, 1, 1, 1],\n[0, 0, 0, 0, 0, 0],\n[0, 0, 0, 0, 0, 0],\n[4, 4, 4, 4, 4, 4],\n[1, 1, 1, 1, 4, 1],\n[0, 4, 0, 4, 4, 0],\n[0, 0, 0, 0, 0, 0],\n[4, 4, 4, 4, 4, 4],\n[1, 1, 1, 1, 1, 1],\n[0, 0, 0, 0, 0, 0],\n[0, 5, 0, 5, 0, 5],\n[5, 5, 5, 5, 5, 5],\n[5, 5, 5, 5, 5, 5],\n[0, 5, 0, 5, 0, 5],\n[-1, 5, 0, 5, 0, 5],\n[1, 5, 1, 5, 1, 5],\n[4, 5, -1, 5, 4, 5],\n[0, 5, 0, 5, 0, 5],\n[4, 5, 0, 5, 4, 5],\n[1, 5, 1, 5, 1, 5],\n[4, 4, 4, 4, 4, 4],\n[0, 4, 0, 4, 4, 4],\n[0, 0, 0, 0, 0, 0],\n[1, 1, 1, 1, 1, 1],\n[6, 6, 6, 6, 6, 6],\n[0, 0, 0, 0, 0, 0],\n[0, 0, 0, 0, 0, 0],\n[1, 1, 1, 1, 1, 1],\n[2, 5, 2, 5, -1, 5],\n[0, 5, 0, 5, 0, 5],\n[2, 5, 2, 5, 0, 5],\n[1, 5, 1, 5, 1, 5],\n[2, 5, 2, 5, 4, 5],\n[0, 5, 0, 5, 0, 5],\n[2, 5, 2, 5, 4, 5],\n[1, 5, 1, 5, 1, 5],\n[2, 4, 2, 4, 4, 2],\n[0, 4, 0, 4, 4, 4],\n[2, 0, 2, 0, 0, 2],\n[1, 1, 1, 1, 1, 1],\n[2, 6, 2, 6, 6, 2],\n[0, 0, 0, 0, 0, 0],\n[2, 0, 2, 0, 0, 2],\n[1, 1, 1, 1, 1, 1],\n[1, 1, 1, 1, 1, 1],\n[0, 1, 1, 1, 0, 1],\n[0, 0, 0, 0, 0, 0],\n[2, 2, 2, 2, 2, 2],\n[4, 1, 1, 1, 4, 1],\n[0, 1, 1, 1, 0, 1],\n[4, 0, 0, 4, 4, 0],\n[2, 2, 2, 2, 2, 2],\n[1, 1, 1, 1, 1, 1],\n[0, 1, 1, 1, 1, 1],\n[0, 0, 0, 0, 0, 0],\n[2, 2, 2, 2, 2, 2],\n[1, 1, 1, 1, 1, 1],\n[0, 0, 0, 0, 0, 0],\n[0, 0, 0, 0, 0, 0],\n[2, 2, 2, 2, 2, 2],\n[1, 1, 1, 1, 1, 1],\n[0, 0, 0, 0, 0, 0],\n[0, 0, 0, 0, 0, 0],\n[5, 5, 5, 5, 5, 5],\n[1, 1, 1, 1, 4, 1],\n[0, 0, 0, 0, 0, 0],\n[4, 0, 0, 4, 4, 0],\n[4, 4, 4, 4, 4, 4],\n[1, 1, 1, 1, 1, 1],\n[0, 0, 0, 0, 0, 0],\n[0, 0, 0, 0, 0, 0],\n[4, 4, 4, 4, 4, 4],\n[1, 1, 1, 1, 1, 1],\n[6, 0, 0, 6, 0, 6],\n[0, 0, 0, 0, 0, 0],\n[6, 6, 6, 6, 6, 6],\n[5, 5, 5, 5, 5, 5],\n[5, 5, 0, 5, 0, 5],\n[5, 5, 0, 5, 0, 5],\n[5, 5, 1, 5, 1, 5],\n[4, 4, 4, 4, 4, 4],\n[0, 0, 0, 0, 0, 0],\n[4, 4, 0, 4, 4, 4],\n[1, 1, 1, 1, 1, 1],\n[4, 4, 4, 4, 4, 4],\n[4, 4, 0, 4, 4, 4],\n[0, 0, 0, 0, 0, 0],\n[1, 1, 1, 1, 1, 1],\n[8, 8, 8, 8, 8, 8],\n[0, 0, 0, 0, 0, 0],\n[0, 0, 0, 0, 0, 0],\n[1, 1, 1, 1, 1, 1],\n[2, 2, 2, 2, 2, 2],\n[0, 0, 0, 0, 0, 0],\n[2, 2, 2, 2, 0, 2],\n[1, 1, 1, 1, 1, 1],\n[2, 2, 2, 2, 2, 2],\n[0, 0, 0, 0, 0, 0],\n[2, 2, 2, 2, 2, 2],\n[1, 1, 1, 1, 1, 1],\n[2, 2, 2, 2, 2, 2],\n[0, 0, 0, 0, 0, 0],\n[0, 0, 0, 0, 0, 0],\n[4, 1, 1, 4, 4, 1],\n[2, 2, 2, 2, 2, 2],\n[0, 0, 0, 0, 0, 0],\n[0, 0, 0, 0, 0, 0],\n[1, 1, 1, 1, 1, 1],\n[1, 1, 1, 1, 1, 1],\n[1, 1, 1, 1, 0, 1],\n[0, 0, 0, 0, 0, 0],\n[2, 2, 2, 2, 2, 2],\n[1, 1, 1, 1, 1, 1],\n[0, 0, 0, 0, 0, 0],\n[0, 0, 0, 0, 0, 0],\n[2, 4, 2, 4, 4, 2],\n[1, 1, 1, 1, 1, 1],\n[1, 1, 1, 1, 1, 1],\n[0, 0, 0, 0, 0, 0],\n[2, 2, 2, 2, 2, 2],\n[1, 1, 1, 1, 1, 1],\n[0, 0, 0, 0, 0, 0],\n[0, 0, 0, 0, 0, 0],\n[2, 2, 2, 2, 2, 2],\n[1, 1, 1, 1, 1, 1],\n[0, 0, 0, 0, 0, 0],\n[0, 0, 0, 0, 0, 0],\n[5, 5, 5, 5, 5, 5],\n[1, 1, 1, 1, 1, 1],\n[0, 0, 0, 0, 0, 0],\n[0, 0, 0, 0, 0, 0],\n[4, 4, 4, 4, 4, 4],\n[1, 1, 1, 1, 1, 1],\n[0, 0, 0, 0, 0, 0],\n[0, 0, 0, 0, 0, 0],\n[4, 4, 4, 4, 4, 4],\n[1, 1, 1, 1, 1, 1],\n[0, 0, 0, 0, 0, 0],\n[0, 0, 0, 0, 0, 0],\n[12, 12, 12, 12, 12, 12]\n]\n\n\ngflex_num_triangles_table = [0,1,1,2,1,2,2,1]\n\ngflex_configuration_table = [\n    ## 000\n        [-1, -1, -1, -1, -1, -1],\n    ## 001\n        [ 4,  2,  5, -1, -1, -1],\n    ## 010\n        [ 3,  1,  4, -1, -1, -1],\n    ## 011\n        [ 3,  1,  2,  3,  2,  5],\n    ## 100\n        [ 0,  3,  5, -1, -1, -1],\n    ## 101\n        [ 0,  3,  4,  0,  4,  2],\n    ## 110\n        [ 0,  1,  4,  0,  4,  5],\n    ## 111\n        [ 0,  1,  2, -1, -1, -1],\n]"
  },
  {
    "path": "geometry/gshell_flexicubes.py",
    "content": "# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES.  All rights reserved.\n#\n# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property\n# and proprietary rights in and to this software, related documentation\n# and any modifications thereto.  Any use, reproduction, disclosure or\n# distribution of this software and related documentation without an express\n# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited.\nimport torch\nfrom .flexicubes_table import *\n\n__all__ = [\n    'GShellFlexiCubes'\n]\n\n\nclass GShellFlexiCubes:\n    \"\"\"\n    This class implements the FlexiCubes method for extracting meshes from scalar fields. \n    It maintains a series of lookup tables and indices to support the mesh extraction process. \n    FlexiCubes, a differentiable variant of the Dual Marching Cubes (DMC) scheme, enhances \n    the geometric fidelity and mesh quality of reconstructed meshes by dynamically adjusting \n    the surface representation through gradient-based optimization.\n\n    During instantiation, the class loads DMC tables from a file and transforms them into \n    PyTorch tensors on the specified device.\n\n    Attributes:\n        device (str): Specifies the computational device (default is \"cuda\").\n        dmc_table (torch.Tensor): Dual Marching Cubes (DMC) table that encodes the edges \n            associated with each dual vertex in 256 Marching Cubes (MC) configurations.\n        num_vd_table (torch.Tensor): Table holding the number of dual vertices in each of \n            the 256 MC configurations.\n        check_table (torch.Tensor): Table resolving ambiguity in cases C16 and C19 \n            of the DMC configurations.\n        tet_table (torch.Tensor): Lookup table used in tetrahedralizing the isosurface.\n        quad_split_1 (torch.Tensor): Indices for splitting a quad into two triangles \n            along one diagonal.\n        quad_split_2 (torch.Tensor): Alternative indices for splitting a quad into \n            two triangles along the other diagonal.\n        quad_split_train (torch.Tensor): Indices for splitting a quad into four triangles \n            during training by connecting all edges to their midpoints.\n        cube_corners (torch.Tensor): Defines the positions of a standard unit cube's \n            eight corners in 3D space, ordered starting from the origin (0,0,0), \n            moving along the x-axis, then y-axis, and finally z-axis. \n            Used as a blueprint for generating a voxel grid.\n        cube_corners_idx (torch.Tensor): Cube corners indexed as powers of 2, used \n            to retrieve the case id.\n        cube_edges (torch.Tensor): Edge connections in a cube, listed in pairs. \n            Used to retrieve edge vertices in DMC.\n        edge_dir_table (torch.Tensor): A mapping tensor that associates edge indices with \n            their corresponding axis. For instance, edge_dir_table[0] = 0 indicates that the \n            first edge is oriented along the x-axis. \n        dir_faces_table (torch.Tensor): A tensor that maps the corresponding axis of shared edges \n            across four adjacent cubes to the shared faces of these cubes. For instance, \n            dir_faces_table[0] = [5, 4] implies that for four cubes sharing an edge along \n            the x-axis, the first and second cubes share faces indexed as 5 and 4, respectively. \n            This tensor is only utilized during isosurface tetrahedralization.\n        adj_pairs (torch.Tensor): \n            A tensor containing index pairs that correspond to neighboring cubes that share the same edge.\n        qef_reg_scale (float):\n            The scaling factor applied to the regularization loss to prevent issues with singularity \n            when solving the QEF. This parameter is only used when a 'grad_func' is specified.\n        weight_scale (float):\n            The scale of weights in FlexiCubes. Should be between 0 and 1.\n    \"\"\"\n\n    def __init__(self, device=\"cuda\", qef_reg_scale=1e-3, weight_scale=0.99):\n\n        self.device = device\n        self.dmc_table = torch.tensor(dmc_table, dtype=torch.long, device=device, requires_grad=False)\n        self.num_vd_table = torch.tensor(num_vd_table,\n                                         dtype=torch.long, device=device, requires_grad=False)\n        self.check_table = torch.tensor(\n            check_table,\n            dtype=torch.long, device=device, requires_grad=False)\n\n        self.tet_table = torch.tensor(tet_table, dtype=torch.long, device=device, requires_grad=False)\n        self.quad_split_1 = torch.tensor([0, 1, 2, 0, 2, 3], dtype=torch.long, device=device, requires_grad=False)\n        self.quad_split_2 = torch.tensor([0, 1, 3, 3, 1, 2], dtype=torch.long, device=device, requires_grad=False)\n        self.quad_split_train = torch.tensor(\n            [0, 1, 1, 2, 2, 3, 3, 0], dtype=torch.long, device=device, requires_grad=False)\n\n        self.cube_corners = torch.tensor([[0, 0, 0], [1, 0, 0], [0, 1, 0], [1, 1, 0], [0, 0, 1], [\n                                         1, 0, 1], [0, 1, 1], [1, 1, 1]], dtype=torch.float, device=device)\n        self.cube_corners_idx = torch.pow(2, torch.arange(8, requires_grad=False))\n        self.cube_edges = torch.tensor([0, 1, 1, 5, 4, 5, 0, 4, 2, 3, 3, 7, 6, 7, 2, 6,\n                                       2, 0, 3, 1, 7, 5, 6, 4], dtype=torch.long, device=device, requires_grad=False)\n\n        self.edge_dir_table = torch.tensor([0, 2, 0, 2, 0, 2, 0, 2, 1, 1, 1, 1],\n                                           dtype=torch.long, device=device)\n        self.dir_faces_table = torch.tensor([\n            [[5, 4], [3, 2], [4, 5], [2, 3]],\n            [[5, 4], [1, 0], [4, 5], [0, 1]],\n            [[3, 2], [1, 0], [2, 3], [0, 1]]\n        ], dtype=torch.long, device=device)\n        self.adj_pairs = torch.tensor([0, 1, 1, 3, 3, 2, 2, 0], dtype=torch.long, device=device)\n        self.qef_reg_scale = qef_reg_scale\n        self.weight_scale = weight_scale\n\n        self.gflex_num_triangles_table = torch.tensor(gflex_num_triangles_table, dtype=torch.long, device=device, requires_grad=False)\n        self.gflex_configuration_table = torch.tensor(gflex_configuration_table, dtype=torch.long, device=device, requires_grad=False)\n\n    def construct_voxel_grid(self, res):\n        \"\"\"\n        Generates a voxel grid based on the specified resolution.\n\n        Args:\n            res (int or list[int]): The resolution of the voxel grid. If an integer\n                is provided, it is used for all three dimensions. If a list or tuple \n                of 3 integers is provided, they define the resolution for the x, \n                y, and z dimensions respectively.\n\n        Returns:\n            (torch.Tensor, torch.Tensor): Returns the vertices and the indices of the \n                cube corners (index into vertices) of the constructed voxel grid. \n                The vertices are centered at the origin, with the length of each \n                dimension in the grid being one.\n        \"\"\"\n        base_cube_f = torch.arange(8).to(self.device)\n        if isinstance(res, int):\n            res = (res, res, res)\n        voxel_grid_template = torch.ones(res, device=self.device)\n\n        res = torch.tensor([res], dtype=torch.float, device=self.device)\n        coords = torch.nonzero(voxel_grid_template).float() / res  # N, 3\n        verts = (self.cube_corners.unsqueeze(0) / res + coords.unsqueeze(1)).reshape(-1, 3)\n        cubes = (base_cube_f.unsqueeze(0) +\n                 torch.arange(coords.shape[0], device=self.device).unsqueeze(1) * 8).reshape(-1)\n\n        verts_rounded = torch.round(verts * 10**5) / (10**5)\n        verts_unique, inverse_indices = torch.unique(verts_rounded, dim=0, return_inverse=True)\n        cubes = inverse_indices[cubes.reshape(-1)].reshape(-1, 8)\n\n        return verts_unique - 0.5, cubes\n\n    def __call__(self, x_nx3, s_n, nu_n, cube_fx8, res, beta_fx12=None, alpha_fx8=None,\n                 gamma_f=None, training=False, output_tetmesh=False, grad_func=None):\n        r\"\"\"\n        Main function for mesh extraction from scalar field using FlexiCubes. This function converts \n        discrete signed distance fields, encoded on voxel grids and additional per-cube parameters, \n        to triangle or tetrahedral meshes using a differentiable operation as described in \n        `Flexible Isosurface Extraction for Gradient-Based Mesh Optimization`_. FlexiCubes enhances \n        mesh quality and geometric fidelity by adjusting the surface representation based on gradient \n        optimization. The output surface is differentiable with respect to the input vertex positions, \n        scalar field values, and weight parameters.\n\n        If you intend to extract a surface mesh from a fixed Signed Distance Field without the \n        optimization of parameters, it is suggested to provide the \"grad_func\" which should \n        return the surface gradient at any given 3D position. When grad_func is provided, the process \n        to determine the dual vertex position adapts to solve a Quadratic Error Function (QEF), as \n        described in the `Manifold Dual Contouring`_ paper, and employs an smart splitting strategy. \n        Please note, this approach is non-differentiable.\n\n        For more details and example usage in optimization, refer to the \n        `Flexible Isosurface Extraction for Gradient-Based Mesh Optimization`_ SIGGRAPH 2023 paper.\n\n        Args:\n            x_nx3 (torch.Tensor): Coordinates of the voxel grid vertices, can be deformed.\n            s_n (torch.Tensor): Scalar field values at each vertex of the voxel grid. Negative values \n                denote that the corresponding vertex resides inside the isosurface. This affects \n                the directions of the extracted triangle faces and volume to be tetrahedralized.\n            cube_fx8 (torch.Tensor): Indices of 8 vertices for each cube in the voxel grid.\n            res (int or list[int]): The resolution of the voxel grid. If an integer is provided, it \n                is used for all three dimensions. If a list or tuple of 3 integers is provided, they \n                specify the resolution for the x, y, and z dimensions respectively.\n            beta_fx12 (torch.Tensor, optional): Weight parameters for the cube edges to adjust dual \n                vertices positioning. Defaults to uniform value for all edges.\n            alpha_fx8 (torch.Tensor, optional): Weight parameters for the cube corners to adjust dual \n                vertices positioning. Defaults to uniform value for all vertices.\n            gamma_f (torch.Tensor, optional): Weight parameters to control the splitting of \n                quadrilaterals into triangles. Defaults to uniform value for all cubes.\n            training (bool, optional): If set to True, applies differentiable quad splitting for \n                training. Defaults to False.\n            output_tetmesh (bool, optional): If set to True, outputs a tetrahedral mesh, otherwise, \n                outputs a triangular mesh. Defaults to False.\n            grad_func (callable, optional): A function to compute the surface gradient at specified \n                3D positions (input: Nx3 positions). The function should return gradients as an Nx3 \n                tensor. If None, the original FlexiCubes algorithm is utilized. Defaults to None.\n\n        Returns:\n            (torch.Tensor, torch.LongTensor, torch.Tensor): Tuple containing:\n                - Vertices for the extracted triangular/tetrahedral mesh.\n                - Faces for the extracted triangular/tetrahedral mesh.\n                - Regularizer L_dev, computed per dual vertex.\n\n        .. _Flexible Isosurface Extraction for Gradient-Based Mesh Optimization:\n            https://research.nvidia.com/labs/toronto-ai/flexicubes/\n        .. _Manifold Dual Contouring:\n            https://people.engr.tamu.edu/schaefer/research/dualsimp_tvcg.pdf\n        \"\"\"\n\n        surf_cubes, occ_fx8 = self._identify_surf_cubes(s_n, cube_fx8)\n        if surf_cubes.sum() == 0:\n            return torch.zeros(\n                (0, 3),\n                device=self.device), torch.zeros(\n                (0, 4),\n                dtype=torch.long, device=self.device) if output_tetmesh else torch.zeros(\n                (0, 3),\n                dtype=torch.long, device=self.device), torch.zeros(\n                (0),\n                device=self.device)\n        beta_fx12, alpha_fx8, gamma_f = self._normalize_weights(beta_fx12, alpha_fx8, gamma_f, surf_cubes)\n\n        case_ids = self._get_case_id(occ_fx8, surf_cubes, res)\n\n        surf_edges, idx_map, edge_counts, surf_edges_mask = self._identify_surf_edges(s_n, cube_fx8, surf_cubes)\n\n        vd, nu_d, nu_d_stopvgd, L_dev, vd_gamma, vd_idx_map = self._compute_vd(\n            x_nx3, cube_fx8[surf_cubes], surf_edges, s_n, nu_n, case_ids, beta_fx12, alpha_fx8, gamma_f, idx_map, grad_func)\n        vertices, nus, nus_stopvgd, faces, s_edges, edge_indices = self._triangulate(\n            s_n, surf_edges, vd, nu_d, nu_d_stopvgd, vd_gamma, edge_counts, idx_map, vd_idx_map, surf_edges_mask, training, grad_func)\n        vertices_open, faces_open, nus_open_stopvgd, nus_boundary_stopvgd = self._triangulate_msdf(vertices, faces, nus, nus_stopvgd)\n        if not output_tetmesh:\n            extra = {\n                'n_verts_watertight': vertices.size(0),\n                'vertices_watertight': vertices,\n                'faces_watertight': faces, \n                'msdf': nus_open_stopvgd,\n                'msdf_watertight': nus,\n                'msdf_boundary': nus_boundary_stopvgd,\n            }\n            # print(torch.any(torch.isnan(nus_open_stopvgd)))\n            return vertices_open, faces_open, L_dev, extra\n        else:\n            raise NotImplementedError\n            vertices, tets = self._tetrahedralize(\n                x_nx3, s_n, cube_fx8, vertices, faces, surf_edges, s_edges, vd_idx_map, case_ids, edge_indices,\n                surf_cubes, training)\n            return vertices, tets, L_dev\n\n    def _compute_reg_loss(self, vd, ue, edge_group_to_vd, vd_num_edges):\n        \"\"\"\n        Regularizer L_dev as in Equation 8\n        \"\"\"\n        dist = torch.norm(ue - torch.index_select(input=vd, index=edge_group_to_vd, dim=0), dim=-1)\n        mean_l2 = torch.zeros_like(vd[:, 0])\n        mean_l2 = (mean_l2).index_add_(0, edge_group_to_vd, dist) / vd_num_edges.squeeze(1).float()\n        mad = (dist - torch.index_select(input=mean_l2, index=edge_group_to_vd, dim=0)).abs()\n        return mad\n\n    def _normalize_weights(self, beta_fx12, alpha_fx8, gamma_f, surf_cubes):\n        \"\"\"\n        Normalizes the given weights to be non-negative. If input weights are None, it creates and returns a set of weights of ones.\n        \"\"\"\n        n_cubes = surf_cubes.shape[0]\n\n        if beta_fx12 is not None:\n            beta_fx12 = (torch.tanh(beta_fx12) * self.weight_scale + 1)\n        else:\n            beta_fx12 = torch.ones((n_cubes, 12), dtype=torch.float, device=self.device)\n\n        if alpha_fx8 is not None:\n            alpha_fx8 = (torch.tanh(alpha_fx8) * self.weight_scale + 1)\n        else:\n            alpha_fx8 = torch.ones((n_cubes, 8), dtype=torch.float, device=self.device)\n\n        if gamma_f is not None:\n            gamma_f = torch.sigmoid(gamma_f) * self.weight_scale + (1 - self.weight_scale)/2\n        else:\n            gamma_f = torch.ones((n_cubes), dtype=torch.float, device=self.device)\n\n        return beta_fx12[surf_cubes], alpha_fx8[surf_cubes], gamma_f[surf_cubes]\n\n    @torch.no_grad()\n    def _get_case_id(self, occ_fx8, surf_cubes, res):\n        \"\"\"\n        Obtains the ID of topology cases based on cell corner occupancy. This function resolves the \n        ambiguity in the Dual Marching Cubes (DMC) configurations as described in Section 1.3 of the \n        supplementary material. It should be noted that this function assumes a regular grid.\n        \"\"\"\n        case_ids = (occ_fx8[surf_cubes] * self.cube_corners_idx.to(self.device).unsqueeze(0)).sum(-1)\n\n        problem_config = self.check_table.to(self.device)[case_ids]\n        to_check = problem_config[..., 0] == 1\n        problem_config = problem_config[to_check]\n        if not isinstance(res, (list, tuple)):\n            res = [res, res, res]\n\n        # The 'problematic_configs' only contain configurations for surface cubes. Next, we construct a 3D array,\n        # 'problem_config_full', to store configurations for all cubes (with default config for non-surface cubes).\n        # This allows efficient checking on adjacent cubes.\n        problem_config_full = torch.zeros(list(res) + [5], device=self.device, dtype=torch.long)\n        vol_idx = torch.nonzero(problem_config_full[..., 0] == 0)  # N, 3\n        vol_idx_problem = vol_idx[surf_cubes][to_check]\n        problem_config_full[vol_idx_problem[..., 0], vol_idx_problem[..., 1], vol_idx_problem[..., 2]] = problem_config\n        vol_idx_problem_adj = vol_idx_problem + problem_config[..., 1:4]\n\n        within_range = (\n            vol_idx_problem_adj[..., 0] >= 0) & (\n            vol_idx_problem_adj[..., 0] < res[0]) & (\n            vol_idx_problem_adj[..., 1] >= 0) & (\n            vol_idx_problem_adj[..., 1] < res[1]) & (\n            vol_idx_problem_adj[..., 2] >= 0) & (\n            vol_idx_problem_adj[..., 2] < res[2])\n\n        vol_idx_problem = vol_idx_problem[within_range]\n        vol_idx_problem_adj = vol_idx_problem_adj[within_range]\n        problem_config = problem_config[within_range]\n        problem_config_adj = problem_config_full[vol_idx_problem_adj[..., 0],\n                                                 vol_idx_problem_adj[..., 1], vol_idx_problem_adj[..., 2]]\n        # If two cubes with cases C16 and C19 share an ambiguous face, both cases are inverted.\n        to_invert = (problem_config_adj[..., 0] == 1)\n        idx = torch.arange(case_ids.shape[0], device=self.device)[to_check][within_range][to_invert]\n        case_ids.index_put_((idx,), problem_config[to_invert][..., -1])\n        return case_ids\n\n    @torch.no_grad()\n    def _identify_surf_edges(self, s_n, cube_fx8, surf_cubes):\n        \"\"\"\n        Identifies grid edges that intersect with the underlying surface by checking for opposite signs. As each edge \n        can be shared by multiple cubes, this function also assigns a unique index to each surface-intersecting edge \n        and marks the cube edges with this index.\n        \"\"\"\n        occ_n = s_n < 0\n        all_edges = cube_fx8[surf_cubes][:, self.cube_edges].reshape(-1, 2)\n        unique_edges, _idx_map, counts = torch.unique(all_edges, dim=0, return_inverse=True, return_counts=True)\n\n        unique_edges = unique_edges.long()\n        mask_edges = occ_n[unique_edges.reshape(-1)].reshape(-1, 2).sum(-1) == 1\n\n        surf_edges_mask = mask_edges[_idx_map]\n        counts = counts[_idx_map]\n\n        mapping = torch.ones((unique_edges.shape[0]), dtype=torch.long, device=cube_fx8.device) * -1\n        mapping[mask_edges] = torch.arange(mask_edges.sum(), device=cube_fx8.device)\n        # Shaped as [number of cubes x 12 edges per cube]. This is later used to map a cube edge to the unique index\n        # for a surface-intersecting edge. Non-surface-intersecting edges are marked with -1.\n        idx_map = mapping[_idx_map]\n        surf_edges = unique_edges[mask_edges]\n        return surf_edges, idx_map, counts, surf_edges_mask\n\n    @torch.no_grad()\n    def _identify_surf_cubes(self, s_n, cube_fx8):\n        \"\"\"\n        Identifies grid cubes that intersect with the underlying surface by checking if the signs at \n        all corners are not identical.\n        \"\"\"\n        occ_n = s_n < 0\n        occ_fx8 = occ_n[cube_fx8.reshape(-1)].reshape(-1, 8)\n        _occ_sum = torch.sum(occ_fx8, -1)\n        surf_cubes = (_occ_sum > 0) & (_occ_sum < 8)\n        return surf_cubes, occ_fx8\n\n    def _linear_interp(self, edges_weight, edges_x):\n        \"\"\"\n        Computes the location of zero-crossings on 'edges_x' using linear interpolation with 'edges_weight'.\n        \"\"\"\n        edge_dim = edges_weight.dim() - 2\n        assert edges_weight.shape[edge_dim] == 2\n        edges_weight = torch.cat([torch.index_select(input=edges_weight, index=torch.tensor(1, device=self.device), dim=edge_dim), -\n                                 torch.index_select(input=edges_weight, index=torch.tensor(0, device=self.device), dim=edge_dim)], edge_dim)\n        denominator = edges_weight.sum(edge_dim)\n        ue = (edges_x * edges_weight).sum(edge_dim) / denominator\n        return ue\n\n    def _linear_interp_nonan(self, edges_weight, edges_x):\n        \"\"\"\n        Computes the location of zero-crossings on 'edges_x' using linear interpolation with 'edges_weight'.\n        \"\"\"\n        edge_dim = edges_weight.dim() - 2\n        assert edges_weight.shape[edge_dim] == 2\n        edges_weight = torch.cat([torch.index_select(input=edges_weight, index=torch.tensor(1, device=self.device), dim=edge_dim), -\n                                 torch.index_select(input=edges_weight, index=torch.tensor(0, device=self.device), dim=edge_dim)], edge_dim)\n        denominator = edges_weight.sum(edge_dim, keepdim=True).expand(-1, 2, 1)\n        with torch.no_grad():\n            nonzero_mask = (denominator.abs() > 0)\n        scale = torch.zeros_like(edges_weight)\n        scale[nonzero_mask] = edges_weight[nonzero_mask] / denominator[nonzero_mask]\n        ue = (edges_x * scale).sum(edge_dim)\n        return ue\n\n    def _solve_vd_QEF(self, p_bxnx3, norm_bxnx3, c_bx3=None):\n        p_bxnx3 = p_bxnx3.reshape(-1, 7, 3)\n        norm_bxnx3 = norm_bxnx3.reshape(-1, 7, 3)\n        c_bx3 = c_bx3.reshape(-1, 3)\n        A = norm_bxnx3\n        B = ((p_bxnx3) * norm_bxnx3).sum(-1, keepdims=True)\n\n        A_reg = (torch.eye(3, device=p_bxnx3.device) * self.qef_reg_scale).unsqueeze(0).repeat(p_bxnx3.shape[0], 1, 1)\n        B_reg = (self.qef_reg_scale * c_bx3).unsqueeze(-1)\n        A = torch.cat([A, A_reg], 1)\n        B = torch.cat([B, B_reg], 1)\n        dual_verts = torch.linalg.lstsq(A, B).solution.squeeze(-1)\n        return dual_verts\n\n    def _compute_vd(self, x_nx3, surf_cubes_fx8, surf_edges, s_n, nu_n, case_ids, beta_fx12, alpha_fx8, gamma_f, idx_map, grad_func):\n        \"\"\"\n        Computes the location of dual vertices as described in Section 4.2\n        \"\"\"\n        alpha_nx12x2 = torch.index_select(input=alpha_fx8, index=self.cube_edges, dim=1).reshape(-1, 12, 2)\n        surf_edges_x = torch.index_select(input=x_nx3, index=surf_edges.reshape(-1), dim=0).reshape(-1, 2, 3)\n        surf_edges_s = torch.index_select(input=s_n, index=surf_edges.reshape(-1), dim=0).reshape(-1, 2, 1)\n        surf_edges_nu = torch.index_select(input=nu_n, index=surf_edges.reshape(-1), dim=0).reshape(-1, 2, 1)\n        zero_crossing = self._linear_interp(surf_edges_s, surf_edges_x)\n\n        idx_map = idx_map.reshape(-1, 12)\n        num_vd = torch.index_select(input=self.num_vd_table, index=case_ids, dim=0)\n        edge_group, edge_group_to_vd, edge_group_to_cube, vd_num_edges, vd_gamma = [], [], [], [], []\n\n        total_num_vd = 0\n        vd_idx_map = torch.zeros((case_ids.shape[0], 12), dtype=torch.long, device=self.device, requires_grad=False)\n        if grad_func is not None:\n            normals = torch.nn.functional.normalize(grad_func(zero_crossing), dim=-1)\n            vd = []\n        for num in torch.unique(num_vd):\n            cur_cubes = (num_vd == num)  # consider cubes with the same numbers of vd emitted (for batching)\n            curr_num_vd = cur_cubes.sum() * num\n            curr_edge_group = self.dmc_table[case_ids[cur_cubes], :num].reshape(-1, num * 7)\n            curr_edge_group_to_vd = torch.arange(\n                curr_num_vd, device=self.device).unsqueeze(-1).repeat(1, 7) + total_num_vd\n            total_num_vd += curr_num_vd\n            curr_edge_group_to_cube = torch.arange(idx_map.shape[0], device=self.device)[\n                cur_cubes].unsqueeze(-1).repeat(1, num * 7).reshape_as(curr_edge_group)\n\n            curr_mask = (curr_edge_group != -1)\n            edge_group.append(torch.masked_select(curr_edge_group, curr_mask))\n            edge_group_to_vd.append(torch.masked_select(curr_edge_group_to_vd.reshape_as(curr_edge_group), curr_mask))\n            edge_group_to_cube.append(torch.masked_select(curr_edge_group_to_cube, curr_mask))\n            vd_num_edges.append(curr_mask.reshape(-1, 7).sum(-1, keepdims=True))\n            vd_gamma.append(torch.masked_select(gamma_f, cur_cubes).unsqueeze(-1).repeat(1, num).reshape(-1))\n\n            if grad_func is not None:\n                with torch.no_grad():\n                    cube_e_verts_idx = idx_map[cur_cubes]\n                    curr_edge_group[~curr_mask] = 0\n\n                    verts_group_idx = torch.gather(input=cube_e_verts_idx, dim=1, index=curr_edge_group)\n                    verts_group_idx[verts_group_idx == -1] = 0\n                    verts_group_pos = torch.index_select(\n                        input=zero_crossing, index=verts_group_idx.reshape(-1), dim=0).reshape(-1, num.item(), 7, 3)\n                    v0 = x_nx3[surf_cubes_fx8[cur_cubes][:, 0]].reshape(-1, 1, 1, 3).repeat(1, num.item(), 1, 1)\n                    curr_mask = curr_mask.reshape(-1, num.item(), 7, 1)\n                    verts_centroid = (verts_group_pos * curr_mask).sum(2) / (curr_mask.sum(2))\n\n                    normals_bx7x3 = torch.index_select(input=normals, index=verts_group_idx.reshape(-1), dim=0).reshape(\n                        -1, num.item(), 7,\n                        3)\n                    curr_mask = curr_mask.squeeze(2)\n                    vd.append(self._solve_vd_QEF((verts_group_pos - v0) * curr_mask, normals_bx7x3 * curr_mask,\n                                                 verts_centroid - v0.squeeze(2)) + v0.reshape(-1, 3))\n        edge_group = torch.cat(edge_group)\n        edge_group_to_vd = torch.cat(edge_group_to_vd)\n        edge_group_to_cube = torch.cat(edge_group_to_cube)\n        vd_num_edges = torch.cat(vd_num_edges)\n        vd_gamma = torch.cat(vd_gamma)\n\n        if grad_func is not None:\n            vd = torch.cat(vd)\n            L_dev = torch.zeros([1], device=self.device)\n        else:\n            vd = torch.zeros((total_num_vd, 3), device=self.device)\n            nu_d = torch.zeros((total_num_vd, 1), device=self.device)\n            beta_sum = torch.zeros((total_num_vd, 1), device=self.device)\n\n            idx_group = torch.gather(input=idx_map.reshape(-1), dim=0, index=edge_group_to_cube * 12 + edge_group)\n\n            x_group = torch.index_select(input=surf_edges_x, index=idx_group.reshape(-1), dim=0).reshape(-1, 2, 3)\n            s_group = torch.index_select(input=surf_edges_s, index=idx_group.reshape(-1), dim=0).reshape(-1, 2, 1)\n            nu_group = torch.index_select(input=surf_edges_nu, index=idx_group.reshape(-1), dim=0).reshape(-1, 2, 1)\n\n            zero_crossing_group = torch.index_select(\n                input=zero_crossing, index=idx_group.reshape(-1), dim=0).reshape(-1, 3)\n\n            alpha_group = torch.index_select(input=alpha_nx12x2.reshape(-1, 2), dim=0,\n                                             index=edge_group_to_cube * 12 + edge_group).reshape(-1, 2, 1)\n            interp_coeff_group = s_group * alpha_group\n            ue_group = self._linear_interp(interp_coeff_group, x_group)\n            nu_e_group = self._linear_interp(interp_coeff_group, nu_group)\n            nu_e_stopvgd_group = self._linear_interp(interp_coeff_group.detach(), nu_group)\n\n            beta_group = torch.gather(input=beta_fx12.reshape(-1), dim=0,\n                                      index=edge_group_to_cube * 12 + edge_group).reshape(-1, 1)\n            beta_sum = beta_sum.index_add_(0, index=edge_group_to_vd, source=beta_group)\n            vd = vd.index_add_(0, index=edge_group_to_vd, source=ue_group * beta_group) / beta_sum\n            nu_d = nu_d.index_add_(0, index=edge_group_to_vd, source=nu_e_group * beta_group) / beta_sum\n            nu_d_stopvgd = nu_d.index_add_(0, index=edge_group_to_vd, source=nu_e_stopvgd_group * beta_group.detach()) / beta_sum.detach()\n            L_dev = self._compute_reg_loss(vd, zero_crossing_group, edge_group_to_vd, vd_num_edges)\n\n        v_idx = torch.arange(vd.shape[0], device=self.device)  # + total_num_vd\n\n        vd_idx_map = (vd_idx_map.reshape(-1)).scatter(dim=0, index=edge_group_to_cube *\n                                                      12 + edge_group, src=v_idx[edge_group_to_vd])\n\n        return vd, nu_d, nu_d_stopvgd, L_dev, vd_gamma, vd_idx_map\n\n    def _triangulate(self, s_n, surf_edges, vd, nu_d, nu_d_stopvgd, vd_gamma, edge_counts, idx_map, vd_idx_map, surf_edges_mask, training, grad_func):\n        \"\"\"\n        Connects four neighboring dual vertices to form a quadrilateral. The quadrilaterals are then split into \n        triangles based on the gamma parameter, as described in Section 4.3.\n        \"\"\"\n        with torch.no_grad():\n            group_mask = (edge_counts == 4) & surf_edges_mask  # surface edges shared by 4 cubes.\n            group = idx_map.reshape(-1)[group_mask]\n            vd_idx = vd_idx_map[group_mask]\n            edge_indices, indices = torch.sort(group, stable=True)\n            quad_vd_idx = vd_idx[indices].reshape(-1, 4)\n\n            # Ensure all face directions point towards the positive SDF to maintain consistent winding.\n            s_edges = s_n[surf_edges[edge_indices.reshape(-1, 4)[:, 0]].reshape(-1)].reshape(-1, 2)\n            flip_mask = s_edges[:, 0] > 0\n            quad_vd_idx = torch.cat((quad_vd_idx[flip_mask][:, [0, 1, 3, 2]],\n                                     quad_vd_idx[~flip_mask][:, [2, 3, 1, 0]]))\n        if grad_func is not None:\n            # when grad_func is given, split quadrilaterals along the diagonals with more consistent gradients.\n            with torch.no_grad():\n                vd_gamma = torch.nn.functional.normalize(grad_func(vd), dim=-1)\n                quad_gamma = torch.index_select(input=vd_gamma, index=quad_vd_idx.reshape(-1), dim=0).reshape(-1, 4, 3)\n                gamma_02 = (quad_gamma[:, 0] * quad_gamma[:, 2]).sum(-1, keepdims=True)\n                gamma_13 = (quad_gamma[:, 1] * quad_gamma[:, 3]).sum(-1, keepdims=True)\n        else:\n            quad_gamma = torch.index_select(input=vd_gamma, index=quad_vd_idx.reshape(-1), dim=0).reshape(-1, 4)\n            gamma_02 = torch.index_select(input=quad_gamma, index=torch.tensor(\n                0, device=self.device), dim=1) * torch.index_select(input=quad_gamma, index=torch.tensor(2, device=self.device), dim=1)\n            gamma_13 = torch.index_select(input=quad_gamma, index=torch.tensor(\n                1, device=self.device), dim=1) * torch.index_select(input=quad_gamma, index=torch.tensor(3, device=self.device), dim=1)\n        if not training:\n            mask = (gamma_02 > gamma_13).squeeze(1)\n            faces = torch.zeros((quad_gamma.shape[0], 6), dtype=torch.long, device=quad_vd_idx.device)\n            faces[mask] = quad_vd_idx[mask][:, self.quad_split_1]\n            faces[~mask] = quad_vd_idx[~mask][:, self.quad_split_2]\n            faces = faces.reshape(-1, 3)\n        else:\n            vd_quad = torch.index_select(input=vd, index=quad_vd_idx.reshape(-1), dim=0).reshape(-1, 4, 3)\n            nu_d_quad = torch.index_select(input=nu_d, index=quad_vd_idx.reshape(-1), dim=0).reshape(-1, 4, 1)\n            nu_d_stopvgd_quad = torch.index_select(input=nu_d_stopvgd, index=quad_vd_idx.reshape(-1), dim=0).reshape(-1, 4, 1)\n            vd_02 = (torch.index_select(input=vd_quad, index=torch.tensor(0, device=self.device), dim=1) +\n                     torch.index_select(input=vd_quad, index=torch.tensor(2, device=self.device), dim=1)) / 2\n            nu_d_02 = (torch.index_select(input=nu_d_quad, index=torch.tensor(0, device=self.device), dim=1) +\n                     torch.index_select(input=nu_d_quad, index=torch.tensor(2, device=self.device), dim=1)) / 2\n            nu_d_stopvgd_02 = (torch.index_select(input=nu_d_stopvgd_quad, index=torch.tensor(0, device=self.device), dim=1) +\n                     torch.index_select(input=nu_d_stopvgd_quad, index=torch.tensor(2, device=self.device), dim=1)) / 2\n            vd_13 = (torch.index_select(input=vd_quad, index=torch.tensor(1, device=self.device), dim=1) +\n                     torch.index_select(input=vd_quad, index=torch.tensor(3, device=self.device), dim=1)) / 2\n            nu_d_13 = (torch.index_select(input=nu_d_quad, index=torch.tensor(1, device=self.device), dim=1) +\n                     torch.index_select(input=nu_d_quad, index=torch.tensor(3, device=self.device), dim=1)) / 2\n            nu_d_stopvgd_13 = (torch.index_select(input=nu_d_stopvgd_quad, index=torch.tensor(1, device=self.device), dim=1) +\n                     torch.index_select(input=nu_d_stopvgd_quad, index=torch.tensor(3, device=self.device), dim=1)) / 2\n            weight_sum = (gamma_02 + gamma_13) + 1e-8\n            vd_center = ((vd_02 * gamma_02.unsqueeze(-1) + vd_13 * gamma_13.unsqueeze(-1)) /\n                         weight_sum.unsqueeze(-1)).squeeze(1)\n            nu_d_center = ((nu_d_02 * gamma_02.unsqueeze(-1) + nu_d_13 * gamma_13.unsqueeze(-1)) /\n                         weight_sum.unsqueeze(-1)).squeeze(1)\n            nu_d_stopvgd_center = ((nu_d_stopvgd_02 * gamma_02.unsqueeze(-1).detach() + nu_d_stopvgd_13 * gamma_13.unsqueeze(-1).detach()) /\n                         weight_sum.unsqueeze(-1).detach()).squeeze(1)\n            vd_center_idx = torch.arange(vd_center.shape[0], device=self.device) + vd.shape[0]\n            vd = torch.cat([vd, vd_center])\n            nu_d = torch.cat([nu_d, nu_d_center])\n            nu_d_stopvgd = torch.cat([nu_d_stopvgd, nu_d_stopvgd_center])\n            faces = quad_vd_idx[:, self.quad_split_train].reshape(-1, 4, 2)\n            faces = torch.cat([faces, vd_center_idx.reshape(-1, 1, 1).repeat(1, 4, 1)], -1).reshape(-1, 3)\n        return vd, nu_d, nu_d_stopvgd, faces, s_edges, edge_indices\n    \n    def _triangulate_msdf(self, vertices, faces, nu_n, nu_n_stopvgd):\n        with torch.no_grad():\n            mocc_n = nu_n >= 0\n            mocc_fx3 = mocc_n[faces.reshape(-1)].reshape(-1,3)\n            mocc_sum = torch.sum(mocc_fx3, -1)\n        \n            \n            uncut_faces_mask = (mocc_sum == 3)\n            cut_faces_mask = (mocc_sum < 3) & (mocc_sum > 0)\n            uncut_faces = faces[uncut_faces_mask]\n            cut_faces = faces[cut_faces_mask]\n        \n        if uncut_faces.size(0) == 0:\n            return vertices, faces, nu_n, nu_n[:1].detach() * 0.0\n\n        vertices_cut_edges_fx2 = vertices[cut_faces[:, [0,1,1,2,2,0]].view(-1)].view(-1, 2, 3)\n        nu_cut_edges_fx2 = nu_n[cut_faces[:, [0,1,1,2,2,0]].view(-1)].view(-1, 2, 1)\n        nu_cut_edges_fx2_stopvgd = nu_n_stopvgd[cut_faces[:, [0,1,1,2,2,0]].view(-1)].view(-1, 2, 1)\n        assert vertices_cut_edges_fx2.size(0) == cut_faces.size(0) * 3 ### DEBUG\n        msdf_zero_crossing = self._linear_interp_nonan(nu_cut_edges_fx2, vertices_cut_edges_fx2)\n        nu_boundary_stopvgd = self._linear_interp_nonan(nu_cut_edges_fx2_stopvgd.detach(), nu_cut_edges_fx2_stopvgd)\n\n        vertices_open = torch.cat([vertices, msdf_zero_crossing], dim=0)\n        nus_open_stopvgd = torch.cat([nu_n_stopvgd, nu_boundary_stopvgd], dim=0)\n\n        with torch.no_grad():\n            v_id = torch.flip(torch.pow(2, torch.arange(3, dtype=torch.long, device=\"cuda\")), dims=[0]) ## do this flip because the triangle table uses a different assumption by mistake..\n            configuration_idx = (mocc_fx3[cut_faces_mask] * v_id.unsqueeze(0)).sum(-1)\n\n        idx_map = torch.cat([cut_faces, vertices.size(0) + torch.arange(cut_faces.size(0) * 3, device='cuda').view(-1, 3)], dim=-1)\n        num_triangles = self.gflex_num_triangles_table[configuration_idx]\n        faces_open = torch.cat([\n            uncut_faces,\n            torch.gather(input=idx_map[num_triangles == 1], dim=1, index=self.gflex_configuration_table[configuration_idx[num_triangles == 1]][:, :3]).view(-1, 3),\n            torch.gather(input=idx_map[num_triangles == 2], dim=1, index=self.gflex_configuration_table[configuration_idx[num_triangles == 2]][:, :6]).view(-1, 3),\n        ])\n\n        return vertices_open, faces_open, nus_open_stopvgd, nu_boundary_stopvgd\n\n    def _tetrahedralize(\n            self, x_nx3, s_n, cube_fx8, vertices, faces, surf_edges, s_edges, vd_idx_map, case_ids, edge_indices,\n            surf_cubes, training):\n        \"\"\"\n        Tetrahedralizes the interior volume to produce a tetrahedral mesh, as described in Section 4.5.\n        \"\"\"\n        occ_n = s_n < 0\n        occ_fx8 = occ_n[cube_fx8.reshape(-1)].reshape(-1, 8)\n        occ_sum = torch.sum(occ_fx8, -1)\n\n        inside_verts = x_nx3[occ_n]\n        mapping_inside_verts = torch.ones((occ_n.shape[0]), dtype=torch.long, device=self.device) * -1\n        mapping_inside_verts[occ_n] = torch.arange(occ_n.sum(), device=self.device) + vertices.shape[0]\n        \"\"\" \n        For each grid edge connecting two grid vertices with different\n        signs, we first form a four-sided pyramid by connecting one\n        of the grid vertices with four mesh vertices that correspond\n        to the grid edge and then subdivide the pyramid into two tetrahedra\n        \"\"\"\n        inside_verts_idx = mapping_inside_verts[surf_edges[edge_indices.reshape(-1, 4)[:, 0]].reshape(-1, 2)[\n            s_edges < 0]]\n        if not training:\n            inside_verts_idx = inside_verts_idx.unsqueeze(1).expand(-1, 2).reshape(-1)\n        else:\n            inside_verts_idx = inside_verts_idx.unsqueeze(1).expand(-1, 4).reshape(-1)\n\n        tets_surface = torch.cat([faces, inside_verts_idx.unsqueeze(-1)], -1)\n        \"\"\" \n        For each grid edge connecting two grid vertices with the\n        same sign, the tetrahedron is formed by the two grid vertices\n        and two vertices in consecutive adjacent cells\n        \"\"\"\n        inside_cubes = (occ_sum == 8)\n        inside_cubes_center = x_nx3[cube_fx8[inside_cubes].reshape(-1)].reshape(-1, 8, 3).mean(1)\n        inside_cubes_center_idx = torch.arange(\n            inside_cubes_center.shape[0], device=inside_cubes.device) + vertices.shape[0] + inside_verts.shape[0]\n\n        surface_n_inside_cubes = surf_cubes | inside_cubes\n        edge_center_vertex_idx = torch.ones(((surface_n_inside_cubes).sum(), 13),\n                                            dtype=torch.long, device=x_nx3.device) * -1\n        surf_cubes = surf_cubes[surface_n_inside_cubes]\n        inside_cubes = inside_cubes[surface_n_inside_cubes]\n        edge_center_vertex_idx[surf_cubes, :12] = vd_idx_map.reshape(-1, 12)\n        edge_center_vertex_idx[inside_cubes, 12] = inside_cubes_center_idx\n\n        all_edges = cube_fx8[surface_n_inside_cubes][:, self.cube_edges].reshape(-1, 2)\n        unique_edges, _idx_map, counts = torch.unique(all_edges, dim=0, return_inverse=True, return_counts=True)\n        unique_edges = unique_edges.long()\n        mask_edges = occ_n[unique_edges.reshape(-1)].reshape(-1, 2).sum(-1) == 2\n        mask = mask_edges[_idx_map]\n        counts = counts[_idx_map]\n        mapping = torch.ones((unique_edges.shape[0]), dtype=torch.long, device=self.device) * -1\n        mapping[mask_edges] = torch.arange(mask_edges.sum(), device=self.device)\n        idx_map = mapping[_idx_map]\n\n        group_mask = (counts == 4) & mask\n        group = idx_map.reshape(-1)[group_mask]\n        edge_indices, indices = torch.sort(group)\n        cube_idx = torch.arange((_idx_map.shape[0] // 12), dtype=torch.long,\n                                device=self.device).unsqueeze(1).expand(-1, 12).reshape(-1)[group_mask]\n        edge_idx = torch.arange((12), dtype=torch.long, device=self.device).unsqueeze(\n            0).expand(_idx_map.shape[0] // 12, -1).reshape(-1)[group_mask]\n        # Identify the face shared by the adjacent cells.\n        cube_idx_4 = cube_idx[indices].reshape(-1, 4)\n        edge_dir = self.edge_dir_table[edge_idx[indices]].reshape(-1, 4)[..., 0]\n        shared_faces_4x2 = self.dir_faces_table[edge_dir].reshape(-1)\n        cube_idx_4x2 = cube_idx_4[:, self.adj_pairs].reshape(-1)\n        # Identify an edge of the face with different signs and\n        # select the mesh vertex corresponding to the identified edge.\n        case_ids_expand = torch.ones((surface_n_inside_cubes).sum(), dtype=torch.long, device=x_nx3.device) * 255\n        case_ids_expand[surf_cubes] = case_ids\n        cases = case_ids_expand[cube_idx_4x2]\n        quad_edge = edge_center_vertex_idx[cube_idx_4x2, self.tet_table[cases, shared_faces_4x2]].reshape(-1, 2)\n        mask = (quad_edge == -1).sum(-1) == 0\n        inside_edge = mapping_inside_verts[unique_edges[mask_edges][edge_indices].reshape(-1)].reshape(-1, 2)\n        tets_inside = torch.cat([quad_edge, inside_edge], -1)[mask]\n\n        tets = torch.cat([tets_surface, tets_inside])\n        vertices = torch.cat([vertices, inside_verts, inside_cubes_center])\n        return vertices, tets"
  },
  {
    "path": "geometry/gshell_flexicubes_geometry.py",
    "content": "# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES.  All rights reserved.\n#\n# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property\n# and proprietary rights in and to this software, related documentation\n# and any modifications thereto.  Any use, reproduction, disclosure or\n# distribution of this software and related documentation without an express\n# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited\n\nimport os\nfrom tqdm import trange\n\nimport numpy as np\nimport torch\nimport torch.nn.functional as F\n\nfrom render import mesh\nfrom render import render\nimport render.optixutils as ou\nfrom render import regularizer\n\nfrom .gshell_flexicubes import GShellFlexiCubes\nfrom render import util\n\nimport kaolin\n\nfrom .mlp import MLP\n\n\n###############################################################################\n# Regularizer\n###############################################################################\n\ndef compute_sdf_reg_loss(sdf, all_edges):\n    sdf_f1x6x2 = sdf[all_edges.reshape(-1)].reshape(-1,2)\n    mask = torch.sign(sdf_f1x6x2[...,0]) != torch.sign(sdf_f1x6x2[...,1])\n    sdf_f1x6x2 = sdf_f1x6x2[mask]\n    sdf_diff = torch.nn.functional.binary_cross_entropy_with_logits(sdf_f1x6x2[...,0], (sdf_f1x6x2[...,1] > 0).float()) + \\\n            torch.nn.functional.binary_cross_entropy_with_logits(sdf_f1x6x2[...,1], (sdf_f1x6x2[...,0] > 0).float())\n    return sdf_diff\n\n###############################################################################\n#  Geometry interface\n###############################################################################\n\nclass GShellFlexiCubesGeometry(torch.nn.Module):\n    def __init__(self, grid_res, scale, FLAGS):\n        super(GShellFlexiCubesGeometry, self).__init__()\n\n        self.FLAGS         = FLAGS\n        self.grid_res      = grid_res\n        self.gflexicubes    = GShellFlexiCubes()\n        verts, indices     = self.gflexicubes.construct_voxel_grid(grid_res) \n        self.boxscale      = torch.tensor(FLAGS.boxscale).view(1, 3).cuda()\n\n        with torch.no_grad():\n            self.optix_ctx = ou.OptiXContext()\n\n        n_cubes = indices.shape[0]\n        per_cube_weights = torch.ones((n_cubes, 21),dtype=torch.float,device='cuda')\n\n        self.verts    = verts * scale * self.boxscale\n        self.indices  = indices\n        print(\"FlexiCubes grid min/max\", torch.min(self.verts).item(), torch.max(self.verts).item())\n        self.generate_edges()\n\n        if self.FLAGS.use_sdf_mlp:\n            self.sdf    = torch.nn.Parameter(torch.zeros_like(self.verts[:, 0]), requires_grad=True) ## placeholder\n            self.register_parameter('sdf', self.sdf)\n            self.sdf_net = MLP(\n                skip_in=self.FLAGS.skip_in,\n                n_freq=self.FLAGS.n_freq,\n                n_hidden=self.FLAGS.n_hidden,\n                d_hidden=self.FLAGS.d_hidden,\n                use_float16=self.FLAGS.use_float16\n            )\n            self.sdf_net.cuda()\n\n            optimizer = torch.optim.Adam(self.sdf_net.parameters(), lr=1e-3)\n            for _ in trange(self.FLAGS.sdf_mlp_pretrain_steps):\n                scaled_verts = self.verts / self.boxscale\n                loss = (self.sdf_net(self.verts) - (scaled_verts.norm(dim=1, keepdim=True) - self.FLAGS.sphere_init_norm)).pow(2).mean()\n                optimizer.zero_grad()\n                loss.backward()\n                optimizer.step()\n            print('sdf net trained with loss:', loss)\n\n        else:\n            # Random init\n            if not self.FLAGS.sphere_init:\n                sdf = torch.rand_like(self.verts[:,0]) - 0.1\n            else:\n                scaled_verts = self.verts / self.boxscale\n                sdf = scaled_verts.norm(dim=1) - 0.5\n            self.sdf    = torch.nn.Parameter(sdf.clone().detach(), requires_grad=True)\n            self.register_parameter('sdf', self.sdf)\n        self.per_cube_weights = torch.nn.Parameter(torch.ones_like(per_cube_weights), requires_grad=True)\n        self.register_parameter('weight', self.per_cube_weights)\n\n\n\n        msdf         = (torch.rand_like(self.verts[:,0]) - 0.01).clamp(-1, 1)\n        self.msdf    = torch.nn.Parameter(msdf.clone().detach(), requires_grad=True)\n        self.register_parameter('msdf', self.msdf)\n\n        self.deform = torch.nn.Parameter(torch.zeros_like(self.verts), requires_grad=True)\n        self.register_parameter('deform', self.deform)\n\n        self.clamp_deform()\n\n    @torch.no_grad()\n    def generate_edges(self):\n        with torch.no_grad():\n            edges = self.gflexicubes.cube_edges\n            all_edges = self.indices[:,edges].reshape(-1,2)\n            all_edges_sorted = torch.sort(all_edges, dim=1)[0]\n            self.all_edges = torch.unique(all_edges_sorted, dim=0)\n            self.max_displacement = util.length(self.verts[self.all_edges[:, 0]] - self.verts[self.all_edges[:, 1]]).mean() / 4\n\n    @torch.no_grad()\n    def getAABB(self):\n        return torch.min(self.verts, dim=0).values, torch.max(self.verts, dim=0).values\n\n    @torch.no_grad()\n    def clamp_deform(self):\n        if not self.FLAGS.use_tanh_deform:\n            self.deform.data[:] = self.deform.clamp(-1.0, 1.0)\n        self.msdf.data[:] = self.msdf.clamp(-2.0, 2.0)\n\n    @torch.no_grad()\n    def map_uv2(self, faces):\n        uvs = torch.tensor([[0.0, 0.0], [1.0, 0.0], [0.0, 1.0]], dtype=torch.float, device='cuda')\n        uv_idx = torch.tensor([0,1,2], dtype=torch.long, device='cuda').repeat(faces.shape[0],1)\n        return uvs, uv_idx\n\n    @torch.no_grad()\n    def map_uv(self, face_gidx, max_idx):\n        N = int(np.ceil(np.sqrt((max_idx+1)//2)))\n        tex_y, tex_x = torch.meshgrid(\n            torch.linspace(0, 1 - (1 / N), N, dtype=torch.float32, device=\"cuda\"),\n            torch.linspace(0, 1 - (1 / N), N, dtype=torch.float32, device=\"cuda\")\n        )\n\n        pad = 0.9 / N\n\n        uvs = torch.stack([\n            tex_x      , tex_y,\n            tex_x + pad, tex_y,\n            tex_x + pad, tex_y + pad,\n            tex_x      , tex_y + pad\n        ], dim=-1).view(-1, 2)\n\n        def _idx(tet_idx, N):\n            x = tet_idx % N\n            y = torch.div(tet_idx, N, rounding_mode='floor')\n            return y * N + x\n\n        tet_idx = _idx(torch.div(face_gidx, N, rounding_mode='floor'), N)\n        tri_idx = face_gidx % 2\n\n        uv_idx = torch.stack((\n            tet_idx * 4, tet_idx * 4 + tri_idx + 1, tet_idx * 4 + tri_idx + 2\n        ), dim = -1). view(-1, 3)\n\n        return uvs, uv_idx\n\n    def getMesh(self, material, _training=False):\n        v_deformed = self.verts + self.max_displacement * self.deform\n        if self.FLAGS.use_sdf_mlp:\n            sdf = self.sdf_net(v_deformed)\n        else:\n            sdf = self.sdf\n\n\n        if self.FLAGS.use_msdf_mlp:\n            msdf = self.msdf_net(v_deformed)\n        else:\n            msdf = self.msdf\n\n        verts, faces, reg_loss, extra = self.gflexicubes(v_deformed, sdf, msdf, self.indices, self.grid_res, \n                            self.per_cube_weights[:,:12], self.per_cube_weights[:,12:20], self.per_cube_weights[:,20],\n                            training=_training)\n\n        self.gflexi_reg_loss = reg_loss.mean()\n\n        face_gidx = torch.arange(faces.shape[0], dtype=torch.long, device=\"cuda\")\n        uvs, uv_idx = self.map_uv(face_gidx, faces.shape[0])\n\n        imesh = mesh.Mesh(verts, faces, v_tex=uvs, t_tex_idx=uv_idx, material=material)\n\n        with torch.no_grad():\n            ou.optix_build_bvh(self.optix_ctx, imesh.v_pos.contiguous(), imesh.t_pos_idx.int(), rebuild=1)\n\n        # Run mesh operations to generate tangent space\n        imesh = mesh.auto_normals(imesh)\n        return_dict = {\n            'imesh': imesh,\n            'sdf': sdf,\n            'msdf': extra['msdf'],\n            'msdf_watertight': extra['msdf_watertight'],\n            'msdf_boundary': extra['msdf_boundary'],\n            'n_verts_watertight': extra['n_verts_watertight'],\n        }\n    \n        if self.FLAGS.visualize_watertight:\n            imesh_watertight = mesh.Mesh(extra['vertices_watertight'], extra['faces_watertight'], v_tex=None, t_tex_idx=None, material=material)\n            imesh_watertight = mesh.auto_normals(imesh_watertight)\n            return_dict['imesh_watertight'] = imesh_watertight\n        return return_dict\n\n    def render(self, glctx, target, lgt, opt_material, bsdf=None, denoiser=None, shadow_scale=1.0,\n            use_uv=False, training=False):\n        opt_mesh_dict = self.getMesh(opt_material)\n        opt_mesh = opt_mesh_dict['imesh']\n        opt_mesh_watertight = opt_mesh_dict['imesh_watertight'] if 'imesh_watertight' in opt_mesh_dict else None\n        if opt_mesh.v_pos.size(0) != 0:\n            sampled_pts = kaolin.ops.mesh.sample_points(opt_mesh.v_pos[None,...], opt_mesh.t_pos_idx, 50000)[0][0]\n            opt_mesh_dict['sampled_pts'] = sampled_pts\n        else:\n            opt_mesh_dict['sampled_pts'] = None\n\n        extra_dict = {\n            'msdf': opt_mesh_dict['msdf'],\n        }\n        opt_mesh_dict['buffers'] = render.render_mesh(\n            self.FLAGS, glctx, opt_mesh, target['mvp'], target['campos'], lgt, target['resolution'], spp=target['spp'], \n            msaa=True, background=target['background'], bsdf=bsdf, use_uv=use_uv,\n            optix_ctx=self.optix_ctx, denoiser=denoiser, shadow_scale=shadow_scale,\n            extra_dict=extra_dict)\n        if self.FLAGS.visualize_watertight:\n            opt_mesh_dict['buffers_watertight'] = render.render_mesh(\n                self.FLAGS, glctx, opt_mesh_watertight, target['mvp'], target['campos'], lgt, target['resolution'], spp=target['spp'], \n                msaa=True, background=target['background'], bsdf=bsdf, use_uv=use_uv,\n                optix_ctx=self.optix_ctx, denoiser=denoiser, shadow_scale=shadow_scale,\n                extra_dict=extra_dict)\n        return opt_mesh_dict\n\n    def tick(self, glctx, target, lgt, opt_material, loss_fn, iteration, denoiser):\n\n        t_iter = iteration / self.FLAGS.iter\n\n        # ==============================================================================================\n        #  Render optimizable object with identical conditions\n        # ==============================================================================================\n        shadow_ramp = min(iteration / 1000, 1.0)\n        if denoiser is not None: denoiser.set_influence(shadow_ramp)\n        opt_mesh_dict = self.render(glctx, target, lgt, opt_material, \n            denoiser=denoiser,\n            shadow_scale=shadow_ramp, \n            training=True)\n        buffers = opt_mesh_dict['buffers']\n\n        # ==============================================================================================\n        #  Compute loss\n        # ==============================================================================================\n\n        with torch.no_grad():\n            # Image-space loss, split into a coverage component and a color component\n            color_ref = target['img']\n            gt_mask = color_ref[..., 3:]\n\n        img_loss = F.mse_loss(buffers['shaded'][..., 3:], color_ref[..., 3:]) \n        img_loss = img_loss + loss_fn(buffers['shaded'][..., 0:3] * color_ref[..., 3:], color_ref[..., 0:3] * color_ref[..., 3:])\n\n\n        img_loss = img_loss + 5e-1 * F.l1_loss(buffers['msdf_image'].clamp(min=0) * (gt_mask == 0).float(), torch.zeros_like(gt_mask))\n        img_loss = img_loss + 5e-1 * F.l1_loss(buffers['msdf_image'].clamp(max=0) * (gt_mask == 1).float(), torch.ones_like(gt_mask))\n\n\n        if self.FLAGS.use_img_2nd_layer:\n            color_ref_2nd = target['img_second']\n            img_loss = img_loss + F.mse_loss(buffers['shaded_second'][..., 3:], color_ref_2nd[..., 3:]) \n            img_loss = img_loss + loss_fn(buffers['shaded_second'][..., 0:3] * color_ref_2nd[..., 3:], color_ref_2nd[..., 0:3] * color_ref_2nd[..., 3:])\n\n        if self.FLAGS.use_depth:\n            depth_loss_scale = 100.\n            depth_loss = depth_loss_scale * ((buffers['invdepth'][:, :, :, :1] - target['invdepth'][:, :, :, :1]).abs()).mean()\n\n            if self.FLAGS.use_depth_2nd_layer:\n                depth_loss += 0.1 * depth_loss_scale * ((buffers['invdepth_second'][:, :, :, :1] - target['invdepth_second'][:, :, :, :1]).abs()).mean()\n        else:\n            depth_loss = torch.tensor(0., device=img_loss.device)\n\n        # Eikonal\n        if self.FLAGS.use_sdf_mlp and self.FLAGS.use_eikonal and opt_mesh_dict['sampled_pts'] is not None:\n            v = opt_mesh_dict['sampled_pts'].detach()\n            v.requires_grad = True\n\n            sdf_eik = self.sdf_net(v)\n            if self.FLAGS.eikonal_scale is None:\n                ### Default hardcoded Eikonal loss schedule\n                if iteration < 500:\n                    eik_coeff = 3e-1\n                elif iteration < 1000:\n                    eik_coeff = 1e-1\n                elif iteration < 2000:\n                    eik_coeff = 1e-1\n                else:\n                    eik_coeff = 1e-2\n            else:\n                eik_coeff = self.FLAGS.eikonal_scale\n\n            eik_loss = eik_coeff * (\n                torch.autograd.grad(sdf_eik.sum(), v, create_graph=True)[0].pow(2).sum(dim=-1).sqrt() - 1\n            ).pow(2).mean()\n        else:\n            eik_loss = torch.tensor(0., device=img_loss.device)\n\n        if self.FLAGS.use_mesh_msdf_reg:\n            mesh_msdf_regscale = (64 / self.grid_res) ** 3 # scale inversely proportional to grid_res^3\n            eps = 1e-3\n            open_scale = self.FLAGS.msdf_reg_open_scale\n            close_scale = self.FLAGS.msdf_reg_close_scale\n            eps = torch.tensor([eps]).cuda()\n            mesh_msdf_reg_loss = open_scale * mesh_msdf_regscale * F.huber_loss(\n                opt_mesh_dict['msdf'].clamp(min=-eps).squeeze(), \n                -eps.expand(opt_mesh_dict['msdf'].size(0)), \n                reduction='sum'\n            )\n\n            if close_scale != 0:\n                with torch.no_grad():\n                    visible_verts = (opt_mesh_dict['imesh'].t_pos_idx[buffers['visible_triangles']]).unique()\n                    visible_boundary_verts = visible_verts[visible_verts >= opt_mesh_dict['n_verts_watertight']] - opt_mesh_dict['n_verts_watertight']\n                    visible_boundary_mask = torch.zeros(opt_mesh_dict['msdf_boundary'].size(0)).cuda()\n                    visible_boundary_mask[visible_boundary_verts] = 1\n                    visible_boundary_mask = visible_boundary_mask.bool()\n\n                boundary_msdf = opt_mesh_dict['msdf_boundary']\n                boundary_msdf = boundary_msdf[visible_boundary_mask]\n                mesh_msdf_reg_loss += close_scale * mesh_msdf_regscale * F.huber_loss(\n                    boundary_msdf.clamp(max=eps).squeeze(), \n                    eps.expand(boundary_msdf.size(0)), \n                    reduction='sum'\n                )\n        else:\n            mesh_msdf_reg_loss = torch.tensor(0., device=img_loss.device)\n\n        # SDF regularizer\n        sdf_weight = self.FLAGS.sdf_regularizer - (self.FLAGS.sdf_regularizer - 0.01) * min(1.0, 4.0 * t_iter)\n        sdf_reg_loss = compute_sdf_reg_loss(opt_mesh_dict['sdf'], self.all_edges).mean() * sdf_weight\n\n        # Monochrome shading regularizer\n        if 'diffuse_light' not in buffers:\n            monochrome_loss = torch.zeros_like(img_loss)\n        else:\n            monochrome_loss = regularizer.shading_loss(buffers['diffuse_light'], buffers['specular_light'], color_ref, self.FLAGS.lambda_diffuse, self.FLAGS.lambda_specular)\n\n        # Material smoothness regularizer\n        mtl_smooth_loss = regularizer.material_smoothness_grad(\n            buffers['kd_grad'], buffers['ks_grad'], buffers['normal_grad'], \n            lambda_kd=self.FLAGS.lambda_kd, lambda_ks=self.FLAGS.lambda_ks, lambda_nrm=self.FLAGS.lambda_nrm)\n\n        # Chroma regularizer\n        chroma_loss = regularizer.chroma_loss(buffers['kd'], color_ref, self.FLAGS.lambda_chroma)\n        assert 'perturbed_nrm' not in buffers # disable normal map in first pass\n\n        # FlexiCubes reg loss\n        flexicube_reg_loss = self.gflexi_reg_loss * 0.25\n\n        geo_reg_loss = sdf_reg_loss + eik_loss + mesh_msdf_reg_loss + flexicube_reg_loss\n        shading_reg_loss =  monochrome_loss + mtl_smooth_loss + chroma_loss\n        reg_loss = geo_reg_loss + shading_reg_loss\n\n        return img_loss, depth_loss, reg_loss\n"
  },
  {
    "path": "geometry/gshell_tets.py",
    "content": "import numpy as np\nimport torch\n\nfrom render import util\n\n######################################################################################\n# Simple smooth vertex normal computation\n######################################################################################\ndef auto_normals(v_pos, t_pos_idx):\n\n    i0 = t_pos_idx[:, 0]\n    i1 = t_pos_idx[:, 1]\n    i2 = t_pos_idx[:, 2]\n\n    v0 = v_pos[i0, :]\n    v1 = v_pos[i1, :]\n    v2 = v_pos[i2, :]\n\n    face_normals = torch.cross(v1 - v0, v2 - v0)\n\n    # Splat face normals to vertices\n    v_nrm = torch.zeros_like(v_pos)\n    v_nrm.scatter_add_(0, i0[:, None].repeat(1,3), face_normals)\n    v_nrm.scatter_add_(0, i1[:, None].repeat(1,3), face_normals)\n    v_nrm.scatter_add_(0, i2[:, None].repeat(1,3), face_normals)\n\n    # Normalize, replace zero (degenerated) normals with some default value\n    v_nrm = torch.where(util.dot(v_nrm, v_nrm) > 1e-20, v_nrm, torch.tensor([0.0, 0.0, 1.0], dtype=torch.float32, device='cuda'))\n    v_nrm = util.safe_normalize(v_nrm)\n\n    if torch.is_anomaly_enabled():\n        assert torch.all(torch.isfinite(v_nrm))\n\n    return v_nrm, t_pos_idx\n\n######################################################################################\n# Compute tangent space from texture map coordinates\n# Follows http://www.mikktspace.com/ conventions\n######################################################################################\ndef compute_tangents(v_pos, v_tex, v_nrm, t_pos_idx, t_tex_idx, t_nrm_idx):\n    vn_idx = [None] * 3\n    pos = [None] * 3\n    tex = [None] * 3\n    for i in range(0,3):\n        pos[i] = v_pos[t_pos_idx[:, i]]\n        tex[i] = v_tex[t_tex_idx[:, i]]\n        vn_idx[i] = t_nrm_idx[:, i]\n\n    tangents = torch.zeros_like(v_nrm)\n    tansum   = torch.zeros_like(v_nrm)\n\n    # Compute tangent space for each triangle\n    uve1 = tex[1] - tex[0]\n    uve2 = tex[2] - tex[0]\n    pe1  = pos[1] - pos[0]\n    pe2  = pos[2] - pos[0]\n    \n    nom   = (pe1 * uve2[..., 1:2] - pe2 * uve1[..., 1:2])\n    denom = (uve1[..., 0:1] * uve2[..., 1:2] - uve1[..., 1:2] * uve2[..., 0:1])\n    \n    # Avoid dimsdfion by zero for degenerated texture coordinates\n    tang = nom / torch.where(denom > 0.0, torch.clamp(denom, min=1e-6), torch.clamp(denom, max=-1e-6))\n\n    # Update all 3 vertices\n    for i in range(0,3):\n        idx = vn_idx[i][:, None].repeat(1,3)\n        tangents.scatter_add_(0, idx, tang)                # tangents[n_i] = tangents[n_i] + tang\n        tansum.scatter_add_(0, idx, torch.ones_like(tang)) # tansum[n_i] = tansum[n_i] + 1\n    tangents = tangents / tansum\n\n    # Normalize and make sure tangent is perpendicular to normal\n    tangents = util.safe_normalize(tangents)\n    tangents = util.safe_normalize(tangents - util.dot(tangents, v_nrm) * v_nrm)\n\n    if torch.is_anomaly_enabled():\n        assert torch.all(torch.isfinite(tangents))\n\n    return tangents, t_nrm_idx\n\nclass GShell_Tets:\n    def __init__(self):\n        self.triangle_table = torch.tensor([\n                [-1, -1, -1, -1, -1, -1],\n                [ 1,  0,  2, -1, -1, -1],\n                [ 4,  0,  3, -1, -1, -1],\n                [ 1,  4,  2,  1,  3,  4],\n                [ 3,  1,  5, -1, -1, -1],\n                [ 2,  3,  0,  2,  5,  3],\n                [ 1,  4,  0,  1,  5,  4],\n                [ 4,  2,  5, -1, -1, -1],\n                [ 4,  5,  2, -1, -1, -1],\n                [ 4,  1,  0,  4,  5,  1],\n                [ 3,  2,  0,  3,  5,  2],\n                [ 1,  3,  5, -1, -1, -1],\n                [ 4,  1,  2,  4,  3,  1],\n                [ 3,  0,  4, -1, -1, -1],\n                [ 2,  0,  1, -1, -1, -1],\n                [-1, -1, -1, -1, -1, -1]\n                ], dtype=torch.long, device='cuda')\n\n        self.mesh_edge_table = torch.tensor([\n                [-1, -1, -1, -1, -1, -1],\n                [ 1,  0,  2,  1, -1, -1],\n                [ 4,  0,  3,  4, -1, -1],\n                [ 1,  3,  4,  2,  1, -1],\n                [ 3,  1,  5,  3, -1, -1],\n                [ 2,  5,  3,  0,  2, -1],\n                [ 1,  5,  4,  0,  1, -1],\n                [ 4,  2,  5,  4, -1, -1],\n                [ 4,  5,  2,  4, -1, -1],\n                [ 4,  5,  1,  0,  4, -1],\n                [ 3,  5,  2,  0,  3, -1],\n                [ 1,  3,  5,  1, -1, -1],\n                [ 4,  3,  1,  2,  4, -1],\n                [ 3,  0,  4,  3, -1, -1],\n                [ 2,  0,  1,  2, -1, -1],\n                [-1, -1, -1, -1, -1, -1]\n                ], dtype=torch.long, device='cuda')\n\n\n        self.triangle_table_tri = torch.tensor([\n            ## 000\n                [-1, -1, -1, -1, -1, -1],\n            ## 001\n                [ 4,  2,  5, -1, -1, -1],\n            ## 010\n                [ 3,  1,  4, -1, -1, -1],\n            ## 011\n                [ 3,  1,  2,  3,  2,  5],\n            ## 100\n                [ 0,  3,  5, -1, -1, -1],\n            ## 101\n                [ 0,  3,  4,  0,  4,  2],\n            ## 110\n                [ 0,  1,  4,  0,  4,  5],\n            ## 111\n                [ 0,  1,  2, -1, -1, -1],\n        ], dtype=torch.long, device='cuda')\n\n        self.triangle_table_quad = torch.tensor([\n            ### in the order of [0, 1, 2, 3]\n            ### so 1000 corresponds to single positive mSDF vertex of index 0\n            ## 0000\n                [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1],\n            ## 0001\n                [ 6,  3,  7, -1, -1, -1, -1, -1, -1, -1, -1, -1],\n            ## 0010\n                [ 5,  2,  6, -1, -1, -1, -1, -1, -1, -1, -1, -1],\n            ## 0011\n                [ 5,  2,  7,  3,  7,  2, -1, -1, -1, -1, -1, -1],\n            ## 0100\n                [ 4,  1,  5, -1, -1, -1, -1, -1, -1, -1, -1, -1],\n            ## 0101\n                [ 4,  1,  5,  4,  5,  7,  5,  6,  7,  7,  6,  3],\n            ## 0110\n                [ 4,  1,  2,  6,  4,  2, -1, -1, -1, -1, -1, -1],\n            ## 0111\n                [ 4,  1,  2,  7,  4,  2,  7,  2,  3, -1, -1, -1],\n            ## 1000\n                [ 0,  4,  7, -1, -1, -1, -1, -1, -1, -1, -1, -1],\n            ## 1001\n                [ 0,  4,  6,  3,  0,  6, -1, -1, -1, -1, -1, -1],\n            ## 1010\n                [ 0,  4,  5,  0,  5,  2,  0,  2,  6,  0,  6,  7],\n            ## 1011\n                [ 0,  4,  5,  0,  5,  2,  0,  2,  3, -1, -1, -1],\n            ## 1100\n                [ 0,  1,  5,  7,  0,  5, -1, -1, -1, -1, -1, -1],\n            ## 1101\n                [ 0,  1,  5,  0,  5,  6,  0,  6,  3, -1, -1, -1],\n            ## 1110\n                [ 0,  1,  2,  0,  2,  6,  0,  6,  7, -1, -1, -1],\n            ## 1111\n                [ 0,  1,  2,  0,  2,  3, -1, -1, -1, -1, -1, -1],\n        ], dtype=torch.long, device='cuda')\n\n        self.num_triangles_table = torch.tensor([0,1,1,2,1,2,2,1,1,2,2,1,2,1,1,0], dtype=torch.long, device='cuda')\n        self.base_tet_edges = torch.tensor([0,1,0,2,0,3,1,2,1,3,2,3], dtype=torch.long, device='cuda')\n\n        self.num_triangles_tri_table = torch.tensor([0,1,1,2,1,2,2,1], dtype=torch.long, device='cuda')\n        self.num_triangles_quad_table = torch.tensor([0,1,1,2,1,4,2,3,1,2,4,3,2,3,3,2], dtype=torch.long, device='cuda')\n\n        edge_ind_list = [[0, 1], [0, 2], [0, 3], [1, 2], [1, 3], [2, 3]]\n        msdf_from_tetverts = []\n        for i in range(5):\n            for j in range(i+1, 6):\n                if (edge_ind_list[i][0] == edge_ind_list[j][0]\n                    or edge_ind_list[i][0] == edge_ind_list[j][1]\n                    or edge_ind_list[i][1] == edge_ind_list[j][0]\n                    or edge_ind_list[i][1] == edge_ind_list[j][1]\n                ):\n                    msdf_from_tetverts.extend([edge_ind_list[i][0], edge_ind_list[i][1], edge_ind_list[j][0], edge_ind_list[j][1]])\n\n        self.msdf_from_tetverts = torch.tensor(msdf_from_tetverts)\n\n    ###############################################################################\n    # Utility functions\n    ###############################################################################\n\n    def sort_edges(self, edges_ex2):\n        with torch.no_grad():\n            order = (edges_ex2[:,0] > edges_ex2[:,1]).long()\n            order = order.unsqueeze(dim=1)\n\n            a = torch.gather(input=edges_ex2, index=order, dim=1)      \n            b = torch.gather(input=edges_ex2, index=1-order, dim=1)  \n\n        return torch.stack([a, b],-1)\n\n    def map_uv(self, face_gidx, max_idx):\n        N = int(np.ceil(np.sqrt((max_idx+1)//2)))\n        tex_y, tex_x = torch.meshgrid(\n            torch.linspace(0, 1 - (1 / N), N, dtype=torch.float32, device=\"cuda\"),\n            torch.linspace(0, 1 - (1 / N), N, dtype=torch.float32, device=\"cuda\"),\n            indexing='ij'\n        )\n\n        pad = 0.9 / N\n\n        uvs = torch.stack([\n            tex_x      , tex_y,\n            tex_x + pad, tex_y,\n            tex_x + pad, tex_y + pad,\n            tex_x      , tex_y + pad\n        ], dim=-1).view(-1, 2)\n\n        def _idx(tet_idx, N):\n            x = tet_idx % N\n            y = torch.div(tet_idx, N, rounding_mode='trunc')\n            return y * N + x\n\n        tet_idx = _idx(torch.div(face_gidx, 2, rounding_mode='trunc'), N)\n        tri_idx = face_gidx % 2\n\n        uv_idx = torch.stack((\n            tet_idx * 4, tet_idx * 4 + tri_idx + 1, tet_idx * 4 + tri_idx + 2\n        ), dim = -1). view(-1, 3)\n\n        return uvs, uv_idx\n\n    ###############################################################################\n    # Marching tets implementation\n    ###############################################################################\n\n    def __call__(self, pos_nx3, sdf_n, msdf_n, tet_fx4, output_watertight_template=True):\n        sdf_n = sdf_n.float()\n        with torch.no_grad():\n            ### To determine if tets are valid\n            ### Step 1: SDF criteria\n            occ_n = sdf_n > 0\n            occ_fx4 = occ_n[tet_fx4.reshape(-1)].reshape(-1,4)\n            occ_sum = torch.sum(occ_fx4, -1)\n\n\n            ### Step 2: pre-filtering with mSDF - mSDF cannot be all non-negative\n            msdf_fx4 = msdf_n[tet_fx4.reshape(-1)].reshape(-1,4)\n            msdf_sign_fx4 = msdf_fx4 > 0\n            msdf_sign_sum = torch.sum(msdf_sign_fx4, -1)\n\n            if output_watertight_template:\n                valid_tets = (occ_sum>0) & (occ_sum<4) \n            else:\n                valid_tets = (occ_sum>0) & (occ_sum<4) & (msdf_sign_sum > 0)\n\n            # find all vertices\n            all_edges = tet_fx4[valid_tets][:,self.base_tet_edges].reshape(-1,2)\n            all_edges = self.sort_edges(all_edges)\n            unique_edges, idx_map = torch.unique(all_edges, dim=0, return_inverse=True)  \n\n            unique_edges = unique_edges.long()\n            mask_edges = occ_n[unique_edges.reshape(-1)].reshape(-1,2).sum(-1) == 1\n            mapping = torch.ones((unique_edges.shape[0]), dtype=torch.long, device=\"cuda\") * -1\n            mapping[mask_edges] = torch.arange(mask_edges.sum(), dtype=torch.long, device=\"cuda\")\n            idx_map = mapping[idx_map] # map edges to verts\n\n            interp_v = unique_edges[mask_edges]\n        edges_to_interp = pos_nx3[interp_v.reshape(-1)].reshape(-1,2,3)\n        edges_to_interp_sdf = sdf_n[interp_v.reshape(-1)].reshape(-1,2,1)\n        edges_to_interp_sdf[:,-1] *= -1\n\n        denominator = edges_to_interp_sdf.sum(1, keepdim = True)\n        denominator = torch.sign(denominator) * (denominator.abs() + 1e-12)\n        denominator[denominator == 0] = 1e-12\n\n        edges_to_interp_sdf = torch.flip(edges_to_interp_sdf, [1]) / denominator\n        verts = (edges_to_interp * edges_to_interp_sdf).sum(1)\n\n        msdf_to_interp = msdf_n[interp_v.reshape(-1)].reshape(-1,2)\n        msdf_vert = (msdf_to_interp * edges_to_interp_sdf.squeeze(-1)).sum(1)\n        msdf_vert_stopvgd = (msdf_to_interp * edges_to_interp_sdf.squeeze(-1).detach()).sum(1)\n\n\n        # (M, 6), M: num of pre-filtered tets, storing indices (besides -1) from 0 to num_mask_edges\n        idx_map = idx_map.reshape(-1,6)\n\n        v_id = torch.pow(2, torch.arange(4, dtype=torch.long, device=\"cuda\"))\n        tetindex = (occ_fx4[valid_tets] * v_id.unsqueeze(0)).sum(-1)\n        # triangle count\n        num_triangles = self.num_triangles_table[tetindex]\n\n        # Get global face index (static, does not depend on topology), before mSDF processing\n        num_tets = tet_fx4.shape[0]\n        tet_gidx = torch.arange(num_tets, dtype=torch.long, device=\"cuda\")[valid_tets]\n        face_gidx_pre = torch.cat((\n            tet_gidx[num_triangles == 1]*2,\n            torch.stack((tet_gidx[num_triangles == 2]*2, tet_gidx[num_triangles == 2]*2 + 1), dim=-1).view(-1)\n        ), dim=0)\n\n        # Get uv before mSDF processing\n        uvs_pre, uv_idx_pre = self.map_uv(face_gidx_pre, num_tets*2)\n\n        # Generate triangle indices before msdf processing\n        faces = torch.cat((\n            torch.gather(input=idx_map[num_triangles == 1], dim=1, index=self.triangle_table[tetindex[num_triangles == 1]][:, :3]).reshape(-1,3),\n            torch.gather(input=idx_map[num_triangles == 2], dim=1, index=self.triangle_table[tetindex[num_triangles == 2]][:, :6]).reshape(-1,3),\n        ), dim=0)\n\n        v_nrm, t_nrm_idx = auto_normals(verts, faces)\n        v_tng, _ = compute_tangents(verts, uvs_pre, v_nrm, faces, faces, faces)\n        \n        ###### Triangulation with mSDF\n        ### Note: we allow area-0 triangular faces for convenience. Can always remove them during post-processing\n        with torch.no_grad():\n            mesh_edge_tri = torch.gather(input=idx_map[num_triangles == 1], dim=1, \n                    index=self.mesh_edge_table[tetindex[num_triangles == 1]][:, [0, 1, 1, 2, 2, 0]]\n                ).view(-1, 3, 2)\n            mesh_edge_quad = torch.gather(input=idx_map[num_triangles == 2], dim=1, \n                    index=self.mesh_edge_table[tetindex[num_triangles == 2]][:, [0, 1, 1, 2, 2, 3, 3, 0]]\n                ).view(-1, 4, 2)\n            mocc_fx3 = (msdf_vert[mesh_edge_tri[:, :, 0].reshape(-1)].reshape(-1, 3) > 0).long()\n            mocc_fx4 = (msdf_vert[mesh_edge_quad[:, :, 0].reshape(-1)].reshape(-1, 4) > 0).long()\n\n\n        ### Attributes to be interpolated for (non-watertight) mesh vertices on the boundary\n        edges_to_interp_vpos_tri = verts[mesh_edge_tri.reshape(-1)].reshape(-1,2,3)\n        edges_to_interp_vpos_quad = verts[mesh_edge_quad.reshape(-1)].reshape(-1,2,3)\n        edges_to_interp_tng_tri = v_tng[mesh_edge_tri.reshape(-1)].reshape(-1,2,3)\n        edges_to_interp_tng_quad = v_tng[mesh_edge_quad.reshape(-1)].reshape(-1,2,3)\n        edges_to_interp_msdf_tri = msdf_vert[mesh_edge_tri.reshape(-1)].reshape(-1,2,1)\n        edges_to_interp_msdf_quad = msdf_vert[mesh_edge_quad.reshape(-1)].reshape(-1,2,1)\n        edges_to_interp_msdf_tri_stopvgd = msdf_vert_stopvgd[mesh_edge_tri.reshape(-1)].reshape(-1,2,1)\n        edges_to_interp_msdf_quad_stopvgd = msdf_vert_stopvgd[mesh_edge_quad.reshape(-1)].reshape(-1,2,1)\n\n\n        ### Linear interpolation on mesh edges (triangle / quad faces)\n        denominator_tri_nonzero = torch.sign(edges_to_interp_msdf_tri[:,:,0]).sum(dim=1).abs() != 2\n        denominator_quad_nonzero = torch.sign(edges_to_interp_msdf_quad[:,:,0]).sum(dim=1).abs() != 2\n\n        edges_to_interp_msdf_tri[:,-1] *= -1\n        edges_to_interp_msdf_quad[:,-1] *= -1\n        denominator_tri = edges_to_interp_msdf_tri.sum(1, keepdim=True)\n        denominator_quad = edges_to_interp_msdf_quad.sum(1, keepdim=True)\n\n        denominator_tri_nonzero = (denominator_tri[:,0,0].abs() > 1e-12) & denominator_tri_nonzero\n        denominator_quad_nonzero = (denominator_quad[:,0,0].abs() > 1e-12) & denominator_quad_nonzero\n\n\n\n        edges_to_interp_msdf_tri_new = torch.zeros_like(edges_to_interp_msdf_tri)\n        edges_to_interp_msdf_quad_new = torch.zeros_like(edges_to_interp_msdf_quad)\n        edges_to_interp_msdf_tri_new[denominator_tri_nonzero] = torch.flip(edges_to_interp_msdf_tri[denominator_tri_nonzero], [1]) / denominator_tri[denominator_tri_nonzero]\n        edges_to_interp_msdf_quad_new[denominator_quad_nonzero] = torch.flip(edges_to_interp_msdf_quad[denominator_quad_nonzero], [1]) / denominator_quad[denominator_quad_nonzero]\n\n        edges_to_interp_msdf_tri = edges_to_interp_msdf_tri_new\n        edges_to_interp_msdf_quad = edges_to_interp_msdf_quad_new\n\n        ### Append additional boundary vertices (with negligible corner cases). Notice that unused vertices are included for efficiency reasons.\n        verts_aug = torch.cat([\n                    verts,\n                    (edges_to_interp_vpos_tri * edges_to_interp_msdf_tri).sum(1), \n                    (edges_to_interp_vpos_quad * edges_to_interp_msdf_quad).sum(1)\n                ],\n            dim=0)\n\n        v_tng_aug = torch.cat([\n                    v_tng,\n                    (edges_to_interp_tng_tri * edges_to_interp_msdf_tri).sum(1), \n                    (edges_to_interp_tng_quad * edges_to_interp_msdf_quad).sum(1)\n                ],\n            dim=0)\n\n        ### NOTE: important to stop gradients from passing through the 'interpolation coefficients' (basically the 'coordinates' of boundary vertices)\n        msdf_vert_tri_stopvgd = (edges_to_interp_msdf_tri_stopvgd * edges_to_interp_msdf_tri.detach()).sum(1).squeeze(dim=-1)\n        msdf_vert_quad_stopvgd = (edges_to_interp_msdf_quad_stopvgd * edges_to_interp_msdf_quad.detach()).sum(1).squeeze(dim=-1)\n\n        msdf_vert_aug_stopvgd = torch.cat([\n            msdf_vert_stopvgd,\n            msdf_vert_tri_stopvgd,\n            msdf_vert_quad_stopvgd,\n        ])\n\n        msdf_vert_boundary_stopvgd = msdf_vert_aug_stopvgd[msdf_vert.size(0):] ## not all boundary vertices but good enough\n\n        ### Determine how to cut polygon faces by checking the look-up tables\n        with torch.no_grad():\n            v_id_msdf_tri = torch.flip(torch.pow(2, torch.arange(3, dtype=torch.long, device=\"cuda\")), dims=[0])\n            v_id_msdf_quad = torch.flip(torch.pow(2, torch.arange(4, dtype=torch.long, device=\"cuda\")), dims=[0])\n            mesh_index_tri = (mocc_fx3 * v_id_msdf_tri.unsqueeze(0)).sum(-1)\n            mesh_index_quad = (mocc_fx4 * v_id_msdf_quad.unsqueeze(0)).sum(-1)\n\n\n        idx_map_tri = torch.cat([mesh_edge_tri[:, :, 0], verts.size(0) + torch.arange(mesh_edge_tri.size(0) * 3, device='cuda').view(-1, 3)], dim=-1)\n        idx_map_quad = torch.cat([mesh_edge_quad[:, :, 0], verts.size(0) + mesh_edge_tri.size(0) * 3 + torch.arange(mesh_edge_quad.size(0) * 4, device='cuda').view(-1, 4)], dim=-1)\n\n        num_triangles_tri = self.num_triangles_tri_table[mesh_index_tri]\n        num_triangles_quad = self.num_triangles_quad_table[mesh_index_quad]\n\n        ### Cut the polygon faces (case-by-case)\n        faces_aug = torch.cat((\n            torch.gather(input=idx_map_tri[num_triangles_tri == 1], dim=1, index=self.triangle_table_tri[mesh_index_tri[num_triangles_tri == 1]][:, :3]).view(-1, 3),\n            torch.gather(input=idx_map_tri[num_triangles_tri == 2], dim=1, index=self.triangle_table_tri[mesh_index_tri[num_triangles_tri == 2]][:, :6]).view(-1, 3),\n            torch.gather(input=idx_map_quad[num_triangles_quad == 1], dim=1, index=self.triangle_table_quad[mesh_index_quad[num_triangles_quad == 1]][:, :3]).view(-1, 3),\n            torch.gather(input=idx_map_quad[num_triangles_quad == 2], dim=1, index=self.triangle_table_quad[mesh_index_quad[num_triangles_quad == 2]][:, :6]).view(-1, 3),\n            torch.gather(input=idx_map_quad[num_triangles_quad == 3], dim=1, index=self.triangle_table_quad[mesh_index_quad[num_triangles_quad == 3]][:, :9]).view(-1, 3),\n            torch.gather(input=idx_map_quad[num_triangles_quad == 4], dim=1, index=self.triangle_table_quad[mesh_index_quad[num_triangles_quad == 4]][:, :12]).view(-1, 3),\n        ), dim=0)\n\n        ### Mark all unused vertices (only for convenience in visualization; not necessary)\n        with torch.no_grad():\n            referenced_vert_idx = faces_aug.unique()\n            mask = torch.ones(verts_aug.size(0))\n            mask[referenced_vert_idx] = 0\n        verts_aug[mask.bool()] = 0\n\n\n        if output_watertight_template:\n            extra = {\n                'n_verts_watertight': verts.size(0),\n                'vertices_watertight': verts,\n                'faces_watertight': faces, \n                'v_tng_watertight': v_tng,\n                'msdf': msdf_vert_aug_stopvgd,\n                'msdf_watertight': msdf_vert_stopvgd,\n                'msdf_boundary': msdf_vert_boundary_stopvgd,\n            }\n        else:\n            extra = {\n                'msdf': msdf_vert_aug_stopvgd,\n                'msdf_watertight': msdf_vert_stopvgd,\n                'msdf_boundary': msdf_vert_boundary_stopvgd,\n            }\n\n        return verts_aug, faces_aug, None, None, v_tng_aug, extra\n    \n\n    @torch.no_grad()\n    def marching_from_auggrid(self, pos_nx3, sdf_n, tet_fx4, \n                          sorted_tet_edges_fx6x2, coeff_sdf_interp, verts_discretized, \n                          midpoint_msdf_sign_n, occgrid\n                          ):\n        sdf_n = sdf_n.float()\n        ### To determine if tets are valid\n        ### Step 1: SDF criteria\n        occ_n = sdf_n > 0\n        occ_fx4 = occ_n[tet_fx4.reshape(-1)].reshape(-1,4)\n        occ_sum = torch.sum(occ_fx4, -1)\n\n        valid_tets = (occ_sum>0) & (occ_sum<4)\n        v_id = torch.pow(2, torch.arange(4, dtype=torch.long, device=\"cuda\"))\n        tetindex = (occ_fx4[valid_tets] * v_id.unsqueeze(0)).sum(-1)\n\n\n        # find all vertices\n        all_edges = sorted_tet_edges_fx6x2.reshape(-1, 6, 2)[valid_tets].reshape(-1, 2)\n        all_edges = all_edges.view(-1, 1, 2)\n        unique_edges, idx_map = torch.unique(all_edges, dim=0, return_inverse=True)\n\n\n        unique_edges = unique_edges.long()\n        mask_edges = occ_n[unique_edges.reshape(-1)].reshape(-1,2).sum(-1) == 1\n        mapping = torch.ones((unique_edges.shape[0]), dtype=torch.long, device=\"cuda\") * -1\n        mapping[mask_edges] = torch.arange(mask_edges.sum(), dtype=torch.long, device=\"cuda\")\n        idx_map = mapping[idx_map] # map edges to verts\n\n        interp_v = unique_edges[mask_edges]\n\n\n        edges_to_interp = pos_nx3[interp_v.reshape(-1)].reshape(-1,2,3)\n        edges_to_interp_canonical = verts_discretized[interp_v.reshape(-1)].reshape(-1,2,3).float()\n        verts_canonical = (edges_to_interp_canonical[:, 0] + edges_to_interp_canonical[:, 1]) / 2.0\n\n        tetedge_cano_midpts = verts_discretized[interp_v.reshape(-1)].float().reshape(-1,2,3).mean(dim=1).long()\n\n        coeff_sdf_interp = coeff_sdf_interp[tetedge_cano_midpts[:, 0], tetedge_cano_midpts[:, 1], tetedge_cano_midpts[:, 2]].view(-1, 1).clamp(0, 1)\n        verts = edges_to_interp[:, 1] * coeff_sdf_interp + edges_to_interp[:, 0] * (1 - coeff_sdf_interp)\n\n        msdf_vert = midpoint_msdf_sign_n[tetedge_cano_midpts[:, 0], tetedge_cano_midpts[:, 1], tetedge_cano_midpts[:, 2]]\n\n        # (M, 6), M: num of pre-filtered tets, storing indices (besides -1) from 0 to num_mask_edges\n        idx_map = idx_map.reshape(-1,6)\n\n        v_id = torch.pow(2, torch.arange(4, dtype=torch.long, device=\"cuda\"))\n        tetindex = (occ_fx4[valid_tets] * v_id.unsqueeze(0)).sum(-1)\n        # triangle count\n        num_triangles = self.num_triangles_table[tetindex]\n\n        # Get global face index (static, does not depend on topology), before mSDF processing\n        num_tets = tet_fx4.shape[0]\n        tet_gidx = torch.arange(num_tets, dtype=torch.long, device=\"cuda\")[valid_tets]\n        face_gidx_pre = torch.cat((\n            tet_gidx[num_triangles == 1]*2,\n            torch.stack((tet_gidx[num_triangles == 2]*2, tet_gidx[num_triangles == 2]*2 + 1), dim=-1).view(-1)\n        ), dim=0)\n\n        valid_tet_gidx = torch.cat([tet_gidx[num_triangles == 1], tet_gidx[num_triangles == 2]], dim=0)\n\n        # Get uv before mSDF processing\n        uvs_pre, uv_idx_pre = self.map_uv(face_gidx_pre, num_tets*2)\n\n        # Generate triangle indices before vis processing\n        faces = torch.cat((\n            torch.gather(input=idx_map[num_triangles == 1], dim=1, index=self.triangle_table[tetindex[num_triangles == 1]][:, :3]).reshape(-1,3),\n            torch.gather(input=idx_map[num_triangles == 2], dim=1, index=self.triangle_table[tetindex[num_triangles == 2]][:, :6]).reshape(-1,3),\n        ), dim=0)\n\n        v_nrm, t_nrm_idx = auto_normals(verts, faces)\n        v_tng, _ = compute_tangents(verts, uvs_pre, v_nrm, faces, faces, faces)\n\n        ###### Triangulation with mSDF\n        # edge_indices_tri = self.pre_mesh_edge_table[tetindex[num_triangles == 1]][:, [0, 1, 1, 2, 2, 0]]\n        # edge_indices_quad = self.pre_mesh_edge_table[tetindex[num_triangles == 2]][:, [0, 1, 1, 2, 2, 3, 3, 0]]\n        # pre_mesh_edge_tri = torch.gather(input=idx_map[num_triangles == 1], dim=1, \n        #         index=edge_indices_tri\n        #     ).view(-1, 3, 2)\n        # pre_mesh_edge_quad = torch.gather(input=idx_map[num_triangles == 2], dim=1, \n        #         index=edge_indices_quad\n        #     ).view(-1, 4, 2)\n                              \n        pre_mesh_edge_tri = torch.gather(input=idx_map[num_triangles == 1], dim=1, \n                index=self.mesh_edge_table[tetindex[num_triangles == 1]][:, [0, 1, 1, 2, 2, 0]]\n            ).view(-1, 3, 2)\n        pre_mesh_edge_quad = torch.gather(input=idx_map[num_triangles == 2], dim=1, \n                index=self.mesh_edge_table[tetindex[num_triangles == 2]][:, [0, 1, 1, 2, 2, 3, 3, 0]]\n            ).view(-1, 4, 2)\n\n        msdf_positive_fx3 = (msdf_vert[pre_mesh_edge_tri[:, :, 0].reshape(-1)].reshape(-1, 3) > 0).long()\n        msdf_positive_fx4 = (msdf_vert[pre_mesh_edge_quad[:, :, 0].reshape(-1)].reshape(-1, 4) > 0).long()\n\n                              \n\n\n        edges_to_interp_prevert_tri = verts[pre_mesh_edge_tri.reshape(-1)].reshape(-1,2,3)\n        edges_to_interp_prevert_quad = verts[pre_mesh_edge_quad.reshape(-1)].reshape(-1,2,3)\n        edges_to_interp_pretng_tri = v_tng[pre_mesh_edge_tri.reshape(-1)].reshape(-1,2,3)\n        edges_to_interp_pretng_quad = v_tng[pre_mesh_edge_quad.reshape(-1)].reshape(-1,2,3)\n\n\n        edges_to_interp_sort_tri = verts_canonical[pre_mesh_edge_tri.reshape(-1)].reshape(-1,2,3)\n        edges_to_interp_sort_quad = verts_canonical[pre_mesh_edge_quad.reshape(-1)].reshape(-1,2,3)\n\n\n        meshocc_loc_tri = (edges_to_interp_sort_tri.mean(dim=1) * 2.0).long()\n        meshocc_loc_quad = (edges_to_interp_sort_quad.mean(dim=1) * 2.0).long()\n\n\n        msdf_coeff_tri = occgrid[meshocc_loc_tri[:, 0], meshocc_loc_tri[:, 1], meshocc_loc_tri[:, 2]] * 0.5 + 0.5\n        msdf_coeff_quad = occgrid[meshocc_loc_quad[:, 0], meshocc_loc_quad[:, 1], meshocc_loc_quad[:, 2]] * 0.5 + 0.5\n\n\n        msdf_coeff_tri = torch.stack([msdf_coeff_tri, 1 - msdf_coeff_tri], dim=-1)\n        msdf_coeff_quad = torch.stack([msdf_coeff_quad, 1 - msdf_coeff_quad], dim=-1)\n\n\n        inscribed_edge_twopoint_order_tri = torch.sign(edges_to_interp_sort_tri[:, 0, :] - edges_to_interp_sort_tri[:, 1, :])\n        inscribed_edge_twopoint_order_tri = (inscribed_edge_twopoint_order_tri * torch.tensor([16, 4, 1], device=inscribed_edge_twopoint_order_tri.device).view(1, -1)).sum(dim=-1)\n        inscribed_edge_twopoint_order_tri = torch.stack([inscribed_edge_twopoint_order_tri, -inscribed_edge_twopoint_order_tri], dim=-1)\n        _, inscribed_edge_twopoint_order_tri = inscribed_edge_twopoint_order_tri.sort(dim=-1, descending=True)\n\n        inscribed_edge_twopoint_order_quad = torch.sign(edges_to_interp_sort_quad[:, 0, :] - edges_to_interp_sort_quad[:, 1, :])\n        inscribed_edge_twopoint_order_quad = (inscribed_edge_twopoint_order_quad * torch.tensor([16, 4, 1], device=inscribed_edge_twopoint_order_quad.device).view(1, -1)).sum(dim=-1)\n        inscribed_edge_twopoint_order_quad = torch.stack([inscribed_edge_twopoint_order_quad, -inscribed_edge_twopoint_order_quad], dim=-1)\n        _, inscribed_edge_twopoint_order_quad = inscribed_edge_twopoint_order_quad.sort(dim=-1, descending=True)\n\n        msdf_coeff_tri = torch.gather(\n            input=msdf_coeff_tri, \n            dim=-1, \n            index=inscribed_edge_twopoint_order_tri.view(-1, 2)\n        ).view(-1, 2, 1)\n\n        msdf_coeff_quad = torch.gather(\n            input=msdf_coeff_quad, \n            dim=-1, \n            index=inscribed_edge_twopoint_order_quad.view(-1, 2)\n        ).view(-1, 2, 1)\n\n        msdf_coeff_tri = msdf_coeff_tri.view(-1, 2, 1)\n        msdf_coeff_quad = msdf_coeff_quad.view(-1, 2, 1)\n\n\n        verts_aug = torch.cat([\n                    verts,\n                    (edges_to_interp_prevert_tri * msdf_coeff_tri).sum(1), \n                    (edges_to_interp_prevert_quad * msdf_coeff_quad).sum(1),\n                ],\n            dim=0)\n\n        v_tng_aug = torch.cat([\n                    v_tng,\n                    (edges_to_interp_pretng_tri * msdf_coeff_tri).sum(1), \n                    (edges_to_interp_pretng_quad * msdf_coeff_quad).sum(1),\n                ],\n            dim=0)\n\n        msdf_vert_aug = torch.cat([\n            msdf_vert,\n            torch.zeros(v_tng_aug.size(0) - v_tng.size(0)).cuda()\n        ])\n\n        v_id_msdf_tri = torch.flip(torch.pow(2, torch.arange(3, dtype=torch.long, device=\"cuda\")), dims=[0]) ## do this flip because the triangle table uses a different assumption by mistake..\n        v_id_msdf_quad = torch.flip(torch.pow(2, torch.arange(4, dtype=torch.long, device=\"cuda\")), dims=[0])\n        premesh_index_tri = (msdf_positive_fx3 * v_id_msdf_tri.unsqueeze(0)).sum(-1)\n        premesh_index_quad = (msdf_positive_fx4 * v_id_msdf_quad.unsqueeze(0)).sum(-1)\n\n        idx_map_tri = torch.cat([pre_mesh_edge_tri[:, :, 0], verts.size(0) + torch.arange(pre_mesh_edge_tri.size(0) * 3, device='cuda').view(-1, 3)], dim=-1)\n        idx_map_quad = torch.cat([pre_mesh_edge_quad[:, :, 0], verts.size(0) + pre_mesh_edge_tri.size(0) * 3 + torch.arange(pre_mesh_edge_quad.size(0) * 4, device='cuda').view(-1, 4)], dim=-1)\n\n        num_triangles_tri = self.num_triangles_tri_table[premesh_index_tri]\n        num_triangles_quad = self.num_triangles_quad_table[premesh_index_quad]\n\n        faces_aug = torch.cat((\n            torch.gather(input=idx_map_tri[num_triangles_tri == 1], dim=1, index=self.triangle_table_tri[premesh_index_tri[num_triangles_tri == 1]][:, :3]).view(-1, 3),\n            torch.gather(input=idx_map_tri[num_triangles_tri == 2], dim=1, index=self.triangle_table_tri[premesh_index_tri[num_triangles_tri == 2]][:, :6]).view(-1, 3),\n            torch.gather(input=idx_map_quad[num_triangles_quad == 1], dim=1, index=self.triangle_table_quad[premesh_index_quad[num_triangles_quad == 1]][:, :3]).view(-1, 3),\n            torch.gather(input=idx_map_quad[num_triangles_quad == 2], dim=1, index=self.triangle_table_quad[premesh_index_quad[num_triangles_quad == 2]][:, :6]).view(-1, 3),\n            torch.gather(input=idx_map_quad[num_triangles_quad == 3], dim=1, index=self.triangle_table_quad[premesh_index_quad[num_triangles_quad == 3]][:, :9]).view(-1, 3),\n            torch.gather(input=idx_map_quad[num_triangles_quad == 4], dim=1, index=self.triangle_table_quad[premesh_index_quad[num_triangles_quad == 4]][:, :12]).view(-1, 3),\n        ), dim=0)\n\n        return verts_aug, faces_aug, None, None, v_tng_aug, verts, valid_tet_gidx, msdf_vert_aug, msdf_vert\n"
  },
  {
    "path": "geometry/gshell_tets_geometry.py",
    "content": "# Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. \n#\n# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual\n# property and proprietary rights in and to this material, related\n# documentation and any modifications thereto. Any use, reproduction, \n# disclosure or distribution of this material and related documentation \n# without an express license agreement from NVIDIA CORPORATION or \n# its affiliates is strictly prohibited.\n\nimport os\nfrom tqdm import trange\n\nimport numpy as np\nimport torch\nimport torch.nn.functional as F\n\nfrom render import mesh\nfrom render import render\nimport render.optixutils as ou\nfrom render import regularizer\n\nfrom .gshell_tets import GShell_Tets\n\nimport kaolin\n\nfrom .mlp import MLP\n\n\n###############################################################################\n# Regularizer\n###############################################################################\n\ndef compute_sdf_reg_loss(sdf, all_edges):\n    sdf_f1x6x2 = sdf[all_edges.reshape(-1)].reshape(-1,2)\n    mask = torch.sign(sdf_f1x6x2[...,0]) != torch.sign(sdf_f1x6x2[...,1])\n    sdf_f1x6x2 = sdf_f1x6x2[mask]\n    sdf_diff = torch.nn.functional.binary_cross_entropy_with_logits(sdf_f1x6x2[...,0], (sdf_f1x6x2[...,1] > 0).float()) + \\\n            torch.nn.functional.binary_cross_entropy_with_logits(sdf_f1x6x2[...,1], (sdf_f1x6x2[...,0] > 0).float())\n    return sdf_diff\n\n###############################################################################\n#  Geometry interface\n###############################################################################\n\nclass GShellTetsGeometry(torch.nn.Module):\n    def __init__(self, grid_res, scale, FLAGS, offset=None, tet_init_file=None, extract_from_generative=False):\n        super(GShellTetsGeometry, self).__init__()\n\n        self.FLAGS         = FLAGS\n        self.grid_res      = grid_res\n        self.gshell_tets   = GShell_Tets()\n        self.scale         = scale\n        self.boxscale      = torch.tensor(FLAGS.boxscale).view(1, 3).cuda()\n\n        with torch.no_grad():\n            self.optix_ctx = ou.OptiXContext()\n\n            if tet_init_file is None:\n                tets = np.load('data/tets/{}_tets.npz'.format(self.grid_res))\n            else:\n                tets = np.load(tet_init_file)\n            print(f'using resolution {self.grid_res}')\n            self.verts    = torch.tensor(tets['vertices'], dtype=torch.float32, device='cuda')\n            self.original_verts = self.verts.clone() if extract_from_generative else None\n            self.verts    = self.verts - self.verts.mean(dim=0)\n            self.verts    = self.verts * scale * self.boxscale\n            self.indices  = torch.tensor(tets['indices'], dtype=torch.long, device='cuda')\n            self.generate_edges()\n\n            if extract_from_generative:\n                self.sorted_tetedges = torch.tensor(tets['tet_edges'], dtype=torch.long, device='cuda')\n                vertices = torch.tensor(tets['vertices'], dtype=torch.float32, device='cuda')\n                vertices_unique = vertices.view(-1).unique()\n                dx = (vertices_unique[1] - vertices_unique[0]) / 2.0 ### denser grid for edge + tet features\n                vertices_discretized = (\n                    ((vertices - vertices.min()) / dx)\n                ).long()\n                self.verts_discretized = vertices_discretized.long().float() ### used to identify where to store edge + tet features\n\n            if offset is None:\n                offset = 0.0\n            else:\n                offset = torch.tensor(offset).cuda().view(1, 3)\n            self.offset = offset\n\n        if self.FLAGS.use_sdf_mlp:\n            self.sdf    = torch.nn.Parameter(torch.zeros_like(self.verts[:, 0]), requires_grad=True) ## placeholder\n            self.register_parameter('sdf', self.sdf)\n            self.sdf_net = MLP(\n                skip_in=self.FLAGS.skip_in,\n                n_freq=self.FLAGS.n_freq,\n                n_hidden=self.FLAGS.n_hidden,\n                d_hidden=self.FLAGS.d_hidden,\n                use_float16=self.FLAGS.use_float16\n            )\n            self.sdf_net.cuda()\n\n            optimizer = torch.optim.Adam(self.sdf_net.parameters(), lr=1e-3)\n            for _ in trange(self.FLAGS.sdf_mlp_pretrain_steps):\n                scaled_verts = self.verts / self.boxscale\n                loss = (self.sdf_net(self.verts) - (scaled_verts.norm(dim=1, keepdim=True) - self.FLAGS.sphere_init_norm)).pow(2).mean()\n                optimizer.zero_grad()\n                loss.backward()\n                optimizer.step()\n            print('sdf net trained with loss:', loss)\n\n        else:\n            # Random init\n            if not self.FLAGS.sphere_init:\n                sdf = torch.rand_like(self.verts[:,0]) - 0.1\n            else:\n                scaled_verts = self.verts / self.boxscale\n                sdf = scaled_verts.norm(dim=1) - 0.5\n            self.sdf    = torch.nn.Parameter(sdf.clone().detach(), requires_grad=True)\n            self.register_parameter('sdf', self.sdf)\n\n\n        if self.FLAGS.use_msdf_mlp:\n            self.msdf    = torch.nn.Parameter(torch.zeros_like(self.verts[:, 0]), requires_grad=True) ## placeholder\n            self.register_parameter('msdf', self.msdf)\n            self.msdf_net = MLP(\n                skip_in=self.FLAGS.skip_in,\n                n_freq=self.FLAGS.n_freq,\n                n_hidden=self.FLAGS.n_hidden,\n                d_hidden=self.FLAGS.d_hidden,\n                use_float16=self.FLAGS.use_float16\n            )\n            self.msdf_net.cuda()\n            optimizer = torch.optim.Adam(self.msdf_net.parameters(), lr=1e-3)\n            for _ in trange(100):\n                scaled_verts = self.verts / self.boxscale\n                loss = (self.msdf_net(self.verts) - 0.1).pow(2).mean()\n                optimizer.zero_grad()\n                loss.backward()\n                optimizer.step()\n            print('sdf net trained with loss:', loss)\n            del optimizer\n        else:\n            msdf         = (torch.rand_like(self.verts[:,0]) - 0.01).clamp(-1, 1)\n            self.msdf    = torch.nn.Parameter(msdf.clone().detach(), requires_grad=True)\n            self.register_parameter('msdf', self.msdf)\n\n        self.deform = torch.nn.Parameter(torch.zeros_like(self.verts), requires_grad=True)\n        self.register_parameter('deform', self.deform)\n\n        self.clamp_deform()\n\n    @torch.no_grad()\n    def generate_edges(self):\n        with torch.no_grad():\n            edges = torch.tensor([0,1,0,2,0,3,1,2,1,3,2,3], dtype = torch.long, device = \"cuda\")\n            all_edges = self.indices[:,edges].reshape(-1,2)\n            all_edges_sorted = torch.sort(all_edges, dim=1)[0]\n            self.all_edges = torch.unique(all_edges_sorted, dim=0)\n            self.max_displacement = 1.0 / self.grid_res * self.scale / 2.1\n\n    @torch.no_grad()\n    def getAABB(self):\n        return torch.min(self.verts, dim=0).values, torch.max(self.verts, dim=0).values\n\n    @torch.no_grad()\n    def clamp_deform(self):\n        if not self.FLAGS.use_tanh_deform:\n            self.deform.data[:] = self.deform.clamp(-1.0, 1.0)\n        self.msdf.data[:] = self.msdf.clamp(-2.0, 2.0)\n\n    def getMesh_from_augmented_grid_withocc(self, material, sdf_sign, sdf_coeff, msdf_sign, occgrid):\n        # Run DM tet to get a base mesh\n        v_deformed = self.verts + self.max_displacement * self.deform\n        if self.FLAGS.use_sdf_mlp:\n            sdf = self.sdf_net(v_deformed)\n        else:\n            sdf = self.sdf\n\n        verts, faces, uvs, uv_idx, v_tng, v_pos_original, tet_gidx, v_msdf, msdf_vert_original = self.gshell_tets.marching_from_auggrid(\n            v_deformed, sdf_sign, self.indices,\n            self.sorted_tetedges, sdf_coeff, self.verts_discretized, \n            msdf_sign, \n            occgrid)\n        imesh = mesh.Mesh(verts, faces, v_tex=uvs, t_tex_idx=uv_idx, material=material)\n\n        # Run mesh operations to generate tangent space\n        imesh = mesh.auto_normals(imesh)\n        imesh = mesh.compute_tangents(imesh, v_tng=v_tng)\n        return {\n            'imesh': imesh,\n            'sdf': sdf,\n            'v_msdf': v_msdf,\n        }\n\n    def getMesh(self, material):\n        v_deformed = self.verts + self.max_displacement * self.deform\n        if self.FLAGS.use_sdf_mlp:\n            sdf = self.sdf_net(v_deformed)\n        else:\n            sdf = self.sdf\n        \n\n        if self.FLAGS.use_msdf_mlp:\n            msdf = self.msdf_net(v_deformed)\n        else:\n            msdf = self.msdf\n\n        v_deformed = v_deformed + self.offset\n\n        verts, faces, uvs, uv_idx, v_tng, extra = self.gshell_tets(\n            v_deformed, sdf, msdf, self.indices)\n        imesh = mesh.Mesh(verts, faces, v_tex=uvs, t_tex_idx=uv_idx, material=material)\n\n        with torch.no_grad():\n            ou.optix_build_bvh(self.optix_ctx, imesh.v_pos.contiguous(), imesh.t_pos_idx.int(), rebuild=1)\n\n        # Run mesh operations to generate tangent space\n        imesh = mesh.auto_normals(imesh)\n        return_dict = {\n            'imesh': imesh,\n            'sdf': sdf,\n            'msdf': extra['msdf'],\n            'msdf_watertight': extra['msdf_watertight'],\n            'msdf_boundary': extra['msdf_boundary'],\n            'n_verts_watertight': extra['n_verts_watertight'],\n        }\n\n        if self.FLAGS.visualize_watertight:\n            imesh_watertight = mesh.Mesh(extra['vertices_watertight'], extra['faces_watertight'], v_tex=None, t_tex_idx=None, material=material)\n            imesh_watertight = mesh.auto_normals(imesh_watertight)\n            return_dict['imesh_watertight'] = imesh_watertight\n        return return_dict\n\n    def render(self, glctx, target, lgt, opt_material, bsdf=None, denoiser=None, shadow_scale=1.0,\n            use_uv=False):\n        opt_mesh_dict = self.getMesh(opt_material)\n        opt_mesh = opt_mesh_dict['imesh']\n        opt_mesh_watertight = opt_mesh_dict['imesh_watertight'] if 'imesh_watertight' in opt_mesh_dict else None\n        if opt_mesh.v_pos.size(0) != 0:\n            sampled_pts = kaolin.ops.mesh.sample_points(opt_mesh.v_pos[None,...], opt_mesh.t_pos_idx, 50000)[0][0]\n            opt_mesh_dict['sampled_pts'] = sampled_pts\n        else:\n            opt_mesh_dict['sampled_pts'] = None\n    \n        extra_dict = {\n            'msdf': opt_mesh_dict['msdf'],\n        }\n        opt_mesh_dict['buffers'] = render.render_mesh(\n            self.FLAGS, glctx, opt_mesh, target['mvp'], target['campos'], lgt, target['resolution'], spp=target['spp'], \n            msaa=True, background=target['background'], bsdf=bsdf, use_uv=use_uv,\n            optix_ctx=self.optix_ctx, denoiser=denoiser, shadow_scale=shadow_scale,\n            extra_dict=extra_dict)\n        if self.FLAGS.visualize_watertight:\n            opt_mesh_dict['buffers_watertight'] = render.render_mesh(\n                self.FLAGS, glctx, opt_mesh_watertight, target['mvp'], target['campos'], lgt, target['resolution'], spp=target['spp'], \n                msaa=True, background=target['background'], bsdf=bsdf, use_uv=use_uv,\n                optix_ctx=self.optix_ctx, denoiser=denoiser, shadow_scale=shadow_scale,\n                extra_dict=extra_dict)\n        return opt_mesh_dict\n\n    def tick(self, glctx, target, lgt, opt_material, loss_fn, iteration, denoiser):\n\n        t_iter = iteration / self.FLAGS.iter\n\n        # ==============================================================================================\n        #  Render optimizable object with identical conditions\n        # ==============================================================================================\n        shadow_ramp = min(iteration / 1000, 1.0) ### set occlusion ray influence\n        if denoiser is not None: denoiser.set_influence(shadow_ramp)\n        opt_mesh_dict = self.render(glctx, target, lgt, opt_material, \n            denoiser=denoiser,\n            shadow_scale=shadow_ramp)\n        buffers = opt_mesh_dict['buffers']\n\n        # ==============================================================================================\n        #  Compute loss\n        # ==============================================================================================\n\n        with torch.no_grad():\n            # Image-space loss, split into a coverage component and a color component\n            color_ref = target['img']\n            gt_mask = color_ref[..., 3:]\n\n        img_loss = F.mse_loss(buffers['shaded'][..., 3:], color_ref[..., 3:]) \n        img_loss = img_loss + loss_fn(buffers['shaded'][..., 0:3] * color_ref[..., 3:], color_ref[..., 0:3] * color_ref[..., 3:])\n\n\n        img_loss = img_loss + 5e-1 * F.l1_loss(buffers['msdf_image'].clamp(min=0) * (gt_mask == 0).float(), torch.zeros_like(gt_mask))\n        img_loss = img_loss + 5e-1 * F.l1_loss(buffers['msdf_image'].clamp(max=0) * (gt_mask == 1).float(), torch.ones_like(gt_mask))\n\n        if self.FLAGS.use_img_2nd_layer:\n            color_ref_2nd = target['img_second']\n            img_loss = img_loss + F.mse_loss(buffers['shaded_second'][..., 3:], color_ref_2nd[..., 3:]) \n            img_loss = img_loss + loss_fn(buffers['shaded_second'][..., 0:3] * color_ref_2nd[..., 3:], color_ref_2nd[..., 0:3] * color_ref_2nd[..., 3:])\n\n        if self.FLAGS.use_depth:\n            depth_loss_scale = 100.\n            depth_loss = depth_loss_scale * ((buffers['invdepth'][:, :, :, :1] - target['invdepth'][:, :, :, :1]).abs()).mean()\n\n            if self.FLAGS.use_depth_2nd_layer:\n                depth_loss += 0.1 * depth_loss_scale * ((buffers['invdepth_second'][:, :, :, :1] - target['invdepth_second'][:, :, :, :1]).abs()).mean()\n        else:\n            depth_loss = torch.tensor(0., device=img_loss.device)\n\n        # Eikonal\n        if self.FLAGS.use_sdf_mlp and self.FLAGS.use_eikonal and opt_mesh_dict['sampled_pts'] is not None:\n            v = opt_mesh_dict['sampled_pts'].detach()\n            v.requires_grad = True\n\n            sdf_eik = self.sdf_net(v)\n            if self.FLAGS.eikonal_scale is None:\n                ### Default hardcoded Eikonal loss schedule\n                if iteration < 500:\n                    eik_coeff = 3e-1\n                elif iteration < 1000:\n                    eik_coeff = 1e-1\n                elif iteration < 2000:\n                    eik_coeff = 1e-1\n                else:\n                    eik_coeff = 1e-2\n            else:\n                eik_coeff = self.FLAGS.eikonal_scale\n\n            eik_loss = eik_coeff * (\n                torch.autograd.grad(sdf_eik.sum(), v, create_graph=True)[0].pow(2).sum(dim=-1).sqrt() - 1\n            ).pow(2).mean()\n        else:\n            eik_loss = torch.tensor(0., device=img_loss.device)\n\n        if self.FLAGS.use_mesh_msdf_reg:\n            mesh_msdf_regscale = (64 / self.grid_res) ** 3 # scale inversely proportional to grid_res^3\n            eps = 1e-3\n            open_scale = self.FLAGS.msdf_reg_open_scale\n            close_scale = self.FLAGS.msdf_reg_close_scale\n            eps = torch.tensor([eps]).cuda()\n\n            if open_scale > 0:\n                mesh_msdf_reg_loss = open_scale * mesh_msdf_regscale * F.huber_loss(\n                    opt_mesh_dict['msdf'].clamp(min=-eps).squeeze(), \n                    -eps.expand(opt_mesh_dict['msdf'].size(0)), \n                    reduction='sum'\n                )\n            else:\n                mesh_msdf_reg_loss = torch.tensor(0., device=img_loss.device)\n\n            if close_scale != 0:\n                with torch.no_grad():\n                    visible_verts = (opt_mesh_dict['imesh'].t_pos_idx[buffers['visible_triangles']]).unique()\n                    visible_boundary_verts = visible_verts[visible_verts >= opt_mesh_dict['n_verts_watertight']] - opt_mesh_dict['n_verts_watertight']\n                    visible_boundary_mask = torch.zeros(opt_mesh_dict['msdf_boundary'].size(0)).cuda()\n                    visible_boundary_mask[visible_boundary_verts] = 1\n                    visible_boundary_mask = visible_boundary_mask.bool()\n\n                boundary_msdf = opt_mesh_dict['msdf_boundary']\n                boundary_msdf = boundary_msdf[visible_boundary_mask]\n                mesh_msdf_reg_loss += close_scale * mesh_msdf_regscale * F.huber_loss(\n                    boundary_msdf.clamp(max=eps).squeeze(), \n                    eps.expand(boundary_msdf.size(0)), \n                    reduction='sum'\n                )\n        else:\n            mesh_msdf_reg_loss = torch.tensor(0., device=img_loss.device)\n\n        # SDF regularizer\n        sdf_weight = self.FLAGS.sdf_regularizer - (self.FLAGS.sdf_regularizer - 0.01) * min(1.0, 4.0 * t_iter)\n        sdf_reg_loss = compute_sdf_reg_loss(opt_mesh_dict['sdf'], self.all_edges).mean() * sdf_weight\n\n        # Monochrome shading regularizer\n        if 'diffuse_light' not in buffers:\n            monochrome_loss = torch.zeros_like(img_loss)\n        else:\n            monochrome_loss = regularizer.shading_loss(buffers['diffuse_light'], buffers['specular_light'], color_ref, self.FLAGS.lambda_diffuse, self.FLAGS.lambda_specular)\n\n        # Material smoothness regularizer\n        mtl_smooth_loss = regularizer.material_smoothness_grad(\n            buffers['kd_grad'], buffers['ks_grad'], buffers['normal_grad'], \n            lambda_kd=self.FLAGS.lambda_kd, lambda_ks=self.FLAGS.lambda_ks, lambda_nrm=self.FLAGS.lambda_nrm)\n\n        # Chroma regularizer\n        chroma_loss = regularizer.chroma_loss(buffers['kd'], color_ref, self.FLAGS.lambda_chroma)\n        assert 'perturbed_nrm' not in buffers # disable normal map in first pass\n\n\n        geo_reg_loss = sdf_reg_loss + eik_loss + mesh_msdf_reg_loss\n        shading_reg_loss =  monochrome_loss + mtl_smooth_loss + chroma_loss\n        reg_loss = geo_reg_loss + shading_reg_loss\n\n        return img_loss, depth_loss, reg_loss\n"
  },
  {
    "path": "geometry/mlp.py",
    "content": "import torch\nimport torch.nn as nn\nimport numpy as np\n\nfrom .embedding import Embedding\n\nclass MLP(nn.Module):\n    def __init__(self, n_freq=6, d_hidden=128, d_out=1, n_hidden=3, skip_in=[], use_float16=False):\n        super().__init__()\n        self.emb = Embedding(3, n_freq)\n        layers = [\n            nn.Linear(self.emb.out_channels, d_hidden),\n            nn.Softplus(beta=100)\n        ]\n        count = 2\n        self.skip_count = []\n        self.skip_in = skip_in\n        for i in range(n_hidden):\n            if i in skip_in:\n                layers.append(nn.Linear(d_hidden + self.emb.out_channels, d_hidden))\n                self.skip_count.append(count)\n            else:\n                layers.append(nn.Linear(d_hidden, d_hidden))\n            count += 1\n            layers.append(nn.Softplus(beta=100))\n            count += 1\n        layers.append(nn.Linear(d_hidden, d_out))\n        count += 1\n        self.net = nn.ModuleList(layers)\n        self.use_float16 = use_float16\n    \n    def forward(self, x):\n        emb = self.emb(x)\n        x = emb\n        with torch.autocast('cuda', dtype=torch.float16, enabled=self.use_float16):\n            for i, module in enumerate(self.net):\n                if i in self.skip_count:\n                    x = module(torch.cat([x, emb], dim=-1))\n                else:\n                    x = module(x)\n        return x"
  },
  {
    "path": "render/light.py",
    "content": "# Copyright (c) 2020-2022, NVIDIA CORPORATION. All rights reserved.\n#\n# NVIDIA CORPORATION and its licensors retain all intellectual property\n# and proprietary rights in and to this software, related documentation\n# and any modifications thereto. Any use, reproduction, disclosure or\n# distribution of this software and related documentation without an express\n# license agreement from NVIDIA CORPORATION is strictly prohibited.\n\nimport os\nimport numpy as np\nimport torch\nimport nvdiffrast.torch as dr\n\nfrom . import util\nfrom . import renderutils as ru\n\n######################################################################################\n# Monte-carlo sampled environment light with PDF / CDF computation\n######################################################################################\n\nclass EnvironmentLight:\n    LIGHT_MIN_RES = 16\n\n    MIN_ROUGHNESS = 0.08\n    MAX_ROUGHNESS = 0.5\n\n    def __init__(self, base):\n        self.mtx = None\n        self.base = base\n\n        self.pdf_scale = (self.base.shape[0] * self.base.shape[1]) / (2 * np.pi * np.pi)\n        self.update_pdf()\n\n    def xfm(self, mtx):\n        self.mtx = mtx\n\n    def parameters(self):\n        return [self.base]\n\n    def clone(self):\n        return EnvironmentLight(self.base.clone().detach())\n\n    def clamp_(self, min=None, max=None):\n        self.base.clamp_(min, max)\n\n    def update_pdf(self):\n        with torch.no_grad():\n            # Compute PDF\n            Y = util.pixel_grid(self.base.shape[1], self.base.shape[0])[..., 1]\n            self._pdf = torch.max(self.base, dim=-1)[0] * torch.sin(Y * np.pi) # Scale by sin(theta) for lat-long, https://cs184.eecs.berkeley.edu/sp18/article/25\n            self._pdf = self._pdf / torch.sum(self._pdf)\n\n            # Compute cumulative sums over the columns and rows\n            self.cols = torch.cumsum(self._pdf, dim=1)\n            self.rows = torch.cumsum(self.cols[:, -1:].repeat([1, self.cols.shape[1]]), dim=0)\n\n            # Normalize\n            self.cols = self.cols / torch.where(self.cols[:, -1:] > 0, self.cols[:, -1:], torch.ones_like(self.cols))\n            self.rows = self.rows / torch.where(self.rows[-1:, :] > 0, self.rows[-1:, :], torch.ones_like(self.rows))\n\n    @torch.no_grad()\n    def generate_image(self, res):\n        texcoord = util.pixel_grid(res[1], res[0])\n        return dr.texture(self.base[None, ...].contiguous(), texcoord[None, ...].contiguous(), filter_mode='linear')[0]\n\n######################################################################################\n# Load and store\n######################################################################################\n\n@torch.no_grad()\ndef _load_env_hdr(fn, scale=1.0, res=None, trainable=False):\n    latlong_img = torch.tensor(util.load_image(fn), dtype=torch.float32, device='cuda')*scale\n\n    if res is not None:\n        texcoord = util.pixel_grid(res[1], res[0])\n        latlong_img = torch.clamp(dr.texture(latlong_img[None, ...], texcoord[None, ...], filter_mode='linear')[0], min=0.0001)\n\n    print(\"EnvProbe,\", latlong_img.shape, \", min/max\", torch.min(latlong_img).item(), torch.max(latlong_img).item())\n    if trainable:\n        print(\"trainable light loaded\")\n        return EnvironmentLight(base=latlong_img.clone().detach().requires_grad_(True))\n    else:\n        return EnvironmentLight(base=latlong_img)\n\n@torch.no_grad()\ndef load_env(fn, scale=1.0, res=None, trainable=False):\n    if os.path.splitext(fn)[1].lower() == \".hdr\":\n        return _load_env_hdr(fn, scale, res, trainable=trainable)\n    else:\n        assert False, \"Unknown envlight extension %s\" % os.path.splitext(fn)[1]\n\n@torch.no_grad()\ndef save_env_map(fn, light):\n    assert isinstance(light, EnvironmentLight)\n    color = light.generate_image([512, 1024])\n    util.save_image_raw(fn, color.detach().cpu().numpy())\n\n######################################################################################\n# Create trainable with random initialization\n######################################################################################\n\ndef create_trainable_env_rnd(base_res, scale=0.5, bias=0.25):  \n    base = torch.rand(base_res, base_res, 3, dtype=torch.float32, device='cuda') * scale + bias\n    l = EnvironmentLight(base.clone().detach().requires_grad_(True))\n    return l\n    "
  },
  {
    "path": "render/material.py",
    "content": "# Copyright (c) 2020-2022, NVIDIA CORPORATION. All rights reserved.\n#\n# NVIDIA CORPORATION and its licensors retain all intellectual property\n# and proprietary rights in and to this software, related documentation\n# and any modifications thereto. Any use, reproduction, disclosure or\n# distribution of this software and related documentation without an express\n# license agreement from NVIDIA CORPORATION is strictly prohibited.\n\nimport os\nimport numpy as np\nimport torch\n\nfrom . import util\nfrom . import texture\nfrom . import mlptexture\n\n######################################################################################\n# .mtl material format loading / storing\n######################################################################################\n\ndef load_mtl(fn, clear_ks=True):\n    import re\n    mtl_path = os.path.dirname(fn)\n\n    # Read file\n    with open(fn, 'r') as f:\n        lines = f.readlines()\n\n    # Parse materials\n    materials = []\n    for line in lines:\n        split_line = re.split(' +|\\t+|\\n+', line.strip())\n        prefix = split_line[0].lower()\n        data = split_line[1:]\n        if 'newmtl' in prefix:\n            material = {'name' : data[0]}\n            materials += [material]\n        elif materials:\n            if 'bsdf' in prefix or 'map_kd' in prefix or 'map_ks' in prefix or 'bump' in prefix:\n                material[prefix] = data[0]\n            else:\n                material[prefix] = torch.tensor(tuple(float(d) for d in data), dtype=torch.float32, device='cuda')\n\n    # Convert everything to textures. Our code expects 'kd' and 'ks' to be texture maps. So replace constants with 1x1 maps\n    for mat in materials:\n        if not 'bsdf' in mat:\n            mat['bsdf'] = 'pbr'\n\n        if 'map_kd' in mat:\n            mat['kd'] = texture.load_texture2D(os.path.join(mtl_path, mat['map_kd']))\n        else:\n            mat['kd'] = texture.Texture2D(mat['kd'])\n        \n        if 'map_ks' in mat:\n            mat['ks'] = texture.load_texture2D(os.path.join(mtl_path, mat['map_ks']), channels=3)\n        else:\n            mat['ks'] = texture.Texture2D(mat['ks'])\n\n        if 'bump' in mat:\n            mat['normal'] = texture.load_texture2D(os.path.join(mtl_path, mat['bump']), lambda_fn=lambda x: x * 2 - 1, channels=3)\n\n        # Convert Kd from sRGB to linear RGB\n        mat['kd'] = texture.srgb_to_rgb(mat['kd'])\n\n        if clear_ks:\n            # Override ORM occlusion (red) channel by zeros. We hijack this channel\n            for mip in mat['ks'].getMips():\n                mip[..., 0] = 0.0 \n\n    return materials\n\ndef save_mtl(fn, material):\n    folder = os.path.dirname(fn)\n    with open(fn, \"w\") as f:\n        f.write('newmtl defaultMat\\n')\n        if material is not None:\n            f.write('bsdf   %s\\n' % material['bsdf'])\n            if 'kd' in material.keys():\n                f.write('map_Kd texture_kd.png\\n')\n                texture.save_texture2D(os.path.join(folder, 'texture_kd.png'), texture.rgb_to_srgb(material['kd']))\n            if 'ks' in material.keys():\n                f.write('map_Ks texture_ks.png\\n')\n                texture.save_texture2D(os.path.join(folder, 'texture_ks.png'), material['ks'])\n            if 'normal' in material.keys():\n                f.write('bump texture_n.png\\n')\n                texture.save_texture2D(os.path.join(folder, 'texture_n.png'), material['normal'], lambda_fn=lambda x:(util.safe_normalize(x)+1)*0.5)\n        else:\n            f.write('Kd 1 1 1\\n')\n            f.write('Ks 0 0 0\\n')\n            f.write('Ka 0 0 0\\n')\n            f.write('Tf 1 1 1\\n')\n            f.write('Ni 1\\n')\n            f.write('Ns 0\\n')\n\n######################################################################################\n# Utility function to convert an existing material and make all textures trainable\n######################################################################################\n\ndef create_trainable(material):\n    result = material.copy()\n    for key, val in result.items():\n        if isinstance(val, texture.Texture2D):\n            result[key] = texture.create_trainable(val)\n    return result\n\ndef get_parameters(material):\n    trainable = []\n    for key, val in material.items():\n        if isinstance(val, texture.Texture2D) or isinstance(val, mlptexture.MLPTexture3D):\n            trainable += val.parameters()\n    return trainable\n\n######################################################################################\n# Merge multiple materials into a single uber-material\n######################################################################################\n\ndef _upscale_replicate(x, full_res):\n    x = x.permute(0, 3, 1, 2)\n    x = torch.nn.functional.pad(x, (0, full_res[1] - x.shape[3], 0, full_res[0] - x.shape[2]), 'replicate')\n    return x.permute(0, 2, 3, 1).contiguous()\n\ndef merge_materials(materials, texcoords, tfaces, mfaces):\n    assert len(materials) > 0\n    for mat in materials:\n        assert mat['bsdf'] == materials[0]['bsdf'], \"All materials must have the same BSDF (uber shader)\"\n        assert ('normal' in mat) is ('normal' in materials[0]), \"All materials must have either normal map enabled or disabled\"\n\n    uber_material = {\n        'name' : 'uber_material',\n        'bsdf' : materials[0]['bsdf'],\n    }\n\n    textures = ['kd', 'ks', 'normal']\n\n    # Find maximum texture resolution across all materials and textures\n    max_res = None\n    for mat in materials:\n        for tex in textures:\n            tex_res = np.array(mat[tex].getRes()) if tex in mat else np.array([1, 1])\n            max_res = np.maximum(max_res, tex_res) if max_res is not None else tex_res\n    \n    # Compute size of compund texture and round up to nearest PoT\n    full_res = 2**np.ceil(np.log2(max_res * np.array([1, len(materials)]))).astype(np.int)\n\n    # Normalize texture resolution across all materials & combine into a single large texture\n    for tex in textures:\n        if tex in materials[0]:\n            tex_data = torch.cat(tuple(util.scale_img_nhwc(mat[tex].data, tuple(max_res)) for mat in materials), dim=2) # Lay out all textures horizontally, NHWC so dim2 is x\n            tex_data = _upscale_replicate(tex_data, full_res)\n            uber_material[tex] = texture.Texture2D(tex_data)\n\n    # Compute scaling values for used / unused texture area\n    s_coeff = [full_res[0] / max_res[0], full_res[1] / max_res[1]]\n\n    # Recompute texture coordinates to cooincide with new composite texture\n    new_tverts = {}\n    new_tverts_data = []\n    for fi in range(len(tfaces)):\n        matIdx = mfaces[fi]\n        for vi in range(3):\n            ti = tfaces[fi][vi]\n            if not (ti in new_tverts):\n                new_tverts[ti] = {}\n            if not (matIdx in new_tverts[ti]): # create new vertex\n                new_tverts_data.append([(matIdx + texcoords[ti][0]) / s_coeff[1], texcoords[ti][1] / s_coeff[0]]) # Offset texture coodrinate (x direction) by material id & scale to local space. Note, texcoords are (u,v) but texture is stored (w,h) so the indexes swap here\n                new_tverts[ti][matIdx] = len(new_tverts_data) - 1\n            tfaces[fi][vi] = new_tverts[ti][matIdx] # reindex vertex\n\n    return uber_material, new_tverts_data, tfaces\n"
  },
  {
    "path": "render/mesh.py",
    "content": "# Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. \n#\n# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual\n# property and proprietary rights in and to this material, related\n# documentation and any modifications thereto. Any use, reproduction, \n# disclosure or distribution of this material and related documentation \n# without an express license agreement from NVIDIA CORPORATION or \n# its affiliates is strictly prohibited.\n\nimport os\nimport numpy as np\nimport torch\n\nfrom . import obj\nfrom . import util\n\n######################################################################################\n# Base mesh class\n######################################################################################\nclass Mesh:\n    def __init__(self, v_pos=None, t_pos_idx=None, v_nrm=None, t_nrm_idx=None, v_tex=None, t_tex_idx=None, v_tng=None, t_tng_idx=None, material=None, base=None):\n        self.v_pos = v_pos\n        self.v_nrm = v_nrm\n        self.v_tex = v_tex\n        self.v_tng = v_tng\n        self.t_pos_idx = t_pos_idx\n        self.t_nrm_idx = t_nrm_idx\n        self.t_tex_idx = t_tex_idx\n        self.t_tng_idx = t_tng_idx\n        self.material = material\n\n        if base is not None:\n            self.copy_none(base)\n\n    def copy_none(self, other):\n        if self.v_pos is None:\n            self.v_pos = other.v_pos\n        if self.t_pos_idx is None:\n            self.t_pos_idx = other.t_pos_idx\n        if self.v_nrm is None:\n            self.v_nrm = other.v_nrm\n        if self.t_nrm_idx is None:\n            self.t_nrm_idx = other.t_nrm_idx\n        if self.v_tex is None:\n            self.v_tex = other.v_tex\n        if self.t_tex_idx is None:\n            self.t_tex_idx = other.t_tex_idx\n        if self.v_tng is None:\n            self.v_tng = other.v_tng\n        if self.t_tng_idx is None:\n            self.t_tng_idx = other.t_tng_idx\n        if self.material is None:\n            self.material = other.material\n\n    def clone(self):\n        out = Mesh(base=self)\n        if out.v_pos is not None:\n            out.v_pos = out.v_pos.clone().detach()\n        if out.t_pos_idx is not None:\n            out.t_pos_idx = out.t_pos_idx.clone().detach()\n        if out.v_nrm is not None:\n            out.v_nrm = out.v_nrm.clone().detach()\n        if out.t_nrm_idx is not None:\n            out.t_nrm_idx = out.t_nrm_idx.clone().detach()\n        if out.v_tex is not None:\n            out.v_tex = out.v_tex.clone().detach()\n        if out.t_tex_idx is not None:\n            out.t_tex_idx = out.t_tex_idx.clone().detach()\n        if out.v_tng is not None:\n            out.v_tng = out.v_tng.clone().detach()\n        if out.t_tng_idx is not None:\n            out.t_tng_idx = out.t_tng_idx.clone().detach()\n        return out\n\n######################################################################################\n# Mesh loeading helper\n######################################################################################\n\ndef load_mesh(filename, mtl_override=None, mtl_default=None, mtl_type_override=None):\n    name, ext = os.path.splitext(filename)\n    if ext == \".obj\":\n        return obj.load_obj(filename, clear_ks=True, mtl_override=mtl_override, mtl_default=mtl_default, mtl_type_override=mtl_type_override)\n    assert False, \"Invalid mesh file extension\"\n\n######################################################################################\n# Compute AABB\n######################################################################################\ndef aabb(mesh):\n    return torch.min(mesh.v_pos, dim=0).values, torch.max(mesh.v_pos, dim=0).values\n\n######################################################################################\n# Compute AABB with only used vertices\n######################################################################################\ndef aabb_clean(mesh):\n    v_pos_clean = mesh.v_pos[mesh.t_pos_idx.unique()]\n    return torch.min(v_pos_clean, dim=0).values, torch.max(v_pos_clean, dim=0).values\n\n######################################################################################\n# Compute unique edge list from attribute/vertex index list\n######################################################################################\ndef compute_edges(attr_idx, return_inverse=False):\n    with torch.no_grad():\n        # Create all edges, packed by triangle\n        all_edges = torch.cat((\n            torch.stack((attr_idx[:, 0], attr_idx[:, 1]), dim=-1),\n            torch.stack((attr_idx[:, 1], attr_idx[:, 2]), dim=-1),\n            torch.stack((attr_idx[:, 2], attr_idx[:, 0]), dim=-1),\n        ), dim=-1).view(-1, 2)\n\n        # Swap edge order so min index is always first\n        order = (all_edges[:, 0] > all_edges[:, 1]).long().unsqueeze(dim=1)\n        sorted_edges = torch.cat((\n            torch.gather(all_edges, 1, order),\n            torch.gather(all_edges, 1, 1 - order)\n        ), dim=-1)\n\n        # Eliminate duplicates and return inverse mapping\n        return torch.unique(sorted_edges, dim=0, return_inverse=return_inverse)\n\n######################################################################################\n# Compute unique edge to face mapping from attribute/vertex index list\n######################################################################################\ndef compute_edge_to_face_mapping(attr_idx, return_inverse=False):\n    with torch.no_grad():\n        # Get unique edges\n        # Create all edges, packed by triangle\n        all_edges = torch.cat((\n            torch.stack((attr_idx[:, 0], attr_idx[:, 1]), dim=-1),\n            torch.stack((attr_idx[:, 1], attr_idx[:, 2]), dim=-1),\n            torch.stack((attr_idx[:, 2], attr_idx[:, 0]), dim=-1),\n        ), dim=-1).view(-1, 2)\n\n        # Swap edge order so min index is always first\n        order = (all_edges[:, 0] > all_edges[:, 1]).long().unsqueeze(dim=1)\n        sorted_edges = torch.cat((\n            torch.gather(all_edges, 1, order),\n            torch.gather(all_edges, 1, 1 - order)\n        ), dim=-1)\n\n        # Elliminate duplicates and return inverse mapping\n        unique_edges, idx_map = torch.unique(sorted_edges, dim=0, return_inverse=True)\n\n        tris = torch.arange(attr_idx.shape[0]).repeat_interleave(3).cuda()\n\n        tris_per_edge = torch.zeros((unique_edges.shape[0], 2), dtype=torch.int64).cuda()\n\n        # Compute edge to face table\n        mask0 = order[:,0] == 0\n        mask1 = order[:,0] == 1\n        tris_per_edge[idx_map[mask0], 0] = tris[mask0]\n        tris_per_edge[idx_map[mask1], 1] = tris[mask1]\n\n        return tris_per_edge\n\n######################################################################################\n# Align base mesh to reference mesh:move & rescale to match bounding boxes.\n######################################################################################\ndef unit_size(mesh):\n    with torch.no_grad():\n        vmin, vmax = aabb(mesh)\n        scale = 2 / torch.max(vmax - vmin).item()\n        v_pos = mesh.v_pos - (vmax + vmin) / 2 # Center mesh on origin\n        v_pos = v_pos * scale                  # Rescale to unit size\n\n        return Mesh(v_pos, base=mesh)\n\n######################################################################################\n# Center & scale mesh for rendering\n######################################################################################\ndef center_by_reference(base_mesh, ref_aabb, scale):\n    center = (ref_aabb[0] + ref_aabb[1]).cuda() * 0.5\n    scale = scale / torch.max(ref_aabb[1] - ref_aabb[0]).item()\n    v_pos = (base_mesh.v_pos - center[None, ...]) * scale\n    return Mesh(v_pos, base=base_mesh)\n\n\ndef center_by_reference_noscale(base_mesh, ref_aabb, scale=None):\n    center = (ref_aabb[0] + ref_aabb[1]) * 0.5\n    v_pos = (base_mesh.v_pos - center[None, ...])\n    return Mesh(v_pos, base=base_mesh)\n\n\ndef center_with_global_aabb(base_mesh, ref_aabb, scale):\n    # center = (base_mesh.v_pos.min(dim=0).values + base_mesh.v_pos.max(dim=0).values) * 0.5 ### used for the experiments... wrong\n    center = ref_aabb.mean(dim=0).cuda()\n    scale = scale / torch.max(ref_aabb[1] - ref_aabb[0]).item() * 2.0\n    v_pos = (base_mesh.v_pos - center[None, ...]) * scale\n    return Mesh(v_pos, base=base_mesh)\n\n\ndef center_with_global_aabb_perdim(base_mesh, ref_aabb, scale):\n    center = (base_mesh.v_pos.min(dim=0).values + base_mesh.v_pos.max(dim=0).values) * 0.5\n    scale = scale / (ref_aabb[1] - ref_aabb[0])\n    v_pos = (base_mesh.v_pos - center[None, ...]) * scale.view(1, 3)\n    return Mesh(v_pos, base=base_mesh)\n\n\ndef scale_with_global_aabb(base_mesh, ref_aabb, scale):\n    scale = scale / torch.max(ref_aabb[1] - ref_aabb[0]).item() * 0.5\n    v_pos = base_mesh.v_pos * scale\n    return Mesh(v_pos, base=base_mesh)\n\n\ndef scale_with_global_aabb_perdim(base_mesh, ref_aabb, scale):\n    scale = scale / (ref_aabb[1] - ref_aabb[0])\n    v_pos = base_mesh.v_pos * scale.view(1, 3)\n    return Mesh(v_pos, base=base_mesh)\n\n######################################################################################\n# Simple smooth vertex normal computation\n######################################################################################\ndef auto_normals(imesh):\n\n    i0 = imesh.t_pos_idx[:, 0]\n    i1 = imesh.t_pos_idx[:, 1]\n    i2 = imesh.t_pos_idx[:, 2]\n\n    v0 = imesh.v_pos[i0, :]\n    v1 = imesh.v_pos[i1, :]\n    v2 = imesh.v_pos[i2, :]\n\n    face_normals = torch.cross(v1 - v0, v2 - v0)\n\n    # Splat face normals to vertices\n    v_nrm = torch.zeros_like(imesh.v_pos)\n    v_nrm.scatter_add_(0, i0[:, None].repeat(1,3), face_normals)\n    v_nrm.scatter_add_(0, i1[:, None].repeat(1,3), face_normals)\n    v_nrm.scatter_add_(0, i2[:, None].repeat(1,3), face_normals)\n\n    # Normalize, replace zero (degenerated) normals with some default value\n    v_nrm = torch.where(util.dot(v_nrm, v_nrm) > 1e-20, v_nrm, torch.tensor([0.0, 0.0, 1.0], dtype=torch.float32, device='cuda'))\n    v_nrm = util.safe_normalize(v_nrm)\n\n    if torch.is_anomaly_enabled():\n        assert torch.all(torch.isfinite(v_nrm))\n\n    return Mesh(v_nrm=v_nrm, t_nrm_idx=imesh.t_pos_idx, base=imesh)\n\n######################################################################################\n# Compute tangent space from texture map coordinates\n# Follows http://www.mikktspace.com/ conventions\n######################################################################################\ndef compute_tangents(imesh, v_tng=None):\n    if v_tng is not None:\n        v_tng = util.safe_normalize(v_tng)\n        v_tng = util.safe_normalize(v_tng - util.dot(v_tng, imesh.v_nrm) * imesh.v_nrm)\n        return Mesh(v_tng=v_tng, t_tng_idx=imesh.t_nrm_idx, base=imesh)\n\n    vn_idx = [None] * 3\n    pos = [None] * 3\n    tex = [None] * 3\n    for i in range(0,3):\n        pos[i] = imesh.v_pos[imesh.t_pos_idx[:, i]]\n        tex[i] = imesh.v_tex[imesh.t_tex_idx[:, i]]\n        vn_idx[i] = imesh.t_nrm_idx[:, i]\n\n    tangents = torch.zeros_like(imesh.v_nrm)\n    tansum   = torch.zeros_like(imesh.v_nrm)\n\n    # Compute tangent space for each triangle\n    uve1 = tex[1] - tex[0]\n    uve2 = tex[2] - tex[0]\n    pe1  = pos[1] - pos[0]\n    pe2  = pos[2] - pos[0]\n    \n    nom   = (pe1 * uve2[..., 1:2] - pe2 * uve1[..., 1:2])\n    denom = (uve1[..., 0:1] * uve2[..., 1:2] - uve1[..., 1:2] * uve2[..., 0:1])\n    \n    # Avoid division by zero for degenerated texture coordinates\n    tang = nom / torch.where(denom > 0.0, torch.clamp(denom, min=1e-6), torch.clamp(denom, max=-1e-6))\n\n    # Update all 3 vertices\n    for i in range(0,3):\n        idx = vn_idx[i][:, None].repeat(1,3)\n        tangents.scatter_add_(0, idx, tang)                # tangents[n_i] = tangents[n_i] + tang\n        tansum.scatter_add_(0, idx, torch.ones_like(tang)) # tansum[n_i] = tansum[n_i] + 1\n    tangents = tangents / tansum\n\n    # Normalize and make sure tangent is perpendicular to normal\n    tangents = util.safe_normalize(tangents)\n    tangents = util.safe_normalize(tangents - util.dot(tangents, imesh.v_nrm) * imesh.v_nrm)\n\n    if torch.is_anomaly_enabled():\n        assert torch.all(torch.isfinite(tangents))\n\n    return Mesh(v_tng=tangents, t_tng_idx=imesh.t_nrm_idx, base=imesh)\n"
  },
  {
    "path": "render/mlptexture.py",
    "content": "# Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. \n#\n# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual\n# property and proprietary rights in and to this material, related\n# documentation and any modifications thereto. Any use, reproduction, \n# disclosure or distribution of this material and related documentation \n# without an express license agreement from NVIDIA CORPORATION or \n# its affiliates is strictly prohibited.\n\nimport torch\nimport tinycudann as tcnn\nimport numpy as np\n\n#######################################################################################################################################################\n# Small MLP using PyTorch primitives, internal helper class\n#######################################################################################################################################################\n\nclass _MLP(torch.nn.Module):\n    def __init__(self, cfg, loss_scale=1.0):\n        super(_MLP, self).__init__()\n        self.loss_scale = loss_scale\n        net = (torch.nn.Linear(cfg['n_input_dims'], cfg['n_neurons'], bias=False), torch.nn.ReLU())\n        for i in range(cfg['n_hidden_layers']-1):\n            net = net + (torch.nn.Linear(cfg['n_neurons'], cfg['n_neurons'], bias=False), torch.nn.ReLU())\n        net = net + (torch.nn.Linear(cfg['n_neurons'], cfg['n_output_dims'], bias=False),)\n        self.net = torch.nn.Sequential(*net).cuda()\n        \n        self.net.apply(self._init_weights)\n        \n        if self.loss_scale != 1.0:\n            self.net.register_full_backward_hook(lambda module, grad_i, grad_o: (grad_i[0] * self.loss_scale, ))\n\n    def forward(self, x):\n        return self.net(x.to(torch.float32))\n\n    @staticmethod\n    def _init_weights(m):\n        if type(m) == torch.nn.Linear:\n            torch.nn.init.kaiming_uniform_(m.weight, nonlinearity='relu')\n            if hasattr(m.bias, 'data'):\n                m.bias.data.fill_(0.0)\n\n#######################################################################################################################################################\n# Outward visible MLP class\n#######################################################################################################################################################\n\nclass MLPTexture3D(torch.nn.Module):\n    def __init__(self, AABB, channels = 3, internal_dims = 32, hidden = 2, min_max = None, use_float16=False):\n        super(MLPTexture3D, self).__init__()\n\n        self.channels = channels\n        self.internal_dims = internal_dims\n        self.AABB = AABB\n        self.min_max = min_max\n        self.use_float16 = use_float16\n\n        # Setup positional encoding, see https://github.com/NVlabs/tiny-cuda-nn for details\n        desired_resolution = 4096\n        base_grid_resolution = 16\n        num_levels = 16\n        per_level_scale = np.exp(np.log(desired_resolution / base_grid_resolution) / (num_levels-1))\n\n        enc_cfg =  {\n            \"otype\": \"HashGrid\",\n            \"n_levels\": num_levels,\n            \"n_features_per_level\": 2,\n            \"log2_hashmap_size\": 19,\n            \"base_resolution\": base_grid_resolution,\n            \"per_level_scale\" : per_level_scale\n\t    }\n\n        gradient_scaling = 128.0\n        self.encoder = tcnn.Encoding(3, enc_cfg)\n        self.encoder.register_full_backward_hook(lambda module, grad_i, grad_o: (grad_i[0] / gradient_scaling, ))\n\n        # Setup MLP\n        mlp_cfg = {\n            \"n_input_dims\" : self.encoder.n_output_dims,\n            \"n_output_dims\" : self.channels,\n            \"n_hidden_layers\" : hidden,\n            \"n_neurons\" : self.internal_dims\n        }\n        self.net = _MLP(mlp_cfg, gradient_scaling)\n        print(\"Encoder output: %d dims\" % (self.encoder.n_output_dims))\n\n    # Sample texture at a given location\n    def sample(self, texc):\n        _texc = (texc.view(-1, 3) - self.AABB[0][None, ...]) / (self.AABB[1][None, ...] - self.AABB[0][None, ...])\n        _texc = torch.clamp(_texc, min=0, max=1)\n        \n        p_enc = self.encoder(_texc.contiguous())\n        with torch.autocast('cuda', dtype=torch.float16, enabled=self.use_float16):\n            out = self.net.forward(p_enc)\n\n        # Sigmoid limit and scale to the allowed range\n        out = torch.sigmoid(out) * (self.min_max[1][None, :] - self.min_max[0][None, :]) + self.min_max[0][None, :]\n\n        return out.view(*texc.shape[:-1], self.channels) # Remap to [n, h, w, c]\n\n    # In-place clamp with no derivative to make sure values are in valid range after training\n    def clamp_(self):\n        pass\n\n    def cleanup(self):\n        tcnn.free_temporary_memory()\n\n"
  },
  {
    "path": "render/obj.py",
    "content": "# Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. \n#\n# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual\n# property and proprietary rights in and to this material, related\n# documentation and any modifications thereto. Any use, reproduction, \n# disclosure or distribution of this material and related documentation \n# without an express license agreement from NVIDIA CORPORATION or \n# its affiliates is strictly prohibited.\n\nimport os\nimport torch\n\nfrom . import texture\nfrom . import mesh\nfrom . import material\n\n######################################################################################\n# Utility functions\n######################################################################################\n\ndef _find_mat(materials, name):\n    for mat in materials:\n        if mat['name'] == name:\n            return mat\n    return materials[0] # Materials 0 is the default\n\n######################################################################################\n# Create mesh object from objfile\n######################################################################################\n\ndef load_obj(filename, clear_ks=True, mtl_override=None, mtl_default=None, mtl_type_override=None):\n    obj_path = os.path.dirname(filename)\n\n    # Read entire file\n    with open(filename, 'r') as f:\n        lines = f.readlines()\n\n    # Load materials\n    if mtl_default is None:\n        all_materials = [\n            {\n                'name' : '_default_mat',\n                'bsdf' : 'pbr',\n                'kd'   : texture.Texture2D(torch.tensor([0.5, 0.5, 0.5], dtype=torch.float32, device='cuda')),\n                'ks'   : texture.Texture2D(torch.tensor([0.0, 0.0, 0.0], dtype=torch.float32, device='cuda'))\n            }\n        ]\n        if mtl_override is None: \n            for line in lines:\n                if len(line.split()) == 0:\n                    continue\n                if line.split()[0] == 'mtllib':\n                    all_materials += material.load_mtl(os.path.join(obj_path, line.split()[1]), clear_ks) # Read in entire material library\n        else:\n            all_materials += material.load_mtl(mtl_override)\n    else:\n        print(\"Load use-defined default mtl\")\n        all_materials = [mtl_default]\n\n    if mtl_type_override is not None:\n        for k in range(len(all_materials)):\n            all_materials[k]['bsdf'] = mtl_type_override\n\n\n    # load vertices\n    vertices, texcoords, normals  = [], [], []\n    for line in lines:\n        if len(line.split()) == 0:\n            continue\n        \n        prefix = line.split()[0].lower()\n        if prefix == 'v':\n            vertices.append([float(v) for v in line.split()[1:]])\n        elif prefix == 'vt':\n            val = [float(v) for v in line.split()[1:]]\n            texcoords.append([val[0], 1.0 - val[1]])\n        elif prefix == 'vn':\n            normals.append([float(v) for v in line.split()[1:]])\n\n    # load faces\n    activeMatIdx = None\n    used_materials = []\n    faces, tfaces, nfaces, mfaces = [], [], [], []\n    for line in lines:\n        if len(line.split()) == 0:\n            continue\n\n        prefix = line.split()[0].lower()\n        if prefix == 'usemtl': # Track used materials\n            mat = _find_mat(all_materials, line.split()[1])\n            if not mat in used_materials:\n                used_materials.append(mat)\n            activeMatIdx = used_materials.index(mat)\n        elif prefix == 'f': # Parse face\n            vs = line.split()[1:]\n            nv = len(vs)\n            vv = vs[0].split('/')\n            v0 = int(vv[0]) - 1\n            t0 = int(vv[1]) - 1 if vv[1] != \"\" else -1\n            n0 = int(vv[2]) - 1 if vv[2] != \"\" else -1\n            for i in range(nv - 2): # Triangulate polygons\n                vv = vs[i + 1].split('/')\n                v1 = int(vv[0]) - 1\n                t1 = int(vv[1]) - 1 if vv[1] != \"\" else -1\n                n1 = int(vv[2]) - 1 if vv[2] != \"\" else -1\n                vv = vs[i + 2].split('/')\n                v2 = int(vv[0]) - 1\n                t2 = int(vv[1]) - 1 if vv[1] != \"\" else -1\n                n2 = int(vv[2]) - 1 if vv[2] != \"\" else -1\n                mfaces.append(activeMatIdx)\n                faces.append([v0, v1, v2])\n                tfaces.append([t0, t1, t2])\n                nfaces.append([n0, n1, n2])\n    assert len(tfaces) == len(faces) and len(nfaces) == len (faces)\n\n    # Create an \"uber\" material by combining all textures into a larger texture\n    if len(used_materials) > 1:\n        uber_material, texcoords, tfaces = material.merge_materials(used_materials, texcoords, tfaces, mfaces)\n    else:\n        uber_material = used_materials[0]\n    # elif len(used_materials) == 1:\n    #     uber_material = used_materials[0]\n    # else:\n    #     uber_material = [all_materials[0]]\n\n\n    vertices = torch.tensor(vertices, dtype=torch.float32, device='cuda')\n    texcoords = torch.tensor(texcoords, dtype=torch.float32, device='cuda') if len(texcoords) > 0 else None\n    normals = torch.tensor(normals, dtype=torch.float32, device='cuda') if len(normals) > 0 else None\n    \n    faces = torch.tensor(faces, dtype=torch.int64, device='cuda')\n    tfaces = torch.tensor(tfaces, dtype=torch.int64, device='cuda') if texcoords is not None else None\n    nfaces = torch.tensor(nfaces, dtype=torch.int64, device='cuda') if normals is not None else None\n\n    vertices = vertices[:, :3]\n\n    return mesh.Mesh(vertices, faces, normals, nfaces, texcoords, tfaces, material=uber_material)\n\n######################################################################################\n# Save mesh object to objfile\n######################################################################################\n\ndef write_obj(folder, mesh, save_material=True):\n    obj_file = os.path.join(folder, 'mesh.obj')\n    print(\"Writing mesh: \", obj_file)\n    with open(obj_file, \"w\") as f:\n        f.write(\"mtllib mesh.mtl\\n\")\n        f.write(\"g default\\n\")\n\n        v_pos = mesh.v_pos.detach().cpu().numpy() if mesh.v_pos is not None else None\n        v_nrm = mesh.v_nrm.detach().cpu().numpy() if mesh.v_nrm is not None else None\n        v_tex = mesh.v_tex.detach().cpu().numpy() if mesh.v_tex is not None else None\n\n        t_pos_idx = mesh.t_pos_idx.detach().cpu().numpy() if mesh.t_pos_idx is not None else None\n        t_nrm_idx = mesh.t_nrm_idx.detach().cpu().numpy() if mesh.t_nrm_idx is not None else None\n        t_tex_idx = mesh.t_tex_idx.detach().cpu().numpy() if mesh.t_tex_idx is not None else None\n\n        print(\"    writing %d vertices\" % len(v_pos))\n        for v in v_pos:\n            f.write('v {} {} {} \\n'.format(v[0], v[1], v[2]))\n       \n        if v_tex is not None:\n            print(\"    writing %d texcoords\" % len(v_tex))\n            assert(len(t_pos_idx) == len(t_tex_idx))\n            for v in v_tex:\n                f.write('vt {} {} \\n'.format(v[0], 1.0 - v[1]))\n\n        if v_nrm is not None:\n            print(\"    writing %d normals\" % len(v_nrm))\n            assert(len(t_pos_idx) == len(t_nrm_idx))\n            for v in v_nrm:\n                f.write('vn {} {} {}\\n'.format(v[0], v[1], v[2]))\n\n        # faces\n        f.write(\"s 1 \\n\")\n        f.write(\"g pMesh1\\n\")\n        f.write(\"usemtl defaultMat\\n\")\n\n        # Write faces\n        print(\"    writing %d faces\" % len(t_pos_idx))\n        for i in range(len(t_pos_idx)):\n            f.write(\"f \")\n            for j in range(3):\n                f.write(' %s/%s/%s' % (str(t_pos_idx[i][j]+1), '' if v_tex is None else str(t_tex_idx[i][j]+1), '' if v_nrm is None else str(t_nrm_idx[i][j]+1)))\n            f.write(\"\\n\")\n\n    if save_material:\n        mtl_file = os.path.join(folder, 'mesh.mtl')\n        print(\"Writing material: \", mtl_file)\n        material.save_mtl(mtl_file, mesh.material)\n\n    print(\"Done exporting mesh\")\n"
  },
  {
    "path": "render/optixutils/__init__.py",
    "content": "# Copyright (c) 2020-2022, NVIDIA CORPORATION. All rights reserved.\n#\n# NVIDIA CORPORATION and its licensors retain all intellectual property\n# and proprietary rights in and to this software, related documentation\n# and any modifications thereto. Any use, reproduction, disclosure or\n# distribution of this software and related documentation without an express\n# license agreement from NVIDIA CORPORATION is strictly prohibited.\n\nfrom .ops import OptiXContext, optix_build_bvh, optix_env_shade, bilateral_denoiser\n__all__ = [\"OptiXContext\", \"optix_build_bvh\", \"optix_env_shade\", 'bilateral_denoiser']\n"
  },
  {
    "path": "render/optixutils/c_src/accessor.h",
    "content": "// Copyright (c) 2020-2022, NVIDIA CORPORATION. All rights reserved.\n//\n// NVIDIA CORPORATION and its licensors retain all intellectual property\n// and proprietary rights in and to this software, related documentation\n// and any modifications thereto. Any use, reproduction, disclosure or\n// distribution of this software and related documentation without an express\n// license agreement from NVIDIA CORPORATION is strictly prohibited.\n\n// Stripped down version from pytorch. Made to work with optix kernels where it's \n// hard to include dependencies\n// https://github.com/pytorch/pytorch/blob/dc169d53aa266560750ea25ee0cf31c7e614550d/aten/src/ATen/core/TensorAccessor.h\n\n/////////////////////////////////////////////////////////////////////////////\n// From PyTorch:\n\n// Copyright (c) 2016-     Facebook, Inc            (Adam Paszke)\n// Copyright (c) 2014-     Facebook, Inc            (Soumith Chintala)\n// Copyright (c) 2011-2014 Idiap Research Institute (Ronan Collobert)\n// Copyright (c) 2012-2014 Deepmind Technologies    (Koray Kavukcuoglu)\n// Copyright (c) 2011-2012 NEC Laboratories America (Koray Kavukcuoglu)\n// Copyright (c) 2011-2013 NYU                      (Clement Farabet)\n// Copyright (c) 2006-2010 NEC Laboratories America (Ronan Collobert, Leon Bottou, Iain Melvin, Jason Weston)\n// Copyright (c) 2006      Idiap Research Institute (Samy Bengio)\n// Copyright (c) 2001-2004 Idiap Research Institute (Ronan Collobert, Samy Bengio, Johnny Mariethoz)\n\n// From Caffe2:\n\n// Copyright (c) 2016-present, Facebook Inc. All rights reserved.\n\n// All contributions by Facebook:\n// Copyright (c) 2016 Facebook Inc.\n\n// All contributions by Google:\n// Copyright (c) 2015 Google Inc.\n// All rights reserved.\n\n// All contributions by Yangqing Jia:\n// Copyright (c) 2015 Yangqing Jia\n// All rights reserved.\n\n// All contributions by Kakao Brain:\n// Copyright 2019-2020 Kakao Brain\n\n// All contributions by Cruise LLC:\n// Copyright (c) 2022 Cruise LLC.\n// All rights reserved.\n\n// All contributions from Caffe:\n// Copyright(c) 2013, 2014, 2015, the respective contributors\n// All rights reserved.\n\n// All other contributions:\n// Copyright(c) 2015, 2016 the respective contributors\n// All rights reserved.\n\n// Caffe2 uses a copyright model similar to Caffe: each contributor holds\n// copyright over their contributions to Caffe2. The project versioning records\n// all such contribution and copyright details. If a contributor wants to further\n// mark their specific copyright on a particular contribution, they should\n// indicate their copyright solely in the commit message of the change when it is\n// committed.\n\n// All rights reserved.\n\n// Redistribution and use in source and binary forms, with or without\n// modification, are permitted provided that the following conditions are met:\n\n// 1. Redistributions of source code must retain the above copyright\n//    notice, this list of conditions and the following disclaimer.\n\n// 2. Redistributions in binary form must reproduce the above copyright\n//    notice, this list of conditions and the following disclaimer in the\n//    documentation and/or other materials provided with the distribution.\n\n// 3. Neither the names of Facebook, Deepmind Technologies, NYU, NEC Laboratories America\n//    and IDIAP Research Institute nor the names of its contributors may be\n//    used to endorse or promote products derived from this software without\n//    specific prior written permission.\n\n// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\"\n// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE\n// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE\n// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE\n// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR\n// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF\n// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS\n// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN\n// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)\n// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE\n// POSSIBILITY OF SUCH DAMAGE.\n/////////////////////////////////////////////////////////////////////////////\n\n#pragma once\n\n#if defined(__OPTIX__)\n    typedef int int32_t;\n    typedef long long int64_t;\n#else\n    #include <stdint.h>\n#endif\n\n#ifdef __CUDACC__\n    #ifdef __CUDA_ARCH__\n        #define C10_DEVICE __device__\n        #define C10_HOST_DEVICE __device__\n    #else\n        #define C10_DEVICE __device__\n        #define C10_HOST __host__\n        #define C10_HOST_DEVICE __host__ __device__\n    #endif\n#else\n    #include <algorithm>\n    #define C10_HOST_DEVICE\n    #define C10_HOST\n#endif\n\n// The PtrTraits argument to the TensorAccessor/GenericPackedTensorAccessor\n// is used to enable the __restrict__ keyword/modifier for the data\n// passed to cuda.\ntemplate <typename T>\nstruct DefaultPtrTraits {\n    typedef T* PtrType;\n};\n\n#if defined(__CUDACC__) || defined(__HIPCC__)\ntemplate <typename T>\nstruct RestrictPtrTraits {\n    typedef T* __restrict__ PtrType;\n};\n#endif\n\n// TensorAccessorBase and TensorAccessor are used for both CPU and CUDA tensors.\n// For CUDA tensors it is used in device code (only). This means that we restrict ourselves\n// to functions and types available there (e.g. IntArrayRef isn't).\n\n// The PtrTraits argument is only relevant to cuda to support `__restrict__` pointers.\ntemplate<typename T, size_t N, template <typename U> class PtrTraits = DefaultPtrTraits, typename index_t = int64_t>\nclass TensorAccessorBase {\npublic:\n    typedef typename PtrTraits<T>::PtrType PtrType;\n\n    C10_HOST_DEVICE TensorAccessorBase(\n        PtrType data_,\n        const index_t* sizes_,\n        const index_t* strides_)\n        : data_(data_), sizes_(sizes_), strides_(strides_) {}\n    C10_HOST_DEVICE index_t stride(index_t i) const {\n        return strides_[i];\n    }\n    C10_HOST_DEVICE index_t size(index_t i) const {\n        return sizes_[i];\n    }\n    C10_HOST_DEVICE PtrType data() {\n        return data_;\n    }\n    C10_HOST_DEVICE const PtrType data() const {\n        return data_;\n    }\nprotected:\n    PtrType data_;\n    const index_t* sizes_;\n    const index_t* strides_;\n};\n\n// The `TensorAccessor` is typically instantiated for CPU `Tensor`s using\n// `Tensor.accessor<T, N>()`.\n// For CUDA `Tensor`s, `GenericPackedTensorAccessor` is used on the host and only\n// indexing on the device uses `TensorAccessor`s.\ntemplate<typename T, size_t N, template <typename U> class PtrTraits = DefaultPtrTraits, typename index_t = int64_t>\nclass TensorAccessor : public TensorAccessorBase<T,N,PtrTraits,index_t> {\npublic:\n    typedef typename PtrTraits<T>::PtrType PtrType;\n\n    C10_HOST_DEVICE TensorAccessor(\n        PtrType data_,\n        const index_t* sizes_,\n        const index_t* strides_)\n        : TensorAccessorBase<T, N, PtrTraits, index_t>(data_,sizes_,strides_) {}\n\n    C10_HOST_DEVICE TensorAccessor<T, N - 1, PtrTraits, index_t> operator[](index_t i) {\n        return TensorAccessor<T,N-1,PtrTraits,index_t>(this->data_ + this->strides_[0]*i,this->sizes_+1,this->strides_+1);\n    }\n\n    C10_HOST_DEVICE const TensorAccessor<T, N-1, PtrTraits, index_t> operator[](index_t i) const {\n        return TensorAccessor<T,N-1,PtrTraits,index_t>(this->data_ + this->strides_[0]*i,this->sizes_+1,this->strides_+1);\n    }\n};\n\ntemplate<typename T, template <typename U> class PtrTraits, typename index_t>\nclass TensorAccessor<T,1,PtrTraits,index_t> : public TensorAccessorBase<T,1,PtrTraits,index_t> {\npublic:\n    typedef typename PtrTraits<T>::PtrType PtrType;\n\n    C10_HOST_DEVICE TensorAccessor(\n        PtrType data_,\n        const index_t* sizes_,\n        const index_t* strides_)\n        : TensorAccessorBase<T, 1, PtrTraits, index_t>(data_,sizes_,strides_) {}\n    C10_HOST_DEVICE T & operator[](index_t i) {\n        // NOLINTNEXTLINE(clang-analyzer-core.NullDereference)\n        return this->data_[this->strides_[0]*i];\n    }\n    C10_HOST_DEVICE const T & operator[](index_t i) const {\n        return this->data_[this->strides_[0]*i];\n    }\n};\n\n// GenericPackedTensorAccessorBase and GenericPackedTensorAccessor are used on for CUDA `Tensor`s on the host\n// and as\n// In contrast to `TensorAccessor`s, they copy the strides and sizes on instantiation (on the host)\n// in order to transfer them on the device when calling kernels.\n// On the device, indexing of multidimensional tensors gives to `TensorAccessor`s.\n// Use RestrictPtrTraits as PtrTraits if you want the tensor's data pointer to be marked as __restrict__.\n// Instantiation from data, sizes, strides is only needed on the host and std::copy isn't available\n// on the device, so those functions are host only.\ntemplate<typename T, size_t N, template <typename U> class PtrTraits = DefaultPtrTraits, typename index_t = int64_t>\nclass GenericPackedTensorAccessorBase {\npublic:\n    typedef typename PtrTraits<T>::PtrType PtrType;\n\n#if !defined(__CUDACC__)\n    C10_HOST GenericPackedTensorAccessorBase() {}\n\n    C10_HOST GenericPackedTensorAccessorBase(\n        PtrType data_,\n        const index_t* sizes_,\n        const index_t* strides_)\n        : data_(data_) {\n        std::copy(sizes_, sizes_ + N, std::begin(this->sizes_));\n        std::copy(strides_, strides_ + N, std::begin(this->strides_));\n    }\n\n    // if index_t is not int64_t, we want to have an int64_t constructor\n    template <typename source_index_t, class = typename std::enable_if<std::is_same<source_index_t, int64_t>::value>::type>\n    C10_HOST GenericPackedTensorAccessorBase(\n        PtrType data_,\n        const source_index_t* sizes_,\n        const source_index_t* strides_)\n        : data_(data_) {\n        for (const auto i : c10::irange(N)) {\n            this->sizes_[i] = sizes_[i];\n            this->strides_[i] = strides_[i];\n        }\n    }\n#endif\n    C10_HOST_DEVICE index_t stride(index_t i) const {\n        return strides_[i];\n    }\n    C10_HOST_DEVICE index_t size(index_t i) const {\n        return sizes_[i];\n    }\n    C10_HOST_DEVICE PtrType data() {\n        return data_;\n    }\n    C10_HOST_DEVICE const PtrType data() const {\n        return data_;\n    }\nprotected:\n    PtrType data_;\n    index_t sizes_[N];\n    index_t strides_[N];\n};\n\ntemplate<typename T, size_t N, template <typename U> class PtrTraits = DefaultPtrTraits, typename index_t = int64_t>\nclass GenericPackedTensorAccessor : public GenericPackedTensorAccessorBase<T,N,PtrTraits,index_t> {\npublic:\n    typedef typename PtrTraits<T>::PtrType PtrType;\n\n#if !defined(__CUDACC__)\n    C10_HOST GenericPackedTensorAccessor() : GenericPackedTensorAccessorBase<T, N, PtrTraits, index_t>() {}\n\n    C10_HOST GenericPackedTensorAccessor(\n        PtrType data_,\n        const index_t* sizes_,\n        const index_t* strides_)\n        : GenericPackedTensorAccessorBase<T, N, PtrTraits, index_t>(data_, sizes_, strides_) {}\n\n    // if index_t is not int64_t, we want to have an int64_t constructor\n    template <typename source_index_t, class = typename std::enable_if<std::is_same<source_index_t, int64_t>::value>::type>\n    C10_HOST GenericPackedTensorAccessor(\n        PtrType data_,\n        const source_index_t* sizes_,\n        const source_index_t* strides_)\n        : GenericPackedTensorAccessorBase<T, N, PtrTraits, index_t>(data_, sizes_, strides_) {}\n#else\n    C10_DEVICE TensorAccessor<T, N - 1, PtrTraits, index_t> operator[](index_t i) {\n        index_t* new_sizes = this->sizes_ + 1;\n        index_t* new_strides = this->strides_ + 1;\n        return TensorAccessor<T,N-1,PtrTraits,index_t>(this->data_ + this->strides_[0]*i, new_sizes, new_strides);\n    }\n\n    C10_DEVICE const TensorAccessor<T, N - 1, PtrTraits, index_t> operator[](index_t i) const {\n        const index_t* new_sizes = this->sizes_ + 1;\n        const index_t* new_strides = this->strides_ + 1;\n        return TensorAccessor<T,N-1,PtrTraits,index_t>(this->data_ + this->strides_[0]*i, new_sizes, new_strides);\n    }\n#endif\n\n};\n\ntemplate<typename T, template <typename U> class PtrTraits, typename index_t>\nclass GenericPackedTensorAccessor<T,1,PtrTraits,index_t> : public GenericPackedTensorAccessorBase<T,1,PtrTraits,index_t> {\npublic:\n    typedef typename PtrTraits<T>::PtrType PtrType;\n\n#if !defined(__CUDACC__)\n    C10_HOST GenericPackedTensorAccessor() : GenericPackedTensorAccessorBase<T, 1, PtrTraits, index_t>() {}\n\n    C10_HOST GenericPackedTensorAccessor(\n        PtrType data_,\n        const index_t* sizes_,\n        const index_t* strides_)\n        : GenericPackedTensorAccessorBase<T, 1, PtrTraits, index_t>(data_, sizes_, strides_) {}\n\n    // if index_t is not int64_t, we want to have an int64_t constructor\n    template <typename source_index_t, class = typename std::enable_if<std::is_same<source_index_t, int64_t>::value>::type>\n    C10_HOST GenericPackedTensorAccessor(\n        PtrType data_,\n        const source_index_t* sizes_,\n        const source_index_t* strides_)\n        : GenericPackedTensorAccessorBase<T, 1, PtrTraits, index_t>(data_, sizes_, strides_) {}\n#else\n    C10_DEVICE T & operator[](index_t i) {\n        return this->data_[this->strides_[0] * i];\n    }\n    C10_DEVICE const T& operator[](index_t i) const {\n        return this->data_[this->strides_[0]*i];\n    }\n#endif\n};\n\ntemplate <typename T, size_t N, template <typename U> class PtrTraits = DefaultPtrTraits>\nusing PackedTensorAccessor32 = GenericPackedTensorAccessor<T, N, PtrTraits, int32_t>;\n\ntemplate <typename T, size_t N, template <typename U> class PtrTraits = DefaultPtrTraits>\nusing PackedTensorAccessor64 = GenericPackedTensorAccessor<T, N, PtrTraits, int64_t>;"
  },
  {
    "path": "render/optixutils/c_src/bsdf.h",
    "content": "// Copyright (c) 2020-2022, NVIDIA CORPORATION. All rights reserved.\n//\n// NVIDIA CORPORATION and its licensors retain all intellectual property\n// and proprietary rights in and to this software, related documentation\n// and any modifications thereto. Any use, reproduction, disclosure or\n// distribution of this software and related documentation without an express\n// license agreement from NVIDIA CORPORATION is strictly prohibited.\n\n#pragma once\n\n#ifdef __CUDACC__\n\n#define SPECULAR_EPSILON 1e-4f\n#ifndef M_PI\n    #define M_PI 3.14159265358979323846f\n#endif\n\n//------------------------------------------------------------------------\n// Lambert functions\n\n__device__ inline float fwdLambert(const float3 nrm, const float3 wi)\n{\n    return max(dot(nrm, wi) / M_PI, 0.0f);\n}\n\n__device__ inline void bwdLambert(const float3 nrm, const float3 wi, float3& d_nrm, float3& d_wi, const float d_out)\n{\n    if (dot(nrm, wi) > 0.0f)\n        bwd_dot(nrm, wi, d_nrm, d_wi, d_out / M_PI);\n}\n\n//------------------------------------------------------------------------\n// Fresnel Schlick \n\n__device__ inline float fwdFresnelSchlick(const float f0, const float f90, const float cosTheta)\n{\n    float _cosTheta = clamp(cosTheta, SPECULAR_EPSILON, 1.0f - SPECULAR_EPSILON);\n    float scale = powf(1.0f - _cosTheta, 5.0f);\n    return f0 * (1.0f - scale) + f90 * scale;\n}\n\n__device__ inline void bwdFresnelSchlick(const float f0, const float f90, const float cosTheta, float& d_f0, float& d_f90, float& d_cosTheta, const float d_out)\n{\n    float _cosTheta = clamp(cosTheta, SPECULAR_EPSILON, 1.0f - SPECULAR_EPSILON);\n    float scale = pow(max(1.0f - _cosTheta, 0.0f), 5.0f);\n    d_f0 += d_out * (1.0 - scale);\n    d_f90 += d_out * scale;\n    if (cosTheta >= SPECULAR_EPSILON && cosTheta < 1.0f - SPECULAR_EPSILON)\n    {\n        d_cosTheta += d_out * (f90 - f0) * -5.0f * powf(1.0f - cosTheta, 4.0f);\n    }\n}\n\n__device__ inline float3 fwdFresnelSchlick(const float3 f0, const float3 f90, const float cosTheta)\n{\n    float _cosTheta = clamp(cosTheta, SPECULAR_EPSILON, 1.0f - SPECULAR_EPSILON);\n    float scale = powf(1.0f - _cosTheta, 5.0f);\n    return f0 * (1.0f - scale) + f90 * scale;\n}\n\n__device__ inline void bwdFresnelSchlick(const float3 f0, const float3 f90, const float cosTheta, float3& d_f0, float3& d_f90, float& d_cosTheta, const float3 d_out)\n{\n    float _cosTheta = clamp(cosTheta, SPECULAR_EPSILON, 1.0f - SPECULAR_EPSILON);\n    float scale = pow(max(1.0f - _cosTheta, 0.0f), 5.0f);\n    d_f0 += d_out * (1.0 - scale);\n    d_f90 += d_out * scale;\n    if (cosTheta >= SPECULAR_EPSILON && cosTheta < 1.0f - SPECULAR_EPSILON)\n    {\n        d_cosTheta += sum(d_out * (f90 - f0) * -5.0f * powf(1.0f - cosTheta, 4.0f));\n    }\n}\n\n//------------------------------------------------------------------------\n// Ndf GGX\n\n__device__ inline float fwdNdfGGX(const float alphaSqr, const float cosTheta)\n{\n    float _cosTheta = clamp(cosTheta, SPECULAR_EPSILON, 1.0f - SPECULAR_EPSILON);\n    float d = (_cosTheta * alphaSqr - _cosTheta) * _cosTheta + 1.0f;\n    return alphaSqr / (d * d * M_PI);\n}\n\n__device__ inline void bwdNdfGGX(const float alphaSqr, const float cosTheta, float& d_alphaSqr, float& d_cosTheta, const float d_out)\n{\n    // Torch only back propagates if clamp doesn't trigger\n    float _cosTheta = clamp(cosTheta, SPECULAR_EPSILON, 1.0f - SPECULAR_EPSILON);\n    float cosThetaSqr = _cosTheta * _cosTheta;\n    d_alphaSqr += d_out * (1.0f - (alphaSqr + 1.0f) * cosThetaSqr) / (M_PI * powf((alphaSqr - 1.0) * cosThetaSqr + 1.0f, 3.0f));\n    if (cosTheta > SPECULAR_EPSILON && cosTheta < 1.0f - SPECULAR_EPSILON)\n    {\n        d_cosTheta += d_out * -(4.0f * (alphaSqr - 1.0f) * alphaSqr * cosTheta) / (M_PI * powf((alphaSqr - 1.0) * cosThetaSqr + 1.0f, 3.0f));\n    }\n}\n\n//------------------------------------------------------------------------\n// Lambda GGX\n\n__device__ inline float fwdLambdaGGX(const float alphaSqr, const float cosTheta)\n{\n    float _cosTheta = clamp(cosTheta, SPECULAR_EPSILON, 1.0f - SPECULAR_EPSILON);\n    float cosThetaSqr = _cosTheta * _cosTheta;\n    float tanThetaSqr = (1.0 - cosThetaSqr) / cosThetaSqr;\n    float res = 0.5f * (sqrtf(1.0f + alphaSqr * tanThetaSqr) - 1.0f);\n    return res;\n}\n\n__device__ inline void bwdLambdaGGX(const float alphaSqr, const float cosTheta, float& d_alphaSqr, float& d_cosTheta, const float d_out)\n{\n    float _cosTheta = clamp(cosTheta, SPECULAR_EPSILON, 1.0f - SPECULAR_EPSILON);\n    float cosThetaSqr = _cosTheta * _cosTheta;\n    float tanThetaSqr = (1.0 - cosThetaSqr) / cosThetaSqr;\n    float res = 0.5f * (sqrtf(1.0f + alphaSqr * tanThetaSqr) - 1.0f);\n\n    d_alphaSqr += d_out * (0.25 * tanThetaSqr) / sqrtf(alphaSqr * tanThetaSqr + 1.0f);\n    if (cosTheta > SPECULAR_EPSILON && cosTheta < 1.0f - SPECULAR_EPSILON)\n        d_cosTheta += d_out * -(0.5 * alphaSqr) / (powf(_cosTheta, 3.0f) * sqrtf(alphaSqr / cosThetaSqr - alphaSqr + 1.0f));\n}\n\n//------------------------------------------------------------------------\n// Masking GGX\n\n__device__ inline float fwdMaskingSmithGGXCorrelated(const float alphaSqr, const float cosThetaI, const float cosThetaO)\n{\n    float lambdaI = fwdLambdaGGX(alphaSqr, cosThetaI);\n    float lambdaO = fwdLambdaGGX(alphaSqr, cosThetaO);\n    return 1.0f / (1.0f + lambdaI + lambdaO);\n}\n\n__device__ inline void bwdMaskingSmithGGXCorrelated(const float alphaSqr, const float cosThetaI, const float cosThetaO, float& d_alphaSqr, float& d_cosThetaI, float& d_cosThetaO, const float d_out)\n{\n    // FWD eval\n    float lambdaI = fwdLambdaGGX(alphaSqr, cosThetaI);\n    float lambdaO = fwdLambdaGGX(alphaSqr, cosThetaO);\n\n    // BWD eval\n    float d_lambdaIO = -d_out / powf(1.0f + lambdaI + lambdaO, 2.0f);\n    bwdLambdaGGX(alphaSqr, cosThetaI, d_alphaSqr, d_cosThetaI, d_lambdaIO);\n    bwdLambdaGGX(alphaSqr, cosThetaO, d_alphaSqr, d_cosThetaO, d_lambdaIO);\n}\n\n//------------------------------------------------------------------------\n// GGX specular\n\n__device__ float3 fwdPbrSpecular(const float3 col, const float3 nrm, const float3 wo, const float3 wi, const float alpha, const float min_roughness)\n{\n    float _alpha = clamp(alpha, min_roughness * min_roughness, 1.0f);\n    float alphaSqr = _alpha * _alpha;\n\n    float3 h = safe_normalize(wo + wi);\n    float woDotN = dot(wo, nrm);\n    float wiDotN = dot(wi, nrm);\n    float woDotH = dot(wo, h);\n    float nDotH = dot(nrm, h);\n\n    float D = fwdNdfGGX(alphaSqr, nDotH);\n    float G = fwdMaskingSmithGGXCorrelated(alphaSqr, woDotN, wiDotN);\n    float3 F = fwdFresnelSchlick(col, make_float3(1.0f), woDotH);\n    float3 w = F * D * G * 0.25 / woDotN;\n\n    bool frontfacing = (woDotN > SPECULAR_EPSILON) & (wiDotN > SPECULAR_EPSILON);\n    return frontfacing ? w : make_float3(0.0f);\n}\n\n__device__ void bwdPbrSpecular(\n    const float3 col, const float3 nrm, const float3 wo, const float3 wi, const float alpha, const float min_roughness,\n    float3& d_col, float3& d_nrm, float3& d_wo, float3& d_wi, float& d_alpha, const float3 d_out)\n{\n    ///////////////////////////////////////////////////////////////////////\n    // FWD eval\n\n    float _alpha = clamp(alpha, min_roughness * min_roughness, 1.0f);\n    float alphaSqr = _alpha * _alpha;\n\n    float3 h = safe_normalize(wo + wi);\n    float woDotN = dot(wo, nrm);\n    float wiDotN = dot(wi, nrm);\n    float woDotH = dot(wo, h);\n    float nDotH = dot(nrm, h);\n\n    float D = fwdNdfGGX(alphaSqr, nDotH);\n    float G = fwdMaskingSmithGGXCorrelated(alphaSqr, woDotN, wiDotN);\n    float3 F = fwdFresnelSchlick(col, make_float3(1.0f), woDotH);\n    float3 w = F * D * G * 0.25 / woDotN;\n    bool frontfacing = (woDotN > SPECULAR_EPSILON) & (wiDotN > SPECULAR_EPSILON);\n\n    if (frontfacing)\n    {\n        ///////////////////////////////////////////////////////////////////////\n        // BWD eval\n\n        float3 d_F = d_out * D * G * 0.25f / woDotN;\n        float d_D = sum(d_out * F * G * 0.25f / woDotN);\n        float d_G = sum(d_out * F * D * 0.25f / woDotN);\n\n        float d_woDotN = -sum(d_out * F * D * G * 0.25f / (woDotN * woDotN));\n\n        float3 d_f90 = make_float3(0);\n        float d_woDotH = 0, d_wiDotN = 0, d_nDotH = 0, d_alphaSqr = 0;\n        bwdFresnelSchlick(col, make_float3(1.0f), woDotH, d_col, d_f90, d_woDotH, d_F);\n        bwdMaskingSmithGGXCorrelated(alphaSqr, woDotN, wiDotN, d_alphaSqr, d_woDotN, d_wiDotN, d_G);\n        bwdNdfGGX(alphaSqr, nDotH, d_alphaSqr, d_nDotH, d_D);\n\n        float3 d_h = make_float3(0);\n        bwd_dot(nrm, h, d_nrm, d_h, d_nDotH);\n        bwd_dot(wo, h, d_wo, d_h, d_woDotH);\n        bwd_dot(wi, nrm, d_wi, d_nrm, d_wiDotN);\n        bwd_dot(wo, nrm, d_wo, d_nrm, d_woDotN);\n\n        float3 d_h_unnorm = make_float3(0);\n        bwd_safe_normalize(wo + wi, d_h_unnorm, d_h);\n        d_wo += d_h_unnorm;\n        d_wi += d_h_unnorm;\n\n        if (alpha > min_roughness * min_roughness)\n            d_alpha += d_alphaSqr * 2 * alpha;\n    }\n}\n\n//------------------------------------------------------------------------\n// Full PBR BSDF\n\n__device__ void fwdPbrBSDF(const float3 kd, const float3 arm, const float3 pos, const float3 nrm, const float3 view_pos, const float3 wi, const float min_roughness, float3 &diffuse, float3 &specular)\n{\n    float3 wo = safe_normalize(view_pos - pos);\n\n    float alpha = arm.y * arm.y;\n    float3 spec_col = (make_float3(0.04f) * (1.0f - arm.z) + kd * arm.z) * (1.0 - arm.x);\n    // Removed because of demodulated albedo.\n    // float3 diff_col = kd * (1.0f - arm.z);\n\n    float diff = 0.0f;\n    diff = fwdLambert(nrm, wi);\n    \n    diffuse = make_float3(diff);//diff_col * diff;\n    specular = fwdPbrSpecular(spec_col, nrm, wo, wi, alpha, min_roughness);\n}\n\n__device__ void bwdPbrBSDF(\n    const float3 kd, const float3 arm, const float3 pos, const float3 nrm, const float3 view_pos, const float3 wi, const float min_roughness,\n    float3& d_kd, float3& d_arm, float3& d_pos, float3& d_nrm, float3& d_view_pos, float3& d_wi, const float3 d_diffuse, float3 d_specular)\n{\n    ////////////////////////////////////////////////////////////////////////\n    // FWD\n    float3 _wo = view_pos - pos;\n    float3 wo = safe_normalize(_wo);\n\n    float alpha = arm.y * arm.y;\n    float3 spec_col = (make_float3(0.04f) * (1.0f - arm.z) + kd * arm.z) * (1.0 - arm.x);\n\n    ////////////////////////////////////////////////////////////////////////\n    // BWD\n\n    float d_alpha = 0;\n    d_wi = make_float3(0);\n    float3 d_spec_col = make_float3(0), d_wo = make_float3(0);\n    bwdPbrSpecular(spec_col, nrm, wo, wi, alpha, min_roughness, d_spec_col, d_nrm, d_wo, d_wi, d_alpha, d_specular);\n\n    // float d_diff = sum(diff_col * d_diffuse);\n    float d_diff = sum(d_diffuse);\n    bwdLambert(nrm, wi, d_nrm, d_wi, d_diff);\n\n    // Backprop: spec_col = (0.04f * (1.0f - arm.z) + kd * arm.z) * (1.0 - arm.x)\n    d_kd -= d_spec_col * (arm.x - 1.0f) * arm.z;\n    d_arm.x += sum(d_spec_col * (arm.z * (make_float3(0.04f) - kd) - 0.04f));\n    d_arm.z -= sum(d_spec_col * (kd - make_float3(0.04f)) * (arm.x - 1.0f));\n\n    // Backprop: alpha = arm.y * arm.y\n    d_arm.y += d_alpha * 2 * arm.y;\n\n    // Backprop: float3 wo = safe_normalize(view_pos - pos);\n    float3 d__wo = make_float3(0);\n    bwd_safe_normalize(_wo, d__wo, d_wo);\n    d_view_pos += d__wo;\n    d_pos -= d__wo;\n}\n\n#endif"
  },
  {
    "path": "render/optixutils/c_src/common.h",
    "content": "// Copyright (c) 2020-2022, NVIDIA CORPORATION. All rights reserved.\n//\n// NVIDIA CORPORATION and its licensors retain all intellectual property\n// and proprietary rights in and to this software, related documentation\n// and any modifications thereto. Any use, reproduction, disclosure or\n// distribution of this software and related documentation without an express\n// license agreement from NVIDIA CORPORATION is strictly prohibited.\n\n#pragma once\n\n// Helper functions to do broadcast guarded fetches\n#if defined(__CUDACC__)\n    template<class T, typename U, typename... Args>\n    static __device__ inline float3 fetch3(const T &tensor, U idx, Args... args) {\n    return tensor.size(0) == 1 ? fetch3(tensor[0], args...) : fetch3(tensor[idx], args...);\n    }\n    template<class T> static __device__ inline float3 fetch3(const T &tensor) {\n    return tensor.size(0) == 1 ? make_float3(tensor[0], tensor[0], tensor[0]) : make_float3(tensor[0], tensor[1], tensor[2]);\n    }\n\n    template<class T, typename U, typename... Args>\n    static __device__ inline float2 fetch2(const T &tensor, U idx, Args... args) {\n    return tensor.size(0) == 1 ? fetch2(tensor[0], args...) : fetch2(tensor[idx], args...);\n    }\n    template<class T> static __device__ inline float2 fetch2(const T &tensor) {\n    return tensor.size(0) == 1 ? make_float2(tensor[0], tensor[0]) : make_float2(tensor[0], tensor[1]);\n    }\n\n    #include \"math_utils.h\"\n    #include \"bsdf.h\"\n#endif\n\n//------------------------------------------------------------------------------\n// CUDA error-checking macros\n//------------------------------------------------------------------------------\n\n#define CUDA_CHECK( call )                                                     \\\n    do                                                                         \\\n    {                                                                          \\\n        cudaError_t error = call;                                              \\\n        if( error != cudaSuccess )                                             \\\n        {                                                                      \\\n            std::stringstream ss;                                              \\\n            ss << \"CUDA call (\" << #call << \" ) failed with error: '\"          \\\n               << cudaGetErrorString( error )                                  \\\n               << \"' (\" __FILE__ << \":\" << __LINE__ << \")\\n\";                  \\\n        }                                                                      \\\n    } while( 0 )\n\n\n#define OPTIX_CHECK( call )                                                    \\\n    do                                                                         \\\n    {                                                                          \\\n        OptixResult res = call;                                                \\\n        if( res != OPTIX_SUCCESS )                                             \\\n        {                                                                      \\\n            std::stringstream ss;                                              \\\n            ss << \"Optix call '\" << #call << \"' failed: \" __FILE__ \":\"         \\\n               << __LINE__ << \")\\n\";                                           \\\n        }                                                                      \\\n    } while( 0 )\n\n#define OPTIX_CHECK_LOG( call )                                                \\\n    do                                                                         \\\n    {                                                                          \\\n        OptixResult res = call;                                                \\\n        const size_t sizeof_log_returned = sizeof_log;                         \\\n        sizeof_log = sizeof( log ); /* reset sizeof_log for future calls */    \\\n        if( res != OPTIX_SUCCESS )                                             \\\n        {                                                                      \\\n            std::stringstream ss;                                              \\\n            ss << \"Optix call '\" << #call << \"' failed: \" __FILE__ \":\"         \\\n               << __LINE__ << \")\\nLog:\\n\" << log                               \\\n               << ( sizeof_log_returned > sizeof( log ) ? \"<TRUNCATED>\" : \"\" ) \\\n               << \"\\n\";                                                        \\\n        }                                                                      \\\n    } while( 0 )\n\n#define NVRTC_CHECK_ERROR( func )                                                                                           \\\n    do                                                                                                                      \\\n    {                                                                                                                       \\\n        nvrtcResult code = func;                                                                                            \\\n        if( code != NVRTC_SUCCESS )                                                                                         \\\n            throw std::runtime_error( \"ERROR: \" __FILE__ \"(): \" + std::string( nvrtcGetErrorString( code ) ) );             \\\n    } while( 0 )\n"
  },
  {
    "path": "render/optixutils/c_src/denoising.cu",
    "content": "// Copyright (c) 2020-2022, NVIDIA CORPORATION. All rights reserved.\n//\n// NVIDIA CORPORATION and its licensors retain all intellectual property\n// and proprietary rights in and to this software, related documentation\n// and any modifications thereto. Any use, reproduction, disclosure or\n// distribution of this software and related documentation without an express\n// license agreement from NVIDIA CORPORATION is strictly prohibited.\n\n#include \"common.h\"\n#include \"denoising.h\"\n\n#define FLT_EPS 0.0001f\n\n__global__ void bilateral_denoiser_fwd_kernel(BilateralDenoiserParams params)\n{\n    uint3 idx = make_uint3(blockIdx.x * blockDim.x + threadIdx.x, blockIdx.y * blockDim.y + threadIdx.y, blockIdx.z * blockDim.z + threadIdx.z);\n\n    if (idx.z >= params.col.size(0) || idx.y >= params.col.size(1) || idx.x >= params.col.size(2))\n        return;\n\n    // Fetch central tap\n    float3 c_nrm = fetch3(params.nrm, idx.z, idx.y, idx.x);\n    float2 c_zdz = fetch2(params.zdz, idx.z, idx.y, idx.x);\n\n    float variance = params.sigma * params.sigma;\n    int filter_rad = 2 * ceil(params.sigma * 2.5) + 1;\n\n    float accum_w = 0.0f;\n    float3 accum_col = make_float3(0.0f);\n    for (int32_t fy = -filter_rad; fy <= filter_rad; ++fy)\n    {\n        for (int32_t fx = -filter_rad; fx <= filter_rad; ++fx)\n        {\n            // Compute tap coordinates, used for input activations and bilateral guides\n            int32_t y = idx.y + fy;\n            int32_t x = idx.x + fx;\n\n            if (y < 0 || x < 0 || y >= params.col.size(1) || x >= params.col.size(2))\n                continue;\n\n            // Fetch current tap\n            float3 t_col = fetch3(params.col, idx.z, y, x);\n            float3 t_nrm = fetch3(params.nrm, idx.z, y, x);\n            float2 t_zdz = fetch2(params.zdz, idx.z, y, x);\n\n            /////////////////////////////////////////////////////////\n            // Compute bilateral weight\n            /////////////////////////////////////////////////////////\n\n            // Distance\n            float dist_sqr = fx * fx + fy * fy;\n            float dist = sqrtf(dist_sqr);\n            float w_xy = expf(-dist_sqr / (2.0f * variance));\n\n            // Normal\n            float w_normal = powf(min(max(dot(t_nrm, c_nrm), FLT_EPS), 1.0f), 128.0f);\n\n            // Depth\n            float w_depth = expf(-(abs(t_zdz.x - c_zdz.x) / max(c_zdz.y * dist, FLT_EPS)));\n\n            float w = w_xy * w_normal * w_depth;\n\n            accum_col = accum_col + t_col * w;\n            accum_w += w;\n        }\n    }\n\n    params.out[idx.z][idx.y][idx.x][0] = accum_col.x;\n    params.out[idx.z][idx.y][idx.x][1] = accum_col.y;\n    params.out[idx.z][idx.y][idx.x][2] = accum_col.z;\n    params.out[idx.z][idx.y][idx.x][3] = max(accum_w, 0.0001f);\n}\n\n__global__ void bilateral_denoiser_bwd_kernel(BilateralDenoiserParams params)\n{\n    uint3 idx = make_uint3(blockIdx.x * blockDim.x + threadIdx.x, blockIdx.y * blockDim.y + threadIdx.y, blockIdx.z * blockDim.z + threadIdx.z);\n\n    if (idx.z >= params.col.size(0) || idx.y >= params.col.size(1) || idx.x >= params.col.size(2))\n        return;\n\n    // Fetch central tap\n    float3 c_nrm = fetch3(params.nrm, idx.z, idx.y, idx.x);\n    float2 c_zdz = fetch2(params.zdz, idx.z, idx.y, idx.x);\n\n    float variance = params.sigma * params.sigma;\n    int filter_rad = 2 * ceil(params.sigma * 2.5) + 1;\n\n    float3 accum_grad = make_float3(0.0f);\n    for (int32_t fy = -filter_rad; fy <= filter_rad; ++fy)\n    {\n        for (int32_t fx = -filter_rad; fx <= filter_rad; ++fx)\n        {\n            // Compute tap coordinates, used for input activations and bilateral guides\n            int32_t y = idx.y + fy;\n            int32_t x = idx.x + fx;\n\n            if (y < 0 || x < 0 || y >= params.col.size(1) || x >= params.col.size(2))\n                continue;\n\n            // Fetch current tap\n            float3 t_col = fetch3(params.col, idx.z, y, x);\n            float3 t_nrm = fetch3(params.nrm, idx.z, y, x);\n            float2 t_zdz = fetch2(params.zdz, idx.z, y, x);\n\n            /////////////////////////////////////////////////////////\n            // Compute bilateral weight\n            /////////////////////////////////////////////////////////\n\n            // Distance, transposing fx & fy doesn't affect distance\n            float dist_sqr = fx * fx + fy * fy;\n            float dist = sqrtf(dist_sqr);\n            float w_xy = expf(-dist_sqr / (2.0f * variance));\n\n            // Normal, transpose c_ and t_ (it's symmetric so doesn't matter)\n            float w_normal = powf(min(max(dot(t_nrm, c_nrm), FLT_EPS), 1.0f), 128.0f);\n\n            // Depth, transpose c_ and t_ (matters for the denominator)\n            float w_depth = expf(-(abs(t_zdz.x - c_zdz.x) / max(t_zdz.y * dist, FLT_EPS)));\n\n            float w = w_xy * w_normal * w_depth;\n\n            float3 t_col_grad = w * fetch3(params.out_grad, idx.z, y, x);\n            accum_grad += t_col_grad;\n        }\n    }\n\n    params.col_grad[idx.z][idx.y][idx.x][0] = accum_grad.x;\n    params.col_grad[idx.z][idx.y][idx.x][1] = accum_grad.y;\n    params.col_grad[idx.z][idx.y][idx.x][2] = accum_grad.z;\n}\n"
  },
  {
    "path": "render/optixutils/c_src/denoising.h",
    "content": "// Copyright (c) 2020-2022, NVIDIA CORPORATION. All rights reserved.\n//\n// NVIDIA CORPORATION and its licensors retain all intellectual property\n// and proprietary rights in and to this software, related documentation\n// and any modifications thereto. Any use, reproduction, disclosure or\n// distribution of this software and related documentation without an express\n// license agreement from NVIDIA CORPORATION is strictly prohibited.\n\n#pragma once\n#include \"accessor.h\"\n\nstruct BilateralDenoiserParams\n{\n    PackedTensorAccessor32<float, 4> col;\n    PackedTensorAccessor32<float, 4> col_grad;  \n    PackedTensorAccessor32<float, 4> nrm;\n    PackedTensorAccessor32<float, 4> zdz;\n    PackedTensorAccessor32<float, 4> out;\n    PackedTensorAccessor32<float, 4> out_grad;\n    float sigma;\n};\n"
  },
  {
    "path": "render/optixutils/c_src/envsampling/kernel.cu",
    "content": "// Copyright (c) 2020-2022, NVIDIA CORPORATION. All rights reserved.\n//\n// NVIDIA CORPORATION and its licensors retain all intellectual property\n// and proprietary rights in and to this software, related documentation\n// and any modifications thereto. Any use, reproduction, disclosure or\n// distribution of this software and related documentation without an express\n// license agreement from NVIDIA CORPORATION is strictly prohibited.\n\n#define OPTIXU_MATH_DEFINE_IN_NAMESPACE\n\n#include <optix.h>\n#include <math_constants.h>\n\n#include \"params.h\"\n#include \"../common.h\"\n\n#define MIN_ROUGHNESS 0.08f\n\nextern \"C\" {\n__constant__ EnvSamplingParams params;\n}\n\n//==============================================================================\n// Math / utility functions\n//==============================================================================\n\n#include \"../bsdf.h\"\n\n// from https://www.reedbeta.com/blog/hash-functions-for-gpu-rendering/\n__device__ unsigned int rand_pcg(unsigned int &rng_state)\n{\n    unsigned int word = ((rng_state >> ((rng_state >> 28u) + 4u)) ^ rng_state) * 277803737u;\n    rng_state = rng_state * 747796405u + 2891336453u;\n    return (word >> 22u) ^ word;\n}\n\n__device__ unsigned int hash_pcg(unsigned int global_seed, unsigned int sample_seed)\n{\n    return rand_pcg(global_seed) ^ rand_pcg(sample_seed);\n}\n\n__device__ float uniform_pcg(unsigned int &rng_state)\n{\n    return (float)(rand_pcg(rng_state) & 0xFFFFFF) / (float)0x1000000;\n}\n\n__device__ float3 tolocal(const float3& a, const float3& u, const float3& v, const float3& w) \n{\n    return make_float3(dot(a, u), dot(a, v), dot(a, w));\n}\n\n__device__ float3 toworld(const float3& a, const float3& u, const float3& v, const float3& w) \n{\n    return u * a.x + v * a.y + w * a.z;\n}\n\n__device__ float3 cosine_sample(float3 N, float u, float v, float& pdf)\n{   \n    // construct local frame\n    N = safe_normalize(N);\n    float3 dx, dy;\n    branchlessONB(N, dx, dy); \n\n    // cosine sampling in local frame\n    float phi = 2.0 * CUDART_PI * u;\n    float costheta = sqrt(v);\n    float sintheta = sqrt(1.0 - v);\n\n    // Cartesian vector in local space\n    float x = cos(phi)*sintheta;\n    float y = sin(phi)*sintheta;\n    float z = costheta;\n\n    pdf = max(0.000001f, costheta / CUDART_PI);\n\n    // Local to world\n    float3 vec = dx*x + dy*y + N*z;\n    return safe_normalize(vec);\n}\n\n__device__ float albedo(const float3& baseColor, const float eta, const float3& wo, const float3& N)\n{\n    // Construct tangent frame\n    float3 W = safe_normalize(N);\n    float3 U,V;\n    branchlessONB(W, U, V);\n    float3 wo_l = safe_normalize(tolocal(wo, U, V, W));\n\n    const float cosNO = wo_l.z;\n    if (!(cosNO > 0))\n        return 0.0f;\n\n    return luminance(fwdFresnelSchlick(baseColor, make_float3(1.f, 1.f, 1.f), cosNO));\n}\n\n//==============================================================================\n// Shadow ray test. Note: This code ignores the shadow gradient boundary term.\n// We saw no benefit to boundary term gradients in our experiments. \n//==============================================================================\n\n__device__ float shadow_test(uint3 idx, float3 ray_origin, float3 ray_dir, float vis_grad)\n{\n    unsigned int isVisible = 0;\n    optixTrace(\n        params.handle,\n        ray_origin,\n        ray_dir,\n        0.0f,                                       // Min intersection distance\n        1e16,                                       // Max intersection distance\n        0.0f,                                       // rayTime -- used for motion blur\n        OptixVisibilityMask(0xFF),\n        OPTIX_RAY_FLAG_DISABLE_ANYHIT | OPTIX_RAY_FLAG_DISABLE_CLOSESTHIT | OPTIX_RAY_FLAG_TERMINATE_ON_FIRST_HIT,\n        0,                                          // SBT offset\n        0,                                          // SBT stride\n        0,                                          // missSBTIndex\n        isVisible);\n    return isVisible ? 1.0f : 0.0f;\n}\n\n//==============================================================================\n// Light probe functions\n//==============================================================================\n\n__device__ float2 _dir_to_tc(float3 dir)\n{\n    float u = atan2f(dir.x, -dir.z) / (2.0f * CUDART_PI) + 0.5f;\n    float v = acosf(clamp(dir.y, -1.0f, 1.0f)) / CUDART_PI;\n    return make_float2(u, v);\n}\n\n__device__ float3 _tc_to_dir(float2 uv)\n{\n    float sinphi, cosphi;\n    sincos((uv.x * 2.0f - 1.0f) * CUDART_PI, &sinphi, &cosphi);\n    float sintheta, costheta;\n    sincos(uv.y * CUDART_PI, &sintheta, &costheta);\n    return make_float3(sintheta*sinphi, costheta, -sintheta*cosphi);\n}\n\ntemplate<class T> __device__ float sample_cdf(const T &cdf, float x, unsigned int &idx, float &pdf)\n{\n    x = min(x, 0.99999994f);\n\n    // Binary search to find next index above\n    unsigned int _min = 0;\n    unsigned int _max = cdf.size(0) - 1;\n    unsigned int m = int(ceil(log2((float)_max))) + 1;\n    for (int i=0; i<m; ++i)\n    {\n        unsigned int mid = (_min + _max) / 2;\n        _min = x >= cdf[mid] ? mid :_min;\n        _max = x < cdf[mid] ? mid : _max;\n    }\n    idx = _max;\n\n    float sample;\n    if (idx == 0) {\n        pdf = cdf[0];\n        sample = x;\n    }\n    else {\n        float data0 = cdf[idx];\n        float data1 = cdf[idx-1];\n        pdf = data0 - data1;\n        sample = (x - data1);\n    }\n    // keep result in [0,1)\n    return min(sample / pdf, 0.99999994f);\n}\n\n__device__ float lightPDF(const float3& dir)\n{\n    // Sample light\n    float2 coord = _dir_to_tc(dir);\n\n    // retrieve nearest neighbor\n    int x = clamp((int)(coord.x * params.pdf.size(1)), 0, params.pdf.size(1) - 1);\n    int y = clamp((int)(coord.y * params.pdf.size(0)), 0, params.pdf.size(0) - 1);\n\n    float pdf_weight = params.cols.size(0) * params.cols.size(1) / (2.0f * CUDART_PI * CUDART_PI * max(sinf(coord.y * CUDART_PI), 0.0001f));\n    return params.pdf[y][x] * pdf_weight;\n}\n\n__device__ float3 lightSample(float u, float v, float& pdf)\n{\n    float row_pdf, col_pdf;\n    unsigned int x, y;\n    float ry = sample_cdf(params.rows, v, y, row_pdf);\n    float rx = sample_cdf(params.cols[y], u, x, col_pdf);\n    float3 rnd_dir = _tc_to_dir(make_float2((x+rx)/params.cols.size(1), (y+ry)/params.cols.size(0)));\n    pdf = lightPDF(rnd_dir);\n    return rnd_dir;\n}\n\n__device__ float3 eval_light_fwd(float2 coord)\n{\n    coord = coord * make_float2(params.light.size(1), params.light.size(0)); \n    int x = clamp((int)coord.x, 0, params.light.size(1) - 1);\n    int y = clamp((int)coord.y, 0, params.light.size(0) - 1);\n    return fetch3(params.light, y, x);\n}\n\n__device__ void eval_light_bwd(float2 coord, float3 light_grad)\n{\n    coord = coord * make_float2(params.light.size(1), params.light.size(0)); \n    int x = clamp((int)coord.x, 0, params.light.size(1) - 1);\n    int y = clamp((int)coord.y, 0, params.light.size(0) - 1);\n    atomicAdd(&params.light_grad[y][x][0], light_grad.x);\n    atomicAdd(&params.light_grad[y][x][1], light_grad.y);\n    atomicAdd(&params.light_grad[y][x][2], light_grad.z);\n}\n\n//==============================================================================\n// BSDF evaluation & importance sampling\n//==============================================================================\n\n__device__ float evalNdfGGX(float alpha, float cosTheta)\n{\n    float a2 = alpha * alpha;\n    float d = ((cosTheta * a2 - cosTheta) * cosTheta + 1);\n    return a2 / (d * d * CUDART_PI);\n}\n\n__device__ float evalG1GGX(float alphaSqr, float cosTheta)\n{\n    if (cosTheta <= 0) return 0;\n    float cosThetaSqr = cosTheta * cosTheta;\n    float tanThetaSqr = max(1.0f - cosThetaSqr, 0.0f) / cosThetaSqr;\n    return 2 / (1 + sqrt(1 + alphaSqr * tanThetaSqr));\n}\n\n__device__ float evalPdfGGX_VNDF(float alpha, float3 wo, float3 h)\n{\n    float G1 = evalG1GGX(alpha * alpha, wo.z);\n    float D = evalNdfGGX(alpha, h.z);\n    return G1 * D * max(0.f, dot(wo, h)) / wo.z;\n}\n\n// Samples the GGX (Trowbridge-Reitz) using the distribution of visible normals (VNDF).\n// See http://jcgt.org/published/0007/04/01/paper.pdf\n__device__ float3 sampleGGX_VNDF(float alpha, float3 wo, float ux, float uy, float& pdf)\n{\n    // Transform the view vector to the hemisphere configuration.\n    float3 Vh = safe_normalize(make_float3(alpha * wo.x, alpha * wo.y, wo.z));\n\n    // Construct orthonormal basis (Vh,T1,T2).\n    float3 T1 = (Vh.z < 0.9999f) ? safe_normalize(cross(make_float3(0.f, 0.f, 1.f), Vh)) : make_float3(1.f, 0.f, 0.f);\n    float3 T2 = cross(Vh, T1);\n\n    // Parameterization of the projected area of the hemisphere.\n    float r = sqrtf(ux);\n    float phi = (2.f * M_PI) * uy;\n    float t1 = r * cos(phi);\n    float t2 = r * sin(phi);\n    float s = 0.5f * (1.f + Vh.z);\n    t2 = (1.f - s) * sqrtf(1.f - t1 * t1) + s * t2;\n\n    // Reproject onto hemisphere.\n    float3 Nh = T1 * t1 + T2* t2 + Vh * sqrtf(max(0.f, 1.f - t1 * t1 - t2 * t2));\n\n    // Transform the normal back to the ellipsoid configuration. This is our half vector.\n    float3 h = safe_normalize(make_float3(alpha * Nh.x, alpha * Nh.y, max(0.f, Nh.z)));\n\n    pdf = evalPdfGGX_VNDF(alpha, wo, h);\n    return h;\n}\n\n__device__ float3 ggx_sample(float3 N, float3 wo, float u, float v, float alpha, float& pdf)\n{\n    // Construct tangent frame\n    float3 W = safe_normalize(N);\n    float3 U,V;\n    branchlessONB(W, U, V);\n\n    float3 wo_l = safe_normalize(tolocal(wo, U, V, W));\n    const float cosNO = wo_l.z;\n    if (!(cosNO > 0)) {\n        pdf = 0.f;\n        return make_float3(0.f, 0.f, 0.f);\n    }\n\n    float3 h = sampleGGX_VNDF(alpha, wo_l, u, v, pdf);    // pdf = G1(wo) * D(h) * max(0,dot(wo,h)) / wo.z\n\n    // Reflect the outgoing direction to find the incident direction.\n    float woDotH = dot(wo_l, h);\n    float3 wi_l = h * woDotH * 2.0f - wo_l;\n    pdf /= (4.0f * woDotH); // Jacobian of the reflection operator.\n\n    float3 wi_o = toworld(wi_l, U, V, W);\n    return safe_normalize(wi_o);\n}\n\n__device__ float evalLambdaGGX(float alphaSqr, float cosTheta)\n{\n    if (cosTheta <= 0) return 0;\n    float cosThetaSqr = cosTheta * cosTheta;\n    float tanThetaSqr = max(1 - cosThetaSqr, 0.0f) / cosThetaSqr;\n    return 0.5 * (-1 + sqrt(1 + alphaSqr * tanThetaSqr));\n}\n\n__device__ float ggx_pdf(float3 N, const float3 wo, const float3 wi, float alpha)\n{\n    // Construct tangent frame\n    float3 W = safe_normalize(N);\n    float3 U,V;\n    branchlessONB(W, U, V);\n\n    // wo_l : V\n    // wi_l : L\n    float3 wo_l = tolocal(wo, U, V, W);\n    float3 wi_l = tolocal(wi, U, V, W);\n\n    float pdf = 0.0f;\n    if (wo_l.z > 0 && wi_l.z > 0) {\n        float3 m = safe_normalize(wi_l + wo_l);\n        const float woDotH = dot(m, wo_l);\n        const float D = evalNdfGGX(alpha, m.z);\n        float G1 = evalG1GGX(alpha * alpha, wo_l.z);\n        pdf = G1 * D * max(0.f, dot(wo_l, m)) / wo_l.z;\n        pdf /= (4 * woDotH);\n    }\n    return pdf; \n}\n\n__device__ void update_pdf(float* pdf, float opdf, float b)\n{\n    if (b > 0.000001f)\n    {\n        opdf *= b;\n        *pdf += opdf;\n    }\n}\n\n__device__ float3 bsdf_sample(float pDiffuse, float pSpecular, float3 N, float3 wo, float3 s, float alpha, float& pdf)\n{\n    float3 rnd = s;\n    pdf = 0.0f;\n    float3 wi_o; \n\n    if (rnd.z < pDiffuse) // Sample diffuse lobe\n    {\n        if (pDiffuse < 0.0001f)\n        { \n            pdf = 1.0f;\n            return N;\n        }\n\n        wi_o = cosine_sample(N, rnd.x, rnd.y, pdf);\n        pdf *= pDiffuse;\n\n        // we sampled the diffuse lobe, now figure out how much the other bsdf contribute to the chosen direction\n        if (pSpecular > 0)\n        {\n            float bsdf_pdf = ggx_pdf(N, wo, wi_o, alpha);\n            update_pdf(&pdf, bsdf_pdf, 1.0f - pDiffuse);\n        }\n    }\n    else // Sample specular lobe\n    {\n        wi_o = ggx_sample(N, wo, rnd.x, rnd.y, alpha, pdf);\n        pdf *= 1.f - pDiffuse;\n\n        // we sampled PDF 1, now figure out how much the other bsdf contribute to the chosen direction\n        if (pDiffuse > 0)\n        {\n            float bsdf_pdf = max(dot(N, wi_o), 0.0) / CUDART_PI; // cosine sampling pdf\n            update_pdf(&pdf, bsdf_pdf, pDiffuse);\n        }\n    }\n\n    return wi_o;\n}\n\n__device__ float bsdf_pdf(float pDiffuse, float pSpecular, float3 N, const float3 wo, const float3 wi, float alpha)\n{\n    // Check that L and V are in the positive hemisphere.\n    // The G term on the correlated form is not robust for NdotL = NdotV = 0.0.\n    float NdotL = dot(N, wi);\n    float NdotV = dot(N, wo);\n    static const float kMinCosTheta = 1e-6f;\n    float pdf = 0.0f;\n    if (min(NdotV, NdotL) < kMinCosTheta)\n        return 1.0f;\n\n    if (pDiffuse > 0)\n    {\n        float bsdf_pdf = max(dot(N, wi), 0.0) / CUDART_PI; // cosine sampling pdf\n        update_pdf(&pdf, bsdf_pdf, pDiffuse);\n    }\n\n    if (pSpecular > 0)\n    {\n        float bsdf_pdf = ggx_pdf(N, wo, wi, alpha); // ggx sampling pdf\n        update_pdf(&pdf, bsdf_pdf, 1.0f - pDiffuse);\n    }\n    return pdf;\n}\n\n//==============================================================================\n// Optix kernels\n//==============================================================================\n\n__device__ void process_sample(uint3 idx, float3 ray_origin, float3 ray_dir, float3 gb_pos, float3 gb_normal, float3 gb_view_pos, \n    float3 gb_kd, float3 gb_ks, float pdfSum, float weight, float3 &diff, float3 &spec, float3 diff_grad, float3 spec_grad)\n{\n    float2 coord = _dir_to_tc(ray_dir);\n    float3 light_col = eval_light_fwd(coord);\n\n    float mis_weight = 1.0 / max(pdfSum, 0.0001f); // MIS balance heuristic\n    // float alpha = gb_ks.y * gb_ks.y;\n\n    float3 _diff = make_float3(0), _spec = make_float3(0);\n    if (params.BSDF == 1 || params.BSDF == 2)\n        _diff = make_float3(fwdLambert(gb_normal, ray_dir));\n    else\n        fwdPbrBSDF(gb_kd, gb_ks, gb_pos, gb_normal, gb_view_pos, ray_dir, 0.08f, _diff, _spec);\n\n    // Trace shadow ray for current sample\n    float V_grad = sum((diff_grad * _diff + spec_grad * _spec) * light_col * mis_weight * weight) * params.shadow_scale;\n    float V = shadow_test(idx, ray_origin, ray_dir, V_grad) * params.shadow_scale + (1 - params.shadow_scale);\n\n    if (params.backward)\n    {\n        float3 light_grad = (diff_grad * _diff + spec_grad * _spec) * V * mis_weight * weight;\n        eval_light_bwd(coord, light_grad);\n\n        float3 _diff_grad = diff_grad * light_col * V * mis_weight * weight;\n        float3 _spec_grad = spec_grad * light_col * V * mis_weight * weight;\n        float3 gb_kd_grad = make_float3(0), gb_ks_grad = make_float3(0), gb_pos_grad = make_float3(0), gb_normal_grad = make_float3(0), gb_view_pos_grad = make_float3(0), ray_dir_grad = make_float3(0);\n        if (params.BSDF == 1 || params.BSDF == 2) // params.BSDF : 0 : 'pbr', 1 : 'diffuse', 2 : 'white'\n        {\n            float3 wi_grad = make_float3(0);\n            float lambert = fwdLambert(gb_normal, ray_dir);\n            float lambert_grad = sum(_diff_grad);\n            bwdLambert(gb_normal, ray_dir, gb_normal_grad, wi_grad, lambert_grad);\n        }\n        else\n        {\n            bwdPbrBSDF( gb_kd, gb_ks, gb_pos, gb_normal, gb_view_pos, ray_dir, 0.08f,  \n                        gb_kd_grad, gb_ks_grad, gb_pos_grad, gb_normal_grad, gb_view_pos_grad, ray_dir_grad, _diff_grad, _spec_grad);\n        }\n        params.gb_pos_grad[idx.z][idx.y][idx.x][0] += gb_pos_grad.x;\n        params.gb_pos_grad[idx.z][idx.y][idx.x][1] += gb_pos_grad.y;\n        params.gb_pos_grad[idx.z][idx.y][idx.x][2] += gb_pos_grad.z;\n\n        params.gb_normal_grad[idx.z][idx.y][idx.x][0] += gb_normal_grad.x;\n        params.gb_normal_grad[idx.z][idx.y][idx.x][1] += gb_normal_grad.y;\n        params.gb_normal_grad[idx.z][idx.y][idx.x][2] += gb_normal_grad.z;\n\n        params.gb_kd_grad[idx.z][idx.y][idx.x][0] += gb_kd_grad.x;\n        params.gb_kd_grad[idx.z][idx.y][idx.x][1] += gb_kd_grad.y;\n        params.gb_kd_grad[idx.z][idx.y][idx.x][2] += gb_kd_grad.z;\n\n        params.gb_ks_grad[idx.z][idx.y][idx.x][0] += gb_ks_grad.x;\n        params.gb_ks_grad[idx.z][idx.y][idx.x][1] += gb_ks_grad.y;\n        params.gb_ks_grad[idx.z][idx.y][idx.x][2] += gb_ks_grad.z;\n    }\n\n    diff = _diff * light_col * V * mis_weight * weight;\n    spec = _spec * light_col * V * mis_weight * weight;\n}\n\nextern \"C\" __global__ void __raygen__rg()\n{\n    // Lookup our location within the launch grid\n    const uint3 idx = optixGetLaunchIndex();\n    const uint3 dim = optixGetLaunchDimensions();\n\n    // Read per-pixel constant input tensors, ray_origin, g-buffer entries etc.\n    float  mask        = params.mask[idx.z][idx.y][idx.x];\n    float3 ray_origin  = fetch3(params.ro, idx.z, idx.y, idx.x);\n    float3 gb_pos      = fetch3(params.gb_pos, idx.z, idx.y, idx.x);\n    float3 gb_normal   = fetch3(params.gb_normal, idx.z, idx.y, idx.x);\n    float3 gb_view_pos = fetch3(params.gb_view_pos, idx.z, idx.y, idx.x);\n    float3 gb_kd       = fetch3(params.gb_kd, idx.z, idx.y, idx.x);\n    float3 gb_ks       = fetch3(params.gb_ks, idx.z, idx.y, idx.x);\n\n    if (mask <= 0) return; // Early exit masked pixels\n\n    float3 diff_grad, spec_grad;\n    if (params.backward)\n    {\n        diff_grad = fetch3(params.diff_grad, idx.z, idx.y, idx.x);\n        spec_grad = fetch3(params.spec_grad, idx.z, idx.y, idx.x);\n    }\n\n    float3 diffAccum = make_float3(0.0f, 0.0f, 0.0f);\n    float3 specAccum = make_float3(0.0f, 0.0f, 0.0f);\n\n    float strata_frac = 1.0f / params.n_samples_x;\n    float sample_frac = 1.0f / (params.n_samples_x * params.n_samples_x);\n    float alpha = gb_ks.y * gb_ks.y; // roughness squared\n    float3 wo = safe_normalize(gb_view_pos - gb_pos); // view direction\n\n    float metallic = gb_ks.z;\n    float3 baseColor = gb_kd;\n    float3 specColor = make_float3(0.04f, 0.04f, 0.04f) * (1.0f - metallic) + baseColor * metallic;\n    float diffuseWeight = (1.f - metallic) * luminance(baseColor);\n    float eta = 1.0f;\n    float specularWeight = albedo(specColor, eta, wo, gb_normal);\n    float pDiffuse = (diffuseWeight + specularWeight) > 0.f ? diffuseWeight / (diffuseWeight + specularWeight) : 1.f;\n    float pSpecular = 1.0f - pDiffuse;\n\n    unsigned int rng_state = hash_pcg(params.rnd_seed, (idx.z * dim.y + idx.y) * dim.x + idx.x);\n    unsigned int lightIdx = rand_pcg(rng_state) % params.perms.size(0), bsdfIdx = rand_pcg(rng_state) % params.perms.size(0);\n\n    for (int i = 0; i < params.n_samples_x * params.n_samples_x; ++i)\n    {\n        float3 ray_dir, diff, spec;\n        float sx, sy, sz = 0.f, pdf_light, pdf_bsdf;\n\n        // Light importance sampling\n        sx = ((float)(params.perms[lightIdx][i] % params.n_samples_x) + uniform_pcg(rng_state)) * strata_frac;\n        sy = ((float)(params.perms[lightIdx][i] / params.n_samples_x) + uniform_pcg(rng_state)) * strata_frac;\n        ray_dir = lightSample(sx, sy, pdf_light);\n        pdf_bsdf = bsdf_pdf(pDiffuse, pSpecular, gb_normal, wo, ray_dir, alpha);\n        process_sample(idx, ray_origin, ray_dir, gb_pos, gb_normal, gb_view_pos, gb_kd, gb_ks, pdf_light + pdf_bsdf, sample_frac, diff, spec, diff_grad, spec_grad);\n        diffAccum = diffAccum + diff;\n        specAccum = specAccum + spec;\n\n        // BSDF sampling (sample either the diffuse or specular lobe)\n        sx = ((float)(params.perms[bsdfIdx][i] % params.n_samples_x) + uniform_pcg(rng_state)) * strata_frac;\n        sy = ((float)(params.perms[bsdfIdx][i] / params.n_samples_x) + uniform_pcg(rng_state)) * strata_frac;\n        sz = uniform_pcg(rng_state);\n        ray_dir = bsdf_sample(pDiffuse, pSpecular, gb_normal, wo, make_float3(sx, sy, sz), alpha, pdf_bsdf);\n        pdf_light = lightPDF(ray_dir);\n        process_sample(idx, ray_origin, ray_dir, gb_pos, gb_normal, gb_view_pos, gb_kd, gb_ks, pdf_light + pdf_bsdf, sample_frac, diff, spec, diff_grad, spec_grad);\n        diffAccum = diffAccum + diff;\n        specAccum = specAccum + spec;\n    }\n\n    // Record results in our output raster\n    if (!params.backward)\n    {\n        params.diff[idx.z][idx.y][idx.x][0] = diffAccum.x;\n        params.diff[idx.z][idx.y][idx.x][1] = diffAccum.y;\n        params.diff[idx.z][idx.y][idx.x][2] = diffAccum.z;\n        params.spec[idx.z][idx.y][idx.x][0] = specAccum.x;\n        params.spec[idx.z][idx.y][idx.x][1] = specAccum.y;\n        params.spec[idx.z][idx.y][idx.x][2] = specAccum.z;\n    }\n}\n\nextern \"C\" __global__ void __miss__ms()\n{\n    optixSetPayload_0(1);\n}"
  },
  {
    "path": "render/optixutils/c_src/envsampling/params.h",
    "content": "// Copyright (c) 2020-2022, NVIDIA CORPORATION. All rights reserved.\n//\n// NVIDIA CORPORATION and its licensors retain all intellectual property\n// and proprietary rights in and to this software, related documentation\n// and any modifications thereto. Any use, reproduction, disclosure or\n// distribution of this software and related documentation without an express\n// license agreement from NVIDIA CORPORATION is strictly prohibited.\n\n#include \"../accessor.h\"\n\nstruct EnvSamplingParams\n{\n    // Ray data\n    PackedTensorAccessor32<float, 4>    ro;             // ray origin\n    \n    // GBuffer\n    PackedTensorAccessor32<float, 3>    mask;\n    PackedTensorAccessor32<float, 4>    gb_pos;\n    PackedTensorAccessor32<float, 4>    gb_pos_grad;\n    PackedTensorAccessor32<float, 4>    gb_normal;\n    PackedTensorAccessor32<float, 4>    gb_normal_grad;\n    PackedTensorAccessor32<float, 4>    gb_view_pos;\n    PackedTensorAccessor32<float, 4>    gb_kd;\n    PackedTensorAccessor32<float, 4>    gb_kd_grad;\n    PackedTensorAccessor32<float, 4>    gb_ks;\n    PackedTensorAccessor32<float, 4>    gb_ks_grad;\n    \n    // Light\n    PackedTensorAccessor32<float, 3>    light;\n    PackedTensorAccessor32<float, 3>    light_grad;\n    PackedTensorAccessor32<float, 2>    pdf;        // light pdf\n    PackedTensorAccessor32<float, 1>    rows;       // light sampling cdf\n    PackedTensorAccessor32<float, 2>    cols;       // light sampling cdf\n\n    // Output\n    PackedTensorAccessor32<float, 4>    diff;\n    PackedTensorAccessor32<float, 4>    diff_grad;\n    PackedTensorAccessor32<float, 4>    spec;\n    PackedTensorAccessor32<float, 4>    spec_grad;\n\n    // Table with random permutations for stratified sampling\n    PackedTensorAccessor32<int, 2>      perms;\n\n    OptixTraversableHandle              handle;\n    unsigned int                        BSDF;\n    unsigned int                        n_samples_x;\n    unsigned int                        rnd_seed;\n    unsigned int                        backward;\n    float                               shadow_scale;\n};"
  },
  {
    "path": "render/optixutils/c_src/math_utils.h",
    "content": "// Copyright (c) 2020-2022, NVIDIA CORPORATION. All rights reserved.\n//\n// NVIDIA CORPORATION and its licensors retain all intellectual property\n// and proprietary rights in and to this software, related documentation\n// and any modifications thereto. Any use, reproduction, disclosure or\n// distribution of this software and related documentation without an express\n// license agreement from NVIDIA CORPORATION is strictly prohibited.\n\n#pragma once \n\n#ifdef __CUDACC__\n\ntemplate<class T> static __device__ __inline__ T clamp(T x, T _min, T _max) { return min(_max, max(_min, x)); }\nstatic __device__ inline float3 make_float3(float a) { return make_float3(a, a, a); }\n\nstatic __device__ inline float2&   operator/=  (float2& a, const float2& b)       { a.x /= b.x; a.y /= b.y; return a; }\nstatic __device__ inline float2&   operator*=  (float2& a, const float2& b)       { a.x *= b.x; a.y *= b.y; return a; }\nstatic __device__ inline float2&   operator+=  (float2& a, const float2& b)       { a.x += b.x; a.y += b.y; return a; }\nstatic __device__ inline float2&   operator-=  (float2& a, const float2& b)       { a.x -= b.x; a.y -= b.y; return a; }\nstatic __device__ inline float2&   operator/=  (float2& a, float b)               { a.x /= b; a.y /= b; return a; }\nstatic __device__ inline float2&   operator*=  (float2& a, float b)               { a.x *= b; a.y *= b; return a; }\nstatic __device__ inline float2&   operator+=  (float2& a, float b)               { a.x += b; a.y += b; return a; }\nstatic __device__ inline float2&   operator-=  (float2& a, float b)               { a.x -= b; a.y -= b; return a; }\nstatic __device__ inline float2    operator/   (const float2& a, const float2& b) { return make_float2(a.x / b.x, a.y / b.y); }\nstatic __device__ inline float2    operator*   (const float2& a, const float2& b) { return make_float2(a.x * b.x, a.y * b.y); }\nstatic __device__ inline float2    operator+   (const float2& a, const float2& b) { return make_float2(a.x + b.x, a.y + b.y); }\nstatic __device__ inline float2    operator-   (const float2& a, const float2& b) { return make_float2(a.x - b.x, a.y - b.y); }\nstatic __device__ inline float2    operator/   (const float2& a, float b)         { return make_float2(a.x / b, a.y / b); }\nstatic __device__ inline float2    operator*   (const float2& a, float b)         { return make_float2(a.x * b, a.y * b); }\nstatic __device__ inline float2    operator+   (const float2& a, float b)         { return make_float2(a.x + b, a.y + b); }\nstatic __device__ inline float2    operator-   (const float2& a, float b)         { return make_float2(a.x - b, a.y - b); }\nstatic __device__ inline float2    operator/   (float a, const float2& b)         { return make_float2(a / b.x, a / b.y); }\nstatic __device__ inline float2    operator*   (float a, const float2& b)         { return make_float2(a * b.x, a * b.y); }\nstatic __device__ inline float2    operator+   (float a, const float2& b)         { return make_float2(a + b.x, a + b.y); }\nstatic __device__ inline float2    operator-   (float a, const float2& b)         { return make_float2(a - b.x, a - b.y); }\nstatic __device__ inline float2    operator-   (const float2& a)                  { return make_float2(-a.x, -a.y); }\nstatic __device__ inline float3&   operator/=  (float3& a, const float3& b)       { a.x /= b.x; a.y /= b.y; a.z /= b.z; return a; }\nstatic __device__ inline float3&   operator*=  (float3& a, const float3& b)       { a.x *= b.x; a.y *= b.y; a.z *= b.z; return a; }\nstatic __device__ inline float3&   operator+=  (float3& a, const float3& b)       { a.x += b.x; a.y += b.y; a.z += b.z; return a; }\nstatic __device__ inline float3&   operator-=  (float3& a, const float3& b)       { a.x -= b.x; a.y -= b.y; a.z -= b.z; return a; }\nstatic __device__ inline float3&   operator/=  (float3& a, float b)               { a.x /= b; a.y /= b; a.z /= b; return a; }\nstatic __device__ inline float3&   operator*=  (float3& a, float b)               { a.x *= b; a.y *= b; a.z *= b; return a; }\nstatic __device__ inline float3&   operator+=  (float3& a, float b)               { a.x += b; a.y += b; a.z += b; return a; }\nstatic __device__ inline float3&   operator-=  (float3& a, float b)               { a.x -= b; a.y -= b; a.z -= b; return a; }\nstatic __device__ inline float3    operator/   (const float3& a, const float3& b) { return make_float3(a.x / b.x, a.y / b.y, a.z / b.z); }\nstatic __device__ inline float3    operator*   (const float3& a, const float3& b) { return make_float3(a.x * b.x, a.y * b.y, a.z * b.z); }\nstatic __device__ inline float3    operator+   (const float3& a, const float3& b) { return make_float3(a.x + b.x, a.y + b.y, a.z + b.z); }\nstatic __device__ inline float3    operator-   (const float3& a, const float3& b) { return make_float3(a.x - b.x, a.y - b.y, a.z - b.z); }\nstatic __device__ inline float3    operator/   (const float3& a, float b)         { return make_float3(a.x / b, a.y / b, a.z / b); }\nstatic __device__ inline float3    operator*   (const float3& a, float b)         { return make_float3(a.x * b, a.y * b, a.z * b); }\nstatic __device__ inline float3    operator+   (const float3& a, float b)         { return make_float3(a.x + b, a.y + b, a.z + b); }\nstatic __device__ inline float3    operator-   (const float3& a, float b)         { return make_float3(a.x - b, a.y - b, a.z - b); }\nstatic __device__ inline float3    operator/   (float a, const float3& b)         { return make_float3(a / b.x, a / b.y, a / b.z); }\nstatic __device__ inline float3    operator*   (float a, const float3& b)         { return make_float3(a * b.x, a * b.y, a * b.z); }\nstatic __device__ inline float3    operator+   (float a, const float3& b)         { return make_float3(a + b.x, a + b.y, a + b.z); }\nstatic __device__ inline float3    operator-   (float a, const float3& b)         { return make_float3(a - b.x, a - b.y, a - b.z); }\nstatic __device__ inline float3    operator-   (const float3& a)                  { return make_float3(-a.x, -a.y, -a.z); }\nstatic __device__ inline float4&   operator/=  (float4& a, const float4& b)       { a.x /= b.x; a.y /= b.y; a.z /= b.z; a.w /= b.w; return a; }\nstatic __device__ inline float4&   operator*=  (float4& a, const float4& b)       { a.x *= b.x; a.y *= b.y; a.z *= b.z; a.w *= b.w; return a; }\nstatic __device__ inline float4&   operator+=  (float4& a, const float4& b)       { a.x += b.x; a.y += b.y; a.z += b.z; a.w += b.w; return a; }\nstatic __device__ inline float4&   operator-=  (float4& a, const float4& b)       { a.x -= b.x; a.y -= b.y; a.z -= b.z; a.w -= b.w; return a; }\nstatic __device__ inline float4&   operator/=  (float4& a, float b)               { a.x /= b; a.y /= b; a.z /= b; a.w /= b; return a; }\nstatic __device__ inline float4&   operator*=  (float4& a, float b)               { a.x *= b; a.y *= b; a.z *= b; a.w *= b; return a; }\nstatic __device__ inline float4&   operator+=  (float4& a, float b)               { a.x += b; a.y += b; a.z += b; a.w += b; return a; }\nstatic __device__ inline float4&   operator-=  (float4& a, float b)               { a.x -= b; a.y -= b; a.z -= b; a.w -= b; return a; }\nstatic __device__ inline float4    operator/   (const float4& a, const float4& b) { return make_float4(a.x / b.x, a.y / b.y, a.z / b.z, a.w / b.w); }\nstatic __device__ inline float4    operator*   (const float4& a, const float4& b) { return make_float4(a.x * b.x, a.y * b.y, a.z * b.z, a.w * b.w); }\nstatic __device__ inline float4    operator+   (const float4& a, const float4& b) { return make_float4(a.x + b.x, a.y + b.y, a.z + b.z, a.w + b.w); }\nstatic __device__ inline float4    operator-   (const float4& a, const float4& b) { return make_float4(a.x - b.x, a.y - b.y, a.z - b.z, a.w - b.w); }\nstatic __device__ inline float4    operator/   (const float4& a, float b)         { return make_float4(a.x / b, a.y / b, a.z / b, a.w / b); }\nstatic __device__ inline float4    operator*   (const float4& a, float b)         { return make_float4(a.x * b, a.y * b, a.z * b, a.w * b); }\nstatic __device__ inline float4    operator+   (const float4& a, float b)         { return make_float4(a.x + b, a.y + b, a.z + b, a.w + b); }\nstatic __device__ inline float4    operator-   (const float4& a, float b)         { return make_float4(a.x - b, a.y - b, a.z - b, a.w - b); }\nstatic __device__ inline float4    operator/   (float a, const float4& b)         { return make_float4(a / b.x, a / b.y, a / b.z, a / b.w); }\nstatic __device__ inline float4    operator*   (float a, const float4& b)         { return make_float4(a * b.x, a * b.y, a * b.z, a * b.w); }\nstatic __device__ inline float4    operator+   (float a, const float4& b)         { return make_float4(a + b.x, a + b.y, a + b.z, a + b.w); }\nstatic __device__ inline float4    operator-   (float a, const float4& b)         { return make_float4(a - b.x, a - b.y, a - b.z, a - b.w); }\nstatic __device__ inline float4    operator-   (const float4& a)                  { return make_float4(-a.x, -a.y, -a.z, -a.w); }\n\nstatic __device__ inline float sum(float3 a)\n{\n    return a.x + a.y + a.z;\n}\n\nstatic __device__ inline float dot(float3 a, float3 b) { return a.x * b.x + a.y * b.y + a.z * b.z; }\n\nstatic __device__ inline void bwd_dot(float3 a, float3 b, float3& d_a, float3& d_b, float d_out)\n{\n    d_a.x += d_out * b.x; d_a.y += d_out * b.y; d_a.z += d_out * b.z;\n    d_b.x += d_out * a.x; d_b.y += d_out * a.y; d_b.z += d_out * a.z;\n}\n\nstatic __device__ inline float luminance(const float3 rgb)\n{\n    return dot(rgb, make_float3(0.2126f, 0.7152f, 0.0722f));\n}\n\nstatic __device__ inline float3 cross(float3 a, float3 b)\n{\n    float3 out;\n    out.x = a.y * b.z - a.z * b.y;\n    out.y = a.z * b.x - a.x * b.z;\n    out.z = a.x * b.y - a.y * b.x;\n    return out;\n}\n\nstatic __device__ inline void bwd_cross(float3 a, float3 b, float3 &d_a, float3 &d_b, float3 d_out)\n{\n    d_a.x += d_out.z * b.y - d_out.y * b.z;\n    d_a.y += d_out.x * b.z - d_out.z * b.x;\n    d_a.z += d_out.y * b.x - d_out.x * b.y;\n\n    d_b.x += d_out.y * a.z - d_out.z * a.y;\n    d_b.y += d_out.z * a.x - d_out.x * a.z;\n    d_b.z += d_out.x * a.y - d_out.y * a.x;\n}\n\nstatic __device__ inline float3 reflect(float3 x, float3 n)\n{\n    return n * 2.0f * dot(n, x) - x;\n}\n\nstatic __device__ inline void bwd_reflect(float3 x, float3 n, float3& d_x, float3& d_n, float3 d_out)\n{\n    d_x.x += d_out.x * (2 * n.x * n.x - 1) + d_out.y * (2 * n.x * n.y) + d_out.z * (2 * n.x * n.z);\n    d_x.y += d_out.x * (2 * n.x * n.y) + d_out.y * (2 * n.y * n.y - 1) + d_out.z * (2 * n.y * n.z);\n    d_x.z += d_out.x * (2 * n.x * n.z) + d_out.y * (2 * n.y * n.z) + d_out.z * (2 * n.z * n.z - 1);\n\n    d_n.x += d_out.x * (2 * (2 * n.x * x.x + n.y * x.y + n.z * x.z)) + d_out.y * (2 * n.y * x.x) + d_out.z * (2 * n.z * x.x);\n    d_n.y += d_out.x * (2 * n.x * x.y) + d_out.y * (2 * (n.x * x.x + 2 * n.y * x.y + n.z * x.z)) + d_out.z * (2 * n.z * x.y);\n    d_n.z += d_out.x * (2 * n.x * x.z) + d_out.y * (2 * n.y * x.z) + d_out.z * (2 * (n.x * x.x + n.y * x.y + 2 * n.z * x.z));\n}\n\nstatic __device__ inline float3 safe_normalize(float3 v)\n{\n    float l = sqrtf(v.x * v.x + v.y * v.y + v.z * v.z);\n    return l > 0.0f ? (v / l) : make_float3(0.0f);\n}\n\nstatic __device__ inline void bwd_safe_normalize(const float3 v, float3& d_v, float3 d_out)\n{\n\n    float l = sqrtf(v.x * v.x + v.y * v.y + v.z * v.z);\n    if (l > 0.0f)\n    {\n        float fac = 1.0 / powf(v.x * v.x + v.y * v.y + v.z * v.z, 1.5f);\n        d_v.x += (d_out.x * (v.y * v.y + v.z * v.z) - d_out.y * (v.x * v.y) - d_out.z * (v.x * v.z)) * fac;\n        d_v.y += (d_out.y * (v.x * v.x + v.z * v.z) - d_out.x * (v.y * v.x) - d_out.z * (v.y * v.z)) * fac;\n        d_v.z += (d_out.z * (v.x * v.x + v.y * v.y) - d_out.x * (v.z * v.x) - d_out.y * (v.z * v.y)) * fac;\n    }\n}\n\n// Code from \n// https://graphics.pixar.com/library/OrthonormalB/paper.pdf\nstatic __device__ inline void branchlessONB(const float3 &n, float3 &b1, float3 &b2)\n{\n    float sign = copysignf(1.0f, n.z);\n    const float a = -1.0f / (sign + n.z);\n    const float b = n.x * n.y * a;\n    b1 = make_float3(1.0f + sign * n.x * n.x * a, sign * b, -sign * n.x);\n    b2 = make_float3(b, sign + n.y * n.y * a, -n.y);\n}\n\n#endif"
  },
  {
    "path": "render/optixutils/c_src/optix_wrapper.cpp",
    "content": "// Copyright (c) 2020-2022, NVIDIA CORPORATION. All rights reserved.\n//\n// NVIDIA CORPORATION and its licensors retain all intellectual property\n// and proprietary rights in and to this software, related documentation\n// and any modifications thereto. Any use, reproduction, disclosure or\n// distribution of this software and related documentation without an express\n// license agreement from NVIDIA CORPORATION is strictly prohibited.\n\n#ifdef _MSC_VER \n#pragma warning(push, 0)\n#include <torch/extension.h>\n#pragma warning(pop)\n#else\n#include <torch/extension.h>\n#endif\n\n#include <algorithm>\n#include <ATen/cuda/CUDAContext.h>\n#include <ATen/cuda/CUDAUtils.h>\n#include <optix.h>\n#include <optix_function_table_definition.h>\n#include <optix_stack_size.h>\n#include <optix_stubs.h>\n#include <nvrtc.h>\n#include <cuda_runtime.h>\n\n#include \"common.h\"\n#include \"optix_wrapper.h\"\n\n// NVRTC compiler options\n#define CUDA_NVRTC_OPTIONS  \\\n  \"-std=c++11\", \\\n  \"-arch\", \\\n  \"compute_70\", \\\n  \"-use_fast_math\", \\\n  \"-lineinfo\", \\\n  \"-default-device\", \\\n  \"-rdc\", \\\n  \"true\", \\\n  \"-D__x86_64\", \\\n  \"-D__OPTIX__\"\n\nstatic void context_log_cb( unsigned int level, const char* tag, const char* message, void* /*cbdata */)\n{\n    std::cerr << \"[\" << std::setw( 2 ) << level << \"][\" << std::setw( 12 ) << tag << \"]: \"\n              << message << \"\\n\";\n}\n\nstatic bool readSourceFile( std::string& str, const std::string& filename )\n{\n    // Try to open file\n    std::ifstream file( filename.c_str(), std::ios::binary );\n    if( file.good() )\n    {\n        // Found usable source file\n        std::vector<unsigned char> buffer = std::vector<unsigned char>( std::istreambuf_iterator<char>( file ), {} );\n        str.assign(buffer.begin(), buffer.end());\n        return true;\n    }\n    return false;\n}\n\nstatic void getCuStringFromFile( std::string& cu, const char* filename )\n{\n    // Try to get source code from file\n    if( readSourceFile( cu, filename ) )\n        return;\n\n    // Wasn't able to find or open the requested file\n    throw std::runtime_error( \"Couldn't open source file \" + std::string( filename ) );\n}\n\nstatic void getPtxFromCuString( std::string& ptx, const char* include_dir, const char* optix_include_dir, const char* cuda_include_dir, const char* cu_source, \n    const char* name, const char** log_string )\n{\n    // Create program\n    nvrtcProgram prog = 0;\n    NVRTC_CHECK_ERROR( nvrtcCreateProgram( &prog, cu_source, name, 0, NULL, NULL ) );\n\n    // Gather NVRTC options\n    std::vector<const char*> options;\n\n    std::string sample_dir;\n    sample_dir = std::string( \"-I\" ) + include_dir;\n    options.push_back( sample_dir.c_str() );\n\n    // Collect include dirs\n    std::vector<std::string> include_dirs;\n    \n    include_dirs.push_back( std::string( \"-I\" ) + optix_include_dir );\n    include_dirs.push_back( std::string( \"-I\" ) + cuda_include_dir );\n\n    for( const std::string& dir : include_dirs)\n        options.push_back( dir.c_str() );\n\n    // Collect NVRTC options\n    const char*  compiler_options[] = {CUDA_NVRTC_OPTIONS};\n    std::copy( std::begin( compiler_options ), std::end( compiler_options ), std::back_inserter( options ) );\n\n    // JIT compile CU to PTX\n    const nvrtcResult compileRes = nvrtcCompileProgram( prog, (int)options.size(), options.data() );\n\n    // Retrieve log output\n    std::string g_nvrtcLog;\n    size_t log_size = 0;\n    NVRTC_CHECK_ERROR( nvrtcGetProgramLogSize( prog, &log_size ) );\n    g_nvrtcLog.resize( log_size );\n    if( log_size > 1 )\n    {\n        NVRTC_CHECK_ERROR( nvrtcGetProgramLog( prog, &g_nvrtcLog[0] ) );\n        if( log_string )\n            *log_string = g_nvrtcLog.c_str();\n    }\n    if( compileRes != NVRTC_SUCCESS )\n        throw std::runtime_error( \"NVRTC Compilation failed.\\n\" + g_nvrtcLog );\n\n    // Retrieve PTX code\n    size_t ptx_size = 0;\n    NVRTC_CHECK_ERROR( nvrtcGetPTXSize( prog, &ptx_size ) );\n    ptx.resize( ptx_size );\n    NVRTC_CHECK_ERROR( nvrtcGetPTX( prog, &ptx[0] ) );\n\n    // Cleanup\n    NVRTC_CHECK_ERROR( nvrtcDestroyProgram( &prog ) );\n}\n\nconst char* getInputData( const char* filename, const char* include_dir, const char* optix_include_dir, \n                          const char* cuda_include_dir, const char* name, size_t& dataSize, const char** log)\n{\n    if( log )\n        *log = NULL;\n\n    std::string * ptx, cu;\n    ptx = new std::string();\n\n    getCuStringFromFile( cu, filename );\n    getPtxFromCuString( *ptx, include_dir, optix_include_dir, cuda_include_dir, cu.c_str(), name, log );\n\n    dataSize = ptx->size();\n    return ptx->c_str();\n}\n\n\nstruct SbtRecord\n{\n    __align__( OPTIX_SBT_RECORD_ALIGNMENT ) char header[OPTIX_SBT_RECORD_HEADER_SIZE];\n};\n\nvoid createPipeline(const OptixDeviceContext context, const std::string& path, const std::string& cuda_path, \n                const std::string& kernel_name, OptixModule* module, OptixPipeline* pipeline, OptixShaderBindingTable& sbt)\n{\n    char log[2048];\n    OptixPipelineCompileOptions pipeline_compile_options = {};\n    {\n        OptixModuleCompileOptions module_compile_options = {};\n        module_compile_options.maxRegisterCount     = OPTIX_COMPILE_DEFAULT_MAX_REGISTER_COUNT;\n        module_compile_options.optLevel             = OPTIX_COMPILE_OPTIMIZATION_DEFAULT;\n        module_compile_options.debugLevel           = OPTIX_COMPILE_DEBUG_LEVEL_DEFAULT;\n\n        pipeline_compile_options.usesMotionBlur        = false;\n        pipeline_compile_options.traversableGraphFlags = OPTIX_TRAVERSABLE_GRAPH_FLAG_ALLOW_SINGLE_GAS;\n        pipeline_compile_options.numPayloadValues      = 1;\n        pipeline_compile_options.numAttributeValues    = 2;\n        pipeline_compile_options.exceptionFlags = OPTIX_EXCEPTION_FLAG_NONE;\n        pipeline_compile_options.pipelineLaunchParamsVariableName = \"params\";\n        pipeline_compile_options.usesPrimitiveTypeFlags = OPTIX_PRIMITIVE_TYPE_FLAGS_TRIANGLE;\n\n        size_t      inputSize  = 0;\n        std::string shaderFile = path + \"/c_src/\" + kernel_name + \"/kernel.cu\";\n        std::string includeDir = path + \"/c_src/\" + kernel_name;\n        std::string optix_include_dir = path + \"/include\";\n        std::string cuda_include_dir = cuda_path + \"/include\";\n\n        const char* input = getInputData(shaderFile.c_str(), includeDir.c_str(), optix_include_dir.c_str(),\n                                         cuda_include_dir.c_str(), \"kernel\", inputSize, (const char**)&log);\n        size_t sizeof_log = sizeof( log );\n\n        OPTIX_CHECK_LOG( optixModuleCreateFromPTX(\n                    context,\n                    &module_compile_options,\n                    &pipeline_compile_options,\n                    input,\n                    inputSize,\n                    log,\n                    &sizeof_log,\n                    module) );\n    }\n\n    //\n    // Create program groups\n    //\n    OptixProgramGroup raygen_prog_group   = nullptr;\n    OptixProgramGroup miss_prog_group     = nullptr;\n    {\n        OptixProgramGroupOptions program_group_options   = {}; // Initialize to zeros\n\n        OptixProgramGroupDesc raygen_prog_group_desc    = {}; //\n        raygen_prog_group_desc.kind                     = OPTIX_PROGRAM_GROUP_KIND_RAYGEN;\n        raygen_prog_group_desc.raygen.module            = *module;\n        raygen_prog_group_desc.raygen.entryFunctionName = \"__raygen__rg\";\n        size_t sizeof_log = sizeof( log );\n        OPTIX_CHECK_LOG( optixProgramGroupCreate(\n                    context,\n                    &raygen_prog_group_desc,\n                    1,   // num program groups\n                    &program_group_options,\n                    log,\n                    &sizeof_log,\n                    &raygen_prog_group\n                    ) );\n\n        OptixProgramGroupDesc miss_prog_group_desc  = {};\n        miss_prog_group_desc.kind                   = OPTIX_PROGRAM_GROUP_KIND_MISS;\n        miss_prog_group_desc.miss.module            = *module;\n        miss_prog_group_desc.miss.entryFunctionName = \"__miss__ms\";\n        sizeof_log = sizeof( log );\n        OPTIX_CHECK_LOG( optixProgramGroupCreate(\n                    context,\n                    &miss_prog_group_desc,\n                    1,   // num program groups\n                    &program_group_options,\n                    log,\n                    &sizeof_log,\n                    &miss_prog_group\n                    ) );\n    }\n\n    //\n    // Link pipeline\n    //\n    {\n        const uint32_t    max_trace_depth  = 1;\n        OptixProgramGroup program_groups[] = { raygen_prog_group, miss_prog_group };\n\n        OptixPipelineLinkOptions pipeline_link_options = {};\n        pipeline_link_options.maxTraceDepth          = max_trace_depth;\n        pipeline_link_options.debugLevel             = OPTIX_COMPILE_DEBUG_LEVEL_DEFAULT;\n        size_t sizeof_log = sizeof( log );\n        OPTIX_CHECK_LOG( optixPipelineCreate(\n                    context,\n                    &pipeline_compile_options,\n                    &pipeline_link_options,\n                    program_groups,\n                    sizeof( program_groups ) / sizeof( program_groups[0] ),\n                    log,\n                    &sizeof_log,\n                    pipeline\n                    ) );\n\n        OptixStackSizes stack_sizes = {};\n        for( auto& prog_group : program_groups )\n        {\n            OPTIX_CHECK( optixUtilAccumulateStackSizes( prog_group, &stack_sizes ) );\n        }\n\n        uint32_t direct_callable_stack_size_from_traversal;\n        uint32_t direct_callable_stack_size_from_state;\n        uint32_t continuation_stack_size;\n        OPTIX_CHECK( optixUtilComputeStackSizes( &stack_sizes, max_trace_depth,\n                                                 0,  // maxCCDepth\n                                                 0,  // maxDCDEpth\n                                                 &direct_callable_stack_size_from_traversal,\n                                                 &direct_callable_stack_size_from_state, &continuation_stack_size ) );\n        OPTIX_CHECK( optixPipelineSetStackSize( *pipeline, direct_callable_stack_size_from_traversal,\n                                                direct_callable_stack_size_from_state, continuation_stack_size,\n                                                1  // maxTraversableDepth\n                                                ) );\n    }\n\n    //\n    // Set up shader binding table\n    //\n    {\n        CUdeviceptr  raygen_record;\n        const size_t raygen_record_size = sizeof( SbtRecord );\n        CUDA_CHECK( cudaMalloc( reinterpret_cast<void**>( &raygen_record ), raygen_record_size ) );\n        SbtRecord rg_sbt;\n        OPTIX_CHECK( optixSbtRecordPackHeader( raygen_prog_group, &rg_sbt ) );\n        CUDA_CHECK( cudaMemcpy(\n                    reinterpret_cast<void*>( raygen_record ),\n                    &rg_sbt,\n                    raygen_record_size,\n                    cudaMemcpyHostToDevice\n                    ) );\n\n        CUdeviceptr miss_record;\n        size_t      miss_record_size = sizeof( SbtRecord );\n        CUDA_CHECK( cudaMalloc( reinterpret_cast<void**>( &miss_record ), miss_record_size ) );\n        SbtRecord ms_sbt;\n        OPTIX_CHECK( optixSbtRecordPackHeader( miss_prog_group, &ms_sbt ) );\n        CUDA_CHECK( cudaMemcpy(\n                    reinterpret_cast<void*>( miss_record ),\n                    &ms_sbt,\n                    miss_record_size,\n                    cudaMemcpyHostToDevice\n                    ) );\n\n        sbt.raygenRecord                = raygen_record;\n        sbt.missRecordBase              = miss_record;\n        sbt.missRecordStrideInBytes     = sizeof( SbtRecord );\n        sbt.missRecordCount             = 1;\n    }    \n}\n\nOptiXStateWrapper::OptiXStateWrapper(const std::string& path, const std::string& cuda_path)\n{\n    pState = new OptiXState();\n    memset(pState, 0, sizeof(OptiXState));\n\n    // create OptiX context\n    pState->context = nullptr;\n    {\n        // Initialize the OptiX API, loading all API entry points\n        OPTIX_CHECK( optixInit() );\n\n        // Specify context options\n        OptixDeviceContextOptions options = {};\n        options.logCallbackFunction       = &context_log_cb;\n        options.logCallbackLevel          = 0;\n\n        // Associate a CUDA context (and therefore a specific GPU) with this\n        // device context\n        CUcontext cuCtx = 0;  // zero means take the current context\n        OPTIX_CHECK( optixDeviceContextCreate( cuCtx, &options, &pState->context ) );\n    }\n\n    // Create pipelines\n    pState->moduleEnvSampling = nullptr;\n    pState->pipelineEnvSampling = nullptr;\n    pState->sbtEnvSampling = {};\n    createPipeline(pState->context, path, cuda_path, \"envsampling\", &pState->moduleEnvSampling, &pState->pipelineEnvSampling, pState->sbtEnvSampling);\n\n    printf(\"End of OptiXStateWrapper \\n\");\n}\n\nOptiXStateWrapper::~OptiXStateWrapper(void)\n{\n    OPTIX_CHECK( optixPipelineDestroy( pState->pipelineEnvSampling ) );\n    CUDA_CHECK( cudaFree( reinterpret_cast<void*>( pState->sbtEnvSampling.raygenRecord       ) ) );\n    CUDA_CHECK( cudaFree( reinterpret_cast<void*>( pState->sbtEnvSampling.missRecordBase     ) ) );\n    OPTIX_CHECK( optixModuleDestroy( pState->moduleEnvSampling ) );\n\n    CUDA_CHECK( cudaFree( reinterpret_cast<void*>( pState->d_gas_output_buffer ) ) );\n    OPTIX_CHECK( optixDeviceContextDestroy( pState->context ) ); \n    delete pState;\n    printf(\"OptiXStateWrapper destructor \\n\");\n}\n"
  },
  {
    "path": "render/optixutils/c_src/optix_wrapper.h",
    "content": "// Copyright (c) 2020-2022, NVIDIA CORPORATION. All rights reserved.\n//\n// NVIDIA CORPORATION and its licensors retain all intellectual property\n// and proprietary rights in and to this software, related documentation\n// and any modifications thereto. Any use, reproduction, disclosure or\n// distribution of this software and related documentation without an express\n// license agreement from NVIDIA CORPORATION is strictly prohibited.\n\n#pragma once\n\n#include <optix.h>\n#include <string>\n\n//------------------------------------------------------------------------\n// Python OptiX state wrapper.\n\nstruct OptiXState\n{\n    OptixDeviceContext context;\n    OptixTraversableHandle gas_handle;\n    CUdeviceptr            d_gas_output_buffer;\n\n    // Differentiable env sampling\n    OptixPipeline pipelineEnvSampling;\n    OptixShaderBindingTable sbtEnvSampling;\n    OptixModule moduleEnvSampling;\n};\n\n\nclass OptiXStateWrapper\n{\npublic:\n    OptiXStateWrapper     (const std::string &path, const std::string &cuda_path);\n    ~OptiXStateWrapper    (void);\n    \n    OptiXState*           pState;\n};\n\n"
  },
  {
    "path": "render/optixutils/c_src/torch_bindings.cpp",
    "content": "// Copyright (c) 2020-2022, NVIDIA CORPORATION. All rights reserved.\n//\n// NVIDIA CORPORATION and its licensors retain all intellectual property\n// and proprietary rights in and to this software, related documentation\n// and any modifications thereto. Any use, reproduction, disclosure or\n// distribution of this software and related documentation without an express\n// license agreement from NVIDIA CORPORATION is strictly prohibited.\n\n#ifdef _MSC_VER \n#pragma warning(push, 0)\n#include <torch/extension.h>\n#pragma warning(pop)\n#else\n#include <torch/extension.h>\n#endif\n\n#include <algorithm>\n#include <string>\n#include <ATen/cuda/CUDAContext.h>\n#include <ATen/cuda/CUDAUtils.h>\n#include <optix_stubs.h>\n\n#include \"common.h\"\n#include \"optix_wrapper.h\"\n#include \"denoising.h\"\n#include \"envsampling/params.h\"\n\n//------------------------------------------------------------------------\n// CUDA kernels\n\nvoid bilateral_denoiser_fwd_kernel(BilateralDenoiserParams params);\nvoid bilateral_denoiser_bwd_kernel(BilateralDenoiserParams params);\n\n//------------------------------------------------------------------------\n// OptiX tracer\n\nvoid optix_build_bvh(OptiXStateWrapper& stateWrapper,torch::Tensor grid_verts, torch::Tensor grid_tris, unsigned int rebuild)\n{\n    //\n    // accel handling\n    //\n\n    // Clear BVH GPU memory\n    {\n        // Use default options for simplicity.  In a real use case we would want to\n        // enable compaction, etc\n        OptixAccelBuildOptions accel_options = {};\n        accel_options.buildFlags = OPTIX_BUILD_FLAG_ALLOW_COMPACTION | OPTIX_BUILD_FLAG_ALLOW_UPDATE | OPTIX_BUILD_FLAG_ALLOW_RANDOM_VERTEX_ACCESS;\n\n        if (rebuild > 0)\n        {\n            CUDA_CHECK( cudaFree( reinterpret_cast<void*>( stateWrapper.pState->d_gas_output_buffer ) ) ); \n            accel_options.operation  = OPTIX_BUILD_OPERATION_BUILD;\n        }\n        else \n        {\n            accel_options.operation = OPTIX_BUILD_OPERATION_UPDATE;\n        }\n        CUdeviceptr d_vertices = (CUdeviceptr)grid_verts.data_ptr<float>();\n        CUdeviceptr d_indices = (CUdeviceptr)grid_tris.data_ptr<int>();\n\n        // Our build input is a simple list of non-indexed triangle vertices\n        const uint32_t triangle_input_flags[1] = { OPTIX_GEOMETRY_FLAG_NONE };\n        OptixBuildInput triangle_input = {};\n        triangle_input.type                        = OPTIX_BUILD_INPUT_TYPE_TRIANGLES;\n        triangle_input.triangleArray.vertexFormat  = OPTIX_VERTEX_FORMAT_FLOAT3;\n        triangle_input.triangleArray.numVertices   = (uint32_t)grid_verts.size(0);\n        triangle_input.triangleArray.vertexBuffers = &d_vertices;\n        triangle_input.triangleArray.indexFormat   = OPTIX_INDICES_FORMAT_UNSIGNED_INT3;\n        triangle_input.triangleArray.numIndexTriplets = (uint32_t)grid_tris.size(0);\n        triangle_input.triangleArray.indexBuffer   = d_indices;\n        triangle_input.triangleArray.flags         = triangle_input_flags;\n        triangle_input.triangleArray.numSbtRecords = 1;\n\n        OptixAccelBufferSizes gas_buffer_sizes;\n        OPTIX_CHECK( optixAccelComputeMemoryUsage(\n                    stateWrapper.pState->context,\n                    &accel_options,\n                    &triangle_input,\n                    1, // Number of build inputs\n                    &gas_buffer_sizes\n                    ) );\n        CUdeviceptr d_temp_buffer_gas;\n        CUDA_CHECK( cudaMalloc(\n                    reinterpret_cast<void**>( &d_temp_buffer_gas ),\n                    gas_buffer_sizes.tempSizeInBytes\n                    ) );\n\n        if (rebuild > 0)\n        {\n            CUDA_CHECK( cudaMalloc(\n                        reinterpret_cast<void**>( &stateWrapper.pState->d_gas_output_buffer ),\n                        gas_buffer_sizes.outputSizeInBytes\n                        ) );\n        }\n\n        OPTIX_CHECK( optixAccelBuild(\n                    stateWrapper.pState->context,\n                    0,                  // CUDA stream\n                    &accel_options,\n                    &triangle_input,\n                    1,                  // num build inputs\n                    d_temp_buffer_gas,\n                    gas_buffer_sizes.tempSizeInBytes,\n                    stateWrapper.pState->d_gas_output_buffer,\n                    gas_buffer_sizes.outputSizeInBytes,\n                    &stateWrapper.pState->gas_handle,\n                    nullptr,            // emitted property list\n                    0                   // num emitted properties\n                    ) );\n\n        // We can now free the scratch space buffer used during build and the vertex\n        // inputs, since they are not needed by our trivial shading method\n        CUDA_CHECK( cudaFree( reinterpret_cast<void*>( d_temp_buffer_gas ) ) );\n    }\n}\n\ntemplate<class T, int N, template <typename U> class PtrTraits = DefaultPtrTraits> PackedTensorAccessor32<T, N> packed_accessor32(torch::Tensor tensor)\n{\n    return PackedTensorAccessor32<T,N,PtrTraits>(static_cast<typename PtrTraits<T>::PtrType>(tensor.data_ptr<T>()), tensor.sizes().data(), tensor.strides().data());\n}\n\nstd::tuple<torch::Tensor, torch::Tensor> env_shade_fwd(\n    OptiXStateWrapper& stateWrapper, \n    torch::Tensor mask, \n    torch::Tensor ro, \n    torch::Tensor gb_pos, \n    torch::Tensor gb_normal, \n    torch::Tensor gb_view_pos, \n    torch::Tensor gb_kd, \n    torch::Tensor gb_ks, \n    torch::Tensor light, \n    torch::Tensor pdf, \n    torch::Tensor rows, \n    torch::Tensor cols,\n    torch::Tensor perms,\n    unsigned int BSDF,\n    unsigned int n_samples_x,\n    unsigned int rnd_seed,\n    float shadow_scale)\n{\n    //\n    // launch OptiX kernel\n    //\n    cudaStream_t stream = at::cuda::getCurrentCUDAStream();\n\n    // Allocate output tensors.\n    torch::TensorOptions opts = torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA);\n    torch::Tensor diff = torch::zeros({ ro.size(0), ro.size(1), ro.size(2), 3 }, opts) ;\n    torch::Tensor spec = torch::zeros({ ro.size(0), ro.size(1), ro.size(2), 3 }, opts) ;\n\n    EnvSamplingParams params;\n    params.handle       = stateWrapper.pState->gas_handle;\n    params.mask         = packed_accessor32<float, 3>(mask);\n    params.ro           = packed_accessor32<float, 4>(ro);\n    params.gb_pos       = packed_accessor32<float, 4>(gb_pos);\n    params.gb_normal    = packed_accessor32<float, 4>(gb_normal);\n    params.gb_view_pos  = packed_accessor32<float, 4>(gb_view_pos);\n    params.gb_kd        = packed_accessor32<float, 4>(gb_kd);\n    params.gb_ks        = packed_accessor32<float, 4>(gb_ks);\n    params.light        = packed_accessor32<float, 3>(light);\n    params.pdf          = packed_accessor32<float, 2>(pdf);\n    params.rows         = packed_accessor32<float, 1>(rows);\n    params.cols         = packed_accessor32<float, 2>(cols);\n    params.diff         = packed_accessor32<float, 4>(diff);\n    params.spec         = packed_accessor32<float, 4>(spec);\n    params.perms        = packed_accessor32<int, 2>(perms);\n    params.BSDF         = BSDF;\n    params.n_samples_x  = n_samples_x;\n    params.rnd_seed     = rnd_seed;\n    params.backward     = 0;\n    params.shadow_scale = shadow_scale;\n\n    CUdeviceptr d_param;\n    CUDA_CHECK( cudaMalloc( reinterpret_cast<void**>( &d_param ), sizeof( EnvSamplingParams ) ) );\n    CUDA_CHECK( cudaMemcpy(\n                reinterpret_cast<void*>( d_param ),\n                &params, sizeof( params ),\n                cudaMemcpyHostToDevice\n                ) );\n\n    OPTIX_CHECK( optixLaunch( stateWrapper.pState->pipelineEnvSampling, stream, d_param, sizeof( EnvSamplingParams ), \n                              &stateWrapper.pState->sbtEnvSampling, ro.size(2), ro.size(1), ro.size(0) ) );\n\n    CUDA_CHECK( cudaStreamSynchronize( stream ) );\n\n    return std::tuple<torch::Tensor, torch::Tensor>(diff, spec);\n}\n\nstd::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor> env_shade_bwd(\n    OptiXStateWrapper& stateWrapper, \n    torch::Tensor mask, \n    torch::Tensor ro, \n    torch::Tensor gb_pos, \n    torch::Tensor gb_normal, \n    torch::Tensor gb_view_pos, \n    torch::Tensor gb_kd, \n    torch::Tensor gb_ks, \n    torch::Tensor light,\n    torch::Tensor pdf,\n    torch::Tensor rows, \n    torch::Tensor cols, \n    torch::Tensor perms,\n    unsigned int BSDF,\n    unsigned int n_samples_x,\n    unsigned int rnd_seed,\n    float shadow_scale,\n    torch::Tensor diff_grad,\n    torch::Tensor spec_grad)\n{\n    //\n    // launch OptiX kernel\n    //\n    cudaStream_t stream = at::cuda::getCurrentCUDAStream();\n\n    // Allocate output tensors.\n    torch::TensorOptions opts = torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA);\n\n    EnvSamplingParams params;\n    params.handle       = stateWrapper.pState->gas_handle;\n    params.mask         = packed_accessor32<float, 3>(mask);\n    params.ro           = packed_accessor32<float, 4>(ro);\n    params.gb_pos       = packed_accessor32<float, 4>(gb_pos);\n    params.gb_normal    = packed_accessor32<float, 4>(gb_normal);\n    params.gb_view_pos  = packed_accessor32<float, 4>(gb_view_pos);\n    params.gb_kd        = packed_accessor32<float, 4>(gb_kd);\n    params.gb_ks        = packed_accessor32<float, 4>(gb_ks);\n    params.light        = packed_accessor32<float, 3>(light);\n    params.pdf          = packed_accessor32<float, 2>(pdf);\n    params.rows         = packed_accessor32<float, 1>(rows);\n    params.cols         = packed_accessor32<float, 2>(cols);\n    params.diff_grad    = packed_accessor32<float, 4>(diff_grad);\n    params.spec_grad    = packed_accessor32<float, 4>(spec_grad);\n    params.perms        = packed_accessor32<int, 2>(perms);\n    params.BSDF         = BSDF;\n    params.n_samples_x  = n_samples_x;\n    params.rnd_seed     = rnd_seed;\n    params.backward     = 1;\n    params.shadow_scale = shadow_scale;\n\n    // Create gradient tensor for pos\n    torch::Tensor gb_pos_grad = torch::zeros({ ro.size(0), ro.size(1), ro.size(2), gb_pos.size(3) }, torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA));\n    params.gb_pos_grad = packed_accessor32<float, 4>(gb_pos_grad);\n\n    torch::Tensor gb_normal_grad = torch::zeros({ ro.size(0), ro.size(1), ro.size(2), gb_normal.size(3) }, torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA));\n    params.gb_normal_grad = packed_accessor32<float, 4>(gb_normal_grad);\n\n    torch::Tensor gb_kd_grad = torch::zeros({ ro.size(0), ro.size(1), ro.size(2), gb_kd.size(3) }, torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA));\n    params.gb_kd_grad = packed_accessor32<float, 4>(gb_kd_grad);\n\n    torch::Tensor gb_ks_grad = torch::zeros({ ro.size(0), ro.size(1), ro.size(2), gb_ks.size(3) }, torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA));\n    params.gb_ks_grad = packed_accessor32<float, 4>(gb_ks_grad);\n\n    // Create gradient tensor for light\n    torch::Tensor light_grad = torch::zeros({ light.size(0), light.size(1), light.size(2) }, torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA));\n    params.light_grad = packed_accessor32<float, 3>(light_grad);\n\n    CUdeviceptr d_param;\n    CUDA_CHECK( cudaMalloc( reinterpret_cast<void**>( &d_param ), sizeof( EnvSamplingParams ) ) );\n    CUDA_CHECK( cudaMemcpy(\n                reinterpret_cast<void*>( d_param ),\n                &params, sizeof( params ),\n                cudaMemcpyHostToDevice\n                ) );\n\n    OPTIX_CHECK( optixLaunch( stateWrapper.pState->pipelineEnvSampling, stream, d_param, sizeof( EnvSamplingParams ), \n                              &stateWrapper.pState->sbtEnvSampling, ro.size(2), ro.size(1), ro.size(0) ) );\n\n    CUDA_CHECK( cudaStreamSynchronize( stream ) );\n\n    return std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>(gb_pos_grad, gb_normal_grad, gb_kd_grad, gb_ks_grad, light_grad);\n}\n\ntorch::Tensor bilateral_denoiser_fwd(torch::Tensor col, torch::Tensor nrm, torch::Tensor zdz, float sigma)\n{\n    cudaStream_t stream = at::cuda::getCurrentCUDAStream();\n\n    torch::TensorOptions opts = torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA);\n    torch::Tensor out = torch::zeros({ col.size(0), col.size(1), col.size(2), 4 }, opts);\n\n    dim3 blockSize(8, 8, 1);\n    dim3 gridSize((col.size(2) - 1) / blockSize.x + 1, (col.size(1) - 1) / blockSize.y + 1, (col.size(0) - 1) / blockSize.z + 1);\n\n    BilateralDenoiserParams params;\n    params.col = packed_accessor32<float, 4>(col);\n    params.nrm = packed_accessor32<float, 4>(nrm);\n    params.zdz = packed_accessor32<float, 4>(zdz);\n    params.out = packed_accessor32<float, 4>(out);\n    params.sigma = sigma;\n\n    void *args[] = {&params};\n    CUDA_CHECK(cudaLaunchKernel((const void *)bilateral_denoiser_fwd_kernel, gridSize, blockSize, args, 0, stream));\n\n    return out;\n}\n\ntorch::Tensor bilateral_denoiser_bwd(torch::Tensor col, torch::Tensor nrm, torch::Tensor zdz, float sigma, torch::Tensor out_grad)\n{\n    cudaStream_t stream = at::cuda::getCurrentCUDAStream();\n\n    torch::TensorOptions opts = torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA);\n    torch::Tensor col_grad = torch::zeros({ col.size(0), col.size(1), col.size(2), col.size(3) }, opts);\n\n    dim3 blockSize(8, 8, 1);\n    dim3 gridSize((col.size(2) - 1) / blockSize.x + 1, (col.size(1) - 1) / blockSize.y + 1, (col.size(0) - 1) / blockSize.z + 1);\n\n    BilateralDenoiserParams params;\n    params.col = packed_accessor32<float, 4>(col);\n    params.nrm = packed_accessor32<float, 4>(nrm);\n    params.zdz = packed_accessor32<float, 4>(zdz);\n    params.out_grad = packed_accessor32<float, 4>(out_grad);\n    params.col_grad = packed_accessor32<float, 4>(col_grad);\n    params.sigma = sigma;\n\n    void *args[] = {&params};\n    CUDA_CHECK(cudaLaunchKernel((const void *)bilateral_denoiser_bwd_kernel, gridSize, blockSize, args, 0, stream));\n\n    return col_grad;\n}\n\nPYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {\n    pybind11::class_<OptiXStateWrapper>(m, \"OptiXStateWrapper\").def(pybind11::init<const std::string &, const std::string &>());\n    m.def(\"env_shade_fwd\", &env_shade_fwd, \"env_shade_fwd\");\n    m.def(\"env_shade_bwd\", &env_shade_bwd, \"env_shade_bwd\");\n    // m.def(\"env_shade_single_sided_fwd\", &env_shade_fwd, \"env_shade_single_sided_fwd\");\n    // m.def(\"env_shade_single_sided_bwd\", &env_shade_bwd, \"env_shade_single_sided_bwd\");\n    m.def(\"optix_build_bvh\", &optix_build_bvh, \"optix_build_bvh\");\n    m.def(\"bilateral_denoiser_fwd\", &bilateral_denoiser_fwd, \"bilateral_denoiser_fwd\");\n    m.def(\"bilateral_denoiser_bwd\", &bilateral_denoiser_bwd, \"bilateral_denoiser_bwd\");    \n}"
  },
  {
    "path": "render/optixutils/include/internal/optix_7_device_impl.h",
    "content": "/*\n* Copyright (c) 2021 NVIDIA Corporation.  All rights reserved.\n*\n* NVIDIA Corporation and its licensors retain all intellectual property and proprietary\n* rights in and to this software, related documentation and any modifications thereto.\n* Any use, reproduction, disclosure or distribution of this software and related\n* documentation without an express license agreement from NVIDIA Corporation is strictly\n* prohibited.\n*\n* TO THE MAXIMUM EXTENT PERMITTED BY APPLICABLE LAW, THIS SOFTWARE IS PROVIDED *AS IS*\n* AND NVIDIA AND ITS SUPPLIERS DISCLAIM ALL WARRANTIES, EITHER EXPRESS OR IMPLIED,\n* INCLUDING, BUT NOT LIMITED TO, IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A\n* PARTICULAR PURPOSE.  IN NO EVENT SHALL NVIDIA OR ITS SUPPLIERS BE LIABLE FOR ANY\n* SPECIAL, INCIDENTAL, INDIRECT, OR CONSEQUENTIAL DAMAGES WHATSOEVER (INCLUDING, WITHOUT\n* LIMITATION, DAMAGES FOR LOSS OF BUSINESS PROFITS, BUSINESS INTERRUPTION, LOSS OF\n* BUSINESS INFORMATION, OR ANY OTHER PECUNIARY LOSS) ARISING OUT OF THE USE OF OR\n* INABILITY TO USE THIS SOFTWARE, EVEN IF NVIDIA HAS BEEN ADVISED OF THE POSSIBILITY OF\n* SUCH DAMAGES\n*/\n\n/**\n* @file   optix_7_device_impl.h\n* @author NVIDIA Corporation\n* @brief  OptiX public API\n*\n* OptiX public API Reference - Device side implementation\n*/\n\n#if !defined( __OPTIX_INCLUDE_INTERNAL_HEADERS__ )\n#error(\"optix_7_device_impl.h is an internal header file and must not be used directly.  Please use optix_device.h or optix.h instead.\")\n#endif\n\n#ifndef __optix_optix_7_device_impl_h__\n#define __optix_optix_7_device_impl_h__\n\n#include \"internal/optix_7_device_impl_exception.h\"\n#include \"internal/optix_7_device_impl_transformations.h\"\n\nstatic __forceinline__ __device__ void optixTrace( OptixTraversableHandle handle,\n                                                   float3                 rayOrigin,\n                                                   float3                 rayDirection,\n                                                   float                  tmin,\n                                                   float                  tmax,\n                                                   float                  rayTime,\n                                                   OptixVisibilityMask    visibilityMask,\n                                                   unsigned int           rayFlags,\n                                                   unsigned int           SBToffset,\n                                                   unsigned int           SBTstride,\n                                                   unsigned int           missSBTIndex )\n{\n    float        ox = rayOrigin.x, oy = rayOrigin.y, oz = rayOrigin.z;\n    float        dx = rayDirection.x, dy = rayDirection.y, dz = rayDirection.z;\n    unsigned int p0, p1, p2, p3, p4, p5, p6, p7, p8, p9, p10, p11, p12, p13, p14, p15, p16, p17, p18, p19, p20, p21,\n        p22, p23, p24, p25, p26, p27, p28, p29, p30, p31;\n    asm volatile(\n        \"call\"\n        \"(%0,%1,%2,%3,%4,%5,%6,%7,%8,%9,%10,%11,%12,%13,%14,%15,%16,%17,%18,%19,%20,%21,%22,%23,%24,%25,%26,%27,%28,%\"\n        \"29,%30,%31),\"\n        \"_optix_trace_typed_32,\"\n        \"(%32,%33,%34,%35,%36,%37,%38,%39,%40,%41,%42,%43,%44,%45,%46,%47,%48,%49,%50,%51,%52,%53,%54,%55,%56,%57,%58,%\"\n        \"59,%60,%61,%62,%63,%64,%65,%66,%67,%68,%69,%70,%71,%72,%73,%74,%75,%76,%77,%78,%79,%80);\"\n        : \"=r\"( p0 ), \"=r\"( p1 ), \"=r\"( p2 ), \"=r\"( p3 ), \"=r\"( p4 ), \"=r\"( p5 ), \"=r\"( p6 ), \"=r\"( p7 ), \"=r\"( p8 ),\n          \"=r\"( p9 ), \"=r\"( p10 ), \"=r\"( p11 ), \"=r\"( p12 ), \"=r\"( p13 ), \"=r\"( p14 ), \"=r\"( p15 ), \"=r\"( p16 ),\n          \"=r\"( p17 ), \"=r\"( p18 ), \"=r\"( p19 ), \"=r\"( p20 ), \"=r\"( p21 ), \"=r\"( p22 ), \"=r\"( p23 ), \"=r\"( p24 ),\n          \"=r\"( p25 ), \"=r\"( p26 ), \"=r\"( p27 ), \"=r\"( p28 ), \"=r\"( p29 ), \"=r\"( p30 ), \"=r\"( p31 )\n        : \"r\"( 0 ), \"l\"( handle ), \"f\"( ox ), \"f\"( oy ), \"f\"( oz ), \"f\"( dx ), \"f\"( dy ), \"f\"( dz ), \"f\"( tmin ),\n          \"f\"( tmax ), \"f\"( rayTime ), \"r\"( visibilityMask ), \"r\"( rayFlags ), \"r\"( SBToffset ), \"r\"( SBTstride ),\n          \"r\"( missSBTIndex ), \"r\"( 0 ), \"r\"( p0 ), \"r\"( p1 ), \"r\"( p2 ), \"r\"( p3 ), \"r\"( p4 ), \"r\"( p5 ), \"r\"( p6 ),\n          \"r\"( p7 ), \"r\"( p8 ), \"r\"( p9 ), \"r\"( p10 ), \"r\"( p11 ), \"r\"( p12 ), \"r\"( p13 ), \"r\"( p14 ), \"r\"( p15 ),\n          \"r\"( p16 ), \"r\"( p17 ), \"r\"( p18 ), \"r\"( p19 ), \"r\"( p20 ), \"r\"( p21 ), \"r\"( p22 ), \"r\"( p23 ), \"r\"( p24 ),\n          \"r\"( p25 ), \"r\"( p26 ), \"r\"( p27 ), \"r\"( p28 ), \"r\"( p29 ), \"r\"( p30 ), \"r\"( p31 )\n        : );\n    (void)p0, (void)p1, (void)p2, (void)p3, (void)p4, (void)p5, (void)p6, (void)p7, (void)p8, (void)p9, (void)p10, (void)p11,\n        (void)p12, (void)p13, (void)p14, (void)p15, (void)p16, (void)p17, (void)p18, (void)p19, (void)p20, (void)p21,\n        (void)p22, (void)p23, (void)p24, (void)p25, (void)p26, (void)p27, (void)p28, (void)p29, (void)p30, (void)p31;\n}\n\nstatic __forceinline__ __device__ void optixTrace( OptixTraversableHandle handle,\n                                                   float3                 rayOrigin,\n                                                   float3                 rayDirection,\n                                                   float                  tmin,\n                                                   float                  tmax,\n                                                   float                  rayTime,\n                                                   OptixVisibilityMask    visibilityMask,\n                                                   unsigned int           rayFlags,\n                                                   unsigned int           SBToffset,\n                                                   unsigned int           SBTstride,\n                                                   unsigned int           missSBTIndex,\n                                                   unsigned int&          p0 )\n{\n    float        ox = rayOrigin.x, oy = rayOrigin.y, oz = rayOrigin.z;\n    float        dx = rayDirection.x, dy = rayDirection.y, dz = rayDirection.z;\n    unsigned int p1, p2, p3, p4, p5, p6, p7, p8, p9, p10, p11, p12, p13, p14, p15, p16, p17, p18, p19, p20, p21, p22,\n        p23, p24, p25, p26, p27, p28, p29, p30, p31;\n    asm volatile(\n        \"call\"\n        \"(%0,%1,%2,%3,%4,%5,%6,%7,%8,%9,%10,%11,%12,%13,%14,%15,%16,%17,%18,%19,%20,%21,%22,%23,%24,%25,%26,%27,%28,%\"\n        \"29,%30,%31),\"\n        \"_optix_trace_typed_32,\"\n        \"(%32,%33,%34,%35,%36,%37,%38,%39,%40,%41,%42,%43,%44,%45,%46,%47,%48,%49,%50,%51,%52,%53,%54,%55,%56,%57,%58,%\"\n        \"59,%60,%61,%62,%63,%64,%65,%66,%67,%68,%69,%70,%71,%72,%73,%74,%75,%76,%77,%78,%79,%80);\"\n        : \"=r\"( p0 ), \"=r\"( p1 ), \"=r\"( p2 ), \"=r\"( p3 ), \"=r\"( p4 ), \"=r\"( p5 ), \"=r\"( p6 ), \"=r\"( p7 ), \"=r\"( p8 ),\n          \"=r\"( p9 ), \"=r\"( p10 ), \"=r\"( p11 ), \"=r\"( p12 ), \"=r\"( p13 ), \"=r\"( p14 ), \"=r\"( p15 ), \"=r\"( p16 ),\n          \"=r\"( p17 ), \"=r\"( p18 ), \"=r\"( p19 ), \"=r\"( p20 ), \"=r\"( p21 ), \"=r\"( p22 ), \"=r\"( p23 ), \"=r\"( p24 ),\n          \"=r\"( p25 ), \"=r\"( p26 ), \"=r\"( p27 ), \"=r\"( p28 ), \"=r\"( p29 ), \"=r\"( p30 ), \"=r\"( p31 )\n        : \"r\"( 0 ), \"l\"( handle ), \"f\"( ox ), \"f\"( oy ), \"f\"( oz ), \"f\"( dx ), \"f\"( dy ), \"f\"( dz ), \"f\"( tmin ),\n          \"f\"( tmax ), \"f\"( rayTime ), \"r\"( visibilityMask ), \"r\"( rayFlags ), \"r\"( SBToffset ), \"r\"( SBTstride ),\n          \"r\"( missSBTIndex ), \"r\"( 1 ), \"r\"( p0 ), \"r\"( p1 ), \"r\"( p2 ), \"r\"( p3 ), \"r\"( p4 ), \"r\"( p5 ), \"r\"( p6 ),\n          \"r\"( p7 ), \"r\"( p8 ), \"r\"( p9 ), \"r\"( p10 ), \"r\"( p11 ), \"r\"( p12 ), \"r\"( p13 ), \"r\"( p14 ), \"r\"( p15 ),\n          \"r\"( p16 ), \"r\"( p17 ), \"r\"( p18 ), \"r\"( p19 ), \"r\"( p20 ), \"r\"( p21 ), \"r\"( p22 ), \"r\"( p23 ), \"r\"( p24 ),\n          \"r\"( p25 ), \"r\"( p26 ), \"r\"( p27 ), \"r\"( p28 ), \"r\"( p29 ), \"r\"( p30 ), \"r\"( p31 )\n        : );\n    (void)p1, (void)p2, (void)p3, (void)p4, (void)p5, (void)p6, (void)p7, (void)p8, (void)p9, (void)p10, (void)p11,\n        (void)p12, (void)p13, (void)p14, (void)p15, (void)p16, (void)p17, (void)p18, (void)p19, (void)p20, (void)p21,\n        (void)p22, (void)p23, (void)p24, (void)p25, (void)p26, (void)p27, (void)p28, (void)p29, (void)p30, (void)p31;\n}\n\nstatic __forceinline__ __device__ void optixTrace( OptixTraversableHandle handle,\n                                                   float3                 rayOrigin,\n                                                   float3                 rayDirection,\n                                                   float                  tmin,\n                                                   float                  tmax,\n                                                   float                  rayTime,\n                                                   OptixVisibilityMask    visibilityMask,\n                                                   unsigned int           rayFlags,\n                                                   unsigned int           SBToffset,\n                                                   unsigned int           SBTstride,\n                                                   unsigned int           missSBTIndex,\n                                                   unsigned int&          p0,\n                                                   unsigned int&          p1 )\n{\n    float        ox = rayOrigin.x, oy = rayOrigin.y, oz = rayOrigin.z;\n    float        dx = rayDirection.x, dy = rayDirection.y, dz = rayDirection.z;\n    unsigned int p2, p3, p4, p5, p6, p7, p8, p9, p10, p11, p12, p13, p14, p15, p16, p17, p18, p19, p20, p21, p22, p23,\n        p24, p25, p26, p27, p28, p29, p30, p31;\n    asm volatile(\n        \"call\"\n        \"(%0,%1,%2,%3,%4,%5,%6,%7,%8,%9,%10,%11,%12,%13,%14,%15,%16,%17,%18,%19,%20,%21,%22,%23,%24,%25,%26,%27,%28,%\"\n        \"29,%30,%31),\"\n        \"_optix_trace_typed_32,\"\n        \"(%32,%33,%34,%35,%36,%37,%38,%39,%40,%41,%42,%43,%44,%45,%46,%47,%48,%49,%50,%51,%52,%53,%54,%55,%56,%57,%58,%\"\n        \"59,%60,%61,%62,%63,%64,%65,%66,%67,%68,%69,%70,%71,%72,%73,%74,%75,%76,%77,%78,%79,%80);\"\n        : \"=r\"( p0 ), \"=r\"( p1 ), \"=r\"( p2 ), \"=r\"( p3 ), \"=r\"( p4 ), \"=r\"( p5 ), \"=r\"( p6 ), \"=r\"( p7 ), \"=r\"( p8 ),\n          \"=r\"( p9 ), \"=r\"( p10 ), \"=r\"( p11 ), \"=r\"( p12 ), \"=r\"( p13 ), \"=r\"( p14 ), \"=r\"( p15 ), \"=r\"( p16 ),\n          \"=r\"( p17 ), \"=r\"( p18 ), \"=r\"( p19 ), \"=r\"( p20 ), \"=r\"( p21 ), \"=r\"( p22 ), \"=r\"( p23 ), \"=r\"( p24 ),\n          \"=r\"( p25 ), \"=r\"( p26 ), \"=r\"( p27 ), \"=r\"( p28 ), \"=r\"( p29 ), \"=r\"( p30 ), \"=r\"( p31 )\n        : \"r\"( 0 ), \"l\"( handle ), \"f\"( ox ), \"f\"( oy ), \"f\"( oz ), \"f\"( dx ), \"f\"( dy ), \"f\"( dz ), \"f\"( tmin ),\n          \"f\"( tmax ), \"f\"( rayTime ), \"r\"( visibilityMask ), \"r\"( rayFlags ), \"r\"( SBToffset ), \"r\"( SBTstride ),\n          \"r\"( missSBTIndex ), \"r\"( 2 ), \"r\"( p0 ), \"r\"( p1 ), \"r\"( p2 ), \"r\"( p3 ), \"r\"( p4 ), \"r\"( p5 ), \"r\"( p6 ),\n          \"r\"( p7 ), \"r\"( p8 ), \"r\"( p9 ), \"r\"( p10 ), \"r\"( p11 ), \"r\"( p12 ), \"r\"( p13 ), \"r\"( p14 ), \"r\"( p15 ),\n          \"r\"( p16 ), \"r\"( p17 ), \"r\"( p18 ), \"r\"( p19 ), \"r\"( p20 ), \"r\"( p21 ), \"r\"( p22 ), \"r\"( p23 ), \"r\"( p24 ),\n          \"r\"( p25 ), \"r\"( p26 ), \"r\"( p27 ), \"r\"( p28 ), \"r\"( p29 ), \"r\"( p30 ), \"r\"( p31 )\n        : );\n    (void)p2, (void)p3, (void)p4, (void)p5, (void)p6, (void)p7, (void)p8, (void)p9, (void)p10, (void)p11, (void)p12,\n        (void)p13, (void)p14, (void)p15, (void)p16, (void)p17, (void)p18, (void)p19, (void)p20, (void)p21, (void)p22,\n        (void)p23, (void)p24, (void)p25, (void)p26, (void)p27, (void)p28, (void)p29, (void)p30, (void)p31;\n}\n\nstatic __forceinline__ __device__ void optixTrace( OptixTraversableHandle handle,\n                                                   float3                 rayOrigin,\n                                                   float3                 rayDirection,\n                                                   float                  tmin,\n                                                   float                  tmax,\n                                                   float                  rayTime,\n                                                   OptixVisibilityMask    visibilityMask,\n                                                   unsigned int           rayFlags,\n                                                   unsigned int           SBToffset,\n                                                   unsigned int           SBTstride,\n                                                   unsigned int           missSBTIndex,\n                                                   unsigned int&          p0,\n                                                   unsigned int&          p1,\n                                                   unsigned int&          p2 )\n{\n    float        ox = rayOrigin.x, oy = rayOrigin.y, oz = rayOrigin.z;\n    float        dx = rayDirection.x, dy = rayDirection.y, dz = rayDirection.z;\n    unsigned int p3, p4, p5, p6, p7, p8, p9, p10, p11, p12, p13, p14, p15, p16, p17, p18, p19, p20, p21, p22, p23, p24,\n        p25, p26, p27, p28, p29, p30, p31;\n    asm volatile(\n        \"call\"\n        \"(%0,%1,%2,%3,%4,%5,%6,%7,%8,%9,%10,%11,%12,%13,%14,%15,%16,%17,%18,%19,%20,%21,%22,%23,%24,%25,%26,%27,%28,%\"\n        \"29,%30,%31),\"\n        \"_optix_trace_typed_32,\"\n        \"(%32,%33,%34,%35,%36,%37,%38,%39,%40,%41,%42,%43,%44,%45,%46,%47,%48,%49,%50,%51,%52,%53,%54,%55,%56,%57,%58,%\"\n        \"59,%60,%61,%62,%63,%64,%65,%66,%67,%68,%69,%70,%71,%72,%73,%74,%75,%76,%77,%78,%79,%80);\"\n        : \"=r\"( p0 ), \"=r\"( p1 ), \"=r\"( p2 ), \"=r\"( p3 ), \"=r\"( p4 ), \"=r\"( p5 ), \"=r\"( p6 ), \"=r\"( p7 ), \"=r\"( p8 ),\n          \"=r\"( p9 ), \"=r\"( p10 ), \"=r\"( p11 ), \"=r\"( p12 ), \"=r\"( p13 ), \"=r\"( p14 ), \"=r\"( p15 ), \"=r\"( p16 ),\n          \"=r\"( p17 ), \"=r\"( p18 ), \"=r\"( p19 ), \"=r\"( p20 ), \"=r\"( p21 ), \"=r\"( p22 ), \"=r\"( p23 ), \"=r\"( p24 ),\n          \"=r\"( p25 ), \"=r\"( p26 ), \"=r\"( p27 ), \"=r\"( p28 ), \"=r\"( p29 ), \"=r\"( p30 ), \"=r\"( p31 )\n        : \"r\"( 0 ), \"l\"( handle ), \"f\"( ox ), \"f\"( oy ), \"f\"( oz ), \"f\"( dx ), \"f\"( dy ), \"f\"( dz ), \"f\"( tmin ),\n          \"f\"( tmax ), \"f\"( rayTime ), \"r\"( visibilityMask ), \"r\"( rayFlags ), \"r\"( SBToffset ), \"r\"( SBTstride ),\n          \"r\"( missSBTIndex ), \"r\"( 3 ), \"r\"( p0 ), \"r\"( p1 ), \"r\"( p2 ), \"r\"( p3 ), \"r\"( p4 ), \"r\"( p5 ), \"r\"( p6 ),\n          \"r\"( p7 ), \"r\"( p8 ), \"r\"( p9 ), \"r\"( p10 ), \"r\"( p11 ), \"r\"( p12 ), \"r\"( p13 ), \"r\"( p14 ), \"r\"( p15 ),\n          \"r\"( p16 ), \"r\"( p17 ), \"r\"( p18 ), \"r\"( p19 ), \"r\"( p20 ), \"r\"( p21 ), \"r\"( p22 ), \"r\"( p23 ), \"r\"( p24 ),\n          \"r\"( p25 ), \"r\"( p26 ), \"r\"( p27 ), \"r\"( p28 ), \"r\"( p29 ), \"r\"( p30 ), \"r\"( p31 )\n        : );\n    (void)p3, (void)p4, (void)p5, (void)p6, (void)p7, (void)p8, (void)p9, (void)p10, (void)p11, (void)p12, (void)p13,\n        (void)p14, (void)p15, (void)p16, (void)p17, (void)p18, (void)p19, (void)p20, (void)p21, (void)p22, (void)p23,\n        (void)p24, (void)p25, (void)p26, (void)p27, (void)p28, (void)p29, (void)p30, (void)p31;\n}\n\nstatic __forceinline__ __device__ void optixTrace( OptixTraversableHandle handle,\n                                                   float3                 rayOrigin,\n                                                   float3                 rayDirection,\n                                                   float                  tmin,\n                                                   float                  tmax,\n                                                   float                  rayTime,\n                                                   OptixVisibilityMask    visibilityMask,\n                                                   unsigned int           rayFlags,\n                                                   unsigned int           SBToffset,\n                                                   unsigned int           SBTstride,\n                                                   unsigned int           missSBTIndex,\n                                                   unsigned int&          p0,\n                                                   unsigned int&          p1,\n                                                   unsigned int&          p2,\n                                                   unsigned int&          p3 )\n{\n    float        ox = rayOrigin.x, oy = rayOrigin.y, oz = rayOrigin.z;\n    float        dx = rayDirection.x, dy = rayDirection.y, dz = rayDirection.z;\n    unsigned int p4, p5, p6, p7, p8, p9, p10, p11, p12, p13, p14, p15, p16, p17, p18, p19, p20, p21, p22, p23, p24, p25,\n        p26, p27, p28, p29, p30, p31;\n    asm volatile(\n        \"call\"\n        \"(%0,%1,%2,%3,%4,%5,%6,%7,%8,%9,%10,%11,%12,%13,%14,%15,%16,%17,%18,%19,%20,%21,%22,%23,%24,%25,%26,%27,%28,%\"\n        \"29,%30,%31),\"\n        \"_optix_trace_typed_32,\"\n        \"(%32,%33,%34,%35,%36,%37,%38,%39,%40,%41,%42,%43,%44,%45,%46,%47,%48,%49,%50,%51,%52,%53,%54,%55,%56,%57,%58,%\"\n        \"59,%60,%61,%62,%63,%64,%65,%66,%67,%68,%69,%70,%71,%72,%73,%74,%75,%76,%77,%78,%79,%80);\"\n        : \"=r\"( p0 ), \"=r\"( p1 ), \"=r\"( p2 ), \"=r\"( p3 ), \"=r\"( p4 ), \"=r\"( p5 ), \"=r\"( p6 ), \"=r\"( p7 ), \"=r\"( p8 ),\n          \"=r\"( p9 ), \"=r\"( p10 ), \"=r\"( p11 ), \"=r\"( p12 ), \"=r\"( p13 ), \"=r\"( p14 ), \"=r\"( p15 ), \"=r\"( p16 ),\n          \"=r\"( p17 ), \"=r\"( p18 ), \"=r\"( p19 ), \"=r\"( p20 ), \"=r\"( p21 ), \"=r\"( p22 ), \"=r\"( p23 ), \"=r\"( p24 ),\n          \"=r\"( p25 ), \"=r\"( p26 ), \"=r\"( p27 ), \"=r\"( p28 ), \"=r\"( p29 ), \"=r\"( p30 ), \"=r\"( p31 )\n        : \"r\"( 0 ), \"l\"( handle ), \"f\"( ox ), \"f\"( oy ), \"f\"( oz ), \"f\"( dx ), \"f\"( dy ), \"f\"( dz ), \"f\"( tmin ),\n          \"f\"( tmax ), \"f\"( rayTime ), \"r\"( visibilityMask ), \"r\"( rayFlags ), \"r\"( SBToffset ), \"r\"( SBTstride ),\n          \"r\"( missSBTIndex ), \"r\"( 4 ), \"r\"( p0 ), \"r\"( p1 ), \"r\"( p2 ), \"r\"( p3 ), \"r\"( p4 ), \"r\"( p5 ), \"r\"( p6 ),\n          \"r\"( p7 ), \"r\"( p8 ), \"r\"( p9 ), \"r\"( p10 ), \"r\"( p11 ), \"r\"( p12 ), \"r\"( p13 ), \"r\"( p14 ), \"r\"( p15 ),\n          \"r\"( p16 ), \"r\"( p17 ), \"r\"( p18 ), \"r\"( p19 ), \"r\"( p20 ), \"r\"( p21 ), \"r\"( p22 ), \"r\"( p23 ), \"r\"( p24 ),\n          \"r\"( p25 ), \"r\"( p26 ), \"r\"( p27 ), \"r\"( p28 ), \"r\"( p29 ), \"r\"( p30 ), \"r\"( p31 )\n        : );\n    (void)p4, (void)p5, (void)p6, (void)p7, (void)p8, (void)p9, (void)p10, (void)p11, (void)p12, (void)p13, (void)p14,\n        (void)p15, (void)p16, (void)p17, (void)p18, (void)p19, (void)p20, (void)p21, (void)p22, (void)p23, (void)p24,\n        (void)p25, (void)p26, (void)p27, (void)p28, (void)p29, (void)p30, (void)p31;\n}\n\nstatic __forceinline__ __device__ void optixTrace( OptixTraversableHandle handle,\n                                                   float3                 rayOrigin,\n                                                   float3                 rayDirection,\n                                                   float                  tmin,\n                                                   float                  tmax,\n                                                   float                  rayTime,\n                                                   OptixVisibilityMask    visibilityMask,\n                                                   unsigned int           rayFlags,\n                                                   unsigned int           SBToffset,\n                                                   unsigned int           SBTstride,\n                                                   unsigned int           missSBTIndex,\n                                                   unsigned int&          p0,\n                                                   unsigned int&          p1,\n                                                   unsigned int&          p2,\n                                                   unsigned int&          p3,\n                                                   unsigned int&          p4 )\n{\n    float        ox = rayOrigin.x, oy = rayOrigin.y, oz = rayOrigin.z;\n    float        dx = rayDirection.x, dy = rayDirection.y, dz = rayDirection.z;\n    unsigned int p5, p6, p7, p8, p9, p10, p11, p12, p13, p14, p15, p16, p17, p18, p19, p20, p21, p22, p23, p24, p25,\n        p26, p27, p28, p29, p30, p31;\n    asm volatile(\n        \"call\"\n        \"(%0,%1,%2,%3,%4,%5,%6,%7,%8,%9,%10,%11,%12,%13,%14,%15,%16,%17,%18,%19,%20,%21,%22,%23,%24,%25,%26,%27,%28,%\"\n        \"29,%30,%31),\"\n        \"_optix_trace_typed_32,\"\n        \"(%32,%33,%34,%35,%36,%37,%38,%39,%40,%41,%42,%43,%44,%45,%46,%47,%48,%49,%50,%51,%52,%53,%54,%55,%56,%57,%58,%\"\n        \"59,%60,%61,%62,%63,%64,%65,%66,%67,%68,%69,%70,%71,%72,%73,%74,%75,%76,%77,%78,%79,%80);\"\n        : \"=r\"( p0 ), \"=r\"( p1 ), \"=r\"( p2 ), \"=r\"( p3 ), \"=r\"( p4 ), \"=r\"( p5 ), \"=r\"( p6 ), \"=r\"( p7 ), \"=r\"( p8 ),\n          \"=r\"( p9 ), \"=r\"( p10 ), \"=r\"( p11 ), \"=r\"( p12 ), \"=r\"( p13 ), \"=r\"( p14 ), \"=r\"( p15 ), \"=r\"( p16 ),\n          \"=r\"( p17 ), \"=r\"( p18 ), \"=r\"( p19 ), \"=r\"( p20 ), \"=r\"( p21 ), \"=r\"( p22 ), \"=r\"( p23 ), \"=r\"( p24 ),\n          \"=r\"( p25 ), \"=r\"( p26 ), \"=r\"( p27 ), \"=r\"( p28 ), \"=r\"( p29 ), \"=r\"( p30 ), \"=r\"( p31 )\n        : \"r\"( 0 ), \"l\"( handle ), \"f\"( ox ), \"f\"( oy ), \"f\"( oz ), \"f\"( dx ), \"f\"( dy ), \"f\"( dz ), \"f\"( tmin ),\n          \"f\"( tmax ), \"f\"( rayTime ), \"r\"( visibilityMask ), \"r\"( rayFlags ), \"r\"( SBToffset ), \"r\"( SBTstride ),\n          \"r\"( missSBTIndex ), \"r\"( 5 ), \"r\"( p0 ), \"r\"( p1 ), \"r\"( p2 ), \"r\"( p3 ), \"r\"( p4 ), \"r\"( p5 ), \"r\"( p6 ),\n          \"r\"( p7 ), \"r\"( p8 ), \"r\"( p9 ), \"r\"( p10 ), \"r\"( p11 ), \"r\"( p12 ), \"r\"( p13 ), \"r\"( p14 ), \"r\"( p15 ),\n          \"r\"( p16 ), \"r\"( p17 ), \"r\"( p18 ), \"r\"( p19 ), \"r\"( p20 ), \"r\"( p21 ), \"r\"( p22 ), \"r\"( p23 ), \"r\"( p24 ),\n          \"r\"( p25 ), \"r\"( p26 ), \"r\"( p27 ), \"r\"( p28 ), \"r\"( p29 ), \"r\"( p30 ), \"r\"( p31 )\n        : );\n    (void)p5, (void)p6, (void)p7, (void)p8, (void)p9, (void)p10, (void)p11, (void)p12, (void)p13, (void)p14, (void)p15,\n        (void)p16, (void)p17, (void)p18, (void)p19, (void)p20, (void)p21, (void)p22, (void)p23, (void)p24, (void)p25,\n        (void)p26, (void)p27, (void)p28, (void)p29, (void)p30, (void)p31;\n}\n\nstatic __forceinline__ __device__ void optixTrace( OptixTraversableHandle handle,\n                                                   float3                 rayOrigin,\n                                                   float3                 rayDirection,\n                                                   float                  tmin,\n                                                   float                  tmax,\n                                                   float                  rayTime,\n                                                   OptixVisibilityMask    visibilityMask,\n                                                   unsigned int           rayFlags,\n                                                   unsigned int           SBToffset,\n                                                   unsigned int           SBTstride,\n                                                   unsigned int           missSBTIndex,\n                                                   unsigned int&          p0,\n                                                   unsigned int&          p1,\n                                                   unsigned int&          p2,\n                                                   unsigned int&          p3,\n                                                   unsigned int&          p4,\n                                                   unsigned int&          p5 )\n{\n    float        ox = rayOrigin.x, oy = rayOrigin.y, oz = rayOrigin.z;\n    float        dx = rayDirection.x, dy = rayDirection.y, dz = rayDirection.z;\n    unsigned int p6, p7, p8, p9, p10, p11, p12, p13, p14, p15, p16, p17, p18, p19, p20, p21, p22, p23, p24, p25, p26,\n        p27, p28, p29, p30, p31;\n    asm volatile(\n        \"call\"\n        \"(%0,%1,%2,%3,%4,%5,%6,%7,%8,%9,%10,%11,%12,%13,%14,%15,%16,%17,%18,%19,%20,%21,%22,%23,%24,%25,%26,%27,%28,%\"\n        \"29,%30,%31),\"\n        \"_optix_trace_typed_32,\"\n        \"(%32,%33,%34,%35,%36,%37,%38,%39,%40,%41,%42,%43,%44,%45,%46,%47,%48,%49,%50,%51,%52,%53,%54,%55,%56,%57,%58,%\"\n        \"59,%60,%61,%62,%63,%64,%65,%66,%67,%68,%69,%70,%71,%72,%73,%74,%75,%76,%77,%78,%79,%80);\"\n        : \"=r\"( p0 ), \"=r\"( p1 ), \"=r\"( p2 ), \"=r\"( p3 ), \"=r\"( p4 ), \"=r\"( p5 ), \"=r\"( p6 ), \"=r\"( p7 ), \"=r\"( p8 ),\n          \"=r\"( p9 ), \"=r\"( p10 ), \"=r\"( p11 ), \"=r\"( p12 ), \"=r\"( p13 ), \"=r\"( p14 ), \"=r\"( p15 ), \"=r\"( p16 ),\n          \"=r\"( p17 ), \"=r\"( p18 ), \"=r\"( p19 ), \"=r\"( p20 ), \"=r\"( p21 ), \"=r\"( p22 ), \"=r\"( p23 ), \"=r\"( p24 ),\n          \"=r\"( p25 ), \"=r\"( p26 ), \"=r\"( p27 ), \"=r\"( p28 ), \"=r\"( p29 ), \"=r\"( p30 ), \"=r\"( p31 )\n        : \"r\"( 0 ), \"l\"( handle ), \"f\"( ox ), \"f\"( oy ), \"f\"( oz ), \"f\"( dx ), \"f\"( dy ), \"f\"( dz ), \"f\"( tmin ),\n          \"f\"( tmax ), \"f\"( rayTime ), \"r\"( visibilityMask ), \"r\"( rayFlags ), \"r\"( SBToffset ), \"r\"( SBTstride ),\n          \"r\"( missSBTIndex ), \"r\"( 6 ), \"r\"( p0 ), \"r\"( p1 ), \"r\"( p2 ), \"r\"( p3 ), \"r\"( p4 ), \"r\"( p5 ), \"r\"( p6 ),\n          \"r\"( p7 ), \"r\"( p8 ), \"r\"( p9 ), \"r\"( p10 ), \"r\"( p11 ), \"r\"( p12 ), \"r\"( p13 ), \"r\"( p14 ), \"r\"( p15 ),\n          \"r\"( p16 ), \"r\"( p17 ), \"r\"( p18 ), \"r\"( p19 ), \"r\"( p20 ), \"r\"( p21 ), \"r\"( p22 ), \"r\"( p23 ), \"r\"( p24 ),\n          \"r\"( p25 ), \"r\"( p26 ), \"r\"( p27 ), \"r\"( p28 ), \"r\"( p29 ), \"r\"( p30 ), \"r\"( p31 )\n        : );\n    (void)p6, (void)p7, (void)p8, (void)p9, (void)p10, (void)p11, (void)p12, (void)p13, (void)p14, (void)p15, (void)p16,\n        (void)p17, (void)p18, (void)p19, (void)p20, (void)p21, (void)p22, (void)p23, (void)p24, (void)p25, (void)p26,\n        (void)p27, (void)p28, (void)p29, (void)p30, (void)p31;\n}\n\nstatic __forceinline__ __device__ void optixTrace( OptixTraversableHandle handle,\n                                                   float3                 rayOrigin,\n                                                   float3                 rayDirection,\n                                                   float                  tmin,\n                                                   float                  tmax,\n                                                   float                  rayTime,\n                                                   OptixVisibilityMask    visibilityMask,\n                                                   unsigned int           rayFlags,\n                                                   unsigned int           SBToffset,\n                                                   unsigned int           SBTstride,\n                                                   unsigned int           missSBTIndex,\n                                                   unsigned int&          p0,\n                                                   unsigned int&          p1,\n                                                   unsigned int&          p2,\n                                                   unsigned int&          p3,\n                                                   unsigned int&          p4,\n                                                   unsigned int&          p5,\n                                                   unsigned int&          p6 )\n{\n    float        ox = rayOrigin.x, oy = rayOrigin.y, oz = rayOrigin.z;\n    float        dx = rayDirection.x, dy = rayDirection.y, dz = rayDirection.z;\n    unsigned int p7, p8, p9, p10, p11, p12, p13, p14, p15, p16, p17, p18, p19, p20, p21, p22, p23, p24, p25, p26, p27,\n        p28, p29, p30, p31;\n    asm volatile(\n        \"call\"\n        \"(%0,%1,%2,%3,%4,%5,%6,%7,%8,%9,%10,%11,%12,%13,%14,%15,%16,%17,%18,%19,%20,%21,%22,%23,%24,%25,%26,%27,%28,%\"\n        \"29,%30,%31),\"\n        \"_optix_trace_typed_32,\"\n        \"(%32,%33,%34,%35,%36,%37,%38,%39,%40,%41,%42,%43,%44,%45,%46,%47,%48,%49,%50,%51,%52,%53,%54,%55,%56,%57,%58,%\"\n        \"59,%60,%61,%62,%63,%64,%65,%66,%67,%68,%69,%70,%71,%72,%73,%74,%75,%76,%77,%78,%79,%80);\"\n        : \"=r\"( p0 ), \"=r\"( p1 ), \"=r\"( p2 ), \"=r\"( p3 ), \"=r\"( p4 ), \"=r\"( p5 ), \"=r\"( p6 ), \"=r\"( p7 ), \"=r\"( p8 ),\n          \"=r\"( p9 ), \"=r\"( p10 ), \"=r\"( p11 ), \"=r\"( p12 ), \"=r\"( p13 ), \"=r\"( p14 ), \"=r\"( p15 ), \"=r\"( p16 ),\n          \"=r\"( p17 ), \"=r\"( p18 ), \"=r\"( p19 ), \"=r\"( p20 ), \"=r\"( p21 ), \"=r\"( p22 ), \"=r\"( p23 ), \"=r\"( p24 ),\n          \"=r\"( p25 ), \"=r\"( p26 ), \"=r\"( p27 ), \"=r\"( p28 ), \"=r\"( p29 ), \"=r\"( p30 ), \"=r\"( p31 )\n        : \"r\"( 0 ), \"l\"( handle ), \"f\"( ox ), \"f\"( oy ), \"f\"( oz ), \"f\"( dx ), \"f\"( dy ), \"f\"( dz ), \"f\"( tmin ),\n          \"f\"( tmax ), \"f\"( rayTime ), \"r\"( visibilityMask ), \"r\"( rayFlags ), \"r\"( SBToffset ), \"r\"( SBTstride ),\n          \"r\"( missSBTIndex ), \"r\"( 7 ), \"r\"( p0 ), \"r\"( p1 ), \"r\"( p2 ), \"r\"( p3 ), \"r\"( p4 ), \"r\"( p5 ), \"r\"( p6 ),\n          \"r\"( p7 ), \"r\"( p8 ), \"r\"( p9 ), \"r\"( p10 ), \"r\"( p11 ), \"r\"( p12 ), \"r\"( p13 ), \"r\"( p14 ), \"r\"( p15 ),\n          \"r\"( p16 ), \"r\"( p17 ), \"r\"( p18 ), \"r\"( p19 ), \"r\"( p20 ), \"r\"( p21 ), \"r\"( p22 ), \"r\"( p23 ), \"r\"( p24 ),\n          \"r\"( p25 ), \"r\"( p26 ), \"r\"( p27 ), \"r\"( p28 ), \"r\"( p29 ), \"r\"( p30 ), \"r\"( p31 )\n        : );\n    (void)p7, (void)p8, (void)p9, (void)p10, (void)p11, (void)p12, (void)p13, (void)p14, (void)p15, (void)p16,\n        (void)p17, (void)p18, (void)p19, (void)p20, (void)p21, (void)p22, (void)p23, (void)p24, (void)p25, (void)p26,\n        (void)p27, (void)p28, (void)p29, (void)p30, (void)p31;\n}\n\nstatic __forceinline__ __device__ void optixTrace( OptixTraversableHandle handle,\n                                                   float3                 rayOrigin,\n                                                   float3                 rayDirection,\n                                                   float                  tmin,\n                                                   float                  tmax,\n                                                   float                  rayTime,\n                                                   OptixVisibilityMask    visibilityMask,\n                                                   unsigned int           rayFlags,\n                                                   unsigned int           SBToffset,\n                                                   unsigned int           SBTstride,\n                                                   unsigned int           missSBTIndex,\n                                                   unsigned int&          p0,\n                                                   unsigned int&          p1,\n                                                   unsigned int&          p2,\n                                                   unsigned int&          p3,\n                                                   unsigned int&          p4,\n                                                   unsigned int&          p5,\n                                                   unsigned int&          p6,\n                                                   unsigned int&          p7 )\n{\n    float        ox = rayOrigin.x, oy = rayOrigin.y, oz = rayOrigin.z;\n    float        dx = rayDirection.x, dy = rayDirection.y, dz = rayDirection.z;\n    unsigned int p8, p9, p10, p11, p12, p13, p14, p15, p16, p17, p18, p19, p20, p21, p22, p23, p24, p25, p26, p27, p28,\n        p29, p30, p31;\n    asm volatile(\n        \"call\"\n        \"(%0,%1,%2,%3,%4,%5,%6,%7,%8,%9,%10,%11,%12,%13,%14,%15,%16,%17,%18,%19,%20,%21,%22,%23,%24,%25,%26,%27,%28,%\"\n        \"29,%30,%31),\"\n        \"_optix_trace_typed_32,\"\n        \"(%32,%33,%34,%35,%36,%37,%38,%39,%40,%41,%42,%43,%44,%45,%46,%47,%48,%49,%50,%51,%52,%53,%54,%55,%56,%57,%58,%\"\n        \"59,%60,%61,%62,%63,%64,%65,%66,%67,%68,%69,%70,%71,%72,%73,%74,%75,%76,%77,%78,%79,%80);\"\n        : \"=r\"( p0 ), \"=r\"( p1 ), \"=r\"( p2 ), \"=r\"( p3 ), \"=r\"( p4 ), \"=r\"( p5 ), \"=r\"( p6 ), \"=r\"( p7 ), \"=r\"( p8 ),\n          \"=r\"( p9 ), \"=r\"( p10 ), \"=r\"( p11 ), \"=r\"( p12 ), \"=r\"( p13 ), \"=r\"( p14 ), \"=r\"( p15 ), \"=r\"( p16 ),\n          \"=r\"( p17 ), \"=r\"( p18 ), \"=r\"( p19 ), \"=r\"( p20 ), \"=r\"( p21 ), \"=r\"( p22 ), \"=r\"( p23 ), \"=r\"( p24 ),\n          \"=r\"( p25 ), \"=r\"( p26 ), \"=r\"( p27 ), \"=r\"( p28 ), \"=r\"( p29 ), \"=r\"( p30 ), \"=r\"( p31 )\n        : \"r\"( 0 ), \"l\"( handle ), \"f\"( ox ), \"f\"( oy ), \"f\"( oz ), \"f\"( dx ), \"f\"( dy ), \"f\"( dz ), \"f\"( tmin ),\n          \"f\"( tmax ), \"f\"( rayTime ), \"r\"( visibilityMask ), \"r\"( rayFlags ), \"r\"( SBToffset ), \"r\"( SBTstride ),\n          \"r\"( missSBTIndex ), \"r\"( 8 ), \"r\"( p0 ), \"r\"( p1 ), \"r\"( p2 ), \"r\"( p3 ), \"r\"( p4 ), \"r\"( p5 ), \"r\"( p6 ),\n          \"r\"( p7 ), \"r\"( p8 ), \"r\"( p9 ), \"r\"( p10 ), \"r\"( p11 ), \"r\"( p12 ), \"r\"( p13 ), \"r\"( p14 ), \"r\"( p15 ),\n          \"r\"( p16 ), \"r\"( p17 ), \"r\"( p18 ), \"r\"( p19 ), \"r\"( p20 ), \"r\"( p21 ), \"r\"( p22 ), \"r\"( p23 ), \"r\"( p24 ),\n          \"r\"( p25 ), \"r\"( p26 ), \"r\"( p27 ), \"r\"( p28 ), \"r\"( p29 ), \"r\"( p30 ), \"r\"( p31 )\n        : );\n    (void)p8, (void)p9, (void)p10, (void)p11, (void)p12, (void)p13, (void)p14, (void)p15, (void)p16, (void)p17,\n        (void)p18, (void)p19, (void)p20, (void)p21, (void)p22, (void)p23, (void)p24, (void)p25, (void)p26, (void)p27,\n        (void)p28, (void)p29, (void)p30, (void)p31;\n}\n\n\nstatic __forceinline__ __device__ void optixSetPayload_0( unsigned int p )\n{\n    asm volatile( \"call _optix_set_payload, (%0, %1);\" : : \"r\"( 0 ), \"r\"( p ) : );\n}\n\nstatic __forceinline__ __device__ void optixSetPayload_1( unsigned int p )\n{\n    asm volatile( \"call _optix_set_payload, (%0, %1);\" : : \"r\"( 1 ), \"r\"( p ) : );\n}\n\nstatic __forceinline__ __device__ void optixSetPayload_2( unsigned int p )\n{\n    asm volatile( \"call _optix_set_payload, (%0, %1);\" : : \"r\"( 2 ), \"r\"( p ) : );\n}\n\nstatic __forceinline__ __device__ void optixSetPayload_3( unsigned int p )\n{\n    asm volatile( \"call _optix_set_payload, (%0, %1);\" : : \"r\"( 3 ), \"r\"( p ) : );\n}\n\nstatic __forceinline__ __device__ void optixSetPayload_4( unsigned int p )\n{\n    asm volatile( \"call _optix_set_payload, (%0, %1);\" : : \"r\"( 4 ), \"r\"( p ) : );\n}\n\nstatic __forceinline__ __device__ void optixSetPayload_5( unsigned int p )\n{\n    asm volatile( \"call _optix_set_payload, (%0, %1);\" : : \"r\"( 5 ), \"r\"( p ) : );\n}\n\nstatic __forceinline__ __device__ void optixSetPayload_6( unsigned int p )\n{\n    asm volatile( \"call _optix_set_payload, (%0, %1);\" : : \"r\"( 6 ), \"r\"( p ) : );\n}\n\nstatic __forceinline__ __device__ void optixSetPayload_7( unsigned int p )\n{\n    asm volatile( \"call _optix_set_payload, (%0, %1);\" : : \"r\"( 7 ), \"r\"( p ) : );\n}\n\n\nstatic __forceinline__ __device__ unsigned int optixGetPayload_0()\n{\n    unsigned int result;\n    asm volatile( \"call (%0), _optix_get_payload, (%1);\" : \"=r\"( result ) : \"r\"( 0 ) : );\n    return result;\n}\n\nstatic __forceinline__ __device__ unsigned int optixGetPayload_1()\n{\n    unsigned int result;\n    asm volatile( \"call (%0), _optix_get_payload, (%1);\" : \"=r\"( result ) : \"r\"( 1 ) : );\n    return result;\n}\n\nstatic __forceinline__ __device__ unsigned int optixGetPayload_2()\n{\n    unsigned int result;\n    asm volatile( \"call (%0), _optix_get_payload, (%1);\" : \"=r\"( result ) : \"r\"( 2 ) : );\n    return result;\n}\n\nstatic __forceinline__ __device__ unsigned int optixGetPayload_3()\n{\n    unsigned int result;\n    asm volatile( \"call (%0), _optix_get_payload, (%1);\" : \"=r\"( result ) : \"r\"( 3 ) : );\n    return result;\n}\n\nstatic __forceinline__ __device__ unsigned int optixGetPayload_4()\n{\n    unsigned int result;\n    asm volatile( \"call (%0), _optix_get_payload, (%1);\" : \"=r\"( result ) : \"r\"( 4 ) : );\n    return result;\n}\n\nstatic __forceinline__ __device__ unsigned int optixGetPayload_5()\n{\n    unsigned int result;\n    asm volatile( \"call (%0), _optix_get_payload, (%1);\" : \"=r\"( result ) : \"r\"( 5 ) : );\n    return result;\n}\n\nstatic __forceinline__ __device__ unsigned int optixGetPayload_6()\n{\n    unsigned int result;\n    asm volatile( \"call (%0), _optix_get_payload, (%1);\" : \"=r\"( result ) : \"r\"( 6 ) : );\n    return result;\n}\n\nstatic __forceinline__ __device__ unsigned int optixGetPayload_7()\n{\n    unsigned int result;\n    asm volatile( \"call (%0), _optix_get_payload, (%1);\" : \"=r\"( result ) : \"r\"( 7 ) : );\n    return result;\n}\n\n\nstatic __forceinline__ __device__ unsigned int optixUndefinedValue()\n{\n    unsigned int u0;\n    asm( \"call (%0), _optix_undef_value, ();\" : \"=r\"( u0 ) : );\n    return u0;\n}\n\nstatic __forceinline__ __device__ float3 optixGetWorldRayOrigin()\n{\n    float f0, f1, f2;\n    asm( \"call (%0), _optix_get_world_ray_origin_x, ();\" : \"=f\"( f0 ) : );\n    asm( \"call (%0), _optix_get_world_ray_origin_y, ();\" : \"=f\"( f1 ) : );\n    asm( \"call (%0), _optix_get_world_ray_origin_z, ();\" : \"=f\"( f2 ) : );\n    return make_float3( f0, f1, f2 );\n}\n\nstatic __forceinline__ __device__ float3 optixGetWorldRayDirection()\n{\n    float f0, f1, f2;\n    asm( \"call (%0), _optix_get_world_ray_direction_x, ();\" : \"=f\"( f0 ) : );\n    asm( \"call (%0), _optix_get_world_ray_direction_y, ();\" : \"=f\"( f1 ) : );\n    asm( \"call (%0), _optix_get_world_ray_direction_z, ();\" : \"=f\"( f2 ) : );\n    return make_float3( f0, f1, f2 );\n}\n\nstatic __forceinline__ __device__ float3 optixGetObjectRayOrigin()\n{\n    float f0, f1, f2;\n    asm( \"call (%0), _optix_get_object_ray_origin_x, ();\" : \"=f\"( f0 ) : );\n    asm( \"call (%0), _optix_get_object_ray_origin_y, ();\" : \"=f\"( f1 ) : );\n    asm( \"call (%0), _optix_get_object_ray_origin_z, ();\" : \"=f\"( f2 ) : );\n    return make_float3( f0, f1, f2 );\n}\n\nstatic __forceinline__ __device__ float3 optixGetObjectRayDirection()\n{\n    float f0, f1, f2;\n    asm( \"call (%0), _optix_get_object_ray_direction_x, ();\" : \"=f\"( f0 ) : );\n    asm( \"call (%0), _optix_get_object_ray_direction_y, ();\" : \"=f\"( f1 ) : );\n    asm( \"call (%0), _optix_get_object_ray_direction_z, ();\" : \"=f\"( f2 ) : );\n    return make_float3( f0, f1, f2 );\n}\n\nstatic __forceinline__ __device__ float optixGetRayTmin()\n{\n    float f0;\n    asm( \"call (%0), _optix_get_ray_tmin, ();\" : \"=f\"( f0 ) : );\n    return f0;\n}\n\nstatic __forceinline__ __device__ float optixGetRayTmax()\n{\n    float f0;\n    asm( \"call (%0), _optix_get_ray_tmax, ();\" : \"=f\"( f0 ) : );\n    return f0;\n}\n\nstatic __forceinline__ __device__ float optixGetRayTime()\n{\n    float f0;\n    asm( \"call (%0), _optix_get_ray_time, ();\" : \"=f\"( f0 ) : );\n    return f0;\n}\n\nstatic __forceinline__ __device__ unsigned int optixGetRayFlags()\n{\n    unsigned int u0;\n    asm( \"call (%0), _optix_get_ray_flags, ();\" : \"=r\"( u0 ) : );\n    return u0;\n}\n\nstatic __forceinline__ __device__ unsigned int optixGetRayVisibilityMask()\n{\n    unsigned int u0;\n    asm( \"call (%0), _optix_get_ray_visibility_mask, ();\" : \"=r\"( u0 ) : );\n    return u0;\n}\n\nstatic __forceinline__ __device__ OptixTraversableHandle optixGetInstanceTraversableFromIAS( OptixTraversableHandle ias,\n                                                                                             unsigned int           instIdx )\n{\n    unsigned long long handle;\n    asm( \"call (%0), _optix_get_instance_traversable_from_ias, (%1, %2);\"\n         : \"=l\"( handle ) : \"l\"( ias ), \"r\"( instIdx ) );\n    return (OptixTraversableHandle)handle;\n}\n\n\nstatic __forceinline__ __device__ void optixGetTriangleVertexData( OptixTraversableHandle gas,\n                                                                   unsigned int           primIdx,\n                                                                   unsigned int           sbtGASIndex,\n                                                                   float                  time,\n                                                                   float3                 data[3] )\n{\n    asm( \"call (%0, %1, %2, %3, %4, %5, %6, %7, %8), _optix_get_triangle_vertex_data, \"\n         \"(%9, %10, %11, %12);\"\n         : \"=f\"( data[0].x ), \"=f\"( data[0].y ), \"=f\"( data[0].z ), \"=f\"( data[1].x ), \"=f\"( data[1].y ),\n           \"=f\"( data[1].z ), \"=f\"( data[2].x ), \"=f\"( data[2].y ), \"=f\"( data[2].z )\n         : \"l\"( gas ), \"r\"( primIdx ), \"r\"( sbtGASIndex ), \"f\"( time )\n         : );\n}\n\nstatic __forceinline__ __device__ void optixGetLinearCurveVertexData( OptixTraversableHandle gas,\n                                                                      unsigned int           primIdx,\n                                                                      unsigned int           sbtGASIndex,\n                                                                      float                  time,\n                                                                      float4                 data[2] )\n{\n    asm( \"call (%0, %1, %2, %3,  %4, %5, %6, %7), _optix_get_linear_curve_vertex_data, \"\n         \"(%8, %9, %10, %11);\"\n         : \"=f\"( data[0].x ), \"=f\"( data[0].y ), \"=f\"( data[0].z ), \"=f\"( data[0].w ),\n           \"=f\"( data[1].x ), \"=f\"( data[1].y ), \"=f\"( data[1].z ), \"=f\"( data[1].w )\n         : \"l\"( gas ), \"r\"( primIdx ), \"r\"( sbtGASIndex ), \"f\"( time )\n         : );\n}\n\nstatic __forceinline__ __device__ void optixGetQuadraticBSplineVertexData( OptixTraversableHandle gas,\n                                                                           unsigned int         primIdx,\n                                                                           unsigned int         sbtGASIndex,\n                                                                           float                time,\n                                                                           float4               data[3] )\n{\n    asm( \"call (%0, %1, %2, %3,  %4, %5, %6, %7,  %8, %9, %10, %11), _optix_get_quadratic_bspline_vertex_data, \"\n         \"(%12, %13, %14, %15);\"\n         : \"=f\"( data[0].x ), \"=f\"( data[0].y ), \"=f\"( data[0].z ), \"=f\"( data[0].w ), \n           \"=f\"( data[1].x ), \"=f\"( data[1].y ), \"=f\"( data[1].z ), \"=f\"( data[1].w ),\n           \"=f\"( data[2].x ), \"=f\"( data[2].y ), \"=f\"( data[2].z ), \"=f\"( data[2].w )\n         : \"l\"( gas ), \"r\"( primIdx ), \"r\"( sbtGASIndex ), \"f\"( time )\n         : );\n}\n\nstatic __forceinline__ __device__ void optixGetCubicBSplineVertexData( OptixTraversableHandle gas,\n                                                                       unsigned int         primIdx,\n                                                                       unsigned int         sbtGASIndex,\n                                                                       float                time,\n                                                                       float4               data[4] )\n{\n    asm( \"call (%0, %1, %2, %3,  %4, %5, %6, %7,  %8, %9, %10, %11,  %12, %13, %14, %15), \"\n         \"_optix_get_cubic_bspline_vertex_data, \"\n         \"(%16, %17, %18, %19);\"\n         : \"=f\"( data[0].x ), \"=f\"( data[0].y ), \"=f\"( data[0].z ), \"=f\"( data[0].w ), \n           \"=f\"( data[1].x ), \"=f\"( data[1].y ), \"=f\"( data[1].z ), \"=f\"( data[1].w ),\n           \"=f\"( data[2].x ), \"=f\"( data[2].y ), \"=f\"( data[2].z ), \"=f\"( data[2].w ),\n           \"=f\"( data[3].x ), \"=f\"( data[3].y ), \"=f\"( data[3].z ), \"=f\"( data[3].w )\n         : \"l\"( gas ), \"r\"( primIdx ), \"r\"( sbtGASIndex ), \"f\"( time )\n         : );\n}\n\nstatic __forceinline__ __device__ OptixTraversableHandle optixGetGASTraversableHandle()\n{\n    unsigned long long handle;\n    asm( \"call (%0), _optix_get_gas_traversable_handle, ();\" : \"=l\"( handle ) : );\n    return (OptixTraversableHandle)handle;\n}\n\nstatic __forceinline__ __device__ float optixGetGASMotionTimeBegin( OptixTraversableHandle handle )\n{\n    float f0;\n    asm( \"call (%0), _optix_get_gas_motion_time_begin, (%1);\" : \"=f\"( f0 ) : \"l\"( handle ) : );\n    return f0;\n}\n\nstatic __forceinline__ __device__ float optixGetGASMotionTimeEnd( OptixTraversableHandle handle )\n{\n    float f0;\n    asm( \"call (%0), _optix_get_gas_motion_time_end, (%1);\" : \"=f\"( f0 ) : \"l\"( handle ) : );\n    return f0;\n}\n\nstatic __forceinline__ __device__ unsigned int optixGetGASMotionStepCount( OptixTraversableHandle handle )\n{\n    unsigned int u0;\n    asm( \"call (%0), _optix_get_gas_motion_step_count, (%1);\" : \"=r\"( u0 ) : \"l\"( handle ) : );\n    return u0;\n}\n\nstatic __forceinline__ __device__ void optixGetWorldToObjectTransformMatrix( float m[12] )\n{\n    if( optixGetTransformListSize() == 0 )\n    {\n        m[0]  = 1.0f;\n        m[1]  = 0.0f;\n        m[2]  = 0.0f;\n        m[3]  = 0.0f;\n        m[4]  = 0.0f;\n        m[5]  = 1.0f;\n        m[6]  = 0.0f;\n        m[7]  = 0.0f;\n        m[8]  = 0.0f;\n        m[9]  = 0.0f;\n        m[10] = 1.0f;\n        m[11] = 0.0f;\n        return;\n    }\n\n    float4 m0, m1, m2;\n    optix_impl::optixGetWorldToObjectTransformMatrix( m0, m1, m2 );\n    m[0]  = m0.x;\n    m[1]  = m0.y;\n    m[2]  = m0.z;\n    m[3]  = m0.w;\n    m[4]  = m1.x;\n    m[5]  = m1.y;\n    m[6]  = m1.z;\n    m[7]  = m1.w;\n    m[8]  = m2.x;\n    m[9]  = m2.y;\n    m[10] = m2.z;\n    m[11] = m2.w;\n}\n\nstatic __forceinline__ __device__ void optixGetObjectToWorldTransformMatrix( float m[12] )\n{\n    if( optixGetTransformListSize() == 0 )\n    {\n        m[0]  = 1.0f;\n        m[1]  = 0.0f;\n        m[2]  = 0.0f;\n        m[3]  = 0.0f;\n        m[4]  = 0.0f;\n        m[5]  = 1.0f;\n        m[6]  = 0.0f;\n        m[7]  = 0.0f;\n        m[8]  = 0.0f;\n        m[9]  = 0.0f;\n        m[10] = 1.0f;\n        m[11] = 0.0f;\n        return;\n    }\n\n    float4 m0, m1, m2;\n    optix_impl::optixGetObjectToWorldTransformMatrix( m0, m1, m2 );\n    m[0]  = m0.x;\n    m[1]  = m0.y;\n    m[2]  = m0.z;\n    m[3]  = m0.w;\n    m[4]  = m1.x;\n    m[5]  = m1.y;\n    m[6]  = m1.z;\n    m[7]  = m1.w;\n    m[8]  = m2.x;\n    m[9]  = m2.y;\n    m[10] = m2.z;\n    m[11] = m2.w;\n}\n\nstatic __forceinline__ __device__ float3 optixTransformPointFromWorldToObjectSpace( float3 point )\n{\n    if( optixGetTransformListSize() == 0 )\n        return point;\n\n    float4 m0, m1, m2;\n    optix_impl::optixGetWorldToObjectTransformMatrix( m0, m1, m2 );\n    return optix_impl::optixTransformPoint( m0, m1, m2, point );\n}\n\nstatic __forceinline__ __device__ float3 optixTransformVectorFromWorldToObjectSpace( float3 vec )\n{\n    if( optixGetTransformListSize() == 0 )\n        return vec;\n\n    float4 m0, m1, m2;\n    optix_impl::optixGetWorldToObjectTransformMatrix( m0, m1, m2 );\n    return optix_impl::optixTransformVector( m0, m1, m2, vec );\n}\n\nstatic __forceinline__ __device__ float3 optixTransformNormalFromWorldToObjectSpace( float3 normal )\n{\n    if( optixGetTransformListSize() == 0 )\n        return normal;\n\n    float4 m0, m1, m2;\n    optix_impl::optixGetObjectToWorldTransformMatrix( m0, m1, m2 );  // inverse of optixGetWorldToObjectTransformMatrix()\n    return optix_impl::optixTransformNormal( m0, m1, m2, normal );\n}\n\nstatic __forceinline__ __device__ float3 optixTransformPointFromObjectToWorldSpace( float3 point )\n{\n    if( optixGetTransformListSize() == 0 )\n        return point;\n\n    float4 m0, m1, m2;\n    optix_impl::optixGetObjectToWorldTransformMatrix( m0, m1, m2 );\n    return optix_impl::optixTransformPoint( m0, m1, m2, point );\n}\n\nstatic __forceinline__ __device__ float3 optixTransformVectorFromObjectToWorldSpace( float3 vec )\n{\n    if( optixGetTransformListSize() == 0 )\n        return vec;\n\n    float4 m0, m1, m2;\n    optix_impl::optixGetObjectToWorldTransformMatrix( m0, m1, m2 );\n    return optix_impl::optixTransformVector( m0, m1, m2, vec );\n}\n\nstatic __forceinline__ __device__ float3 optixTransformNormalFromObjectToWorldSpace( float3 normal )\n{\n    if( optixGetTransformListSize() == 0 )\n        return normal;\n\n    float4 m0, m1, m2;\n    optix_impl::optixGetWorldToObjectTransformMatrix( m0, m1, m2 );  // inverse of optixGetObjectToWorldTransformMatrix()\n    return optix_impl::optixTransformNormal( m0, m1, m2, normal );\n}\n\nstatic __forceinline__ __device__ unsigned int optixGetTransformListSize()\n{\n    unsigned int u0;\n    asm( \"call (%0), _optix_get_transform_list_size, ();\" : \"=r\"( u0 ) : );\n    return u0;\n}\n\nstatic __forceinline__ __device__ OptixTraversableHandle optixGetTransformListHandle( unsigned int index )\n{\n    unsigned long long u0;\n    asm( \"call (%0), _optix_get_transform_list_handle, (%1);\" : \"=l\"( u0 ) : \"r\"( index ) : );\n    return u0;\n}\n\nstatic __forceinline__ __device__ OptixTransformType optixGetTransformTypeFromHandle( OptixTraversableHandle handle )\n{\n    int i0;\n    asm( \"call (%0), _optix_get_transform_type_from_handle, (%1);\" : \"=r\"( i0 ) : \"l\"( handle ) : );\n    return (OptixTransformType)i0;\n}\n\nstatic __forceinline__ __device__ const OptixStaticTransform* optixGetStaticTransformFromHandle( OptixTraversableHandle handle )\n{\n    unsigned long long ptr;\n    asm( \"call (%0), _optix_get_static_transform_from_handle, (%1);\" : \"=l\"( ptr ) : \"l\"( handle ) : );\n    return (const OptixStaticTransform*)ptr;\n}\n\nstatic __forceinline__ __device__ const OptixSRTMotionTransform* optixGetSRTMotionTransformFromHandle( OptixTraversableHandle handle )\n{\n    unsigned long long ptr;\n    asm( \"call (%0), _optix_get_srt_motion_transform_from_handle, (%1);\" : \"=l\"( ptr ) : \"l\"( handle ) : );\n    return (const OptixSRTMotionTransform*)ptr;\n}\n\nstatic __forceinline__ __device__ const OptixMatrixMotionTransform* optixGetMatrixMotionTransformFromHandle( OptixTraversableHandle handle )\n{\n    unsigned long long ptr;\n    asm( \"call (%0), _optix_get_matrix_motion_transform_from_handle, (%1);\" : \"=l\"( ptr ) : \"l\"( handle ) : );\n    return (const OptixMatrixMotionTransform*)ptr;\n}\n\nstatic __forceinline__ __device__ unsigned int optixGetInstanceIdFromHandle( OptixTraversableHandle handle )\n{\n    int i0;\n    asm( \"call (%0), _optix_get_instance_id_from_handle, (%1);\" : \"=r\"( i0 ) : \"l\"( handle ) : );\n    return i0;\n}\n\nstatic __forceinline__ __device__ OptixTraversableHandle optixGetInstanceChildFromHandle( OptixTraversableHandle handle )\n{\n    unsigned long long i0;\n    asm( \"call (%0), _optix_get_instance_child_from_handle, (%1);\" : \"=l\"( i0 ) : \"l\"( handle ) : );\n    return (OptixTraversableHandle)i0;\n}\n\nstatic __forceinline__ __device__ const float4* optixGetInstanceTransformFromHandle( OptixTraversableHandle handle )\n{\n    unsigned long long ptr;\n    asm( \"call (%0), _optix_get_instance_transform_from_handle, (%1);\" : \"=l\"( ptr ) : \"l\"( handle ) : );\n    return (const float4*)ptr;\n}\n\nstatic __forceinline__ __device__ const float4* optixGetInstanceInverseTransformFromHandle( OptixTraversableHandle handle )\n{\n    unsigned long long ptr;\n    asm( \"call (%0), _optix_get_instance_inverse_transform_from_handle, (%1);\" : \"=l\"( ptr ) : \"l\"( handle ) : );\n    return (const float4*)ptr;\n}\n\nstatic __forceinline__ __device__ bool optixReportIntersection( float hitT, unsigned int hitKind )\n{\n    int ret;\n    asm volatile(\n        \"call (%0), _optix_report_intersection_0\"\n        \", (%1, %2);\"\n        : \"=r\"( ret )\n        : \"f\"( hitT ), \"r\"( hitKind )\n        : );\n    return ret;\n}\n\nstatic __forceinline__ __device__ bool optixReportIntersection( float hitT, unsigned int hitKind, unsigned int a0 )\n{\n    int ret;\n    asm volatile(\n        \"call (%0), _optix_report_intersection_1\"\n        \", (%1, %2, %3);\"\n        : \"=r\"( ret )\n        : \"f\"( hitT ), \"r\"( hitKind ), \"r\"( a0 )\n        : );\n    return ret;\n}\n\nstatic __forceinline__ __device__ bool optixReportIntersection( float hitT, unsigned int hitKind, unsigned int a0, unsigned int a1 )\n{\n    int ret;\n    asm volatile(\n        \"call (%0), _optix_report_intersection_2\"\n        \", (%1, %2, %3, %4);\"\n        : \"=r\"( ret )\n        : \"f\"( hitT ), \"r\"( hitKind ), \"r\"( a0 ), \"r\"( a1 )\n        : );\n    return ret;\n}\n\nstatic __forceinline__ __device__ bool optixReportIntersection( float hitT, unsigned int hitKind, unsigned int a0, unsigned int a1, unsigned int a2 )\n{\n    int ret;\n    asm volatile(\n        \"call (%0), _optix_report_intersection_3\"\n        \", (%1, %2, %3, %4, %5);\"\n        : \"=r\"( ret )\n        : \"f\"( hitT ), \"r\"( hitKind ), \"r\"( a0 ), \"r\"( a1 ), \"r\"( a2 )\n        : );\n    return ret;\n}\n\nstatic __forceinline__ __device__ bool optixReportIntersection( float        hitT,\n                                                                unsigned int hitKind,\n                                                                unsigned int a0,\n                                                                unsigned int a1,\n                                                                unsigned int a2,\n                                                                unsigned int a3 )\n{\n    int ret;\n    asm volatile(\n        \"call (%0), _optix_report_intersection_4\"\n        \", (%1, %2, %3, %4, %5, %6);\"\n        : \"=r\"( ret )\n        : \"f\"( hitT ), \"r\"( hitKind ), \"r\"( a0 ), \"r\"( a1 ), \"r\"( a2 ), \"r\"( a3 )\n        : );\n    return ret;\n}\n\nstatic __forceinline__ __device__ bool optixReportIntersection( float        hitT,\n                                                                unsigned int hitKind,\n                                                                unsigned int a0,\n                                                                unsigned int a1,\n                                                                unsigned int a2,\n                                                                unsigned int a3,\n                                                                unsigned int a4 )\n{\n    int ret;\n    asm volatile(\n        \"call (%0), _optix_report_intersection_5\"\n        \", (%1, %2, %3, %4, %5, %6, %7);\"\n        : \"=r\"( ret )\n        : \"f\"( hitT ), \"r\"( hitKind ), \"r\"( a0 ), \"r\"( a1 ), \"r\"( a2 ), \"r\"( a3 ), \"r\"( a4 )\n        : );\n    return ret;\n}\n\nstatic __forceinline__ __device__ bool optixReportIntersection( float        hitT,\n                                                                unsigned int hitKind,\n                                                                unsigned int a0,\n                                                                unsigned int a1,\n                                                                unsigned int a2,\n                                                                unsigned int a3,\n                                                                unsigned int a4,\n                                                                unsigned int a5 )\n{\n    int ret;\n    asm volatile(\n        \"call (%0), _optix_report_intersection_6\"\n        \", (%1, %2, %3, %4, %5, %6, %7, %8);\"\n        : \"=r\"( ret )\n        : \"f\"( hitT ), \"r\"( hitKind ), \"r\"( a0 ), \"r\"( a1 ), \"r\"( a2 ), \"r\"( a3 ), \"r\"( a4 ), \"r\"( a5 )\n        : );\n    return ret;\n}\n\nstatic __forceinline__ __device__ bool optixReportIntersection( float        hitT,\n                                                                unsigned int hitKind,\n                                                                unsigned int a0,\n                                                                unsigned int a1,\n                                                                unsigned int a2,\n                                                                unsigned int a3,\n                                                                unsigned int a4,\n                                                                unsigned int a5,\n                                                                unsigned int a6 )\n{\n    int ret;\n    asm volatile(\n        \"call (%0), _optix_report_intersection_7\"\n        \", (%1, %2, %3, %4, %5, %6, %7, %8, %9);\"\n        : \"=r\"( ret )\n        : \"f\"( hitT ), \"r\"( hitKind ), \"r\"( a0 ), \"r\"( a1 ), \"r\"( a2 ), \"r\"( a3 ), \"r\"( a4 ), \"r\"( a5 ), \"r\"( a6 )\n        : );\n    return ret;\n}\n\nstatic __forceinline__ __device__ bool optixReportIntersection( float        hitT,\n                                                                unsigned int hitKind,\n                                                                unsigned int a0,\n                                                                unsigned int a1,\n                                                                unsigned int a2,\n                                                                unsigned int a3,\n                                                                unsigned int a4,\n                                                                unsigned int a5,\n                                                                unsigned int a6,\n                                                                unsigned int a7 )\n{\n    int ret;\n    asm volatile(\n        \"call (%0), _optix_report_intersection_8\"\n        \", (%1, %2, %3, %4, %5, %6, %7, %8, %9, %10);\"\n        : \"=r\"( ret )\n        : \"f\"( hitT ), \"r\"( hitKind ), \"r\"( a0 ), \"r\"( a1 ), \"r\"( a2 ), \"r\"( a3 ), \"r\"( a4 ), \"r\"( a5 ), \"r\"( a6 ), \"r\"( a7 )\n        : );\n    return ret;\n}\n\n#define OPTIX_DEFINE_optixGetAttribute_BODY( which )                                                                   \\\n    unsigned int ret;                                                                                                  \\\n    asm( \"call (%0), _optix_get_attribute_\" #which \", ();\" : \"=r\"( ret ) : );                                          \\\n    return ret;\n\nstatic __forceinline__ __device__ unsigned int optixGetAttribute_0()\n{\n    OPTIX_DEFINE_optixGetAttribute_BODY( 0 );\n}\n\nstatic __forceinline__ __device__ unsigned int optixGetAttribute_1()\n{\n    OPTIX_DEFINE_optixGetAttribute_BODY( 1 );\n}\n\nstatic __forceinline__ __device__ unsigned int optixGetAttribute_2()\n{\n    OPTIX_DEFINE_optixGetAttribute_BODY( 2 );\n}\n\nstatic __forceinline__ __device__ unsigned int optixGetAttribute_3()\n{\n    OPTIX_DEFINE_optixGetAttribute_BODY( 3 );\n}\n\nstatic __forceinline__ __device__ unsigned int optixGetAttribute_4()\n{\n    OPTIX_DEFINE_optixGetAttribute_BODY( 4 );\n}\n\nstatic __forceinline__ __device__ unsigned int optixGetAttribute_5()\n{\n    OPTIX_DEFINE_optixGetAttribute_BODY( 5 );\n}\n\nstatic __forceinline__ __device__ unsigned int optixGetAttribute_6()\n{\n    OPTIX_DEFINE_optixGetAttribute_BODY( 6 );\n}\n\nstatic __forceinline__ __device__ unsigned int optixGetAttribute_7()\n{\n    OPTIX_DEFINE_optixGetAttribute_BODY( 7 );\n}\n\n#undef OPTIX_DEFINE_optixGetAttribute_BODY\n\nstatic __forceinline__ __device__ void optixTerminateRay()\n{\n    asm volatile( \"call _optix_terminate_ray, ();\" );\n}\n\nstatic __forceinline__ __device__ void optixIgnoreIntersection()\n{\n    asm volatile( \"call _optix_ignore_intersection, ();\" );\n}\n\nstatic __forceinline__ __device__ unsigned int optixGetPrimitiveIndex()\n{\n    unsigned int u0;\n    asm( \"call (%0), _optix_read_primitive_idx, ();\" : \"=r\"( u0 ) : );\n    return u0;\n}\n\nstatic __forceinline__ __device__ unsigned int optixGetSbtGASIndex()\n{\n    unsigned int u0;\n    asm( \"call (%0), _optix_read_sbt_gas_idx, ();\" : \"=r\"( u0 ) : );\n    return u0;\n}\n\nstatic __forceinline__ __device__ unsigned int optixGetInstanceId()\n{\n    unsigned int u0;\n    asm( \"call (%0), _optix_read_instance_id, ();\" : \"=r\"( u0 ) : );\n    return u0;\n}\n\nstatic __forceinline__ __device__ unsigned int optixGetInstanceIndex()\n{\n    unsigned int u0;\n    asm( \"call (%0), _optix_read_instance_idx, ();\" : \"=r\"( u0 ) : );\n    return u0;\n}\n\nstatic __forceinline__ __device__ unsigned int optixGetHitKind()\n{\n    unsigned int u0;\n    asm( \"call (%0), _optix_get_hit_kind, ();\" : \"=r\"( u0 ) : );\n    return u0;\n}\n\nstatic __forceinline__ __device__ OptixPrimitiveType optixGetPrimitiveType(unsigned int hitKind)\n{\n    unsigned int u0;\n    asm( \"call (%0), _optix_get_primitive_type_from_hit_kind, (%1);\" : \"=r\"( u0 ) : \"r\"( hitKind ) );\n    return (OptixPrimitiveType)u0;\n}\n\nstatic __forceinline__ __device__ bool optixIsBackFaceHit( unsigned int hitKind )\n{\n    unsigned int u0;\n    asm( \"call (%0), _optix_get_backface_from_hit_kind, (%1);\" : \"=r\"( u0 ) : \"r\"( hitKind ) );\n    return (u0 == 0x1);\n}\n\nstatic __forceinline__ __device__ bool optixIsFrontFaceHit( unsigned int hitKind )\n{\n    return !optixIsBackFaceHit( hitKind );\n}\n\n\nstatic __forceinline__ __device__ OptixPrimitiveType optixGetPrimitiveType()\n{\n    return optixGetPrimitiveType( optixGetHitKind() );\n}\n\nstatic __forceinline__ __device__ bool optixIsBackFaceHit()\n{\n    return optixIsBackFaceHit( optixGetHitKind() );\n}\n\nstatic __forceinline__ __device__ bool optixIsFrontFaceHit()\n{\n    return optixIsFrontFaceHit( optixGetHitKind() );\n}\n\nstatic __forceinline__ __device__ bool optixIsTriangleHit()\n{\n    return optixIsTriangleFrontFaceHit() || optixIsTriangleBackFaceHit();\n}\n\nstatic __forceinline__ __device__ bool optixIsTriangleFrontFaceHit()\n{\n    return optixGetHitKind() == OPTIX_HIT_KIND_TRIANGLE_FRONT_FACE;\n}\n\nstatic __forceinline__ __device__ bool optixIsTriangleBackFaceHit()\n{\n    return optixGetHitKind() == OPTIX_HIT_KIND_TRIANGLE_BACK_FACE;\n}\n\nstatic __forceinline__ __device__ float optixGetCurveParameter()\n{\n    return __int_as_float( optixGetAttribute_0() );\n}\n\nstatic __forceinline__ __device__ float2 optixGetTriangleBarycentrics()\n{\n    float f0, f1;\n    asm( \"call (%0, %1), _optix_get_triangle_barycentrics, ();\" : \"=f\"( f0 ), \"=f\"( f1 ) : );\n    return make_float2( f0, f1 );\n}\n\nstatic __forceinline__ __device__ uint3 optixGetLaunchIndex()\n{\n    unsigned int u0, u1, u2;\n    asm( \"call (%0), _optix_get_launch_index_x, ();\" : \"=r\"( u0 ) : );\n    asm( \"call (%0), _optix_get_launch_index_y, ();\" : \"=r\"( u1 ) : );\n    asm( \"call (%0), _optix_get_launch_index_z, ();\" : \"=r\"( u2 ) : );\n    return make_uint3( u0, u1, u2 );\n}\n\nstatic __forceinline__ __device__ uint3 optixGetLaunchDimensions()\n{\n    unsigned int u0, u1, u2;\n    asm( \"call (%0), _optix_get_launch_dimension_x, ();\" : \"=r\"( u0 ) : );\n    asm( \"call (%0), _optix_get_launch_dimension_y, ();\" : \"=r\"( u1 ) : );\n    asm( \"call (%0), _optix_get_launch_dimension_z, ();\" : \"=r\"( u2 ) : );\n    return make_uint3( u0, u1, u2 );\n}\n\nstatic __forceinline__ __device__ CUdeviceptr optixGetSbtDataPointer()\n{\n    unsigned long long ptr;\n    asm( \"call (%0), _optix_get_sbt_data_ptr_64, ();\" : \"=l\"( ptr ) : );\n    return (CUdeviceptr)ptr;\n}\n\nstatic __forceinline__ __device__ void optixThrowException( int exceptionCode )\n{\n    asm volatile(\n        \"call _optix_throw_exception_0, (%0);\"\n        : /* no return value */\n        : \"r\"( exceptionCode )\n        : );\n}\n\nstatic __forceinline__ __device__ void optixThrowException( int exceptionCode, unsigned int exceptionDetail0 )\n{\n    asm volatile(\n        \"call _optix_throw_exception_1, (%0, %1);\"\n        : /* no return value */\n        : \"r\"( exceptionCode ), \"r\"( exceptionDetail0 )\n        : );\n}\n\nstatic __forceinline__ __device__ void optixThrowException( int exceptionCode, unsigned int exceptionDetail0, unsigned int exceptionDetail1 )\n{\n    asm volatile(\n        \"call _optix_throw_exception_2, (%0, %1, %2);\"\n        : /* no return value */\n        : \"r\"( exceptionCode ), \"r\"( exceptionDetail0 ), \"r\"( exceptionDetail1 )\n        : );\n}\n\nstatic __forceinline__ __device__ void optixThrowException( int exceptionCode, unsigned int exceptionDetail0, unsigned int exceptionDetail1, unsigned int exceptionDetail2 )\n{\n    asm volatile(\n        \"call _optix_throw_exception_3, (%0, %1, %2, %3);\"\n        : /* no return value */\n        : \"r\"( exceptionCode ), \"r\"( exceptionDetail0 ), \"r\"( exceptionDetail1 ), \"r\"( exceptionDetail2 )\n        : );\n}\n\nstatic __forceinline__ __device__ void optixThrowException( int exceptionCode, unsigned int exceptionDetail0, unsigned int exceptionDetail1, unsigned int exceptionDetail2, unsigned int exceptionDetail3 )\n{\n    asm volatile(\n        \"call _optix_throw_exception_4, (%0, %1, %2, %3, %4);\"\n        : /* no return value */\n        : \"r\"( exceptionCode ), \"r\"( exceptionDetail0 ), \"r\"( exceptionDetail1 ), \"r\"( exceptionDetail2 ), \"r\"( exceptionDetail3 )\n        : );\n}\n\nstatic __forceinline__ __device__ void optixThrowException( int exceptionCode, unsigned int exceptionDetail0, unsigned int exceptionDetail1, unsigned int exceptionDetail2, unsigned int exceptionDetail3, unsigned int exceptionDetail4 )\n{\n    asm volatile(\n        \"call _optix_throw_exception_5, (%0, %1, %2, %3, %4, %5);\"\n        : /* no return value */\n        : \"r\"( exceptionCode ), \"r\"( exceptionDetail0 ), \"r\"( exceptionDetail1 ), \"r\"( exceptionDetail2 ), \"r\"( exceptionDetail3 ), \"r\"( exceptionDetail4 )\n        : );\n}\n\nstatic __forceinline__ __device__ void optixThrowException( int exceptionCode, unsigned int exceptionDetail0, unsigned int exceptionDetail1, unsigned int exceptionDetail2, unsigned int exceptionDetail3, unsigned int exceptionDetail4, unsigned int exceptionDetail5 )\n{\n    asm volatile(\n        \"call _optix_throw_exception_6, (%0, %1, %2, %3, %4, %5, %6);\"\n        : /* no return value */\n        : \"r\"( exceptionCode ), \"r\"( exceptionDetail0 ), \"r\"( exceptionDetail1 ), \"r\"( exceptionDetail2 ), \"r\"( exceptionDetail3 ), \"r\"( exceptionDetail4 ), \"r\"( exceptionDetail5 )\n        : );\n}\n\nstatic __forceinline__ __device__ void optixThrowException( int exceptionCode, unsigned int exceptionDetail0, unsigned int exceptionDetail1, unsigned int exceptionDetail2, unsigned int exceptionDetail3, unsigned int exceptionDetail4, unsigned int exceptionDetail5, unsigned int exceptionDetail6 )\n{\n    asm volatile(\n        \"call _optix_throw_exception_7, (%0, %1, %2, %3, %4, %5, %6, %7);\"\n        : /* no return value */\n        : \"r\"( exceptionCode ), \"r\"( exceptionDetail0 ), \"r\"( exceptionDetail1 ), \"r\"( exceptionDetail2 ), \"r\"( exceptionDetail3 ), \"r\"( exceptionDetail4 ), \"r\"( exceptionDetail5 ), \"r\"( exceptionDetail6 )\n        : );\n}\n\nstatic __forceinline__ __device__ void optixThrowException( int exceptionCode, unsigned int exceptionDetail0, unsigned int exceptionDetail1, unsigned int exceptionDetail2, unsigned int exceptionDetail3, unsigned int exceptionDetail4, unsigned int exceptionDetail5, unsigned int exceptionDetail6, unsigned int exceptionDetail7 )\n{\n    asm volatile(\n        \"call _optix_throw_exception_8, (%0, %1, %2, %3, %4, %5, %6, %7, %8);\"\n        : /* no return value */\n        : \"r\"( exceptionCode ), \"r\"( exceptionDetail0 ), \"r\"( exceptionDetail1 ), \"r\"( exceptionDetail2 ), \"r\"( exceptionDetail3 ), \"r\"( exceptionDetail4 ), \"r\"( exceptionDetail5 ), \"r\"( exceptionDetail6 ), \"r\"( exceptionDetail7 )\n        : );\n}\n\nstatic __forceinline__ __device__ int optixGetExceptionCode()\n{\n    int s0;\n    asm( \"call (%0), _optix_get_exception_code, ();\" : \"=r\"( s0 ) : );\n    return s0;\n}\n\n#define OPTIX_DEFINE_optixGetExceptionDetail_BODY( which )                                                             \\\n    unsigned int ret;                                                                                                  \\\n    asm( \"call (%0), _optix_get_exception_detail_\" #which \", ();\" : \"=r\"( ret ) : );                                   \\\n    return ret;\n\nstatic __forceinline__ __device__ unsigned int optixGetExceptionDetail_0()\n{\n    OPTIX_DEFINE_optixGetExceptionDetail_BODY( 0 );\n}\n\nstatic __forceinline__ __device__ unsigned int optixGetExceptionDetail_1()\n{\n    OPTIX_DEFINE_optixGetExceptionDetail_BODY( 1 );\n}\n\nstatic __forceinline__ __device__ unsigned int optixGetExceptionDetail_2()\n{\n    OPTIX_DEFINE_optixGetExceptionDetail_BODY( 2 );\n}\n\nstatic __forceinline__ __device__ unsigned int optixGetExceptionDetail_3()\n{\n    OPTIX_DEFINE_optixGetExceptionDetail_BODY( 3 );\n}\n\nstatic __forceinline__ __device__ unsigned int optixGetExceptionDetail_4()\n{\n    OPTIX_DEFINE_optixGetExceptionDetail_BODY( 4 );\n}\n\nstatic __forceinline__ __device__ unsigned int optixGetExceptionDetail_5()\n{\n    OPTIX_DEFINE_optixGetExceptionDetail_BODY( 5 );\n}\n\nstatic __forceinline__ __device__ unsigned int optixGetExceptionDetail_6()\n{\n    OPTIX_DEFINE_optixGetExceptionDetail_BODY( 6 );\n}\n\nstatic __forceinline__ __device__ unsigned int optixGetExceptionDetail_7()\n{\n    OPTIX_DEFINE_optixGetExceptionDetail_BODY( 7 );\n}\n\n#undef OPTIX_DEFINE_optixGetExceptionDetail_BODY\n\nstatic __forceinline__ __device__ OptixTraversableHandle optixGetExceptionInvalidTraversable()\n{\n    unsigned long long handle;\n    asm( \"call (%0), _optix_get_exception_invalid_traversable, ();\" : \"=l\"( handle ) : );\n    return (OptixTraversableHandle)handle;\n}\n\nstatic __forceinline__ __device__ int optixGetExceptionInvalidSbtOffset()\n{\n    int s0;\n    asm( \"call (%0), _optix_get_exception_invalid_sbt_offset, ();\" : \"=r\"( s0 ) : );\n    return s0;\n}\n\nstatic __forceinline__ __device__ OptixInvalidRayExceptionDetails optixGetExceptionInvalidRay()\n{\n    float rayOriginX, rayOriginY, rayOriginZ, rayDirectionX, rayDirectionY, rayDirectionZ, tmin, tmax, rayTime;\n    asm( \"call (%0, %1, %2, %3, %4, %5, %6, %7, %8), _optix_get_exception_invalid_ray, ();\"\n         : \"=f\"( rayOriginX ), \"=f\"( rayOriginY ), \"=f\"( rayOriginZ ), \"=f\"( rayDirectionX ), \"=f\"( rayDirectionY ),\n           \"=f\"( rayDirectionZ ), \"=f\"( tmin ), \"=f\"( tmax ), \"=f\"( rayTime )\n         : );\n    OptixInvalidRayExceptionDetails ray;\n    ray.origin    = make_float3( rayOriginX, rayOriginY, rayOriginZ );\n    ray.direction = make_float3( rayDirectionX, rayDirectionY, rayDirectionZ );\n    ray.tmin      = tmin;\n    ray.tmax      = tmax;\n    ray.time      = rayTime;\n    return ray;\n}\n\nstatic __forceinline__ __device__ OptixParameterMismatchExceptionDetails optixGetExceptionParameterMismatch()\n{\n    unsigned int expected, actual, sbtIdx;\n    unsigned long long calleeName;\n    asm(\n        \"call (%0, %1, %2, %3), _optix_get_exception_parameter_mismatch, ();\"\n        : \"=r\"(expected), \"=r\"(actual), \"=r\"(sbtIdx), \"=l\"(calleeName) : );\n    OptixParameterMismatchExceptionDetails details;\n    details.expectedParameterCount = expected;\n    details.passedArgumentCount = actual;\n    details.sbtIndex = sbtIdx;\n    details.callableName = (char*)calleeName;\n    return details;\n}\n\nstatic __forceinline__ __device__ char* optixGetExceptionLineInfo()\n{\n    unsigned long long ptr;\n    asm( \"call (%0), _optix_get_exception_line_info, ();\" : \"=l\"(ptr) : );\n    return (char*)ptr;\n}\n\ntemplate <typename ReturnT, typename... ArgTypes>\nstatic __forceinline__ __device__ ReturnT optixDirectCall( unsigned int sbtIndex, ArgTypes... args )\n{\n    unsigned long long func;\n    asm( \"call (%0), _optix_call_direct_callable,(%1);\" : \"=l\"( func ) : \"r\"( sbtIndex ) : );\n    using funcT = ReturnT ( * )( ArgTypes... );\n    funcT call  = ( funcT )( func );\n    return call( args... );\n}\n\ntemplate <typename ReturnT, typename... ArgTypes>\nstatic __forceinline__ __device__ ReturnT optixContinuationCall( unsigned int sbtIndex, ArgTypes... args )\n{\n    unsigned long long func;\n    asm( \"call (%0), _optix_call_continuation_callable,(%1);\" : \"=l\"( func ) : \"r\"( sbtIndex ) : );\n    using funcT = ReturnT ( * )( ArgTypes... );\n    funcT call  = ( funcT )( func );\n    return call( args... );\n}\n#endif\n\nstatic __forceinline__ __device__ uint4 optixTexFootprint2D( unsigned long long tex, unsigned int texInfo, float x, float y, unsigned int* singleMipLevel )\n{\n    uint4              result;\n    unsigned long long resultPtr         = reinterpret_cast<unsigned long long>( &result );\n    unsigned long long singleMipLevelPtr = reinterpret_cast<unsigned long long>( singleMipLevel );\n    // Cast float args to integers, because the intrinics take .b32 arguments when compiled to PTX.\n    asm volatile(\n        \"call _optix_tex_footprint_2d_v2\"\n        \", (%0, %1, %2, %3, %4, %5);\"\n        :\n        : \"l\"( tex ), \"r\"( texInfo ), \"r\"( __float_as_uint( x ) ), \"r\"( __float_as_uint( y ) ),\n          \"l\"( singleMipLevelPtr ), \"l\"( resultPtr )\n        : );\n    return result;\n}\n\nstatic __forceinline__ __device__ uint4 optixTexFootprint2DGrad( unsigned long long tex,\n                                                                 unsigned int       texInfo,\n                                                                 float              x,\n                                                                 float              y,\n                                                                 float              dPdx_x,\n                                                                 float              dPdx_y,\n                                                                 float              dPdy_x,\n                                                                 float              dPdy_y,\n                                                                 bool               coarse,\n                                                                 unsigned int*      singleMipLevel )\n{\n    uint4              result;\n    unsigned long long resultPtr         = reinterpret_cast<unsigned long long>( &result );\n    unsigned long long singleMipLevelPtr = reinterpret_cast<unsigned long long>( singleMipLevel );\n    // Cast float args to integers, because the intrinics take .b32 arguments when compiled to PTX.\n    asm volatile(\n        \"call _optix_tex_footprint_2d_grad_v2\"\n        \", (%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10);\"\n        :\n        : \"l\"( tex ), \"r\"( texInfo ), \"r\"( __float_as_uint( x ) ), \"r\"( __float_as_uint( y ) ),\n          \"r\"( __float_as_uint( dPdx_x ) ), \"r\"( __float_as_uint( dPdx_y ) ), \"r\"( __float_as_uint( dPdy_x ) ),\n          \"r\"( __float_as_uint( dPdy_y ) ), \"r\"( static_cast<unsigned int>( coarse ) ), \"l\"( singleMipLevelPtr ), \"l\"( resultPtr )\n        : );\n\n    return result;\n}\n\nstatic __forceinline__ __device__ uint4\noptixTexFootprint2DLod( unsigned long long tex, unsigned int texInfo, float x, float y, float level, bool coarse, unsigned int* singleMipLevel )\n{\n    uint4              result;\n    unsigned long long resultPtr         = reinterpret_cast<unsigned long long>( &result );\n    unsigned long long singleMipLevelPtr = reinterpret_cast<unsigned long long>( singleMipLevel );\n    // Cast float args to integers, because the intrinics take .b32 arguments when compiled to PTX.\n    asm volatile(\n        \"call _optix_tex_footprint_2d_lod_v2\"\n        \", (%0, %1, %2, %3, %4, %5, %6, %7);\"\n        :\n        : \"l\"( tex ), \"r\"( texInfo ), \"r\"( __float_as_uint( x ) ), \"r\"( __float_as_uint( y ) ),\n          \"r\"( __float_as_uint( level ) ), \"r\"( static_cast<unsigned int>( coarse ) ), \"l\"( singleMipLevelPtr ), \"l\"( resultPtr )\n        : );\n    return result;\n}\n"
  },
  {
    "path": "render/optixutils/include/internal/optix_7_device_impl_exception.h",
    "content": "/*\n* Copyright (c) 2021 NVIDIA Corporation.  All rights reserved.\n*\n* NVIDIA Corporation and its licensors retain all intellectual property and proprietary\n* rights in and to this software, related documentation and any modifications thereto.\n* Any use, reproduction, disclosure or distribution of this software and related\n* documentation without an express license agreement from NVIDIA Corporation is strictly\n* prohibited.\n*\n* TO THE MAXIMUM EXTENT PERMITTED BY APPLICABLE LAW, THIS SOFTWARE IS PROVIDED *AS IS*\n* AND NVIDIA AND ITS SUPPLIERS DISCLAIM ALL WARRANTIES, EITHER EXPRESS OR IMPLIED,\n* INCLUDING, BUT NOT LIMITED TO, IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A\n* PARTICULAR PURPOSE.  IN NO EVENT SHALL NVIDIA OR ITS SUPPLIERS BE LIABLE FOR ANY\n* SPECIAL, INCIDENTAL, INDIRECT, OR CONSEQUENTIAL DAMAGES WHATSOEVER (INCLUDING, WITHOUT\n* LIMITATION, DAMAGES FOR LOSS OF BUSINESS PROFITS, BUSINESS INTERRUPTION, LOSS OF\n* BUSINESS INFORMATION, OR ANY OTHER PECUNIARY LOSS) ARISING OUT OF THE USE OF OR\n* INABILITY TO USE THIS SOFTWARE, EVEN IF NVIDIA HAS BEEN ADVISED OF THE POSSIBILITY OF\n* SUCH DAMAGES\n*/\n\n/**\n* @file   optix_7_device_impl_exception.h\n* @author NVIDIA Corporation\n* @brief  OptiX public API\n*\n* OptiX public API Reference - Device side implementation for exception helper function.\n*/\n\n#if !defined( __OPTIX_INCLUDE_INTERNAL_HEADERS__ )\n#error(\"optix_7_device_impl_exception.h is an internal header file and must not be used directly.  Please use optix_device.h or optix.h instead.\")\n#endif\n\n#ifndef __optix_optix_7_device_impl_exception_h__\n#define __optix_optix_7_device_impl_exception_h__\n\n#if !defined(__CUDACC_RTC__)\n#include <cstdio> /* for printf */\n#endif\n\nnamespace optix_impl {\n\n    static __forceinline__ __device__ void optixDumpStaticTransformFromHandle( OptixTraversableHandle handle )\n    {\n        const OptixStaticTransform* traversable = optixGetStaticTransformFromHandle( handle );\n        if( traversable )\n        {\n            const uint3 index = optixGetLaunchIndex();\n            printf( \"(%4i,%4i,%4i)     OptixStaticTransform@%p = {\\n\"\n                    \"                       child        = %p,\\n\"\n                    \"                       transform    = { %f,%f,%f,%f,\\n\"\n                    \"                                        %f,%f,%f,%f,\\n\"\n                    \"                                        %f,%f,%f,%f } }\\n\",\n                index.x,index.y,index.z,\n                traversable,\n                (void*)traversable->child,\n                traversable->transform[0], traversable->transform[1], traversable->transform[2], traversable->transform[3],\n                traversable->transform[4], traversable->transform[5], traversable->transform[6], traversable->transform[7],\n                traversable->transform[8], traversable->transform[9], traversable->transform[10], traversable->transform[11] );\n        }\n    }\n\n    static __forceinline__ __device__ void optixDumpMotionMatrixTransformFromHandle( OptixTraversableHandle handle )\n    {\n        const OptixMatrixMotionTransform* traversable =  optixGetMatrixMotionTransformFromHandle( handle );\n        if( traversable )\n        {\n            const uint3 index = optixGetLaunchIndex();\n            printf( \"(%4i,%4i,%4i)     OptixMatrixMotionTransform@%p = {\\n\"\n                    \"                       child         = %p,\\n\"\n                    \"                       motionOptions = { numKeys = %i, flags = %i, timeBegin = %f, timeEnd = %f },\\n\"\n                    \"                       transform     = { { %f,%f,%f,%f,\\n\"\n                    \"                                           %f,%f,%f,%f,\\n\"\n                    \"                                           %f,%f,%f,%f }, ... }\\n\",\n                index.x,index.y,index.z,\n                traversable,\n                (void*)traversable->child,\n                (int)traversable->motionOptions.numKeys, (int)traversable->motionOptions.flags, traversable->motionOptions.timeBegin, traversable->motionOptions.timeEnd,\n                traversable->transform[0][0], traversable->transform[0][1], traversable->transform[0][2],  traversable->transform[0][3],\n                traversable->transform[0][4], traversable->transform[0][5], traversable->transform[0][6],  traversable->transform[0][7],\n                traversable->transform[0][8], traversable->transform[0][9], traversable->transform[0][10], traversable->transform[0][11] );\n        }\n    }\n\n    static __forceinline__ __device__ void optixDumpSrtMatrixTransformFromHandle( OptixTraversableHandle handle )\n    {\n        const OptixSRTMotionTransform* traversable =  optixGetSRTMotionTransformFromHandle( handle );\n        if( traversable )\n        {\n            const uint3 index = optixGetLaunchIndex();\n            printf( \"(%4i,%4i,%4i)     OptixSRTMotionTransform@%p = {\\n\"\n                    \"                       child         = %p,\\n\"\n                    \"                       motionOptions = { numKeys = %i, flags = %i, timeBegin = %f, timeEnd = %f },\\n\"\n                    \"                       srtData       = { { sx  = %f,  a = %f,   b = %f, pvx = %f,\\n\"\n                    \"                                           sy  = %f,  c = %f, pvy = %f,  sz = %f,\\n\"\n                    \"                                           pvz = %f, qx = %f,  qy = %f,  qz = %f,\\n\"\n                    \"                                           qw  = %f, tx = %f,  ty = %f,  tz = %f }, ... }\\n\",\n                index.x,index.y,index.z,\n                traversable,\n                (void*)traversable->child,\n                (int)traversable->motionOptions.numKeys, (int)traversable->motionOptions.flags, traversable->motionOptions.timeBegin, traversable->motionOptions.timeEnd,\n                traversable->srtData[0].sx, traversable->srtData[0].a, traversable->srtData[0].b,  traversable->srtData[0].pvx,\n                traversable->srtData[0].sy, traversable->srtData[0].c, traversable->srtData[0].pvy,traversable->srtData[0].sz,\n                traversable->srtData[0].pvz,traversable->srtData[0].qx,traversable->srtData[0].qy, traversable->srtData[0].qz,\n                traversable->srtData[0].qw, traversable->srtData[0].tx,traversable->srtData[0].ty, traversable->srtData[0].tz );\n        }\n    }\n\n    static __forceinline__ __device__ void optixDumpInstanceFromHandle( OptixTraversableHandle handle )\n    {\n        if( optixGetTransformTypeFromHandle( handle ) == OPTIX_TRANSFORM_TYPE_INSTANCE )\n        {\n            unsigned int instanceId = optixGetInstanceIdFromHandle( handle );\n            const float4* transform = optixGetInstanceTransformFromHandle( handle );\n\n            const uint3 index = optixGetLaunchIndex();\n            printf( \"(%4i,%4i,%4i)     OptixInstance = {\\n\"\n                    \"                       instanceId = %i,\\n\"\n                    \"                       transform  = { %f,%f,%f,%f,\\n\"\n                    \"                                      %f,%f,%f,%f,\\n\"\n                    \"                                      %f,%f,%f,%f } }\\n\",\n                index.x,index.y,index.z,\n                instanceId,\n                transform[0].x, transform[0].y, transform[0].z,  transform[0].w,\n                transform[1].x, transform[1].y, transform[1].z,  transform[1].w,\n                transform[2].x, transform[2].y, transform[2].z,  transform[2].w );\n        }\n    }\n\n    static __forceinline__ __device__ void optixDumpTransform( OptixTraversableHandle handle )\n    {\n        const OptixTransformType type = optixGetTransformTypeFromHandle( handle );\n        const uint3 index = optixGetLaunchIndex();\n\n        switch( type )\n        {\n            case OPTIX_TRANSFORM_TYPE_NONE:\n                break;\n            case OPTIX_TRANSFORM_TYPE_STATIC_TRANSFORM:\n                optixDumpStaticTransformFromHandle( handle );\n                break;\n            case OPTIX_TRANSFORM_TYPE_MATRIX_MOTION_TRANSFORM:\n                optixDumpMotionMatrixTransformFromHandle( handle );\n                break;\n            case OPTIX_TRANSFORM_TYPE_SRT_MOTION_TRANSFORM:\n                optixDumpSrtMatrixTransformFromHandle( handle );\n                break;\n            case OPTIX_TRANSFORM_TYPE_INSTANCE:\n                optixDumpInstanceFromHandle( handle );\n                break;\n            default:\n                break;\n        }\n    }\n\n    static __forceinline__ __device__ void optixDumpTransformList()\n    {\n        const int tlistSize = optixGetTransformListSize();\n        const uint3 index = optixGetLaunchIndex();\n\n        printf(\"(%4i,%4i,%4i) transform list of size %i:\\n\", index.x,index.y,index.z, tlistSize);\n\n        for( unsigned int i = 0 ; i < tlistSize ; ++i )\n        {\n            OptixTraversableHandle handle = optixGetTransformListHandle( i );\n            printf(\"(%4i,%4i,%4i)   transform[%i] = %p\\n\", index.x, index.y, index.z, i, (void*)handle);\n            optixDumpTransform(handle);\n        }\n    }\n\n    static __forceinline__ __device__ void optixDumpExceptionDetails()\n    {\n        bool dumpTlist = false;\n        const int exceptionCode = optixGetExceptionCode();\n        const uint3 index = optixGetLaunchIndex();\n\n        if( exceptionCode == OPTIX_EXCEPTION_CODE_STACK_OVERFLOW )\n        {\n            printf(\"(%4i,%4i,%4i) error: stack overflow\\n\", index.x,index.y,index.z);\n        }\n        else if( exceptionCode == OPTIX_EXCEPTION_CODE_TRACE_DEPTH_EXCEEDED )\n        {\n            printf(\"(%4i,%4i,%4i) error: trace depth exceeded\\n\", index.x,index.y,index.z);\n        }\n        else if( exceptionCode == OPTIX_EXCEPTION_CODE_TRAVERSAL_DEPTH_EXCEEDED )\n        {\n            printf(\"(%4i,%4i,%4i) error: traversal depth exceeded\\n\", index.x,index.y,index.z);\n            dumpTlist = true;\n        }\n        else if( exceptionCode == OPTIX_EXCEPTION_CODE_TRAVERSAL_INVALID_TRAVERSABLE )\n        {\n            OptixTraversableHandle handle = optixGetExceptionInvalidTraversable();\n            printf(\"(%4i,%4i,%4i) error: invalid traversable %p\\n\", index.x,index.y,index.z, (void*)handle);\n            dumpTlist = true;\n        }\n        else if( exceptionCode == OPTIX_EXCEPTION_CODE_TRAVERSAL_INVALID_MISS_SBT )\n        {\n            int sbtOffset = optixGetExceptionInvalidSbtOffset();\n            printf(\"(%4i,%4i,%4i) error: invalid miss sbt of %i\\n\", index.x,index.y,index.z, sbtOffset);\n        }\n        else if( exceptionCode == OPTIX_EXCEPTION_CODE_TRAVERSAL_INVALID_HIT_SBT )\n        {\n            int sbtOffset = optixGetExceptionInvalidSbtOffset();\n            printf(\"(%4i,%4i,%4i) error: invalid hit sbt of %i at primitive with gas sbt index %i\\n\", index.x,index.y,index.z, sbtOffset, optixGetSbtGASIndex() );\n            dumpTlist = true;\n        }\n        else if( exceptionCode == OPTIX_EXCEPTION_CODE_UNSUPPORTED_PRIMITIVE_TYPE )\n        {\n            dumpTlist = true;\n            printf( \"(%4i,%4i,%4i) error: shader encountered unsupported builtin type\\n\"\n                    \"       call location:   %s\\n\", index.x, index.y, index.z, optixGetExceptionLineInfo() );\n        }\n        else if( exceptionCode == OPTIX_EXCEPTION_CODE_INVALID_RAY )\n        {\n            OptixInvalidRayExceptionDetails ray = optixGetExceptionInvalidRay();\n            printf( \"(%4i,%4i,%4i) error: encountered ray with nan or inf values:\\n\", index.x, index.y, index.z );\n            printf(\n                \"       origin:          [%f, %f, %f]\\n\"\n                \"       direction:       [%f, %f, %f]\\n\"\n                \"       tmin:            %f\\n\"\n                \"       tmax:            %f\\n\"\n                \"       rayTime:         %f\\n\"\n                \"       call location:   %s\\n\",\n                ray.origin.x, ray.origin.y, ray.origin.z, ray.direction.x, ray.direction.y,\n                ray.direction.z, ray.tmin, ray.tmax, ray.time, optixGetExceptionLineInfo() );\n        }\n        else if( exceptionCode == OPTIX_EXCEPTION_CODE_CALLABLE_PARAMETER_MISMATCH )\n        {\n             OptixParameterMismatchExceptionDetails details = optixGetExceptionParameterMismatch();\n             printf( \"(%4i,%4i,%4i) error: parameter mismatch in callable call.\\n\", index.x, index.y, index.z );\n             printf(\n                \"       passed packed arguments:       %u 32 Bit values\\n\"\n                \"       expected packed parameters:    %u 32 Bit values\\n\"\n                \"       SBT index:                     %u\\n\"\n                \"       called function:               %s\\n\"\n                \"       call location:                 %s\\n\",\n                details.passedArgumentCount, details.expectedParameterCount, details.sbtIndex,\n                details.callableName, optixGetExceptionLineInfo() );\n        }\n        else if( exceptionCode == OPTIX_EXCEPTION_CODE_BUILTIN_IS_MISMATCH )\n        {\n            dumpTlist = true;\n            printf(\"(%4i,%4i,%4i) error: mismatch between builtin IS shader and build input\\n\"\n                   \"       call location:   %s\\n\", index.x,index.y,index.z, optixGetExceptionLineInfo() );\n        }\n        else if( exceptionCode == OPTIX_EXCEPTION_CODE_CALLABLE_INVALID_SBT )\n        {\n            int sbtOffset = optixGetExceptionInvalidSbtOffset();\n            printf( \"(%4i,%4i,%4i) error: invalid sbt offset of %i for callable program\\n\", index.x, index.y, index.z, sbtOffset );\n        }\n        else if( exceptionCode == OPTIX_EXCEPTION_CODE_CALLABLE_NO_DC_SBT_RECORD )\n        {\n            int sbtOffset = optixGetExceptionInvalidSbtOffset();\n            printf( \"(%4i,%4i,%4i) error: invalid sbt offset of %i for direct callable program\\n\", index.x, index.y, index.z, sbtOffset );\n        }\n        else if( exceptionCode == OPTIX_EXCEPTION_CODE_CALLABLE_NO_CC_SBT_RECORD )\n        {\n            int sbtOffset = optixGetExceptionInvalidSbtOffset();\n            printf( \"(%4i,%4i,%4i) error: invalid sbt offset of %i for continuation callable program\\n\", index.x, index.y, index.z, sbtOffset );\n        }\n        else if( exceptionCode == OPTIX_EXCEPTION_CODE_UNSUPPORTED_SINGLE_LEVEL_GAS )\n        {\n            OptixTraversableHandle handle = optixGetExceptionInvalidTraversable();\n            printf(\"(%4i,%4i,%4i) error: unsupported single GAS traversable graph %p\\n\", index.x,index.y,index.z, (void*)handle);\n            dumpTlist = true;\n        }\n        else if( ( exceptionCode <= OPTIX_EXCEPTION_CODE_INVALID_VALUE_ARGUMENT_0 ) && ( exceptionCode >= OPTIX_EXCEPTION_CODE_INVALID_VALUE_ARGUMENT_2 ) )\n        {\n            printf(\"(%4i,%4i,%4i) error: invalid value for argument %i\\n\", index.x,index.y,index.z, -(exceptionCode - OPTIX_EXCEPTION_CODE_INVALID_VALUE_ARGUMENT_0) );\n        }\n        else if( exceptionCode == OPTIX_EXCEPTION_CODE_UNSUPPORTED_DATA_ACCESS )\n        {\n            printf(\"(%4i,%4i,%4i) error: unsupported random data access\\n\", index.x,index.y,index.z);\n        }\n        else if( exceptionCode >= 0 )\n        {\n            dumpTlist = true;\n            printf( \"(%4i,%4i,%4i) error: user exception with error code %i\\n\"\n                    \"       call location:   %s\\n\", index.x, index.y, index.z, exceptionCode, optixGetExceptionLineInfo() );\n        }\n        else\n        {\n            printf(\"(%4i,%4i,%4i) error: unknown exception with error code %i\\n\", index.x,index.y,index.z, exceptionCode);\n        }\n\n        if( dumpTlist )\n            optixDumpTransformList();\n    }\n\n}  // namespace optix_impl\n\n#endif\n"
  },
  {
    "path": "render/optixutils/include/internal/optix_7_device_impl_transformations.h",
    "content": "/*\n* Copyright (c) 2021 NVIDIA Corporation.  All rights reserved.\n*\n* NVIDIA Corporation and its licensors retain all intellectual property and proprietary\n* rights in and to this software, related documentation and any modifications thereto.\n* Any use, reproduction, disclosure or distribution of this software and related\n* documentation without an express license agreement from NVIDIA Corporation is strictly\n* prohibited.\n*\n* TO THE MAXIMUM EXTENT PERMITTED BY APPLICABLE LAW, THIS SOFTWARE IS PROVIDED *AS IS*\n* AND NVIDIA AND ITS SUPPLIERS DISCLAIM ALL WARRANTIES, EITHER EXPRESS OR IMPLIED,\n* INCLUDING, BUT NOT LIMITED TO, IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A\n* PARTICULAR PURPOSE.  IN NO EVENT SHALL NVIDIA OR ITS SUPPLIERS BE LIABLE FOR ANY\n* SPECIAL, INCIDENTAL, INDIRECT, OR CONSEQUENTIAL DAMAGES WHATSOEVER (INCLUDING, WITHOUT\n* LIMITATION, DAMAGES FOR LOSS OF BUSINESS PROFITS, BUSINESS INTERRUPTION, LOSS OF\n* BUSINESS INFORMATION, OR ANY OTHER PECUNIARY LOSS) ARISING OUT OF THE USE OF OR\n* INABILITY TO USE THIS SOFTWARE, EVEN IF NVIDIA HAS BEEN ADVISED OF THE POSSIBILITY OF\n* SUCH DAMAGES\n*/\n\n/**\n* @file   optix_7_device_impl_transformations.h\n* @author NVIDIA Corporation\n* @brief  OptiX public API\n*\n* OptiX public API Reference - Device side implementation for transformation helper functions.\n*/\n\n#if !defined( __OPTIX_INCLUDE_INTERNAL_HEADERS__ )\n#error(\"optix_7_device_impl_transformations.h is an internal header file and must not be used directly.  Please use optix_device.h or optix.h instead.\")\n#endif\n\n#ifndef __optix_optix_7_device_impl_transformations_h__\n#define __optix_optix_7_device_impl_transformations_h__\n\nnamespace optix_impl {\n\nstatic __forceinline__ __device__ float4 optixAddFloat4( const float4& a, const float4& b )\n{\n    return make_float4( a.x + b.x, a.y + b.y, a.z + b.z, a.w + b.w );\n}\n\nstatic __forceinline__ __device__ float4 optixMulFloat4( const float4& a, float b )\n{\n    return make_float4( a.x * b, a.y * b, a.z * b, a.w * b );\n}\n\nstatic __forceinline__ __device__ uint4 optixLdg( unsigned long long addr )\n{\n    const uint4* ptr;\n    asm volatile( \"cvta.to.global.u64 %0, %1;\" : \"=l\"( ptr ) : \"l\"( addr ) );\n    uint4 ret;\n    asm volatile( \"ld.global.v4.u32 {%0,%1,%2,%3}, [%4];\"\n                  : \"=r\"( ret.x ), \"=r\"( ret.y ), \"=r\"( ret.z ), \"=r\"( ret.w )\n                  : \"l\"( ptr ) );\n    return ret;\n}\n\ntemplate <class T>\nstatic __forceinline__ __device__ T optixLoadReadOnlyAlign16( const T* ptr )\n{\n    T v;\n    for( int ofs                     = 0; ofs < sizeof( T ); ofs += 16 )\n        *(uint4*)( (char*)&v + ofs ) = optixLdg( (unsigned long long)( (char*)ptr + ofs ) );\n    return v;\n}\n\n// Multiplies the row vector vec with the 3x4 matrix with rows m0, m1, and m2\nstatic __forceinline__ __device__ float4 optixMultiplyRowMatrix( const float4 vec, const float4 m0, const float4 m1, const float4 m2 )\n{\n    float4 result;\n\n    result.x = vec.x * m0.x + vec.y * m1.x + vec.z * m2.x;\n    result.y = vec.x * m0.y + vec.y * m1.y + vec.z * m2.y;\n    result.z = vec.x * m0.z + vec.y * m1.z + vec.z * m2.z;\n    result.w = vec.x * m0.w + vec.y * m1.w + vec.z * m2.w + vec.w;\n\n    return result;\n}\n\n// Converts the SRT transformation srt into a 3x4 matrix with rows m0, m1, and m2\nstatic __forceinline__ __device__ void optixGetMatrixFromSrt( float4& m0, float4& m1, float4& m2, const OptixSRTData& srt )\n{\n    const float4 q = {srt.qx, srt.qy, srt.qz, srt.qw};\n\n    // normalize\n    const float  inv_sql = 1.f / ( srt.qx * srt.qx + srt.qy * srt.qy + srt.qz * srt.qz + srt.qw * srt.qw );\n    const float4 nq      = optixMulFloat4( q, inv_sql );\n\n    const float sqw = q.w * nq.w;\n    const float sqx = q.x * nq.x;\n    const float sqy = q.y * nq.y;\n    const float sqz = q.z * nq.z;\n\n    const float xy = q.x * nq.y;\n    const float zw = q.z * nq.w;\n    const float xz = q.x * nq.z;\n    const float yw = q.y * nq.w;\n    const float yz = q.y * nq.z;\n    const float xw = q.x * nq.w;\n\n    m0.x = ( sqx - sqy - sqz + sqw );\n    m0.y = 2.0f * ( xy - zw );\n    m0.z = 2.0f * ( xz + yw );\n\n    m1.x = 2.0f * ( xy + zw );\n    m1.y = ( -sqx + sqy - sqz + sqw );\n    m1.z = 2.0f * ( yz - xw );\n\n    m2.x = 2.0f * ( xz - yw );\n    m2.y = 2.0f * ( yz + xw );\n    m2.z = ( -sqx - sqy + sqz + sqw );\n\n    m0.w = m0.x * srt.pvx + m0.y * srt.pvy + m0.z * srt.pvz + srt.tx;\n    m1.w = m1.x * srt.pvx + m1.y * srt.pvy + m1.z * srt.pvz + srt.ty;\n    m2.w = m2.x * srt.pvx + m2.y * srt.pvy + m2.z * srt.pvz + srt.tz;\n\n    m0.z = m0.x * srt.b + m0.y * srt.c + m0.z * srt.sz;\n    m1.z = m1.x * srt.b + m1.y * srt.c + m1.z * srt.sz;\n    m2.z = m2.x * srt.b + m2.y * srt.c + m2.z * srt.sz;\n\n    m0.y = m0.x * srt.a + m0.y * srt.sy;\n    m1.y = m1.x * srt.a + m1.y * srt.sy;\n    m2.y = m2.x * srt.a + m2.y * srt.sy;\n\n    m0.x = m0.x * srt.sx;\n    m1.x = m1.x * srt.sx;\n    m2.x = m2.x * srt.sx;\n}\n\n// Inverts a 3x4 matrix in place\nstatic __forceinline__ __device__ void optixInvertMatrix( float4& m0, float4& m1, float4& m2 )\n{\n    const float det3 =\n        m0.x * ( m1.y * m2.z - m1.z * m2.y ) - m0.y * ( m1.x * m2.z - m1.z * m2.x ) + m0.z * ( m1.x * m2.y - m1.y * m2.x );\n\n    const float inv_det3 = 1.0f / det3;\n\n    float inv3[3][3];\n    inv3[0][0] = inv_det3 * ( m1.y * m2.z - m2.y * m1.z );\n    inv3[0][1] = inv_det3 * ( m0.z * m2.y - m2.z * m0.y );\n    inv3[0][2] = inv_det3 * ( m0.y * m1.z - m1.y * m0.z );\n\n    inv3[1][0] = inv_det3 * ( m1.z * m2.x - m2.z * m1.x );\n    inv3[1][1] = inv_det3 * ( m0.x * m2.z - m2.x * m0.z );\n    inv3[1][2] = inv_det3 * ( m0.z * m1.x - m1.z * m0.x );\n\n    inv3[2][0] = inv_det3 * ( m1.x * m2.y - m2.x * m1.y );\n    inv3[2][1] = inv_det3 * ( m0.y * m2.x - m2.y * m0.x );\n    inv3[2][2] = inv_det3 * ( m0.x * m1.y - m1.x * m0.y );\n\n    const float b[3] = {m0.w, m1.w, m2.w};\n\n    m0.x = inv3[0][0];\n    m0.y = inv3[0][1];\n    m0.z = inv3[0][2];\n    m0.w = -inv3[0][0] * b[0] - inv3[0][1] * b[1] - inv3[0][2] * b[2];\n\n    m1.x = inv3[1][0];\n    m1.y = inv3[1][1];\n    m1.z = inv3[1][2];\n    m1.w = -inv3[1][0] * b[0] - inv3[1][1] * b[1] - inv3[1][2] * b[2];\n\n    m2.x = inv3[2][0];\n    m2.y = inv3[2][1];\n    m2.z = inv3[2][2];\n    m2.w = -inv3[2][0] * b[0] - inv3[2][1] * b[1] - inv3[2][2] * b[2];\n}\n\nstatic __forceinline__ __device__ void optixLoadInterpolatedMatrixKey( float4& m0, float4& m1, float4& m2, const float4* matrix, const float t1 )\n{\n    m0 = optixLoadReadOnlyAlign16( &matrix[0] );\n    m1 = optixLoadReadOnlyAlign16( &matrix[1] );\n    m2 = optixLoadReadOnlyAlign16( &matrix[2] );\n\n    // The conditional prevents concurrent loads leading to spills\n    if( t1 > 0.0f )\n    {\n        const float t0 = 1.0f - t1;\n        m0 = optixAddFloat4( optixMulFloat4( m0, t0 ), optixMulFloat4( optixLoadReadOnlyAlign16( &matrix[3] ), t1 ) );\n        m1 = optixAddFloat4( optixMulFloat4( m1, t0 ), optixMulFloat4( optixLoadReadOnlyAlign16( &matrix[4] ), t1 ) );\n        m2 = optixAddFloat4( optixMulFloat4( m2, t0 ), optixMulFloat4( optixLoadReadOnlyAlign16( &matrix[5] ), t1 ) );\n    }\n}\n\nstatic __forceinline__ __device__ void optixLoadInterpolatedSrtKey( float4&       srt0,\n                                                                    float4&       srt1,\n                                                                    float4&       srt2,\n                                                                    float4&       srt3,\n                                                                    const float4* srt,\n                                                                    const float   t1 )\n{\n    srt0 = optixLoadReadOnlyAlign16( &srt[0] );\n    srt1 = optixLoadReadOnlyAlign16( &srt[1] );\n    srt2 = optixLoadReadOnlyAlign16( &srt[2] );\n    srt3 = optixLoadReadOnlyAlign16( &srt[3] );\n\n    // The conditional prevents concurrent loads leading to spills\n    if( t1 > 0.0f )\n    {\n        const float t0 = 1.0f - t1;\n        srt0 = optixAddFloat4( optixMulFloat4( srt0, t0 ), optixMulFloat4( optixLoadReadOnlyAlign16( &srt[4] ), t1 ) );\n        srt1 = optixAddFloat4( optixMulFloat4( srt1, t0 ), optixMulFloat4( optixLoadReadOnlyAlign16( &srt[5] ), t1 ) );\n        srt2 = optixAddFloat4( optixMulFloat4( srt2, t0 ), optixMulFloat4( optixLoadReadOnlyAlign16( &srt[6] ), t1 ) );\n        srt3 = optixAddFloat4( optixMulFloat4( srt3, t0 ), optixMulFloat4( optixLoadReadOnlyAlign16( &srt[7] ), t1 ) );\n\n        float inv_length = 1.f / sqrt( srt2.y * srt2.y + srt2.z * srt2.z + srt2.w * srt2.w + srt3.x * srt3.x );\n        srt2.y *= inv_length;\n        srt2.z *= inv_length;\n        srt2.w *= inv_length;\n        srt3.x *= inv_length;\n    }\n}\n\nstatic __forceinline__ __device__ void optixResolveMotionKey( float& localt, int& key, const OptixMotionOptions& options, const float globalt )\n{\n    const float timeBegin    = options.timeBegin;\n    const float timeEnd      = options.timeEnd;\n    const float numIntervals = (float)( options.numKeys - 1 );\n\n    // No need to check the motion flags. If data originates from a valid transform list handle, then globalt is in\n    // range, or vanish flags are not set.\n\n    const float time = max( 0.f, min( numIntervals, ( globalt - timeBegin ) * numIntervals / ( timeEnd - timeBegin ) ) );\n    const float fltKey = floorf( time );\n\n    localt = time - fltKey;\n    key    = (int)fltKey;\n}\n\n// Returns the interpolated transformation matrix for a particular matrix motion transformation and point in time.\nstatic __forceinline__ __device__ void optixGetInterpolatedTransformation( float4&                           trf0,\n                                                                           float4&                           trf1,\n                                                                           float4&                           trf2,\n                                                                           const OptixMatrixMotionTransform* transformData,\n                                                                           const float                       time )\n{\n    // Compute key and intra key time\n    float keyTime;\n    int   key;\n    optixResolveMotionKey( keyTime, key, optixLoadReadOnlyAlign16( transformData ).motionOptions, time );\n\n    // Get pointer to left key\n    const float4* transform = (const float4*)( &transformData->transform[key][0] );\n\n    // Load and interpolate matrix keys\n    optixLoadInterpolatedMatrixKey( trf0, trf1, trf2, transform, keyTime );\n}\n\n// Returns the interpolated transformation matrix for a particular SRT motion transformation and point in time.\nstatic __forceinline__ __device__ void optixGetInterpolatedTransformation( float4&                        trf0,\n                                                                           float4&                        trf1,\n                                                                           float4&                        trf2,\n                                                                           const OptixSRTMotionTransform* transformData,\n                                                                           const float                    time )\n{\n    // Compute key and intra key time\n    float keyTime;\n    int   key;\n    optixResolveMotionKey( keyTime, key, optixLoadReadOnlyAlign16( transformData ).motionOptions, time );\n\n    // Get pointer to left key\n    const float4* dataPtr = reinterpret_cast<const float4*>( &transformData->srtData[key] );\n\n    // Load and interpolated SRT keys\n    float4 data[4];\n    optixLoadInterpolatedSrtKey( data[0], data[1], data[2], data[3], dataPtr, keyTime );\n\n    OptixSRTData srt = {data[0].x, data[0].y, data[0].z, data[0].w, data[1].x, data[1].y, data[1].z, data[1].w,\n                        data[2].x, data[2].y, data[2].z, data[2].w, data[3].x, data[3].y, data[3].z, data[3].w};\n\n    // Convert SRT into a matrix\n    optixGetMatrixFromSrt( trf0, trf1, trf2, srt );\n}\n\n// Returns the interpolated transformation matrix for a particular traversable handle and point in time.\nstatic __forceinline__ __device__ void optixGetInterpolatedTransformationFromHandle( float4&                      trf0,\n                                                                                     float4&                      trf1,\n                                                                                     float4&                      trf2,\n                                                                                     const OptixTraversableHandle handle,\n                                                                                     const float                  time,\n                                                                                     const bool objectToWorld )\n{\n    const OptixTransformType type = optixGetTransformTypeFromHandle( handle );\n\n    if( type == OPTIX_TRANSFORM_TYPE_MATRIX_MOTION_TRANSFORM || type == OPTIX_TRANSFORM_TYPE_SRT_MOTION_TRANSFORM )\n    {\n        if( type == OPTIX_TRANSFORM_TYPE_MATRIX_MOTION_TRANSFORM )\n        {\n            const OptixMatrixMotionTransform* transformData = optixGetMatrixMotionTransformFromHandle( handle );\n            optixGetInterpolatedTransformation( trf0, trf1, trf2, transformData, time );\n        }\n        else\n        {\n            const OptixSRTMotionTransform* transformData = optixGetSRTMotionTransformFromHandle( handle );\n            optixGetInterpolatedTransformation( trf0, trf1, trf2, transformData, time );\n        }\n\n        if( !objectToWorld )\n            optixInvertMatrix( trf0, trf1, trf2 );\n    }\n    else if( type == OPTIX_TRANSFORM_TYPE_INSTANCE || type == OPTIX_TRANSFORM_TYPE_STATIC_TRANSFORM )\n    {\n        const float4* transform;\n\n        if( type == OPTIX_TRANSFORM_TYPE_INSTANCE )\n        {\n            transform = ( objectToWorld ) ? optixGetInstanceTransformFromHandle( handle ) :\n                                            optixGetInstanceInverseTransformFromHandle( handle );\n        }\n        else\n        {\n            const OptixStaticTransform* traversable = optixGetStaticTransformFromHandle( handle );\n            transform = (const float4*)( ( objectToWorld ) ? traversable->transform : traversable->invTransform );\n        }\n\n        trf0 = optixLoadReadOnlyAlign16( &transform[0] );\n        trf1 = optixLoadReadOnlyAlign16( &transform[1] );\n        trf2 = optixLoadReadOnlyAlign16( &transform[2] );\n    }\n    else\n    {\n        trf0 = {1.0f, 0.0f, 0.0f, 0.0f};\n        trf1 = {0.0f, 1.0f, 0.0f, 0.0f};\n        trf2 = {0.0f, 0.0f, 1.0f, 0.0f};\n    }\n}\n\n// Returns the world-to-object transformation matrix resulting from the current transform stack and current ray time.\nstatic __forceinline__ __device__ void optixGetWorldToObjectTransformMatrix( float4& m0, float4& m1, float4& m2 )\n{\n    const unsigned int size = optixGetTransformListSize();\n    const float        time = optixGetRayTime();\n\n#pragma unroll 1\n    for( unsigned int i = 0; i < size; ++i )\n    {\n        OptixTraversableHandle handle = optixGetTransformListHandle( i );\n\n        float4 trf0, trf1, trf2;\n        optixGetInterpolatedTransformationFromHandle( trf0, trf1, trf2, handle, time, /*objectToWorld*/ false );\n\n        if( i == 0 )\n        {\n            m0 = trf0;\n            m1 = trf1;\n            m2 = trf2;\n        }\n        else\n        {\n            // m := trf * m\n            float4 tmp0 = m0, tmp1 = m1, tmp2 = m2;\n            m0 = optixMultiplyRowMatrix( trf0, tmp0, tmp1, tmp2 );\n            m1 = optixMultiplyRowMatrix( trf1, tmp0, tmp1, tmp2 );\n            m2 = optixMultiplyRowMatrix( trf2, tmp0, tmp1, tmp2 );\n        }\n    }\n}\n\n// Returns the object-to-world transformation matrix resulting from the current transform stack and current ray time.\nstatic __forceinline__ __device__ void optixGetObjectToWorldTransformMatrix( float4& m0, float4& m1, float4& m2 )\n{\n    const int   size = optixGetTransformListSize();\n    const float time = optixGetRayTime();\n\n#pragma unroll 1\n    for( int i = size - 1; i >= 0; --i )\n    {\n        OptixTraversableHandle handle = optixGetTransformListHandle( i );\n\n        float4 trf0, trf1, trf2;\n        optixGetInterpolatedTransformationFromHandle( trf0, trf1, trf2, handle, time, /*objectToWorld*/ true );\n\n        if( i == size - 1 )\n        {\n            m0 = trf0;\n            m1 = trf1;\n            m2 = trf2;\n        }\n        else\n        {\n            // m := trf * m\n            float4 tmp0 = m0, tmp1 = m1, tmp2 = m2;\n            m0 = optixMultiplyRowMatrix( trf0, tmp0, tmp1, tmp2 );\n            m1 = optixMultiplyRowMatrix( trf1, tmp0, tmp1, tmp2 );\n            m2 = optixMultiplyRowMatrix( trf2, tmp0, tmp1, tmp2 );\n        }\n    }\n}\n\n// Multiplies the 3x4 matrix with rows m0, m1, m2 with the point p.\nstatic __forceinline__ __device__ float3 optixTransformPoint( const float4& m0, const float4& m1, const float4& m2, const float3& p )\n{\n    float3 result;\n    result.x = m0.x * p.x + m0.y * p.y + m0.z * p.z + m0.w;\n    result.y = m1.x * p.x + m1.y * p.y + m1.z * p.z + m1.w;\n    result.z = m2.x * p.x + m2.y * p.y + m2.z * p.z + m2.w;\n    return result;\n}\n\n// Multiplies the 3x3 linear submatrix of the 3x4 matrix with rows m0, m1, m2 with the vector v.\nstatic __forceinline__ __device__ float3 optixTransformVector( const float4& m0, const float4& m1, const float4& m2, const float3& v )\n{\n    float3 result;\n    result.x = m0.x * v.x + m0.y * v.y + m0.z * v.z;\n    result.y = m1.x * v.x + m1.y * v.y + m1.z * v.z;\n    result.z = m2.x * v.x + m2.y * v.y + m2.z * v.z;\n    return result;\n}\n\n// Multiplies the transpose of the 3x3 linear submatrix of the 3x4 matrix with rows m0, m1, m2 with the normal n.\n// Note that the given matrix is supposed to be the inverse of the actual transformation matrix.\nstatic __forceinline__ __device__ float3 optixTransformNormal( const float4& m0, const float4& m1, const float4& m2, const float3& n )\n{\n    float3 result;\n    result.x = m0.x * n.x + m1.x * n.y + m2.x * n.z;\n    result.y = m0.y * n.x + m1.y * n.y + m2.y * n.z;\n    result.z = m0.z * n.x + m1.z * n.y + m2.z * n.z;\n    return result;\n}\n\n}  // namespace optix_impl\n\n#endif\n"
  },
  {
    "path": "render/optixutils/include/optix.h",
    "content": "\n/*\n * Copyright (c) 2021 NVIDIA Corporation.  All rights reserved.\n *\n * NVIDIA Corporation and its licensors retain all intellectual property and proprietary\n * rights in and to this software, related documentation and any modifications thereto.\n * Any use, reproduction, disclosure or distribution of this software and related\n * documentation without an express license agreement from NVIDIA Corporation is strictly\n * prohibited.\n *\n * TO THE MAXIMUM EXTENT PERMITTED BY APPLICABLE LAW, THIS SOFTWARE IS PROVIDED *AS IS*\n * AND NVIDIA AND ITS SUPPLIERS DISCLAIM ALL WARRANTIES, EITHER EXPRESS OR IMPLIED,\n * INCLUDING, BUT NOT LIMITED TO, IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A\n * PARTICULAR PURPOSE.  IN NO EVENT SHALL NVIDIA OR ITS SUPPLIERS BE LIABLE FOR ANY\n * SPECIAL, INCIDENTAL, INDIRECT, OR CONSEQUENTIAL DAMAGES WHATSOEVER (INCLUDING, WITHOUT\n * LIMITATION, DAMAGES FOR LOSS OF BUSINESS PROFITS, BUSINESS INTERRUPTION, LOSS OF\n * BUSINESS INFORMATION, OR ANY OTHER PECUNIARY LOSS) ARISING OUT OF THE USE OF OR\n * INABILITY TO USE THIS SOFTWARE, EVEN IF NVIDIA HAS BEEN ADVISED OF THE POSSIBILITY OF\n * SUCH DAMAGES\n */\n\n/// @file\n/// @author NVIDIA Corporation\n/// @brief  OptiX public API header\n///\n/// Includes the host api if compiling host code, includes the cuda api if compiling device code.\n/// For the math library routines include optix_math.h\n\n#ifndef __optix_optix_h__\n#define __optix_optix_h__\n\n/// The OptiX version.\n///\n/// - major =  OPTIX_VERSION/10000\n/// - minor = (OPTIX_VERSION%10000)/100\n/// - micro =  OPTIX_VERSION%100\n#define OPTIX_VERSION 70300\n\n\n#ifdef __CUDACC__\n#include \"optix_device.h\"\n#else\n#include \"optix_host.h\"\n#endif\n\n\n#endif  // __optix_optix_h__\n"
  },
  {
    "path": "render/optixutils/include/optix_7_device.h",
    "content": "/*\n* Copyright (c) 2021 NVIDIA Corporation.  All rights reserved.\n*\n* NVIDIA Corporation and its licensors retain all intellectual property and proprietary\n* rights in and to this software, related documentation and any modifications thereto.\n* Any use, reproduction, disclosure or distribution of this software and related\n* documentation without an express license agreement from NVIDIA Corporation is strictly\n* prohibited.\n*\n* TO THE MAXIMUM EXTENT PERMITTED BY APPLICABLE LAW, THIS SOFTWARE IS PROVIDED *AS IS*\n* AND NVIDIA AND ITS SUPPLIERS DISCLAIM ALL WARRANTIES, EITHER EXPRESS OR IMPLIED,\n* INCLUDING, BUT NOT LIMITED TO, IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A\n* PARTICULAR PURPOSE.  IN NO EVENT SHALL NVIDIA OR ITS SUPPLIERS BE LIABLE FOR ANY\n* SPECIAL, INCIDENTAL, INDIRECT, OR CONSEQUENTIAL DAMAGES WHATSOEVER (INCLUDING, WITHOUT\n* LIMITATION, DAMAGES FOR LOSS OF BUSINESS PROFITS, BUSINESS INTERRUPTION, LOSS OF\n* BUSINESS INFORMATION, OR ANY OTHER PECUNIARY LOSS) ARISING OUT OF THE USE OF OR\n* INABILITY TO USE THIS SOFTWARE, EVEN IF NVIDIA HAS BEEN ADVISED OF THE POSSIBILITY OF\n* SUCH DAMAGES\n*/\n\n/// @file\n/// @author NVIDIA Corporation\n/// @brief  OptiX public API header\n///\n/// OptiX public API Reference - Device API declarations\n\n#if !defined( __OPTIX_INCLUDE_INTERNAL_HEADERS__ )\n#error(\"optix_7_device.h is an internal header file and must not be used directly.  Please use optix_device.h or optix.h instead.\")\n#endif\n\n\n#ifndef __optix_optix_7_device_h__\n#define __optix_optix_7_device_h__\n\n#if defined( __cplusplus ) && ( __cplusplus < 201103L ) && !defined( _WIN32 )\n#error Device code for OptiX requires at least C++11. Consider adding \"--std c++11\" to the nvcc command-line.\n#endif\n\n#include \"optix_7_types.h\"\n\n/// \\defgroup optix_device_api Device API\n/// \\brief OptiX Device API\n\n/** \\addtogroup optix_device_api\n@{\n*/\n\n/// Initiates a ray tracing query starting with the given traversable (overload without payload).\n///\n/// \\param[in] handle\n/// \\param[in] rayOrigin\n/// \\param[in] rayDirection\n/// \\param[in] tmin\n/// \\param[in] tmax\n/// \\param[in] rayTime\n/// \\param[in] visibilityMask really only 8 bits\n/// \\param[in] rayFlags       really only 8 bits, combination of OptixRayFlags\n/// \\param[in] SBToffset      really only 8 bits\n/// \\param[in] SBTstride      really only 8 bits\n/// \\param[in] missSBTIndex   specifies the miss program invoked on a miss\nstatic __forceinline__ __device__ void optixTrace( OptixTraversableHandle handle,\n                                                   float3                 rayOrigin,\n                                                   float3                 rayDirection,\n                                                   float                  tmin,\n                                                   float                  tmax,\n                                                   float                  rayTime,\n                                                   OptixVisibilityMask    visibilityMask,\n                                                   unsigned int           rayFlags,\n                                                   unsigned int           SBToffset,\n                                                   unsigned int           SBTstride,\n                                                   unsigned int           missSBTIndex );\n/// Initiates a ray tracing query starting with the given traversable (overload with 1 payload registers).\n///\n/// \\see #optixTrace(OptixTraversableHandle,float3,float3,float,float,float,OptixVisibilityMask,unsigned int,unsigned int,unsigned int,unsigned int)\nstatic __forceinline__ __device__ void optixTrace( OptixTraversableHandle handle,\n                                                   float3                 rayOrigin,\n                                                   float3                 rayDirection,\n                                                   float                  tmin,\n                                                   float                  tmax,\n                                                   float                  rayTime,\n                                                   OptixVisibilityMask    visibilityMask,\n                                                   unsigned int           rayFlags,\n                                                   unsigned int           SBToffset,\n                                                   unsigned int           SBTstride,\n                                                   unsigned int           missSBTIndex,\n                                                   unsigned int&          p0 );\n\n/// Initiates a ray tracing query starting with the given traversable (overload with 2 payload registers).\n///\n/// \\see #optixTrace(OptixTraversableHandle,float3,float3,float,float,float,OptixVisibilityMask,unsigned int,unsigned int,unsigned int,unsigned int)\nstatic __forceinline__ __device__ void optixTrace( OptixTraversableHandle handle,\n                                                   float3                 rayOrigin,\n                                                   float3                 rayDirection,\n                                                   float                  tmin,\n                                                   float                  tmax,\n                                                   float                  rayTime,\n                                                   OptixVisibilityMask    visibilityMask,\n                                                   unsigned int           rayFlags,\n                                                   unsigned int           SBToffset,\n                                                   unsigned int           SBTstride,\n                                                   unsigned int           missSBTIndex,\n                                                   unsigned int&          p0,\n                                                   unsigned int&          p1 );\n\n/// Initiates a ray tracing query starting with the given traversable (overload with 3 payload registers).\n///\n/// \\see #optixTrace(OptixTraversableHandle,float3,float3,float,float,float,OptixVisibilityMask,unsigned int,unsigned int,unsigned int,unsigned int)\nstatic __forceinline__ __device__ void optixTrace( OptixTraversableHandle handle,\n                                                   float3                 rayOrigin,\n                                                   float3                 rayDirection,\n                                                   float                  tmin,\n                                                   float                  tmax,\n                                                   float                  rayTime,\n                                                   OptixVisibilityMask    visibilityMask,\n                                                   unsigned int           rayFlags,\n                                                   unsigned int           SBToffset,\n                                                   unsigned int           SBTstride,\n                                                   unsigned int           missSBTIndex,\n                                                   unsigned int&          p0,\n                                                   unsigned int&          p1,\n                                                   unsigned int&          p2 );\n\n/// Initiates a ray tracing query starting with the given traversable (overload with 4 payload registers).\n///\n/// \\see #optixTrace(OptixTraversableHandle,float3,float3,float,float,float,OptixVisibilityMask,unsigned int,unsigned int,unsigned int,unsigned int)\nstatic __forceinline__ __device__ void optixTrace( OptixTraversableHandle handle,\n                                                   float3                 rayOrigin,\n                                                   float3                 rayDirection,\n                                                   float                  tmin,\n                                                   float                  tmax,\n                                                   float                  rayTime,\n                                                   OptixVisibilityMask    visibilityMask,\n                                                   unsigned int           rayFlags,\n                                                   unsigned int           SBToffset,\n                                                   unsigned int           SBTstride,\n                                                   unsigned int           missSBTIndex,\n                                                   unsigned int&          p0,\n                                                   unsigned int&          p1,\n                                                   unsigned int&          p2,\n                                                   unsigned int&          p3 );\n\n/// Initiates a ray tracing query starting with the given traversable (overload with 5 payload registers).\n///\n/// \\see #optixTrace(OptixTraversableHandle,float3,float3,float,float,float,OptixVisibilityMask,unsigned int,unsigned int,unsigned int,unsigned int)\nstatic __forceinline__ __device__ void optixTrace( OptixTraversableHandle handle,\n                                                   float3                 rayOrigin,\n                                                   float3                 rayDirection,\n                                                   float                  tmin,\n                                                   float                  tmax,\n                                                   float                  rayTime,\n                                                   OptixVisibilityMask    visibilityMask,\n                                                   unsigned int           rayFlags,\n                                                   unsigned int           SBToffset,\n                                                   unsigned int           SBTstride,\n                                                   unsigned int           missSBTIndex,\n                                                   unsigned int&          p0,\n                                                   unsigned int&          p1,\n                                                   unsigned int&          p2,\n                                                   unsigned int&          p3,\n                                                   unsigned int&          p4 );\n\n/// Initiates a ray tracing query starting with the given traversable (overload with 6 payload registers).\n///\n/// \\see #optixTrace(OptixTraversableHandle,float3,float3,float,float,float,OptixVisibilityMask,unsigned int,unsigned int,unsigned int,unsigned int)\nstatic __forceinline__ __device__ void optixTrace( OptixTraversableHandle handle,\n                                                   float3                 rayOrigin,\n                                                   float3                 rayDirection,\n                                                   float                  tmin,\n                                                   float                  tmax,\n                                                   float                  rayTime,\n                                                   OptixVisibilityMask    visibilityMask,\n                                                   unsigned int           rayFlags,\n                                                   unsigned int           SBToffset,\n                                                   unsigned int           SBTstride,\n                                                   unsigned int           missSBTIndex,\n                                                   unsigned int&          p0,\n                                                   unsigned int&          p1,\n                                                   unsigned int&          p2,\n                                                   unsigned int&          p3,\n                                                   unsigned int&          p4,\n                                                   unsigned int&          p5 );\n\n/// Initiates a ray tracing query starting with the given traversable (overload with 7 payload registers).\n///\n/// \\see #optixTrace(OptixTraversableHandle,float3,float3,float,float,float,OptixVisibilityMask,unsigned int,unsigned int,unsigned int,unsigned int)\nstatic __forceinline__ __device__ void optixTrace( OptixTraversableHandle handle,\n                                                   float3                 rayOrigin,\n                                                   float3                 rayDirection,\n                                                   float                  tmin,\n                                                   float                  tmax,\n                                                   float                  rayTime,\n                                                   OptixVisibilityMask    visibilityMask,\n                                                   unsigned int           rayFlags,\n                                                   unsigned int           SBToffset,\n                                                   unsigned int           SBTstride,\n                                                   unsigned int           missSBTIndex,\n                                                   unsigned int&          p0,\n                                                   unsigned int&          p1,\n                                                   unsigned int&          p2,\n                                                   unsigned int&          p3,\n                                                   unsigned int&          p4,\n                                                   unsigned int&          p5,\n                                                   unsigned int&          p6 );\n\n/// Initiates a ray tracing query starting with the given traversable (overload with 8 payload registers).\n///\n/// \\see #optixTrace(OptixTraversableHandle,float3,float3,float,float,float,OptixVisibilityMask,unsigned int,unsigned int,unsigned int,unsigned int)\nstatic __forceinline__ __device__ void optixTrace( OptixTraversableHandle handle,\n                                                   float3                 rayOrigin,\n                                                   float3                 rayDirection,\n                                                   float                  tmin,\n                                                   float                  tmax,\n                                                   float                  rayTime,\n                                                   OptixVisibilityMask    visibilityMask,\n                                                   unsigned int           rayFlags,\n                                                   unsigned int           SBToffset,\n                                                   unsigned int           SBTstride,\n                                                   unsigned int           missSBTIndex,\n                                                   unsigned int&          p0,\n                                                   unsigned int&          p1,\n                                                   unsigned int&          p2,\n                                                   unsigned int&          p3,\n                                                   unsigned int&          p4,\n                                                   unsigned int&          p5,\n                                                   unsigned int&          p6,\n                                                   unsigned int&          p7 );\n\n\n/// Writes the 32-bit payload value at slot 0.\nstatic __forceinline__ __device__ void optixSetPayload_0( unsigned int p );\n/// Writes the 32-bit payload value at slot 1.\nstatic __forceinline__ __device__ void optixSetPayload_1( unsigned int p );\n/// Writes the 32-bit payload value at slot 2.\nstatic __forceinline__ __device__ void optixSetPayload_2( unsigned int p );\n/// Writes the 32-bit payload value at slot 3.\nstatic __forceinline__ __device__ void optixSetPayload_3( unsigned int p );\n/// Writes the 32-bit payload value at slot 4.\nstatic __forceinline__ __device__ void optixSetPayload_4( unsigned int p );\n/// Writes the 32-bit payload value at slot 5.\nstatic __forceinline__ __device__ void optixSetPayload_5( unsigned int p );\n/// Writes the 32-bit payload value at slot 6.\nstatic __forceinline__ __device__ void optixSetPayload_6( unsigned int p );\n/// Writes the 32-bit payload value at slot 7.\nstatic __forceinline__ __device__ void optixSetPayload_7( unsigned int p );\n\n\n/// Reads the 32-bit payload value at slot 0.\nstatic __forceinline__ __device__ unsigned int optixGetPayload_0();\n/// Reads the 32-bit payload value at slot 1.\nstatic __forceinline__ __device__ unsigned int optixGetPayload_1();\n/// Reads the 32-bit payload value at slot 2.\nstatic __forceinline__ __device__ unsigned int optixGetPayload_2();\n/// Reads the 32-bit payload value at slot 3.\nstatic __forceinline__ __device__ unsigned int optixGetPayload_3();\n/// Reads the 32-bit payload value at slot 4.\nstatic __forceinline__ __device__ unsigned int optixGetPayload_4();\n/// Reads the 32-bit payload value at slot 5.\nstatic __forceinline__ __device__ unsigned int optixGetPayload_5();\n/// Reads the 32-bit payload value at slot 6.\nstatic __forceinline__ __device__ unsigned int optixGetPayload_6();\n/// Reads the 32-bit payload value at slot 7.\nstatic __forceinline__ __device__ unsigned int optixGetPayload_7();\n\n\n/// Returns an undefined value.\nstatic __forceinline__ __device__ unsigned int optixUndefinedValue();\n\n/// Returns the rayOrigin passed into optixTrace.\n///\n/// May be more expensive to call in IS and AH than their object space counterparts,\n/// so effort should be made to use the object space ray in those programs.\n/// Only available in IS, AH, CH, MS\nstatic __forceinline__ __device__ float3 optixGetWorldRayOrigin();\n\n/// Returns the rayDirection passed into optixTrace.\n///\n/// May be more expensive to call in IS and AH than their object space counterparts,\n/// so effort should be made to use the object space ray in those programs.\n/// Only available in IS, AH, CH, MS\nstatic __forceinline__ __device__ float3 optixGetWorldRayDirection();\n\n/// Returns the current object space ray origin based on the current transform stack.\n///\n/// Only available in IS and AH.\nstatic __forceinline__ __device__ float3 optixGetObjectRayOrigin();\n\n/// Returns the current object space ray direction based on the current transform stack.\n///\n/// Only available in IS and AH.\nstatic __forceinline__ __device__ float3 optixGetObjectRayDirection();\n\n/// Returns the tmin passed into optixTrace.\n///\n/// Only available in IS, AH, CH, MS\nstatic __forceinline__ __device__ float optixGetRayTmin();\n\n/// In IS and CH returns the current smallest reported hitT or the tmax passed into optixTrace if no hit has been reported\n/// In AH returns the hitT value as passed in to optixReportIntersection\n/// In MS returns the tmax passed into optixTrace\n/// Only available in IS, AH, CH, MS\nstatic __forceinline__ __device__ float optixGetRayTmax();\n\n/// Returns the rayTime passed into optixTrace.\n///\n/// Will return 0 if motion is disabled.\n/// Only available in IS, AH, CH, MS\nstatic __forceinline__ __device__ float optixGetRayTime();\n\n/// Returns the rayFlags passed into optixTrace\n///\n/// Only available in IS, AH, CH, MS\nstatic __forceinline__ __device__ unsigned int optixGetRayFlags();\n\n/// Returns the visibilityMask passed into optixTrace\n///\n/// Only available in IS, AH, CH, MS\nstatic __forceinline__ __device__ unsigned int optixGetRayVisibilityMask();\n\n/// Return the traversable handle of a given instance in an Instance \n/// Acceleration Structure (IAS)\nstatic __forceinline__ __device__ OptixTraversableHandle optixGetInstanceTraversableFromIAS( OptixTraversableHandle ias, unsigned int instIdx );\n\n/// Return the object space triangle vertex positions of a given triangle in a Geometry\n/// Acceleration Structure (GAS) at a given motion time.\n/// To access vertex data, the GAS must be built using the flag OPTIX_BUILD_FLAG_ALLOW_RANDOM_VERTEX_ACCESS.\n///\n/// If motion is disabled via OptixPipelineCompileOptions::usesMotionBlur, or the GAS does not contain motion, the\n/// time parameter is ignored.\nstatic __forceinline__ __device__ void optixGetTriangleVertexData( OptixTraversableHandle gas, unsigned int primIdx, unsigned int sbtGASIndex, float time, float3 data[3]);\n\n/// Return the object space curve control vertex data of a linear curve in a Geometry\n/// Acceleration Structure (GAS) at a given motion time.\n/// To access vertex data, the GAS must be built using the flag OPTIX_BUILD_FLAG_ALLOW_RANDOM_VERTEX_ACCESS.\n///\n/// data[i] = {x,y,z,w} with {x,y,z} the position and w the radius of control vertex i.\n/// If motion is disabled via OptixPipelineCompileOptions::usesMotionBlur, or the GAS does not contain motion, the\n/// time parameter is ignored.\nstatic __forceinline__ __device__ void optixGetLinearCurveVertexData( OptixTraversableHandle gas, unsigned int primIdx, unsigned int sbtGASIndex, float time, float4 data[2] );\n\n/// Return the object space curve control vertex data of a quadratic BSpline curve in a Geometry\n/// Acceleration Structure (GAS) at a given motion time.\n/// To access vertex data, the GAS must be built using the flag OPTIX_BUILD_FLAG_ALLOW_RANDOM_VERTEX_ACCESS.\n///\n/// data[i] = {x,y,z,w} with {x,y,z} the position and w the radius of control vertex i.\n/// If motion is disabled via OptixPipelineCompileOptions::usesMotionBlur, or the GAS does not contain motion, the\n/// time parameter is ignored.\nstatic __forceinline__ __device__ void optixGetQuadraticBSplineVertexData( OptixTraversableHandle gas, unsigned int primIdx, unsigned int sbtGASIndex, float time, float4 data[3] );\n\n/// Return the object space curve control vertex data of a cubic BSpline curve in a Geometry\n/// Acceleration Structure (GAS) at a given motion time.\n/// To access vertex data, the GAS must be built using the flag OPTIX_BUILD_FLAG_ALLOW_RANDOM_VERTEX_ACCESS.\n///\n/// data[i] = {x,y,z,w} with {x,y,z} the position and w the radius of control vertex i.\n/// If motion is disabled via OptixPipelineCompileOptions::usesMotionBlur, or the GAS does not contain motion, the\n/// time parameter is ignored.\nstatic __forceinline__ __device__ void optixGetCubicBSplineVertexData( OptixTraversableHandle gas, unsigned int primIdx, unsigned int sbtGASIndex, float time, float4 data[4] );\n\n/// Returns the traversable handle for the Geometry Acceleration Structure (GAS) containing\n/// the current hit. May be called from IS, AH and CH.\nstatic __forceinline__ __device__ OptixTraversableHandle optixGetGASTraversableHandle();\n\n/// Returns the motion begin time of a GAS (see OptixMotionOptions)\nstatic __forceinline__ __device__ float optixGetGASMotionTimeBegin( OptixTraversableHandle gas );\n\n/// Returns the motion end time of a GAS (see OptixMotionOptions)\nstatic __forceinline__ __device__ float optixGetGASMotionTimeEnd( OptixTraversableHandle gas );\n\n/// Returns the number of motion steps of a GAS (see OptixMotionOptions)\nstatic __forceinline__ __device__ unsigned int optixGetGASMotionStepCount( OptixTraversableHandle gas );\n\n/// Returns the world-to-object transformation matrix resulting from the current active transformation list.\n///\n/// The cost of this function may be proportional to the size of the transformation list.\nstatic __forceinline__ __device__ void optixGetWorldToObjectTransformMatrix( float m[12] );\n\n/// Returns the object-to-world transformation matrix resulting from the current active transformation list.\n///\n/// The cost of this function may be proportional to the size of the transformation list.\nstatic __forceinline__ __device__ void optixGetObjectToWorldTransformMatrix( float m[12] );\n\n/// Transforms the point using world-to-object transformation matrix resulting from the current active transformation\n/// list.\n///\n/// The cost of this function may be proportional to the size of the transformation list.\nstatic __forceinline__ __device__ float3 optixTransformPointFromWorldToObjectSpace( float3 point );\n\n/// Transforms the vector using world-to-object transformation matrix resulting from the current active transformation\n/// list.\n///\n/// The cost of this function may be proportional to the size of the transformation list.\nstatic __forceinline__ __device__ float3 optixTransformVectorFromWorldToObjectSpace( float3 vec );\n\n/// Transforms the normal using world-to-object transformation matrix resulting from the current active transformation\n/// list.\n///\n/// The cost of this function may be proportional to the size of the transformation list.\nstatic __forceinline__ __device__ float3 optixTransformNormalFromWorldToObjectSpace( float3 normal );\n\n/// Transforms the point using object-to-world transformation matrix resulting from the current active transformation\n/// list.\n///\n/// The cost of this function may be proportional to the size of the transformation list.\nstatic __forceinline__ __device__ float3 optixTransformPointFromObjectToWorldSpace( float3 point );\n\n/// Transforms the vector using object-to-world transformation matrix resulting from the current active transformation\n/// list.\n///\n/// The cost of this function may be proportional to the size of the transformation list.\nstatic __forceinline__ __device__ float3 optixTransformVectorFromObjectToWorldSpace( float3 vec );\n\n/// Transforms the normal using object-to-world transformation matrix resulting from the current active transformation\n/// list.\n///\n/// The cost of this function may be proportional to the size of the transformation list.\nstatic __forceinline__ __device__ float3 optixTransformNormalFromObjectToWorldSpace( float3 normal );\n\n/// Returns the number of transforms on the current transform list.\n///\n/// Only available in IS, AH, CH, EX\nstatic __forceinline__ __device__ unsigned int optixGetTransformListSize();\n\n/// Returns the traversable handle for a transform on the current transform list.\n///\n/// Only available in IS, AH, CH, EX\nstatic __forceinline__ __device__ OptixTraversableHandle optixGetTransformListHandle( unsigned int index );\n\n\n/// Returns the transform type of a traversable handle from a transform list.\nstatic __forceinline__ __device__ OptixTransformType optixGetTransformTypeFromHandle( OptixTraversableHandle handle );\n\n/// Returns a pointer to a OptixStaticTransform from its traversable handle.\n///\n/// Returns 0 if the traversable is not of type OPTIX_TRANSFORM_TYPE_STATIC_TRANSFORM.\nstatic __forceinline__ __device__ const OptixStaticTransform* optixGetStaticTransformFromHandle( OptixTraversableHandle handle );\n\n/// Returns a pointer to a OptixSRTMotionTransform from its traversable handle.\n///\n/// Returns 0 if the traversable is not of type OPTIX_TRANSFORM_TYPE_SRT_MOTION_TRANSFORM.\nstatic __forceinline__ __device__ const OptixSRTMotionTransform* optixGetSRTMotionTransformFromHandle( OptixTraversableHandle handle );\n\n/// Returns a pointer to a OptixMatrixMotionTransform from its traversable handle.\n///\n/// Returns 0 if the traversable is not of type OPTIX_TRANSFORM_TYPE_MATRIX_MOTION_TRANSFORM.\nstatic __forceinline__ __device__ const OptixMatrixMotionTransform* optixGetMatrixMotionTransformFromHandle( OptixTraversableHandle handle );\n\n/// Returns instanceId from an OptixInstance traversable.\n///\n/// Returns 0 if the traversable handle does not reference an OptixInstance.\nstatic __forceinline__ __device__ unsigned int optixGetInstanceIdFromHandle( OptixTraversableHandle handle );\n\n/// Returns child traversable handle from an OptixInstance traversable.\n///\n/// Returns 0 if the traversable handle does not reference an OptixInstance.\nstatic __forceinline__ __device__ OptixTraversableHandle optixGetInstanceChildFromHandle( OptixTraversableHandle handle );\n\n/// Returns object-to-world transform from an OptixInstance traversable.\n///\n/// Returns 0 if the traversable handle does not reference an OptixInstance.\nstatic __forceinline__ __device__ const float4* optixGetInstanceTransformFromHandle( OptixTraversableHandle handle );\n\n/// Returns world-to-object transform from an OptixInstance traversable.\n///\n/// Returns 0 if the traversable handle does not reference an OptixInstance.\nstatic __forceinline__ __device__ const float4* optixGetInstanceInverseTransformFromHandle( OptixTraversableHandle handle );\n\n/// Reports an intersections (overload without attributes).\n///\n/// If optixGetRayTmin() <= hitT <= optixGetRayTmax(), the any hit program associated with this intersection program (via the SBT entry) is called.\n/// The AH program can do one of three things:\n/// 1. call optixIgnoreIntersection - no hit is recorded, optixReportIntersection returns false\n/// 2. call optixTerminateRay       -    hit is recorded, optixReportIntersection does not return, no further traversal occurs,\n///                                       and the associated closest hit program is called\n/// 3. neither                   -    hit is recorded, optixReportIntersection returns true\n/// hitKind - Only the 7 least significant bits should be written [0..127].  Any values above 127 are reserved for built in intersection.  The value can be queried with optixGetHitKind() in AH and CH.\n///\n/// The attributes specified with a0..a7 are available in the AH and CH programs.\n/// Note that the attributes available in the CH program correspond to the closest recorded intersection.\n/// The number of attributes in registers and memory can be configured in the pipeline.\n///\n/// \\param[in] hitT\n/// \\param[in] hitKind\nstatic __forceinline__ __device__ bool optixReportIntersection( float hitT, unsigned int hitKind );\n\n/// Reports an intersection (overload with 1 attribute register).\n///\n/// \\see #optixReportIntersection(float,unsigned int)\nstatic __forceinline__ __device__ bool optixReportIntersection( float hitT, unsigned int hitKind, unsigned int a0 );\n\n/// Reports an intersection (overload with 2 attribute registers).\n///\n/// \\see #optixReportIntersection(float,unsigned int)\nstatic __forceinline__ __device__ bool optixReportIntersection( float hitT, unsigned int hitKind, unsigned int a0, unsigned int a1 );\n\n/// Reports an intersection (overload with 3 attribute registers).\n///\n/// \\see #optixReportIntersection(float,unsigned int)\nstatic __forceinline__ __device__ bool optixReportIntersection( float hitT, unsigned int hitKind, unsigned int a0, unsigned int a1, unsigned int a2 );\n\n/// Reports an intersection (overload with 4 attribute registers).\n///\n/// \\see #optixReportIntersection(float,unsigned int)\nstatic __forceinline__ __device__ bool optixReportIntersection( float        hitT,\n                                                                unsigned int hitKind,\n                                                                unsigned int a0,\n                                                                unsigned int a1,\n                                                                unsigned int a2,\n                                                                unsigned int a3 );\n\n/// Reports an intersection (overload with 5 attribute registers).\n///\n/// \\see #optixReportIntersection(float,unsigned int)\nstatic __forceinline__ __device__ bool optixReportIntersection( float        hitT,\n                                                                unsigned int hitKind,\n                                                                unsigned int a0,\n                                                                unsigned int a1,\n                                                                unsigned int a2,\n                                                                unsigned int a3,\n                                                                unsigned int a4 );\n\n/// Reports an intersection (overload with 6 attribute registers).\n///\n/// \\see #optixReportIntersection(float,unsigned int)\nstatic __forceinline__ __device__ bool optixReportIntersection( float        hitT,\n                                                                unsigned int hitKind,\n                                                                unsigned int a0,\n                                                                unsigned int a1,\n                                                                unsigned int a2,\n                                                                unsigned int a3,\n                                                                unsigned int a4,\n                                                                unsigned int a5 );\n\n/// Reports an intersection (overload with 7 attribute registers).\n///\n/// \\see #optixReportIntersection(float,unsigned int)\nstatic __forceinline__ __device__ bool optixReportIntersection( float        hitT,\n                                                                unsigned int hitKind,\n                                                                unsigned int a0,\n                                                                unsigned int a1,\n                                                                unsigned int a2,\n                                                                unsigned int a3,\n                                                                unsigned int a4,\n                                                                unsigned int a5,\n                                                                unsigned int a6 );\n\n/// Reports an intersection (overload with 8 attribute registers).\n///\n/// \\see #optixReportIntersection(float,unsigned int)\nstatic __forceinline__ __device__ bool optixReportIntersection( float        hitT,\n                                                                unsigned int hitKind,\n                                                                unsigned int a0,\n                                                                unsigned int a1,\n                                                                unsigned int a2,\n                                                                unsigned int a3,\n                                                                unsigned int a4,\n                                                                unsigned int a5,\n                                                                unsigned int a6,\n                                                                unsigned int a7 );\n\n/// Returns the attribute at slot 0.\nstatic __forceinline__ __device__ unsigned int optixGetAttribute_0();\n/// Returns the attribute at slot 1.\nstatic __forceinline__ __device__ unsigned int optixGetAttribute_1();\n/// Returns the attribute at slot 2.\nstatic __forceinline__ __device__ unsigned int optixGetAttribute_2();\n/// Returns the attribute at slot 3.\nstatic __forceinline__ __device__ unsigned int optixGetAttribute_3();\n/// Returns the attribute at slot 4.\nstatic __forceinline__ __device__ unsigned int optixGetAttribute_4();\n/// Returns the attribute at slot 5.\nstatic __forceinline__ __device__ unsigned int optixGetAttribute_5();\n/// Returns the attribute at slot 6.\nstatic __forceinline__ __device__ unsigned int optixGetAttribute_6();\n/// Returns the attribute at slot 7.\nstatic __forceinline__ __device__ unsigned int optixGetAttribute_7();\n\n/// Record the hit, stops traversal, and proceeds to CH.\n///\n/// Available only in AH.\nstatic __forceinline__ __device__ void optixTerminateRay();\n\n/// Discards the hit, and returns control to the calling optixReportIntersection or built-in intersection routine.\n///\n/// Available only in AH.\nstatic __forceinline__ __device__ void optixIgnoreIntersection();\n\n\n/// For a given OptixBuildInputTriangleArray the number of primitives is defined as\n/// \"(OptixBuildInputTriangleArray::indexBuffer == 0) ? OptixBuildInputTriangleArray::numVertices/3 :\n///                                                     OptixBuildInputTriangleArray::numIndexTriplets;\".\n/// For a given OptixBuildInputCustomPrimitiveArray the number of primitives is defined as\n/// numAabbs.\n///\n/// The primitive index returns the index into the array of primitives\n/// plus the primitiveIndexOffset.\n///\n/// In IS and AH this corresponds to the currently intersected primitive.\n/// In CH this corresponds to the primitive index of the closest intersected primitive.\nstatic __forceinline__ __device__ unsigned int optixGetPrimitiveIndex();\n\n/// Returns the Sbt GAS index of the primitive associated with the current intersection.\n///\n/// In IS and AH this corresponds to the currently intersected primitive.\n/// In CH this corresponds to the Sbt GAS index of the closest intersected primitive.\n/// In EX with exception code OPTIX_EXCEPTION_CODE_TRAVERSAL_INVALID_HIT_SBT corresponds to the sbt index within the hit GAS. Returns zero for all other exceptions.\nstatic __forceinline__ __device__ unsigned int optixGetSbtGASIndex();\n\n\n/// Returns the OptixInstance::instanceId of the instance within the top level acceleration structure associated with the current intersection.\n///\n/// When building an acceleration structure using OptixBuildInputInstanceArray each OptixInstance has a user supplied instanceId.\n/// OptixInstance objects reference another acceleration structure.  During traversal the acceleration structures are visited top down.\n/// In the IS and AH programs the OptixInstance::instanceId corresponding to the most recently visited OptixInstance is returned when calling optixGetInstanceId().\n/// In CH optixGetInstanceId() returns the OptixInstance::instanceId when the hit was recorded with optixReportIntersection.\n/// In the case where there is no OptixInstance visited, optixGetInstanceId returns ~0u\nstatic __forceinline__ __device__ unsigned int optixGetInstanceId();\n\n/// Returns the zero-based index of the instance within its instance acceleration structure associated with the current intersection.\n///\n/// In the IS and AH programs the index corresponding to the most recently visited OptixInstance is returned when calling optixGetInstanceIndex().\n/// In CH optixGetInstanceIndex() returns the index when the hit was recorded with optixReportIntersection.\n/// In the case where there is no OptixInstance visited, optixGetInstanceIndex returns 0\nstatic __forceinline__ __device__ unsigned int optixGetInstanceIndex();\n\n/// Returns the 8 bit hit kind associated with the current hit.\n/// \n/// Use optixGetPrimitiveType() to interpret the hit kind.\n/// For custom intersections (primitive type OPTIX_PRIMITIVE_TYPE_CUSTOM),\n/// this is the 7-bit hitKind passed to optixReportIntersection(). \n/// Hit kinds greater than 127 are reserved for built-in primitives.\n///\n/// Available only in AH and CH.\nstatic __forceinline__ __device__ unsigned int optixGetHitKind();\n\n/// Function interpreting the result of #optixGetHitKind().\nstatic __forceinline__ __device__ OptixPrimitiveType optixGetPrimitiveType( unsigned int hitKind );\n\n/// Function interpreting the result of #optixGetHitKind().\nstatic __forceinline__ __device__ bool optixIsFrontFaceHit( unsigned int hitKind );\n\n/// Function interpreting the result of #optixGetHitKind().\nstatic __forceinline__ __device__ bool optixIsBackFaceHit( unsigned int hitKind );\n\n/// Function interpreting the hit kind associated with the current optixReportIntersection.\nstatic __forceinline__ __device__ OptixPrimitiveType optixGetPrimitiveType();\n\n/// Function interpreting the hit kind associated with the current optixReportIntersection.\nstatic __forceinline__ __device__ bool optixIsFrontFaceHit();\n\n/// Function interpreting the hit kind associated with the current optixReportIntersection.\nstatic __forceinline__ __device__ bool optixIsBackFaceHit();\n\n/// Convenience function interpreting the result of #optixGetHitKind().\nstatic __forceinline__ __device__ bool optixIsTriangleHit();\n\n/// Convenience function interpreting the result of #optixGetHitKind().\nstatic __forceinline__ __device__ bool optixIsTriangleFrontFaceHit();\n\n/// Convenience function interpreting the result of #optixGetHitKind().\nstatic __forceinline__ __device__ bool optixIsTriangleBackFaceHit();\n\n/// Convenience function that returns the first two attributes as floats.\n///\n/// When using OptixBuildInputTriangleArray objects, during intersection the barycentric\n/// coordinates are stored into the first two attribute registers.\nstatic __forceinline__ __device__ float2 optixGetTriangleBarycentrics();\n\n/// Convenience function that returns the curve parameter.\n///\n/// When using OptixBuildInputCurveArray objects, during intersection the curve parameter\n/// is stored into the first attribute register.\nstatic __forceinline__ __device__ float optixGetCurveParameter();\n\n/// Available in any program, it returns the current launch index within the launch dimensions specified by optixLaunch on the host.\n///\n/// The raygen program is typically only launched once per launch index.\nstatic __forceinline__ __device__ uint3 optixGetLaunchIndex();\n\n/// Available in any program, it returns the dimensions of the current launch specified by optixLaunch on the host.\nstatic __forceinline__ __device__ uint3 optixGetLaunchDimensions();\n\n/// Returns the generic memory space pointer to the data region (past the header) of the currently active SBT record corresponding to the current program.\nstatic __forceinline__ __device__ CUdeviceptr optixGetSbtDataPointer();\n\n/// Throws a user exception with the given exception code (overload without exception details).\n///\n/// The exception code must be in the range from 0 to 2^30 - 1. Up to 8 optional exception details can be passed. They\n/// can be queried in the EX program using optixGetExceptionDetail_0() to ..._8().\n///\n/// The exception details must not be used to encode pointers to the stack since the current stack is not preserved in\n/// the EX program.\n///\n/// Not available in EX.\n///\n/// \\param[in] exceptionCode The exception code to be thrown.\nstatic __forceinline__ __device__ void optixThrowException( int exceptionCode );\n\n/// Throws a user exception with the given exception code (overload with 1 exception detail).\n///\n/// \\see #optixThrowException(int)\nstatic __forceinline__ __device__ void optixThrowException( int exceptionCode, unsigned int exceptionDetail0 );\n\n/// Throws a user exception with the given exception code (overload with 2 exception details).\n///\n/// \\see #optixThrowException(int)\nstatic __forceinline__ __device__ void optixThrowException( int exceptionCode,\n                                                            unsigned int exceptionDetail0,\n                                                            unsigned int exceptionDetail1 );\n\n/// Throws a user exception with the given exception code (overload with 3 exception details).\n///\n/// \\see #optixThrowException(int)\nstatic __forceinline__ __device__ void optixThrowException( int exceptionCode,\n                                                            unsigned int exceptionDetail0,\n                                                            unsigned int exceptionDetail1,\n                                                            unsigned int exceptionDetail2 );\n\n/// Throws a user exception with the given exception code (overload with 4 exception details).\n///\n/// \\see #optixThrowException(int)\nstatic __forceinline__ __device__ void optixThrowException( int exceptionCode,\n                                                            unsigned int exceptionDetail0,\n                                                            unsigned int exceptionDetail1,\n                                                            unsigned int exceptionDetail2,\n                                                            unsigned int exceptionDetail3 );\n\n/// Throws a user exception with the given exception code (overload with 5 exception details).\n///\n/// \\see #optixThrowException(int)\nstatic __forceinline__ __device__ void optixThrowException( int exceptionCode,\n                                                            unsigned int exceptionDetail0,\n                                                            unsigned int exceptionDetail1,\n                                                            unsigned int exceptionDetail2,\n                                                            unsigned int exceptionDetail3,\n                                                            unsigned int exceptionDetail4 );\n\n/// Throws a user exception with the given exception code (overload with 6 exception details).\n///\n/// \\see #optixThrowException(int)\nstatic __forceinline__ __device__ void optixThrowException( int exceptionCode,\n                                                            unsigned int exceptionDetail0,\n                                                            unsigned int exceptionDetail1,\n                                                            unsigned int exceptionDetail2,\n                                                            unsigned int exceptionDetail3,\n                                                            unsigned int exceptionDetail4,\n                                                            unsigned int exceptionDetail5 );\n\n/// Throws a user exception with the given exception code (overload with 7 exception details).\n///\n/// \\see #optixThrowException(int)\nstatic __forceinline__ __device__ void optixThrowException( int exceptionCode,\n                                                            unsigned int exceptionDetail0,\n                                                            unsigned int exceptionDetail1,\n                                                            unsigned int exceptionDetail2,\n                                                            unsigned int exceptionDetail3,\n                                                            unsigned int exceptionDetail4,\n                                                            unsigned int exceptionDetail5,\n                                                            unsigned int exceptionDetail6 );\n\n/// Throws a user exception with the given exception code (overload with 8 exception details).\n///\n/// \\see #optixThrowException(int)\nstatic __forceinline__ __device__ void optixThrowException( int exceptionCode,\n                                                            unsigned int exceptionDetail0,\n                                                            unsigned int exceptionDetail1,\n                                                            unsigned int exceptionDetail2,\n                                                            unsigned int exceptionDetail3,\n                                                            unsigned int exceptionDetail4,\n                                                            unsigned int exceptionDetail5,\n                                                            unsigned int exceptionDetail6,\n                                                            unsigned int exceptionDetail7 );\n\n/// Returns the exception code.\n///\n/// Only available in EX.\nstatic __forceinline__ __device__ int optixGetExceptionCode();\n\n/// Returns the 32-bit exception detail at slot 0.\n///\n/// The behavior is undefined if the exception is not a user exception, or the used overload #optixThrowException() did\n/// not provide the queried exception detail.\n///\n/// Only available in EX.\nstatic __forceinline__ __device__ unsigned int optixGetExceptionDetail_0();\n\n/// Returns the 32-bit exception detail at slot 1.\n///\n/// \\see #optixGetExceptionDetail_0()\nstatic __forceinline__ __device__ unsigned int optixGetExceptionDetail_1();\n\n/// Returns the 32-bit exception detail at slot 2.\n///\n/// \\see #optixGetExceptionDetail_0()\nstatic __forceinline__ __device__ unsigned int optixGetExceptionDetail_2();\n\n/// Returns the 32-bit exception detail at slot 3.\n///\n/// \\see #optixGetExceptionDetail_0()\nstatic __forceinline__ __device__ unsigned int optixGetExceptionDetail_3();\n\n/// Returns the 32-bit exception detail at slot 4.\n///\n/// \\see #optixGetExceptionDetail_0()\nstatic __forceinline__ __device__ unsigned int optixGetExceptionDetail_4();\n\n/// Returns the 32-bit exception detail at slot 5.\n///\n/// \\see #optixGetExceptionDetail_0()\nstatic __forceinline__ __device__ unsigned int optixGetExceptionDetail_5();\n\n/// Returns the 32-bit exception detail at slot 6.\n///\n/// \\see #optixGetExceptionDetail_0()\nstatic __forceinline__ __device__ unsigned int optixGetExceptionDetail_6();\n\n/// Returns the 32-bit exception detail at slot 7.\n///\n/// \\see #optixGetExceptionDetail_0()\nstatic __forceinline__ __device__ unsigned int optixGetExceptionDetail_7();\n\n/// Returns the invalid traversable handle for exceptions with exception code OPTIX_EXCEPTION_CODE_TRAVERSAL_INVALID_TRAVERSABLE. \n/// \n/// Returns zero for all other exception codes. \n/// \n/// Only available in EX.\nstatic __forceinline__ __device__ OptixTraversableHandle optixGetExceptionInvalidTraversable();\n\n/// Returns the invalid sbt offset for exceptions with exception code OPTIX_EXCEPTION_CODE_TRAVERSAL_INVALID_MISS_SBT and OPTIX_EXCEPTION_CODE_TRAVERSAL_INVALID_HIT_SBT. \n/// \n/// Returns zero for all other exception codes. \n/// \n/// Only available in EX.\nstatic __forceinline__ __device__ int optixGetExceptionInvalidSbtOffset();\n\n/// Returns the invalid ray for exceptions with exception code OPTIX_EXCEPTION_CODE_INVALID_RAY.\n/// Exceptions of type OPTIX_EXCEPTION_CODE_INVALID_RAY are thrown when one or more values that were\n/// passed into optixTrace are either inf or nan.\n///\n/// OptixInvalidRayExceptionDetails::rayTime will always be 0 if OptixPipelineCompileOptions::usesMotionBlur is 0.\n/// Values in the returned struct are all zero for all other exception codes.\n/// \n/// Only available in EX.\nstatic __forceinline__ __device__ OptixInvalidRayExceptionDetails optixGetExceptionInvalidRay();\n\n/// Returns information about an exception with code OPTIX_EXCEPTION_CODE_CALLABLE_PARAMETER_MISMATCH.\n/// \n/// Exceptions of type OPTIX_EXCEPTION_CODE_CALLABLE_PARAMETER_MISMATCH are called when the number of\n/// arguments that were passed into a call to optixDirectCall or optixContinuationCall does not match\n/// the number of parameters of the callable that is called.\n/// Note that the parameters are packed by OptiX into individual 32 bit values, so the number of\n/// expected and passed values may not correspond to the number of arguments passed into optixDirectCall\n/// or optixContinuationCall.\n/// \n/// Values in the returned struct are all zero for all other exception codes.\n/// \n/// Only available in EX.\nstatic __forceinline__ __device__ OptixParameterMismatchExceptionDetails optixGetExceptionParameterMismatch();\n\n/// Returns a string that includes information about the source location that caused the current exception.\n///\n/// The source location is only available for exceptions of type OPTIX_EXCEPTION_CODE_CALLABLE_PARAMETER_MISMATCH,\n/// OPTIX_EXCEPTION_CODE_UNSUPPORTED_PRIMITIVE_TYPE, OPTIX_EXCEPTION_CODE_INVALID_RAY, and for user exceptions.\n/// Line information needs to be present in the input PTX and OptixModuleCompileOptions::debugLevel\n/// may not be set to OPTIX_COMPILE_DEBUG_LEVEL_NONE.\n/// \n/// Returns a NULL pointer if no line information is available.\n/// \n/// Only available in EX.\nstatic __forceinline__ __device__ char* optixGetExceptionLineInfo();\n\n/// Creates a call to the direct callable program at the specified SBT entry.\n/// \n/// This will call the program that was specified in the OptixProgramGroupCallables::entryFunctionNameDC in the\n/// module specified by OptixProgramGroupCallables::moduleDC.\n/// The address of the SBT entry is calculated by OptixShaderBindingTable::callablesRecordBase + ( OptixShaderBindingTable::callablesRecordStrideInBytes * sbtIndex ).\n/// \n/// Behavior is undefined if there is no direct callable program at the specified SBT entry.\n/// \n/// Behavior is undefined if the number of arguments that are being passed in does not match the number of\n/// parameters expected by the program that is called. In that case an exception of type OPTIX_EXCEPTION_CODE_CALLABLE_PARAMETER_MISMATCH \n/// will be thrown if OPTIX_EXCEPTION_FLAG_DEBUG was specified for the OptixPipelineCompileOptions::exceptionFlags.\n///\n/// \\param[in] sbtIndex The offset of the SBT entry of the direct callable program to call relative to OptixShaderBindingTable::callablesRecordBase.\n/// \\param[in] args The arguments to pass to the direct callable program.\ntemplate <typename ReturnT, typename... ArgTypes>\nstatic __forceinline__ __device__ ReturnT optixDirectCall( unsigned int sbtIndex, ArgTypes... args );\n\n\n/// Creates a call to the continuation callable program at the specified SBT entry.\n/// \n/// This will call the program that was specified in the OptixProgramGroupCallables::entryFunctionNameCC in the\n/// module specified by OptixProgramGroupCallables::moduleCC.\n/// The address of the SBT entry is calculated by OptixShaderBindingTable::callablesRecordBase + ( OptixShaderBindingTable::callablesRecordStrideInBytes * sbtIndex ).\n/// As opposed to direct callable programs, continuation callable programs are allowed to call optixTrace recursively.\n/// \n/// Behavior is undefined if there is no continuation callable program at the specified SBT entry.\n/// \n/// Behavior is undefined if the number of arguments that are being passed in does not match the number of\n/// parameters expected by the program that is called. In that case an exception of type OPTIX_EXCEPTION_CODE_CALLABLE_PARAMETER_MISMATCH \n/// will be thrown if OPTIX_EXCEPTION_FLAG_DEBUG was specified for the OptixPipelineCompileOptions::exceptionFlags.\n///\n/// \\param[in] sbtIndex The offset of the SBT entry of the continuation callable program to call relative to OptixShaderBindingTable::callablesRecordBase.\n/// \\param[in] args The arguments to pass to the continuation callable program.\ntemplate <typename ReturnT, typename... ArgTypes>\nstatic __forceinline__ __device__ ReturnT optixContinuationCall( unsigned int sbtIndex, ArgTypes... args );\n\n\n/// optixTexFootprint2D calculates the footprint of a corresponding 2D texture fetch (non-mipmapped).\n///\n/// On Turing and subsequent architectures, a texture footprint instruction allows user programs to\n/// determine the set of texels that would be accessed by an equivalent filtered texture lookup.\n///\n/// \\param[in] tex      CUDA texture object (cast to 64-bit integer)\n/// \\param[in] texInfo  Texture info packed into 32-bit integer, described below.\n/// \\param[in] x        Texture coordinate\n/// \\param[in] y        Texture coordinate\n/// \\param[out] singleMipLevel  Result indicating whether the footprint spans only a single miplevel.\n///\n/// The texture info argument is a packed 32-bit integer with the following layout:\n///\n///   texInfo[31:29] = reserved (3 bits)\n///   texInfo[28:24] = miplevel count (5 bits)\n///   texInfo[23:20] = log2 of tile width (4 bits)\n///   texInfo[19:16] = log2 of tile height (4 bits)\n///   texInfo[15:10] = reserved (6 bits)\n///   texInfo[9:8]   = horizontal wrap mode (2 bits) (CUaddress_mode)\n///   texInfo[7:6]   = vertical wrap mode (2 bits) (CUaddress_mode)\n///   texInfo[5]     = mipmap filter mode (1 bit) (CUfilter_mode)\n///   texInfo[4:0]   = maximum anisotropy (5 bits)\n///\n/// Returns a 16-byte structure (as a uint4) that stores the footprint of a texture request at a\n/// particular \"granularity\", which has the following layout:\n///\n///    struct Texture2DFootprint\n///    {\n///        unsigned long long mask;\n///        unsigned int tileY : 12;\n///        unsigned int reserved1 : 4;\n///        unsigned int dx : 3;\n///        unsigned int dy : 3;\n///        unsigned int reserved2 : 2;\n///        unsigned int granularity : 4;\n///        unsigned int reserved3 : 4;\n///        unsigned int tileX : 12;\n///        unsigned int level : 4;\n///        unsigned int reserved4 : 16;\n///    };\n///\n/// The granularity indicates the size of texel groups that are represented by an 8x8 bitmask. For\n/// example, a granularity of 12 indicates texel groups that are 128x64 texels in size. In a\n/// footprint call, The returned granularity will either be the actual granularity of the result, or\n/// 0 if the footprint call was able to honor the requested granularity (the usual case).\n///\n/// level is the mip level of the returned footprint. Two footprint calls are needed to get the\n/// complete footprint when a texture call spans multiple mip levels.\n///\n/// mask is an 8x8 bitmask of texel groups that are covered, or partially covered, by the footprint.\n/// tileX and tileY give the starting position of the mask in 8x8 texel-group blocks.  For example,\n/// suppose a granularity of 12 (128x64 texels), and tileX=3 and tileY=4. In this case, bit 0 of the\n/// mask (the low order bit) corresponds to texel group coordinates (3*8, 4*8), and texel\n/// coordinates (3*8*128, 4*8*64), within the specified mip level.\n///\n/// If nonzero, dx and dy specify a \"toroidal rotation\" of the bitmask.  Toroidal rotation of a\n/// coordinate in the mask simply means that its value is reduced by 8.  Continuing the example from\n/// above, if dx=0 and dy=0 the mask covers texel groups (3*8, 4*8) to (3*8+7, 4*8+7) inclusive.\n/// If, on the other hand, dx=2, the rightmost 2 columns in the mask have their x coordinates\n/// reduced by 8, and similarly for dy.\n///\n/// See the OptiX SDK for sample code that illustrates how to unpack the result.\nstatic __forceinline__ __device__ uint4 optixTexFootprint2D( unsigned long long tex, unsigned int texInfo, float x, float y, unsigned int* singleMipLevel );\n\n/// optixTexFootprint2DLod calculates the footprint of a corresponding 2D texture fetch (tex2DLod)\n/// \\param[in] tex      CUDA texture object (cast to 64-bit integer)\n/// \\param[in] texInfo  Texture info packed into 32-bit integer, described below.\n/// \\param[in] x        Texture coordinate\n/// \\param[in] y        Texture coordinate\n/// \\param[in] level    Level of detail (lod)\n/// \\param[in] coarse   Requests footprint from coarse miplevel, when the footprint spans two levels.\n/// \\param[out] singleMipLevel  Result indicating whether the footprint spans only a single miplevel.\n/// \\see #optixTexFootprint2D(unsigned long long,unsigned int,float,float,unsigned int*)\nstatic __forceinline__ __device__ uint4\noptixTexFootprint2DLod( unsigned long long tex, unsigned int texInfo, float x, float y, float level, bool coarse, unsigned int* singleMipLevel );\n\n/// optixTexFootprint2DGrad calculates the footprint of a corresponding 2D texture fetch (tex2DGrad)\n/// \\param[in] tex      CUDA texture object (cast to 64-bit integer)\n/// \\param[in] texInfo  Texture info packed into 32-bit integer, described below.\n/// \\param[in] x        Texture coordinate\n/// \\param[in] y        Texture coordinate\n/// \\param[in] dPdx_x   Derivative of x coordinte, which determines level of detail.\n/// \\param[in] dPdx_y   Derivative of x coordinte, which determines level of detail.\n/// \\param[in] dPdy_x   Derivative of y coordinte, which determines level of detail.\n/// \\param[in] dPdy_y   Derivative of y coordinte, which determines level of detail.\n/// \\param[in] coarse   Requests footprint from coarse miplevel, when the footprint spans two levels.\n/// \\param[out] singleMipLevel  Result indicating whether the footprint spans only a single miplevel.\n/// \\see #optixTexFootprint2D(unsigned long long,unsigned int,float,float,unsigned int*)\nstatic __forceinline__ __device__ uint4 optixTexFootprint2DGrad( unsigned long long tex,\n                                                                 unsigned int       texInfo,\n                                                                 float              x,\n                                                                 float              y,\n                                                                 float              dPdx_x,\n                                                                 float              dPdx_y,\n                                                                 float              dPdy_x,\n                                                                 float              dPdy_y,\n                                                                 bool               coarse,\n                                                                 unsigned int*      singleMipLevel );\n\n/*@}*/  // end group optix_device_api\n\n#include \"internal/optix_7_device_impl.h\"\n\n#endif  // __optix_optix_7_device_h__\n"
  },
  {
    "path": "render/optixutils/include/optix_7_host.h",
    "content": "/*\n * Copyright (c) 2021 NVIDIA Corporation.  All rights reserved.\n *\n * NVIDIA Corporation and its licensors retain all intellectual property and proprietary\n * rights in and to this software, related documentation and any modifications thereto.\n * Any use, reproduction, disclosure or distribution of this software and related\n * documentation without an express license agreement from NVIDIA Corporation is strictly\n * prohibited.\n *\n * TO THE MAXIMUM EXTENT PERMITTED BY APPLICABLE LAW, THIS SOFTWARE IS PROVIDED *AS IS*\n * AND NVIDIA AND ITS SUPPLIERS DISCLAIM ALL WARRANTIES, EITHER EXPRESS OR IMPLIED,\n * INCLUDING, BUT NOT LIMITED TO, IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A\n * PARTICULAR PURPOSE.  IN NO EVENT SHALL NVIDIA OR ITS SUPPLIERS BE LIABLE FOR ANY\n * SPECIAL, INCIDENTAL, INDIRECT, OR CONSEQUENTIAL DAMAGES WHATSOEVER (INCLUDING, WITHOUT\n * LIMITATION, DAMAGES FOR LOSS OF BUSINESS PROFITS, BUSINESS INTERRUPTION, LOSS OF\n * BUSINESS INFORMATION, OR ANY OTHER PECUNIARY LOSS) ARISING OUT OF THE USE OF OR\n * INABILITY TO USE THIS SOFTWARE, EVEN IF NVIDIA HAS BEEN ADVISED OF THE POSSIBILITY OF\n * SUCH DAMAGES\n */\n\n/// @file\n/// @author NVIDIA Corporation\n/// @brief  OptiX public API header\n///\n/// OptiX host include file -- includes the host api if compiling host code.\n/// For the math library routines include optix_math.h\n\n#if !defined( __OPTIX_INCLUDE_INTERNAL_HEADERS__ )\n#error(\"optix_7_host.h is an internal header file and must not be used directly.  Please use optix_host.h or optix.h instead.\")\n#endif\n\n#ifndef __optix_optix_7_host_h__\n#define __optix_optix_7_host_h__\n\n#include \"optix_7_types.h\"\n#if !defined( OPTIX_DONT_INCLUDE_CUDA )\n// If OPTIX_DONT_INCLUDE_CUDA is defined, cuda driver types must be defined through other\n// means before including optix headers.\n#include <cuda.h>\n#endif\n\n#ifdef __cplusplus\nextern \"C\" {\n#endif\n\n/// \\defgroup optix_host_api Host API\n/// \\brief OptiX Host API\n\n/// \\defgroup optix_host_api_error_handling Error handling\n/// \\ingroup optix_host_api\n//@{\n\n/// Returns a string containing the name of an error code in the enum.\n///\n/// Output is a string representation of the enum.  For example \"OPTIX_SUCCESS\" for\n/// OPTIX_SUCCESS and \"OPTIX_ERROR_INVALID_VALUE\" for OPTIX_ERROR_INVALID_VALUE.\n///\n/// If the error code is not recognized, \"Unrecognized OptixResult code\" is returned.\n///\n/// \\param[in] result  OptixResult enum to generate string name for\n///\n/// \\see #optixGetErrorString\nconst char* optixGetErrorName( OptixResult result );\n\n/// Returns the description string for an error code.\n///\n/// Output is a string description of the enum.  For example \"Success\" for\n/// OPTIX_SUCCESS and \"Invalid value\" for OPTIX_ERROR_INVALID_VALUE.\n///\n/// If the error code is not recognized, \"Unrecognized OptixResult code\" is returned.\n///\n/// \\param[in] result  OptixResult enum to generate string description for\n///\n/// \\see #optixGetErrorName\nconst char* optixGetErrorString( OptixResult result );\n\n//@}\n/// \\defgroup optix_host_api_device_context Device context\n/// \\ingroup optix_host_api\n//@{\n\n/// Create a device context associated with the CUDA context specified with 'fromContext'.\n///\n/// If zero is specified for 'fromContext', OptiX will use the current CUDA context. The\n/// CUDA context should be initialized before calling optixDeviceContextCreate.\n///\n/// \\param[in] fromContext\n/// \\param[in] options\n/// \\param[out] context\n/// \\return\n/// - OPTIX_ERROR_CUDA_NOT_INITIALIZED\n///   If using zero for 'fromContext' and CUDA has not been initialized yet on the calling\n///   thread.\n/// - OPTIX_ERROR_CUDA_ERROR\n///   CUDA operation failed.\n/// - OPTIX_ERROR_HOST_OUT_OF_MEMORY\n///   Heap allocation failed.\n/// - OPTIX_ERROR_INTERNAL_ERROR\n///   Internal error\nOptixResult optixDeviceContextCreate( CUcontext fromContext, const OptixDeviceContextOptions* options, OptixDeviceContext* context );\n\n/// Destroys all CPU and GPU state associated with the device.\n///\n/// It will attempt to block on CUDA streams that have launch work outstanding.\n///\n/// Any API objects, such as OptixModule and OptixPipeline, not already destroyed will be\n/// destroyed.\n///\n/// Thread safety: A device context must not be destroyed while it is still in use by concurrent API calls in other threads.\nOptixResult optixDeviceContextDestroy( OptixDeviceContext context );\n\n/// Query properties of a device context.\n///\n/// \\param[in] context     the device context to query the property for\n/// \\param[in] property    the property to query\n/// \\param[out] value      pointer to the returned\n/// \\param[in] sizeInBytes size of output\nOptixResult optixDeviceContextGetProperty( OptixDeviceContext context, OptixDeviceProperty property, void* value, size_t sizeInBytes );\n\n/// Sets the current log callback method.\n///\n/// See #OptixLogCallback for more details.\n///\n/// Thread safety: It is guaranteed that the callback itself (callbackFunction and callbackData) are updated atomically.\n/// It is not guaranteed that the callback itself (callbackFunction and callbackData) and the callbackLevel are updated\n/// atomically. It is unspecified when concurrent API calls using the same context start to make use of the new\n/// callback method.\n///\n/// \\param[in] context          the device context\n/// \\param[in] callbackFunction the callback function to call\n/// \\param[in] callbackData     pointer to data passed to callback function while invoking it\n/// \\param[in] callbackLevel    callback level\nOptixResult optixDeviceContextSetLogCallback( OptixDeviceContext context,\n                                              OptixLogCallback   callbackFunction,\n                                              void*              callbackData,\n                                              unsigned int       callbackLevel );\n\n/// Enables or disables the disk cache.\n///\n/// If caching was previously disabled, enabling it will attempt to initialize\n/// the disk cache database using the currently configured cache location. An\n/// error will be returned if initialization fails.\n///\n/// Note that no in-memory cache is used, so no caching behavior will be observed if the disk cache\n/// is disabled.\n///\n/// The cache can be disabled by setting the environment variable OPTIX_CACHE_MAXSIZE=0.\n/// The environment variable takes precedence over this setting.\n/// See #optixDeviceContextSetCacheDatabaseSizes for additional information.\n///\n/// Note that the disk cache can be disabled by the environment variable, but it cannot be enabled\n/// via the environment if it is disabled via the API.\n///\n/// \\param[in] context the device context\n/// \\param[in] enabled 1 to enabled, 0 to disable\nOptixResult optixDeviceContextSetCacheEnabled( OptixDeviceContext context,\n                                               int                enabled );\n\n/// Sets the location of the disk cache.\n///\n/// The location is specified by a directory. This directory should not be used for other purposes\n/// and will be created if it does not exist. An error will be returned if is not possible to\n/// create the disk cache at the specified location for any reason (e.g., the path is invalid or\n/// the directory is not writable). Caching will be disabled if the disk cache cannot be\n/// initialized in the new location. If caching is disabled, no error will be returned until caching\n/// is enabled. If the disk cache is located on a network file share, behavior is undefined.\n///\n/// The location of the disk cache can be overridden with the environment variable OPTIX_CACHE_PATH.\n/// The environment variable takes precedence over this setting.\n///\n/// The default location depends on the operating system:\n/// - Windows: %LOCALAPPDATA%\\\\NVIDIA\\\\OptixCache\n/// - Linux:   /var/tmp/OptixCache_\\<username\\> (or /tmp/OptixCache_\\<username\\> if the first choice is not usable),\n///            the underscore and username suffix are omitted if the username cannot be obtained\n/// - MacOS X: /Library/Application Support/NVIDIA/OptixCache\n///\n/// \\param[in] context  the device context\n/// \\param[in] location directory of disk cache\nOptixResult optixDeviceContextSetCacheLocation( OptixDeviceContext context, const char* location );\n\n/// Sets the low and high water marks for disk cache garbage collection.\n///\n/// Garbage collection is triggered when a new entry is written to the cache and\n/// the current cache data size plus the size of the cache entry that is about\n/// to be inserted exceeds the high water mark. Garbage collection proceeds until\n/// the size reaches the low water mark. Garbage collection will always free enough\n/// space to insert the new entry without exceeding the low water mark. Setting\n/// either limit to zero will disable garbage collection. An error will be returned\n/// if both limits are non-zero and the high water mark is smaller than the low water mark.\n///\n/// Note that garbage collection is performed only on writes to the disk cache. No garbage\n/// collection is triggered on disk cache initialization or immediately when calling this function,\n/// but on subsequent inserting of data into the database.\n///\n/// If the size of a compiled module exceeds the value configured for the high water\n/// mark and garbage collection is enabled, the module will not be added to the cache\n/// and a warning will be added to the log.\n///\n/// The high water mark can be overridden with the environment variable OPTIX_CACHE_MAXSIZE.\n/// The environment variable takes precedence over the function parameters.  The low water mark\n/// will be set to half the value of OPTIX_CACHE_MAXSIZE.  Setting OPTIX_CACHE_MAXSIZE to 0 will\n/// disable the disk cache, but will not alter the contents of the cache.  Negative and non-integer\n/// values will be ignored.\n///\n/// \\param[in] context       the device context\n/// \\param[in] lowWaterMark  the low water mark\n/// \\param[in] highWaterMark the high water mark\nOptixResult optixDeviceContextSetCacheDatabaseSizes( OptixDeviceContext context, size_t lowWaterMark, size_t highWaterMark );\n\n/// Indicates whether the disk cache is enabled or disabled.\n///\n/// \\param[in] context   the device context\n/// \\param[out] enabled  1 if enabled, 0 if disabled\nOptixResult optixDeviceContextGetCacheEnabled( OptixDeviceContext context, int* enabled );\n/// Returns the location of the disk cache.  If the cache has been disabled by setting the environment\n/// variable OPTIX_CACHE_MAXSIZE=0, this function will return an empy string.\n///\n/// \\param[in] context      the device context\n/// \\param[out] location    directory of disk cache, null terminated if locationSize > 0\n/// \\param[in] locationSize locationSize\nOptixResult optixDeviceContextGetCacheLocation( OptixDeviceContext context, char* location, size_t locationSize );\n\n/// Returns the low and high water marks for disk cache garbage collection.  If the cache has been disabled by\n/// setting the environment variable OPTIX_CACHE_MAXSIZE=0, this function will return 0 for the low and high\n/// water marks.\n///\n/// \\param[in] context        the device context\n/// \\param[out] lowWaterMark  the low water mark\n/// \\param[out] highWaterMark the high water mark\nOptixResult optixDeviceContextGetCacheDatabaseSizes( OptixDeviceContext context, size_t* lowWaterMark, size_t* highWaterMark );\n\n//@}\n/// \\defgroup optix_host_api_pipelines Pipelines\n/// \\ingroup optix_host_api\n//@{\n\n/// logString is an optional buffer that contains compiler feedback and errors.  This\n/// information is also passed to the context logger (if enabled), however it may be\n/// difficult to correlate output to the logger to specific API invocations when using\n/// multiple threads.  The output to logString will only contain feedback for this specific\n/// invocation of this API call.\n///\n/// logStringSize as input should be a pointer to the number of bytes backing logString.\n/// Upon return it contains the length of the log message (including the null terminator)\n/// which may be greater than the input value.  In this case, the log message will be\n/// truncated to fit into logString.\n///\n/// If logString or logStringSize are NULL, no output is written to logString.  If\n/// logStringSize points to a value that is zero, no output is written.  This does not\n/// affect output to the context logger if enabled.\n///\n/// \\param[in] context\n/// \\param[in] pipelineCompileOptions\n/// \\param[in] pipelineLinkOptions\n/// \\param[in] programGroups          array of ProgramGroup objects\n/// \\param[in] numProgramGroups       number of ProgramGroup objects\n/// \\param[out] logString             Information will be written to this string. If logStringSize > 0 logString will be null terminated.\n/// \\param[in,out] logStringSize\n/// \\param[out] pipeline\nOptixResult optixPipelineCreate( OptixDeviceContext                 context,\n                                 const OptixPipelineCompileOptions* pipelineCompileOptions,\n                                 const OptixPipelineLinkOptions*    pipelineLinkOptions,\n                                 const OptixProgramGroup*           programGroups,\n                                 unsigned int                       numProgramGroups,\n                                 char*                              logString,\n                                 size_t*                            logStringSize,\n                                 OptixPipeline*                     pipeline );\n\n/// Thread safety: A pipeline must not be destroyed while it is still in use by concurrent API calls in other threads.\nOptixResult optixPipelineDestroy( OptixPipeline pipeline );\n\n/// Sets the stack sizes for a pipeline.\n///\n/// Users are encouraged to see the programming guide and the implementations of the helper functions\n/// to understand how to construct the stack sizes based on their particular needs.\n///\n/// If this method is not used, an internal default implementation is used. The default implementation is correct (but\n/// not necessarily optimal) as long as the maximum depth of call trees of CC and DC programs is at most 2 and no motion transforms are used.\n///\n/// The maxTraversableGraphDepth responds to the maximal number of traversables visited when calling trace.\n/// Every acceleration structure and motion transform count as one level of traversal.\n/// E.g., for a simple IAS (instance acceleration structure) -> GAS (geometry acceleration structure)\n/// traversal graph, the maxTraversableGraphDepth is two.\n/// For IAS -> MT (motion transform) -> GAS, the maxTraversableGraphDepth is three.\n/// Note that it does not matter whether a IAS or GAS has motion or not, it always counts as one.\n/// Launching optix with exceptions turned on (see #OPTIX_EXCEPTION_FLAG_TRACE_DEPTH) will throw an exception\n/// if the specified maxTraversableGraphDepth is too small.\n///\n/// \\param[in] pipeline                             The pipeline to configure the stack size for.\n/// \\param[in] directCallableStackSizeFromTraversal The direct stack size requirement for direct callables invoked from IS or AH.\n/// \\param[in] directCallableStackSizeFromState     The direct stack size requirement for direct callables invoked from RG, MS, or CH.\n/// \\param[in] continuationStackSize                The continuation stack requirement.\n/// \\param[in] maxTraversableGraphDepth             The maximum depth of a traversable graph passed to trace.\nOptixResult optixPipelineSetStackSize( OptixPipeline pipeline,\n                                       unsigned int  directCallableStackSizeFromTraversal,\n                                       unsigned int  directCallableStackSizeFromState,\n                                       unsigned int  continuationStackSize,\n                                       unsigned int  maxTraversableGraphDepth );\n\n//@}\n/// \\defgroup optix_host_api_modules Modules\n/// \\ingroup optix_host_api\n//@{\n\n/// logString is an optional buffer that contains compiler feedback and errors.  This\n/// information is also passed to the context logger (if enabled), however it may be\n/// difficult to correlate output to the logger to specific API invocations when using\n/// multiple threads.  The output to logString will only contain feedback for this specific\n/// invocation of this API call.\n///\n/// logStringSize as input should be a pointer to the number of bytes backing logString.\n/// Upon return it contains the length of the log message (including the null terminator)\n/// which may be greater than the input value.  In this case, the log message will be\n/// truncated to fit into logString.\n///\n/// If logString or logStringSize are NULL, no output is written to logString.  If\n/// logStringSize points to a value that is zero, no output is written.  This does not\n/// affect output to the context logger if enabled.\n///\n/// \\param[in] context\n/// \\param[in] moduleCompileOptions\n/// \\param[in] pipelineCompileOptions All modules in a pipeline need to use the same values for the pipeline compile options.\n/// \\param[in] PTX                    Pointer to the PTX input string.\n/// \\param[in] PTXsize                Parsing proceeds up to PTXsize characters, or the first NUL byte, whichever occurs first.\n/// \\param[out] logString             Information will be written to this string. If logStringSize > 0 logString will be null terminated.\n/// \\param[in,out] logStringSize\n/// \\param[out] module\n///\n/// \\return OPTIX_ERROR_INVALID_VALUE - context is 0, moduleCompileOptions is 0, pipelineCompileOptions is 0, PTX is 0, module is 0.\nOptixResult optixModuleCreateFromPTX( OptixDeviceContext                 context,\n                                      const OptixModuleCompileOptions*   moduleCompileOptions,\n                                      const OptixPipelineCompileOptions* pipelineCompileOptions,\n                                      const char*                        PTX,\n                                      size_t                             PTXsize,\n                                      char*                              logString,\n                                      size_t*                            logStringSize,\n                                      OptixModule*                       module );\n\n/// Call for OptixModule objects created with optixModuleCreateFromPTX and optixModuleDeserialize.\n///\n/// Modules must not be destroyed while they are still used by any program group.\n///\n/// Thread safety: A module must not be destroyed while it is still in use by concurrent API calls in other threads.\nOptixResult optixModuleDestroy( OptixModule module );\n\n/// Returns a module containing the intersection program for the built-in primitive type specified\n/// by the builtinISOptions.  This module must be used as the moduleIS for the OptixProgramGroupHitgroup\n/// in any SBT record for that primitive type.  (The entryFunctionNameIS should be null.)\nOptixResult optixBuiltinISModuleGet( OptixDeviceContext                 context,\n                                     const OptixModuleCompileOptions*   moduleCompileOptions,\n                                     const OptixPipelineCompileOptions* pipelineCompileOptions,\n                                     const OptixBuiltinISOptions*       builtinISOptions,\n                                     OptixModule*                       builtinModule );\n\n//@}\n/// \\defgroup optix_host_api_program_groups Program groups\n/// \\ingroup optix_host_api\n//@{\n\n/// Returns the stack sizes for the given program group.\n///\n/// \\param[in] programGroup the program group\n/// \\param[out] stackSizes  the corresponding stack sizes\nOptixResult optixProgramGroupGetStackSize( OptixProgramGroup programGroup, OptixStackSizes* stackSizes );\n\n/// logString is an optional buffer that contains compiler feedback and errors.  This\n/// information is also passed to the context logger (if enabled), however it may be\n/// difficult to correlate output to the logger to specific API invocations when using\n/// multiple threads.  The output to logString will only contain feedback for this specific\n/// invocation of this API call.\n///\n/// logStringSize as input should be a pointer to the number of bytes backing logString.\n/// Upon return it contains the length of the log message (including the null terminator)\n/// which may be greater than the input value.  In this case, the log message will be\n/// truncated to fit into logString.\n///\n/// If logString or logStringSize are NULL, no output is written to logString.  If\n/// logStringSize points to a value that is zero, no output is written.  This does not\n/// affect output to the context logger if enabled.\n///\n/// Creates numProgramGroups OptiXProgramGroup objects from the specified\n/// OptixProgramGroupDesc array.  The size of the arrays must match.\n///\n/// \\param[in] context\n/// \\param[in] programDescriptions    N * OptixProgramGroupDesc\n/// \\param[in] numProgramGroups       N\n/// \\param[in] options\n/// \\param[out] logString             Information will be written to this string. If logStringSize > 0 logString will be null terminated.\n/// \\param[in,out] logStringSize\n/// \\param[out] programGroups\nOptixResult optixProgramGroupCreate( OptixDeviceContext              context,\n                                     const OptixProgramGroupDesc*    programDescriptions,\n                                     unsigned int                    numProgramGroups,\n                                     const OptixProgramGroupOptions* options,\n                                     char*                           logString,\n                                     size_t*                         logStringSize,\n                                     OptixProgramGroup*              programGroups );\n\n/// Thread safety: A program group must not be destroyed while it is still in use by concurrent API calls in other threads.\nOptixResult optixProgramGroupDestroy( OptixProgramGroup programGroup );\n\n//@}\n/// \\defgroup optix_host_api_launches Launches\n/// \\ingroup optix_host_api\n//@{\n\n/// Where the magic happens.\n///\n/// The stream and pipeline must belong to the same device context.  Multiple launches\n/// may be issues in parallel from multiple threads to different streams.\n///\n/// pipelineParamsSize number of bytes are copied from the device memory pointed to by\n/// pipelineParams before launch.  It is an error if pipelineParamsSize is greater than the\n/// size of the variable declared in modules and identified by\n/// OptixPipelineCompileOptions::pipelineLaunchParamsVariableName. If the launch params\n/// variable was optimized out or not found in the modules linked to the pipeline then\n/// the pipelineParams and pipelineParamsSize parameters are ignored.\n///\n/// sbt points to the shader binding table, which defines shader\n/// groupings and their resources. See the SBT spec.\n///\n/// \\param[in] pipeline\n/// \\param[in] stream\n/// \\param[in] pipelineParams\n/// \\param[in] pipelineParamsSize\n/// \\param[in] sbt\n/// \\param[in] width              number of elements to compute\n/// \\param[in] height             number of elements to compute\n/// \\param[in] depth              number of elements to compute\n///\n/// Thread safety: In the current implementation concurrent launches to the same pipeline are not\n/// supported.  Concurrent launches require separate OptixPipeline objects.\nOptixResult optixLaunch( OptixPipeline                  pipeline,\n                         CUstream                       stream,\n                         CUdeviceptr                    pipelineParams,\n                         size_t                         pipelineParamsSize,\n                         const OptixShaderBindingTable* sbt,\n                         unsigned int                   width,\n                         unsigned int                   height,\n                         unsigned int                   depth );\n\n/// \\param[in]  programGroup               the program group containing the program(s)\n/// \\param[out] sbtRecordHeaderHostPointer  the result sbt record header\nOptixResult optixSbtRecordPackHeader( OptixProgramGroup programGroup, void* sbtRecordHeaderHostPointer );\n\n//@}\n/// \\defgroup optix_host_api_acceleration_structures Acceleration structures\n/// \\ingroup optix_host_api\n//@{\n\n/// \\param[in] context        device context of the pipeline\n/// \\param[in] accelOptions   accel options\n/// \\param[in] buildInputs    an array of OptixBuildInput objects\n/// \\param[in] numBuildInputs number of elements in buildInputs (must be at least 1)\n/// \\param[out] bufferSizes   fills in buffer sizes\nOptixResult optixAccelComputeMemoryUsage( OptixDeviceContext            context,\n                                          const OptixAccelBuildOptions* accelOptions,\n                                          const OptixBuildInput*        buildInputs,\n                                          unsigned int                  numBuildInputs,\n                                          OptixAccelBufferSizes*        bufferSizes );\n\n/// \\param[in] context\n/// \\param[in] stream\n/// \\param[in] accelOptions             accel options\n/// \\param[in] buildInputs              an array of OptixBuildInput objects\n/// \\param[in] numBuildInputs           must be >= 1 for GAS, and == 1 for IAS\n/// \\param[in] tempBuffer               must be a multiple of OPTIX_ACCEL_BUFFER_BYTE_ALIGNMENT\n/// \\param[in] tempBufferSizeInBytes\n/// \\param[in] outputBuffer             must be a multiple of OPTIX_ACCEL_BUFFER_BYTE_ALIGNMENT\n/// \\param[in] outputBufferSizeInBytes\n/// \\param[out] outputHandle\n/// \\param[out] emittedProperties        types of requested properties and output buffers\n/// \\param[in] numEmittedProperties      number of post-build properties to populate (may be zero)\nOptixResult optixAccelBuild( OptixDeviceContext            context,\n                             CUstream                      stream,\n                             const OptixAccelBuildOptions* accelOptions,\n                             const OptixBuildInput*        buildInputs,\n                             unsigned int                  numBuildInputs,\n                             CUdeviceptr                   tempBuffer,\n                             size_t                        tempBufferSizeInBytes,\n                             CUdeviceptr                   outputBuffer,\n                             size_t                        outputBufferSizeInBytes,\n                             OptixTraversableHandle*       outputHandle,\n                             const OptixAccelEmitDesc*     emittedProperties,\n                             unsigned int                  numEmittedProperties );\n\n/// Obtain relocation information, stored in OptixAccelRelocationInfo, for a given context\n/// and acceleration structure's traversable handle.\n///\n/// The relocation information can be passed to optixAccelCheckRelocationCompatibility to\n/// determine if an acceleration structure, referenced by 'handle', can be relocated to a\n/// different device's memory space (see #optixAccelCheckRelocationCompatibility).\n///\n/// When used with optixAccelRelocate, it provides data necessary for doing the relocation.\n///\n/// If the acceleration structure data associated with 'handle' is copied multiple times,\n/// the same OptixAccelRelocationInfo can also be used on all copies.\n///\n/// \\param[in] context\n/// \\param[in] handle\n/// \\param[out] info\n/// \\return OPTIX_ERROR_INVALID_VALUE will be returned for traversable handles that are not from\n/// acceleration structure builds.\nOptixResult optixAccelGetRelocationInfo( OptixDeviceContext context, OptixTraversableHandle handle, OptixAccelRelocationInfo* info );\n\n/// Checks if an acceleration structure built using another OptixDeviceContext (that was\n/// used to fill in 'info') is compatible with the OptixDeviceContext specified in the\n/// 'context' parameter.\n///\n/// Any device is always compatible with itself.\n///\n/// \\param[in] context\n/// \\param[in] info\n/// \\param[out] compatible If OPTIX_SUCCESS is returned 'compatible' will have the value of either:\n/// - 0: This context is not compatible with acceleration structure data associated with 'info'.\n/// - 1: This context is compatible.\nOptixResult optixAccelCheckRelocationCompatibility( OptixDeviceContext context, const OptixAccelRelocationInfo* info, int* compatible );\n\n/// optixAccelRelocate is called to update the acceleration structure after it has been\n/// relocated.  Relocation is necessary when the acceleration structure's location in device\n/// memory has changed.  optixAccelRelocate does not copy the memory.  This function only\n/// operates on the relocated memory who's new location is specified by 'targetAccel'.\n/// optixAccelRelocate also returns the new OptixTraversableHandle associated with\n/// 'targetAccel'.  The original memory (source) is not required to be valid, only the\n/// OptixAccelRelocationInfo.\n///\n/// Before copying the data and calling optixAccelRelocate,\n/// optixAccelCheckRelocationCompatibility should be called to ensure the copy will be\n/// compatible with the destination device context.\n///\n/// The memory pointed to by 'targetAccel' should be allocated with the same size as the\n/// source acceleration.  Similar to the 'outputBuffer' used in optixAccelBuild, this\n/// pointer must be a multiple of OPTIX_ACCEL_BUFFER_BYTE_ALIGNMENT.\n///\n/// The memory in 'targetAccel' must be allocated as long as the accel is in use.\n///\n/// When relocating an accel that contains instances, 'instanceTraversableHandles' and\n/// 'numInstanceTraversableHandles' should be supplied.  These are the traversable handles\n/// of the instances.  These can be used when also relocating the instances.  No updates to\n/// the bounds are performed.  Use optixAccelBuild to update the bounds.\n/// 'instanceTraversableHandles' and 'numInstanceTraversableHandles' may be zero when\n/// relocating bottom level accel (i.e. an accel with no instances).\n///\n/// \\param[in] context\n/// \\param[in] stream\n/// \\param[in] info\n/// \\param[in] instanceTraversableHandles\n/// \\param[in] numInstanceTraversableHandles\n/// \\param[in] targetAccel\n/// \\param[in] targetAccelSizeInBytes\n/// \\param[out] targetHandle\nOptixResult optixAccelRelocate( OptixDeviceContext              context,\n                                CUstream                        stream,\n                                const OptixAccelRelocationInfo* info,\n                                CUdeviceptr                     instanceTraversableHandles,\n                                size_t                          numInstanceTraversableHandles,\n                                CUdeviceptr                     targetAccel,\n                                size_t                          targetAccelSizeInBytes,\n                                OptixTraversableHandle*         targetHandle );\n\n/// After building an acceleration structure, it can be copied in a compacted form to reduce\n/// memory.  In order to be compacted, OPTIX_BUILD_FLAG_ALLOW_COMPACTION must be supplied in\n/// OptixAccelBuildOptions::buildFlags passed to optixAccelBuild.\n///\n/// 'outputBuffer' is the pointer to where the compacted acceleration structure will be\n/// written.  This pointer must be a multiple of OPTIX_ACCEL_BUFFER_BYTE_ALIGNMENT.\n///\n/// The size of the memory specified in 'outputBufferSizeInBytes' should be at least the\n/// value computed using the OPTIX_PROPERTY_TYPE_COMPACTED_SIZE that was reported during\n/// optixAccelBuild.\n///\n/// \\param[in] context\n/// \\param[in] stream\n/// \\param[in] inputHandle\n/// \\param[in] outputBuffer\n/// \\param[in] outputBufferSizeInBytes\n/// \\param[out] outputHandle\nOptixResult optixAccelCompact( OptixDeviceContext      context,\n                               CUstream                stream,\n                               OptixTraversableHandle  inputHandle,\n                               CUdeviceptr             outputBuffer,\n                               size_t                  outputBufferSizeInBytes,\n                               OptixTraversableHandle* outputHandle );\n\n/// \\param[in] onDevice\n/// \\param[in] pointer            pointer to traversable allocated in OptixDeviceContext. This pointer must be a multiple of OPTIX_TRANSFORM_BYTE_ALIGNMENT\n/// \\param[in] traversableType    Type of OptixTraversableHandle to create\n/// \\param[out] traversableHandle traversable handle. traversableHandle must be in host memory\nOptixResult optixConvertPointerToTraversableHandle( OptixDeviceContext      onDevice,\n                                                    CUdeviceptr             pointer,\n                                                    OptixTraversableType    traversableType,\n                                                    OptixTraversableHandle* traversableHandle );\n\n//@}\n/// \\defgroup optix_host_api_denoiser Denoiser\n/// \\ingroup optix_host_api\n//@{\n\n/// Creates a denoiser object with the given options, using built-in inference models\n///\n/// 'modelKind' selects the model used for inference.\n/// Inference for the built-in models can be guided (giving hints to improve image quality) with\n/// albedo and normal vector images in the guide layer (see 'optixDenoiserInvoke').\n/// Use of these images must be enabled in 'OptixDenoiserOptions'.\n///\n/// \\param[in] context\n/// \\param[in] modelKind\n/// \\param[in] options\n/// \\param[out] denoiser\nOptixResult optixDenoiserCreate( OptixDeviceContext context,\n                                 OptixDenoiserModelKind modelKind,\n                                 const OptixDenoiserOptions* options,\n                                 OptixDenoiser* denoiser );\n\n/// Creates a denoiser object with the given options, using a provided inference model\n///\n/// 'userData' and 'userDataSizeInBytes' provide a user model for inference.\n/// The memory passed in userData will be accessed only during the invocation of this function and\n/// can be freed after it returns.\n/// The user model must export only one weight set which determines both the model kind and the\n/// required set of guide images.\n///\n/// \\param[in] context\n/// \\param[in] userData\n/// \\param[in] userDataSizeInBytes\n/// \\param[out] denoiser\nOptixResult optixDenoiserCreateWithUserModel( OptixDeviceContext context,\n                                              const void* userData, size_t userDataSizeInBytes, OptixDenoiser* denoiser );\n\n/// Destroys the denoiser object and any associated host resources.\nOptixResult optixDenoiserDestroy( OptixDenoiser denoiser );\n\n/// Computes the GPU memory resources required to execute the denoiser.\n///\n/// Memory for state and scratch buffers must be allocated with the sizes in 'returnSizes' and scratch memory\n/// passed to optixDenoiserSetup, optixDenoiserInvoke,\n/// optixDenoiserComputeIntensity and optixDenoiserComputeAverageColor.\n/// For tiled denoising an overlap area must be added to each tile on all sides which increases the amount of\n/// memory needed to denoise a tile. In case of tiling use withOverlapScratchSizeInBytes.\n/// If only full resolution images are denoised, withoutOverlapScratchSizeInBytes can be used which is always\n/// smaller than withOverlapScratchSizeInBytes.\n///\n/// 'outputWidth' and 'outputHeight' is the dimension of the image to be denoised (without overlap in case tiling\n/// is being used).\n/// 'outputWidth' and 'outputHeight' must be greater than or equal to the dimensions passed to optixDenoiserSetup.\n///\n/// \\param[in] denoiser\n/// \\param[in] outputWidth\n/// \\param[in] outputHeight\n/// \\param[out] returnSizes\nOptixResult optixDenoiserComputeMemoryResources( const OptixDenoiser denoiser,\n                                                 unsigned int        outputWidth,\n                                                 unsigned int        outputHeight,\n                                                 OptixDenoiserSizes* returnSizes );\n\n/// Initializes the state required by the denoiser.\n///\n/// 'inputWidth' and 'inputHeight' must include overlap on both sides of the image if tiling is being used. The overlap is\n/// returned by #optixDenoiserComputeMemoryResources.\n/// For subsequent calls to #optixDenoiserInvoke 'inputWidth' and 'inputHeight' are the maximum dimensions\n/// of the input layers. Dimensions of the input layers passed to #optixDenoiserInvoke may be different in each\n/// invocation however they always must be smaller than 'inputWidth' and 'inputHeight' passed to #optixDenoiserSetup.\n///\n/// \\param[in] denoiser\n/// \\param[in] stream\n/// \\param[in] inputWidth\n/// \\param[in] inputHeight\n/// \\param[in] denoiserState\n/// \\param[in] denoiserStateSizeInBytes\n/// \\param[in] scratch\n/// \\param[in] scratchSizeInBytes\nOptixResult optixDenoiserSetup( OptixDenoiser denoiser,\n                                CUstream      stream,\n                                unsigned int  inputWidth,\n                                unsigned int  inputHeight,\n                                CUdeviceptr   denoiserState,\n                                size_t        denoiserStateSizeInBytes,\n                                CUdeviceptr   scratch,\n                                size_t        scratchSizeInBytes );\n\n/// Invokes denoiser on a set of input data and produces at least one output image.\n/// State memory must be available during the execution of the\n/// denoiser (or until optixDenoiserSetup is called with a new state memory pointer).\n/// Scratch memory passed is used only for the duration of this function.\n/// Scratch and state memory sizes must have a size greater than or equal to the sizes as returned by\n/// optixDenoiserComputeMemoryResources.\n///\n/// 'inputOffsetX' and 'inputOffsetY' are pixel offsets in the 'inputLayers' image\n/// specifying the beginning of the image without overlap. When denoising an entire image without tiling\n/// there is no overlap and 'inputOffsetX' and 'inputOffsetY' must be zero. When denoising a tile which is\n/// adjacent to one of the four sides of the entire image the corresponding offsets must also be zero since\n/// there is no overlap at the side adjacent to the image border.\n///\n/// 'guideLayer' provides additional information to the denoiser. When providing albedo and normal vector\n/// guide images, the corresponding fields in the 'OptixDenoiserOptions' must be\n/// enabled, see #optixDenoiserCreate.\n/// 'guideLayer' must not be null. If a guide image in 'OptixDenoiserOptions' is not enabled, the\n/// corresponding image in 'OptixDenoiserGuideLayer' is ignored.\n///\n/// If OPTIX_DENOISER_MODEL_KIND_TEMPORAL is selected, a 2d flow image must be given in 'OptixDenoiserGuideLayer'.\n/// It describes for each pixel the flow from the previous to the current frame (a 2d vector in pixel space).\n/// The denoised beauty/AOV of the previous frame must be given in 'previousOutput'.\n/// If this image is not available in the first frame of a sequence, the noisy beauty/AOV from the first frame\n/// and zero flow vectors could be given as a substitute.\n/// For non-temporal model kinds the flow image in 'OptixDenoiserGuideLayer' is ignored.\n/// 'previousOutput' and\n/// 'output' may refer to the same buffer, i.e. 'previousOutput' is first read by this function and later\n/// overwritten with the denoised result. 'output' can be passed as 'previousOutput' to the next frame.\n/// In other model kinds (not temporal) 'previousOutput' is ignored.\n///\n/// The beauty layer must be given as the first entry in 'layers'.\n/// In AOV type model kinds (OPTIX_DENOISER_MODEL_KIND_AOV or in user defined models implementing\n/// kernel-prediction) additional layers for the AOV images can be given.\n/// In each layer the noisy input image is given in 'input', the denoised output is written into the\n/// 'output' image. input and output images may refer to the same buffer, with the restriction that\n/// the pixel formats must be identical for input and output when the blend mode is selected (see\n/// #OptixDenoiserParams).\n///\n/// If OPTIX_DENOISER_MODEL_KIND_TEMPORAL is selected, the\n/// normal vector guide image must be given as 3d vectors in camera space. In the other models only\n/// the x and y channels are used and other channels are ignored.\n///\n/// \\param[in] denoiser\n/// \\param[in] stream\n/// \\param[in] params\n/// \\param[in] denoiserState\n/// \\param[in] denoiserStateSizeInBytes\n/// \\param[in] guideLayer\n/// \\param[in] layers\n/// \\param[in] numLayers\n/// \\param[in] inputOffsetX\n/// \\param[in] inputOffsetY\n/// \\param[in] outputLayer\n/// \\param[in] scratch\n/// \\param[in] scratchSizeInBytes\nOptixResult optixDenoiserInvoke( OptixDenoiser                   denoiser,\n                                 CUstream                        stream,\n                                 const OptixDenoiserParams*      params,\n                                 CUdeviceptr                     denoiserState,\n                                 size_t                          denoiserStateSizeInBytes,\n                                 const OptixDenoiserGuideLayer*  guideLayer,\n                                 const OptixDenoiserLayer*       layers,\n                                 unsigned int                    numLayers,\n                                 unsigned int                    inputOffsetX,\n                                 unsigned int                    inputOffsetY,\n                                 CUdeviceptr                     scratch,\n                                 size_t                          scratchSizeInBytes );\n\n/// Computes the logarithmic average intensity of the given image. The returned value 'outputIntensity'\n/// is multiplied with the RGB values of the input image/tile in optixDenoiserInvoke if given in the parameter\n/// OptixDenoiserParams::hdrIntensity (otherwise 'hdrIntensity' must be a null pointer). This is useful for\n/// denoising HDR images which are very dark or bright.\n/// When denoising tiles the intensity of the entire image should be computed, i.e. not per tile to get\n/// consistent results.\n///\n/// For each RGB pixel in the inputImage the intensity is calculated and summed if it is greater than 1e-8f:\n/// intensity = log(r * 0.212586f + g * 0.715170f + b * 0.072200f).\n/// The function returns 0.18 / exp(sum of intensities / number of summed pixels).\n/// More details could be found in the Reinhard tonemapping paper:\n/// http://www.cmap.polytechnique.fr/~peyre/cours/x2005signal/hdr_photographic.pdf\n///\n/// This function needs scratch memory with a size of at least\n/// sizeof( int ) * ( 2 + inputImage::width * inputImage::height ). When denoising entire images (no tiling)\n/// the same scratch memory as passed to optixDenoiserInvoke could be used.\n//\n/// data type unsigned char is not supported for 'inputImage', it must be 3 or 4 component half/float.\n///\n/// \\param[in] denoiser\n/// \\param[in] stream\n/// \\param[in] inputImage\n/// \\param[out] outputIntensity    single float\n/// \\param[in] scratch\n/// \\param[in] scratchSizeInBytes\nOptixResult optixDenoiserComputeIntensity( OptixDenoiser       denoiser,\n                                           CUstream            stream,\n                                           const OptixImage2D* inputImage,\n                                           CUdeviceptr         outputIntensity,\n                                           CUdeviceptr         scratch,\n                                           size_t              scratchSizeInBytes );\n\n/// Compute average logarithmic for each of the first three channels for the given image.\n/// When denoising tiles the intensity of the entire image should be computed, i.e. not per tile to get\n/// consistent results.\n/// This function needs scratch memory with a size of at least\n/// sizeof( int ) * ( 3 + 3 * inputImage::width * inputImage::height ). When denoising entire images (no tiling)\n/// the same scratch memory as passed to optixDenoiserInvoke could be used.\n///\n/// data type unsigned char is not supported for 'inputImage', it must be 3 or 4 component half/float.\n///\n/// \\param[in] denoiser\n/// \\param[in] stream\n/// \\param[in] inputImage\n/// \\param[out] outputAverageColor three floats\n/// \\param[in] scratch\n/// \\param[in] scratchSizeInBytes\nOptixResult optixDenoiserComputeAverageColor( OptixDenoiser       denoiser,\n                                              CUstream            stream,\n                                              const OptixImage2D* inputImage,\n                                              CUdeviceptr         outputAverageColor,\n                                              CUdeviceptr         scratch,\n                                              size_t              scratchSizeInBytes );\n\n//@}\n\n#ifdef __cplusplus\n}\n#endif\n\n#include \"optix_function_table.h\"\n\n#endif  // __optix_optix_7_host_h__\n"
  },
  {
    "path": "render/optixutils/include/optix_7_types.h",
    "content": "\n/*\n * Copyright (c) 2021 NVIDIA Corporation.  All rights reserved.\n *\n * NVIDIA Corporation and its licensors retain all intellectual property and proprietary\n * rights in and to this software, related documentation and any modifications thereto.\n * Any use, reproduction, disclosure or distribution of this software and related\n * documentation without an express license agreement from NVIDIA Corporation is strictly\n * prohibited.\n *\n * TO THE MAXIMUM EXTENT PERMITTED BY APPLICABLE LAW, THIS SOFTWARE IS PROVIDED *AS IS*\n * AND NVIDIA AND ITS SUPPLIERS DISCLAIM ALL WARRANTIES, EITHER EXPRESS OR IMPLIED,\n * INCLUDING, BUT NOT LIMITED TO, IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A\n * PARTICULAR PURPOSE.  IN NO EVENT SHALL NVIDIA OR ITS SUPPLIERS BE LIABLE FOR ANY\n * SPECIAL, INCIDENTAL, INDIRECT, OR CONSEQUENTIAL DAMAGES WHATSOEVER (INCLUDING, WITHOUT\n * LIMITATION, DAMAGES FOR LOSS OF BUSINESS PROFITS, BUSINESS INTERRUPTION, LOSS OF\n * BUSINESS INFORMATION, OR ANY OTHER PECUNIARY LOSS) ARISING OUT OF THE USE OF OR\n * INABILITY TO USE THIS SOFTWARE, EVEN IF NVIDIA HAS BEEN ADVISED OF THE POSSIBILITY OF\n * SUCH DAMAGES\n */\n\n/// @file\n/// @author NVIDIA Corporation\n/// @brief  OptiX public API header\n///\n/// OptiX types include file -- defines types and enums used by the API.\n/// For the math library routines include optix_math.h\n\n#if !defined( __OPTIX_INCLUDE_INTERNAL_HEADERS__ )\n#error(\"optix_7_types.h is an internal header file and must not be used directly.  Please use optix_types.h, optix_host.h, optix_device.h or optix.h instead.\")\n#endif\n\n#ifndef __optix_optix_7_types_h__\n#define __optix_optix_7_types_h__\n\n#if !defined(__CUDACC_RTC__)\n#include <stddef.h> /* for size_t */\n#endif\n\n/// \\defgroup optix_types Types\n/// \\brief OptiX Types\n\n/** \\addtogroup optix_types\n@{\n*/\n\n// This typedef should match the one in cuda.h in order to avoid compilation errors.\n#if defined(__x86_64) || defined(AMD64) || defined(_M_AMD64) || defined(__powerpc64__) || defined(__EDG_IA64_ABI)/*=NVRTC*/ || defined(__aarch64__)\n/// CUDA device pointer\ntypedef unsigned long long CUdeviceptr;\n#else\n/// CUDA device pointer\ntypedef unsigned int CUdeviceptr;\n#endif\n\n/// Opaque type representing a device context\ntypedef struct OptixDeviceContext_t* OptixDeviceContext;\n\n/// Opaque type representing a module\ntypedef struct OptixModule_t* OptixModule;\n\n/// Opaque type representing a program group\ntypedef struct OptixProgramGroup_t* OptixProgramGroup;\n\n/// Opaque type representing a pipeline\ntypedef struct OptixPipeline_t* OptixPipeline;\n\n/// Opaque type representing a denoiser instance\ntypedef struct OptixDenoiser_t* OptixDenoiser;\n\n/// Traversable handle\ntypedef unsigned long long OptixTraversableHandle;\n\n/// Visibility mask\ntypedef unsigned int OptixVisibilityMask;\n\n/// Size of the SBT record headers.\n#define OPTIX_SBT_RECORD_HEADER_SIZE ( (size_t)32 )\n\n/// Alignment requirement for device pointers in OptixShaderBindingTable.\n#define OPTIX_SBT_RECORD_ALIGNMENT 16ull\n\n/// Alignment requirement for output and temporay buffers for acceleration structures.\n#define OPTIX_ACCEL_BUFFER_BYTE_ALIGNMENT 128ull\n\n/// Alignment requirement for OptixBuildInputInstanceArray::instances.\n#define OPTIX_INSTANCE_BYTE_ALIGNMENT 16ull\n\n/// Alignment requirement for OptixBuildInputCustomPrimitiveArray::aabbBuffers\n#define OPTIX_AABB_BUFFER_BYTE_ALIGNMENT 8ull\n\n/// Alignment requirement for OptixBuildInputTriangleArray::preTransform\n#define OPTIX_GEOMETRY_TRANSFORM_BYTE_ALIGNMENT 16ull\n\n/// Alignment requirement for OptixStaticTransform, OptixMatrixMotionTransform, OptixSRTMotionTransform.\n#define OPTIX_TRANSFORM_BYTE_ALIGNMENT 64ull\n\n/// Maximum number of registers allowed. Defaults to no explicit limit.\n#define OPTIX_COMPILE_DEFAULT_MAX_REGISTER_COUNT 0\n\n/// Maximum number of payload values allowed.\n#define OPTIX_COMPILE_DEFAULT_MAX_PAYLOAD_VALUE_COUNT 8\n\n\n/// Result codes returned from API functions\n///\n/// All host side API functions return OptixResult with the exception of optixGetErrorName\n/// and optixGetErrorString.  When successful OPTIX_SUCCESS is returned.  All return codes\n/// except for OPTIX_SUCCESS should be assumed to be errors as opposed to a warning.\n///\n/// \\see #optixGetErrorName(), #optixGetErrorString()\ntypedef enum OptixResult\n{\n    OPTIX_SUCCESS                               = 0,\n    OPTIX_ERROR_INVALID_VALUE                   = 7001,\n    OPTIX_ERROR_HOST_OUT_OF_MEMORY              = 7002,\n    OPTIX_ERROR_INVALID_OPERATION               = 7003,\n    OPTIX_ERROR_FILE_IO_ERROR                   = 7004,\n    OPTIX_ERROR_INVALID_FILE_FORMAT             = 7005,\n    OPTIX_ERROR_DISK_CACHE_INVALID_PATH         = 7010,\n    OPTIX_ERROR_DISK_CACHE_PERMISSION_ERROR     = 7011,\n    OPTIX_ERROR_DISK_CACHE_DATABASE_ERROR       = 7012,\n    OPTIX_ERROR_DISK_CACHE_INVALID_DATA         = 7013,\n    OPTIX_ERROR_LAUNCH_FAILURE                  = 7050,\n    OPTIX_ERROR_INVALID_DEVICE_CONTEXT          = 7051,\n    OPTIX_ERROR_CUDA_NOT_INITIALIZED            = 7052,\n    OPTIX_ERROR_VALIDATION_FAILURE              = 7053,\n    OPTIX_ERROR_INVALID_PTX                     = 7200,\n    OPTIX_ERROR_INVALID_LAUNCH_PARAMETER        = 7201,\n    OPTIX_ERROR_INVALID_PAYLOAD_ACCESS          = 7202,\n    OPTIX_ERROR_INVALID_ATTRIBUTE_ACCESS        = 7203,\n    OPTIX_ERROR_INVALID_FUNCTION_USE            = 7204,\n    OPTIX_ERROR_INVALID_FUNCTION_ARGUMENTS      = 7205,\n    OPTIX_ERROR_PIPELINE_OUT_OF_CONSTANT_MEMORY = 7250,\n    OPTIX_ERROR_PIPELINE_LINK_ERROR             = 7251,\n    OPTIX_ERROR_ILLEGAL_DURING_TASK_EXECUTE     = 7270,\n    OPTIX_ERROR_INTERNAL_COMPILER_ERROR         = 7299,\n    OPTIX_ERROR_DENOISER_MODEL_NOT_SET          = 7300,\n    OPTIX_ERROR_DENOISER_NOT_INITIALIZED        = 7301,\n    OPTIX_ERROR_ACCEL_NOT_COMPATIBLE            = 7400,\n    OPTIX_ERROR_NOT_SUPPORTED                   = 7800,\n    OPTIX_ERROR_UNSUPPORTED_ABI_VERSION         = 7801,\n    OPTIX_ERROR_FUNCTION_TABLE_SIZE_MISMATCH    = 7802,\n    OPTIX_ERROR_INVALID_ENTRY_FUNCTION_OPTIONS  = 7803,\n    OPTIX_ERROR_LIBRARY_NOT_FOUND               = 7804,\n    OPTIX_ERROR_ENTRY_SYMBOL_NOT_FOUND          = 7805,\n    OPTIX_ERROR_LIBRARY_UNLOAD_FAILURE          = 7806,\n    OPTIX_ERROR_CUDA_ERROR                      = 7900,\n    OPTIX_ERROR_INTERNAL_ERROR                  = 7990,\n    OPTIX_ERROR_UNKNOWN                         = 7999,\n} OptixResult;\n\n/// Parameters used for #optixDeviceContextGetProperty()\n///\n/// \\see #optixDeviceContextGetProperty()\ntypedef enum OptixDeviceProperty\n{\n    /// Maximum value for OptixPipelineLinkOptions::maxTraceDepth. sizeof( unsigned int )\n    OPTIX_DEVICE_PROPERTY_LIMIT_MAX_TRACE_DEPTH = 0x2001,\n\n    /// Maximum value to pass into optixPipelineSetStackSize for parameter\n    /// maxTraversableGraphDepth.v sizeof( unsigned int )\n    OPTIX_DEVICE_PROPERTY_LIMIT_MAX_TRAVERSABLE_GRAPH_DEPTH = 0x2002,\n\n    /// The maximum number of primitives (over all build inputs) as input to a single\n    /// Geometry Acceleration Structure (GAS). sizeof( unsigned int )\n    OPTIX_DEVICE_PROPERTY_LIMIT_MAX_PRIMITIVES_PER_GAS = 0x2003,\n\n    /// The maximum number of instances (over all build inputs) as input to a single\n    /// Instance Acceleration Structure (IAS). sizeof( unsigned int )\n    OPTIX_DEVICE_PROPERTY_LIMIT_MAX_INSTANCES_PER_IAS = 0x2004,\n\n    /// The RT core version supported by the device (0 for no support, 10 for version\n    /// 1.0). sizeof( unsigned int )\n    OPTIX_DEVICE_PROPERTY_RTCORE_VERSION = 0x2005,\n\n    /// The maximum value for #OptixInstance::instanceId. sizeof( unsigned int )\n    OPTIX_DEVICE_PROPERTY_LIMIT_MAX_INSTANCE_ID = 0x2006,\n\n    /// The number of bits available for the #OptixInstance::visibilityMask.\n    /// Higher bits must be set to zero. sizeof( unsigned int )\n    OPTIX_DEVICE_PROPERTY_LIMIT_NUM_BITS_INSTANCE_VISIBILITY_MASK = 0x2007,\n\n    /// The maximum number of instances that can be added to a single Instance\n    /// Acceleration Structure (IAS). sizeof( unsigned int )\n    OPTIX_DEVICE_PROPERTY_LIMIT_MAX_SBT_RECORDS_PER_GAS = 0x2008,\n\n    /// The maximum value for #OptixInstance::sbtOffset. sizeof( unsigned int )\n    OPTIX_DEVICE_PROPERTY_LIMIT_MAX_SBT_OFFSET = 0x2009,\n} OptixDeviceProperty;\n\n/// Type of the callback function used for log messages.\n///\n/// \\param[in] level      The log level indicates the severity of the message. See below for\n///                       possible values.\n/// \\param[in] tag        A terse message category description (e.g., 'SCENE STAT').\n/// \\param[in] message    Null terminated log message (without newline at the end).\n/// \\param[in] cbdata     Callback data that was provided with the callback pointer.\n///\n/// It is the users responsibility to ensure thread safety within this function.\n///\n/// The following log levels are defined.\n///\n///   0   disable   Setting the callback level will disable all messages.  The callback\n///                 function will not be called in this case.\n///   1   fatal     A non-recoverable error. The context and/or OptiX itself might no longer\n///                 be in a usable state.\n///   2   error     A recoverable error, e.g., when passing invalid call parameters.\n///   3   warning   Hints that OptiX might not behave exactly as requested by the user or\n///                 may perform slower than expected.\n///   4   print     Status or progress messages.\n///\n/// Higher levels might occur.\n///\n/// \\see #optixDeviceContextSetLogCallback(), #OptixDeviceContextOptions\ntypedef void ( *OptixLogCallback )( unsigned int level, const char* tag, const char* message, void* cbdata );\n\n/// Validation mode settings.\n///\n/// When enabled, certain device code utilities will be enabled to provide as good debug and\n/// error checking facilities as possible.\n///\n///\n/// \\see #optixDeviceContextCreate()\ntypedef enum OptixDeviceContextValidationMode\n{\n    OPTIX_DEVICE_CONTEXT_VALIDATION_MODE_OFF = 0,\n    OPTIX_DEVICE_CONTEXT_VALIDATION_MODE_ALL = 0xFFFFFFFF\n} OptixDeviceContextValidationMode;\n\n/// Parameters used for #optixDeviceContextCreate()\n///\n/// \\see #optixDeviceContextCreate()\ntypedef struct OptixDeviceContextOptions\n{\n    /// Function pointer used when OptiX wishes to generate messages\n    OptixLogCallback logCallbackFunction;\n    /// Pointer stored and passed to logCallbackFunction when a message is generated\n    void* logCallbackData;\n    /// Maximum callback level to generate message for (see #OptixLogCallback)\n    int logCallbackLevel;\n    /// Validation mode of context.\n    OptixDeviceContextValidationMode validationMode;\n} OptixDeviceContextOptions;\n\n/// Flags used by #OptixBuildInputTriangleArray::flags\n/// and #OptixBuildInputCurveArray::flag\n/// and #OptixBuildInputCustomPrimitiveArray::flags\ntypedef enum OptixGeometryFlags\n{\n    /// No flags set\n    OPTIX_GEOMETRY_FLAG_NONE = 0,\n\n    /// Disables the invocation of the anyhit program.\n    /// Can be overridden by OPTIX_INSTANCE_FLAG_ENFORCE_ANYHIT and OPTIX_RAY_FLAG_ENFORCE_ANYHIT.\n    OPTIX_GEOMETRY_FLAG_DISABLE_ANYHIT = 1u << 0,\n\n    /// If set, an intersection with the primitive will trigger one and only one\n    /// invocation of the anyhit program.  Otherwise, the anyhit program may be invoked\n    /// more than once.\n    OPTIX_GEOMETRY_FLAG_REQUIRE_SINGLE_ANYHIT_CALL = 1u << 1\n} OptixGeometryFlags;\n\n/// Legacy type: A subset of the hit kinds for built-in primitive intersections.\n/// It is preferred to use optixGetPrimitiveType(), together with\n/// optixIsFrontFaceHit() or optixIsBackFaceHit().\n///\n/// \\see #optixGetHitKind()\ntypedef enum OptixHitKind\n{\n    /// Ray hit the triangle on the front face\n    OPTIX_HIT_KIND_TRIANGLE_FRONT_FACE = 0xFE,\n    /// Ray hit the triangle on the back face\n    OPTIX_HIT_KIND_TRIANGLE_BACK_FACE = 0xFF\n} OptixHitKind;\n\n/// Format of indices used int #OptixBuildInputTriangleArray::indexFormat.\ntypedef enum OptixIndicesFormat\n{\n    /// No indices, this format must only be used in combination with triangle soups, i.e., numIndexTriplets must be zero\n    OPTIX_INDICES_FORMAT_NONE = 0,\n    /// Three shorts\n    OPTIX_INDICES_FORMAT_UNSIGNED_SHORT3 = 0x2102,\n    /// Three ints\n    OPTIX_INDICES_FORMAT_UNSIGNED_INT3 = 0x2103\n} OptixIndicesFormat;\n\n/// Format of vertices used in #OptixBuildInputTriangleArray::vertexFormat.\ntypedef enum OptixVertexFormat\n{\n    OPTIX_VERTEX_FORMAT_NONE      = 0,       ///< No vertices\n    OPTIX_VERTEX_FORMAT_FLOAT3    = 0x2121,  ///< Vertices are represented by three floats\n    OPTIX_VERTEX_FORMAT_FLOAT2    = 0x2122,  ///< Vertices are represented by two floats\n    OPTIX_VERTEX_FORMAT_HALF3     = 0x2123,  ///< Vertices are represented by three halfs\n    OPTIX_VERTEX_FORMAT_HALF2     = 0x2124,  ///< Vertices are represented by two halfs\n    OPTIX_VERTEX_FORMAT_SNORM16_3 = 0x2125,\n    OPTIX_VERTEX_FORMAT_SNORM16_2 = 0x2126\n} OptixVertexFormat;\n\n/// Format of transform used in #OptixBuildInputTriangleArray::transformFormat.\ntypedef enum OptixTransformFormat\n{\n    OPTIX_TRANSFORM_FORMAT_NONE           = 0,       ///< no transform, default for zero initialization\n    OPTIX_TRANSFORM_FORMAT_MATRIX_FLOAT12 = 0x21E1,  ///< 3x4 row major affine matrix\n} OptixTransformFormat;\n\n/// Triangle inputs\n///\n/// \\see #OptixBuildInput::triangleArray\ntypedef struct OptixBuildInputTriangleArray\n{\n    /// Points to host array of device pointers, one per motion step. Host array size must match the number of\n    /// motion keys as set in #OptixMotionOptions (or an array of size 1 if OptixMotionOptions::numKeys is set\n    /// to 0 or 1). Each per motion key device pointer must point to an array of vertices of the\n    /// triangles in the format as described by vertexFormat. The minimum alignment must match the natural\n    /// alignment of the type as specified in the vertexFormat, i.e., for OPTIX_VERTEX_FORMAT_FLOATX 4-byte,\n    /// for all others a 2-byte alignment. However, an 16-byte stride (and buffer alignment) is recommended for\n    /// vertices of format OPTIX_VERTEX_FORMAT_FLOAT3 for GAS build performance.\n    const CUdeviceptr* vertexBuffers;\n\n    /// Number of vertices in each of buffer in OptixBuildInputTriangleArray::vertexBuffers.\n    unsigned int numVertices;\n\n    /// \\see #OptixVertexFormat\n    OptixVertexFormat vertexFormat;\n\n    /// Stride between vertices. If set to zero, vertices are assumed to be tightly\n    /// packed and stride is inferred from vertexFormat.\n    unsigned int vertexStrideInBytes;\n\n    /// Optional pointer to array of 16 or 32-bit int triplets, one triplet per triangle.\n    /// The minimum alignment must match the natural alignment of the type as specified in the indexFormat, i.e.,\n    /// for OPTIX_INDICES_FORMAT_UNSIGNED_INT3 4-byte and for OPTIX_INDICES_FORMAT_UNSIGNED_SHORT3 a 2-byte alignment.\n    CUdeviceptr indexBuffer;\n\n    /// Size of array in OptixBuildInputTriangleArray::indexBuffer. For build, needs to be zero if indexBuffer is \\c nullptr.\n    unsigned int numIndexTriplets;\n\n    /// \\see #OptixIndicesFormat\n    OptixIndicesFormat indexFormat;\n\n    /// Stride between triplets of indices. If set to zero, indices are assumed to be tightly\n    /// packed and stride is inferred from indexFormat.\n    unsigned int indexStrideInBytes;\n\n    /// Optional pointer to array of floats\n    /// representing a 3x4 row major affine\n    /// transformation matrix. This pointer must be a multiple of OPTIX_GEOMETRY_TRANSFORM_BYTE_ALIGNMENT\n    CUdeviceptr preTransform;\n\n    /// Array of flags, to specify flags per sbt record,\n    /// combinations of OptixGeometryFlags describing the\n    /// primitive behavior, size must match numSbtRecords\n    const unsigned int* flags;\n\n    /// Number of sbt records available to the sbt index offset override.\n    unsigned int numSbtRecords;\n\n    /// Device pointer to per-primitive local sbt index offset buffer. May be NULL.\n    /// Every entry must be in range [0,numSbtRecords-1].\n    /// Size needs to be the number of primitives.\n    CUdeviceptr sbtIndexOffsetBuffer;\n\n    /// Size of type of the sbt index offset. Needs to be 0, 1, 2 or 4 (8, 16 or 32 bit).\n    unsigned int sbtIndexOffsetSizeInBytes;\n\n    /// Stride between the index offsets. If set to zero, the offsets are assumed to be tightly\n    /// packed and the stride matches the size of the type (sbtIndexOffsetSizeInBytes).\n    unsigned int sbtIndexOffsetStrideInBytes;\n\n    /// Primitive index bias, applied in optixGetPrimitiveIndex().\n    /// Sum of primitiveIndexOffset and number of triangles must not overflow 32bits.\n    unsigned int primitiveIndexOffset;\n\n    /// \\see #OptixTransformFormat\n    OptixTransformFormat transformFormat;\n} OptixBuildInputTriangleArray;\n\n/// Builtin primitive types\n///\ntypedef enum OptixPrimitiveType\n{\n    /// Custom primitive.\n    OPTIX_PRIMITIVE_TYPE_CUSTOM                        = 0x2500,\n    /// B-spline curve of degree 2 with circular cross-section.\n    OPTIX_PRIMITIVE_TYPE_ROUND_QUADRATIC_BSPLINE       = 0x2501,\n    /// B-spline curve of degree 3 with circular cross-section.\n    OPTIX_PRIMITIVE_TYPE_ROUND_CUBIC_BSPLINE           = 0x2502,\n    /// Piecewise linear curve with circular cross-section.\n    OPTIX_PRIMITIVE_TYPE_ROUND_LINEAR                  = 0x2503,\n    /// Triangle.\n    OPTIX_PRIMITIVE_TYPE_TRIANGLE                      = 0x2531,\n} OptixPrimitiveType;\n\n/// Builtin flags may be bitwise combined.\n///\n/// \\see #OptixPipelineCompileOptions::usesPrimitiveTypeFlags\ntypedef enum OptixPrimitiveTypeFlags\n{\n    /// Custom primitive.\n    OPTIX_PRIMITIVE_TYPE_FLAGS_CUSTOM                  = 1 << 0,\n    /// B-spline curve of degree 2 with circular cross-section.\n    OPTIX_PRIMITIVE_TYPE_FLAGS_ROUND_QUADRATIC_BSPLINE = 1 << 1,\n    /// B-spline curve of degree 3 with circular cross-section.\n    OPTIX_PRIMITIVE_TYPE_FLAGS_ROUND_CUBIC_BSPLINE     = 1 << 2,\n    /// Piecewise linear curve with circular cross-section.\n    OPTIX_PRIMITIVE_TYPE_FLAGS_ROUND_LINEAR            = 1 << 3,\n    /// Triangle.\n    OPTIX_PRIMITIVE_TYPE_FLAGS_TRIANGLE                = 1 << 31,\n} OptixPrimitiveTypeFlags;\n\n/// Curve inputs\n///\n/// A curve is a swept surface defined by a 3D spline curve and a varying width (radius). A curve (or \"strand\") of\n/// degree d (3=cubic, 2=quadratic, 1=linear) is represented by N > d vertices and N width values, and comprises N - d segments.\n/// Each segment is defined by d+1 consecutive vertices. Each curve may have a different number of vertices.\n///\n/// OptiX describes the curve array as a list of curve segments. The primitive id is the segment number.\n/// It is the user's responsibility to maintain a mapping between curves and curve segments.\n/// Each index buffer entry i = indexBuffer[primid] specifies the start of a curve segment,\n/// represented by d+1 consecutive vertices in the vertex buffer,\n/// and d+1 consecutive widths in the width buffer. Width is interpolated the same\n/// way vertices are interpolated, that is, using the curve basis.\n///\n/// Each curves build input has only one SBT record.\n/// To create curves with different materials in the same BVH, use multiple build inputs.\n///\n/// \\see #OptixBuildInput::curveArray\ntypedef struct OptixBuildInputCurveArray\n{\n    /// Curve degree and basis\n    /// \\see #OptixPrimitiveType\n    OptixPrimitiveType curveType;\n    /// Number of primitives. Each primitive is a polynomial curve segment.\n    unsigned int numPrimitives;\n\n    /// Pointer to host array of device pointers, one per motion step. Host array size must match number of\n    /// motion keys as set in #OptixMotionOptions (or an array of size 1 if OptixMotionOptions::numKeys is set\n    /// to 1). Each per-motion-key device pointer must point to an array of floats (the vertices of the\n    /// curves).\n    const CUdeviceptr* vertexBuffers;\n    /// Number of vertices in each buffer in vertexBuffers.\n    unsigned int numVertices;\n    /// Stride between vertices. If set to zero, vertices are assumed to be tightly\n    /// packed and stride is sizeof( float3 ).\n    unsigned int vertexStrideInBytes;\n\n    /// Parallel to vertexBuffers: a device pointer per motion step, each with numVertices float values,\n    /// specifying the curve width (radius) corresponding to each vertex.\n    const CUdeviceptr* widthBuffers;\n    /// Stride between widths. If set to zero, widths are assumed to be tightly\n    /// packed and stride is sizeof( float ).\n    unsigned int widthStrideInBytes;\n\n    /// Reserved for future use.\n    const CUdeviceptr* normalBuffers;\n    /// Reserved for future use.\n    unsigned int normalStrideInBytes;\n\n    /// Device pointer to array of unsigned ints, one per curve segment.\n    /// This buffer is required (unlike for OptixBuildInputTriangleArray).\n    /// Each index is the start of degree+1 consecutive vertices in vertexBuffers,\n    /// and corresponding widths in widthBuffers and normals in normalBuffers.\n    /// These define a single segment. Size of array is numPrimitives.\n    CUdeviceptr indexBuffer;\n    /// Stride between indices. If set to zero, indices are assumed to be tightly\n    /// packed and stride is sizeof( unsigned int ).\n    unsigned int indexStrideInBytes;\n\n    /// Combination of OptixGeometryFlags describing the\n    /// primitive behavior.\n    unsigned int flag;\n\n    /// Primitive index bias, applied in optixGetPrimitiveIndex().\n    /// Sum of primitiveIndexOffset and number of primitives must not overflow 32bits.\n    unsigned int primitiveIndexOffset;\n} OptixBuildInputCurveArray;\n\n/// AABB inputs\ntypedef struct OptixAabb\n{\n    float minX;  ///< Lower extent in X direction.\n    float minY;  ///< Lower extent in Y direction.\n    float minZ;  ///< Lower extent in Z direction.\n    float maxX;  ///< Upper extent in X direction.\n    float maxY;  ///< Upper extent in Y direction.\n    float maxZ;  ///< Upper extent in Z direction.\n} OptixAabb;\n\n/// Custom primitive inputs\n///\n/// \\see #OptixBuildInput::customPrimitiveArray\ntypedef struct OptixBuildInputCustomPrimitiveArray\n{\n    /// Points to host array of device pointers to AABBs (type OptixAabb), one per motion step.\n    /// Host array size must match number of motion keys as set in OptixMotionOptions (or an array of size 1\n    /// if OptixMotionOptions::numKeys is set to 1).\n    /// Each device pointer must be a multiple of OPTIX_AABB_BUFFER_BYTE_ALIGNMENT.\n    const CUdeviceptr* aabbBuffers;\n\n    /// Number of primitives in each buffer (i.e., per motion step) in\n    /// #OptixBuildInputCustomPrimitiveArray::aabbBuffers.\n    unsigned int numPrimitives;\n\n    /// Stride between AABBs (per motion key). If set to zero, the aabbs are assumed to be tightly\n    /// packed and the stride is assumed to be sizeof( OptixAabb ).\n    /// If non-zero, the value must be a multiple of OPTIX_AABB_BUFFER_BYTE_ALIGNMENT.\n    unsigned int strideInBytes;\n\n    /// Array of flags, to specify flags per sbt record,\n    /// combinations of OptixGeometryFlags describing the\n    /// primitive behavior, size must match numSbtRecords\n    const unsigned int* flags;\n\n    /// Number of sbt records available to the sbt index offset override.\n    unsigned int numSbtRecords;\n\n    /// Device pointer to per-primitive local sbt index offset buffer. May be NULL.\n    /// Every entry must be in range [0,numSbtRecords-1].\n    /// Size needs to be the number of primitives.\n    CUdeviceptr sbtIndexOffsetBuffer;\n\n    /// Size of type of the sbt index offset. Needs to be 0, 1, 2 or 4 (8, 16 or 32 bit).\n    unsigned int sbtIndexOffsetSizeInBytes;\n\n    /// Stride between the index offsets. If set to zero, the offsets are assumed to be tightly\n    /// packed and the stride matches the size of the type (sbtIndexOffsetSizeInBytes).\n    unsigned int sbtIndexOffsetStrideInBytes;\n\n    /// Primitive index bias, applied in optixGetPrimitiveIndex().\n    /// Sum of primitiveIndexOffset and number of primitive must not overflow 32bits.\n    unsigned int primitiveIndexOffset;\n} OptixBuildInputCustomPrimitiveArray;\n\n/// Instance and instance pointer inputs\n///\n/// \\see #OptixBuildInput::instanceArray\ntypedef struct OptixBuildInputInstanceArray\n{\n    /// If OptixBuildInput::type is OPTIX_BUILD_INPUT_TYPE_INSTANCE_POINTERS instances and\n    /// aabbs should be interpreted as arrays of pointers instead of arrays of structs.\n    ///\n    /// This pointer must be a multiple of OPTIX_INSTANCE_BYTE_ALIGNMENT if\n    /// OptixBuildInput::type is OPTIX_BUILD_INPUT_TYPE_INSTANCES. The array elements must\n    /// be a multiple of OPTIX_INSTANCE_BYTE_ALIGNMENT if OptixBuildInput::type is\n    /// OPTIX_BUILD_INPUT_TYPE_INSTANCE_POINTERS.\n    CUdeviceptr instances;\n\n    /// Number of elements in #OptixBuildInputInstanceArray::instances.\n    unsigned int numInstances;\n} OptixBuildInputInstanceArray;\n\n/// Enum to distinguish the different build input types.\n///\n/// \\see #OptixBuildInput::type\ntypedef enum OptixBuildInputType\n{\n    /// Triangle inputs. \\see #OptixBuildInputTriangleArray\n    OPTIX_BUILD_INPUT_TYPE_TRIANGLES = 0x2141,\n    /// Custom primitive inputs. \\see #OptixBuildInputCustomPrimitiveArray\n    OPTIX_BUILD_INPUT_TYPE_CUSTOM_PRIMITIVES = 0x2142,\n    /// Instance inputs. \\see #OptixBuildInputInstanceArray\n    OPTIX_BUILD_INPUT_TYPE_INSTANCES = 0x2143,\n    /// Instance pointer inputs. \\see #OptixBuildInputInstanceArray\n    OPTIX_BUILD_INPUT_TYPE_INSTANCE_POINTERS = 0x2144,\n    /// Curve inputs. \\see #OptixBuildInputCurveArray\n    OPTIX_BUILD_INPUT_TYPE_CURVES = 0x2145\n} OptixBuildInputType;\n\n/// Build inputs.\n///\n/// All of them support motion and the size of the data arrays needs to match the number of motion steps\n///\n/// \\see #optixAccelComputeMemoryUsage(), #optixAccelBuild()\ntypedef struct OptixBuildInput\n{\n    /// The type of the build input.\n    OptixBuildInputType type;\n\n    union\n    {\n        /// Triangle inputs.\n        OptixBuildInputTriangleArray triangleArray;\n        /// Curve inputs.\n        OptixBuildInputCurveArray curveArray;\n        /// Custom primitive inputs.\n        OptixBuildInputCustomPrimitiveArray customPrimitiveArray;\n        /// Instance and instance pointer inputs.\n        OptixBuildInputInstanceArray instanceArray;\n        char pad[1024];\n    };\n} OptixBuildInput;\n\n// Some 32-bit tools use this header. This static_assert fails for them because\n// the default enum size is 4 bytes, rather than 8, under 32-bit compilers.\n// This #ifndef allows them to disable the static assert.\n\n// TODO Define a static assert for C/pre-C++-11\n#if defined( __cplusplus ) && __cplusplus >= 201103L\nstatic_assert( sizeof( OptixBuildInput ) == 8 + 1024, \"OptixBuildInput has wrong size\" );\n#endif\n\n/// Flags set on the #OptixInstance::flags.\n///\n/// These can be or'ed together to combine multiple flags.\ntypedef enum OptixInstanceFlags\n{\n    /// No special flag set\n    OPTIX_INSTANCE_FLAG_NONE = 0,\n\n    /// Prevent triangles from getting culled due to their orientation.\n    /// Effectively ignores ray flags\n    /// OPTIX_RAY_FLAG_CULL_BACK_FACING_TRIANGLES and OPTIX_RAY_FLAG_CULL_FRONT_FACING_TRIANGLES.\n    OPTIX_INSTANCE_FLAG_DISABLE_TRIANGLE_FACE_CULLING = 1u << 0,\n\n    /// Flip triangle orientation.\n    /// This affects front/backface culling as well as the reported face in case of a hit.\n    OPTIX_INSTANCE_FLAG_FLIP_TRIANGLE_FACING = 1u << 1,\n\n    /// Disable anyhit programs for all geometries of the instance.\n    /// Can be overridden by OPTIX_RAY_FLAG_ENFORCE_ANYHIT.\n    /// This flag is mutually exclusive with OPTIX_INSTANCE_FLAG_ENFORCE_ANYHIT.\n    OPTIX_INSTANCE_FLAG_DISABLE_ANYHIT = 1u << 2,\n\n    /// Enables anyhit programs for all geometries of the instance.\n    /// Overrides OPTIX_GEOMETRY_FLAG_DISABLE_ANYHIT\n    /// Can be overridden by OPTIX_RAY_FLAG_DISABLE_ANYHIT.\n    /// This flag is mutually exclusive with OPTIX_INSTANCE_FLAG_DISABLE_ANYHIT.\n    OPTIX_INSTANCE_FLAG_ENFORCE_ANYHIT = 1u << 3,\n\n    /// Disable the instance transformation\n    OPTIX_INSTANCE_FLAG_DISABLE_TRANSFORM = 1u << 6,\n} OptixInstanceFlags;\n\n/// Instances\n///\n/// \\see #OptixBuildInputInstanceArray::instances\ntypedef struct OptixInstance\n{\n    /// affine object-to-world transformation as 3x4 matrix in row-major layout\n    float transform[12];\n\n    /// Application supplied ID. The maximal ID can be queried using OPTIX_DEVICE_PROPERTY_LIMIT_MAX_INSTANCE_ID.\n    unsigned int instanceId;\n\n    /// SBT record offset.  Will only be used for instances of geometry acceleration structure (GAS) objects.\n    /// Needs to be set to 0 for instances of instance acceleration structure (IAS) objects. The maximal SBT offset\n    /// can be queried using OPTIX_DEVICE_PROPERTY_LIMIT_MAX_INSTANCE_SBT_OFFSET.\n    unsigned int sbtOffset;\n\n    /// Visibility mask. If rayMask & instanceMask == 0 the instance is culled. The number of available bits can be\n    /// queried using OPTIX_DEVICE_PROPERTY_LIMIT_NUM_BITS_INSTANCE_VISIBILITY_MASK.\n    unsigned int visibilityMask;\n\n    /// Any combination of OptixInstanceFlags is allowed.\n    unsigned int flags;\n\n    /// Set with an OptixTraversableHandle.\n    OptixTraversableHandle traversableHandle;\n\n    /// round up to 80-byte, to ensure 16-byte alignment\n    unsigned int pad[2];\n} OptixInstance;\n\n/// Builder Options\n///\n/// Used for #OptixAccelBuildOptions::buildFlags. Can be or'ed together.\ntypedef enum OptixBuildFlags\n{\n    /// No special flags set.\n    OPTIX_BUILD_FLAG_NONE = 0,\n\n    /// Allow updating the build with new vertex positions with subsequent calls to\n    /// optixAccelBuild.\n    OPTIX_BUILD_FLAG_ALLOW_UPDATE = 1u << 0,\n\n    OPTIX_BUILD_FLAG_ALLOW_COMPACTION = 1u << 1,\n\n    OPTIX_BUILD_FLAG_PREFER_FAST_TRACE = 1u << 2,\n\n    OPTIX_BUILD_FLAG_PREFER_FAST_BUILD = 1u << 3,\n\n    /// Allow random access to build input vertices\n    /// See optixGetTriangleVertexData\n    ///     optixGetLinearCurveVertexData\n    ///     optixGetQuadraticBSplineVertexData\n    ///     optixGetCubicBSplineVertexData\n    OPTIX_BUILD_FLAG_ALLOW_RANDOM_VERTEX_ACCESS = 1u << 4,\n\n    /// Allow random access to instances\n    /// See optixGetInstanceTraversableFromIAS\n    OPTIX_BUILD_FLAG_ALLOW_RANDOM_INSTANCE_ACCESS = 1u << 5,\n} OptixBuildFlags;\n\n/// Enum to specify the acceleration build operation.\n///\n/// Used in OptixAccelBuildOptions, which is then passed to optixAccelBuild and\n/// optixAccelComputeMemoryUsage, this enum indicates whether to do a build or an update\n/// of the acceleration structure.\n///\n/// Acceleration structure updates utilize the same acceleration structure, but with\n/// updated bounds.  Updates are typically much faster than builds, however, large\n/// perturbations can degrade the quality of the acceleration structure.\n///\n/// \\see #optixAccelComputeMemoryUsage(), #optixAccelBuild(), #OptixAccelBuildOptions\ntypedef enum OptixBuildOperation\n{\n    /// Perform a full build operation\n    OPTIX_BUILD_OPERATION_BUILD = 0x2161,\n    /// Perform an update using new bounds\n    OPTIX_BUILD_OPERATION_UPDATE = 0x2162,\n} OptixBuildOperation;\n\n/// Enum to specify motion flags.\n///\n/// \\see #OptixMotionOptions::flags.\ntypedef enum OptixMotionFlags\n{\n    OPTIX_MOTION_FLAG_NONE         = 0,\n    OPTIX_MOTION_FLAG_START_VANISH = 1u << 0,\n    OPTIX_MOTION_FLAG_END_VANISH   = 1u << 1\n} OptixMotionFlags;\n\n/// Motion options\n///\n/// \\see #OptixAccelBuildOptions::motionOptions, #OptixMatrixMotionTransform::motionOptions,\n///      #OptixSRTMotionTransform::motionOptions\ntypedef struct OptixMotionOptions\n{\n    /// If numKeys > 1, motion is enabled. timeBegin,\n    /// timeEnd and flags are all ignored when motion is disabled.\n    unsigned short numKeys;\n\n    /// Combinations of #OptixMotionFlags\n    unsigned short flags;\n\n    /// Point in time where motion starts.\n    float timeBegin;\n\n    /// Point in time where motion ends.\n    float timeEnd;\n} OptixMotionOptions;\n\n/// Build options for acceleration structures.\n///\n/// \\see #optixAccelComputeMemoryUsage(), #optixAccelBuild()\ntypedef struct OptixAccelBuildOptions\n{\n    /// Combinations of OptixBuildFlags\n    unsigned int buildFlags;\n\n    /// If OPTIX_BUILD_OPERATION_UPDATE the output buffer is assumed to contain the result\n    /// of a full build with OPTIX_BUILD_FLAG_ALLOW_UPDATE set and using the same number of\n    /// primitives.  It is updated incrementally to reflect the current position of the\n    /// primitives.\n    OptixBuildOperation operation;\n\n    /// Options for motion.\n    OptixMotionOptions motionOptions;\n} OptixAccelBuildOptions;\n\n/// Struct for querying builder allocation requirements.\n///\n/// Once queried the sizes should be used to allocate device memory of at least these sizes.\n///\n/// \\see #optixAccelComputeMemoryUsage()\ntypedef struct OptixAccelBufferSizes\n{\n    /// The size in bytes required for the outputBuffer parameter to optixAccelBuild when\n    /// doing a build (OPTIX_BUILD_OPERATION_BUILD).\n    size_t outputSizeInBytes;\n\n    /// The size in bytes required for the tempBuffer paramter to optixAccelBuild when\n    /// doing a build (OPTIX_BUILD_OPERATION_BUILD).\n    size_t tempSizeInBytes;\n\n    /// The size in bytes required for the tempBuffer parameter to optixAccelBuild\n    /// when doing an update (OPTIX_BUILD_OPERATION_UPDATE).  This value can be different\n    /// than tempSizeInBytes used for a full build.  Only non-zero if\n    /// OPTIX_BUILD_FLAG_ALLOW_UPDATE flag is set in OptixAccelBuildOptions.\n    size_t tempUpdateSizeInBytes;\n} OptixAccelBufferSizes;\n\n/// Properties which can be emitted during acceleration structure build.\n///\n/// \\see #OptixAccelEmitDesc::type.\ntypedef enum OptixAccelPropertyType\n{\n    /// Size of a compacted acceleration structure. The device pointer points to a uint64.\n    OPTIX_PROPERTY_TYPE_COMPACTED_SIZE = 0x2181,\n\n    /// OptixAabb * numMotionSteps\n    OPTIX_PROPERTY_TYPE_AABBS = 0x2182,\n} OptixAccelPropertyType;\n\n/// Specifies a type and output destination for emitted post-build properties.\n///\n/// \\see #optixAccelBuild()\ntypedef struct OptixAccelEmitDesc\n{\n    /// Output buffer for the properties\n    CUdeviceptr result;\n\n    /// Requested property\n    OptixAccelPropertyType type;\n} OptixAccelEmitDesc;\n\n/// Used to store information related to relocation of acceleration structures.\n///\n/// \\see #optixAccelGetRelocationInfo(), #optixAccelCheckRelocationCompatibility(), #optixAccelRelocate()\ntypedef struct OptixAccelRelocationInfo\n{\n    /// Opaque data, used internally, should not be modified\n    unsigned long long info[4];\n} OptixAccelRelocationInfo;\n\n/// Static transform\n///\n/// The device address of instances of this type must be a multiple of OPTIX_TRANSFORM_BYTE_ALIGNMENT.\n///\n/// \\see #optixConvertPointerToTraversableHandle()\ntypedef struct OptixStaticTransform\n{\n    /// The traversable transformed by this transformation\n    OptixTraversableHandle child;\n\n    /// Padding to make the transformations 16 byte aligned\n    unsigned int pad[2];\n\n    /// Affine object-to-world transformation as 3x4 matrix in row-major layout\n    float transform[12];\n\n    /// Affine world-to-object transformation as 3x4 matrix in row-major layout\n    /// Must be the inverse of the transform matrix\n    float invTransform[12];\n} OptixStaticTransform;\n\n/// Represents a matrix motion transformation.\n///\n/// The device address of instances of this type must be a multiple of OPTIX_TRANSFORM_BYTE_ALIGNMENT.\n///\n/// This struct, as defined here, handles only N=2 motion keys due to the fixed array length of its transform member.\n/// The following example shows how to create instances for an arbitrary number N of motion keys:\n///\n/// \\code\n/// float matrixData[N][12];\n/// ... // setup matrixData\n///\n/// size_t transformSizeInBytes = sizeof( OptixMatrixMotionTransform ) + ( N-2 ) * 12 * sizeof( float );\n/// OptixMatrixMotionTransform* matrixMoptionTransform = (OptixMatrixMotionTransform*) malloc( transformSizeInBytes );\n/// memset( matrixMoptionTransform, 0, transformSizeInBytes );\n///\n/// ... // setup other members of matrixMoptionTransform\n/// matrixMoptionTransform->motionOptions.numKeys/// = N;\n/// memcpy( matrixMoptionTransform->transform, matrixData, N * 12 * sizeof( float ) );\n///\n/// ... // copy matrixMoptionTransform to device memory\n/// free( matrixMoptionTransform )\n/// \\endcode\n///\n/// \\see #optixConvertPointerToTraversableHandle()\ntypedef struct OptixMatrixMotionTransform\n{\n    /// The traversable that is transformed by this transformation\n    OptixTraversableHandle child;\n\n    /// The motion options for this transformation\n    OptixMotionOptions motionOptions;\n\n    /// Padding to make the transformation 16 byte aligned\n    unsigned int pad[3];\n\n    /// Affine object-to-world transformation as 3x4 matrix in row-major layout\n    float transform[2][12];\n} OptixMatrixMotionTransform;\n\n/// Represents an SRT transformation.\n///\n/// An SRT transformation can represent a smooth rotation with fewer motion keys than a matrix transformation. Each\n/// motion key is constructed from elements taken from a matrix S, a quaternion R, and a translation T.\n///\n/// The scaling matrix\n/// \\f$S = \\begin{bmatrix} sx & a & b & pvx \\\\ 0 & sy & c & pvy \\\\ 0 & 0  & sz & pvz \\end{bmatrix}\\f$\n//      [ sx   a   b  pvx ]\n//  S = [  0  sy   c  pvy ]\n//      [  0   0  sz  pvz ]\n/// defines an affine transformation that can include scale, shear, and a translation.\n/// The translation allows to define the pivot point for the subsequent rotation.\n///\n/// The quaternion R = [ qx, qy, qz, qw ] describes a rotation  with angular component qw = cos(theta/2) and other\n/// components [ qx, qy, qz ] = sin(theta/2) * [ ax, ay, az ] where the axis [ ax, ay, az ] is normalized.\n///\n/// The translation matrix\n/// \\f$T = \\begin{bmatrix} 1 & 0 & 0 & tx \\\\ 0 & 1 & 0 & ty \\\\ 0 & 0 & 1 & tz \\end{bmatrix}\\f$\n//      [  1  0  0 tx ]\n//  T = [  0  1  0 ty ]\n//      [  0  0  1 tz ]\n/// defines another translation that is applied after the rotation. Typically, this translation includes\n/// the inverse translation from the matrix S to reverse the translation for the pivot point for R.\n///\n/// To obtain the effective transformation at time t, the elements of the components of S, R, and T will be interpolated\n/// linearly. The components are then multiplied to obtain the combined transformation C = T * R * S. The transformation\n/// C is the effective object-to-world transformations at time t, and C^(-1) is the effective world-to-object\n/// transformation at time t.\n///\n/// \\see #OptixSRTMotionTransform::srtData, #optixConvertPointerToTraversableHandle()\ntypedef struct OptixSRTData\n{\n    /// \\name Parameters describing the SRT transformation\n    /// @{\n    float sx, a, b, pvx, sy, c, pvy, sz, pvz, qx, qy, qz, qw, tx, ty, tz;\n    /// @}\n} OptixSRTData;\n\n// TODO Define a static assert for C/pre-C++-11\n#if defined( __cplusplus ) && __cplusplus >= 201103L\nstatic_assert( sizeof( OptixSRTData ) == 16 * 4, \"OptixSRTData has wrong size\" );\n#endif\n\n/// Represents an SRT motion transformation.\n///\n/// The device address of instances of this type must be a multiple of OPTIX_TRANSFORM_BYTE_ALIGNMENT.\n///\n/// This struct, as defined here, handles only N=2 motion keys due to the fixed array length of its srtData member.\n/// The following example shows how to create instances for an arbitrary number N of motion keys:\n///\n/// \\code\n/// OptixSRTData srtData[N];\n/// ... // setup srtData\n///\n/// size_t transformSizeInBytes = sizeof( OptixSRTMotionTransform ) + ( N-2 ) * sizeof( OptixSRTData );\n/// OptixSRTMotionTransform* srtMotionTransform = (OptixSRTMotionTransform*) malloc( transformSizeInBytes );\n/// memset( srtMotionTransform, 0, transformSizeInBytes );\n///\n/// ... // setup other members of srtMotionTransform\n/// srtMotionTransform->motionOptions.numKeys   = N;\n/// memcpy( srtMotionTransform->srtData, srtData, N * sizeof( OptixSRTData ) );\n///\n/// ... // copy srtMotionTransform to device memory\n/// free( srtMotionTransform )\n/// \\endcode\n///\n/// \\see #optixConvertPointerToTraversableHandle()\ntypedef struct OptixSRTMotionTransform\n{\n    /// The traversable transformed by this transformation\n    OptixTraversableHandle child;\n\n    /// The motion options for this transformation\n    OptixMotionOptions motionOptions;\n\n    /// Padding to make the SRT data 16 byte aligned\n    unsigned int pad[3];\n\n    /// The actual SRT data describing the transformation\n    OptixSRTData srtData[2];\n} OptixSRTMotionTransform;\n\n// TODO Define a static assert for C/pre-C++-11\n#if defined( __cplusplus ) && __cplusplus >= 201103L\nstatic_assert( sizeof( OptixSRTMotionTransform ) == 8 + 12 + 12 + 2 * 16 * 4, \"OptixSRTMotionTransform has wrong size\" );\n#endif\n\n/// Traversable Handles\n///\n/// \\see #optixConvertPointerToTraversableHandle()\ntypedef enum OptixTraversableType\n{\n    /// Static transforms. \\see #OptixStaticTransform\n    OPTIX_TRAVERSABLE_TYPE_STATIC_TRANSFORM = 0x21C1,\n    /// Matrix motion transform. \\see #OptixMatrixMotionTransform\n    OPTIX_TRAVERSABLE_TYPE_MATRIX_MOTION_TRANSFORM = 0x21C2,\n    /// SRT motion transform. \\see #OptixSRTMotionTransform\n    OPTIX_TRAVERSABLE_TYPE_SRT_MOTION_TRANSFORM = 0x21C3,\n} OptixTraversableType;\n\n/// Pixel formats used by the denoiser.\n///\n/// \\see #OptixImage2D::format\ntypedef enum OptixPixelFormat\n{\n    OPTIX_PIXEL_FORMAT_HALF2  = 0x2207,  ///< two halfs, XY\n    OPTIX_PIXEL_FORMAT_HALF3  = 0x2201,  ///< three halfs, RGB\n    OPTIX_PIXEL_FORMAT_HALF4  = 0x2202,  ///< four halfs, RGBA\n    OPTIX_PIXEL_FORMAT_FLOAT2 = 0x2208,  ///< two floats, XY\n    OPTIX_PIXEL_FORMAT_FLOAT3 = 0x2203,  ///< three floats, RGB\n    OPTIX_PIXEL_FORMAT_FLOAT4 = 0x2204,  ///< four floats, RGBA\n    OPTIX_PIXEL_FORMAT_UCHAR3 = 0x2205,  ///< three unsigned chars, RGB\n    OPTIX_PIXEL_FORMAT_UCHAR4 = 0x2206   ///< four unsigned chars, RGBA\n} OptixPixelFormat;\n\n/// Image descriptor used by the denoiser.\n///\n/// \\see #optixDenoiserInvoke(), #optixDenoiserComputeIntensity()\ntypedef struct OptixImage2D\n{\n    /// Pointer to the actual pixel data.\n    CUdeviceptr data;\n    /// Width of the image (in pixels)\n    unsigned int width;\n    /// Height of the image (in pixels)\n    unsigned int height;\n    /// Stride between subsequent rows of the image (in bytes).\n    unsigned int rowStrideInBytes;\n    /// Stride between subsequent pixels of the image (in bytes).\n    /// For now, only 0 or the value that corresponds to a dense packing of pixels (no gaps) is supported.\n    unsigned int pixelStrideInBytes;\n    /// Pixel format.\n    OptixPixelFormat format;\n} OptixImage2D;\n\n/// Model kind used by the denoiser.\n///\n/// \\see #optixDenoiserCreate\ntypedef enum OptixDenoiserModelKind\n{\n    /// Use the built-in model appropriate for low dynamic range input.\n    OPTIX_DENOISER_MODEL_KIND_LDR = 0x2322,\n\n    /// Use the built-in model appropriate for high dynamic range input.\n    OPTIX_DENOISER_MODEL_KIND_HDR = 0x2323,\n\n    /// Use the built-in model appropriate for high dynamic range input and support for AOVs\n    OPTIX_DENOISER_MODEL_KIND_AOV = 0x2324,\n\n    /// Use the built-in model appropriate for high dynamic range input, temporally stable\n    OPTIX_DENOISER_MODEL_KIND_TEMPORAL = 0x2325,\n\n} OptixDenoiserModelKind;\n\n/// Options used by the denoiser\n///\n/// \\see #optixDenoiserCreate()\ntypedef struct OptixDenoiserOptions\n{\n    // if nonzero, albedo image must be given in OptixDenoiserGuideLayer\n    unsigned int guideAlbedo;\n\n    // if nonzero, normal image must be given in OptixDenoiserGuideLayer\n    unsigned int guideNormal;\n} OptixDenoiserOptions;\n\n/// Guide layer for the denoiser\n///\n/// \\see #optixDenoiserInvoke()\ntypedef struct OptixDenoiserGuideLayer\n{\n    // albedo/bsdf image\n    OptixImage2D  albedo;\n\n    // normal vector image (2d or 3d pixel format)\n    OptixImage2D  normal;\n\n    // 2d flow image, pixel flow from previous to current frame for each pixel\n    OptixImage2D  flow;\n} OptixDenoiserGuideLayer;\n\n/// Input/Output layers for the denoiser\n///\n/// \\see #optixDenoiserInvoke()\ntypedef struct OptixDenoiserLayer\n{\n    // input image (beauty or AOV)\n    OptixImage2D  input;\n\n    // denoised output image from previous frame if temporal model kind selected\n    OptixImage2D  previousOutput;\n\n    // denoised output for given input\n    OptixImage2D  output;\n} OptixDenoiserLayer;\n\n/// Various parameters used by the denoiser\n///\n/// \\see #optixDenoiserInvoke()\n/// \\see #optixDenoiserComputeIntensity()\n/// \\see #optixDenoiserComputeAverageColor()\ntypedef struct OptixDenoiserParams\n{\n    /// if set to nonzero value, denoise alpha channel (if present) in first inputLayer image\n    unsigned int denoiseAlpha;\n\n    /// average log intensity of input image (default null pointer). points to a single float.\n    /// with the default (null pointer) denoised results will not be optimal for very dark or\n    /// bright input images.\n    CUdeviceptr  hdrIntensity;\n\n    /// blend factor.\n    /// If set to 0 the output is 100% of the denoised input. If set to 1, the output is 100% of\n    /// the unmodified input. Values between 0 and 1 will linearly interpolate between the denoised\n    /// and unmodified input.\n    float        blendFactor;\n\n    /// this parameter is used when the OPTIX_DENOISER_MODEL_KIND_AOV model kind is set.\n    /// average log color of input image, separate for RGB channels (default null pointer).\n    /// points to three floats. with the default (null pointer) denoised results will not be\n    /// optimal.\n    CUdeviceptr  hdrAverageColor;\n} OptixDenoiserParams;\n\n/// Various sizes related to the denoiser.\n///\n/// \\see #optixDenoiserComputeMemoryResources()\ntypedef struct OptixDenoiserSizes\n{\n    size_t       stateSizeInBytes;\n    size_t       withOverlapScratchSizeInBytes;\n    size_t       withoutOverlapScratchSizeInBytes;\n    unsigned int overlapWindowSizeInPixels;\n} OptixDenoiserSizes;\n\n/// Ray flags passed to the device function #optixTrace().  These affect the behavior of\n/// traversal per invocation.\n///\n/// \\see #optixTrace()\ntypedef enum OptixRayFlags\n{\n    /// No change from the behavior configured for the individual AS.\n    OPTIX_RAY_FLAG_NONE = 0u,\n\n    /// Disables anyhit programs for the ray.\n    /// Overrides OPTIX_INSTANCE_FLAG_ENFORCE_ANYHIT.\n    /// This flag is mutually exclusive with OPTIX_RAY_FLAG_ENFORCE_ANYHIT,\n    /// OPTIX_RAY_FLAG_CULL_DISABLED_ANYHIT, OPTIX_RAY_FLAG_CULL_ENFORCED_ANYHIT.\n    OPTIX_RAY_FLAG_DISABLE_ANYHIT = 1u << 0,\n\n    /// Forces anyhit program execution for the ray.\n    /// Overrides OPTIX_GEOMETRY_FLAG_DISABLE_ANYHIT as well as OPTIX_INSTANCE_FLAG_DISABLE_ANYHIT.\n    /// This flag is mutually exclusive with OPTIX_RAY_FLAG_DISABLE_ANYHIT,\n    /// OPTIX_RAY_FLAG_CULL_DISABLED_ANYHIT, OPTIX_RAY_FLAG_CULL_ENFORCED_ANYHIT.\n    OPTIX_RAY_FLAG_ENFORCE_ANYHIT = 1u << 1,\n\n    /// Terminates the ray after the first hit and executes\n    /// the closesthit program of that hit.\n    OPTIX_RAY_FLAG_TERMINATE_ON_FIRST_HIT = 1u << 2,\n\n    /// Disables closesthit programs for the ray, but still executes miss program in case of a miss.\n    OPTIX_RAY_FLAG_DISABLE_CLOSESTHIT = 1u << 3,\n\n    /// Do not intersect triangle back faces\n    /// (respects a possible face change due to instance flag\n    /// OPTIX_INSTANCE_FLAG_FLIP_TRIANGLE_FACING).\n    /// This flag is mutually exclusive with OPTIX_RAY_FLAG_CULL_FRONT_FACING_TRIANGLES.\n    OPTIX_RAY_FLAG_CULL_BACK_FACING_TRIANGLES = 1u << 4,\n\n    /// Do not intersect triangle front faces\n    /// (respects a possible face change due to instance flag\n    /// OPTIX_INSTANCE_FLAG_FLIP_TRIANGLE_FACING).\n    /// This flag is mutually exclusive with OPTIX_RAY_FLAG_CULL_BACK_FACING_TRIANGLES.\n    OPTIX_RAY_FLAG_CULL_FRONT_FACING_TRIANGLES = 1u << 5,\n\n    /// Do not intersect geometry which disables anyhit programs\n    /// (due to setting geometry flag OPTIX_GEOMETRY_FLAG_DISABLE_ANYHIT or\n    /// instance flag OPTIX_INSTANCE_FLAG_DISABLE_ANYHIT).\n    /// This flag is mutually exclusive with OPTIX_RAY_FLAG_CULL_ENFORCED_ANYHIT,\n    /// OPTIX_RAY_FLAG_ENFORCE_ANYHIT, OPTIX_RAY_FLAG_DISABLE_ANYHIT.\n    OPTIX_RAY_FLAG_CULL_DISABLED_ANYHIT = 1u << 6,\n\n    /// Do not intersect geometry which have an enabled anyhit program\n    /// (due to not setting geometry flag OPTIX_GEOMETRY_FLAG_DISABLE_ANYHIT or\n    /// setting instance flag OPTIX_INSTANCE_FLAG_ENFORCE_ANYHIT).\n    /// This flag is mutually exclusive with OPTIX_RAY_FLAG_CULL_DISABLED_ANYHIT,\n    /// OPTIX_RAY_FLAG_ENFORCE_ANYHIT, OPTIX_RAY_FLAG_DISABLE_ANYHIT.\n    OPTIX_RAY_FLAG_CULL_ENFORCED_ANYHIT = 1u << 7\n} OptixRayFlags;\n\n/// Transform\n///\n/// OptixTransformType is used by the device function #optixGetTransformTypeFromHandle() to\n/// determine the type of the OptixTraversableHandle returned from\n/// optixGetTransformListHandle().\ntypedef enum OptixTransformType\n{\n    OPTIX_TRANSFORM_TYPE_NONE                    = 0, ///< Not a transformation\n    OPTIX_TRANSFORM_TYPE_STATIC_TRANSFORM        = 1, ///< \\see #OptixStaticTransform\n    OPTIX_TRANSFORM_TYPE_MATRIX_MOTION_TRANSFORM = 2, ///< \\see #OptixMatrixMotionTransform\n    OPTIX_TRANSFORM_TYPE_SRT_MOTION_TRANSFORM    = 3, ///< \\see #OptixSRTMotionTransform\n    OPTIX_TRANSFORM_TYPE_INSTANCE                = 4, ///< \\see #OptixInstance\n} OptixTransformType;\n\n/// Specifies the set of valid traversable graphs that may be\n/// passed to invocation of #optixTrace(). Flags may be bitwise combined.\ntypedef enum OptixTraversableGraphFlags\n{\n    ///  Used to signal that any traversable graphs is valid.\n    ///  This flag is mutually exclusive with all other flags.\n    OPTIX_TRAVERSABLE_GRAPH_FLAG_ALLOW_ANY = 0,\n\n    ///  Used to signal that a traversable graph of a single Geometry Acceleration\n    ///  Structure (GAS) without any transforms is valid. This flag may be combined with\n    ///  other flags except for OPTIX_TRAVERSABLE_GRAPH_FLAG_ALLOW_ANY.\n    OPTIX_TRAVERSABLE_GRAPH_FLAG_ALLOW_SINGLE_GAS = 1u << 0,\n\n    ///  Used to signal that a traversable graph of a single Instance Acceleration\n    ///  Structure (IAS) directly connected to Geometry Acceleration Structure (GAS)\n    ///  traversables without transform traversables in between is valid.  This flag may\n    ///  be combined with other flags except for OPTIX_TRAVERSABLE_GRAPH_FLAG_ALLOW_ANY.\n    OPTIX_TRAVERSABLE_GRAPH_FLAG_ALLOW_SINGLE_LEVEL_INSTANCING = 1u << 1,\n} OptixTraversableGraphFlags;\n\n/// Optimization levels\n///\n/// \\see #OptixModuleCompileOptions::optLevel\ntypedef enum OptixCompileOptimizationLevel\n{\n    /// Default is to run all optimizations\n    OPTIX_COMPILE_OPTIMIZATION_DEFAULT = 0,\n    /// No optimizations\n    OPTIX_COMPILE_OPTIMIZATION_LEVEL_0 = 0x2340,\n    /// Some optimizations\n    OPTIX_COMPILE_OPTIMIZATION_LEVEL_1 = 0x2341,\n    /// Most optimizations\n    OPTIX_COMPILE_OPTIMIZATION_LEVEL_2 = 0x2342,\n    /// All optimizations\n    OPTIX_COMPILE_OPTIMIZATION_LEVEL_3 = 0x2343,\n} OptixCompileOptimizationLevel;\n\n/// Debug levels\n///\n/// \\see #OptixModuleCompileOptions::debugLevel\ntypedef enum OptixCompileDebugLevel\n{\n    /// Default currently is to add line info\n    OPTIX_COMPILE_DEBUG_LEVEL_DEFAULT  = 0,\n    /// No debug information\n    OPTIX_COMPILE_DEBUG_LEVEL_NONE     = 0x2350,\n    /// Generate lineinfo only\n    OPTIX_COMPILE_DEBUG_LEVEL_LINEINFO = 0x2351,\n    /// Generate dwarf debug information.\n    OPTIX_COMPILE_DEBUG_LEVEL_FULL     = 0x2352,\n} OptixCompileDebugLevel;\n\n\n\n\n/// Struct for specifying specializations for pipelineParams as specified in\n/// OptixPipelineCompileOptions::pipelineLaunchParamsVariableName.\n///\n/// The bound values are supposed to represent a constant value in the\n/// pipelineParams. OptiX will attempt to locate all loads from the pipelineParams and\n/// correlate them to the appropriate bound value, but there are cases where OptiX cannot\n/// safely or reliably do this. For example if the pointer to the pipelineParams is passed\n/// as an argument to a non-inline function or the offset of the load to the\n/// pipelineParams cannot be statically determined (e.g. accessed in a loop). No module\n/// should rely on the value being specialized in order to work correctly.  The values in\n/// the pipelineParams specified on optixLaunch should match the bound value. If\n/// validation mode is enabled on the context, OptiX will verify that the bound values\n/// specified matches the values in pipelineParams specified to optixLaunch.\n///\n/// These values are compiled in to the module as constants. Once the constants are\n/// inserted into the code, an optimization pass will be run that will attempt to\n/// propagate the consants and remove unreachable code.\n///\n/// If caching is enabled, changes in these values will result in newly compiled modules.\n///\n/// The pipelineParamOffset and sizeInBytes must be within the bounds of the\n/// pipelineParams variable. OPTIX_ERROR_INVALID_VALUE will be returned from\n/// optixModuleCreateFromPTX otherwise.\n///\n/// If more than one bound value overlaps or the size of a bound value is equal to 0,\n/// an OPTIX_ERROR_INVALID_VALUE will be returned from optixModuleCreateFromPTX.\n///\n/// The same set of bound values do not need to be used for all modules in a pipeline, but\n/// overlapping values between modules must have the same value.\n/// OPTIX_ERROR_INVALID_VALUE will be returned from optixPipelineCreate otherwise.\n///\n/// \\see #OptixModuleCompileOptions\ntypedef struct OptixModuleCompileBoundValueEntry {\n    size_t pipelineParamOffsetInBytes;\n    size_t sizeInBytes;\n    const void* boundValuePtr;\n    const char* annotation; // optional string to display, set to 0 if unused.  If unused,\n                            // OptiX will report the annotation as \"No annotation\"\n} OptixModuleCompileBoundValueEntry;\n\n\n/// Compilation options for module\n///\n/// \\see #optixModuleCreateFromPTX()\ntypedef struct OptixModuleCompileOptions\n{\n    /// Maximum number of registers allowed when compiling to SASS.\n    /// Set to 0 for no explicit limit. May vary within a pipeline.\n    int maxRegisterCount;\n\n    /// Optimization level. May vary within a pipeline.\n    OptixCompileOptimizationLevel optLevel;\n\n    /// Generate debug information.\n    OptixCompileDebugLevel debugLevel;\n\n    /// Ingored if numBoundValues is set to 0\n    const OptixModuleCompileBoundValueEntry* boundValues;\n\n    /// set to 0 if unused\n    unsigned int numBoundValues;\n\n} OptixModuleCompileOptions;\n\n\n/// Distinguishes different kinds of program groups.\ntypedef enum OptixProgramGroupKind\n{\n    /// Program group containing a raygen (RG) program\n    /// \\see #OptixProgramGroupSingleModule, #OptixProgramGroupDesc::raygen\n    OPTIX_PROGRAM_GROUP_KIND_RAYGEN = 0x2421,\n\n    /// Program group containing a miss (MS) program\n    /// \\see #OptixProgramGroupSingleModule, #OptixProgramGroupDesc::miss\n    OPTIX_PROGRAM_GROUP_KIND_MISS = 0x2422,\n\n    /// Program group containing an exception (EX) program\n    /// \\see OptixProgramGroupHitgroup, #OptixProgramGroupDesc::exception\n    OPTIX_PROGRAM_GROUP_KIND_EXCEPTION = 0x2423,\n\n    /// Program group containing an intersection (IS), any hit (AH), and/or closest hit (CH) program\n    /// \\see #OptixProgramGroupSingleModule, #OptixProgramGroupDesc::hitgroup\n    OPTIX_PROGRAM_GROUP_KIND_HITGROUP = 0x2424,\n\n    /// Program group containing a direct (DC) or continuation (CC) callable program\n    /// \\see OptixProgramGroupCallables, #OptixProgramGroupDesc::callables\n    OPTIX_PROGRAM_GROUP_KIND_CALLABLES = 0x2425\n} OptixProgramGroupKind;\n\n/// Flags for program groups\ntypedef enum OptixProgramGroupFlags\n{\n    /// Currently there are no flags\n    OPTIX_PROGRAM_GROUP_FLAGS_NONE = 0\n} OptixProgramGroupFlags;\n\n/// Program group representing a single module.\n///\n/// Used for raygen, miss, and exception programs. In case of raygen and exception programs, module and entry\n/// function name need to be valid. For miss programs, module and entry function name might both be \\c nullptr.\n///\n/// \\see #OptixProgramGroupDesc::raygen, #OptixProgramGroupDesc::miss, #OptixProgramGroupDesc::exception\ntypedef struct OptixProgramGroupSingleModule\n{\n    /// Module holding single program.\n    OptixModule module;\n    /// Entry function name of the single program.\n    const char* entryFunctionName;\n} OptixProgramGroupSingleModule;\n\n/// Program group representing the hitgroup.\n///\n/// For each of the three program types, module and entry function name might both be \\c nullptr.\n///\n/// \\see #OptixProgramGroupDesc::hitgroup\ntypedef struct OptixProgramGroupHitgroup\n{\n    /// Module holding the closest hit (CH) program.\n    OptixModule moduleCH;\n    /// Entry function name of the closest hit (CH) program.\n    const char* entryFunctionNameCH;\n    /// Module holding the any hit (AH) program.\n    OptixModule moduleAH;\n    /// Entry function name of the any hit (AH) program.\n    const char* entryFunctionNameAH;\n    /// Module holding the intersection (Is) program.\n    OptixModule moduleIS;\n    /// Entry function name of the intersection (IS) program.\n    const char* entryFunctionNameIS;\n} OptixProgramGroupHitgroup;\n\n/// Program group representing callables.\n///\n/// Module and entry function name need to be valid for at least one of the two callables.\n///\n/// \\see ##OptixProgramGroupDesc::callables\ntypedef struct OptixProgramGroupCallables\n{\n    /// Module holding the direct callable (DC) program.\n    OptixModule moduleDC;\n    /// Entry function name of the direct callable (DC) program.\n    const char* entryFunctionNameDC;\n    /// Module holding the continuation callable (CC) program.\n    OptixModule moduleCC;\n    /// Entry function name of the continuation callable (CC) program.\n    const char* entryFunctionNameCC;\n} OptixProgramGroupCallables;\n\n/// Descriptor for program groups.\ntypedef struct OptixProgramGroupDesc\n{\n    /// The kind of program group.\n    OptixProgramGroupKind kind;\n\n    /// See #OptixProgramGroupFlags\n    unsigned int flags;\n\n    union\n    {\n        /// \\see #OPTIX_PROGRAM_GROUP_KIND_RAYGEN\n        OptixProgramGroupSingleModule raygen;\n        /// \\see #OPTIX_PROGRAM_GROUP_KIND_MISS\n        OptixProgramGroupSingleModule miss;\n        /// \\see #OPTIX_PROGRAM_GROUP_KIND_EXCEPTION\n        OptixProgramGroupSingleModule exception;\n        /// \\see #OPTIX_PROGRAM_GROUP_KIND_CALLABLES\n        OptixProgramGroupCallables callables;\n        /// \\see #OPTIX_PROGRAM_GROUP_KIND_HITGROUP\n        OptixProgramGroupHitgroup hitgroup;\n    };\n} OptixProgramGroupDesc;\n\n/// Program group options\n///\n/// \\see #optixProgramGroupCreate()\ntypedef struct OptixProgramGroupOptions\n{\n    /// reserved value for future use. must be 0.\n    int reserved;\n} OptixProgramGroupOptions;\n\n/// The following values are used to indicate which exception was thrown.\ntypedef enum OptixExceptionCodes\n{\n    /// Stack overflow of the continuation stack.\n    /// no exception details.\n    OPTIX_EXCEPTION_CODE_STACK_OVERFLOW = -1,\n\n    /// The trace depth is exceeded.\n    /// no exception details.\n    OPTIX_EXCEPTION_CODE_TRACE_DEPTH_EXCEEDED = -2,\n\n    /// The traversal depth is exceeded.\n    /// Exception details:\n    ///     optixGetTransformListSize()\n    ///     optixGetTransformListHandle()\n    OPTIX_EXCEPTION_CODE_TRAVERSAL_DEPTH_EXCEEDED = -3,\n\n    /// Traversal encountered an invalid traversable type.\n    /// Exception details:\n    ///     optixGetTransformListSize()\n    ///     optixGetTransformListHandle()\n    ///     optixGetExceptionInvalidTraversable()\n    OPTIX_EXCEPTION_CODE_TRAVERSAL_INVALID_TRAVERSABLE = -5,\n\n    /// The miss SBT record index is out of bounds\n    /// A miss SBT record index is valid within the range [0, OptixShaderBindingTable::missRecordCount) (See optixLaunch)\n    /// Exception details:\n    ///     optixGetExceptionInvalidSbtOffset()\n    OPTIX_EXCEPTION_CODE_TRAVERSAL_INVALID_MISS_SBT = -6,\n\n    /// The traversal hit SBT record index out of bounds.\n    ///\n    /// A traversal hit SBT record index is valid within the range [0, OptixShaderBindingTable::hitgroupRecordCount) (See optixLaunch)\n    /// The following formula relates the\n    //      sbt-index (See optixGetExceptionInvalidSbtOffset),\n    //      sbt-instance-offset (See OptixInstance::sbtOffset),\n    ///     sbt-geometry-acceleration-structure-index (See optixGetSbtGASIndex),\n    ///     sbt-stride-from-trace-call and sbt-offset-from-trace-call (See optixTrace)\n    ///\n    /// sbt-index = sbt-instance-offset + (sbt-geometry-acceleration-structure-index * sbt-stride-from-trace-call) + sbt-offset-from-trace-call\n    ///\n    /// Exception details:\n    ///     optixGetTransformListSize()\n    ///     optixGetTransformListHandle()\n    ///     optixGetExceptionInvalidSbtOffset()\n    ///     optixGetSbtGASIndex()\n    OPTIX_EXCEPTION_CODE_TRAVERSAL_INVALID_HIT_SBT = -7,\n\n    /// The shader encountered an unsupported primitive type (See OptixPipelineCompileOptions::usesPrimitiveTypeFlags).\n    /// no exception details.\n    OPTIX_EXCEPTION_CODE_UNSUPPORTED_PRIMITIVE_TYPE = -8,\n\n    /// The shader encountered a call to optixTrace with at least\n    /// one of the float arguments being inf or nan.\n    /// Exception details:\n    ///     optixGetExceptionInvalidRay()\n    OPTIX_EXCEPTION_CODE_INVALID_RAY = -9,\n\n    /// The shader encountered a call to either optixDirectCall or optixCallableCall\n    /// where the argument count does not match the parameter count of the callable\n    /// program which is called.\n    /// Exception details:\n    ///     optixGetExceptionParameterMismatch\n    OPTIX_EXCEPTION_CODE_CALLABLE_PARAMETER_MISMATCH = -10,\n\n    /// The invoked builtin IS does not match the current GAS\n    OPTIX_EXCEPTION_CODE_BUILTIN_IS_MISMATCH = -11,\n\n    /// Tried to call a callable program using an SBT offset that is larger\n    /// than the number of passed in callable SBT records.\n    /// Exception details:\n    ///     optixGetExceptionInvalidSbtOffset()\n    OPTIX_EXCEPTION_CODE_CALLABLE_INVALID_SBT = -12,\n\n    /// Tried to call a direct callable using an SBT offset of a record that\n    /// was built from a program group that did not include a direct callable.\n    OPTIX_EXCEPTION_CODE_CALLABLE_NO_DC_SBT_RECORD = -13,\n\n    /// Tried to call a continuation callable using an SBT offset of a record\n    /// that was built from a program group that did not include a continuation callable.\n    OPTIX_EXCEPTION_CODE_CALLABLE_NO_CC_SBT_RECORD = -14,\n\n    /// Tried to directly traverse a single gas while single gas traversable graphs are not enabled\n    ///   (see OptixTraversableGraphFlags::OPTIX_TRAVERSABLE_GRAPH_FLAG_ALLOW_SINGLE_GAS).\n    /// Exception details:\n    ///     optixGetTransformListSize()\n    ///     optixGetTransformListHandle()\n    ///     optixGetExceptionInvalidTraversable()\n    OPTIX_EXCEPTION_CODE_UNSUPPORTED_SINGLE_LEVEL_GAS = -15,\n\n    /// argument passed to an optix call is\n    /// not within an acceptable range of values.\n    OPTIX_EXCEPTION_CODE_INVALID_VALUE_ARGUMENT_0 = -16,\n    OPTIX_EXCEPTION_CODE_INVALID_VALUE_ARGUMENT_1 = -17,\n    OPTIX_EXCEPTION_CODE_INVALID_VALUE_ARGUMENT_2 = -18,\n\n    /// Tried to access data on an AS without random data access support (See OptixBuildFlags).\n    OPTIX_EXCEPTION_CODE_UNSUPPORTED_DATA_ACCESS = -32,\n\n} OptixExceptionCodes;\n\n/// Exception flags.\n///\n/// \\see #OptixPipelineCompileOptions::exceptionFlags, #OptixExceptionCodes\ntypedef enum OptixExceptionFlags\n{\n    /// No exception are enabled.\n    OPTIX_EXCEPTION_FLAG_NONE = 0,\n\n    /// Enables exceptions check related to the continuation stack.\n    OPTIX_EXCEPTION_FLAG_STACK_OVERFLOW = 1u << 0,\n\n    /// Enables exceptions check related to trace depth.\n    OPTIX_EXCEPTION_FLAG_TRACE_DEPTH = 1u << 1,\n\n    /// Enables user exceptions via optixThrowException(). This flag must be specified for all modules in a pipeline\n    /// if any module calls optixThrowException().\n    OPTIX_EXCEPTION_FLAG_USER = 1u << 2,\n\n    /// Enables various exceptions check related to traversal.\n    OPTIX_EXCEPTION_FLAG_DEBUG = 1u << 3\n} OptixExceptionFlags;\n\n/// Compilation options for all modules of a pipeline.\n///\n/// Similar to #OptixModuleCompileOptions, but these options here need to be equal for all modules of a pipeline.\n///\n/// \\see #optixModuleCreateFromPTX(), #optixPipelineCreate()\ntypedef struct OptixPipelineCompileOptions\n{\n    /// Boolean value indicating whether motion blur could be used\n    int usesMotionBlur;\n\n    /// Traversable graph bitfield. See OptixTraversableGraphFlags\n    unsigned int traversableGraphFlags;\n\n    /// How much storage, in 32b words, to make available for the payload, [0..32]\n    int numPayloadValues;\n\n    /// How much storage, in 32b words, to make available for the attributes. The\n    /// minimum number is 2. Values below that will automatically be changed to 2. [2..8]\n    int numAttributeValues;\n\n    /// A bitmask of OptixExceptionFlags indicating which exceptions are enabled.\n    unsigned int exceptionFlags;\n\n    /// The name of the pipeline parameter variable.  If 0, no pipeline parameter\n    /// will be available. This will be ignored if the launch param variable was\n    /// optimized out or was not found in the modules linked to the pipeline.\n    const char* pipelineLaunchParamsVariableName;\n\n    /// Bit field enabling primitive types. See OptixPrimitiveTypeFlags.\n    /// Setting to zero corresponds to enabling OPTIX_PRIMITIVE_TYPE_FLAGS_CUSTOM and OPTIX_PRIMITIVE_TYPE_FLAGS_TRIANGLE.\n    unsigned int usesPrimitiveTypeFlags;\n\n    // Reserved for future use.These values must be set to zero.\n    unsigned int reserved;\n    size_t reserved2;\n\n} OptixPipelineCompileOptions;\n\n/// Link options for a pipeline\n///\n/// \\see #optixPipelineCreate()\ntypedef struct OptixPipelineLinkOptions\n{\n    /// Maximum trace recursion depth. 0 means a ray generation program can be\n    /// launched, but can't trace any rays. The maximum allowed value is 31.\n    unsigned int maxTraceDepth;\n\n    /// Generate debug information.\n    OptixCompileDebugLevel debugLevel;\n} OptixPipelineLinkOptions;\n\n/// Describes the shader binding table (SBT)\n///\n/// \\see #optixLaunch()\ntypedef struct OptixShaderBindingTable\n{\n    /// Device address of the SBT record of the ray gen program to start launch at. The address must be a multiple of\n    /// OPTIX_SBT_RECORD_ALIGNMENT.\n    CUdeviceptr raygenRecord;\n\n    /// Device address of the SBT record of the exception program. The address must be a multiple of\n    /// OPTIX_SBT_RECORD_ALIGNMENT.\n    CUdeviceptr exceptionRecord;\n\n    /// Arrays of SBT records for miss programs. The base address and the stride must be a multiple of\n    /// OPTIX_SBT_RECORD_ALIGNMENT.\n    /// @{\n    CUdeviceptr  missRecordBase;\n    unsigned int missRecordStrideInBytes;\n    unsigned int missRecordCount;\n    /// @}\n\n    /// Arrays of SBT records for hit groups. The base address and the stride must be a multiple of\n    /// OPTIX_SBT_RECORD_ALIGNMENT.\n    /// @{\n    CUdeviceptr  hitgroupRecordBase;\n    unsigned int hitgroupRecordStrideInBytes;\n    unsigned int hitgroupRecordCount;\n    /// @}\n\n    /// Arrays of SBT records for callable programs. If the base address is not null, the stride and count must not be\n    /// zero. If the base address is null, then the count needs to zero. The base address and the stride must be a\n    /// multiple of OPTIX_SBT_RECORD_ALIGNMENT.\n    /// @{\n    CUdeviceptr  callablesRecordBase;\n    unsigned int callablesRecordStrideInBytes;\n    unsigned int callablesRecordCount;\n    /// @}\n\n} OptixShaderBindingTable;\n\n/// Describes the stack size requirements of a program group.\n///\n/// \\see optixProgramGroupGetStackSize()\ntypedef struct OptixStackSizes\n{\n    /// Continuation stack size of RG programs in bytes\n    unsigned int cssRG;\n    /// Continuation stack size of MS programs in bytes\n    unsigned int cssMS;\n    /// Continuation stack size of CH programs in bytes\n    unsigned int cssCH;\n    /// Continuation stack size of AH programs in bytes\n    unsigned int cssAH;\n    /// Continuation stack size of IS programs in bytes\n    unsigned int cssIS;\n    /// Continuation stack size of CC programs in bytes\n    unsigned int cssCC;\n    /// Direct stack size of DC programs in bytes\n    unsigned int dssDC;\n\n} OptixStackSizes;\n\n/// Options that can be passed to \\c optixQueryFunctionTable()\ntypedef enum OptixQueryFunctionTableOptions\n{\n    /// Placeholder (there are no options yet)\n    OPTIX_QUERY_FUNCTION_TABLE_OPTION_DUMMY = 0\n\n} OptixQueryFunctionTableOptions;\n\n/// Type of the function \\c optixQueryFunctionTable()\ntypedef OptixResult( OptixQueryFunctionTable_t )( int          abiId,\n                                                  unsigned int numOptions,\n                                                  OptixQueryFunctionTableOptions* /*optionKeys*/,\n                                                  const void** /*optionValues*/,\n                                                  void*  functionTable,\n                                                  size_t sizeOfTable );\n\n/// Specifies the options for retrieving an intersection program for a built-in primitive type.\n/// The primitive type must not be OPTIX_PRIMITIVE_TYPE_CUSTOM.\n///\n/// \\see #optixBuiltinISModuleGet()\ntypedef struct OptixBuiltinISOptions\n{\n    OptixPrimitiveType        builtinISModuleType;\n    /// Boolean value indicating whether vertex motion blur is used (but not motion transform blur).\n    int                       usesMotionBlur;\n} OptixBuiltinISOptions;\n\n#if defined( __CUDACC__ )\n/// Describes the ray that was passed into \\c optixTrace() which caused an exception with\n/// exception code OPTIX_EXCEPTION_CODE_INVALID_RAY.\n///\n/// \\see #optixGetExceptionInvalidRay()\ntypedef struct OptixInvalidRayExceptionDetails\n{\n    float3 origin;\n    float3 direction;\n    float  tmin;\n    float  tmax;\n    float  time;\n} OptixInvalidRayExceptionDetails;\n\n/// Describes the details of a call to a callable program which caused an exception with\n/// exception code OPTIX_EXCEPTION_CODE_CALLABLE_PARAMETER_MISMATCH,\n/// Note that OptiX packs the parameters into individual 32 bit values, so the number of\n/// expected and passed values may not correspond to the number of arguments passed into\n/// optixDirectCall or optixContinuationCall, or the number parameters in the definition\n/// of the function that is called.\ntypedef struct OptixParameterMismatchExceptionDetails\n{\n    /// Number of 32 bit values expected by the callable program\n    unsigned int expectedParameterCount;\n    /// Number of 32 bit values that were passed to the callable program\n    unsigned int passedArgumentCount;\n    /// The offset of the SBT entry of the callable program relative to OptixShaderBindingTable::callablesRecordBase\n    unsigned int sbtIndex;\n    /// Pointer to a string that holds the name of the callable program that was called\n    char*        callableName;\n} OptixParameterMismatchExceptionDetails;\n#endif\n\n\n/*@}*/  // end group optix_types\n\n#endif  // __optix_optix_7_types_h__\n"
  },
  {
    "path": "render/optixutils/include/optix_denoiser_tiling.h",
    "content": "/*\n * Copyright (c) 2021 NVIDIA Corporation.  All rights reserved.\n *\n * Redistribution and use in source and binary forms, with or without\n * modification, are permitted provided that the following conditions\n * are met:\n *  * Redistributions of source code must retain the above copyright\n *    notice, this list of conditions and the following disclaimer.\n *  * Redistributions in binary form must reproduce the above copyright\n *    notice, this list of conditions and the following disclaimer in the\n *    documentation and/or other materials provided with the distribution.\n *  * Neither the name of NVIDIA CORPORATION nor the names of its\n *    contributors may be used to endorse or promote products derived\n *    from this software without specific prior written permission.\n *\n * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY\n * EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE\n * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR\n * PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE COPYRIGHT OWNER OR\n * CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,\n * EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,\n * PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR\n * PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY\n * OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT\n * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE\n * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.\n */\n\n/// @file\n/// @author NVIDIA Corporation\n/// @brief  OptiX public API header\n\n#ifndef optix_denoiser_tiling_h\n#define optix_denoiser_tiling_h\n\n\n#include <optix.h>\n\n#include <algorithm>\n#include <vector>\n\n#ifdef __cplusplus\nextern \"C\" {\n#endif\n\n/** \\addtogroup optix_utilities\n@{\n*/\n\n/// Tile definition\n///\n/// see #optixUtilDenoiserSplitImage\n///\nstruct OptixUtilDenoiserImageTile\n{\n    // input tile image\n    OptixImage2D input;\n\n    // output tile image\n    OptixImage2D output;\n\n    // overlap offsets, parameters for #optixUtilDenoiserInvoke\n    unsigned int inputOffsetX;\n    unsigned int inputOffsetY;\n};\n\n/// Return pixel stride in bytes for the given pixel format\n/// if the pixelStrideInBytes member of the image is zero.\n/// Otherwise return pixelStrideInBytes from the image.\n///\n/// \\param[in]                  image Image containing the pixel stride\n///\ninline unsigned int optixUtilGetPixelStride( const OptixImage2D& image )\n{\n    unsigned int pixelStrideInBytes = image.pixelStrideInBytes;\n    if( pixelStrideInBytes == 0 )\n    {\n        switch( image.format )\n        {\n            case OPTIX_PIXEL_FORMAT_HALF2:\n                pixelStrideInBytes = 2 * sizeof( short );\n                break;\n            case OPTIX_PIXEL_FORMAT_HALF3:\n                pixelStrideInBytes = 3 * sizeof( short );\n                break;\n            case OPTIX_PIXEL_FORMAT_HALF4:\n                pixelStrideInBytes = 4 * sizeof( short );\n                break;\n            case OPTIX_PIXEL_FORMAT_FLOAT2:\n                pixelStrideInBytes = 2 * sizeof( float );\n                break;\n            case OPTIX_PIXEL_FORMAT_FLOAT3:\n                pixelStrideInBytes = 3 * sizeof( float );\n                break;\n            case OPTIX_PIXEL_FORMAT_FLOAT4:\n                pixelStrideInBytes = 4 * sizeof( float );\n                break;\n            case OPTIX_PIXEL_FORMAT_UCHAR3:\n                pixelStrideInBytes = 3 * sizeof( char );\n                break;\n            case OPTIX_PIXEL_FORMAT_UCHAR4:\n                pixelStrideInBytes = 4 * sizeof( char );\n                break;\n        }\n    }\n    return pixelStrideInBytes;\n}\n\n/// Split image into 2D tiles given horizontal and vertical tile size\n///\n/// \\param[in]  input            full resolution input image to be split\n/// \\param[in]  output           full resolution output image\n/// \\param[in]  overlapWindowSizeInPixels    see #OptixDenoiserSizes, #optixDenoiserComputeMemoryResources\n/// \\param[in]  tileWidth        maximum width of tiles\n/// \\param[in]  tileHeight       maximum height of tiles\n/// \\param[out] tiles            list of tiles covering the input image\n///\ninline OptixResult optixUtilDenoiserSplitImage(\n                                               const OptixImage2D&                     input,\n                                               const OptixImage2D&                     output,\n                                               unsigned int                            overlapWindowSizeInPixels,\n                                               unsigned int                            tileWidth,\n                                               unsigned int                            tileHeight,\n                                               std::vector<OptixUtilDenoiserImageTile>&    tiles )\n{\n    if( tileWidth == 0 || tileHeight == 0 )\n        return OPTIX_ERROR_INVALID_VALUE;\n\n    unsigned int inPixelStride  = optixUtilGetPixelStride( input );\n    unsigned int outPixelStride = optixUtilGetPixelStride( output );\n\n    int inp_w = std::min( tileWidth + 2 * overlapWindowSizeInPixels, input.width );\n    int inp_h = std::min( tileHeight + 2 * overlapWindowSizeInPixels, input.height );\n    int inp_y = 0, copied_y = 0;\n\n    do\n    {\n        int inputOffsetY = inp_y == 0 ? 0 : std::max( (int)overlapWindowSizeInPixels, inp_h - ( (int)input.height - inp_y ) );\n        int copy_y       = inp_y == 0 ? std::min( input.height, tileHeight + overlapWindowSizeInPixels ) :\n                                  std::min( tileHeight, input.height - copied_y );\n\n        int inp_x = 0, copied_x = 0;\n        do\n        {\n            int inputOffsetX = inp_x == 0 ? 0 : std::max( (int)overlapWindowSizeInPixels, inp_w - ( (int)input.width - inp_x ) );\n            int copy_x = inp_x == 0 ? std::min( input.width, tileWidth + overlapWindowSizeInPixels ) :\n                                      std::min( tileWidth, input.width - copied_x );\n\n            OptixUtilDenoiserImageTile tile;\n            tile.input.data               = input.data + ( inp_y - inputOffsetY ) * input.rowStrideInBytes\n                                            + ( inp_x - inputOffsetX ) * inPixelStride;\n            tile.input.width              = inp_w;\n            tile.input.height             = inp_h;\n            tile.input.rowStrideInBytes   = input.rowStrideInBytes;\n            tile.input.pixelStrideInBytes = input.pixelStrideInBytes;\n            tile.input.format             = input.format;\n\n            tile.output.data               = output.data + inp_y * output.rowStrideInBytes + inp_x * outPixelStride;\n            tile.output.width              = copy_x;\n            tile.output.height             = copy_y;\n            tile.output.rowStrideInBytes   = output.rowStrideInBytes;\n            tile.output.pixelStrideInBytes = output.pixelStrideInBytes;\n            tile.output.format             = output.format;\n\n            tile.inputOffsetX = inputOffsetX;\n            tile.inputOffsetY = inputOffsetY;\n            tiles.push_back( tile );\n\n            inp_x += inp_x == 0 ? tileWidth + overlapWindowSizeInPixels : tileWidth;\n            copied_x += copy_x;\n        } while( inp_x < static_cast<int>( input.width ) );\n\n        inp_y += inp_y == 0 ? tileHeight + overlapWindowSizeInPixels : tileHeight;\n        copied_y += copy_y;\n    } while( inp_y < static_cast<int>( input.height ) );\n\n    return OPTIX_SUCCESS;\n}\n\n/// Run denoiser on input layers\n/// see #optixDenoiserInvoke\n/// additional parameters:\n\n/// Runs the denoiser on the input layers on a single GPU and stream using #optixDenoiserInvoke.\n/// If the input layers' dimensions are larger than the specified tile size, the image is divided into\n/// tiles using #optixUtilDenoiserSplitImage, and multiple back-to-back invocations are performed in\n/// order to reuse the scratch space.  Multiple tiles can be invoked concurrently if\n/// #optixUtilDenoiserSplitImage is used directly and multiple scratch allocations for each concurrent\n/// invocation are used.\n\n/// The input parameters are the same as #optixDenoiserInvoke except for the addition of the maximum tile size.\n///\n/// \\param[in] denoiser\n/// \\param[in] stream\n/// \\param[in] params\n/// \\param[in] denoiserState\n/// \\param[in] denoiserStateSizeInBytes\n/// \\param[in] guideLayer\n/// \\param[in] layers\n/// \\param[in] numLayers\n/// \\param[in] scratch\n/// \\param[in] scratchSizeInBytes\n/// \\param[in] overlapWindowSizeInPixels\n/// \\param[in] tileWidth\n/// \\param[in] tileHeight\ninline OptixResult optixUtilDenoiserInvokeTiled(\n                                                OptixDenoiser                   denoiser,\n                                                CUstream                        stream,\n                                                const OptixDenoiserParams*      params,\n                                                CUdeviceptr                     denoiserState,\n                                                size_t                          denoiserStateSizeInBytes,\n                                                const OptixDenoiserGuideLayer*  guideLayer,\n                                                const OptixDenoiserLayer*       layers,\n                                                unsigned int                    numLayers,\n                                                CUdeviceptr                     scratch,\n                                                size_t                          scratchSizeInBytes,\n                                                unsigned int                    overlapWindowSizeInPixels,\n                                                unsigned int                    tileWidth,\n                                                unsigned int                    tileHeight )\n{\n    if( !guideLayer || !layers )\n        return OPTIX_ERROR_INVALID_VALUE;\n\n    std::vector<std::vector<OptixUtilDenoiserImageTile>> tiles( numLayers );\n    std::vector<std::vector<OptixUtilDenoiserImageTile>> prevTiles( numLayers );\n    for( unsigned int l = 0; l < numLayers; l++ )\n    {\n        if( const OptixResult res = optixUtilDenoiserSplitImage( layers[l].input, layers[l].output,\n                                                                 overlapWindowSizeInPixels,\n                                                                 tileWidth, tileHeight, tiles[l] ) )\n            return res;\n\n        if( layers[l].previousOutput.data )\n        {\n            OptixImage2D dummyOutput = layers[l].previousOutput;\n            if( const OptixResult res = optixUtilDenoiserSplitImage( layers[l].previousOutput, dummyOutput,\n                                                                 overlapWindowSizeInPixels,\n                                                                 tileWidth, tileHeight, prevTiles[l] ) )\n                return res;\n        }\n    }\n\n    std::vector<OptixUtilDenoiserImageTile> albedoTiles;\n    if( guideLayer->albedo.data )\n    {\n        OptixImage2D dummyOutput = guideLayer->albedo;\n        if( const OptixResult res = optixUtilDenoiserSplitImage( guideLayer->albedo, dummyOutput,\n                                                                 overlapWindowSizeInPixels,\n                                                                 tileWidth, tileHeight, albedoTiles ) )\n            return res;\n    }\n\n    std::vector<OptixUtilDenoiserImageTile> normalTiles;\n    if( guideLayer->normal.data )\n    {\n        OptixImage2D dummyOutput = guideLayer->normal;\n        if( const OptixResult res = optixUtilDenoiserSplitImage( guideLayer->normal, dummyOutput,\n                                                                 overlapWindowSizeInPixels,\n                                                                 tileWidth, tileHeight, normalTiles ) )\n            return res;\n    }\n    std::vector<OptixUtilDenoiserImageTile> flowTiles;\n    if( guideLayer->flow.data )\n    {\n        OptixImage2D dummyOutput = guideLayer->flow;\n        if( const OptixResult res = optixUtilDenoiserSplitImage( guideLayer->flow, dummyOutput,\n                                                                 overlapWindowSizeInPixels,\n                                                                 tileWidth, tileHeight, flowTiles ) )\n            return res;\n    }\n\n    for( size_t t = 0; t < tiles[0].size(); t++ )\n    {\n        std::vector<OptixDenoiserLayer> tlayers;\n        for( unsigned int l = 0; l < numLayers; l++ )\n        {\n            OptixDenoiserLayer layer = {};\n            layer.input  = ( tiles[l] )[t].input;\n            layer.output = ( tiles[l] )[t].output;\n            if( layers[l].previousOutput.data )\n                layer.previousOutput = ( prevTiles[l] )[t].input;\n            tlayers.push_back( layer );\n        }\n\n        OptixDenoiserGuideLayer gl = {};\n        if( guideLayer->albedo.data )\n            gl.albedo = albedoTiles[t].input;\n\n        if( guideLayer->normal.data )\n            gl.normal = normalTiles[t].input;\n\n        if( guideLayer->flow.data )\n            gl.flow = flowTiles[t].input;\n\n        if( const OptixResult res =\n                optixDenoiserInvoke( denoiser, stream, params, denoiserState, denoiserStateSizeInBytes,\n                                     &gl, &tlayers[0], numLayers,\n                                     ( tiles[0] )[t].inputOffsetX, ( tiles[0] )[t].inputOffsetY,\n                                     scratch, scratchSizeInBytes ) )\n            return res;\n    }\n    return OPTIX_SUCCESS;\n}\n\n/*@}*/  // end group optix_utilities\n\n#ifdef __cplusplus\n}\n#endif\n\n#endif  // __optix_optix_stack_size_h__\n"
  },
  {
    "path": "render/optixutils/include/optix_device.h",
    "content": "\n/*\n * Copyright (c) 2021 NVIDIA Corporation.  All rights reserved.\n *\n * NVIDIA Corporation and its licensors retain all intellectual property and proprietary\n * rights in and to this software, related documentation and any modifications thereto.\n * Any use, reproduction, disclosure or distribution of this software and related\n * documentation without an express license agreement from NVIDIA Corporation is strictly\n * prohibited.\n *\n * TO THE MAXIMUM EXTENT PERMITTED BY APPLICABLE LAW, THIS SOFTWARE IS PROVIDED *AS IS*\n * AND NVIDIA AND ITS SUPPLIERS DISCLAIM ALL WARRANTIES, EITHER EXPRESS OR IMPLIED,\n * INCLUDING, BUT NOT LIMITED TO, IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A\n * PARTICULAR PURPOSE.  IN NO EVENT SHALL NVIDIA OR ITS SUPPLIERS BE LIABLE FOR ANY\n * SPECIAL, INCIDENTAL, INDIRECT, OR CONSEQUENTIAL DAMAGES WHATSOEVER (INCLUDING, WITHOUT\n * LIMITATION, DAMAGES FOR LOSS OF BUSINESS PROFITS, BUSINESS INTERRUPTION, LOSS OF\n * BUSINESS INFORMATION, OR ANY OTHER PECUNIARY LOSS) ARISING OUT OF THE USE OF OR\n * INABILITY TO USE THIS SOFTWARE, EVEN IF NVIDIA HAS BEEN ADVISED OF THE POSSIBILITY OF\n * SUCH DAMAGES\n */\n\n /**\n * @file   optix_device.h\n * @author NVIDIA Corporation\n * @brief  OptiX public API\n *\n * OptiX public API Reference - Host/Device side\n */\n\n/******************************************************************************\\\n * optix_cuda.h\n *\n * This file provides the nvcc interface for generating PTX that the OptiX is\n * capable of parsing and weaving into the final kernel.  This is included by\n * optix.h automatically if compiling device code.  It can be included explicitly\n * in host code if desired.\n *\n\\******************************************************************************/\n#if !defined(__OPTIX_INCLUDE_INTERNAL_HEADERS__)\n#  define __OPTIX_INCLUDE_INTERNAL_HEADERS__\n#  define __UNDEF_OPTIX_INCLUDE_INTERNAL_HEADERS_OPTIX_DEVICE_H__\n#endif\n#include \"optix_7_device.h\"\n#if defined( __UNDEF_OPTIX_INCLUDE_INTERNAL_HEADERS_OPTIX_DEVICE_H__ )\n#  undef __OPTIX_INCLUDE_INTERNAL_HEADERS__\n#  undef __UNDEF_OPTIX_INCLUDE_INTERNAL_HEADERS_OPTIX_DEVICE_H__\n#endif\n"
  },
  {
    "path": "render/optixutils/include/optix_function_table.h",
    "content": "/*\n * Copyright (c) 2021 NVIDIA Corporation.  All rights reserved.\n *\n * NVIDIA Corporation and its licensors retain all intellectual property and proprietary\n * rights in and to this software, related documentation and any modifications thereto.\n * Any use, reproduction, disclosure or distribution of this software and related\n * documentation without an express license agreement from NVIDIA Corporation is strictly\n * prohibited.\n *\n * TO THE MAXIMUM EXTENT PERMITTED BY APPLICABLE LAW, THIS SOFTWARE IS PROVIDED *AS IS*\n * AND NVIDIA AND ITS SUPPLIERS DISCLAIM ALL WARRANTIES, EITHER EXPRESS OR IMPLIED,\n * INCLUDING, BUT NOT LIMITED TO, IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A\n * PARTICULAR PURPOSE.  IN NO EVENT SHALL NVIDIA OR ITS SUPPLIERS BE LIABLE FOR ANY\n * SPECIAL, INCIDENTAL, INDIRECT, OR CONSEQUENTIAL DAMAGES WHATSOEVER (INCLUDING, WITHOUT\n * LIMITATION, DAMAGES FOR LOSS OF BUSINESS PROFITS, BUSINESS INTERRUPTION, LOSS OF\n * BUSINESS INFORMATION, OR ANY OTHER PECUNIARY LOSS) ARISING OUT OF THE USE OF OR\n * INABILITY TO USE THIS SOFTWARE, EVEN IF NVIDIA HAS BEEN ADVISED OF THE POSSIBILITY OF\n * SUCH DAMAGES\n */\n\n/// @file\n/// @author NVIDIA Corporation\n/// @brief  OptiX public API header\n\n#ifndef __optix_optix_function_table_h__\n#define __optix_optix_function_table_h__\n\n/// The OptiX ABI version.\n#define OPTIX_ABI_VERSION 47\n\n#ifndef OPTIX_DEFINE_ABI_VERSION_ONLY\n\n#include \"optix_types.h\"\n\n#if !defined( OPTIX_DONT_INCLUDE_CUDA )\n// If OPTIX_DONT_INCLUDE_CUDA is defined, cuda driver types must be defined through other\n// means before including optix headers.\n#include <cuda.h>\n#endif\n\n#ifdef __cplusplus\nextern \"C\" {\n#endif\n\n/// \\defgroup optix_function_table Function Table\n/// \\brief OptiX Function Table\n\n/** \\addtogroup optix_function_table\n@{\n*/\n\n/// The function table containing all API functions.\n///\n/// See #optixInit() and #optixInitWithHandle().\ntypedef struct OptixFunctionTable\n{\n    /// \\name Error handling\n    //@ {\n\n    /// See ::optixGetErrorName().\n    const char* ( *optixGetErrorName )( OptixResult result );\n\n    /// See ::optixGetErrorString().\n    const char* ( *optixGetErrorString )( OptixResult result );\n\n    //@ }\n    /// \\name Device context\n    //@ {\n\n    /// See ::optixDeviceContextCreate().\n    OptixResult ( *optixDeviceContextCreate )( CUcontext fromContext, const OptixDeviceContextOptions* options, OptixDeviceContext* context );\n\n    /// See ::optixDeviceContextDestroy().\n    OptixResult ( *optixDeviceContextDestroy )( OptixDeviceContext context );\n\n    /// See ::optixDeviceContextGetProperty().\n    OptixResult ( *optixDeviceContextGetProperty )( OptixDeviceContext context, OptixDeviceProperty property, void* value, size_t sizeInBytes );\n\n    /// See ::optixDeviceContextSetLogCallback().\n    OptixResult ( *optixDeviceContextSetLogCallback )( OptixDeviceContext context,\n                                                       OptixLogCallback   callbackFunction,\n                                                       void*              callbackData,\n                                                       unsigned int       callbackLevel );\n\n    /// See ::optixDeviceContextSetCacheEnabled().\n    OptixResult ( *optixDeviceContextSetCacheEnabled )( OptixDeviceContext context, int enabled );\n\n    /// See ::optixDeviceContextSetCacheLocation().\n    OptixResult ( *optixDeviceContextSetCacheLocation )( OptixDeviceContext context, const char* location );\n\n    /// See ::optixDeviceContextSetCacheDatabaseSizes().\n    OptixResult ( *optixDeviceContextSetCacheDatabaseSizes )( OptixDeviceContext context, size_t lowWaterMark, size_t highWaterMark );\n\n    /// See ::optixDeviceContextGetCacheEnabled().\n    OptixResult ( *optixDeviceContextGetCacheEnabled )( OptixDeviceContext context, int* enabled );\n\n    /// See ::optixDeviceContextGetCacheLocation().\n    OptixResult ( *optixDeviceContextGetCacheLocation )( OptixDeviceContext context, char* location, size_t locationSize );\n\n    /// See ::optixDeviceContextGetCacheDatabaseSizes().\n    OptixResult ( *optixDeviceContextGetCacheDatabaseSizes )( OptixDeviceContext context, size_t* lowWaterMark, size_t* highWaterMark );\n\n    //@ }\n    /// \\name Modules\n    //@ {\n\n    /// See ::optixModuleCreateFromPTX().\n    OptixResult ( *optixModuleCreateFromPTX )( OptixDeviceContext                 context,\n                                               const OptixModuleCompileOptions*   moduleCompileOptions,\n                                               const OptixPipelineCompileOptions* pipelineCompileOptions,\n                                               const char*                        PTX,\n                                               size_t                             PTXsize,\n                                               char*                              logString,\n                                               size_t*                            logStringSize,\n                                               OptixModule*                       module );\n\n    /// See ::optixModuleDestroy().\n    OptixResult ( *optixModuleDestroy )( OptixModule module );\n\n    /// See ::optixBuiltinISModuleGet().\n    OptixResult( *optixBuiltinISModuleGet )( OptixDeviceContext                 context,\n                                             const OptixModuleCompileOptions*   moduleCompileOptions,\n                                             const OptixPipelineCompileOptions* pipelineCompileOptions,\n                                             const OptixBuiltinISOptions*       builtinISOptions,\n                                             OptixModule*                       builtinModule);\n\n    //@ }\n    /// \\name Program groups\n    //@ {\n\n    /// See ::optixProgramGroupCreate().\n    OptixResult ( *optixProgramGroupCreate )( OptixDeviceContext              context,\n                                              const OptixProgramGroupDesc*    programDescriptions,\n                                              unsigned int                    numProgramGroups,\n                                              const OptixProgramGroupOptions* options,\n                                              char*                           logString,\n                                              size_t*                         logStringSize,\n                                              OptixProgramGroup*              programGroups );\n\n    /// See ::optixProgramGroupDestroy().\n    OptixResult ( *optixProgramGroupDestroy )( OptixProgramGroup programGroup );\n\n    /// See ::optixProgramGroupGetStackSize().\n    OptixResult ( *optixProgramGroupGetStackSize )( OptixProgramGroup programGroup, OptixStackSizes* stackSizes );\n\n    //@ }\n    /// \\name Pipeline\n    //@ {\n\n    /// See ::optixPipelineCreate().\n    OptixResult ( *optixPipelineCreate )( OptixDeviceContext                 context,\n                                          const OptixPipelineCompileOptions* pipelineCompileOptions,\n                                          const OptixPipelineLinkOptions*    pipelineLinkOptions,\n                                          const OptixProgramGroup*           programGroups,\n                                          unsigned int                       numProgramGroups,\n                                          char*                              logString,\n                                          size_t*                            logStringSize,\n                                          OptixPipeline*                     pipeline );\n\n    /// See ::optixPipelineDestroy().\n    OptixResult ( *optixPipelineDestroy )( OptixPipeline pipeline );\n\n    /// See ::optixPipelineSetStackSize().\n    OptixResult ( *optixPipelineSetStackSize )( OptixPipeline pipeline,\n                                                unsigned int  directCallableStackSizeFromTraversal,\n                                                unsigned int  directCallableStackSizeFromState,\n                                                unsigned int  continuationStackSize,\n                                                unsigned int  maxTraversableGraphDepth );\n\n    //@ }\n    /// \\name Acceleration structures\n    //@ {\n\n    /// See ::optixAccelComputeMemoryUsage().\n    OptixResult ( *optixAccelComputeMemoryUsage )( OptixDeviceContext            context,\n                                                   const OptixAccelBuildOptions* accelOptions,\n                                                   const OptixBuildInput*        buildInputs,\n                                                   unsigned int                  numBuildInputs,\n                                                   OptixAccelBufferSizes*        bufferSizes );\n\n    /// See ::optixAccelBuild().\n    OptixResult ( *optixAccelBuild )( OptixDeviceContext            context,\n                                      CUstream                      stream,\n                                      const OptixAccelBuildOptions* accelOptions,\n                                      const OptixBuildInput*        buildInputs,\n                                      unsigned int                  numBuildInputs,\n                                      CUdeviceptr                   tempBuffer,\n                                      size_t                        tempBufferSizeInBytes,\n                                      CUdeviceptr                   outputBuffer,\n                                      size_t                        outputBufferSizeInBytes,\n                                      OptixTraversableHandle*       outputHandle,\n                                      const OptixAccelEmitDesc*     emittedProperties,\n                                      unsigned int                  numEmittedProperties );\n\n    /// See ::optixAccelGetRelocationInfo().\n    OptixResult ( *optixAccelGetRelocationInfo )( OptixDeviceContext context, OptixTraversableHandle handle, OptixAccelRelocationInfo* info );\n\n\n    /// See ::optixAccelCheckRelocationCompatibility().\n    OptixResult ( *optixAccelCheckRelocationCompatibility )( OptixDeviceContext              context,\n                                                             const OptixAccelRelocationInfo* info,\n                                                             int*                            compatible );\n\n    /// See ::optixAccelRelocate().\n    OptixResult ( *optixAccelRelocate )( OptixDeviceContext              context,\n                                         CUstream                        stream,\n                                         const OptixAccelRelocationInfo* info,\n                                         CUdeviceptr                     instanceTraversableHandles,\n                                         size_t                          numInstanceTraversableHandles,\n                                         CUdeviceptr                     targetAccel,\n                                         size_t                          targetAccelSizeInBytes,\n                                         OptixTraversableHandle*         targetHandle );\n\n\n    /// See ::optixAccelCompact().\n    OptixResult ( *optixAccelCompact )( OptixDeviceContext      context,\n                                        CUstream                stream,\n                                        OptixTraversableHandle  inputHandle,\n                                        CUdeviceptr             outputBuffer,\n                                        size_t                  outputBufferSizeInBytes,\n                                        OptixTraversableHandle* outputHandle );\n\n    /// See ::optixConvertPointerToTraversableHandle().\n    OptixResult ( *optixConvertPointerToTraversableHandle )( OptixDeviceContext      onDevice,\n                                                             CUdeviceptr             pointer,\n                                                             OptixTraversableType    traversableType,\n                                                             OptixTraversableHandle* traversableHandle );\n\n    //@ }\n    /// \\name Launch\n    //@ {\n\n    /// See ::optixConvertPointerToTraversableHandle().\n    OptixResult ( *optixSbtRecordPackHeader )( OptixProgramGroup programGroup, void* sbtRecordHeaderHostPointer );\n\n    /// See ::optixConvertPointerToTraversableHandle().\n    OptixResult ( *optixLaunch )( OptixPipeline                  pipeline,\n                                  CUstream                       stream,\n                                  CUdeviceptr                    pipelineParams,\n                                  size_t                         pipelineParamsSize,\n                                  const OptixShaderBindingTable* sbt,\n                                  unsigned int                   width,\n                                  unsigned int                   height,\n                                  unsigned int                   depth );\n\n    //@ }\n    /// \\name Denoiser\n    //@ {\n\n    /// See ::optixDenoiserCreate().\n    OptixResult ( *optixDenoiserCreate )( OptixDeviceContext context, OptixDenoiserModelKind modelKind, const OptixDenoiserOptions* options, OptixDenoiser* returnHandle );\n\n    /// See ::optixDenoiserDestroy().\n    OptixResult ( *optixDenoiserDestroy )( OptixDenoiser handle );\n\n    /// See ::optixDenoiserComputeMemoryResources().\n    OptixResult ( *optixDenoiserComputeMemoryResources )( const OptixDenoiser handle,\n                                                          unsigned int        maximumInputWidth,\n                                                          unsigned int        maximumInputHeight,\n                                                          OptixDenoiserSizes* returnSizes );\n\n    /// See ::optixDenoiserSetup().\n    OptixResult ( *optixDenoiserSetup )( OptixDenoiser denoiser,\n                                         CUstream      stream,\n                                         unsigned int  inputWidth,\n                                         unsigned int  inputHeight,\n                                         CUdeviceptr   state,\n                                         size_t        stateSizeInBytes,\n                                         CUdeviceptr   scratch,\n                                         size_t        scratchSizeInBytes );\n\n    /// See ::optixDenoiserInvoke().\n    OptixResult ( *optixDenoiserInvoke )( OptixDenoiser                   denoiser,\n                                          CUstream                        stream,\n                                          const OptixDenoiserParams*      params,\n                                          CUdeviceptr                     denoiserState,\n                                          size_t                          denoiserStateSizeInBytes,\n                                          const OptixDenoiserGuideLayer * guideLayer,\n                                          const OptixDenoiserLayer *      layers,\n                                          unsigned int                    numLayers,\n                                          unsigned int                    inputOffsetX,\n                                          unsigned int                    inputOffsetY,\n                                          CUdeviceptr                     scratch,\n                                          size_t                          scratchSizeInBytes );\n\n    /// See ::optixDenoiserComputeIntensity().\n    OptixResult ( *optixDenoiserComputeIntensity )( OptixDenoiser       handle,\n                                                    CUstream            stream,\n                                                    const OptixImage2D* inputImage,\n                                                    CUdeviceptr         outputIntensity,\n                                                    CUdeviceptr         scratch,\n                                                    size_t              scratchSizeInBytes );\n\n    /// See ::optixDenoiserComputeAverageColor().\n    OptixResult ( *optixDenoiserComputeAverageColor )( OptixDenoiser       handle,\n                                                       CUstream            stream,\n                                                       const OptixImage2D* inputImage,\n                                                       CUdeviceptr         outputAverageColor,\n                                                       CUdeviceptr         scratch,\n                                                       size_t              scratchSizeInBytes );\n\n    /// See ::optixDenoiserCreateWithUserModel().\n    OptixResult ( *optixDenoiserCreateWithUserModel )( OptixDeviceContext context, const void * data, size_t dataSizeInBytes, OptixDenoiser* returnHandle );\n    //@ }\n\n} OptixFunctionTable;\n\n/*@}*/  // end group optix_function_table\n\n#ifdef __cplusplus\n}\n#endif\n\n#endif /* OPTIX_DEFINE_ABI_VERSION_ONLY */\n\n#endif /* __optix_optix_function_table_h__ */\n"
  },
  {
    "path": "render/optixutils/include/optix_function_table_definition.h",
    "content": "/*\n * Copyright (c) 2021 NVIDIA Corporation.  All rights reserved.\n *\n * NVIDIA Corporation and its licensors retain all intellectual property and proprietary\n * rights in and to this software, related documentation and any modifications thereto.\n * Any use, reproduction, disclosure or distribution of this software and related\n * documentation without an express license agreement from NVIDIA Corporation is strictly\n * prohibited.\n *\n * TO THE MAXIMUM EXTENT PERMITTED BY APPLICABLE LAW, THIS SOFTWARE IS PROVIDED *AS IS*\n * AND NVIDIA AND ITS SUPPLIERS DISCLAIM ALL WARRANTIES, EITHER EXPRESS OR IMPLIED,\n * INCLUDING, BUT NOT LIMITED TO, IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A\n * PARTICULAR PURPOSE.  IN NO EVENT SHALL NVIDIA OR ITS SUPPLIERS BE LIABLE FOR ANY\n * SPECIAL, INCIDENTAL, INDIRECT, OR CONSEQUENTIAL DAMAGES WHATSOEVER (INCLUDING, WITHOUT\n * LIMITATION, DAMAGES FOR LOSS OF BUSINESS PROFITS, BUSINESS INTERRUPTION, LOSS OF\n * BUSINESS INFORMATION, OR ANY OTHER PECUNIARY LOSS) ARISING OUT OF THE USE OF OR\n * INABILITY TO USE THIS SOFTWARE, EVEN IF NVIDIA HAS BEEN ADVISED OF THE POSSIBILITY OF\n * SUCH DAMAGES\n */\n\n/// @file\n/// @author NVIDIA Corporation\n/// @brief  OptiX public API header\n\n#ifndef __optix_optix_function_table_definition_h__\n#define __optix_optix_function_table_definition_h__\n\n#include \"optix_function_table.h\"\n\n#ifdef __cplusplus\nextern \"C\" {\n#endif\n\n/** \\addtogroup optix_function_table\n@{\n*/\n\n/// If the stubs in optix_stubs.h are used, then the function table needs to be defined in exactly\n/// one translation unit. This can be achieved by including this header file in that translation\n/// unit.\nOptixFunctionTable g_optixFunctionTable;\n\n/*@}*/  // end group optix_function_table\n\n#ifdef __cplusplus\n}\n#endif\n\n#endif  // __optix_optix_function_table_definition_h__\n"
  },
  {
    "path": "render/optixutils/include/optix_host.h",
    "content": "\n/*\n * Copyright (c) 2021 NVIDIA Corporation.  All rights reserved.\n *\n * NVIDIA Corporation and its licensors retain all intellectual property and proprietary\n * rights in and to this software, related documentation and any modifications thereto.\n * Any use, reproduction, disclosure or distribution of this software and related\n * documentation without an express license agreement from NVIDIA Corporation is strictly\n * prohibited.\n *\n * TO THE MAXIMUM EXTENT PERMITTED BY APPLICABLE LAW, THIS SOFTWARE IS PROVIDED *AS IS*\n * AND NVIDIA AND ITS SUPPLIERS DISCLAIM ALL WARRANTIES, EITHER EXPRESS OR IMPLIED,\n * INCLUDING, BUT NOT LIMITED TO, IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A\n * PARTICULAR PURPOSE.  IN NO EVENT SHALL NVIDIA OR ITS SUPPLIERS BE LIABLE FOR ANY\n * SPECIAL, INCIDENTAL, INDIRECT, OR CONSEQUENTIAL DAMAGES WHATSOEVER (INCLUDING, WITHOUT\n * LIMITATION, DAMAGES FOR LOSS OF BUSINESS PROFITS, BUSINESS INTERRUPTION, LOSS OF\n * BUSINESS INFORMATION, OR ANY OTHER PECUNIARY LOSS) ARISING OUT OF THE USE OF OR\n * INABILITY TO USE THIS SOFTWARE, EVEN IF NVIDIA HAS BEEN ADVISED OF THE POSSIBILITY OF\n * SUCH DAMAGES\n */\n\n/**\n * @file   optix_host.h\n * @author NVIDIA Corporation\n * @brief  OptiX public API\n *\n * OptiX public API Reference - Host side\n */\n\n#if !defined(__OPTIX_INCLUDE_INTERNAL_HEADERS__)\n#  define __OPTIX_INCLUDE_INTERNAL_HEADERS__\n#  define __UNDEF_OPTIX_INCLUDE_INTERNAL_HEADERS_OPTIX_HOST_H__\n#endif\n#include \"optix_7_host.h\"\n#if defined( __UNDEF_OPTIX_INCLUDE_INTERNAL_HEADERS_OPTIX_HOST_H__ )\n#  undef __OPTIX_INCLUDE_INTERNAL_HEADERS__\n#  undef __UNDEF_OPTIX_INCLUDE_INTERNAL_HEADERS_OPTIX_HOST_H__\n#endif\n"
  },
  {
    "path": "render/optixutils/include/optix_stack_size.h",
    "content": "/*\n * Copyright (c) 2021 NVIDIA Corporation.  All rights reserved.\n *\n * Redistribution and use in source and binary forms, with or without\n * modification, are permitted provided that the following conditions\n * are met:\n *  * Redistributions of source code must retain the above copyright\n *    notice, this list of conditions and the following disclaimer.\n *  * Redistributions in binary form must reproduce the above copyright\n *    notice, this list of conditions and the following disclaimer in the\n *    documentation and/or other materials provided with the distribution.\n *  * Neither the name of NVIDIA CORPORATION nor the names of its\n *    contributors may be used to endorse or promote products derived\n *    from this software without specific prior written permission.\n *\n * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY\n * EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE\n * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR\n * PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE COPYRIGHT OWNER OR\n * CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,\n * EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,\n * PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR\n * PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY\n * OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT\n * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE\n * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.\n */\n\n/// @file\n/// @author NVIDIA Corporation\n/// @brief  OptiX public API header\n\n#ifndef __optix_optix_stack_size_h__\n#define __optix_optix_stack_size_h__\n\n#include \"optix.h\"\n\n#include <algorithm>\n#include <cstring>\n\n#ifdef __cplusplus\nextern \"C\" {\n#endif\n\n/** \\addtogroup optix_utilities\n@{\n*/\n\n/// Retrieves direct and continuation stack sizes for each program in the program group and accumulates the upper bounds\n/// in the correponding output variables based on the semantic type of the program. Before the first invocation of this\n/// function with a given instance of #OptixStackSizes, the members of that instance should be set to 0.\ninline OptixResult optixUtilAccumulateStackSizes( OptixProgramGroup programGroup, OptixStackSizes* stackSizes )\n{\n    if( !stackSizes )\n        return OPTIX_ERROR_INVALID_VALUE;\n\n    OptixStackSizes localStackSizes;\n    OptixResult     result = optixProgramGroupGetStackSize( programGroup, &localStackSizes );\n    if( result != OPTIX_SUCCESS )\n        return result;\n\n    stackSizes->cssRG = std::max( stackSizes->cssRG, localStackSizes.cssRG );\n    stackSizes->cssMS = std::max( stackSizes->cssMS, localStackSizes.cssMS );\n    stackSizes->cssCH = std::max( stackSizes->cssCH, localStackSizes.cssCH );\n    stackSizes->cssAH = std::max( stackSizes->cssAH, localStackSizes.cssAH );\n    stackSizes->cssIS = std::max( stackSizes->cssIS, localStackSizes.cssIS );\n    stackSizes->cssCC = std::max( stackSizes->cssCC, localStackSizes.cssCC );\n    stackSizes->dssDC = std::max( stackSizes->dssDC, localStackSizes.dssDC );\n\n    return OPTIX_SUCCESS;\n}\n\n/// Computes the stack size values needed to configure a pipeline.\n///\n/// See the programming guide for an explanation of the formula.\n///\n/// \\param[in] stackSizes                              Accumulated stack sizes of all programs in the call graph.\n/// \\param[in] maxTraceDepth                           Maximum depth of #optixTrace() calls.\n/// \\param[in] maxCCDepth                              Maximum depth of calls trees of continuation callables.\n/// \\param[in] maxDCDepth                              Maximum depth of calls trees of direct callables.\n/// \\param[out] directCallableStackSizeFromTraversal   Direct stack size requirement for direct callables invoked from\n///                                                    IS or AH.\n/// \\param[out] directCallableStackSizeFromState       Direct stack size requirement for direct callables invoked from\n///                                                    RG, MS, or CH.\n/// \\param[out] continuationStackSize                  Continuation stack requirement.\ninline OptixResult optixUtilComputeStackSizes( const OptixStackSizes* stackSizes,\n                                               unsigned int           maxTraceDepth,\n                                               unsigned int           maxCCDepth,\n                                               unsigned int           maxDCDepth,\n                                               unsigned int*          directCallableStackSizeFromTraversal,\n                                               unsigned int*          directCallableStackSizeFromState,\n                                               unsigned int*          continuationStackSize )\n{\n    if( !stackSizes )\n        return OPTIX_ERROR_INVALID_VALUE;\n\n    const unsigned int cssRG = stackSizes->cssRG;\n    const unsigned int cssMS = stackSizes->cssMS;\n    const unsigned int cssCH = stackSizes->cssCH;\n    const unsigned int cssAH = stackSizes->cssAH;\n    const unsigned int cssIS = stackSizes->cssIS;\n    const unsigned int cssCC = stackSizes->cssCC;\n    const unsigned int dssDC = stackSizes->dssDC;\n\n    if( directCallableStackSizeFromTraversal )\n        *directCallableStackSizeFromTraversal = maxDCDepth * dssDC;\n    if( directCallableStackSizeFromState )\n        *directCallableStackSizeFromState = maxDCDepth * dssDC;\n\n    // upper bound on continuation stack used by call trees of continuation callables\n    unsigned int cssCCTree = maxCCDepth * cssCC;\n\n    // upper bound on continuation stack used by CH or MS programs including the call tree of\n    // continuation callables\n    unsigned int cssCHOrMSPlusCCTree = std::max( cssCH, cssMS ) + cssCCTree;\n\n    // clang-format off\n    if( continuationStackSize )\n        *continuationStackSize\n            = cssRG + cssCCTree\n            + ( std::max( maxTraceDepth, 1u ) - 1 ) * cssCHOrMSPlusCCTree\n            + std::min( maxTraceDepth, 1u ) * std::max( cssCHOrMSPlusCCTree, cssIS + cssAH );\n    // clang-format on\n\n    return OPTIX_SUCCESS;\n}\n\n/// Computes the stack size values needed to configure a pipeline.\n///\n/// This variant is similar to #optixUtilComputeStackSizes(), except that it expects the values dssDC and\n/// maxDCDepth split by call site semantic.\n///\n/// See programming guide for an explanation of the formula.\n///\n/// \\param[in] stackSizes                              Accumulated stack sizes of all programs in the call graph.\n/// \\param[in] dssDCFromTraversal                      Accumulated direct stack size of all DC programs invoked from IS\n///                                                    or AH.\n/// \\param[in] dssDCFromState                          Accumulated direct stack size of all DC programs invoked from RG,\n///                                                    MS, or CH.\n/// \\param[in] maxTraceDepth                           Maximum depth of #optixTrace() calls.\n/// \\param[in] maxCCDepth                              Maximum depth of calls trees of continuation callables.\n/// \\param[in] maxDCDepthFromTraversal                 Maximum depth of calls trees of direct callables invoked from IS\n///                                                    or AH.\n/// \\param[in] maxDCDepthFromState                     Maximum depth of calls trees of direct callables invoked from RG,\n///                                                    MS, or CH.\n/// \\param[out] directCallableStackSizeFromTraversal   Direct stack size requirement for direct callables invoked from\n///                                                    IS or AH.\n/// \\param[out] directCallableStackSizeFromState       Direct stack size requirement for direct callables invoked from\n///                                                    RG, MS, or CH.\n/// \\param[out] continuationStackSize                  Continuation stack requirement.\ninline OptixResult optixUtilComputeStackSizesDCSplit( const OptixStackSizes* stackSizes,\n                                                      unsigned int           dssDCFromTraversal,\n                                                      unsigned int           dssDCFromState,\n                                                      unsigned int           maxTraceDepth,\n                                                      unsigned int           maxCCDepth,\n                                                      unsigned int           maxDCDepthFromTraversal,\n                                                      unsigned int           maxDCDepthFromState,\n                                                      unsigned int*          directCallableStackSizeFromTraversal,\n                                                      unsigned int*          directCallableStackSizeFromState,\n                                                      unsigned int*          continuationStackSize )\n{\n    if( !stackSizes )\n        return OPTIX_ERROR_INVALID_VALUE;\n\n    const unsigned int cssRG = stackSizes->cssRG;\n    const unsigned int cssMS = stackSizes->cssMS;\n    const unsigned int cssCH = stackSizes->cssCH;\n    const unsigned int cssAH = stackSizes->cssAH;\n    const unsigned int cssIS = stackSizes->cssIS;\n    const unsigned int cssCC = stackSizes->cssCC;\n    // use dssDCFromTraversal and dssDCFromState instead of stackSizes->dssDC\n\n    if( directCallableStackSizeFromTraversal )\n        *directCallableStackSizeFromTraversal = maxDCDepthFromTraversal * dssDCFromTraversal;\n    if( directCallableStackSizeFromState )\n        *directCallableStackSizeFromState = maxDCDepthFromState * dssDCFromState;\n\n    // upper bound on continuation stack used by call trees of continuation callables\n    unsigned int cssCCTree = maxCCDepth * cssCC;\n\n    // upper bound on continuation stack used by CH or MS programs including the call tree of\n    // continuation callables\n    unsigned int cssCHOrMSPlusCCTree = std::max( cssCH, cssMS ) + cssCCTree;\n\n    // clang-format off\n    if( continuationStackSize )\n        *continuationStackSize\n            = cssRG + cssCCTree\n            + ( std::max( maxTraceDepth, 1u ) - 1 ) * cssCHOrMSPlusCCTree\n            + std::min( maxTraceDepth, 1u ) * std::max( cssCHOrMSPlusCCTree, cssIS + cssAH );\n    // clang-format on\n\n    return OPTIX_SUCCESS;\n}\n\n/// Computes the stack size values needed to configure a pipeline.\n///\n/// This variant is similar to #optixUtilComputeStackSizes(), except that it expects the value cssCCTree\n/// instead of cssCC and maxCCDepth.\n///\n/// See programming guide for an explanation of the formula.\n///\n/// \\param[in] stackSizes                              Accumulated stack sizes of all programs in the call graph.\n/// \\param[in] cssCCTree                               Maximum stack size used by calls trees of continuation callables.\n/// \\param[in] maxTraceDepth                           Maximum depth of #optixTrace() calls.\n/// \\param[in] maxDCDepth                              Maximum depth of calls trees of direct callables.\n/// \\param[out] directCallableStackSizeFromTraversal   Direct stack size requirement for direct callables invoked from\n///                                                    IS or AH.\n/// \\param[out] directCallableStackSizeFromState       Direct stack size requirement for direct callables invoked from\n///                                                    RG, MS, or CH.\n/// \\param[out] continuationStackSize                  Continuation stack requirement.\ninline OptixResult optixUtilComputeStackSizesCssCCTree( const OptixStackSizes* stackSizes,\n                                                        unsigned int           cssCCTree,\n                                                        unsigned int           maxTraceDepth,\n                                                        unsigned int           maxDCDepth,\n                                                        unsigned int*          directCallableStackSizeFromTraversal,\n                                                        unsigned int*          directCallableStackSizeFromState,\n                                                        unsigned int*          continuationStackSize )\n{\n    if( !stackSizes )\n        return OPTIX_ERROR_INVALID_VALUE;\n\n    const unsigned int cssRG = stackSizes->cssRG;\n    const unsigned int cssMS = stackSizes->cssMS;\n    const unsigned int cssCH = stackSizes->cssCH;\n    const unsigned int cssAH = stackSizes->cssAH;\n    const unsigned int cssIS = stackSizes->cssIS;\n    // use cssCCTree instead of stackSizes->cssCC and maxCCDepth\n    const unsigned int dssDC = stackSizes->dssDC;\n\n    if( directCallableStackSizeFromTraversal )\n        *directCallableStackSizeFromTraversal = maxDCDepth * dssDC;\n    if( directCallableStackSizeFromState )\n        *directCallableStackSizeFromState = maxDCDepth * dssDC;\n\n    // upper bound on continuation stack used by CH or MS programs including the call tree of\n    // continuation callables\n    unsigned int cssCHOrMSPlusCCTree = std::max( cssCH, cssMS ) + cssCCTree;\n\n    // clang-format off\n    if( continuationStackSize )\n        *continuationStackSize\n            = cssRG + cssCCTree\n            + ( std::max( maxTraceDepth, 1u ) - 1 ) * cssCHOrMSPlusCCTree\n            + std::min( maxTraceDepth, 1u ) * std::max( cssCHOrMSPlusCCTree, cssIS + cssAH );\n    // clang-format on\n\n    return OPTIX_SUCCESS;\n}\n\n/// Computes the stack size values needed to configure a pipeline.\n///\n/// This variant is a specialization of #optixUtilComputeStackSizes() for a simple path tracer with the following\n/// assumptions: There are only two ray types, camera rays and shadow rays. There are only RG, MS, and CH programs, and\n/// no AH, IS, CC, or DC programs. The camera rays invoke only the miss and closest hit programs MS1 and CH1,\n/// respectively. The CH1 program might trace shadow rays, which invoke only the miss and closest hit programs MS2 and\n/// CH2, respectively.\n///\n/// For flexibility, we allow for each of CH1 and CH2 not just one single program group, but an array of programs\n/// groups, and compute the maximas of the stack size requirements per array.\n///\n/// See programming guide for an explanation of the formula.\ninline OptixResult optixUtilComputeStackSizesSimplePathTracer( OptixProgramGroup        programGroupRG,\n                                                               OptixProgramGroup        programGroupMS1,\n                                                               const OptixProgramGroup* programGroupCH1,\n                                                               unsigned int             programGroupCH1Count,\n                                                               OptixProgramGroup        programGroupMS2,\n                                                               const OptixProgramGroup* programGroupCH2,\n                                                               unsigned int             programGroupCH2Count,\n                                                               unsigned int* directCallableStackSizeFromTraversal,\n                                                               unsigned int* directCallableStackSizeFromState,\n                                                               unsigned int* continuationStackSize )\n{\n    if( !programGroupCH1 && ( programGroupCH1Count > 0 ) )\n        return OPTIX_ERROR_INVALID_VALUE;\n    if( !programGroupCH2 && ( programGroupCH2Count > 0 ) )\n        return OPTIX_ERROR_INVALID_VALUE;\n\n    OptixResult result;\n\n    OptixStackSizes stackSizesRG = {};\n    result                       = optixProgramGroupGetStackSize( programGroupRG, &stackSizesRG );\n    if( result != OPTIX_SUCCESS )\n        return result;\n\n    OptixStackSizes stackSizesMS1 = {};\n    result                        = optixProgramGroupGetStackSize( programGroupMS1, &stackSizesMS1 );\n    if( result != OPTIX_SUCCESS )\n        return result;\n\n    OptixStackSizes stackSizesCH1 = {};\n    for( unsigned int i = 0; i < programGroupCH1Count; ++i )\n    {\n        result = optixUtilAccumulateStackSizes( programGroupCH1[i], &stackSizesCH1 );\n        if( result != OPTIX_SUCCESS )\n            return result;\n    }\n\n    OptixStackSizes stackSizesMS2 = {};\n    result                        = optixProgramGroupGetStackSize( programGroupMS2, &stackSizesMS2 );\n    if( result != OPTIX_SUCCESS )\n        return result;\n\n    OptixStackSizes stackSizesCH2 = {};\n    memset( &stackSizesCH2, 0, sizeof( OptixStackSizes ) );\n    for( unsigned int i = 0; i < programGroupCH2Count; ++i )\n    {\n        result = optixUtilAccumulateStackSizes( programGroupCH2[i], &stackSizesCH2 );\n        if( result != OPTIX_SUCCESS )\n            return result;\n    }\n\n    const unsigned int cssRG  = stackSizesRG.cssRG;\n    const unsigned int cssMS1 = stackSizesMS1.cssMS;\n    const unsigned int cssCH1 = stackSizesCH1.cssCH;\n    const unsigned int cssMS2 = stackSizesMS2.cssMS;\n    const unsigned int cssCH2 = stackSizesCH2.cssCH;\n    // no AH, IS, CC, or DC programs\n\n    if( directCallableStackSizeFromTraversal )\n        *directCallableStackSizeFromTraversal = 0;\n    if( directCallableStackSizeFromState )\n        *directCallableStackSizeFromState = 0;\n\n    if( continuationStackSize )\n        *continuationStackSize = cssRG + std::max( cssMS1, cssCH1 + std::max( cssMS2, cssCH2 ) );\n\n    return OPTIX_SUCCESS;\n}\n\n/*@}*/  // end group optix_utilities\n\n#ifdef __cplusplus\n}\n#endif\n\n#endif  // __optix_optix_stack_size_h__\n"
  },
  {
    "path": "render/optixutils/include/optix_stubs.h",
    "content": "/*\n * Copyright (c) 2021 NVIDIA Corporation.  All rights reserved.\n *\n * Redistribution and use in source and binary forms, with or without\n * modification, are permitted provided that the following conditions\n * are met:\n *  * Redistributions of source code must retain the above copyright\n *    notice, this list of conditions and the following disclaimer.\n *  * Redistributions in binary form must reproduce the above copyright\n *    notice, this list of conditions and the following disclaimer in the\n *    documentation and/or other materials provided with the distribution.\n *  * Neither the name of NVIDIA CORPORATION nor the names of its\n *    contributors may be used to endorse or promote products derived\n *    from this software without specific prior written permission.\n *\n * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY\n * EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE\n * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR\n * PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE COPYRIGHT OWNER OR\n * CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,\n * EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,\n * PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR\n * PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY\n * OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT\n * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE\n * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.\n */\n\n/// @file\n/// @author NVIDIA Corporation\n/// @brief  OptiX public API header\n\n#ifndef __optix_optix_stubs_h__\n#define __optix_optix_stubs_h__\n\n#include \"optix_function_table.h\"\n\n#ifdef _WIN32\n#ifndef WIN32_LEAN_AND_MEAN\n#define WIN32_LEAN_AND_MEAN 1\n#endif\n#include <windows.h>\n// The cfgmgr32 header is necessary for interrogating driver information in the registry.\n// For convenience the library is also linked in automatically using the #pragma command.\n#include <cfgmgr32.h>\n#pragma comment( lib, \"Cfgmgr32.lib\" )\n#include <string.h>\n#else\n#include <dlfcn.h>\n#endif\n\n#ifdef __cplusplus\nextern \"C\" {\n#endif\n\n// The function table needs to be defined in exactly one translation unit. This can be\n// achieved by including optix_function_table_definition.h in that translation unit.\nextern OptixFunctionTable g_optixFunctionTable;\n\n#ifdef _WIN32\nstatic void* optixLoadWindowsDllFromName( const char* optixDllName )\n{\n    void* handle = NULL;\n\n\n    // Get the size of the path first, then allocate\n    unsigned int size = GetSystemDirectoryA( NULL, 0 );\n    if( size == 0 )\n    {\n        // Couldn't get the system path size, so bail\n        return NULL;\n    }\n    size_t pathSize   = size + 1 + strlen( optixDllName );\n    char*  systemPath = (char*)malloc( pathSize );\n    if( systemPath == NULL )\n        return NULL;\n    if( GetSystemDirectoryA( systemPath, size ) != size - 1 )\n    {\n        // Something went wrong\n        free( systemPath );\n        return NULL;\n    }\n    strcat( systemPath, \"\\\\\" );\n    strcat( systemPath, optixDllName );\n    handle = LoadLibraryA( systemPath );\n    free( systemPath );\n    if( handle )\n        return handle;\n\n    // If we didn't find it, go looking in the register store.  Since nvoptix.dll doesn't\n    // have its own registry entry, we are going to look for the opengl driver which lives\n    // next to nvoptix.dll.  0 (null) will be returned if any errors occured.\n\n    static const char* deviceInstanceIdentifiersGUID = \"{4d36e968-e325-11ce-bfc1-08002be10318}\";\n    const ULONG        flags                         = CM_GETIDLIST_FILTER_CLASS | CM_GETIDLIST_FILTER_PRESENT;\n    ULONG              deviceListSize                = 0;\n    if( CM_Get_Device_ID_List_SizeA( &deviceListSize, deviceInstanceIdentifiersGUID, flags ) != CR_SUCCESS )\n    {\n        return NULL;\n    }\n    char* deviceNames = (char*)malloc( deviceListSize );\n    if( deviceNames == NULL )\n        return NULL;\n    if( CM_Get_Device_ID_ListA( deviceInstanceIdentifiersGUID, deviceNames, deviceListSize, flags ) )\n    {\n        free( deviceNames );\n        return NULL;\n    }\n    DEVINST devID   = 0;\n    char*   dllPath = NULL;\n\n    // Continue to the next device if errors are encountered.\n    for( char* deviceName = deviceNames; *deviceName; deviceName += strlen( deviceName ) + 1 )\n    {\n        if( CM_Locate_DevNodeA( &devID, deviceName, CM_LOCATE_DEVNODE_NORMAL ) != CR_SUCCESS )\n        {\n            continue;\n        }\n        HKEY regKey = 0;\n        if( CM_Open_DevNode_Key( devID, KEY_QUERY_VALUE, 0, RegDisposition_OpenExisting, &regKey, CM_REGISTRY_SOFTWARE ) != CR_SUCCESS )\n        {\n            continue;\n        }\n        const char* valueName = \"OpenGLDriverName\";\n        DWORD       valueSize = 0;\n        LSTATUS     ret       = RegQueryValueExA( regKey, valueName, NULL, NULL, NULL, &valueSize );\n        if( ret != ERROR_SUCCESS )\n        {\n            RegCloseKey( regKey );\n            continue;\n        }\n        char* regValue = (char*)malloc( valueSize );\n        if( regValue == NULL )\n        {\n            RegCloseKey( regKey );\n            continue;\n        }\n        ret            = RegQueryValueExA( regKey, valueName, NULL, NULL, (LPBYTE)regValue, &valueSize );\n        if( ret != ERROR_SUCCESS )\n        {\n            free( regValue );\n            RegCloseKey( regKey );\n            continue;\n        }\n        // Strip the opengl driver dll name from the string then create a new string with\n        // the path and the nvoptix.dll name\n        for( int i = (int) valueSize - 1; i >= 0 && regValue[i] != '\\\\'; --i )\n            regValue[i] = '\\0';\n        size_t newPathSize = strlen( regValue ) + strlen( optixDllName ) + 1;\n        dllPath            = (char*)malloc( newPathSize );\n        if( dllPath == NULL )\n        {\n            free( regValue );\n            RegCloseKey( regKey );\n            continue;\n        }\n        strcpy( dllPath, regValue );\n        strcat( dllPath, optixDllName );\n        free( regValue );\n        RegCloseKey( regKey );\n        handle = LoadLibraryA( (LPCSTR)dllPath );\n        free( dllPath );\n        if( handle )\n            break;\n    }\n    free( deviceNames );\n    return handle;\n}\n\nstatic void* optixLoadWindowsDll( )\n{\n    return optixLoadWindowsDllFromName( \"nvoptix.dll\" );\n}\n#endif\n\n/// \\defgroup optix_utilities Utilities\n/// \\brief OptiX Utilities\n\n/** \\addtogroup optix_utilities\n@{\n*/\n\n/// Loads the OptiX library and initializes the function table used by the stubs below.\n///\n/// If handlePtr is not nullptr, an OS-specific handle to the library will be returned in *handlePtr.\n///\n/// \\see #optixUninitWithHandle\ninline OptixResult optixInitWithHandle( void** handlePtr )\n{\n    // Make sure these functions get initialized to zero in case the DLL and function\n    // table can't be loaded\n    g_optixFunctionTable.optixGetErrorName   = 0;\n    g_optixFunctionTable.optixGetErrorString = 0;\n\n    if( !handlePtr )\n        return OPTIX_ERROR_INVALID_VALUE;\n\n#ifdef _WIN32\n    *handlePtr = optixLoadWindowsDll();\n    if( !*handlePtr )\n        return OPTIX_ERROR_LIBRARY_NOT_FOUND;\n\n    void* symbol = GetProcAddress( (HMODULE)*handlePtr, \"optixQueryFunctionTable\" );\n    if( !symbol )\n        return OPTIX_ERROR_ENTRY_SYMBOL_NOT_FOUND;\n#else\n    *handlePtr = dlopen( \"libnvoptix.so.1\", RTLD_NOW );\n    if( !*handlePtr )\n        return OPTIX_ERROR_LIBRARY_NOT_FOUND;\n\n    void* symbol = dlsym( *handlePtr, \"optixQueryFunctionTable\" );\n    if( !symbol )\n        return OPTIX_ERROR_ENTRY_SYMBOL_NOT_FOUND;\n#endif\n\n    OptixQueryFunctionTable_t* optixQueryFunctionTable = (OptixQueryFunctionTable_t*)symbol;\n\n    return optixQueryFunctionTable( OPTIX_ABI_VERSION, 0, 0, 0, &g_optixFunctionTable, sizeof( g_optixFunctionTable ) );\n}\n\n/// Loads the OptiX library and initializes the function table used by the stubs below.\n///\n/// A variant of #optixInitWithHandle() that does not make the handle to the loaded library available.\ninline OptixResult optixInit( void )\n{\n    void* handle;\n    return optixInitWithHandle( &handle );\n}\n\n/// Unloads the OptiX library and zeros the function table used by the stubs below.  Takes the\n/// handle returned by optixInitWithHandle.  All OptixDeviceContext objects must be destroyed\n/// before calling this function, or the behavior is undefined.\n///\n/// \\see #optixInitWithHandle\ninline OptixResult optixUninitWithHandle( void* handle )\n{\n    if( !handle )\n      return OPTIX_ERROR_INVALID_VALUE;\n#ifdef _WIN32\n    if( !FreeLibrary( (HMODULE)handle ) )\n        return OPTIX_ERROR_LIBRARY_UNLOAD_FAILURE;\n#else\n    if( dlclose( handle ) )\n        return OPTIX_ERROR_LIBRARY_UNLOAD_FAILURE;\n#endif\n    OptixFunctionTable empty = { 0 };\n    g_optixFunctionTable = empty;\n    return OPTIX_SUCCESS;\n}\n\n\n/*@}*/  // end group optix_utilities\n\n#ifndef OPTIX_DOXYGEN_SHOULD_SKIP_THIS\n\n// Stub functions that forward calls to the corresponding function pointer in the function table.\n\ninline const char* optixGetErrorName( OptixResult result )\n{\n    if( g_optixFunctionTable.optixGetErrorName )\n        return g_optixFunctionTable.optixGetErrorName( result );\n\n    // If the DLL and symbol table couldn't be loaded, provide a set of error strings\n    // suitable for processing errors related to the DLL loading.\n    switch( result )\n    {\n        case OPTIX_SUCCESS:\n            return \"OPTIX_SUCCESS\";\n        case OPTIX_ERROR_INVALID_VALUE:\n            return \"OPTIX_ERROR_INVALID_VALUE\";\n        case OPTIX_ERROR_UNSUPPORTED_ABI_VERSION:\n            return \"OPTIX_ERROR_UNSUPPORTED_ABI_VERSION\";\n        case OPTIX_ERROR_FUNCTION_TABLE_SIZE_MISMATCH:\n            return \"OPTIX_ERROR_FUNCTION_TABLE_SIZE_MISMATCH\";\n        case OPTIX_ERROR_INVALID_ENTRY_FUNCTION_OPTIONS:\n            return \"OPTIX_ERROR_INVALID_ENTRY_FUNCTION_OPTIONS\";\n        case OPTIX_ERROR_LIBRARY_NOT_FOUND:\n            return \"OPTIX_ERROR_LIBRARY_NOT_FOUND\";\n        case OPTIX_ERROR_ENTRY_SYMBOL_NOT_FOUND:\n            return \"OPTIX_ERROR_ENTRY_SYMBOL_NOT_FOUND\";\n        case OPTIX_ERROR_LIBRARY_UNLOAD_FAILURE:\n            return \"OPTIX_ERROR_LIBRARY_UNLOAD_FAILURE\";\n        default:\n            return \"Unknown OptixResult code\";\n    }\n}\n\ninline const char* optixGetErrorString( OptixResult result )\n{\n    if( g_optixFunctionTable.optixGetErrorString )\n        return g_optixFunctionTable.optixGetErrorString( result );\n\n    // If the DLL and symbol table couldn't be loaded, provide a set of error strings\n    // suitable for processing errors related to the DLL loading.\n    switch( result )\n    {\n        case OPTIX_SUCCESS:\n            return \"Success\";\n        case OPTIX_ERROR_INVALID_VALUE:\n            return \"Invalid value\";\n        case OPTIX_ERROR_UNSUPPORTED_ABI_VERSION:\n            return \"Unsupported ABI version\";\n        case OPTIX_ERROR_FUNCTION_TABLE_SIZE_MISMATCH:\n            return \"Function table size mismatch\";\n        case OPTIX_ERROR_INVALID_ENTRY_FUNCTION_OPTIONS:\n            return \"Invalid options to entry function\";\n        case OPTIX_ERROR_LIBRARY_NOT_FOUND:\n            return \"Library not found\";\n        case OPTIX_ERROR_ENTRY_SYMBOL_NOT_FOUND:\n            return \"Entry symbol not found\";\n        case OPTIX_ERROR_LIBRARY_UNLOAD_FAILURE:\n            return \"Library could not be unloaded\";\n        default:\n            return \"Unknown OptixResult code\";\n    }\n}\n\ninline OptixResult optixDeviceContextCreate( CUcontext fromContext, const OptixDeviceContextOptions* options, OptixDeviceContext* context )\n{\n    return g_optixFunctionTable.optixDeviceContextCreate( fromContext, options, context );\n}\n\ninline OptixResult optixDeviceContextDestroy( OptixDeviceContext context )\n{\n    return g_optixFunctionTable.optixDeviceContextDestroy( context );\n}\n\ninline OptixResult optixDeviceContextGetProperty( OptixDeviceContext context, OptixDeviceProperty property, void* value, size_t sizeInBytes )\n{\n    return g_optixFunctionTable.optixDeviceContextGetProperty( context, property, value, sizeInBytes );\n}\n\ninline OptixResult optixDeviceContextSetLogCallback( OptixDeviceContext context,\n                                                     OptixLogCallback   callbackFunction,\n                                                     void*              callbackData,\n                                                     unsigned int       callbackLevel )\n{\n    return g_optixFunctionTable.optixDeviceContextSetLogCallback( context, callbackFunction, callbackData, callbackLevel );\n}\n\ninline OptixResult optixDeviceContextSetCacheEnabled( OptixDeviceContext context, int enabled )\n{\n    return g_optixFunctionTable.optixDeviceContextSetCacheEnabled( context, enabled );\n}\n\ninline OptixResult optixDeviceContextSetCacheLocation( OptixDeviceContext context, const char* location )\n{\n    return g_optixFunctionTable.optixDeviceContextSetCacheLocation( context, location );\n}\n\ninline OptixResult optixDeviceContextSetCacheDatabaseSizes( OptixDeviceContext context, size_t lowWaterMark, size_t highWaterMark )\n{\n    return g_optixFunctionTable.optixDeviceContextSetCacheDatabaseSizes( context, lowWaterMark, highWaterMark );\n}\n\ninline OptixResult optixDeviceContextGetCacheEnabled( OptixDeviceContext context, int* enabled )\n{\n    return g_optixFunctionTable.optixDeviceContextGetCacheEnabled( context, enabled );\n}\n\ninline OptixResult optixDeviceContextGetCacheLocation( OptixDeviceContext context, char* location, size_t locationSize )\n{\n    return g_optixFunctionTable.optixDeviceContextGetCacheLocation( context, location, locationSize );\n}\n\ninline OptixResult optixDeviceContextGetCacheDatabaseSizes( OptixDeviceContext context, size_t* lowWaterMark, size_t* highWaterMark )\n{\n    return g_optixFunctionTable.optixDeviceContextGetCacheDatabaseSizes( context, lowWaterMark, highWaterMark );\n}\n\ninline OptixResult optixModuleCreateFromPTX( OptixDeviceContext                 context,\n                                             const OptixModuleCompileOptions*   moduleCompileOptions,\n                                             const OptixPipelineCompileOptions* pipelineCompileOptions,\n                                             const char*                        PTX,\n                                             size_t                             PTXsize,\n                                             char*                              logString,\n                                             size_t*                            logStringSize,\n                                             OptixModule*                       module )\n{\n    return g_optixFunctionTable.optixModuleCreateFromPTX( context, moduleCompileOptions, pipelineCompileOptions, PTX,\n                                                          PTXsize, logString, logStringSize, module );\n}\n\ninline OptixResult optixModuleDestroy( OptixModule module )\n{\n    return g_optixFunctionTable.optixModuleDestroy( module );\n}\n\ninline OptixResult optixBuiltinISModuleGet( OptixDeviceContext                 context,\n                                            const OptixModuleCompileOptions*   moduleCompileOptions,\n                                            const OptixPipelineCompileOptions* pipelineCompileOptions,\n                                            const OptixBuiltinISOptions*       builtinISOptions,\n                                            OptixModule*                       builtinModule )\n{\n    return g_optixFunctionTable.optixBuiltinISModuleGet( context, moduleCompileOptions, pipelineCompileOptions, \n                                                         builtinISOptions, builtinModule );\n}\n\ninline OptixResult optixProgramGroupCreate( OptixDeviceContext              context,\n                                            const OptixProgramGroupDesc*    programDescriptions,\n                                            unsigned int                    numProgramGroups,\n                                            const OptixProgramGroupOptions* options,\n                                            char*                           logString,\n                                            size_t*                         logStringSize,\n                                            OptixProgramGroup*              programGroups )\n{\n    return g_optixFunctionTable.optixProgramGroupCreate( context, programDescriptions, numProgramGroups, options,\n                                                         logString, logStringSize, programGroups );\n}\n\ninline OptixResult optixProgramGroupDestroy( OptixProgramGroup programGroup )\n{\n    return g_optixFunctionTable.optixProgramGroupDestroy( programGroup );\n}\n\ninline OptixResult optixProgramGroupGetStackSize( OptixProgramGroup programGroup, OptixStackSizes* stackSizes )\n{\n    return g_optixFunctionTable.optixProgramGroupGetStackSize( programGroup, stackSizes );\n}\n\ninline OptixResult optixPipelineCreate( OptixDeviceContext                 context,\n                                        const OptixPipelineCompileOptions* pipelineCompileOptions,\n                                        const OptixPipelineLinkOptions*    pipelineLinkOptions,\n                                        const OptixProgramGroup*           programGroups,\n                                        unsigned int                       numProgramGroups,\n                                        char*                              logString,\n                                        size_t*                            logStringSize,\n                                        OptixPipeline*                     pipeline )\n{\n    return g_optixFunctionTable.optixPipelineCreate( context, pipelineCompileOptions, pipelineLinkOptions, programGroups,\n                                                     numProgramGroups, logString, logStringSize, pipeline );\n}\n\ninline OptixResult optixPipelineDestroy( OptixPipeline pipeline )\n{\n    return g_optixFunctionTable.optixPipelineDestroy( pipeline );\n}\n\ninline OptixResult optixPipelineSetStackSize( OptixPipeline pipeline,\n                                              unsigned int  directCallableStackSizeFromTraversal,\n                                              unsigned int  directCallableStackSizeFromState,\n                                              unsigned int  continuationStackSize,\n                                              unsigned int  maxTraversableGraphDepth )\n{\n    return g_optixFunctionTable.optixPipelineSetStackSize( pipeline, directCallableStackSizeFromTraversal, directCallableStackSizeFromState,\n                                                           continuationStackSize, maxTraversableGraphDepth );\n}\n\ninline OptixResult optixAccelComputeMemoryUsage( OptixDeviceContext            context,\n                                                 const OptixAccelBuildOptions* accelOptions,\n                                                 const OptixBuildInput*        buildInputs,\n                                                 unsigned int                  numBuildInputs,\n                                                 OptixAccelBufferSizes*        bufferSizes )\n{\n    return g_optixFunctionTable.optixAccelComputeMemoryUsage( context, accelOptions, buildInputs, numBuildInputs, bufferSizes );\n}\n\ninline OptixResult optixAccelBuild( OptixDeviceContext            context,\n                                    CUstream                      stream,\n                                    const OptixAccelBuildOptions* accelOptions,\n                                    const OptixBuildInput*        buildInputs,\n                                    unsigned int                  numBuildInputs,\n                                    CUdeviceptr                   tempBuffer,\n                                    size_t                        tempBufferSizeInBytes,\n                                    CUdeviceptr                   outputBuffer,\n                                    size_t                        outputBufferSizeInBytes,\n                                    OptixTraversableHandle*       outputHandle,\n                                    const OptixAccelEmitDesc*     emittedProperties,\n                                    unsigned int                  numEmittedProperties )\n{\n    return g_optixFunctionTable.optixAccelBuild( context, stream, accelOptions, buildInputs, numBuildInputs, tempBuffer,\n                                                 tempBufferSizeInBytes, outputBuffer, outputBufferSizeInBytes,\n                                                 outputHandle, emittedProperties, numEmittedProperties );\n}\n\n\ninline OptixResult optixAccelGetRelocationInfo( OptixDeviceContext context, OptixTraversableHandle handle, OptixAccelRelocationInfo* info )\n{\n    return g_optixFunctionTable.optixAccelGetRelocationInfo( context, handle, info );\n}\n\n\ninline OptixResult optixAccelCheckRelocationCompatibility( OptixDeviceContext context, const OptixAccelRelocationInfo* info, int* compatible )\n{\n    return g_optixFunctionTable.optixAccelCheckRelocationCompatibility( context, info, compatible );\n}\n\ninline OptixResult optixAccelRelocate( OptixDeviceContext              context,\n                                       CUstream                        stream,\n                                       const OptixAccelRelocationInfo* info,\n                                       CUdeviceptr                     instanceTraversableHandles,\n                                       size_t                          numInstanceTraversableHandles,\n                                       CUdeviceptr                     targetAccel,\n                                       size_t                          targetAccelSizeInBytes,\n                                       OptixTraversableHandle*         targetHandle )\n{\n    return g_optixFunctionTable.optixAccelRelocate( context, stream, info, instanceTraversableHandles, numInstanceTraversableHandles,\n                                                    targetAccel, targetAccelSizeInBytes, targetHandle );\n}\n\ninline OptixResult optixAccelCompact( OptixDeviceContext      context,\n                                      CUstream                stream,\n                                      OptixTraversableHandle  inputHandle,\n                                      CUdeviceptr             outputBuffer,\n                                      size_t                  outputBufferSizeInBytes,\n                                      OptixTraversableHandle* outputHandle )\n{\n    return g_optixFunctionTable.optixAccelCompact( context, stream, inputHandle, outputBuffer, outputBufferSizeInBytes, outputHandle );\n}\n\ninline OptixResult optixConvertPointerToTraversableHandle( OptixDeviceContext      onDevice,\n                                                           CUdeviceptr             pointer,\n                                                           OptixTraversableType    traversableType,\n                                                           OptixTraversableHandle* traversableHandle )\n{\n    return g_optixFunctionTable.optixConvertPointerToTraversableHandle( onDevice, pointer, traversableType, traversableHandle );\n}\n\ninline OptixResult optixSbtRecordPackHeader( OptixProgramGroup programGroup, void* sbtRecordHeaderHostPointer )\n{\n    return g_optixFunctionTable.optixSbtRecordPackHeader( programGroup, sbtRecordHeaderHostPointer );\n}\n\ninline OptixResult optixLaunch( OptixPipeline                  pipeline,\n                                CUstream                       stream,\n                                CUdeviceptr                    pipelineParams,\n                                size_t                         pipelineParamsSize,\n                                const OptixShaderBindingTable* sbt,\n                                unsigned int                   width,\n                                unsigned int                   height,\n                                unsigned int                   depth )\n{\n    return g_optixFunctionTable.optixLaunch( pipeline, stream, pipelineParams, pipelineParamsSize, sbt, width, height, depth );\n}\n\ninline OptixResult optixDenoiserCreate( OptixDeviceContext context, OptixDenoiserModelKind modelKind, const OptixDenoiserOptions* options, OptixDenoiser* returnHandle )\n{\n    return g_optixFunctionTable.optixDenoiserCreate( context, modelKind, options, returnHandle );\n}\n\ninline OptixResult optixDenoiserCreateWithUserModel( OptixDeviceContext context, const void* data, size_t dataSizeInBytes, OptixDenoiser* returnHandle )\n{\n    return g_optixFunctionTable.optixDenoiserCreateWithUserModel( context, data, dataSizeInBytes, returnHandle );\n}\n\ninline OptixResult optixDenoiserDestroy( OptixDenoiser handle )\n{\n    return g_optixFunctionTable.optixDenoiserDestroy( handle );\n}\n\ninline OptixResult optixDenoiserComputeMemoryResources( const OptixDenoiser handle,\n                                                        unsigned int        maximumInputWidth,\n                                                        unsigned int        maximumInputHeight,\n                                                        OptixDenoiserSizes* returnSizes )\n{\n    return g_optixFunctionTable.optixDenoiserComputeMemoryResources( handle, maximumInputWidth, maximumInputHeight, returnSizes );\n}\n\ninline OptixResult optixDenoiserSetup( OptixDenoiser denoiser,\n                                       CUstream      stream,\n                                       unsigned int  inputWidth,\n                                       unsigned int  inputHeight,\n                                       CUdeviceptr   denoiserState,\n                                       size_t        denoiserStateSizeInBytes,\n                                       CUdeviceptr   scratch,\n                                       size_t        scratchSizeInBytes )\n{\n    return g_optixFunctionTable.optixDenoiserSetup( denoiser, stream, inputWidth, inputHeight, denoiserState,\n                                                    denoiserStateSizeInBytes, scratch, scratchSizeInBytes );\n}\n\ninline OptixResult optixDenoiserInvoke( OptixDenoiser                   handle,\n                                        CUstream                        stream,\n                                        const OptixDenoiserParams*      params,\n                                        CUdeviceptr                     denoiserData,\n                                        size_t                          denoiserDataSize,\n                                        const OptixDenoiserGuideLayer*  guideLayer,\n                                        const OptixDenoiserLayer*       layers,\n                                        unsigned int                    numLayers,\n                                        unsigned int                    inputOffsetX,\n                                        unsigned int                    inputOffsetY,\n                                        CUdeviceptr                     scratch,\n                                        size_t                          scratchSizeInBytes )\n{\n    return g_optixFunctionTable.optixDenoiserInvoke( handle, stream, params, denoiserData, denoiserDataSize,\n                                                     guideLayer, layers, numLayers,\n                                                     inputOffsetX, inputOffsetY, scratch, scratchSizeInBytes );\n}\n\ninline OptixResult optixDenoiserComputeIntensity( OptixDenoiser       handle,\n                                                  CUstream            stream,\n                                                  const OptixImage2D* inputImage,\n                                                  CUdeviceptr         outputIntensity,\n                                                  CUdeviceptr         scratch,\n                                                  size_t              scratchSizeInBytes )\n{\n    return g_optixFunctionTable.optixDenoiserComputeIntensity( handle, stream, inputImage, outputIntensity, scratch, scratchSizeInBytes );\n}\n\ninline OptixResult optixDenoiserComputeAverageColor( OptixDenoiser       handle,\n                                                     CUstream            stream,\n                                                     const OptixImage2D* inputImage,\n                                                     CUdeviceptr         outputAverageColor,\n                                                     CUdeviceptr         scratch,\n                                                     size_t              scratchSizeInBytes )\n{\n    return g_optixFunctionTable.optixDenoiserComputeAverageColor( handle, stream, inputImage, outputAverageColor, scratch, scratchSizeInBytes );\n}\n\n#endif  // OPTIX_DOXYGEN_SHOULD_SKIP_THIS\n\n#ifdef __cplusplus\n}\n#endif\n\n#endif  // __optix_optix_stubs_h__\n"
  },
  {
    "path": "render/optixutils/include/optix_types.h",
    "content": "/*\n * Copyright (c) 2021 NVIDIA Corporation.  All rights reserved.\n *\n * NVIDIA Corporation and its licensors retain all intellectual property and proprietary\n * rights in and to this software, related documentation and any modifications thereto.\n * Any use, reproduction, disclosure or distribution of this software and related\n * documentation without an express license agreement from NVIDIA Corporation is strictly\n * prohibited.\n *\n * TO THE MAXIMUM EXTENT PERMITTED BY APPLICABLE LAW, THIS SOFTWARE IS PROVIDED *AS IS*\n * AND NVIDIA AND ITS SUPPLIERS DISCLAIM ALL WARRANTIES, EITHER EXPRESS OR IMPLIED,\n * INCLUDING, BUT NOT LIMITED TO, IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A\n * PARTICULAR PURPOSE.  IN NO EVENT SHALL NVIDIA OR ITS SUPPLIERS BE LIABLE FOR ANY\n * SPECIAL, INCIDENTAL, INDIRECT, OR CONSEQUENTIAL DAMAGES WHATSOEVER (INCLUDING, WITHOUT\n * LIMITATION, DAMAGES FOR LOSS OF BUSINESS PROFITS, BUSINESS INTERRUPTION, LOSS OF\n * BUSINESS INFORMATION, OR ANY OTHER PECUNIARY LOSS) ARISING OUT OF THE USE OF OR\n * INABILITY TO USE THIS SOFTWARE, EVEN IF NVIDIA HAS BEEN ADVISED OF THE POSSIBILITY OF\n * SUCH DAMAGES\n */\n\n/**\n * @file   optix_types.h\n * @author NVIDIA Corporation\n * @brief  OptiX public API header\n *\n */\n\n#ifndef __optix_optix_types_h__\n#define __optix_optix_types_h__\n\n// clang-format off\n#if !defined(__OPTIX_INCLUDE_INTERNAL_HEADERS__)\n#  define __OPTIX_INCLUDE_INTERNAL_HEADERS__\n#  define __UNDEF_OPTIX_INCLUDE_INTERNAL_HEADERS_OPTIX_TYPES_H__\n#endif\n#include \"optix_7_types.h\"\n#if defined( __UNDEF_OPTIX_INCLUDE_INTERNAL_HEADERS_OPTIX_TYPES_H__ )\n#  undef __OPTIX_INCLUDE_INTERNAL_HEADERS__\n#  undef __UNDEF_OPTIX_INCLUDE_INTERNAL_HEADERS_OPTIX_TYPES_H__\n#endif\n// clang-format on\n\n#endif // #ifndef __optix_optix_types_h__\n"
  },
  {
    "path": "render/optixutils/ops.py",
    "content": "# Copyright (c) 2020-2022, NVIDIA CORPORATION. All rights reserved.\n#\n# NVIDIA CORPORATION and its licensors retain all intellectual property\n# and proprietary rights in and to this software, related documentation\n# and any modifications thereto. Any use, reproduction, disclosure or\n# distribution of this software and related documentation without an express\n# license agreement from NVIDIA CORPORATION is strictly prohibited.\n\nimport numpy as np\nimport os\nimport sys\nimport torch\nimport torch.utils.cpp_extension\n\n#----------------------------------------------------------------------------\n# C++/Cuda plugin compiler/loader.\n\n_plugin = None\nif _plugin is None:\n\n    # Make sure we can find the necessary compiler and libary binaries.\n    if os.name == 'nt':\n        optix_include_dir = os.path.dirname(__file__) + r\"\\include\"\n\n        def find_cl_path():\n            import glob\n            for edition in ['Enterprise', 'Professional', 'BuildTools', 'Community']:\n                vs_editions = glob.glob(r\"C:\\Program Files (x86)\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64\" % edition) \\\n                    + glob.glob(r\"C:\\Program Files\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64\" % edition)\n                paths = sorted(vs_editions, reverse=True)\n                if paths:\n                    return paths[0]\n\n        # If cl.exe is not on path, try to find it.\n        if os.system(\"where cl.exe >nul 2>nul\") != 0:\n            cl_path = find_cl_path()\n            if cl_path is None:\n                raise RuntimeError(\"Could not locate a supported Microsoft Visual C++ installation\")\n            os.environ['PATH'] += ';' + cl_path\n\n    elif os.name == 'posix':\n        optix_include_dir = os.path.dirname(__file__) + r\"/include\"\n\n    include_paths = [optix_include_dir]\n\n    # Compiler options.\n    opts = ['-DNVDR_TORCH']\n\n    # Linker options.\n    if os.name == 'posix':\n        ldflags = ['-lcuda', '-lnvrtc']\n    elif os.name == 'nt':\n        ldflags = ['cuda.lib', 'advapi32.lib', 'nvrtc.lib']\n\n    # List of sources.\n    source_files = [\n        'c_src/denoising.cu',\n        'c_src/optix_wrapper.cpp',\n        'c_src/torch_bindings.cpp'\n    ]\n\n    # Some containers set this to contain old architectures that won't compile. We only need the one installed in the machine.\n    os.environ['TORCH_CUDA_ARCH_LIST'] = ''\n\n    # Compile and load.\n    build_dir = os.path.join(os. path. dirname(__file__), 'build')\n    os.makedirs(build_dir, exist_ok=True)\n    source_paths = [os.path.join(os.path.dirname(__file__), fn) for fn in source_files]\n    torch.utils.cpp_extension.load(name='optixutils_plugin', sources=source_paths, extra_cflags=opts,\n         build_directory=build_dir,\n         extra_cuda_cflags=opts, extra_ldflags=ldflags, extra_include_paths=include_paths, with_cuda=True, verbose=True)\n\n    # Import, cache, and return the compiled module.\n    import optixutils_plugin\n    _plugin = optixutils_plugin\n\n#----------------------------------------------------------------------------\n# OptiX autograd func\n#----------------------------------------------------------------------------\n\nclass _optix_env_shade_func(torch.autograd.Function):\n    _random_perm = {}\n\n    @staticmethod\n    def forward(ctx, optix_ctx, mask, ro, gb_pos, gb_normal, gb_view_pos, gb_kd, gb_ks, light, pdf, rows, cols, BSDF, n_samples_x, rnd_seed, shadow_scale):\n        _rnd_seed = np.random.randint(2**31) if rnd_seed is None else rnd_seed\n        if n_samples_x not in _optix_env_shade_func._random_perm:\n            # Generate (32k) tables with random permutations to decorrelate the BSDF and light stratified samples\n            _optix_env_shade_func._random_perm[n_samples_x] = torch.argsort(torch.rand(32768, n_samples_x * n_samples_x, device=\"cuda\"), dim=-1).int()\n\n        diff, spec = _plugin.env_shade_fwd(optix_ctx.cpp_wrapper, mask, ro, gb_pos, gb_normal, gb_view_pos, gb_kd, gb_ks, light, pdf, rows, cols, _optix_env_shade_func._random_perm[n_samples_x], BSDF, n_samples_x, _rnd_seed, shadow_scale)\n        ctx.save_for_backward(mask, ro, gb_pos, gb_normal, gb_view_pos, gb_kd, gb_ks, light, pdf, rows, cols)\n        ctx.optix_ctx = optix_ctx\n        ctx.BSDF = BSDF\n        ctx.n_samples_x = n_samples_x\n        ctx.rnd_seed = rnd_seed\n        ctx.shadow_scale = shadow_scale\n        return diff, spec\n    \n    @staticmethod\n    def backward(ctx, diff_grad, spec_grad):\n        optix_ctx = ctx.optix_ctx\n        _rnd_seed = np.random.randint(2**31) if ctx.rnd_seed is None else ctx.rnd_seed\n        mask, ro, gb_pos, gb_normal, gb_view_pos, gb_kd, gb_ks, light, pdf, rows, cols = ctx.saved_variables\n        gb_pos_grad, gb_normal_grad, gb_kd_grad, gb_ks_grad, light_grad = _plugin.env_shade_bwd(\n            optix_ctx.cpp_wrapper, mask, ro, gb_pos, gb_normal, gb_view_pos, gb_kd, gb_ks, light, pdf, rows, cols, _optix_env_shade_func._random_perm[ctx.n_samples_x], \n            ctx.BSDF, ctx.n_samples_x, _rnd_seed, ctx.shadow_scale, diff_grad, spec_grad)\n        return None, None, None, gb_pos_grad, gb_normal_grad, None, gb_kd_grad, gb_ks_grad, light_grad, None, None, None, None, None, None, None\n\nclass _bilateral_denoiser_func(torch.autograd.Function):\n    @staticmethod\n    def forward(ctx, col, nrm, zdz, sigma):\n        ctx.save_for_backward(col, nrm, zdz)\n        ctx.sigma = sigma\n        out = _plugin.bilateral_denoiser_fwd(col, nrm, zdz, sigma)\n        return out\n    \n    @staticmethod\n    def backward(ctx, out_grad):\n        col, nrm, zdz = ctx.saved_variables\n        col_grad = _plugin.bilateral_denoiser_bwd(col, nrm, zdz, ctx.sigma, out_grad)\n        return col_grad, None, None, None\n\n#----------------------------------------------------------------------------\n# OptiX ray tracing utils\n#----------------------------------------------------------------------------\n\nclass OptiXContext:\n    def __init__(self):\n        print(\"Cuda path\", torch.utils.cpp_extension.CUDA_HOME)\n        self.cpp_wrapper = _plugin.OptiXStateWrapper(os.path.dirname(__file__), torch.utils.cpp_extension.CUDA_HOME)\n\ndef optix_build_bvh(optix_ctx, verts, tris, rebuild):\n    '''\n        choose not to raise error since we may have msdf supervision.. should clean the code later\n    '''\n    # assert tris.shape[0] > 0, \"Got empty training triangle mesh (unrecoverable discontinuity)\"\n    # assert verts.shape[0] > 0, \"Got empty training triangle mesh (unrecoverable discontinuity)\"\n    _plugin.optix_build_bvh(optix_ctx.cpp_wrapper, verts.view(-1, 3), tris.view(-1, 3), rebuild)\n\ndef optix_env_shade(optix_ctx, mask, ro, gb_pos, gb_normal, gb_view_pos, gb_kd, gb_ks, light, pdf, rows, cols, BSDF='pbr', n_samples_x=8, rnd_seed=None, shadow_scale=1.0):\n    iBSDF = ['pbr', 'diffuse', 'white'].index(BSDF) # Ordering important, must match the order of the fwd/bwdPbrBSDF kernel.\n    return _optix_env_shade_func.apply(optix_ctx, mask, ro, gb_pos, gb_normal, gb_view_pos, gb_kd, gb_ks, light, pdf, rows, cols, iBSDF, n_samples_x, rnd_seed, shadow_scale)\n\ndef bilateral_denoiser(col, nrm, zdz, sigma):\n    col_w = _bilateral_denoiser_func.apply(col, nrm, zdz, sigma)\n    return col_w[..., 0:3] / col_w[..., 3:4]\n"
  },
  {
    "path": "render/optixutils/tests/filter_test.py",
    "content": "# Copyright (c) 2020-2022, NVIDIA CORPORATION. All rights reserved.\n#\n# NVIDIA CORPORATION and its licensors retain all intellectual property\n# and proprietary rights in and to this software, related documentation\n# and any modifications thereto. Any use, reproduction, disclosure or\n# distribution of this software and related documentation without an express\n# license agreement from NVIDIA CORPORATION is strictly prohibited.\n\nfrom pickletools import read_float8\nimport torch\n\nimport os\nimport sys\nimport math\nsys.path.insert(0, os.path.join(sys.path[0], '../..'))\nimport optixutils as ou\nimport numpy as np\n\nRES = 1024\nDTYPE = torch.float32\n\ndef length(x: torch.Tensor, eps: float =1e-20) -> torch.Tensor:\n\treturn torch.sqrt(torch.clamp(dot(x,x), min=eps)) # Clamp to avoid nan gradients because grad(sqrt(0)) = NaN\n\ndef safe_normalize(x: torch.Tensor, eps: float =1e-20) -> torch.Tensor:\n\treturn x / length(x, eps)\n\ndef dot(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:\n\treturn torch.sum(x*y, -1, keepdim=True)\n\nclass BilateralDenoiser(torch.nn.Module):\n\tdef __init__(self, sigma=1.0):\n\t\tsuper(BilateralDenoiser, self).__init__()\n\t\tself.set_sigma(sigma)\n\n\tdef set_sigma(self, sigma):\n\t\tself.sigma = max(sigma, 0.0001)\n\t\tself.variance = self.sigma**2.\n\t\tself.N = 2 * math.ceil(self.sigma * 2.5) + 1\n\n\tdef forward(self, input):\n\t\teps    = 0.0001\n\t\tcol    = input[..., 0:3]\n\t\tnrm    = input[..., 3:6]\n\t\tkd     = input[..., 6:9]\n\t\tzdz    = input[..., 9:11]\n\n\t\taccum_col = torch.zeros_like(col)\n\t\taccum_w = torch.zeros_like(col[..., 0:1])\n\t\tfor y in range(-self.N, self.N+1):\n\t\t\tfor x in range(-self.N, self.N+1):\n\n\t\t\t\tty, tx = torch.meshgrid(torch.arange(0, input.shape[1], dtype=torch.float32, device=\"cuda\"), torch.arange(0, input.shape[2], dtype=torch.float32, device=\"cuda\"))\n\t\t\t\ttx = tx[None, ..., None] + x\n\t\t\t\tty = ty[None, ..., None] + y\n\n\t\t\t\tdist_sqr = (x**2 + y**2)\n\t\t\t\tdist = np.sqrt(dist_sqr)\n\t\t\t\tw_xy = np.exp(-dist_sqr / (2 * self.variance))\n\n\t\t\t\twith torch.no_grad():\n\t\t\t\t\tnrm_tap = torch.roll(nrm, (-y, -x), (1, 2))\n\t\t\t\t\tw_normal = torch.pow(torch.clamp(dot(nrm_tap, nrm), min=eps, max=1.0), 128.0)           # From SVGF\n\n\t\t\t\t\tzdz_tap = torch.roll(zdz, (-y, -x), (1, 2))\n\t\t\t\t\tw_depth = torch.exp(-(torch.abs(zdz_tap[..., 0:1] - zdz[..., 0:1]) / torch.clamp(zdz[..., 1:2] * dist, min=eps)) ) # From SVGF\t\n\n\t\t\t\t\tw = w_xy * w_normal * w_depth\n\t\t\t\t\tw = torch.where((tx >= 0) & (tx < input.shape[2]) & (ty >= 0) & (ty < input.shape[1]), w, torch.zeros_like(w))\n\n\t\t\t\tcol_tap = torch.roll(col, (-y, -x), (1, 2))\n\t\t\t\taccum_col += col_tap * w\n\t\t\t\taccum_w += w\n\t\treturn accum_col / torch.clamp(accum_w, min=eps)\n\ndef relative_loss(name, ref, cuda):\n\tref = ref.float()\n\tcuda = cuda.float()\n\tdenom = torch.where(ref > 1e-7, ref, torch.ones_like(ref))\n\trelative = torch.abs(ref - cuda) / denom\n\tprint(name, torch.max(relative).item())\n\n\ndef test_filter():\n\timg_cuda = torch.rand(1, RES, RES, 11, dtype=DTYPE, device='cuda')\n\timg_cuda[..., 3:6] = safe_normalize(img_cuda[..., 3:6])\n\timg_ref = img_cuda.clone().detach().requires_grad_(True)\n\timg_cuda = img_cuda.clone().detach().requires_grad_(True)\n\ttarget_cuda = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True)\n\ttarget_ref = target_cuda.clone().detach().requires_grad_(True)\n\t\n\tSIGMA = 2.0\n\n\tstart = torch.cuda.Event(enable_timing=True)\n\tend = torch.cuda.Event(enable_timing=True)\n\n\tstart.record()\n\tdenoiser = BilateralDenoiser(sigma=SIGMA)\n\tdenoised_ref = denoiser.forward(img_ref)\n\tref_loss = torch.nn.MSELoss()(denoised_ref, target_ref)\n\tref_loss.backward()\n\tend.record()\n\ttorch.cuda.synchronize()\n\tprint(\"Python:\", start.elapsed_time(end))\n\n\tstart.record()\n\tdenoised_cuda = ou.svgf(img_cuda[..., 0:3], img_cuda[..., 3:6], img_cuda[..., 9:11], img_cuda[..., 6:9], SIGMA)\n\tcuda_loss = torch.nn.MSELoss()(denoised_cuda, target_cuda)\n\tcuda_loss.backward()\n\tend.record()\n\ttorch.cuda.synchronize()\n\tprint(\"CUDA:\", start.elapsed_time(end))\n\n\tprint(\"-------------------------------------------------------------\")\n\tprint(\"    Filter loss:\")\n\tprint(\"-------------------------------------------------------------\")\n\n\trelative_loss(\"denoised:\", denoised_ref[..., 0:3], denoised_cuda[..., 0:3])\n\trelative_loss(\"grad:\", img_ref.grad[..., 0:3], img_cuda.grad[..., 0:3])\n\ntest_filter()"
  },
  {
    "path": "render/regularizer.py",
    "content": "# Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. \n#\n# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual\n# property and proprietary rights in and to this material, related\n# documentation and any modifications thereto. Any use, reproduction, \n# disclosure or distribution of this material and related documentation \n# without an express license agreement from NVIDIA CORPORATION or \n# its affiliates is strictly prohibited.\n\nimport torch\nimport nvdiffrast.torch as dr\n\nfrom render import util\nfrom . import mesh\n\ndef luma(x):\n    return ((x[..., 0:1] + x[..., 1:2] + x[..., 2:3]) / 3).repeat(1, 1, 1, 3)\ndef value(x):\n    return torch.max(x[..., 0:3], dim=-1, keepdim=True)[0].repeat(1, 1, 1, 3)\n\ndef chroma_loss(kd, color_ref, lambda_chroma):\n    eps = 0.001\n    ref_chroma = color_ref[..., 0:3] / torch.clip(value(color_ref), min=eps)\n    opt_chroma = kd[..., 0:3] / torch.clip(value(kd), min=eps)\n    return torch.mean(torch.abs((opt_chroma - ref_chroma) * color_ref[..., 3:])) * lambda_chroma\n\n# Diffuse luma regularizer + specular \ndef shading_loss(diffuse_light, specular_light, color_ref, lambda_diffuse, lambda_specular):\n    diffuse_luma  = luma(diffuse_light)\n    specular_luma = luma(specular_light)\n    ref_luma      = value(color_ref)\n    \n    eps = 0.001\n    img    = util.rgb_to_srgb(torch.log(torch.clamp((diffuse_luma + specular_luma) * color_ref[..., 3:], min=0, max=65535) + 1))\n    target = util.rgb_to_srgb(torch.log(torch.clamp(ref_luma * color_ref[..., 3:], min=0, max=65535) + 1))\n    # error  = torch.abs(img - target) * diffuse_luma / torch.clamp(diffuse_luma + specular_luma, min=eps) ### encourage specular component to take control\n    error  = torch.abs(img - target) ### the original version in the paper\n    loss   = torch.mean(error) * lambda_diffuse\n    loss  += torch.mean(specular_luma) / torch.clamp(torch.mean(diffuse_luma), min=eps) * lambda_specular\n    return loss\n\n######################################################################################\n# Material smoothness loss\n######################################################################################\n\ndef material_smoothness_grad(kd_grad, ks_grad, nrm_grad, lambda_kd=0.25, lambda_ks=0.1, lambda_nrm=0.0):\n    kd_luma_grad = (kd_grad[..., 0] + kd_grad[..., 1] + kd_grad[..., 2]) / 3\n    loss  = torch.mean(kd_luma_grad * kd_grad[..., -1]) * lambda_kd\n    loss += torch.mean(ks_grad[..., :-1] * ks_grad[..., -1:]) * lambda_ks\n    loss += torch.mean(nrm_grad[..., :-1] * nrm_grad[..., -1:]) * lambda_nrm\n    return loss\n\n######################################################################################\n# Computes the image gradient, useful for kd/ks smoothness losses\n######################################################################################\ndef image_grad(buf, std=0.01):\n    t, s = torch.meshgrid(torch.linspace(-1.0 + 1.0 / buf.shape[1], 1.0 - 1.0 / buf.shape[1], buf.shape[1], device=\"cuda\"), \n                          torch.linspace(-1.0 + 1.0 / buf.shape[2], 1.0 - 1.0 / buf.shape[2], buf.shape[2], device=\"cuda\"),\n                          indexing='ij')\n    tc   = torch.normal(mean=0, std=std, size=(buf.shape[0], buf.shape[1], buf.shape[2], 2), device=\"cuda\") + torch.stack((s, t), dim=-1)[None, ...]\n    tap  = dr.texture(buf, tc, filter_mode='linear', boundary_mode='clamp')\n    return torch.abs(tap[..., :-1] - buf[..., :-1]) * tap[..., -1:] * buf[..., -1:]\n\n######################################################################################\n# Computes the avergage edge length of a mesh. \n# Rough estimate of the tessellation of a mesh. Can be used e.g. to clamp gradients\n######################################################################################\ndef avg_edge_length(v_pos, t_pos_idx):\n    e_pos_idx = mesh.compute_edges(t_pos_idx)\n    edge_len  = util.length(v_pos[e_pos_idx[:, 0]] - v_pos[e_pos_idx[:, 1]])\n    return torch.mean(edge_len)\n\n######################################################################################\n# Laplacian regularization using umbrella operator (Fujiwara / Desbrun).\n# https://mgarland.org/class/geom04/material/smoothing.pdf\n######################################################################################\ndef laplace_regularizer_const(v_pos, t_pos_idx):\n    term = torch.zeros_like(v_pos)\n    norm = torch.zeros_like(v_pos[..., 0:1])\n\n    v0 = v_pos[t_pos_idx[:, 0], :]\n    v1 = v_pos[t_pos_idx[:, 1], :]\n    v2 = v_pos[t_pos_idx[:, 2], :]\n\n    term.scatter_add_(0, t_pos_idx[:, 0:1].repeat(1,3), (v1 - v0) + (v2 - v0))\n    term.scatter_add_(0, t_pos_idx[:, 1:2].repeat(1,3), (v0 - v1) + (v2 - v1))\n    term.scatter_add_(0, t_pos_idx[:, 2:3].repeat(1,3), (v0 - v2) + (v1 - v2))\n\n    two = torch.ones_like(v0) * 2.0\n    norm.scatter_add_(0, t_pos_idx[:, 0:1], two)\n    norm.scatter_add_(0, t_pos_idx[:, 1:2], two)\n    norm.scatter_add_(0, t_pos_idx[:, 2:3], two)\n\n    term = term / torch.clamp(norm, min=1.0)\n\n    return torch.mean(term**2)\n\n######################################################################################\n# Smooth vertex normals\n######################################################################################\ndef normal_consistency(v_pos, t_pos_idx):\n    # Compute face normals\n    v0 = v_pos[t_pos_idx[:, 0], :]\n    v1 = v_pos[t_pos_idx[:, 1], :]\n    v2 = v_pos[t_pos_idx[:, 2], :]\n\n    face_normals = util.safe_normalize(torch.cross(v1 - v0, v2 - v0))\n\n    tris_per_edge = mesh.compute_edge_to_face_mapping(t_pos_idx)\n\n    # Fetch normals for both faces sharind an edge\n    n0 = face_normals[tris_per_edge[:, 0], :]\n    n1 = face_normals[tris_per_edge[:, 1], :]\n\n    # Compute error metric based on normal difference\n    term = torch.clamp(util.dot(n0, n1), min=-1.0, max=1.0)\n    term = (1.0 - term) * 0.5\n\n    return torch.mean(torch.abs(term))\n"
  },
  {
    "path": "render/render.py",
    "content": "# Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n#\n# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual\n# property and proprietary rights in and to this material, related\n# documentation and any modifications thereto. Any use, reproduction,\n# disclosure or distribution of this material and related documentation\n# without an express license agreement from NVIDIA CORPORATION or\n# its affiliates is strictly prohibited.\n\nfrom threading import local\nimport numpy as np\nimport torch\nimport nvdiffrast.torch as dr\n\nfrom . import util\nfrom . import renderutils as ru\nfrom . import optixutils as ou\nfrom . import light\n\nrnd_seed = 0\n\n# ==============================================================================================\n#  Helper functions\n# ==============================================================================================\ndef interpolate(attr, rast, attr_idx, rast_db=None):\n    return dr.interpolate(attr.contiguous(), rast, attr_idx, rast_db=rast_db, diff_attrs=None if rast_db is None else 'all')\n\n# ==============================================================================================\n#  pixel shader\n# ==============================================================================================\ndef shade(\n        FLAGS,\n        rast,\n        gb_depth,\n        gb_pos,\n        gb_geometric_normal,\n        gb_normal,\n        gb_tangent,\n        gb_texc,\n        gb_texc_deriv,\n        view_pos,\n        lgt,\n        material,\n        optix_ctx,\n        mesh,\n        bsdf,\n        denoiser,\n        shadow_scale,\n        use_uv=True,\n        finetune_normal=True,\n        xfm_lgt=None,\n        shade_data=False\n    ):\n\n    offset = torch.normal(mean=0, std=0.005, size=(gb_depth.shape[0], gb_depth.shape[1], gb_depth.shape[2], 2), device=\"cuda\")\n    jitter = (util.pixel_grid(gb_depth.shape[2], gb_depth.shape[1])[None, ...] + offset).contiguous()\n\n    mask = (rast[..., -1:] > 0).float()\n    mask_tap = dr.texture(mask.contiguous(), jitter, filter_mode='linear', boundary_mode='clamp')\n    grad_weight = mask * mask_tap\n\n    ################################################################################\n    # Texture lookups\n    ################################################################################\n    perturbed_nrm = None\n    if 'kd_ks' in material:\n        # Combined texture, used for MLPs because lookups are expensive\n        all_tex_jitter = material['kd_ks'].sample(gb_pos + torch.normal(mean=0, std=0.01, size=gb_pos.shape, device=\"cuda\"))\n        # all_tex_jitter = material['kd_ks'].sample(gb_pos + torch.normal(mean=0, std=0.002, size=gb_pos.shape, device=\"cuda\"))\n        all_tex = material['kd_ks'].sample(gb_pos)\n        assert all_tex.shape[-1] == 6, \"Combined kd_ks must be 6 channels\"\n        kd, ks = all_tex[..., 0:3], all_tex[..., 3:6]\n        kd_grad  = torch.abs(all_tex_jitter[..., 0:3] - kd)\n        ks_grad  = torch.abs(all_tex_jitter[..., 3:6] - ks) * torch.tensor([0, 1, 1], dtype=torch.float32, device='cuda')[None, None, None, :] # Omit o-component\n    elif 'kd_ks_normal' in material:\n        raise NotImplementedError\n    else:\n        if shade_data:\n            kd = material['kd'].sample(gb_texc, gb_texc_deriv)\n            ks = material['ks'].sample(gb_texc, gb_texc_deriv)[..., 0:3] # skip alpha\n            if 'normal' in material:\n                perturbed_nrm = material['normal'].sample(gb_texc, gb_texc_deriv)\n\n            kd_jitter = dr.texture(kd.contiguous(), jitter, filter_mode='linear', boundary_mode='clamp')\n            ks_jitter = dr.texture(ks.contiguous(), jitter, filter_mode='linear', boundary_mode='clamp')\n            kd_grad = torch.abs(kd_jitter - kd) * grad_weight\n            ks_grad  = torch.abs(ks_jitter - ks) * torch.tensor([0, 1, 1], dtype=torch.float32, device='cuda')[None, None, None, :] * grad_weight # Omit o-component\n        else:\n            kd = material['kd'].sample(gb_texc, gb_texc_deriv)\n            ks = material['ks'].sample(gb_texc, gb_texc_deriv)[..., 0:3] # skip alpha\n            if 'normal' in material:\n                perturbed_nrm = material['normal'].sample(gb_texc, gb_texc_deriv)\n\n            kd_jitter = dr.texture(kd.contiguous(), jitter, filter_mode='linear', boundary_mode='clamp')\n            ks_jitter = dr.texture(ks.contiguous(), jitter, filter_mode='linear', boundary_mode='clamp')\n            kd_grad = torch.abs(kd_jitter - kd) * grad_weight\n            ks_grad  = torch.abs(ks_jitter - ks) * torch.tensor([0, 1, 1], dtype=torch.float32, device='cuda')[None, None, None, :] * grad_weight # Omit o-component\n\n    # Separate kd into alpha and color, default alpha = 1\n    alpha = kd[..., 3:4] if kd.shape[-1] == 4 else torch.ones_like(kd[..., 0:1])\n    kd = kd[..., 0:3]\n\n    ################################################################################\n    # Normal perturbation & normal bend\n    ################################################################################\n    if (not finetune_normal) or ('no_perturbed_nrm' in material and material['no_perturbed_nrm']):\n        perturbed_nrm = None\n\n    # Geometric smoothed normal regularizer\n    nrm_jitter = dr.texture(gb_normal.contiguous(), jitter, filter_mode='linear', boundary_mode='clamp')\n    nrm_grad = torch.abs(nrm_jitter - gb_normal) * grad_weight\n\n    if perturbed_nrm is not None:\n        perturbed_nrm_jitter = dr.texture(perturbed_nrm.contiguous(), jitter, filter_mode='linear', boundary_mode='clamp')\n        perturbed_nrm_grad = 1.0 - util.safe_normalize(util.safe_normalize(perturbed_nrm_jitter) + util.safe_normalize(perturbed_nrm))[..., 2:3]\n        perturbed_nrm_grad = perturbed_nrm_grad.repeat(1,1,1,3) * grad_weight\n\n    gb_normal = ru.prepare_shading_normal(gb_pos, view_pos, perturbed_nrm, gb_normal, gb_tangent, gb_geometric_normal, two_sided_shading=True, opengl=True)\n\n\n    ################################################################################\n    # Evaluate BSDF\n    ################################################################################\n    assert 'bsdf' in material or bsdf is not None, \"Material must specify a BSDF type\"\n    bsdf = material['bsdf'] if bsdf is None else bsdf\n    \n    if bsdf == 'pbr' or bsdf == 'diffuse' or bsdf == 'white':\n        kd = torch.ones_like(kd) if bsdf == 'white' else kd\n\n        assert isinstance(lgt, light.EnvironmentLight) and optix_ctx is not None\n        ro = gb_pos + gb_normal*0.001\n\n        global rnd_seed\n        diffuse_accum, specular_accum = ou.optix_env_shade(optix_ctx, rast[..., -1], ro, gb_pos, gb_normal, view_pos, kd, ks, \n                            lgt.base, lgt._pdf, lgt.rows[:,0], lgt.cols, BSDF=bsdf, n_samples_x=FLAGS.n_samples, \n                            rnd_seed=None if FLAGS.decorrelated else rnd_seed, shadow_scale=shadow_scale)\n        rnd_seed += 1\n\n        # denoise demodulated shaded values if possible\n        if denoiser is not None and FLAGS.denoiser_demodulate:\n            diffuse_accum  = denoiser.forward(torch.cat((diffuse_accum, gb_normal, gb_depth), dim=-1))\n            specular_accum = denoiser.forward(torch.cat((specular_accum, gb_normal, gb_depth), dim=-1))\n\n        if bsdf == 'white' or bsdf == 'diffuse':\n            shaded_col = diffuse_accum * kd\n        else:\n            kd = kd * (1.0 - ks[..., 2:3]) # kd * (1.0 - metalness)\n            shaded_col = diffuse_accum * kd + specular_accum\n\n        # denoise combined shaded values if possible\n        if denoiser is not None and not FLAGS.denoiser_demodulate:\n            shaded_col = denoiser.forward(torch.cat((shaded_col, gb_normal, gb_depth), dim=-1))\n    elif bsdf == 'normal':\n        shaded_col = (gb_normal + 1.0)*0.5\n    elif bsdf == 'tangent':\n        shaded_col = (gb_tangent + 1.0)*0.5\n    elif bsdf == 'kd':\n        shaded_col = kd\n    elif bsdf == 'ks':\n        shaded_col = ks\n    else:\n        assert False, \"Invalid BSDF '%s'\" % bsdf\n\n    eps = 1e-8\n    allone_map = torch.ones_like(alpha)\n    # Return multiple buffers\n    # Setting the `alphas` of depth and invdepth to 1 to avoid double blending\n    # (one with background, the other in antialiasing)\n    buffers = {\n        'shaded'            : torch.cat((shaded_col, alpha), dim=-1),\n        'z_grad'            : torch.cat((gb_depth, torch.zeros_like(alpha), alpha), dim=-1),\n        'normal'            : torch.cat((gb_normal, alpha), dim=-1),\n        'geometric_normal'  : torch.cat((gb_geometric_normal, alpha), dim=-1),\n        'kd'                : torch.cat((kd, alpha), dim=-1),\n        'ks'                : torch.cat((ks, alpha), dim=-1),\n        'kd_grad'           : torch.cat((kd_grad, alpha), dim=-1),\n        'ks_grad'           : torch.cat((ks_grad, alpha), dim=-1),\n        'normal_grad'       : torch.cat((nrm_grad, alpha), dim=-1),\n        # 'depth'             : torch.cat(((gb_pos - view_pos).pow(2).sum(dim=-1, keepdim=True).sqrt(), allone_map), dim=-1),\n        # 'invdepth'          : torch.cat((1.0 / ((gb_pos - view_pos).pow(2) + eps).sum(dim=-1, keepdim=True).sqrt(), allone_map), dim=-1),\n    }\n\n    if 'diffuse_accum' in locals():\n        buffers['diffuse_light'] = torch.cat((diffuse_accum, alpha), dim=-1)\n    if 'specular_accum' in locals():\n        buffers['specular_light'] = torch.cat((specular_accum, alpha), dim=-1)\n\n    if perturbed_nrm is not None: \n        buffers['perturbed_nrm'] = torch.cat((perturbed_nrm, alpha), dim=-1)\n        buffers['perturbed_nrm_grad'] = torch.cat((perturbed_nrm_grad, alpha), dim=-1)\n    return buffers\n\n# ==============================================================================================\n#  Render a depth slice of the mesh (scene), some limitations:\n#  - Single mesh\n#  - Single light\n#  - Single material\n# ==============================================================================================\ndef render_layer(\n        FLAGS,\n        v_pos_clip,\n        rast,\n        rast_deriv,\n        mesh,\n        view_pos,\n        lgt,\n        resolution,\n        spp,\n        msaa,\n        optix_ctx,\n        bsdf,\n        denoiser,\n        shadow_scale,\n        use_uv=True,\n        finetune_normal=True,\n        extra_dict=None,\n        xfm_lgt = None,\n        shade_data = False\n    ):\n\n    full_res = [resolution[0]*spp, resolution[1]*spp]\n\n    ################################################################################\n    # Rasterize\n    ################################################################################\n\n    # Scale down to shading resolution when MSAA is enabled, otherwise shade at full resolution\n    if spp > 1 and msaa:\n        rast_out_s = util.scale_img_nhwc(rast, resolution, mag='nearest', min='nearest')\n        rast_out_deriv_s = util.scale_img_nhwc(rast_deriv, resolution, mag='nearest', min='nearest') * spp\n    else:\n        rast_out_s = rast\n        rast_out_deriv_s = rast_deriv\n\n    ################################################################################\n    # Interpolate attributes\n    ################################################################################\n\n    # Interpolate world space position\n    gb_pos, _ = interpolate(mesh.v_pos[None, ...], rast_out_s, mesh.t_pos_idx.int())\n\n    # Compute geometric normals. We need those because of bent normals trick (for bump mapping)\n    v0 = mesh.v_pos[mesh.t_pos_idx[:, 0], :]\n    v1 = mesh.v_pos[mesh.t_pos_idx[:, 1], :]\n    v2 = mesh.v_pos[mesh.t_pos_idx[:, 2], :]\n    face_normals = util.safe_normalize(torch.cross(v1 - v0, v2 - v0))\n    face_normal_indices = (torch.arange(0, face_normals.shape[0], dtype=torch.int64, device='cuda')[:, None]).repeat(1, 3)\n    gb_geometric_normal, _ = interpolate(face_normals[None, ...], rast_out_s, face_normal_indices.int())\n\n    if use_uv:\n        # Compute tangent space\n        assert mesh.v_nrm is not None and mesh.v_tng is not None\n        gb_normal, _ = interpolate(mesh.v_nrm[None, ...], rast_out_s, mesh.t_nrm_idx.int())\n        gb_tangent, _ = interpolate(mesh.v_tng[None, ...], rast_out_s, mesh.t_tng_idx.int()) # Interpolate tangents\n\n        # Texture coordinate\n        assert mesh.v_tex is not None\n        gb_texc, gb_texc_deriv = interpolate(mesh.v_tex[None, ...], rast_out_s, mesh.t_tex_idx.int(), rast_db=rast_out_deriv_s)\n\n    else:\n        # Compute tangent space\n        assert mesh.v_nrm is not None\n        gb_normal, _ = interpolate(mesh.v_nrm[None, ...], rast_out_s, mesh.t_nrm_idx.int())\n        with torch.no_grad():\n            noise = torch.randn_like(gb_normal)\n            noise = noise / noise.norm(dim=-1, keepdim=True)\n        gb_tangent = torch.cross(noise, gb_normal) ### since we only use tangent for adding isotropic noises but not for uv maps\n\n        # # Texture coordinate\n        gb_texc, gb_texc_deriv = None, None\n\n    # Interpolate z and z-gradient\n    with torch.no_grad():\n        eps = 0.00001\n        clip_pos, clip_pos_deriv = interpolate(v_pos_clip, rast_out_s, mesh.t_pos_idx.int(), rast_db=rast_out_deriv_s)\n        z0 = torch.clamp(clip_pos[..., 2:3], min=eps) / torch.clamp(clip_pos[..., 3:4], min=eps)\n        z1 = torch.clamp(clip_pos[..., 2:3] + torch.abs(clip_pos_deriv[..., 2:3]), min=eps) / torch.clamp(clip_pos[..., 3:4] + torch.abs(clip_pos_deriv[..., 3:4]), min=eps)\n        z_grad = torch.abs(z1 - z0)\n        gb_depth = torch.cat((z0, z_grad), dim=-1)\n    ################################################################################\n    # Shade\n    ################################################################################\n\n    buffers = shade(\n        FLAGS, rast_out_s, gb_depth,\n        gb_pos, gb_geometric_normal, gb_normal,\n        gb_tangent, gb_texc, gb_texc_deriv,\n        view_pos, lgt, mesh.material, optix_ctx, \n        mesh, bsdf,\n        denoiser, shadow_scale,\n        use_uv=use_uv,\n        finetune_normal=finetune_normal,\n        xfm_lgt=xfm_lgt,\n        shade_data=shade_data\n    )\n\n    ################################################################################\n    # Prepare output\n    ################################################################################\n\n\n    if extra_dict is not None:\n        for key in extra_dict:\n            if key == 'msdf' and extra_dict[key] is not None:\n                assert extra_dict[key].dim() == 1 or (extra_dict[key].dim() == 2 and extra_dict[key].size(1) == 1)\n                buffers['msdf_image'], _ = interpolate(extra_dict[key].squeeze()[None, :, None], rast_out_s, mesh.t_pos_idx.int())\n            elif key == 'msdf_watertight' and extra_dict[key] is not None:\n                assert extra_dict[key].dim() == 1 or (extra_dict[key].dim() == 2 and extra_dict[key].size(1) == 1)\n                buffers['msdf_watertight_image'], _ = interpolate(extra_dict['msdf_watertight'].squeeze()[None, :, None], rast_out_s.detach(), mesh.t_pos_idx.int()) ## maybe better to stop all gradients to vpos\n\n    # Scale back up to visibility resolution if using MSAA\n    if spp > 1 and msaa:\n        for key in buffers.keys():\n            buffers[key] = util.scale_img_nhwc(buffers[key], full_res, mag='nearest', min='nearest')\n\n    # Return buffers\n    return buffers\n\n# ==============================================================================================\n#  Render a depth peeled mesh (scene), some limitations:\n#  - Single mesh\n#  - Single light\n#  - Single material\n# ==============================================================================================\ndef render_mesh(\n        FLAGS,\n        ctx,\n        mesh,\n        mtx_in,\n        view_pos,\n        lgt,\n        resolution,\n        spp        = 1,\n        num_layers = 1,\n        msaa       = False,\n        background = None,\n        optix_ctx  = None,\n        bsdf       = None,\n        denoiser   = None,\n        shadow_scale = 1.0,\n        use_uv      = True,\n        finetune_normal = True,\n        extra_dict  = None,\n        xfm_lgt     = None,\n        shade_data  = False,\n    ):\n\n    def prepare_input_vector(x):\n        x = torch.tensor(x, dtype=torch.float32, device='cuda') if not torch.is_tensor(x) else x\n        return x[:, None, None, :] if len(x.shape) == 2 else x\n\n    def composite_buffer(key, layers, background, antialias):\n        accum = background\n        for buffers, rast in reversed(layers):\n            alpha = (rast[..., -1:] > 0).float() * buffers[key][..., -1:]\n            accum = torch.lerp(accum, torch.cat((buffers[key][..., :-1], torch.ones_like(buffers[key][..., -1:])), dim=-1), alpha)\n            if antialias:\n                accum = dr.antialias(accum.contiguous(), rast, v_pos_clip, mesh.t_pos_idx.int())\n        return accum\n\n    '''\n        choose not to raise error since it is possible that we have msdf supervision. should clean the code later\n    '''\n    # assert mesh.t_pos_idx.shape[0] > 0, \"Got empty training triangle mesh (unrecoverable discontinuity)\"\n    # assert background is None or (background.shape[1] == resolution[0] and background.shape[2] == resolution[1])\n\n    full_res = [resolution[0]*spp, resolution[1]*spp]\n\n    # Convert numpy arrays to torch tensors\n    mtx_in      = torch.tensor(mtx_in, dtype=torch.float32, device='cuda') if not torch.is_tensor(mtx_in) else mtx_in\n    view_pos    = prepare_input_vector(view_pos)\n\n    # clip space transform\n    v_pos_clip = ru.xfm_points(mesh.v_pos[None, ...], mtx_in)\n\n    # Render all layers front-to-back\n    with dr.DepthPeeler(ctx, v_pos_clip, mesh.t_pos_idx.int(), full_res) as peeler:\n        assert num_layers == 1\n        rast, db = peeler.rasterize_next_layer()\n        visible_triangles = rast[:,:,:,-1].long().unique()\n        if visible_triangles[0] == 0:\n            visible_triangles = visible_triangles[1:]\n        visible_triangles = visible_triangles - 1\n        layers = [\n            (render_layer(\n                FLAGS, v_pos_clip,\n                rast, db, mesh, view_pos, lgt, resolution, spp, msaa,\n                optix_ctx, bsdf, denoiser, shadow_scale,\n                use_uv=use_uv, finetune_normal=finetune_normal,\n                extra_dict=extra_dict,\n                xfm_lgt=xfm_lgt,\n                shade_data=shade_data),\n            rast)]\n        # rast, db = peeler.rasterize_next_layer()\n        # layer_second = [\n        #     (render_layer(\n        #         FLAGS, v_pos_clip,\n        #         rast, db, mesh, view_pos, lgt, resolution, spp, msaa,\n        #         optix_ctx, bsdf, denoiser, shadow_scale,\n        #         use_uv=use_uv, finetune_normal=finetune_normal,\n        #         extra_dict=extra_dict,\n        #         xfm_lgt=xfm_lgt,\n        #         shade_data=shade_data),\n        #     rast)]\n\n    # Setup background\n    if background is not None:\n        if spp > 1:\n            background = util.scale_img_nhwc(background, full_res, mag='nearest', min='nearest')\n        background = torch.cat((background, torch.zeros_like(background[..., 0:1])), dim=-1)\n    else:\n        background = torch.zeros(1, full_res[0], full_res[1], 4, dtype=torch.float32, device='cuda')\n\n    # Composite layers front-to-back\n    out_buffers = {}\n    out_buffers['visible_triangles'] = visible_triangles\n    for key in layers[0][0].keys():\n        if layers[0][0][key] is None:\n            out_buffers[key] = None\n            continue\n        if key == 'shaded':\n            accum = composite_buffer(key, layers, background, True)\n        elif key == 'depth':\n            continue\n            default_depth = 20.0\n            accum = composite_buffer(key, layers, torch.ones_like(layers[0][0][key]) * default_depth, True)\n        elif key == 'invdepth':\n            accum = composite_buffer(key, layers, torch.zeros_like(layers[0][0][key]), True)\n        else:\n            accum = composite_buffer(key, layers, torch.zeros_like(layers[0][0][key]), True)\n\n        # Downscale to framebuffer resolution. Use avg pooling\n        out_buffers[key] = util.avg_pool_nhwc(accum, spp) if spp > 1 else accum\n\n    # accum = composite_buffer('shaded', layer_second, background.clone(), True)\n    # out_buffers['shaded_second'] = util.avg_pool_nhwc(accum, spp) if spp > 1 else accum\n\n    # accum = composite_buffer('invdepth', layer_second, torch.zeros_like(layers[0][0]['invdepth']), True)\n    # out_buffers['invdepth_second'] = util.avg_pool_nhwc(accum, spp) if spp > 1 else accum\n\n    # accum = composite_buffer('depth', layer_second, torch.ones_like(layers[0][0]['depth']) * default_depth, True)\n    # out_buffers['depth_second'] = util.avg_pool_nhwc(accum, spp) if spp > 1 else accum\n\n    return out_buffers\n\n# ==============================================================================================\n#  Render UVs\n# ==============================================================================================\ndef render_uv(ctx, mesh, resolution, mlp_texture):\n\n    # clip space transform\n    uv_clip = mesh.v_tex[None, ...]*2.0 - 1.0\n\n    # pad to four component coordinate\n    uv_clip4 = torch.cat((uv_clip, torch.zeros_like(uv_clip[...,0:1]), torch.ones_like(uv_clip[...,0:1])), dim = -1)\n\n    # rasterize\n    rast, _ = dr.rasterize(ctx, uv_clip4, mesh.t_tex_idx.int(), resolution)\n\n    # Interpolate world space position\n    gb_pos, _ = interpolate(mesh.v_pos[None, ...], rast, mesh.t_pos_idx.int())\n\n    # Sample out textures from MLP\n    all_tex = mlp_texture.sample(gb_pos)\n    assert all_tex.shape[-1] == 6, \"Combined kd_ks must be 6 channels\"\n    return (rast[..., -1:] > 0).float(), all_tex[..., 0:3], all_tex[..., 3:6]"
  },
  {
    "path": "render/renderutils/__init__.py",
    "content": "# Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. \n#\n# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual\n# property and proprietary rights in and to this material, related\n# documentation and any modifications thereto. Any use, reproduction, \n# disclosure or distribution of this material and related documentation \n# without an express license agreement from NVIDIA CORPORATION or \n# its affiliates is strictly prohibited.\n\nfrom .ops import xfm_points, xfm_vectors, image_loss, diffuse_cubemap, specular_cubemap, prepare_shading_normal, lambert, frostbite_diffuse, pbr_specular, pbr_bsdf, _fresnel_shlick, _ndf_ggx, _lambda_ggx, _masking_smith\n__all__ = [\"xfm_vectors\", \"xfm_points\", \"image_loss\", \"diffuse_cubemap\",\"specular_cubemap\", \"prepare_shading_normal\", \"lambert\", \"frostbite_diffuse\", \"pbr_specular\", \"pbr_bsdf\", \"_fresnel_shlick\", \"_ndf_ggx\", \"_lambda_ggx\", \"_masking_smith\", ]\n"
  },
  {
    "path": "render/renderutils/bsdf.py",
    "content": "# Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. \n#\n# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual\n# property and proprietary rights in and to this material, related\n# documentation and any modifications thereto. Any use, reproduction, \n# disclosure or distribution of this material and related documentation \n# without an express license agreement from NVIDIA CORPORATION or \n# its affiliates is strictly prohibited.\n\nimport math\nimport torch\n\nNORMAL_THRESHOLD = 0.1\n\n################################################################################\n# Vector utility functions\n################################################################################\n\ndef _dot(x, y):\n    return torch.sum(x*y, -1, keepdim=True)\n\ndef _reflect(x, n):\n    return 2*_dot(x, n)*n - x\n\ndef _safe_normalize(x):\n    return torch.nn.functional.normalize(x, dim = -1)\n\ndef _bend_normal(view_vec, smooth_nrm, geom_nrm, two_sided_shading):\n    # Swap normal direction for backfacing surfaces\n    if two_sided_shading:\n        smooth_nrm = torch.where(_dot(geom_nrm, view_vec) > 0, smooth_nrm, -smooth_nrm)\n        geom_nrm   = torch.where(_dot(geom_nrm, view_vec) > 0, geom_nrm, -geom_nrm)\n\n    t = torch.clamp(_dot(view_vec, smooth_nrm) / NORMAL_THRESHOLD, min=0, max=1)\n    return torch.lerp(geom_nrm, smooth_nrm, t)\n\n\ndef _perturb_normal(perturbed_nrm, smooth_nrm, smooth_tng, opengl):\n    smooth_bitang = _safe_normalize(torch.cross(smooth_tng, smooth_nrm))\n    if opengl:\n        shading_nrm = smooth_tng * perturbed_nrm[..., 0:1] - smooth_bitang * perturbed_nrm[..., 1:2] + smooth_nrm * torch.clamp(perturbed_nrm[..., 2:3], min=0.0)\n    else:\n        shading_nrm = smooth_tng * perturbed_nrm[..., 0:1] + smooth_bitang * perturbed_nrm[..., 1:2] + smooth_nrm * torch.clamp(perturbed_nrm[..., 2:3], min=0.0)\n    return _safe_normalize(shading_nrm)\n\ndef bsdf_prepare_shading_normal(pos, view_pos, perturbed_nrm, smooth_nrm, smooth_tng, geom_nrm, two_sided_shading, opengl):\n    smooth_nrm = _safe_normalize(smooth_nrm)\n    smooth_tng = _safe_normalize(smooth_tng)\n    view_vec   = _safe_normalize(view_pos - pos)\n    shading_nrm = _perturb_normal(perturbed_nrm, smooth_nrm, smooth_tng, opengl)\n    return _bend_normal(view_vec, shading_nrm, geom_nrm, two_sided_shading)\n\n################################################################################\n# Simple lambertian diffuse BSDF\n################################################################################\n\ndef bsdf_lambert(nrm, wi):\n    return torch.clamp(_dot(nrm, wi), min=0.0) / math.pi\n\n################################################################################\n# Frostbite diffuse\n################################################################################\n\ndef bsdf_frostbite(nrm, wi, wo, linearRoughness):\n    wiDotN = _dot(wi, nrm)\n    woDotN = _dot(wo, nrm)\n\n    h = _safe_normalize(wo + wi)\n    wiDotH = _dot(wi, h)\n\n    energyBias = 0.5 * linearRoughness\n    energyFactor = 1.0 - (0.51 / 1.51) * linearRoughness\n    f90 = energyBias + 2.0 * wiDotH * wiDotH * linearRoughness\n    f0 = 1.0\n\n    wiScatter = bsdf_fresnel_shlick(f0, f90, wiDotN)\n    woScatter = bsdf_fresnel_shlick(f0, f90, woDotN)\n    res = wiScatter * woScatter * energyFactor\n    return torch.where((wiDotN > 0.0) & (woDotN > 0.0), res, torch.zeros_like(res))\n\n################################################################################\n# Phong specular, loosely based on mitsuba implementation\n################################################################################\n\ndef bsdf_phong(nrm, wo, wi, N):\n    dp_r = torch.clamp(_dot(_reflect(wo, nrm), wi), min=0.0, max=1.0)\n    dp_l = torch.clamp(_dot(nrm, wi), min=0.0, max=1.0)\n    return (dp_r ** N) * dp_l * (N + 2) / (2 * math.pi)\n\n################################################################################\n# PBR's implementation of GGX specular\n################################################################################\n\nspecular_epsilon = 1e-4\n\ndef bsdf_fresnel_shlick(f0, f90, cosTheta):\n    _cosTheta = torch.clamp(cosTheta, min=specular_epsilon, max=1.0 - specular_epsilon)\n    return f0 + (f90 - f0) * (1.0 - _cosTheta) ** 5.0\n\ndef bsdf_ndf_ggx(alphaSqr, cosTheta):\n    _cosTheta = torch.clamp(cosTheta, min=specular_epsilon, max=1.0 - specular_epsilon)\n    d = (_cosTheta * alphaSqr - _cosTheta) * _cosTheta + 1\n    return alphaSqr / (d * d * math.pi)\n\ndef bsdf_lambda_ggx(alphaSqr, cosTheta):\n    _cosTheta = torch.clamp(cosTheta, min=specular_epsilon, max=1.0 - specular_epsilon)\n    cosThetaSqr = _cosTheta * _cosTheta\n    tanThetaSqr = (1.0 - cosThetaSqr) / cosThetaSqr\n    res = 0.5 * (torch.sqrt(1 + alphaSqr * tanThetaSqr) - 1.0)\n    return res\n\ndef bsdf_masking_smith_ggx_correlated(alphaSqr, cosThetaI, cosThetaO):\n    lambdaI = bsdf_lambda_ggx(alphaSqr, cosThetaI)\n    lambdaO = bsdf_lambda_ggx(alphaSqr, cosThetaO)\n    return 1 / (1 + lambdaI + lambdaO)\n\ndef bsdf_pbr_specular(col, nrm, wo, wi, alpha, min_roughness=0.08):\n    _alpha = torch.clamp(alpha, min=min_roughness*min_roughness, max=1.0)\n    alphaSqr = _alpha * _alpha\n\n    h = _safe_normalize(wo + wi)\n    woDotN = _dot(wo, nrm)\n    wiDotN = _dot(wi, nrm)\n    woDotH = _dot(wo, h)\n    nDotH  = _dot(nrm, h)\n\n    D = bsdf_ndf_ggx(alphaSqr, nDotH)\n    G = bsdf_masking_smith_ggx_correlated(alphaSqr, woDotN, wiDotN)\n    F = bsdf_fresnel_shlick(col, 1, woDotH)\n\n    w = F * D * G * 0.25 / torch.clamp(woDotN, min=specular_epsilon)\n\n    frontfacing = (woDotN > specular_epsilon) & (wiDotN > specular_epsilon)\n    return torch.where(frontfacing, w, torch.zeros_like(w))\n\ndef bsdf_pbr(kd, arm, pos, nrm, view_pos, light_pos, min_roughness, BSDF):\n    wo = _safe_normalize(view_pos - pos)\n    wi = _safe_normalize(light_pos - pos)\n\n    spec_str  = arm[..., 0:1] # x component\n    roughness = arm[..., 1:2] # y component\n    metallic  = arm[..., 2:3] # z component\n    ks = (0.04 * (1.0 - metallic) + kd * metallic) * (1 - spec_str)\n    kd = kd * (1.0 - metallic)\n\n    if BSDF == 0:\n        diffuse = kd * bsdf_lambert(nrm, wi)\n    else:\n        diffuse = kd * bsdf_frostbite(nrm, wi, wo, roughness)\n    specular = bsdf_pbr_specular(ks, nrm, wo, wi, roughness*roughness, min_roughness=min_roughness)\n    return diffuse + specular\n"
  },
  {
    "path": "render/renderutils/c_src/bsdf.cu",
    "content": "/*\n * Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n *\n * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual\n * property and proprietary rights in and to this material, related \n * documentation and any modifications thereto. Any use, reproduction, \n * disclosure or distribution of this material and related documentation\n * without an express license agreement from NVIDIA CORPORATION or \n * its affiliates is strictly prohibited.\n */\n\n#include \"common.h\"\n#include \"bsdf.h\"\n\n#define SPECULAR_EPSILON 1e-4f\n\n//------------------------------------------------------------------------\n// Lambert functions\n\n__device__ inline float fwdLambert(const vec3f nrm, const vec3f wi)\n{\n    return max(dot(nrm, wi) / M_PI, 0.0f);\n}\n\n__device__ inline void bwdLambert(const vec3f nrm, const vec3f wi, vec3f& d_nrm, vec3f& d_wi, const float d_out)\n{\n    if (dot(nrm, wi) > 0.0f)\n        bwdDot(nrm, wi, d_nrm, d_wi, d_out / M_PI);\n}\n\n//------------------------------------------------------------------------\n// Fresnel Schlick \n\n__device__ inline float fwdFresnelSchlick(const float f0, const float f90, const float cosTheta)\n{\n    float _cosTheta = clamp(cosTheta, SPECULAR_EPSILON, 1.0f - SPECULAR_EPSILON);\n    float scale = powf(1.0f - _cosTheta, 5.0f);\n    return f0 * (1.0f - scale) + f90 * scale;\n}\n\n__device__ inline void bwdFresnelSchlick(const float f0, const float f90, const float cosTheta, float& d_f0, float& d_f90, float& d_cosTheta, const float d_out)\n{\n    float _cosTheta = clamp(cosTheta, SPECULAR_EPSILON, 1.0f - SPECULAR_EPSILON);\n    float scale = pow(max(1.0f - _cosTheta, 0.0f), 5.0f);\n    d_f0 += d_out * (1.0 - scale);\n    d_f90 += d_out * scale;\n    if (cosTheta >= SPECULAR_EPSILON && cosTheta < 1.0f - SPECULAR_EPSILON)\n    {\n        d_cosTheta += d_out * (f90 - f0) * -5.0f * powf(1.0f - cosTheta, 4.0f);\n    }\n}\n\n__device__ inline vec3f fwdFresnelSchlick(const vec3f f0, const vec3f f90, const float cosTheta)\n{\n    float _cosTheta = clamp(cosTheta, SPECULAR_EPSILON, 1.0f - SPECULAR_EPSILON);\n    float scale = powf(1.0f - _cosTheta, 5.0f);\n    return f0 * (1.0f - scale) + f90 * scale;\n}\n\n__device__ inline void bwdFresnelSchlick(const vec3f f0, const vec3f f90, const float cosTheta, vec3f& d_f0, vec3f& d_f90, float& d_cosTheta, const vec3f d_out)\n{\n    float _cosTheta = clamp(cosTheta, SPECULAR_EPSILON, 1.0f - SPECULAR_EPSILON);\n    float scale = pow(max(1.0f - _cosTheta, 0.0f), 5.0f);\n    d_f0 += d_out * (1.0 - scale);\n    d_f90 += d_out * scale;\n    if (cosTheta >= SPECULAR_EPSILON && cosTheta < 1.0f - SPECULAR_EPSILON)\n    {\n        d_cosTheta += sum(d_out * (f90 - f0) * -5.0f * powf(1.0f - cosTheta, 4.0f));\n    }\n}\n\n//------------------------------------------------------------------------\n// Frostbite diffuse\n\n__device__ inline float fwdFrostbiteDiffuse(const vec3f nrm, const vec3f wi, const vec3f wo, float linearRoughness)\n{\n    float wiDotN = dot(wi, nrm);\n    float woDotN = dot(wo, nrm);\n    if (wiDotN > 0.0f && woDotN > 0.0f)\n    {\n        vec3f h = safeNormalize(wo + wi);\n        float wiDotH = dot(wi, h);\n\n        float energyBias = 0.5f * linearRoughness;\n        float energyFactor = 1.0f - (0.51f / 1.51f) * linearRoughness;\n        float f90 = energyBias + 2.f * wiDotH * wiDotH * linearRoughness;\n        float f0 = 1.f;\n        \n        float wiScatter = fwdFresnelSchlick(f0, f90, wiDotN);\n        float woScatter = fwdFresnelSchlick(f0, f90, woDotN);\n        \n        return wiScatter * woScatter * energyFactor;\n    }\n    else return 0.0f;\n}\n\n__device__ inline void bwdFrostbiteDiffuse(const vec3f nrm, const vec3f wi, const vec3f wo, float linearRoughness, vec3f& d_nrm, vec3f& d_wi, vec3f& d_wo, float &d_linearRoughness, const float d_out)\n{\n    float wiDotN = dot(wi, nrm);\n    float woDotN = dot(wo, nrm);\n\n    if (wiDotN > 0.0f && woDotN > 0.0f)\n    {\n        vec3f h = safeNormalize(wo + wi);\n        float wiDotH = dot(wi, h);\n\n        float energyBias = 0.5f * linearRoughness;\n        float energyFactor = 1.0f - (0.51f / 1.51f) * linearRoughness;\n        float f90 = energyBias + 2.f * wiDotH * wiDotH * linearRoughness;\n        float f0 = 1.f;\n        \n        float wiScatter = fwdFresnelSchlick(f0, f90, wiDotN);\n        float woScatter = fwdFresnelSchlick(f0, f90, woDotN);\n\n        // -------------- BWD --------------\n        // Backprop: return wiScatter * woScatter * energyFactor;\n        float d_wiScatter = d_out * woScatter * energyFactor;\n        float d_woScatter = d_out * wiScatter * energyFactor;\n        float d_energyFactor = d_out * wiScatter * woScatter; \n\n        // Backprop: float woScatter = fwdFresnelSchlick(f0, f90, woDotN);\n        float d_woDotN = 0.0f, d_f0 = 0.0, d_f90 = 0.0f;\n        bwdFresnelSchlick(f0, f90, woDotN, d_f0, d_f90, d_woDotN, d_woScatter);\n\n        // Backprop: float wiScatter = fwdFresnelSchlick(fd0, fd90, wiDotN);\n        float d_wiDotN = 0.0f;\n        bwdFresnelSchlick(f0, f90, wiDotN, d_f0, d_f90, d_wiDotN, d_wiScatter);\n\n        // Backprop: float f90 = energyBias + 2.f * wiDotH * wiDotH * linearRoughness;\n        float d_energyBias = d_f90;\n        float d_wiDotH = d_f90 * 4 * wiDotH * linearRoughness;\n        d_linearRoughness += d_f90 * 2 * wiDotH * wiDotH;\n\n        // Backprop: float energyFactor = 1.0f - (0.51f / 1.51f) * linearRoughness;\n        d_linearRoughness -= (0.51f / 1.51f) * d_energyFactor;\n\n        // Backprop: float energyBias = 0.5f * linearRoughness;\n        d_linearRoughness += 0.5 * d_energyBias;\n\n        // Backprop: float wiDotH = dot(wi, h);\n        vec3f d_h(0);\n        bwdDot(wi, h, d_wi, d_h, d_wiDotH);\n\n        // Backprop: vec3f h = safeNormalize(wo + wi);     \n        vec3f d_wo_wi(0);\n        bwdSafeNormalize(wo + wi, d_wo_wi, d_h);\n        d_wi += d_wo_wi; d_wo += d_wo_wi;\n\n        bwdDot(wo, nrm, d_wo, d_nrm, d_woDotN);\n        bwdDot(wi, nrm, d_wi, d_nrm, d_wiDotN);\n    }\n}\n\n//------------------------------------------------------------------------\n// Ndf GGX\n\n__device__ inline float fwdNdfGGX(const float alphaSqr, const float cosTheta)\n{\n    float _cosTheta = clamp(cosTheta, SPECULAR_EPSILON, 1.0f - SPECULAR_EPSILON);\n    float d = (_cosTheta * alphaSqr - _cosTheta) * _cosTheta + 1.0f;\n    return alphaSqr / (d * d * M_PI);\n}\n\n__device__ inline void bwdNdfGGX(const float alphaSqr, const float cosTheta, float& d_alphaSqr, float& d_cosTheta, const float d_out)\n{\n    // Torch only back propagates if clamp doesn't trigger\n    float _cosTheta = clamp(cosTheta, SPECULAR_EPSILON, 1.0f - SPECULAR_EPSILON);\n    float cosThetaSqr = _cosTheta * _cosTheta;\n    d_alphaSqr += d_out * (1.0f - (alphaSqr + 1.0f) * cosThetaSqr) / (M_PI * powf((alphaSqr - 1.0) * cosThetaSqr + 1.0f, 3.0f));\n    if (cosTheta > SPECULAR_EPSILON && cosTheta < 1.0f - SPECULAR_EPSILON)\n    {\n        d_cosTheta += d_out * -(4.0f * (alphaSqr - 1.0f) * alphaSqr * cosTheta) / (M_PI * powf((alphaSqr - 1.0) * cosThetaSqr + 1.0f, 3.0f));\n    }\n}\n\n//------------------------------------------------------------------------\n// Lambda GGX\n\n__device__ inline float fwdLambdaGGX(const float alphaSqr, const float cosTheta)\n{\n    float _cosTheta = clamp(cosTheta, SPECULAR_EPSILON, 1.0f - SPECULAR_EPSILON);\n    float cosThetaSqr = _cosTheta * _cosTheta;\n    float tanThetaSqr = (1.0 - cosThetaSqr) / cosThetaSqr;\n    float res = 0.5f * (sqrtf(1.0f + alphaSqr * tanThetaSqr) - 1.0f);\n    return res;\n}\n\n__device__ inline void bwdLambdaGGX(const float alphaSqr, const float cosTheta, float& d_alphaSqr, float& d_cosTheta, const float d_out)\n{\n    float _cosTheta = clamp(cosTheta, SPECULAR_EPSILON, 1.0f - SPECULAR_EPSILON);\n    float cosThetaSqr = _cosTheta * _cosTheta;\n    float tanThetaSqr = (1.0 - cosThetaSqr) / cosThetaSqr;\n    float res = 0.5f * (sqrtf(1.0f + alphaSqr * tanThetaSqr) - 1.0f);\n\n    d_alphaSqr += d_out * (0.25 * tanThetaSqr) / sqrtf(alphaSqr * tanThetaSqr + 1.0f);\n    if (cosTheta > SPECULAR_EPSILON && cosTheta < 1.0f - SPECULAR_EPSILON)\n        d_cosTheta += d_out * -(0.5 * alphaSqr) / (powf(_cosTheta, 3.0f) * sqrtf(alphaSqr / cosThetaSqr - alphaSqr + 1.0f));\n}\n\n//------------------------------------------------------------------------\n// Masking GGX\n\n__device__ inline float fwdMaskingSmithGGXCorrelated(const float alphaSqr, const float cosThetaI, const float cosThetaO)\n{\n    float lambdaI = fwdLambdaGGX(alphaSqr, cosThetaI);\n    float lambdaO = fwdLambdaGGX(alphaSqr, cosThetaO);\n    return 1.0f / (1.0f + lambdaI + lambdaO);\n}\n\n__device__ inline void bwdMaskingSmithGGXCorrelated(const float alphaSqr, const float cosThetaI, const float cosThetaO, float& d_alphaSqr, float& d_cosThetaI, float& d_cosThetaO, const float d_out)\n{\n    // FWD eval\n    float lambdaI = fwdLambdaGGX(alphaSqr, cosThetaI);\n    float lambdaO = fwdLambdaGGX(alphaSqr, cosThetaO);\n\n    // BWD eval\n    float d_lambdaIO = -d_out / powf(1.0f + lambdaI + lambdaO, 2.0f);\n    bwdLambdaGGX(alphaSqr, cosThetaI, d_alphaSqr, d_cosThetaI, d_lambdaIO);\n    bwdLambdaGGX(alphaSqr, cosThetaO, d_alphaSqr, d_cosThetaO, d_lambdaIO);\n}\n\n//------------------------------------------------------------------------\n// GGX specular\n\n__device__ vec3f fwdPbrSpecular(const vec3f col, const vec3f nrm, const vec3f wo, const vec3f wi, const float alpha, const float min_roughness)\n{\n    float _alpha = clamp(alpha, min_roughness * min_roughness, 1.0f);\n    float alphaSqr = _alpha * _alpha;\n\n    vec3f h = safeNormalize(wo + wi);\n    float woDotN = dot(wo, nrm);\n    float wiDotN = dot(wi, nrm);\n    float woDotH = dot(wo, h);\n    float nDotH = dot(nrm, h);\n\n    float D = fwdNdfGGX(alphaSqr, nDotH);\n    float G = fwdMaskingSmithGGXCorrelated(alphaSqr, woDotN, wiDotN);\n    vec3f F = fwdFresnelSchlick(col, 1.0f, woDotH);\n    vec3f w = F * D * G * 0.25 / woDotN;\n\n    bool frontfacing = (woDotN > SPECULAR_EPSILON) & (wiDotN > SPECULAR_EPSILON);\n    return frontfacing ? w : 0.0f;\n}\n\n__device__ void bwdPbrSpecular(\n    const vec3f col, const vec3f nrm, const vec3f wo, const vec3f wi, const float alpha, const float min_roughness,\n    vec3f& d_col, vec3f& d_nrm, vec3f& d_wo, vec3f& d_wi, float& d_alpha, const vec3f d_out)\n{\n    ///////////////////////////////////////////////////////////////////////\n    // FWD eval\n\n    float _alpha = clamp(alpha, min_roughness * min_roughness, 1.0f);\n    float alphaSqr = _alpha * _alpha;\n\n    vec3f h = safeNormalize(wo + wi);\n    float woDotN = dot(wo, nrm);\n    float wiDotN = dot(wi, nrm);\n    float woDotH = dot(wo, h);\n    float nDotH = dot(nrm, h);\n\n    float D = fwdNdfGGX(alphaSqr, nDotH);\n    float G = fwdMaskingSmithGGXCorrelated(alphaSqr, woDotN, wiDotN);\n    vec3f F = fwdFresnelSchlick(col, 1.0f, woDotH);\n    vec3f w = F * D * G * 0.25 / woDotN;\n    bool frontfacing = (woDotN > SPECULAR_EPSILON) & (wiDotN > SPECULAR_EPSILON);\n\n    if (frontfacing)\n    {\n        ///////////////////////////////////////////////////////////////////////\n        // BWD eval\n\n        vec3f d_F = d_out * D * G * 0.25f / woDotN;\n        float d_D = sum(d_out * F * G * 0.25f / woDotN);\n        float d_G = sum(d_out * F * D * 0.25f / woDotN);\n\n        float d_woDotN = -sum(d_out * F * D * G * 0.25f / (woDotN * woDotN));\n\n        vec3f d_f90(0);\n        float d_woDotH(0), d_wiDotN(0), d_nDotH(0), d_alphaSqr(0);\n        bwdFresnelSchlick(col, 1.0f, woDotH, d_col, d_f90, d_woDotH, d_F);\n        bwdMaskingSmithGGXCorrelated(alphaSqr, woDotN, wiDotN, d_alphaSqr, d_woDotN, d_wiDotN, d_G);\n        bwdNdfGGX(alphaSqr, nDotH, d_alphaSqr, d_nDotH, d_D);\n\n        vec3f d_h(0);\n        bwdDot(nrm, h, d_nrm, d_h, d_nDotH);\n        bwdDot(wo, h, d_wo, d_h, d_woDotH);\n        bwdDot(wi, nrm, d_wi, d_nrm, d_wiDotN);\n        bwdDot(wo, nrm, d_wo, d_nrm, d_woDotN);\n\n        vec3f d_h_unnorm(0);\n        bwdSafeNormalize(wo + wi, d_h_unnorm, d_h);\n        d_wo += d_h_unnorm;\n        d_wi += d_h_unnorm;\n\n        if (alpha > min_roughness * min_roughness)\n            d_alpha += d_alphaSqr * 2 * alpha;\n    }\n}\n\n//------------------------------------------------------------------------\n// Full PBR BSDF\n\n__device__ vec3f fwdPbrBSDF(const vec3f kd, const vec3f arm, const vec3f pos, const vec3f nrm, const vec3f view_pos, const vec3f light_pos, const float min_roughness, int BSDF)\n{\n    vec3f wo = safeNormalize(view_pos - pos);\n    vec3f wi = safeNormalize(light_pos - pos);\n\n    float alpha = arm.y * arm.y;\n    vec3f spec_col = (0.04f * (1.0f - arm.z) + kd * arm.z) * (1.0 - arm.x);\n    vec3f diff_col = kd * (1.0f - arm.z);\n\n    float diff = 0.0f;\n    if (BSDF == 0)\n        diff = fwdLambert(nrm, wi);\n    else\n        diff = fwdFrostbiteDiffuse(nrm, wi, wo, arm.y);    \n    vec3f diffuse = diff_col * diff;\n    vec3f specular = fwdPbrSpecular(spec_col, nrm, wo, wi, alpha, min_roughness);\n\n    return diffuse + specular;\n}\n\n__device__ void bwdPbrBSDF(\n    const vec3f kd, const vec3f arm, const vec3f pos, const vec3f nrm, const vec3f view_pos, const vec3f light_pos, const float min_roughness, int BSDF,\n    vec3f& d_kd, vec3f& d_arm, vec3f& d_pos, vec3f& d_nrm, vec3f& d_view_pos, vec3f& d_light_pos, const vec3f d_out)\n{\n    ////////////////////////////////////////////////////////////////////////\n    // FWD\n    vec3f _wi = light_pos - pos;\n    vec3f _wo = view_pos - pos;\n    vec3f wi = safeNormalize(_wi);\n    vec3f wo = safeNormalize(_wo);\n\n    float alpha = arm.y * arm.y;\n    vec3f spec_col = (0.04f * (1.0f - arm.z) + kd * arm.z) * (1.0 - arm.x);\n    vec3f diff_col = kd * (1.0f - arm.z);\n    float diff = 0.0f;\n    if (BSDF == 0)\n        diff = fwdLambert(nrm, wi);\n    else\n        diff = fwdFrostbiteDiffuse(nrm, wi, wo, arm.y);    \n\n    ////////////////////////////////////////////////////////////////////////\n    // BWD\n\n    float d_alpha(0);\n    vec3f d_spec_col(0), d_wi(0), d_wo(0);\n    bwdPbrSpecular(spec_col, nrm, wo, wi, alpha, min_roughness, d_spec_col, d_nrm, d_wo, d_wi, d_alpha, d_out);\n\n    float d_diff = sum(diff_col * d_out);\n    if (BSDF == 0)\n        bwdLambert(nrm, wi, d_nrm, d_wi, d_diff);\n    else\n        bwdFrostbiteDiffuse(nrm, wi, wo, arm.y, d_nrm, d_wi, d_wo, d_arm.y, d_diff);    \n\n    // Backprop: diff_col = kd * (1.0f - arm.z)\n    vec3f d_diff_col = d_out * diff;\n    d_kd += d_diff_col * (1.0f - arm.z);\n    d_arm.z -= sum(d_diff_col * kd);\n\n    // Backprop: spec_col = (0.04f * (1.0f - arm.z) + kd * arm.z) * (1.0 - arm.x)\n    d_kd -= d_spec_col * (arm.x - 1.0f) * arm.z;\n    d_arm.x += sum(d_spec_col * (arm.z * (0.04f - kd) - 0.04f));\n    d_arm.z -= sum(d_spec_col * (kd - 0.04f) * (arm.x - 1.0f));\n\n    // Backprop: alpha = arm.y * arm.y\n    d_arm.y += d_alpha * 2 * arm.y;\n\n    // Backprop: vec3f wi = safeNormalize(light_pos - pos);\n    vec3f d__wi(0);\n    bwdSafeNormalize(_wi, d__wi, d_wi);\n    d_light_pos += d__wi;\n    d_pos -= d__wi;\n\n    // Backprop: vec3f wo = safeNormalize(view_pos - pos);\n    vec3f d__wo(0);\n    bwdSafeNormalize(_wo, d__wo, d_wo);\n    d_view_pos += d__wo;\n    d_pos -= d__wo;\n}\n\n//------------------------------------------------------------------------\n// Kernels\n\n__global__ void LambertFwdKernel(LambertKernelParams p)\n{\n    // Calculate pixel position.\n    unsigned int px = blockIdx.x * blockDim.x + threadIdx.x;\n    unsigned int py = blockIdx.y * blockDim.y + threadIdx.y;\n    unsigned int pz = blockIdx.z;\n    if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z)\n        return;\n\n    vec3f nrm = p.nrm.fetch3(px, py, pz);\n    vec3f wi = p.wi.fetch3(px, py, pz);\n\n    float res = fwdLambert(nrm, wi);\n\n    p.out.store(px, py, pz, res);\n}\n\n__global__ void LambertBwdKernel(LambertKernelParams p)\n{\n    // Calculate pixel position.\n    unsigned int px = blockIdx.x * blockDim.x + threadIdx.x;\n    unsigned int py = blockIdx.y * blockDim.y + threadIdx.y;\n    unsigned int pz = blockIdx.z;\n    if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z)\n        return;\n\n    vec3f nrm = p.nrm.fetch3(px, py, pz);\n    vec3f wi = p.wi.fetch3(px, py, pz);\n    float d_out = p.out.fetch1(px, py, pz);\n\n    vec3f d_nrm(0), d_wi(0);\n    bwdLambert(nrm, wi, d_nrm, d_wi, d_out);\n\n    p.nrm.store_grad(px, py, pz, d_nrm);\n    p.wi.store_grad(px, py, pz, d_wi);\n}\n\n__global__ void FrostbiteDiffuseFwdKernel(FrostbiteDiffuseKernelParams p)\n{\n    // Calculate pixel position.\n    unsigned int px = blockIdx.x * blockDim.x + threadIdx.x;\n    unsigned int py = blockIdx.y * blockDim.y + threadIdx.y;\n    unsigned int pz = blockIdx.z;\n    if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z)\n        return;\n\n    vec3f nrm = p.nrm.fetch3(px, py, pz);\n    vec3f wi = p.wi.fetch3(px, py, pz);\n    vec3f wo = p.wo.fetch3(px, py, pz);\n    float linearRoughness = p.linearRoughness.fetch1(px, py, pz);\n\n    float res = fwdFrostbiteDiffuse(nrm, wi, wo, linearRoughness);\n\n    p.out.store(px, py, pz, res);\n}\n\n__global__ void FrostbiteDiffuseBwdKernel(FrostbiteDiffuseKernelParams p)\n{\n    // Calculate pixel position.\n    unsigned int px = blockIdx.x * blockDim.x + threadIdx.x;\n    unsigned int py = blockIdx.y * blockDim.y + threadIdx.y;\n    unsigned int pz = blockIdx.z;\n    if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z)\n        return;\n\n    vec3f nrm = p.nrm.fetch3(px, py, pz);\n    vec3f wi = p.wi.fetch3(px, py, pz);\n    vec3f wo = p.wo.fetch3(px, py, pz);\n    float linearRoughness = p.linearRoughness.fetch1(px, py, pz);\n    float d_out = p.out.fetch1(px, py, pz);\n\n    float d_linearRoughness = 0.0f;\n    vec3f d_nrm(0), d_wi(0), d_wo(0);\n    bwdFrostbiteDiffuse(nrm, wi, wo, linearRoughness, d_nrm, d_wi, d_wo, d_linearRoughness, d_out);\n\n    p.nrm.store_grad(px, py, pz, d_nrm);\n    p.wi.store_grad(px, py, pz, d_wi);\n    p.wo.store_grad(px, py, pz, d_wo);\n    p.linearRoughness.store_grad(px, py, pz, d_linearRoughness);\n}\n\n__global__ void FresnelShlickFwdKernel(FresnelShlickKernelParams p)\n{\n    // Calculate pixel position.\n    unsigned int px = blockIdx.x * blockDim.x + threadIdx.x;\n    unsigned int py = blockIdx.y * blockDim.y + threadIdx.y;\n    unsigned int pz = blockIdx.z;\n    if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z)\n        return;\n\n    vec3f f0 = p.f0.fetch3(px, py, pz);\n    vec3f f90 = p.f90.fetch3(px, py, pz);\n    float cosTheta = p.cosTheta.fetch1(px, py, pz);\n\n    vec3f res = fwdFresnelSchlick(f0, f90, cosTheta);\n    p.out.store(px, py, pz, res);\n}\n\n__global__ void FresnelShlickBwdKernel(FresnelShlickKernelParams p)\n{\n    // Calculate pixel position.\n    unsigned int px = blockIdx.x * blockDim.x + threadIdx.x;\n    unsigned int py = blockIdx.y * blockDim.y + threadIdx.y;\n    unsigned int pz = blockIdx.z;\n    if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z)\n        return;\n\n    vec3f f0 = p.f0.fetch3(px, py, pz);\n    vec3f f90 = p.f90.fetch3(px, py, pz);\n    float cosTheta = p.cosTheta.fetch1(px, py, pz);\n    vec3f d_out = p.out.fetch3(px, py, pz);\n\n    vec3f d_f0(0), d_f90(0);\n    float d_cosTheta(0);\n    bwdFresnelSchlick(f0, f90, cosTheta, d_f0, d_f90, d_cosTheta, d_out);\n\n    p.f0.store_grad(px, py, pz, d_f0);\n    p.f90.store_grad(px, py, pz, d_f90);\n    p.cosTheta.store_grad(px, py, pz, d_cosTheta);\n}\n\n__global__ void ndfGGXFwdKernel(NdfGGXParams p)\n{\n    // Calculate pixel position.\n    unsigned int px = blockIdx.x * blockDim.x + threadIdx.x;\n    unsigned int py = blockIdx.y * blockDim.y + threadIdx.y;\n    unsigned int pz = blockIdx.z;\n    if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z)\n        return;\n\n    float alphaSqr = p.alphaSqr.fetch1(px, py, pz);\n    float cosTheta = p.cosTheta.fetch1(px, py, pz);\n    float res = fwdNdfGGX(alphaSqr, cosTheta);\n    \n    p.out.store(px, py, pz, res);\n}\n\n__global__ void ndfGGXBwdKernel(NdfGGXParams p)\n{\n    // Calculate pixel position.\n    unsigned int px = blockIdx.x * blockDim.x + threadIdx.x;\n    unsigned int py = blockIdx.y * blockDim.y + threadIdx.y;\n    unsigned int pz = blockIdx.z;\n    if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z)\n        return;\n\n    float alphaSqr = p.alphaSqr.fetch1(px, py, pz);\n    float cosTheta = p.cosTheta.fetch1(px, py, pz);\n    float d_out = p.out.fetch1(px, py, pz);\n\n    float d_alphaSqr(0), d_cosTheta(0);\n    bwdNdfGGX(alphaSqr, cosTheta, d_alphaSqr, d_cosTheta, d_out);\n\n    p.alphaSqr.store_grad(px, py, pz, d_alphaSqr);\n    p.cosTheta.store_grad(px, py, pz, d_cosTheta);\n}\n\n__global__ void lambdaGGXFwdKernel(NdfGGXParams p)\n{\n    // Calculate pixel position.\n    unsigned int px = blockIdx.x * blockDim.x + threadIdx.x;\n    unsigned int py = blockIdx.y * blockDim.y + threadIdx.y;\n    unsigned int pz = blockIdx.z;\n    if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z)\n        return;\n\n    float alphaSqr = p.alphaSqr.fetch1(px, py, pz);\n    float cosTheta = p.cosTheta.fetch1(px, py, pz);\n    float res = fwdLambdaGGX(alphaSqr, cosTheta);\n\n    p.out.store(px, py, pz, res);\n}\n\n__global__ void lambdaGGXBwdKernel(NdfGGXParams p)\n{\n    // Calculate pixel position.\n    unsigned int px = blockIdx.x * blockDim.x + threadIdx.x;\n    unsigned int py = blockIdx.y * blockDim.y + threadIdx.y;\n    unsigned int pz = blockIdx.z;\n    if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z)\n        return;\n\n    float alphaSqr = p.alphaSqr.fetch1(px, py, pz);\n    float cosTheta = p.cosTheta.fetch1(px, py, pz);\n    float d_out = p.out.fetch1(px, py, pz);\n\n    float d_alphaSqr(0), d_cosTheta(0);\n    bwdLambdaGGX(alphaSqr, cosTheta, d_alphaSqr, d_cosTheta, d_out);\n\n    p.alphaSqr.store_grad(px, py, pz, d_alphaSqr);\n    p.cosTheta.store_grad(px, py, pz, d_cosTheta);\n}\n\n__global__ void maskingSmithFwdKernel(MaskingSmithParams p)\n{\n    // Calculate pixel position.\n    unsigned int px = blockIdx.x * blockDim.x + threadIdx.x;\n    unsigned int py = blockIdx.y * blockDim.y + threadIdx.y;\n    unsigned int pz = blockIdx.z;\n    if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z)\n        return;\n\n    float alphaSqr = p.alphaSqr.fetch1(px, py, pz);\n    float cosThetaI = p.cosThetaI.fetch1(px, py, pz);\n    float cosThetaO = p.cosThetaO.fetch1(px, py, pz);\n    float res = fwdMaskingSmithGGXCorrelated(alphaSqr, cosThetaI, cosThetaO);\n    \n    p.out.store(px, py, pz, res);\n}\n\n__global__ void maskingSmithBwdKernel(MaskingSmithParams p)\n{\n    // Calculate pixel position.\n    unsigned int px = blockIdx.x * blockDim.x + threadIdx.x;\n    unsigned int py = blockIdx.y * blockDim.y + threadIdx.y;\n    unsigned int pz = blockIdx.z;\n    if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z)\n        return;\n\n    float alphaSqr = p.alphaSqr.fetch1(px, py, pz);\n    float cosThetaI = p.cosThetaI.fetch1(px, py, pz);\n    float cosThetaO = p.cosThetaO.fetch1(px, py, pz);\n    float d_out = p.out.fetch1(px, py, pz);\n\n    float d_alphaSqr(0), d_cosThetaI(0), d_cosThetaO(0);\n    bwdMaskingSmithGGXCorrelated(alphaSqr, cosThetaI, cosThetaO, d_alphaSqr, d_cosThetaI, d_cosThetaO, d_out);\n\n    p.alphaSqr.store_grad(px, py, pz, d_alphaSqr);\n    p.cosThetaI.store_grad(px, py, pz, d_cosThetaI);\n    p.cosThetaO.store_grad(px, py, pz, d_cosThetaO);\n}\n\n__global__ void pbrSpecularFwdKernel(PbrSpecular p)\n{\n    // Calculate pixel position.\n    unsigned int px = blockIdx.x * blockDim.x + threadIdx.x;\n    unsigned int py = blockIdx.y * blockDim.y + threadIdx.y;\n    unsigned int pz = blockIdx.z;\n    if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z)\n        return;\n\n    vec3f col = p.col.fetch3(px, py, pz);\n    vec3f nrm = p.nrm.fetch3(px, py, pz);\n    vec3f wo = p.wo.fetch3(px, py, pz);\n    vec3f wi = p.wi.fetch3(px, py, pz);\n    float alpha = p.alpha.fetch1(px, py, pz);\n\n    vec3f res = fwdPbrSpecular(col, nrm, wo, wi, alpha, p.min_roughness);\n\n    p.out.store(px, py, pz, res);\n}\n\n__global__ void pbrSpecularBwdKernel(PbrSpecular p)\n{\n    // Calculate pixel position.\n    unsigned int px = blockIdx.x * blockDim.x + threadIdx.x;\n    unsigned int py = blockIdx.y * blockDim.y + threadIdx.y;\n    unsigned int pz = blockIdx.z;\n    if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z)\n        return;\n\n    vec3f col = p.col.fetch3(px, py, pz);\n    vec3f nrm = p.nrm.fetch3(px, py, pz);\n    vec3f wo = p.wo.fetch3(px, py, pz);\n    vec3f wi = p.wi.fetch3(px, py, pz);\n    float alpha = p.alpha.fetch1(px, py, pz);\n    vec3f d_out = p.out.fetch3(px, py, pz);\n\n    float d_alpha(0);\n    vec3f d_col(0), d_nrm(0), d_wo(0), d_wi(0);\n    bwdPbrSpecular(col, nrm, wo, wi, alpha, p.min_roughness, d_col, d_nrm, d_wo, d_wi, d_alpha, d_out);\n\n    p.col.store_grad(px, py, pz, d_col);\n    p.nrm.store_grad(px, py, pz, d_nrm);\n    p.wo.store_grad(px, py, pz, d_wo);\n    p.wi.store_grad(px, py, pz, d_wi);\n    p.alpha.store_grad(px, py, pz, d_alpha);\n}\n\n__global__ void pbrBSDFFwdKernel(PbrBSDF p)\n{\n    // Calculate pixel position.\n    unsigned int px = blockIdx.x * blockDim.x + threadIdx.x;\n    unsigned int py = blockIdx.y * blockDim.y + threadIdx.y;\n    unsigned int pz = blockIdx.z;\n    if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z)\n        return;\n\n    vec3f kd = p.kd.fetch3(px, py, pz);\n    vec3f arm = p.arm.fetch3(px, py, pz);\n    vec3f pos = p.pos.fetch3(px, py, pz);\n    vec3f nrm = p.nrm.fetch3(px, py, pz);\n    vec3f view_pos = p.view_pos.fetch3(px, py, pz);\n    vec3f light_pos = p.light_pos.fetch3(px, py, pz);\n\n    vec3f res = fwdPbrBSDF(kd, arm, pos, nrm, view_pos, light_pos, p.min_roughness, p.BSDF);\n\n    p.out.store(px, py, pz, res);\n}\n__global__ void pbrBSDFBwdKernel(PbrBSDF p)\n{\n    // Calculate pixel position.\n    unsigned int px = blockIdx.x * blockDim.x + threadIdx.x;\n    unsigned int py = blockIdx.y * blockDim.y + threadIdx.y;\n    unsigned int pz = blockIdx.z;\n    if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z)\n        return;\n\n    vec3f kd = p.kd.fetch3(px, py, pz);\n    vec3f arm = p.arm.fetch3(px, py, pz);\n    vec3f pos = p.pos.fetch3(px, py, pz);\n    vec3f nrm = p.nrm.fetch3(px, py, pz);\n    vec3f view_pos = p.view_pos.fetch3(px, py, pz);\n    vec3f light_pos = p.light_pos.fetch3(px, py, pz);\n    vec3f d_out = p.out.fetch3(px, py, pz);\n\n    vec3f d_kd(0), d_arm(0), d_pos(0), d_nrm(0), d_view_pos(0), d_light_pos(0);\n    bwdPbrBSDF(kd, arm, pos, nrm, view_pos, light_pos, p.min_roughness, p.BSDF, d_kd, d_arm, d_pos, d_nrm, d_view_pos, d_light_pos, d_out);\n\n    p.kd.store_grad(px, py, pz, d_kd);\n    p.arm.store_grad(px, py, pz, d_arm);\n    p.pos.store_grad(px, py, pz, d_pos);\n    p.nrm.store_grad(px, py, pz, d_nrm);\n    p.view_pos.store_grad(px, py, pz, d_view_pos);\n    p.light_pos.store_grad(px, py, pz, d_light_pos);\n}\n"
  },
  {
    "path": "render/renderutils/c_src/bsdf.h",
    "content": "/*\n * Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n *\n * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual\n * property and proprietary rights in and to this material, related \n * documentation and any modifications thereto. Any use, reproduction, \n * disclosure or distribution of this material and related documentation\n * without an express license agreement from NVIDIA CORPORATION or \n * its affiliates is strictly prohibited.\n */\n\n#pragma once\n\n#include \"common.h\"\n\nstruct LambertKernelParams\n{\n    Tensor  nrm;\n    Tensor  wi;\n    Tensor  out;\n    dim3    gridSize;\n};\n\nstruct FrostbiteDiffuseKernelParams\n{\n    Tensor  nrm;\n    Tensor  wi;\n    Tensor  wo;\n    Tensor  linearRoughness;\n    Tensor  out;\n    dim3    gridSize;\n};\n\nstruct FresnelShlickKernelParams\n{\n    Tensor  f0;\n    Tensor  f90;\n    Tensor  cosTheta;\n    Tensor  out;\n    dim3    gridSize;\n};\n\nstruct NdfGGXParams\n{\n    Tensor  alphaSqr;\n    Tensor  cosTheta;\n    Tensor  out;\n    dim3    gridSize;\n};\n\nstruct MaskingSmithParams\n{\n    Tensor  alphaSqr;\n    Tensor  cosThetaI;\n    Tensor  cosThetaO;\n    Tensor  out;\n    dim3    gridSize;\n};\n\nstruct PbrSpecular\n{\n    Tensor  col;\n    Tensor  nrm;\n    Tensor  wo;\n    Tensor  wi;\n    Tensor  alpha;\n    Tensor  out;\n    dim3    gridSize;\n    float   min_roughness;\n};\n\nstruct PbrBSDF\n{\n    Tensor  kd;\n    Tensor  arm;\n    Tensor  pos;\n    Tensor  nrm;\n    Tensor  view_pos;\n    Tensor  light_pos;\n    Tensor  out;\n    dim3    gridSize;\n    float   min_roughness;\n    int     BSDF;\n};\n"
  },
  {
    "path": "render/renderutils/c_src/common.cpp",
    "content": "/*\n * Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n *\n * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual\n * property and proprietary rights in and to this material, related \n * documentation and any modifications thereto. Any use, reproduction, \n * disclosure or distribution of this material and related documentation\n * without an express license agreement from NVIDIA CORPORATION or \n * its affiliates is strictly prohibited.\n */\n\n#include <cuda_runtime.h>\n#include <algorithm>\n\n//------------------------------------------------------------------------\n// Block and grid size calculators for kernel launches.\n\ndim3 getLaunchBlockSize(int maxWidth, int maxHeight, dim3 dims)\n{\n    int maxThreads = maxWidth * maxHeight;\n    if (maxThreads <= 1 || (dims.x * dims.y) <= 1)\n        return dim3(1, 1, 1); // Degenerate.\n\n    // Start from max size.\n    int bw = maxWidth;\n    int bh = maxHeight;\n\n    // Optimizations for weirdly sized buffers.\n    if (dims.x < bw)\n    {\n        // Decrease block width to smallest power of two that covers the buffer width.\n        while ((bw >> 1) >= dims.x)\n            bw >>= 1;\n\n        // Maximize height.\n        bh = maxThreads / bw;\n        if (bh > dims.y)\n            bh = dims.y;\n    }\n    else if (dims.y < bh)\n    {\n        // Halve height and double width until fits completely inside buffer vertically.\n        while (bh > dims.y)\n        {\n            bh >>= 1;\n            if (bw < dims.x)\n                bw <<= 1;\n        }\n    }\n\n    // Done.\n    return dim3(bw, bh, 1);\n}\n\n// returns the size of a block that can be reduced using horizontal SIMD operations (e.g. __shfl_xor_sync)\ndim3 getWarpSize(dim3 blockSize)\n{\n    return dim3(\n        std::min(blockSize.x, 32u), \n        std::min(std::max(32u / blockSize.x, 1u), std::min(32u, blockSize.y)), \n        std::min(std::max(32u / (blockSize.x * blockSize.y), 1u), std::min(32u, blockSize.z))\n    );\n}\n\ndim3 getLaunchGridSize(dim3 blockSize, dim3 dims)\n{\n    dim3 gridSize;\n    gridSize.x = (dims.x  - 1) / blockSize.x + 1;\n    gridSize.y = (dims.y - 1) / blockSize.y + 1;\n    gridSize.z = (dims.z  - 1) / blockSize.z + 1;\n    return gridSize;\n}\n\n//------------------------------------------------------------------------\n"
  },
  {
    "path": "render/renderutils/c_src/common.h",
    "content": "/*\n * Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n *\n * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual\n * property and proprietary rights in and to this material, related \n * documentation and any modifications thereto. Any use, reproduction, \n * disclosure or distribution of this material and related documentation\n * without an express license agreement from NVIDIA CORPORATION or \n * its affiliates is strictly prohibited.\n */\n\n#pragma once\n#include <cuda.h>\n#include <stdint.h>\n\n#include \"vec3f.h\"\n#include \"vec4f.h\"\n#include \"tensor.h\"\n\ndim3 getLaunchBlockSize(int maxWidth, int maxHeight, dim3 dims);\ndim3 getLaunchGridSize(dim3 blockSize, dim3 dims);\n\n#ifdef __CUDACC__\n\n#ifdef _MSC_VER\n#define M_PI 3.14159265358979323846f\n#endif\n\n__host__ __device__ static inline dim3 getWarpSize(dim3 blockSize)\n{\n    return dim3(\n        min(blockSize.x, 32u),\n        min(max(32u / blockSize.x, 1u), min(32u, blockSize.y)),\n        min(max(32u / (blockSize.x * blockSize.y), 1u), min(32u, blockSize.z))\n    );\n}\n\n__device__ static inline float clamp(float val, float mn, float mx) { return min(max(val, mn), mx); }\n#else\ndim3 getWarpSize(dim3 blockSize);\n#endif"
  },
  {
    "path": "render/renderutils/c_src/cubemap.cu",
    "content": "/*\n * Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n *\n * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual\n * property and proprietary rights in and to this material, related \n * documentation and any modifications thereto. Any use, reproduction, \n * disclosure or distribution of this material and related documentation\n * without an express license agreement from NVIDIA CORPORATION or \n * its affiliates is strictly prohibited.\n */\n\n#include \"common.h\"\n#include \"cubemap.h\"\n#include <float.h>\n\n// https://cgvr.cs.uni-bremen.de/teaching/cg_literatur/Spherical,%20Cubic,%20and%20Parabolic%20Environment%20Mappings.pdf\n__device__ float pixel_area(int x, int y, int N)\n{\n    if (N > 1)\n    {\n        int H = N / 2;\n        x = abs(x - H);\n        y = abs(y - H);\n        float dx = atan((float)(x + 1) / (float)H) - atan((float)x / (float)H);\n        float dy = atan((float)(y + 1) / (float)H) - atan((float)y / (float)H);\n        return dx * dy;\n    }\n    else\n        return 1;\n}\n\n__device__ vec3f cube_to_dir(int x, int y, int side, int N)\n{\n    float fx = 2.0f * (((float)x + 0.5f) / (float)N) - 1.0f;\n    float fy = 2.0f * (((float)y + 0.5f) / (float)N) - 1.0f;\n    switch (side)\n    {\n        case 0: return safeNormalize(vec3f(1, -fy, -fx));\n        case 1: return safeNormalize(vec3f(-1, -fy, fx));\n        case 2: return safeNormalize(vec3f(fx, 1, fy));\n        case 3: return safeNormalize(vec3f(fx, -1, -fy));\n        case 4: return safeNormalize(vec3f(fx, -fy, 1));\n        case 5: return safeNormalize(vec3f(-fx, -fy, -1));\n    }\n    return vec3f(0,0,0); // Unreachable\n}\n\n__device__ vec3f dir_to_side(int side, vec3f v)\n{\n    switch (side)\n    {\n    case 0: return vec3f(-v.z, -v.y,  v.x);\n    case 1: return vec3f( v.z, -v.y, -v.x);\n    case 2: return vec3f( v.x,  v.z,  v.y);\n    case 3: return vec3f( v.x, -v.z, -v.y);\n    case 4: return vec3f( v.x, -v.y,  v.z);\n    case 5: return vec3f(-v.x, -v.y, -v.z);\n    }\n    return vec3f(0,0,0); // Unreachable\n}\n\n__device__ void extents_1d(float x, float z, float theta, float& _min, float& _max)\n{\n    float l = sqrtf(x * x + z * z);\n    float pxr = x + z * tan(theta) * l, pzr = z - x * tan(theta) * l;\n    float pxl = x - z * tan(theta) * l, pzl = z + x * tan(theta) * l;\n    if (pzl <= 0.00001f)\n        _min = pxl > 0.0f ? FLT_MAX : -FLT_MAX;\n    else\n        _min = pxl / pzl;\n    if (pzr <= 0.00001f)\n        _max = pxr > 0.0f ? FLT_MAX : -FLT_MAX;\n    else\n        _max = pxr / pzr;\n}\n\n__device__ void dir_extents(int side, int N, vec3f v, float theta, int &_xmin, int& _xmax, int& _ymin, int& _ymax)\n{\n    vec3f c = dir_to_side(side, v); // remap to (x,y,z) where side is at z = 1\n\n    if (theta < 0.785398f) // PI/4\n    {\n        float xmin, xmax, ymin, ymax;\n        extents_1d(c.x, c.z, theta, xmin, xmax);\n        extents_1d(c.y, c.z, theta, ymin, ymax);\n\n        if (xmin > 1.0f || xmax < -1.0f || ymin > 1.0f || ymax < -1.0f)\n        {\n            _xmin = -1; _xmax = -1; _ymin = -1; _ymax = -1; // Bad aabb\n        }\n        else\n        {\n            _xmin = (int)min(max((xmin + 1.0f) * (0.5f * (float)N), 0.0f), (float)(N - 1));\n            _xmax = (int)min(max((xmax + 1.0f) * (0.5f * (float)N), 0.0f), (float)(N - 1));\n            _ymin = (int)min(max((ymin + 1.0f) * (0.5f * (float)N), 0.0f), (float)(N - 1));\n            _ymax = (int)min(max((ymax + 1.0f) * (0.5f * (float)N), 0.0f), (float)(N - 1));\n        }\n    }\n    else\n    {\n            _xmin = 0.0f;\n            _xmax = (float)(N-1);\n            _ymin = 0.0f;\n            _ymax = (float)(N-1);\n    }\n}\n\n///////////////////////////////////////////////////////////////////////////////////////////////////////////\n// Diffuse kernel\n__global__ void DiffuseCubemapFwdKernel(DiffuseCubemapKernelParams p)\n{\n    // Calculate pixel position.\n    int px = blockIdx.x * blockDim.x + threadIdx.x;\n    int py = blockIdx.y * blockDim.y + threadIdx.y;\n    int pz = blockIdx.z;\n    if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z)\n        return;\n\n    int Npx = p.cubemap.dims[1];\n    vec3f N = cube_to_dir(px, py, pz, Npx);\n\n    vec3f col(0);\n\n    for (int s = 0; s < p.cubemap.dims[0]; ++s)\n    {\n        for (int y = 0; y < Npx; ++y)\n        {\n            for (int x = 0; x < Npx; ++x)\n            {\n                vec3f L = cube_to_dir(x, y, s, Npx);\n                float costheta = min(max(dot(N, L), 0.0f), 0.999f);\n                float w = costheta * pixel_area(x, y, Npx) / 3.141592f; // pi = area of positive hemisphere\n                col += p.cubemap.fetch3(x, y, s) * w;\n            }\n        }\n    }\n\n    p.out.store(px, py, pz, col);\n}\n\n__global__ void DiffuseCubemapBwdKernel(DiffuseCubemapKernelParams p)\n{\n    // Calculate pixel position.\n    int px = blockIdx.x * blockDim.x + threadIdx.x;\n    int py = blockIdx.y * blockDim.y + threadIdx.y;\n    int pz = blockIdx.z;\n    if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z)\n        return;\n\n    int Npx = p.cubemap.dims[1];\n    vec3f N = cube_to_dir(px, py, pz, Npx);\n    vec3f grad = p.out.fetch3(px, py, pz);\n\n    for (int s = 0; s < p.cubemap.dims[0]; ++s)\n    {\n        for (int y = 0; y < Npx; ++y)\n        {\n            for (int x = 0; x < Npx; ++x)\n            {\n                vec3f L = cube_to_dir(x, y, s, Npx);\n                float costheta = min(max(dot(N, L), 0.0f), 0.999f);\n                float w = costheta * pixel_area(x, y, Npx) / 3.141592f; // pi = area of positive hemisphere\n                atomicAdd((float*)p.cubemap.d_val + p.cubemap.nhwcIndexContinuous(s, y, x, 0), grad.x * w);\n                atomicAdd((float*)p.cubemap.d_val + p.cubemap.nhwcIndexContinuous(s, y, x, 1), grad.y * w);\n                atomicAdd((float*)p.cubemap.d_val + p.cubemap.nhwcIndexContinuous(s, y, x, 2), grad.z * w);\n            }\n        }\n    }\n}\n\n///////////////////////////////////////////////////////////////////////////////////////////////////////////\n// GGX splitsum kernel \n\n__device__ inline float ndfGGX(const float alphaSqr, const float cosTheta)\n{\n    float _cosTheta = clamp(cosTheta, 0.0, 1.0f);\n    float d = (_cosTheta * alphaSqr - _cosTheta) * _cosTheta + 1.0f;\n    return alphaSqr / (d * d * M_PI);\n}\n\n__global__ void SpecularBoundsKernel(SpecularBoundsKernelParams p)\n{\n    int px = blockIdx.x * blockDim.x + threadIdx.x;\n    int py = blockIdx.y * blockDim.y + threadIdx.y;\n    int pz = blockIdx.z;\n    if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z)\n        return;\n\n    int Npx = p.gridSize.x;\n    vec3f VNR = cube_to_dir(px, py, pz, Npx);\n\n    const int TILE_SIZE = 16;\n\n    // Brute force entire cubemap and compute bounds for the cone\n    for (int s = 0; s < p.gridSize.z; ++s)\n    {\n        // Assume empty BBox \n        int _min_x = p.gridSize.x - 1, _max_x = 0;\n        int _min_y = p.gridSize.y - 1, _max_y = 0;\n        \n        // For each (8x8) tile\n        for (int tx = 0; tx < (p.gridSize.x + TILE_SIZE - 1) / TILE_SIZE; tx++)\n        {\n            for (int ty = 0; ty < (p.gridSize.y + TILE_SIZE - 1) / TILE_SIZE; ty++)\n            {\n                // Compute tile extents\n                int tsx = tx * TILE_SIZE, tsy = ty * TILE_SIZE;\n                int tex = min((tx + 1) * TILE_SIZE, p.gridSize.x), tey = min((ty + 1) * TILE_SIZE, p.gridSize.y);\n\n                // Use some blunt interval arithmetics to cull tiles\n                vec3f L0 = cube_to_dir(tsx, tsy, s, Npx), L1 = cube_to_dir(tex, tsy, s, Npx);\n                vec3f L2 = cube_to_dir(tsx, tey, s, Npx), L3 = cube_to_dir(tex, tey, s, Npx);\n                \n                float minx = min(min(L0.x, L1.x), min(L2.x, L3.x)), maxx = max(max(L0.x, L1.x), max(L2.x, L3.x));\n                float miny = min(min(L0.y, L1.y), min(L2.y, L3.y)), maxy = max(max(L0.y, L1.y), max(L2.y, L3.y));\n                float minz = min(min(L0.z, L1.z), min(L2.z, L3.z)), maxz = max(max(L0.z, L1.z), max(L2.z, L3.z));\n\n                float maxdp = max(minx * VNR.x, maxx * VNR.x) + max(miny * VNR.y, maxy * VNR.y) + max(minz * VNR.z, maxz * VNR.z);\n                if (maxdp >= p.costheta_cutoff)\n                {\n                    // Test all pixels in tile.\n                    for (int y = tsy; y < tey; ++y)\n                    {\n                        for (int x = tsx; x < tex; ++x)\n                        {\n                            vec3f L = cube_to_dir(x, y, s, Npx);\n                            if (dot(L, VNR) >= p.costheta_cutoff)\n                            {\n                                _min_x = min(_min_x, x);\n                                _max_x = max(_max_x, x);\n                                _min_y = min(_min_y, y);\n                                _max_y = max(_max_y, y);\n                            }\n                        }\n                    }\n                }\n            }\n        }\n        p.out.store(p.out._nhwcIndex(pz, py, px, s * 4 + 0), _min_x);\n        p.out.store(p.out._nhwcIndex(pz, py, px, s * 4 + 1), _max_x);\n        p.out.store(p.out._nhwcIndex(pz, py, px, s * 4 + 2), _min_y);\n        p.out.store(p.out._nhwcIndex(pz, py, px, s * 4 + 3), _max_y);\n    }\n}\n\n__global__ void SpecularCubemapFwdKernel(SpecularCubemapKernelParams p)\n{\n    // Calculate pixel position.\n    int px = blockIdx.x * blockDim.x + threadIdx.x;\n    int py = blockIdx.y * blockDim.y + threadIdx.y;\n    int pz = blockIdx.z;\n    if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z)\n        return;\n\n    int Npx = p.cubemap.dims[1];\n    vec3f VNR = cube_to_dir(px, py, pz, Npx);\n\n    float alpha = p.roughness * p.roughness;\n    float alphaSqr = alpha * alpha;\n\n    float wsum = 0.0f;\n    vec3f col(0);\n    for (int s = 0; s < p.cubemap.dims[0]; ++s)\n    {\n        int xmin, xmax, ymin, ymax;\n        xmin = (int)p.bounds.fetch(p.bounds._nhwcIndex(pz, py, px, s * 4 + 0));\n        xmax = (int)p.bounds.fetch(p.bounds._nhwcIndex(pz, py, px, s * 4 + 1));\n        ymin = (int)p.bounds.fetch(p.bounds._nhwcIndex(pz, py, px, s * 4 + 2));\n        ymax = (int)p.bounds.fetch(p.bounds._nhwcIndex(pz, py, px, s * 4 + 3));\n\n        if (xmin <= xmax)\n        {\n            for (int y = ymin; y <= ymax; ++y)\n            {\n                for (int x = xmin; x <= xmax; ++x)\n                {\n                    vec3f L = cube_to_dir(x, y, s, Npx);\n                    if (dot(L, VNR) >= p.costheta_cutoff)\n                    {\n                        vec3f H = safeNormalize(L + VNR);\n\n                        float wiDotN = max(dot(L, VNR), 0.0f);\n                        float VNRDotH = max(dot(VNR, H), 0.0f);\n\n                        float w = wiDotN * ndfGGX(alphaSqr, VNRDotH) * pixel_area(x, y, Npx) / 4.0f;\n                        col += p.cubemap.fetch3(x, y, s) * w;\n                        wsum += w;\n                    }\n                }\n            }\n        }\n    }\n\n    p.out.store(p.out._nhwcIndex(pz, py, px, 0), col.x);\n    p.out.store(p.out._nhwcIndex(pz, py, px, 1), col.y);\n    p.out.store(p.out._nhwcIndex(pz, py, px, 2), col.z);\n    p.out.store(p.out._nhwcIndex(pz, py, px, 3), wsum);\n}\n\n__global__ void SpecularCubemapBwdKernel(SpecularCubemapKernelParams p)\n{\n    // Calculate pixel position.\n    int px = blockIdx.x * blockDim.x + threadIdx.x;\n    int py = blockIdx.y * blockDim.y + threadIdx.y;\n    int pz = blockIdx.z;\n    if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z)\n        return;\n\n    int Npx = p.cubemap.dims[1];\n    vec3f VNR = cube_to_dir(px, py, pz, Npx);\n\n    vec3f grad = p.out.fetch3(px, py, pz);\n\n    float alpha = p.roughness * p.roughness;\n    float alphaSqr = alpha * alpha;\n\n    vec3f col(0);\n    for (int s = 0; s < p.cubemap.dims[0]; ++s)\n    {\n        int xmin, xmax, ymin, ymax;\n        xmin = (int)p.bounds.fetch(p.bounds._nhwcIndex(pz, py, px, s * 4 + 0));\n        xmax = (int)p.bounds.fetch(p.bounds._nhwcIndex(pz, py, px, s * 4 + 1));\n        ymin = (int)p.bounds.fetch(p.bounds._nhwcIndex(pz, py, px, s * 4 + 2));\n        ymax = (int)p.bounds.fetch(p.bounds._nhwcIndex(pz, py, px, s * 4 + 3));\n\n        if (xmin <= xmax)\n        {\n            for (int y = ymin; y <= ymax; ++y)\n            {\n                for (int x = xmin; x <= xmax; ++x)\n                {\n                    vec3f L = cube_to_dir(x, y, s, Npx);\n                    if (dot(L, VNR) >= p.costheta_cutoff)\n                    {\n                        vec3f H = safeNormalize(L + VNR);\n\n                        float wiDotN = max(dot(L, VNR), 0.0f);\n                        float VNRDotH = max(dot(VNR, H), 0.0f);\n\n                        float w = wiDotN * ndfGGX(alphaSqr, VNRDotH) * pixel_area(x, y, Npx) / 4.0f;\n\n                        atomicAdd((float*)p.cubemap.d_val + p.cubemap.nhwcIndexContinuous(s, y, x, 0), grad.x * w);\n                        atomicAdd((float*)p.cubemap.d_val + p.cubemap.nhwcIndexContinuous(s, y, x, 1), grad.y * w);\n                        atomicAdd((float*)p.cubemap.d_val + p.cubemap.nhwcIndexContinuous(s, y, x, 2), grad.z * w);\n                    }\n                }\n            }\n        }\n    }\n}\n"
  },
  {
    "path": "render/renderutils/c_src/cubemap.h",
    "content": "/*\n * Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n *\n * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual\n * property and proprietary rights in and to this material, related \n * documentation and any modifications thereto. Any use, reproduction, \n * disclosure or distribution of this material and related documentation\n * without an express license agreement from NVIDIA CORPORATION or \n * its affiliates is strictly prohibited.\n */\n\n#pragma once\n\n#include \"common.h\"\n\nstruct DiffuseCubemapKernelParams\n{\n    Tensor  cubemap;\n    Tensor  out;\n    dim3    gridSize;\n};\n\nstruct SpecularCubemapKernelParams\n{\n    Tensor  cubemap;\n    Tensor  bounds;\n    Tensor  out;\n    dim3    gridSize;\n    float   costheta_cutoff;\n    float   roughness;\n};\n\nstruct SpecularBoundsKernelParams\n{\n    float   costheta_cutoff;\n    Tensor  out;\n    dim3    gridSize;\n};\n"
  },
  {
    "path": "render/renderutils/c_src/loss.cu",
    "content": "/*\n * Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n *\n * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual\n * property and proprietary rights in and to this material, related \n * documentation and any modifications thereto. Any use, reproduction, \n * disclosure or distribution of this material and related documentation\n * without an express license agreement from NVIDIA CORPORATION or \n * its affiliates is strictly prohibited.\n */\n\n#include <cuda.h>\n\n#include \"common.h\"\n#include \"loss.h\"\n\n//------------------------------------------------------------------------\n// Utils\n\n__device__ inline float bwdAbs(float x) { return x == 0.0f ? 0.0f : x < 0.0f ? -1.0f : 1.0f; }\n\n__device__ float warpSum(float val) {\n    for (int i = 1; i < 32; i *= 2)\n        val += __shfl_xor_sync(0xFFFFFFFF, val, i);\n    return val;\n}\n\n//------------------------------------------------------------------------\n// Tonemapping\n\n__device__ inline float fwdSRGB(float x)\n{\n    return x > 0.0031308f ? powf(max(x, 0.0031308f), 1.0f / 2.4f) * 1.055f - 0.055f : 12.92f * max(x, 0.0f);\n}\n\n__device__ inline void bwdSRGB(float x, float &d_x, float d_out)\n{\n    if (x > 0.0031308f)\n        d_x += d_out * 0.439583f / powf(x, 0.583333f);\n    else if (x > 0.0f)\n        d_x += d_out * 12.92f;\n}\n\n__device__ inline vec3f fwdTonemapLogSRGB(vec3f x)\n{\n    return vec3f(fwdSRGB(logf(x.x + 1.0f)), fwdSRGB(logf(x.y + 1.0f)), fwdSRGB(logf(x.z + 1.0f)));\n}\n\n__device__ inline void bwdTonemapLogSRGB(vec3f x, vec3f& d_x, vec3f d_out)\n{\n    if (x.x > 0.0f && x.x < 65535.0f)\n    {\n        bwdSRGB(logf(x.x + 1.0f), d_x.x, d_out.x);\n        d_x.x *= 1 / (x.x + 1.0f);\n    }\n    if (x.y > 0.0f && x.y < 65535.0f)\n    {\n        bwdSRGB(logf(x.y + 1.0f), d_x.y, d_out.y);\n        d_x.y *= 1 / (x.y + 1.0f);\n    }\n    if (x.z > 0.0f && x.z < 65535.0f)\n    {\n        bwdSRGB(logf(x.z + 1.0f), d_x.z, d_out.z);\n        d_x.z *= 1 / (x.z + 1.0f);\n    }\n}\n\n__device__ inline float fwdRELMSE(float img, float target, float eps = 0.1f)\n{\n    return (img - target) * (img - target) / (img * img + target * target + eps);\n}\n\n__device__ inline void bwdRELMSE(float img, float target, float &d_img, float &d_target, float d_out, float eps = 0.1f)\n{\n    float denom  = (target * target + img * img + eps);\n    d_img    += d_out * 2 * (img - target) * (target * (target + img) + eps) / (denom * denom);\n    d_target -= d_out * 2 * (img - target) * (img * (target + img) + eps) / (denom * denom);\n}\n\n__device__ inline float fwdSMAPE(float img, float target, float eps=0.01f)\n{\n    return abs(img - target) / (img + target + eps);\n}\n\n__device__ inline void bwdSMAPE(float img, float target, float& d_img, float& d_target, float d_out, float eps = 0.01f)\n{\n    float denom = (target + img + eps);\n    d_img    += d_out * bwdAbs(img - target) * (2 * target + eps) / (denom * denom);\n    d_target -= d_out * bwdAbs(img - target) * (2 * img + eps) / (denom * denom);\n}\n\n//------------------------------------------------------------------------\n// Kernels\n\n__global__ void imgLossFwdKernel(LossKernelParams p)\n{\n    // Calculate pixel position.\n    unsigned int px = blockIdx.x * blockDim.x + threadIdx.x;\n    unsigned int py = blockIdx.y * blockDim.y + threadIdx.y;\n    unsigned int pz = blockIdx.z;\n\n    float floss = 0.0f;\n    if (px < p.gridSize.x && py < p.gridSize.y && pz < p.gridSize.z)\n    {\n        vec3f img = p.img.fetch3(px, py, pz);\n        vec3f target = p.target.fetch3(px, py, pz);\n\n        img = vec3f(clamp(img.x, 0.0f, 65535.0f), clamp(img.y, 0.0f, 65535.0f), clamp(img.z, 0.0f, 65535.0f));\n        target = vec3f(clamp(target.x, 0.0f, 65535.0f), clamp(target.y, 0.0f, 65535.0f), clamp(target.z, 0.0f, 65535.0f));\n\n        if (p.tonemapper == TONEMAPPER_LOG_SRGB)\n        {\n            img = fwdTonemapLogSRGB(img);\n            target = fwdTonemapLogSRGB(target);\n        }\n\n        vec3f vloss(0);\n        if (p.loss == LOSS_MSE)\n            vloss = (img - target) * (img - target);\n        else if (p.loss == LOSS_RELMSE)\n            vloss = vec3f(fwdRELMSE(img.x, target.x), fwdRELMSE(img.y, target.y), fwdRELMSE(img.z, target.z));\n        else if (p.loss == LOSS_SMAPE)\n            vloss = vec3f(fwdSMAPE(img.x, target.x), fwdSMAPE(img.y, target.y), fwdSMAPE(img.z, target.z));\n        else\n            vloss = vec3f(abs(img.x - target.x), abs(img.y - target.y), abs(img.z - target.z));\n        \n        floss = sum(vloss) / 3.0f;\n    }\n\n    floss = warpSum(floss);\n\n    dim3 warpSize = getWarpSize(blockDim);\n    if (px < p.gridSize.x && py < p.gridSize.y && pz < p.gridSize.z && threadIdx.x % warpSize.x == 0 && threadIdx.y % warpSize.y == 0 && threadIdx.z % warpSize.z == 0)\n        p.out.store(px / warpSize.x, py / warpSize.y, pz / warpSize.z, floss);\n}\n\n__global__ void imgLossBwdKernel(LossKernelParams p)\n{ \n    // Calculate pixel position.\n    unsigned int px = blockIdx.x * blockDim.x + threadIdx.x;\n    unsigned int py = blockIdx.y * blockDim.y + threadIdx.y;\n    unsigned int pz = blockIdx.z;\n\n    if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z)\n        return;\n\n    dim3 warpSize = getWarpSize(blockDim);\n\n    vec3f _img = p.img.fetch3(px, py, pz);\n    vec3f _target = p.target.fetch3(px, py, pz);\n    float d_out = p.out.fetch1(px / warpSize.x, py / warpSize.y, pz / warpSize.z);\n\n    /////////////////////////////////////////////////////////////////////\n    // FWD\n\n    vec3f img = _img, target = _target;\n    if (p.tonemapper == TONEMAPPER_LOG_SRGB)\n    {\n        img = fwdTonemapLogSRGB(img);\n        target = fwdTonemapLogSRGB(target);\n    }\n\n    /////////////////////////////////////////////////////////////////////\n    // BWD\n\n    vec3f d_vloss = vec3f(d_out, d_out, d_out) / 3.0f;\n\n    vec3f d_img(0), d_target(0);\n    if (p.loss == LOSS_MSE)\n    {\n        d_img = vec3f(d_vloss.x * 2 * (img.x - target.x), d_vloss.y * 2 * (img.y - target.y), d_vloss.x * 2 * (img.z - target.z));\n        d_target = -d_img;\n    }\n    else if (p.loss == LOSS_RELMSE)\n    {\n        bwdRELMSE(img.x, target.x, d_img.x, d_target.x, d_vloss.x);\n        bwdRELMSE(img.y, target.y, d_img.y, d_target.y, d_vloss.y);\n        bwdRELMSE(img.z, target.z, d_img.z, d_target.z, d_vloss.z);\n    }\n    else if (p.loss == LOSS_SMAPE)\n    {\n        bwdSMAPE(img.x, target.x, d_img.x, d_target.x, d_vloss.x);\n        bwdSMAPE(img.y, target.y, d_img.y, d_target.y, d_vloss.y);\n        bwdSMAPE(img.z, target.z, d_img.z, d_target.z, d_vloss.z);\n    }\n    else\n    {\n        d_img = d_vloss * vec3f(bwdAbs(img.x - target.x), bwdAbs(img.y - target.y), bwdAbs(img.z - target.z));\n        d_target = -d_img;\n    }\n\n\n    if (p.tonemapper == TONEMAPPER_LOG_SRGB)\n    {\n        vec3f d__img(0), d__target(0);\n        bwdTonemapLogSRGB(_img, d__img, d_img);\n        bwdTonemapLogSRGB(_target, d__target, d_target);\n        d_img = d__img; d_target = d__target;\n    }\n\n    if (_img.x <= 0.0f || _img.x >= 65535.0f) d_img.x = 0;\n    if (_img.y <= 0.0f || _img.y >= 65535.0f) d_img.y = 0;\n    if (_img.z <= 0.0f || _img.z >= 65535.0f) d_img.z = 0;\n    if (_target.x <= 0.0f || _target.x >= 65535.0f) d_target.x = 0;\n    if (_target.y <= 0.0f || _target.y >= 65535.0f) d_target.y = 0;\n    if (_target.z <= 0.0f || _target.z >= 65535.0f) d_target.z = 0;\n\n    p.img.store_grad(px, py, pz, d_img);\n    p.target.store_grad(px, py, pz, d_target);\n}"
  },
  {
    "path": "render/renderutils/c_src/loss.h",
    "content": "/*\n * Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n *\n * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual\n * property and proprietary rights in and to this material, related \n * documentation and any modifications thereto. Any use, reproduction, \n * disclosure or distribution of this material and related documentation\n * without an express license agreement from NVIDIA CORPORATION or \n * its affiliates is strictly prohibited.\n */\n\n#pragma once\n\n#include \"common.h\"\n\nenum TonemapperType\n{\n    TONEMAPPER_NONE = 0,\n    TONEMAPPER_LOG_SRGB = 1\n};\n\nenum LossType\n{\n    LOSS_L1 = 0,\n    LOSS_MSE = 1,\n    LOSS_RELMSE = 2,\n    LOSS_SMAPE = 3\n};\n\nstruct LossKernelParams\n{\n    Tensor          img;\n    Tensor          target;\n    Tensor          out;\n    dim3            gridSize;\n    TonemapperType  tonemapper;\n    LossType        loss;\n};\n"
  },
  {
    "path": "render/renderutils/c_src/mesh.cu",
    "content": "/*\n * Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n *\n * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual\n * property and proprietary rights in and to this material, related \n * documentation and any modifications thereto. Any use, reproduction, \n * disclosure or distribution of this material and related documentation\n * without an express license agreement from NVIDIA CORPORATION or \n * its affiliates is strictly prohibited.\n */\n\n#include <cuda.h>\n#include <stdio.h>\n\n#include \"common.h\"\n#include \"mesh.h\"\n\n\n//------------------------------------------------------------------------\n// Kernels\n\n__global__ void xfmPointsFwdKernel(XfmKernelParams p)\n{\n    unsigned int px = blockIdx.x * blockDim.x + threadIdx.x;\n    unsigned int pz = blockIdx.z * blockDim.z + threadIdx.z;\n\n    __shared__ float mtx[4][4];\n    if (threadIdx.x < 16)\n        mtx[threadIdx.x % 4][threadIdx.x / 4] = p.matrix.fetch(p.matrix.nhwcIndex(pz, threadIdx.x / 4, threadIdx.x % 4, 0));\n    __syncthreads();\n    \n    if (px >= p.gridSize.x)\n        return;\n\n    vec3f pos(\n        p.points.fetch(p.points.nhwcIndex(pz, px, 0, 0)),\n        p.points.fetch(p.points.nhwcIndex(pz, px, 1, 0)),\n        p.points.fetch(p.points.nhwcIndex(pz, px, 2, 0))\n    );\n\n    if (p.isPoints)\n    {\n        p.out.store(p.out.nhwcIndex(pz, px, 0, 0), pos.x * mtx[0][0] + pos.y * mtx[1][0] + pos.z * mtx[2][0] + mtx[3][0]);\n        p.out.store(p.out.nhwcIndex(pz, px, 1, 0), pos.x * mtx[0][1] + pos.y * mtx[1][1] + pos.z * mtx[2][1] + mtx[3][1]);\n        p.out.store(p.out.nhwcIndex(pz, px, 2, 0), pos.x * mtx[0][2] + pos.y * mtx[1][2] + pos.z * mtx[2][2] + mtx[3][2]);\n        p.out.store(p.out.nhwcIndex(pz, px, 3, 0), pos.x * mtx[0][3] + pos.y * mtx[1][3] + pos.z * mtx[2][3] + mtx[3][3]);\n    }\n    else\n    {\n        p.out.store(p.out.nhwcIndex(pz, px, 0, 0), pos.x * mtx[0][0] + pos.y * mtx[1][0] + pos.z * mtx[2][0]);\n        p.out.store(p.out.nhwcIndex(pz, px, 1, 0), pos.x * mtx[0][1] + pos.y * mtx[1][1] + pos.z * mtx[2][1]);\n        p.out.store(p.out.nhwcIndex(pz, px, 2, 0), pos.x * mtx[0][2] + pos.y * mtx[1][2] + pos.z * mtx[2][2]);\n    }\n}\n\n__global__ void xfmPointsBwdKernel(XfmKernelParams p)\n{ \n    unsigned int px = blockIdx.x * blockDim.x + threadIdx.x;\n    unsigned int pz = blockIdx.z * blockDim.z + threadIdx.z;\n\n    __shared__ float mtx[4][4];\n    if (threadIdx.x < 16)\n        mtx[threadIdx.x % 4][threadIdx.x / 4] = p.matrix.fetch(p.matrix.nhwcIndex(pz, threadIdx.x / 4, threadIdx.x % 4, 0));\n    __syncthreads();\n\n    if (px >= p.gridSize.x)\n        return;\n\n    vec3f pos(\n        p.points.fetch(p.points.nhwcIndex(pz, px, 0, 0)),\n        p.points.fetch(p.points.nhwcIndex(pz, px, 1, 0)),\n        p.points.fetch(p.points.nhwcIndex(pz, px, 2, 0))\n    );\n\n    vec4f d_out(\n        p.out.fetch(p.out.nhwcIndex(pz, px, 0, 0)),\n        p.out.fetch(p.out.nhwcIndex(pz, px, 1, 0)),\n        p.out.fetch(p.out.nhwcIndex(pz, px, 2, 0)),\n        p.out.fetch(p.out.nhwcIndex(pz, px, 3, 0))\n    );\n\n    if (p.isPoints)\n    {\n        p.points.store_grad(p.points.nhwcIndexContinuous(pz, px, 0, 0), d_out.x * mtx[0][0] + d_out.y * mtx[0][1] + d_out.z * mtx[0][2] + d_out.w * mtx[0][3]);\n        p.points.store_grad(p.points.nhwcIndexContinuous(pz, px, 1, 0), d_out.x * mtx[1][0] + d_out.y * mtx[1][1] + d_out.z * mtx[1][2] + d_out.w * mtx[1][3]);\n        p.points.store_grad(p.points.nhwcIndexContinuous(pz, px, 2, 0), d_out.x * mtx[2][0] + d_out.y * mtx[2][1] + d_out.z * mtx[2][2] + d_out.w * mtx[2][3]);\n    }\n    else\n    {\n        p.points.store_grad(p.points.nhwcIndexContinuous(pz, px, 0, 0), d_out.x * mtx[0][0] + d_out.y * mtx[0][1] + d_out.z * mtx[0][2]);\n        p.points.store_grad(p.points.nhwcIndexContinuous(pz, px, 1, 0), d_out.x * mtx[1][0] + d_out.y * mtx[1][1] + d_out.z * mtx[1][2]);\n        p.points.store_grad(p.points.nhwcIndexContinuous(pz, px, 2, 0), d_out.x * mtx[2][0] + d_out.y * mtx[2][1] + d_out.z * mtx[2][2]);\n    }\n}"
  },
  {
    "path": "render/renderutils/c_src/mesh.h",
    "content": "/*\n * Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n *\n * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual\n * property and proprietary rights in and to this material, related \n * documentation and any modifications thereto. Any use, reproduction, \n * disclosure or distribution of this material and related documentation\n * without an express license agreement from NVIDIA CORPORATION or \n * its affiliates is strictly prohibited.\n */\n\n#pragma once\n\n#include \"common.h\"\n\nstruct XfmKernelParams\n{\n    bool            isPoints;\n    Tensor          points;\n    Tensor          matrix;\n    Tensor          out;\n    dim3            gridSize;\n};\n"
  },
  {
    "path": "render/renderutils/c_src/normal.cu",
    "content": "/*\n * Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n *\n * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual\n * property and proprietary rights in and to this material, related \n * documentation and any modifications thereto. Any use, reproduction, \n * disclosure or distribution of this material and related documentation\n * without an express license agreement from NVIDIA CORPORATION or \n * its affiliates is strictly prohibited.\n */\n\n#include \"common.h\"\n#include \"normal.h\"\n\n#define NORMAL_THRESHOLD 0.1f\n\n//------------------------------------------------------------------------\n// Perturb shading normal by tangent frame\n\n__device__ vec3f fwdPerturbNormal(const vec3f perturbed_nrm, const vec3f smooth_nrm, const vec3f smooth_tng, bool opengl)\n{\n    vec3f _smooth_bitng = cross(smooth_tng, smooth_nrm);\n    vec3f smooth_bitng = safeNormalize(_smooth_bitng);\n    vec3f _shading_nrm = smooth_tng * perturbed_nrm.x + (opengl ? -1 : 1) * smooth_bitng * perturbed_nrm.y + smooth_nrm * max(perturbed_nrm.z, 0.0f);\n    return safeNormalize(_shading_nrm);\n}\n\n__device__ void bwdPerturbNormal(const vec3f perturbed_nrm, const vec3f smooth_nrm, const vec3f smooth_tng, vec3f &d_perturbed_nrm, vec3f &d_smooth_nrm, vec3f &d_smooth_tng, const vec3f d_out, bool opengl)\n{\n    ////////////////////////////////////////////////////////////////////////\n    // FWD\n    vec3f _smooth_bitng = cross(smooth_tng, smooth_nrm);\n    vec3f smooth_bitng = safeNormalize(_smooth_bitng);\n    vec3f _shading_nrm = smooth_tng * perturbed_nrm.x + (opengl ? -1 : 1) * smooth_bitng * perturbed_nrm.y + smooth_nrm * max(perturbed_nrm.z, 0.0f);\n        \n    ////////////////////////////////////////////////////////////////////////\n    // BWD\n    vec3f d_shading_nrm(0);\n    bwdSafeNormalize(_shading_nrm, d_shading_nrm, d_out);\n\n    vec3f d_smooth_bitng(0);\n    \n    if (perturbed_nrm.z > 0.0f)\n    {\n        d_smooth_nrm += d_shading_nrm * perturbed_nrm.z;\n        d_perturbed_nrm.z += sum(d_shading_nrm * smooth_nrm);\n    }\n\n    d_smooth_bitng += (opengl ? -1 : 1) * d_shading_nrm * perturbed_nrm.y;\n    d_perturbed_nrm.y += (opengl ? -1 : 1) * sum(d_shading_nrm * smooth_bitng);\n\n    d_smooth_tng += d_shading_nrm * perturbed_nrm.x;\n    d_perturbed_nrm.x += sum(d_shading_nrm * smooth_tng);\n\n    vec3f d__smooth_bitng(0);\n    bwdSafeNormalize(_smooth_bitng, d__smooth_bitng, d_smooth_bitng);\n\n    bwdCross(smooth_tng, smooth_nrm, d_smooth_tng, d_smooth_nrm, d__smooth_bitng);\n}\n\n//------------------------------------------------------------------------\n#define bent_nrm_eps 0.001f\n\n__device__ vec3f fwdBendNormal(const vec3f view_vec, const vec3f smooth_nrm, const vec3f geom_nrm)\n{\n    float dp = dot(view_vec, smooth_nrm);\n    float t = clamp(dp / NORMAL_THRESHOLD, 0.0f, 1.0f);\n    return geom_nrm * (1.0f - t) + smooth_nrm * t;\n}\n\n__device__ void bwdBendNormal(const vec3f view_vec, const vec3f smooth_nrm, const vec3f geom_nrm, vec3f& d_view_vec, vec3f& d_smooth_nrm, vec3f& d_geom_nrm, const vec3f d_out)\n{\n    ////////////////////////////////////////////////////////////////////////\n    // FWD\n    float dp = dot(view_vec, smooth_nrm);\n    float t = clamp(dp / NORMAL_THRESHOLD, 0.0f, 1.0f);\n\n    ////////////////////////////////////////////////////////////////////////\n    // BWD\n    if (dp > NORMAL_THRESHOLD)\n        d_smooth_nrm += d_out;\n    else\n    {\n        // geom_nrm * (1.0f - t) + smooth_nrm * t;\n        d_geom_nrm   += d_out * (1.0f - t);\n        d_smooth_nrm += d_out * t;\n        float d_t = sum(d_out * (smooth_nrm - geom_nrm));\n\n        float d_dp = dp < 0.0f || dp > NORMAL_THRESHOLD ? 0.0f : d_t / NORMAL_THRESHOLD;\n\n        bwdDot(view_vec, smooth_nrm, d_view_vec, d_smooth_nrm, d_dp);\n    }\n}\n\n//------------------------------------------------------------------------\n// Kernels\n\n__global__ void PrepareShadingNormalFwdKernel(PrepareShadingNormalKernelParams p) \n{\n    // Calculate pixel position.\n    unsigned int px = blockIdx.x * blockDim.x + threadIdx.x;\n    unsigned int py = blockIdx.y * blockDim.y + threadIdx.y;\n    unsigned int pz = blockIdx.z;\n    if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z)\n        return;\n\n    vec3f pos = p.pos.fetch3(px, py, pz);\n    vec3f view_pos = p.view_pos.fetch3(px, py, pz);\n    vec3f perturbed_nrm = p.perturbed_nrm.fetch3(px, py, pz);\n    vec3f _smooth_nrm = p.smooth_nrm.fetch3(px, py, pz);\n    vec3f _smooth_tng = p.smooth_tng.fetch3(px, py, pz);\n    vec3f geom_nrm = p.geom_nrm.fetch3(px, py, pz);\n\n    vec3f smooth_nrm = safeNormalize(_smooth_nrm);\n    vec3f smooth_tng = safeNormalize(_smooth_tng);\n    vec3f view_vec = safeNormalize(view_pos - pos);\n    vec3f shading_nrm = fwdPerturbNormal(perturbed_nrm, smooth_nrm, smooth_tng, p.opengl);\n\n    vec3f res;\n    if (p.two_sided_shading && dot(view_vec, geom_nrm) < 0.0f)\n        res = fwdBendNormal(view_vec, -shading_nrm, -geom_nrm);\n    else\n        res = fwdBendNormal(view_vec, shading_nrm, geom_nrm);\n\n    p.out.store(px, py, pz, res);\n}\n\n__global__ void PrepareShadingNormalBwdKernel(PrepareShadingNormalKernelParams p) \n{ \n    // Calculate pixel position.\n    unsigned int px = blockIdx.x * blockDim.x + threadIdx.x;\n    unsigned int py = blockIdx.y * blockDim.y + threadIdx.y;\n    unsigned int pz = blockIdx.z;\n    if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z)\n        return;\n\n    vec3f pos = p.pos.fetch3(px, py, pz);\n    vec3f view_pos = p.view_pos.fetch3(px, py, pz);\n    vec3f perturbed_nrm = p.perturbed_nrm.fetch3(px, py, pz);\n    vec3f _smooth_nrm = p.smooth_nrm.fetch3(px, py, pz);\n    vec3f _smooth_tng = p.smooth_tng.fetch3(px, py, pz);\n    vec3f geom_nrm = p.geom_nrm.fetch3(px, py, pz);\n    vec3f d_out = p.out.fetch3(px, py, pz);\n\n    ///////////////////////////////////////////////////////////////////////////////////////////////////\n    // FWD\n\n    vec3f smooth_nrm = safeNormalize(_smooth_nrm);\n    vec3f smooth_tng = safeNormalize(_smooth_tng);\n    vec3f _view_vec = view_pos - pos;\n    vec3f view_vec = safeNormalize(view_pos - pos);\n\n    vec3f shading_nrm = fwdPerturbNormal(perturbed_nrm, smooth_nrm, smooth_tng, p.opengl);\n\n    ///////////////////////////////////////////////////////////////////////////////////////////////////\n    // BWD\n\n    vec3f d_view_vec(0), d_shading_nrm(0), d_geom_nrm(0);\n    if (p.two_sided_shading && dot(view_vec, geom_nrm) < 0.0f)\n    {\n        bwdBendNormal(view_vec, -shading_nrm, -geom_nrm, d_view_vec, d_shading_nrm, d_geom_nrm, d_out);\n        d_shading_nrm = -d_shading_nrm;\n        d_geom_nrm = -d_geom_nrm;\n    }\n    else\n        bwdBendNormal(view_vec, shading_nrm, geom_nrm, d_view_vec, d_shading_nrm, d_geom_nrm, d_out);\n\n    vec3f d_perturbed_nrm(0), d_smooth_nrm(0), d_smooth_tng(0);\n    bwdPerturbNormal(perturbed_nrm, smooth_nrm, smooth_tng, d_perturbed_nrm, d_smooth_nrm, d_smooth_tng, d_shading_nrm, p.opengl);\n\n    vec3f d__view_vec(0), d__smooth_nrm(0), d__smooth_tng(0);\n    bwdSafeNormalize(_view_vec, d__view_vec, d_view_vec);\n    bwdSafeNormalize(_smooth_nrm, d__smooth_nrm, d_smooth_nrm);\n    bwdSafeNormalize(_smooth_tng, d__smooth_tng, d_smooth_tng);\n\n    p.pos.store_grad(px, py, pz, -d__view_vec);\n    p.view_pos.store_grad(px, py, pz, d__view_vec);\n    p.perturbed_nrm.store_grad(px, py, pz, d_perturbed_nrm);\n    p.smooth_nrm.store_grad(px, py, pz, d__smooth_nrm);\n    p.smooth_tng.store_grad(px, py, pz, d__smooth_tng);\n    p.geom_nrm.store_grad(px, py, pz, d_geom_nrm);\n}"
  },
  {
    "path": "render/renderutils/c_src/normal.h",
    "content": "/*\n * Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n *\n * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual\n * property and proprietary rights in and to this material, related \n * documentation and any modifications thereto. Any use, reproduction, \n * disclosure or distribution of this material and related documentation\n * without an express license agreement from NVIDIA CORPORATION or \n * its affiliates is strictly prohibited.\n */\n\n#pragma once\n\n#include \"common.h\"\n\nstruct PrepareShadingNormalKernelParams\n{\n    Tensor  pos;\n    Tensor  view_pos;\n    Tensor  perturbed_nrm;\n    Tensor  smooth_nrm;\n    Tensor  smooth_tng;\n    Tensor  geom_nrm;\n    Tensor  out;\n    dim3    gridSize;\n    bool    two_sided_shading, opengl;\n};\n"
  },
  {
    "path": "render/renderutils/c_src/tensor.h",
    "content": "/*\n * Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n *\n * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual\n * property and proprietary rights in and to this material, related \n * documentation and any modifications thereto. Any use, reproduction, \n * disclosure or distribution of this material and related documentation\n * without an express license agreement from NVIDIA CORPORATION or \n * its affiliates is strictly prohibited.\n */\n\n#pragma once\n#if defined(__CUDACC__) && defined(BFLOAT16)\n#include <cuda_bf16.h> // bfloat16 is float32 compatible with less mantissa bits\n#endif\n\n//---------------------------------------------------------------------------------\n// CUDA-side Tensor class for in/out parameter parsing. Can be float32 or bfloat16\n\nstruct Tensor\n{\n    void*   val;\n    void*   d_val;\n    int     dims[4], _dims[4];\n    int     strides[4];\n    bool    fp16;\n\n#if defined(__CUDA__) && !defined(__CUDA_ARCH__)\n    Tensor() : val(nullptr), d_val(nullptr), fp16(true), dims{ 0, 0, 0, 0 }, _dims{ 0, 0, 0, 0 }, strides{ 0, 0, 0, 0 } {}\n#endif\n\n#ifdef __CUDACC__\n    // Helpers to index and read/write a single element\n    __device__ inline int   _nhwcIndex(int n, int h, int w, int c) const { return n * strides[0] + h * strides[1] + w * strides[2] + c * strides[3]; }\n    __device__ inline int   nhwcIndex(int n, int h, int w, int c) const { return (dims[0] == 1 ? 0 : n * strides[0]) + (dims[1] == 1 ? 0 : h * strides[1]) + (dims[2] == 1 ? 0 : w * strides[2]) + (dims[3] == 1 ? 0 : c * strides[3]); }\n    __device__ inline int   nhwcIndexContinuous(int n, int h, int w, int c) const { return ((n * _dims[1] + h) * _dims[2] + w) * _dims[3] + c; }\n#ifdef BFLOAT16\n    __device__ inline float fetch(unsigned int idx) const { return fp16 ? __bfloat162float(((__nv_bfloat16*)val)[idx]) : ((float*)val)[idx]; }\n    __device__ inline void  store(unsigned int idx, float _val) { if (fp16) ((__nv_bfloat16*)val)[idx] = __float2bfloat16(_val); else ((float*)val)[idx] = _val; }\n    __device__ inline void  store_grad(unsigned int idx, float _val) { if (fp16) ((__nv_bfloat16*)d_val)[idx] = __float2bfloat16(_val); else ((float*)d_val)[idx] = _val; }\n#else\n    __device__ inline float fetch(unsigned int idx) const { return ((float*)val)[idx]; }\n    __device__ inline void  store(unsigned int idx, float _val) { ((float*)val)[idx] = _val; }\n    __device__ inline void  store_grad(unsigned int idx, float _val) { ((float*)d_val)[idx] = _val; }\n#endif\n\n    //////////////////////////////////////////////////////////////////////////////////////////\n    // Fetch, use broadcasting for tensor dimensions of size 1\n    __device__ inline float fetch1(unsigned int x, unsigned int y, unsigned int z) const\n    {\n        return fetch(nhwcIndex(z, y, x, 0));\n    }\n\n    __device__ inline vec3f fetch3(unsigned int x, unsigned int y, unsigned int z) const\n    {\n        return vec3f(\n            fetch(nhwcIndex(z, y, x, 0)),\n            fetch(nhwcIndex(z, y, x, 1)),\n            fetch(nhwcIndex(z, y, x, 2))\n        );\n    }\n\n    /////////////////////////////////////////////////////////////////////////////////////////////////////////////\n    // Store, no broadcasting here. Assume we output full res gradient and then reduce using torch.sum outside\n    __device__ inline void store(unsigned int x, unsigned int y, unsigned int z, float _val)\n    {\n        store(_nhwcIndex(z, y, x, 0), _val);\n    }\n\n    __device__ inline void store(unsigned int x, unsigned int y, unsigned int z, vec3f _val)\n    {\n        store(_nhwcIndex(z, y, x, 0), _val.x);\n        store(_nhwcIndex(z, y, x, 1), _val.y);\n        store(_nhwcIndex(z, y, x, 2), _val.z);\n    }\n\n    /////////////////////////////////////////////////////////////////////////////////////////////////////////////\n    // Store gradient , no broadcasting here. Assume we output full res gradient and then reduce using torch.sum outside\n    __device__ inline void store_grad(unsigned int x, unsigned int y, unsigned int z, float _val)\n    {\n        store_grad(nhwcIndexContinuous(z, y, x, 0), _val);\n    }\n\n    __device__ inline void store_grad(unsigned int x, unsigned int y, unsigned int z, vec3f _val)\n    {\n        store_grad(nhwcIndexContinuous(z, y, x, 0), _val.x);\n        store_grad(nhwcIndexContinuous(z, y, x, 1), _val.y);\n        store_grad(nhwcIndexContinuous(z, y, x, 2), _val.z);\n    }\n#endif\n\n};\n"
  },
  {
    "path": "render/renderutils/c_src/torch_bindings.cpp",
    "content": "/*\n * Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n *\n * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual\n * property and proprietary rights in and to this material, related \n * documentation and any modifications thereto. Any use, reproduction, \n * disclosure or distribution of this material and related documentation\n * without an express license agreement from NVIDIA CORPORATION or \n * its affiliates is strictly prohibited.\n */\n\n#ifdef _MSC_VER \n#pragma warning(push, 0)\n#include <torch/extension.h>\n#pragma warning(pop)\n#else\n#include <torch/extension.h>\n#endif\n\n#include <ATen/cuda/CUDAContext.h>\n#include <ATen/cuda/CUDAUtils.h>\n#include <algorithm>\n#include <string>\n\n#define NVDR_CHECK_CUDA_ERROR(CUDA_CALL) { cudaError_t err = CUDA_CALL; AT_CUDA_CHECK(cudaGetLastError()); }\n#define NVDR_CHECK_GL_ERROR(GL_CALL) { GL_CALL; GLenum err = glGetError(); TORCH_CHECK(err == GL_NO_ERROR, \"OpenGL error: \", getGLErrorString(err), \"[\", #GL_CALL, \";]\"); }\n#define CHECK_TENSOR(X, DIMS, CHANNELS) \\\n    TORCH_CHECK(X.is_cuda(), #X \" must be a cuda tensor\") \\\n    TORCH_CHECK(X.scalar_type() == torch::kFloat || X.scalar_type() == torch::kBFloat16, #X \" must be fp32 or bf16\") \\\n    TORCH_CHECK(X.dim() == DIMS, #X \" must have \" #DIMS \" dimensions\") \\\n    TORCH_CHECK(X.size(DIMS - 1) == CHANNELS, #X \" must have \" #CHANNELS \" channels\")\n\n#include \"common.h\"\n#include \"loss.h\"\n#include \"normal.h\"\n#include \"cubemap.h\"\n#include \"bsdf.h\"\n#include \"mesh.h\"\n\n#define BLOCK_X 8\n#define BLOCK_Y 8\n\n//------------------------------------------------------------------------\n// mesh.cu\n\nvoid xfmPointsFwdKernel(XfmKernelParams p);\nvoid xfmPointsBwdKernel(XfmKernelParams p);\n\n//------------------------------------------------------------------------\n// loss.cu\n\nvoid imgLossFwdKernel(LossKernelParams p);\nvoid imgLossBwdKernel(LossKernelParams p);\n\n//------------------------------------------------------------------------\n// normal.cu\n\nvoid PrepareShadingNormalFwdKernel(PrepareShadingNormalKernelParams p);\nvoid PrepareShadingNormalBwdKernel(PrepareShadingNormalKernelParams p);\n\n//------------------------------------------------------------------------\n// cubemap.cu\n\nvoid DiffuseCubemapFwdKernel(DiffuseCubemapKernelParams p);\nvoid DiffuseCubemapBwdKernel(DiffuseCubemapKernelParams p);\nvoid SpecularBoundsKernel(SpecularBoundsKernelParams p);\nvoid SpecularCubemapFwdKernel(SpecularCubemapKernelParams p);\nvoid SpecularCubemapBwdKernel(SpecularCubemapKernelParams p);\n\n//------------------------------------------------------------------------\n// bsdf.cu\n\nvoid LambertFwdKernel(LambertKernelParams p);\nvoid LambertBwdKernel(LambertKernelParams p);\n\nvoid FrostbiteDiffuseFwdKernel(FrostbiteDiffuseKernelParams p);\nvoid FrostbiteDiffuseBwdKernel(FrostbiteDiffuseKernelParams p);\n\nvoid FresnelShlickFwdKernel(FresnelShlickKernelParams p);\nvoid FresnelShlickBwdKernel(FresnelShlickKernelParams p);\n\nvoid ndfGGXFwdKernel(NdfGGXParams p);\nvoid ndfGGXBwdKernel(NdfGGXParams p);\n\nvoid lambdaGGXFwdKernel(NdfGGXParams p);\nvoid lambdaGGXBwdKernel(NdfGGXParams p);\n\nvoid maskingSmithFwdKernel(MaskingSmithParams p);\nvoid maskingSmithBwdKernel(MaskingSmithParams p);\n\nvoid pbrSpecularFwdKernel(PbrSpecular p);\nvoid pbrSpecularBwdKernel(PbrSpecular p);\n\nvoid pbrBSDFFwdKernel(PbrBSDF p);\nvoid pbrBSDFBwdKernel(PbrBSDF p);\n\n//------------------------------------------------------------------------\n// Tensor helpers\n\nvoid update_grid(dim3 &gridSize, torch::Tensor x)\n{\n    gridSize.x = std::max(gridSize.x, (uint32_t)x.size(2));\n    gridSize.y = std::max(gridSize.y, (uint32_t)x.size(1));\n    gridSize.z = std::max(gridSize.z, (uint32_t)x.size(0));\n}\n\ntemplate<typename... Ts>\nvoid update_grid(dim3& gridSize, torch::Tensor x, Ts&&... vs)\n{\n    gridSize.x = std::max(gridSize.x, (uint32_t)x.size(2));\n    gridSize.y = std::max(gridSize.y, (uint32_t)x.size(1));\n    gridSize.z = std::max(gridSize.z, (uint32_t)x.size(0));\n    update_grid(gridSize, std::forward<Ts>(vs)...);\n}\n\nTensor make_cuda_tensor(torch::Tensor val)\n{\n    Tensor res;\n    for (int i = 0; i < val.dim(); ++i)\n    {\n        res.dims[i] = val.size(i);\n        res.strides[i] = val.stride(i);\n    }\n    res.fp16 = val.scalar_type() == torch::kBFloat16;\n    res.val = res.fp16 ? (void*)val.data_ptr<torch::BFloat16>() : (void*)val.data_ptr<float>();\n    res.d_val = nullptr;\n    return res;\n}\n\nTensor make_cuda_tensor(torch::Tensor val, dim3 outDims, torch::Tensor* grad = nullptr)\n{\n    Tensor res;\n    for (int i = 0; i < val.dim(); ++i)\n    {\n        res.dims[i] = val.size(i);\n        res.strides[i] = val.stride(i);\n    }\n    if (val.dim() == 4)\n        res._dims[0] = outDims.z, res._dims[1] = outDims.y, res._dims[2] = outDims.x, res._dims[3] = val.size(3);\n    else\n        res._dims[0] = outDims.z, res._dims[1] = outDims.x, res._dims[2] = val.size(2), res._dims[3] = 1; // Add a trailing one for indexing math to work out\n\n    res.fp16 = val.scalar_type() == torch::kBFloat16;\n    res.val = res.fp16 ? (void*)val.data_ptr<torch::BFloat16>() : (void*)val.data_ptr<float>();\n    res.d_val = nullptr;\n    if (grad != nullptr)\n    {\n        if (val.dim() == 4)\n            *grad = torch::empty({ outDims.z, outDims.y, outDims.x, val.size(3) }, torch::TensorOptions().dtype(res.fp16 ? torch::kBFloat16 : torch::kFloat32).device(torch::kCUDA));\n        else // 3\n            *grad = torch::empty({ outDims.z, outDims.x, val.size(2) }, torch::TensorOptions().dtype(res.fp16 ? torch::kBFloat16 : torch::kFloat32).device(torch::kCUDA));\n\n        res.d_val = res.fp16 ? (void*)grad->data_ptr<torch::BFloat16>() : (void*)grad->data_ptr<float>();\n    }\n    return res;\n}\n\n//------------------------------------------------------------------------\n// prepare_shading_normal\n\ntorch::Tensor prepare_shading_normal_fwd(torch::Tensor pos, torch::Tensor view_pos, torch::Tensor perturbed_nrm, torch::Tensor smooth_nrm, torch::Tensor smooth_tng, torch::Tensor geom_nrm, bool two_sided_shading, bool opengl, bool fp16)\n{\n    CHECK_TENSOR(pos, 4, 3);\n    CHECK_TENSOR(view_pos, 4, 3);\n    CHECK_TENSOR(perturbed_nrm, 4, 3);\n    CHECK_TENSOR(smooth_nrm, 4, 3);\n    CHECK_TENSOR(smooth_tng, 4, 3);\n    CHECK_TENSOR(geom_nrm, 4, 3);\n\n    cudaStream_t stream = at::cuda::getCurrentCUDAStream();\n\n    // Extract input parameters.\n    PrepareShadingNormalKernelParams p;\n    p.two_sided_shading = two_sided_shading;\n    p.opengl = opengl;\n    p.out.fp16 = fp16;\n    update_grid(p.gridSize, pos, view_pos, perturbed_nrm, smooth_nrm, smooth_tng, geom_nrm);\n\n    // Allocate output tensors.\n    torch::TensorOptions opts = torch::TensorOptions().dtype(p.out.fp16 ? torch::kBFloat16 : torch::kFloat32).device(torch::kCUDA);\n    torch::Tensor out = torch::empty({ p.gridSize.z, p.gridSize.y, p.gridSize.x, 3 }, opts);\n\n    // Choose launch parameters.\n    dim3 blockSize = getLaunchBlockSize(BLOCK_X, BLOCK_Y, p.gridSize);\n    dim3 gridSize = getLaunchGridSize(blockSize, p.gridSize);\n\n    // Setup tensors\n    p.pos = make_cuda_tensor(pos, p.gridSize);\n    p.view_pos = make_cuda_tensor(view_pos, p.gridSize);\n    p.perturbed_nrm = make_cuda_tensor(perturbed_nrm, p.gridSize);\n    p.smooth_nrm = make_cuda_tensor(smooth_nrm, p.gridSize);\n    p.smooth_tng = make_cuda_tensor(smooth_tng, p.gridSize);\n    p.geom_nrm = make_cuda_tensor(geom_nrm, p.gridSize);\n    p.out = make_cuda_tensor(out, p.gridSize);\n\n    // Launch CUDA kernel.\n    void* args[] = { &p };\n    NVDR_CHECK_CUDA_ERROR(cudaLaunchKernel((const void*)PrepareShadingNormalFwdKernel, gridSize, blockSize, args, 0, stream));\n\n    return out;\n}\n\nstd::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor> prepare_shading_normal_bwd(torch::Tensor pos, torch::Tensor view_pos, torch::Tensor perturbed_nrm, torch::Tensor smooth_nrm, torch::Tensor smooth_tng, torch::Tensor geom_nrm, torch::Tensor grad, bool two_sided_shading, bool opengl)\n{\n    cudaStream_t stream = at::cuda::getCurrentCUDAStream();\n\n    // Extract input parameters.\n    PrepareShadingNormalKernelParams p;\n    p.two_sided_shading = two_sided_shading;\n    p.opengl = opengl;\n    update_grid(p.gridSize, pos, view_pos, perturbed_nrm, smooth_nrm, smooth_tng, geom_nrm);\n\n    // Choose launch parameters.\n    dim3 blockSize = getLaunchBlockSize(BLOCK_X, BLOCK_Y, p.gridSize);\n    dim3 gridSize = getLaunchGridSize(blockSize, p.gridSize);\n\n    // Setup tensors\n    torch::Tensor pos_grad, view_pos_grad, perturbed_nrm_grad, smooth_nrm_grad, smooth_tng_grad, geom_nrm_grad;\n    p.pos = make_cuda_tensor(pos, p.gridSize, &pos_grad);\n    p.view_pos = make_cuda_tensor(view_pos, p.gridSize, &view_pos_grad);\n    p.perturbed_nrm = make_cuda_tensor(perturbed_nrm, p.gridSize, &perturbed_nrm_grad);\n    p.smooth_nrm = make_cuda_tensor(smooth_nrm, p.gridSize, &smooth_nrm_grad);\n    p.smooth_tng = make_cuda_tensor(smooth_tng, p.gridSize, &smooth_tng_grad);\n    p.geom_nrm = make_cuda_tensor(geom_nrm, p.gridSize, &geom_nrm_grad);\n    p.out = make_cuda_tensor(grad, p.gridSize);\n\n    // Launch CUDA kernel.\n    void* args[] = { &p };\n    NVDR_CHECK_CUDA_ERROR(cudaLaunchKernel((const void*)PrepareShadingNormalBwdKernel, gridSize, blockSize, args, 0, stream));\n\n    return std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>(pos_grad, view_pos_grad, perturbed_nrm_grad, smooth_nrm_grad, smooth_tng_grad, geom_nrm_grad);\n}\n\n//------------------------------------------------------------------------\n// lambert\n\ntorch::Tensor lambert_fwd(torch::Tensor nrm, torch::Tensor wi, bool fp16)\n{\n    CHECK_TENSOR(nrm, 4, 3);\n    CHECK_TENSOR(wi, 4, 3);\n\n    cudaStream_t stream = at::cuda::getCurrentCUDAStream();\n\n    // Extract input parameters.\n    LambertKernelParams p;\n    p.out.fp16 = fp16;\n    update_grid(p.gridSize, nrm, wi);\n\n    // Allocate output tensors.\n    torch::TensorOptions opts = torch::TensorOptions().dtype(p.out.fp16 ? torch::kBFloat16 : torch::kFloat32).device(torch::kCUDA);\n    torch::Tensor out = torch::empty({ p.gridSize.z, p.gridSize.y, p.gridSize.x, 1 }, opts);\n\n    // Choose launch parameters.\n    dim3 blockSize = getLaunchBlockSize(BLOCK_X, BLOCK_Y, p.gridSize);\n    dim3 gridSize = getLaunchGridSize(blockSize, p.gridSize);\n\n    p.nrm = make_cuda_tensor(nrm, p.gridSize);\n    p.wi = make_cuda_tensor(wi, p.gridSize);\n    p.out = make_cuda_tensor(out, p.gridSize);\n\n    // Launch CUDA kernel.\n    void* args[] = { &p };\n    NVDR_CHECK_CUDA_ERROR(cudaLaunchKernel((const void*)LambertFwdKernel, gridSize, blockSize, args, 0, stream));\n\n    return out;\n}\n\nstd::tuple<torch::Tensor, torch::Tensor> lambert_bwd(torch::Tensor nrm, torch::Tensor wi, torch::Tensor grad)\n{\n    cudaStream_t stream = at::cuda::getCurrentCUDAStream();\n\n    // Extract input parameters.\n    LambertKernelParams p;\n    update_grid(p.gridSize, nrm, wi);\n\n    // Choose launch parameters.\n    dim3 blockSize = getLaunchBlockSize(BLOCK_X, BLOCK_Y, p.gridSize);\n    dim3 gridSize = getLaunchGridSize(blockSize, p.gridSize);\n\n    torch::Tensor nrm_grad, wi_grad;\n    p.nrm = make_cuda_tensor(nrm, p.gridSize, &nrm_grad);\n    p.wi = make_cuda_tensor(wi, p.gridSize, &wi_grad);\n    p.out = make_cuda_tensor(grad, p.gridSize);\n\n    // Launch CUDA kernel.\n    void* args[] = { &p };\n    NVDR_CHECK_CUDA_ERROR(cudaLaunchKernel((const void*)LambertBwdKernel, gridSize, blockSize, args, 0, stream));\n\n    return std::tuple<torch::Tensor, torch::Tensor>(nrm_grad, wi_grad);\n}\n\n//------------------------------------------------------------------------\n// frostbite diffuse\n\ntorch::Tensor frostbite_fwd(torch::Tensor nrm, torch::Tensor wi, torch::Tensor wo, torch::Tensor linearRoughness, bool fp16)\n{\n    CHECK_TENSOR(nrm, 4, 3);\n    CHECK_TENSOR(wi, 4, 3);\n    CHECK_TENSOR(wo, 4, 3);\n    CHECK_TENSOR(linearRoughness, 4, 1);\n\n    cudaStream_t stream = at::cuda::getCurrentCUDAStream();\n\n    // Extract input parameters.\n    FrostbiteDiffuseKernelParams p;\n    p.out.fp16 = fp16;\n    update_grid(p.gridSize, nrm, wi, wo, linearRoughness);\n\n    // Allocate output tensors.\n    torch::TensorOptions opts = torch::TensorOptions().dtype(p.out.fp16 ? torch::kBFloat16 : torch::kFloat32).device(torch::kCUDA);\n    torch::Tensor out = torch::empty({ p.gridSize.z, p.gridSize.y, p.gridSize.x, 1 }, opts);\n\n    // Choose launch parameters.\n    dim3 blockSize = getLaunchBlockSize(BLOCK_X, BLOCK_Y, p.gridSize);\n    dim3 gridSize = getLaunchGridSize(blockSize, p.gridSize);\n\n    p.nrm = make_cuda_tensor(nrm, p.gridSize);\n    p.wi = make_cuda_tensor(wi, p.gridSize);\n    p.wo = make_cuda_tensor(wo, p.gridSize);\n    p.linearRoughness = make_cuda_tensor(linearRoughness, p.gridSize);\n    p.out = make_cuda_tensor(out, p.gridSize);\n\n    // Launch CUDA kernel.\n    void* args[] = { &p };\n    NVDR_CHECK_CUDA_ERROR(cudaLaunchKernel((const void*)FrostbiteDiffuseFwdKernel, gridSize, blockSize, args, 0, stream));\n\n    return out;\n}\n\nstd::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor> frostbite_bwd(torch::Tensor nrm, torch::Tensor wi, torch::Tensor wo, torch::Tensor linearRoughness, torch::Tensor grad)\n{\n    cudaStream_t stream = at::cuda::getCurrentCUDAStream();\n\n    // Extract input parameters.\n    FrostbiteDiffuseKernelParams p;\n    update_grid(p.gridSize, nrm, wi, wo, linearRoughness);\n\n    // Choose launch parameters.\n    dim3 blockSize = getLaunchBlockSize(BLOCK_X, BLOCK_Y, p.gridSize);\n    dim3 gridSize = getLaunchGridSize(blockSize, p.gridSize);\n\n    torch::Tensor nrm_grad, wi_grad, wo_grad, linearRoughness_grad;\n    p.nrm = make_cuda_tensor(nrm, p.gridSize, &nrm_grad);\n    p.wi = make_cuda_tensor(wi, p.gridSize, &wi_grad);\n    p.wo = make_cuda_tensor(wo, p.gridSize, &wo_grad);\n    p.linearRoughness = make_cuda_tensor(linearRoughness, p.gridSize, &linearRoughness_grad);\n    p.out = make_cuda_tensor(grad, p.gridSize);\n\n    // Launch CUDA kernel.\n    void* args[] = { &p };\n    NVDR_CHECK_CUDA_ERROR(cudaLaunchKernel((const void*)FrostbiteDiffuseBwdKernel, gridSize, blockSize, args, 0, stream));\n\n    return std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>(nrm_grad, wi_grad, wo_grad, linearRoughness_grad);\n}\n\n//------------------------------------------------------------------------\n// fresnel_shlick\n\ntorch::Tensor fresnel_shlick_fwd(torch::Tensor f0, torch::Tensor f90, torch::Tensor cosTheta, bool fp16)\n{\n    CHECK_TENSOR(f0, 4, 3);\n    CHECK_TENSOR(f90, 4, 3);\n    CHECK_TENSOR(cosTheta, 4, 1);\n\n    cudaStream_t stream = at::cuda::getCurrentCUDAStream();\n\n    // Extract input parameters.\n    FresnelShlickKernelParams p;\n    p.out.fp16 = fp16;\n    update_grid(p.gridSize, f0, f90, cosTheta);\n\n    // Allocate output tensors.\n    torch::TensorOptions opts = torch::TensorOptions().dtype(p.out.fp16 ? torch::kBFloat16 : torch::kFloat32).device(torch::kCUDA);\n    torch::Tensor out = torch::empty({ p.gridSize.z, p.gridSize.y, p.gridSize.x, 3 }, opts);\n\n    // Choose launch parameters.\n    dim3 blockSize = getLaunchBlockSize(BLOCK_X, BLOCK_Y, p.gridSize);\n    dim3 gridSize = getLaunchGridSize(blockSize, p.gridSize);\n\n    p.f0 = make_cuda_tensor(f0, p.gridSize);\n    p.f90 = make_cuda_tensor(f90, p.gridSize);\n    p.cosTheta = make_cuda_tensor(cosTheta, p.gridSize);\n    p.out = make_cuda_tensor(out, p.gridSize);\n\n    // Launch CUDA kernel.\n    void* args[] = { &p };\n    NVDR_CHECK_CUDA_ERROR(cudaLaunchKernel((const void*)FresnelShlickFwdKernel, gridSize, blockSize, args, 0, stream));\n\n    return out;\n}\n\nstd::tuple<torch::Tensor, torch::Tensor, torch::Tensor> fresnel_shlick_bwd(torch::Tensor f0, torch::Tensor f90, torch::Tensor cosTheta, torch::Tensor grad)\n{\n    cudaStream_t stream = at::cuda::getCurrentCUDAStream();\n\n    // Extract input parameters.\n    FresnelShlickKernelParams p;\n    update_grid(p.gridSize, f0, f90, cosTheta);\n\n    // Choose launch parameters.\n    dim3 blockSize = getLaunchBlockSize(BLOCK_X, BLOCK_Y, p.gridSize);\n    dim3 gridSize = getLaunchGridSize(blockSize, p.gridSize);\n\n    torch::Tensor f0_grad, f90_grad, cosT_grad;\n    p.f0 = make_cuda_tensor(f0, p.gridSize, &f0_grad);\n    p.f90 = make_cuda_tensor(f90, p.gridSize, &f90_grad);\n    p.cosTheta = make_cuda_tensor(cosTheta, p.gridSize, &cosT_grad);\n    p.out = make_cuda_tensor(grad, p.gridSize);\n\n    // Launch CUDA kernel.\n    void* args[] = { &p };\n    NVDR_CHECK_CUDA_ERROR(cudaLaunchKernel((const void*)FresnelShlickBwdKernel, gridSize, blockSize, args, 0, stream));\n\n    return std::tuple<torch::Tensor, torch::Tensor, torch::Tensor>(f0_grad, f90_grad, cosT_grad);\n}\n\n//------------------------------------------------------------------------\n// ndf_ggd\n\ntorch::Tensor ndf_ggx_fwd(torch::Tensor alphaSqr, torch::Tensor cosTheta, bool fp16)\n{\n    CHECK_TENSOR(alphaSqr, 4, 1);\n    CHECK_TENSOR(cosTheta, 4, 1);\n\n    cudaStream_t stream = at::cuda::getCurrentCUDAStream();\n\n    // Extract input parameters.\n    NdfGGXParams p;\n    p.out.fp16 = fp16;\n    update_grid(p.gridSize, alphaSqr, cosTheta);\n\n    // Allocate output tensors.\n    torch::TensorOptions opts = torch::TensorOptions().dtype(p.out.fp16 ? torch::kBFloat16 : torch::kFloat32).device(torch::kCUDA);\n    torch::Tensor out = torch::empty({ p.gridSize.z, p.gridSize.y, p.gridSize.x, 1 }, opts);\n\n    // Choose launch parameters.\n    dim3 blockSize = getLaunchBlockSize(BLOCK_X, BLOCK_Y, p.gridSize);\n    dim3 gridSize = getLaunchGridSize(blockSize, p.gridSize);\n\n    p.alphaSqr = make_cuda_tensor(alphaSqr, p.gridSize);\n    p.cosTheta = make_cuda_tensor(cosTheta, p.gridSize);\n    p.out = make_cuda_tensor(out, p.gridSize);\n\n    // Launch CUDA kernel.\n    void* args[] = { &p };\n    NVDR_CHECK_CUDA_ERROR(cudaLaunchKernel((const void*)ndfGGXFwdKernel, gridSize, blockSize, args, 0, stream));\n\n    return out;\n}\n\nstd::tuple<torch::Tensor, torch::Tensor> ndf_ggx_bwd(torch::Tensor alphaSqr, torch::Tensor cosTheta, torch::Tensor grad)\n{\n    cudaStream_t stream = at::cuda::getCurrentCUDAStream();\n\n    // Extract input parameters.\n    NdfGGXParams p;\n    update_grid(p.gridSize, alphaSqr, cosTheta);\n\n    // Choose launch parameters.\n    dim3 blockSize = getLaunchBlockSize(BLOCK_X, BLOCK_Y, p.gridSize);\n    dim3 gridSize = getLaunchGridSize(blockSize, p.gridSize);\n\n    torch::Tensor alphaSqr_grad, cosTheta_grad;\n    p.alphaSqr = make_cuda_tensor(alphaSqr, p.gridSize, &alphaSqr_grad);\n    p.cosTheta = make_cuda_tensor(cosTheta, p.gridSize, &cosTheta_grad);\n    p.out = make_cuda_tensor(grad, p.gridSize);\n\n    // Launch CUDA kernel.\n    void* args[] = { &p };\n    NVDR_CHECK_CUDA_ERROR(cudaLaunchKernel((const void*)ndfGGXBwdKernel, gridSize, blockSize, args, 0, stream));\n\n    return std::tuple<torch::Tensor, torch::Tensor>(alphaSqr_grad, cosTheta_grad);\n}\n\n//------------------------------------------------------------------------\n// lambda_ggx\n\ntorch::Tensor lambda_ggx_fwd(torch::Tensor alphaSqr, torch::Tensor cosTheta, bool fp16)\n{\n    CHECK_TENSOR(alphaSqr, 4, 1);\n    CHECK_TENSOR(cosTheta, 4, 1);\n\n    cudaStream_t stream = at::cuda::getCurrentCUDAStream();\n\n    // Extract input parameters.\n    NdfGGXParams p;\n    p.out.fp16 = fp16;\n    update_grid(p.gridSize, alphaSqr, cosTheta);\n\n    // Allocate output tensors.\n    torch::TensorOptions opts = torch::TensorOptions().dtype(p.out.fp16 ? torch::kBFloat16 : torch::kFloat32).device(torch::kCUDA);\n    torch::Tensor out = torch::empty({ p.gridSize.z, p.gridSize.y, p.gridSize.x, 1 }, opts);\n\n    // Choose launch parameters.\n    dim3 blockSize = getLaunchBlockSize(BLOCK_X, BLOCK_Y, p.gridSize);\n    dim3 gridSize = getLaunchGridSize(blockSize, p.gridSize);\n\n    p.alphaSqr = make_cuda_tensor(alphaSqr, p.gridSize);\n    p.cosTheta = make_cuda_tensor(cosTheta, p.gridSize);\n    p.out = make_cuda_tensor(out, p.gridSize);\n\n    // Launch CUDA kernel.\n    void* args[] = { &p };\n    NVDR_CHECK_CUDA_ERROR(cudaLaunchKernel((const void*)lambdaGGXFwdKernel, gridSize, blockSize, args, 0, stream));\n\n    return out;\n}\n\nstd::tuple<torch::Tensor, torch::Tensor> lambda_ggx_bwd(torch::Tensor alphaSqr, torch::Tensor cosTheta, torch::Tensor grad)\n{\n    cudaStream_t stream = at::cuda::getCurrentCUDAStream();\n\n    // Extract input parameters.\n    NdfGGXParams p;\n    update_grid(p.gridSize, alphaSqr, cosTheta);\n\n    // Choose launch parameters.\n    dim3 blockSize = getLaunchBlockSize(BLOCK_X, BLOCK_Y, p.gridSize);\n    dim3 gridSize = getLaunchGridSize(blockSize, p.gridSize);\n\n    torch::Tensor alphaSqr_grad, cosTheta_grad;\n    p.alphaSqr = make_cuda_tensor(alphaSqr, p.gridSize, &alphaSqr_grad);\n    p.cosTheta = make_cuda_tensor(cosTheta, p.gridSize, &cosTheta_grad);\n    p.out = make_cuda_tensor(grad, p.gridSize);\n\n    // Launch CUDA kernel.\n    void* args[] = { &p };\n    NVDR_CHECK_CUDA_ERROR(cudaLaunchKernel((const void*)lambdaGGXBwdKernel, gridSize, blockSize, args, 0, stream));\n\n    return std::tuple<torch::Tensor, torch::Tensor>(alphaSqr_grad, cosTheta_grad);\n}\n\n//------------------------------------------------------------------------\n// masking_smith\n\ntorch::Tensor masking_smith_fwd(torch::Tensor alphaSqr, torch::Tensor cosThetaI, torch::Tensor cosThetaO, bool fp16)\n{\n    CHECK_TENSOR(alphaSqr, 4, 1);\n    CHECK_TENSOR(cosThetaI, 4, 1);\n    CHECK_TENSOR(cosThetaO, 4, 1);\n\n    cudaStream_t stream = at::cuda::getCurrentCUDAStream();\n\n    // Extract input parameters.\n    MaskingSmithParams p;\n    p.out.fp16 = fp16;\n    update_grid(p.gridSize, alphaSqr, cosThetaI, cosThetaO);\n\n    // Allocate output tensors.\n    torch::TensorOptions opts = torch::TensorOptions().dtype(p.out.fp16 ? torch::kBFloat16 : torch::kFloat32).device(torch::kCUDA);\n    torch::Tensor out = torch::empty({ p.gridSize.z, p.gridSize.y, p.gridSize.x, 1 }, opts);\n\n    // Choose launch parameters.\n    dim3 blockSize = getLaunchBlockSize(BLOCK_X, BLOCK_Y, p.gridSize);\n    dim3 gridSize = getLaunchGridSize(blockSize, p.gridSize);\n\n    p.alphaSqr = make_cuda_tensor(alphaSqr, p.gridSize);\n    p.cosThetaI = make_cuda_tensor(cosThetaI, p.gridSize);\n    p.cosThetaO = make_cuda_tensor(cosThetaO, p.gridSize);\n    p.out = make_cuda_tensor(out, p.gridSize);\n\n    // Launch CUDA kernel.\n    void* args[] = { &p };\n    NVDR_CHECK_CUDA_ERROR(cudaLaunchKernel((const void*)maskingSmithFwdKernel, gridSize, blockSize, args, 0, stream));\n\n    return out;\n}\n\nstd::tuple<torch::Tensor, torch::Tensor, torch::Tensor> masking_smith_bwd(torch::Tensor alphaSqr, torch::Tensor cosThetaI, torch::Tensor cosThetaO, torch::Tensor grad)\n{\n    cudaStream_t stream = at::cuda::getCurrentCUDAStream();\n\n    // Extract input parameters.\n    MaskingSmithParams p;\n    update_grid(p.gridSize, alphaSqr, cosThetaI, cosThetaO);\n\n    // Choose launch parameters.\n    dim3 blockSize = getLaunchBlockSize(BLOCK_X, BLOCK_Y, p.gridSize);\n    dim3 gridSize = getLaunchGridSize(blockSize, p.gridSize);\n\n    torch::Tensor alphaSqr_grad, cosThetaI_grad, cosThetaO_grad;\n    p.alphaSqr = make_cuda_tensor(alphaSqr, p.gridSize, &alphaSqr_grad);\n    p.cosThetaI = make_cuda_tensor(cosThetaI, p.gridSize, &cosThetaI_grad);\n    p.cosThetaO = make_cuda_tensor(cosThetaO, p.gridSize, &cosThetaO_grad);\n    p.out = make_cuda_tensor(grad, p.gridSize);\n\n    // Launch CUDA kernel.\n    void* args[] = { &p };\n    NVDR_CHECK_CUDA_ERROR(cudaLaunchKernel((const void*)maskingSmithBwdKernel, gridSize, blockSize, args, 0, stream));\n\n    return std::tuple<torch::Tensor, torch::Tensor, torch::Tensor>(alphaSqr_grad, cosThetaI_grad, cosThetaO_grad);\n}\n\n//------------------------------------------------------------------------\n// pbr_specular\n\ntorch::Tensor pbr_specular_fwd(torch::Tensor col, torch::Tensor nrm, torch::Tensor wo, torch::Tensor wi, torch::Tensor alpha, float min_roughness, bool fp16)\n{\n    CHECK_TENSOR(col, 4, 3);\n    CHECK_TENSOR(nrm, 4, 3);\n    CHECK_TENSOR(wo, 4, 3);\n    CHECK_TENSOR(wi, 4, 3);\n    CHECK_TENSOR(alpha, 4, 1);\n\n    cudaStream_t stream = at::cuda::getCurrentCUDAStream();\n\n    // Extract input parameters.\n    PbrSpecular p;\n    p.out.fp16 = fp16;\n    p.min_roughness = min_roughness;\n    update_grid(p.gridSize, col, nrm, wo, wi, alpha);\n\n    // Allocate output tensors.\n    torch::TensorOptions opts = torch::TensorOptions().dtype(p.out.fp16 ? torch::kBFloat16 : torch::kFloat32).device(torch::kCUDA);\n    torch::Tensor out = torch::empty({ p.gridSize.z, p.gridSize.y, p.gridSize.x, 3 }, opts);\n\n    // Choose launch parameters.\n    dim3 blockSize = getLaunchBlockSize(BLOCK_X, BLOCK_Y, p.gridSize);\n    dim3 gridSize = getLaunchGridSize(blockSize, p.gridSize);\n\n    p.col = make_cuda_tensor(col, p.gridSize);\n    p.nrm = make_cuda_tensor(nrm, p.gridSize);\n    p.wo = make_cuda_tensor(wo, p.gridSize);\n    p.wi = make_cuda_tensor(wi, p.gridSize);\n    p.alpha = make_cuda_tensor(alpha, p.gridSize);\n    p.out = make_cuda_tensor(out, p.gridSize);\n\n    // Launch CUDA kernel.\n    void* args[] = { &p };\n    NVDR_CHECK_CUDA_ERROR(cudaLaunchKernel((const void*)pbrSpecularFwdKernel, gridSize, blockSize, args, 0, stream));\n\n    return out;\n}\n\nstd::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor> pbr_specular_bwd(torch::Tensor col, torch::Tensor nrm, torch::Tensor wo, torch::Tensor wi, torch::Tensor alpha, float min_roughness, torch::Tensor grad)\n{\n    cudaStream_t stream = at::cuda::getCurrentCUDAStream();\n\n    // Extract input parameters.\n    PbrSpecular p;\n    update_grid(p.gridSize, col, nrm, wo, wi, alpha);\n    p.min_roughness = min_roughness;\n\n    // Choose launch parameters.\n    dim3 blockSize = getLaunchBlockSize(BLOCK_X, BLOCK_Y, p.gridSize);\n    dim3 gridSize = getLaunchGridSize(blockSize, p.gridSize);\n\n    torch::Tensor col_grad, nrm_grad, wo_grad, wi_grad, alpha_grad;\n    p.col = make_cuda_tensor(col, p.gridSize, &col_grad);\n    p.nrm = make_cuda_tensor(nrm, p.gridSize, &nrm_grad);\n    p.wo = make_cuda_tensor(wo, p.gridSize, &wo_grad);\n    p.wi = make_cuda_tensor(wi, p.gridSize, &wi_grad);\n    p.alpha = make_cuda_tensor(alpha, p.gridSize, &alpha_grad);\n    p.out = make_cuda_tensor(grad, p.gridSize);\n\n    // Launch CUDA kernel.\n    void* args[] = { &p };\n    NVDR_CHECK_CUDA_ERROR(cudaLaunchKernel((const void*)pbrSpecularBwdKernel, gridSize, blockSize, args, 0, stream));\n\n    return std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>(col_grad, nrm_grad, wo_grad, wi_grad, alpha_grad);\n}\n\n//------------------------------------------------------------------------\n// pbr_bsdf\n\ntorch::Tensor pbr_bsdf_fwd(torch::Tensor kd, torch::Tensor arm, torch::Tensor pos, torch::Tensor nrm, torch::Tensor view_pos, torch::Tensor light_pos, float min_roughness, int BSDF, bool fp16)\n{\n    CHECK_TENSOR(kd, 4, 3);\n    CHECK_TENSOR(arm, 4, 3);\n    CHECK_TENSOR(pos, 4, 3);\n    CHECK_TENSOR(nrm, 4, 3);\n    CHECK_TENSOR(view_pos, 4, 3);\n    CHECK_TENSOR(light_pos, 4, 3);\n\n    cudaStream_t stream = at::cuda::getCurrentCUDAStream();\n\n    // Extract input parameters.\n    PbrBSDF p;\n    p.out.fp16 = fp16;\n    p.min_roughness = min_roughness;\n    p.BSDF = BSDF;\n    update_grid(p.gridSize, kd, arm, pos, nrm, view_pos, light_pos);\n\n    // Allocate output tensors.\n    torch::TensorOptions opts = torch::TensorOptions().dtype(p.out.fp16 ? torch::kBFloat16 : torch::kFloat32).device(torch::kCUDA);\n    torch::Tensor out = torch::empty({ p.gridSize.z, p.gridSize.y, p.gridSize.x, 3 }, opts);\n\n    // Choose launch parameters.\n    dim3 blockSize = getLaunchBlockSize(BLOCK_X, BLOCK_Y, p.gridSize);\n    dim3 gridSize = getLaunchGridSize(blockSize, p.gridSize);\n\n    p.kd = make_cuda_tensor(kd, p.gridSize);\n    p.arm = make_cuda_tensor(arm, p.gridSize);\n    p.pos = make_cuda_tensor(pos, p.gridSize);\n    p.nrm = make_cuda_tensor(nrm, p.gridSize);\n    p.view_pos = make_cuda_tensor(view_pos, p.gridSize);\n    p.light_pos = make_cuda_tensor(light_pos, p.gridSize);\n    p.out = make_cuda_tensor(out, p.gridSize);\n\n    // Launch CUDA kernel.\n    void* args[] = { &p };\n    NVDR_CHECK_CUDA_ERROR(cudaLaunchKernel((const void*)pbrBSDFFwdKernel, gridSize, blockSize, args, 0, stream));\n\n    return out;\n}\n\nstd::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor> pbr_bsdf_bwd(torch::Tensor kd, torch::Tensor arm, torch::Tensor pos, torch::Tensor nrm, torch::Tensor view_pos, torch::Tensor light_pos, float min_roughness, int BSDF, torch::Tensor grad)\n{\n    cudaStream_t stream = at::cuda::getCurrentCUDAStream();\n\n    // Extract input parameters.\n    PbrBSDF p;\n    update_grid(p.gridSize, kd, arm, pos, nrm, view_pos, light_pos);\n    p.min_roughness = min_roughness;\n    p.BSDF = BSDF;\n\n    // Choose launch parameters.\n    dim3 blockSize = getLaunchBlockSize(BLOCK_X, BLOCK_Y, p.gridSize);\n    dim3 gridSize = getLaunchGridSize(blockSize, p.gridSize);\n\n    torch::Tensor kd_grad, arm_grad, pos_grad, nrm_grad, view_pos_grad, light_pos_grad;\n    p.kd = make_cuda_tensor(kd, p.gridSize, &kd_grad);\n    p.arm = make_cuda_tensor(arm, p.gridSize, &arm_grad);\n    p.pos = make_cuda_tensor(pos, p.gridSize, &pos_grad);\n    p.nrm = make_cuda_tensor(nrm, p.gridSize, &nrm_grad);\n    p.view_pos = make_cuda_tensor(view_pos, p.gridSize, &view_pos_grad);\n    p.light_pos = make_cuda_tensor(light_pos, p.gridSize, &light_pos_grad);\n    p.out = make_cuda_tensor(grad, p.gridSize);\n\n    // Launch CUDA kernel.\n    void* args[] = { &p };\n    NVDR_CHECK_CUDA_ERROR(cudaLaunchKernel((const void*)pbrBSDFBwdKernel, gridSize, blockSize, args, 0, stream));\n\n    return std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>(kd_grad, arm_grad, pos_grad, nrm_grad, view_pos_grad, light_pos_grad);\n}\n\n//------------------------------------------------------------------------\n// filter_cubemap\n\ntorch::Tensor diffuse_cubemap_fwd(torch::Tensor cubemap)\n{\n    CHECK_TENSOR(cubemap, 4, 3);\n\n    cudaStream_t stream = at::cuda::getCurrentCUDAStream();\n\n    // Extract input parameters.\n    DiffuseCubemapKernelParams p;\n    update_grid(p.gridSize, cubemap);\n\n    // Allocate output tensors.\n    torch::TensorOptions opts = torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA);\n    torch::Tensor out = torch::empty({ p.gridSize.z, p.gridSize.y, p.gridSize.x, 3 }, opts);\n\n    // Choose launch parameters.\n    dim3 blockSize = getLaunchBlockSize(BLOCK_X, BLOCK_Y, p.gridSize);\n    dim3 gridSize = getLaunchGridSize(blockSize, p.gridSize);\n\n    // Setup tensors\n    p.cubemap = make_cuda_tensor(cubemap, p.gridSize);\n    p.out = make_cuda_tensor(out, p.gridSize);\n\n    // Launch CUDA kernel.\n    void* args[] = { &p };\n    NVDR_CHECK_CUDA_ERROR(cudaLaunchKernel((const void*)DiffuseCubemapFwdKernel, gridSize, blockSize, args, 0, stream));\n\n    return out;\n}\n\ntorch::Tensor diffuse_cubemap_bwd(torch::Tensor cubemap, torch::Tensor grad)\n{\n    CHECK_TENSOR(cubemap, 4, 3);\n    CHECK_TENSOR(grad, 4, 3);\n\n    cudaStream_t stream = at::cuda::getCurrentCUDAStream();\n\n    // Extract input parameters.\n    DiffuseCubemapKernelParams p;\n    update_grid(p.gridSize, cubemap);\n\n    // Choose launch parameters.\n    dim3 blockSize = getLaunchBlockSize(BLOCK_X, BLOCK_Y, p.gridSize);\n    dim3 gridSize = getLaunchGridSize(blockSize, p.gridSize);\n\n    // Setup tensors\n    torch::Tensor cubemap_grad;\n    p.cubemap = make_cuda_tensor(cubemap, p.gridSize);\n    p.out = make_cuda_tensor(grad, p.gridSize);\n\n    cubemap_grad = torch::zeros({ p.gridSize.z, p.gridSize.y, p.gridSize.x, cubemap.size(3) }, torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA));\n    p.cubemap.d_val = (void*)cubemap_grad.data_ptr<float>();\n\n    // Launch CUDA kernel.\n    void* args[] = { &p };\n    NVDR_CHECK_CUDA_ERROR(cudaLaunchKernel((const void*)DiffuseCubemapBwdKernel, gridSize, blockSize, args, 0, stream));\n\n    return cubemap_grad;\n}\n\ntorch::Tensor specular_bounds(int resolution, float costheta_cutoff)\n{\n    cudaStream_t stream = at::cuda::getCurrentCUDAStream();\n\n    // Extract input parameters.\n    SpecularBoundsKernelParams p;\n    p.costheta_cutoff = costheta_cutoff;\n    p.gridSize = dim3(resolution, resolution, 6);\n\n    // Allocate output tensors.\n    torch::TensorOptions opts = torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA);\n    torch::Tensor out = torch::zeros({ p.gridSize.z, p.gridSize.y, p.gridSize.x, 6*4 }, opts);\n\n    // Choose launch parameters.\n    dim3 blockSize = getLaunchBlockSize(BLOCK_X, BLOCK_Y, p.gridSize);\n    dim3 gridSize = getLaunchGridSize(blockSize, p.gridSize);\n\n    // Setup tensors\n    p.out = make_cuda_tensor(out, p.gridSize);\n\n    // Launch CUDA kernel.\n    void* args[] = { &p };\n    NVDR_CHECK_CUDA_ERROR(cudaLaunchKernel((const void*)SpecularBoundsKernel, gridSize, blockSize, args, 0, stream));\n\n    return out;\n}\n\ntorch::Tensor specular_cubemap_fwd(torch::Tensor cubemap, torch::Tensor bounds, float roughness, float costheta_cutoff)\n{\n    CHECK_TENSOR(cubemap, 4, 3);\n    CHECK_TENSOR(bounds, 4, 6*4);\n\n    cudaStream_t stream = at::cuda::getCurrentCUDAStream();\n\n    // Extract input parameters.\n    SpecularCubemapKernelParams p;\n    p.roughness = roughness;\n    p.costheta_cutoff = costheta_cutoff;\n    update_grid(p.gridSize, cubemap);\n\n    // Allocate output tensors.\n    torch::TensorOptions opts = torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA);\n    torch::Tensor out = torch::empty({ p.gridSize.z, p.gridSize.y, p.gridSize.x, 4 }, opts);\n\n    // Choose launch parameters.\n    dim3 blockSize = getLaunchBlockSize(BLOCK_X, BLOCK_Y, p.gridSize);\n    dim3 gridSize = getLaunchGridSize(blockSize, p.gridSize);\n\n    // Setup tensors\n    p.cubemap = make_cuda_tensor(cubemap, p.gridSize);\n    p.bounds = make_cuda_tensor(bounds, p.gridSize);\n    p.out = make_cuda_tensor(out, p.gridSize);\n\n    // Launch CUDA kernel.\n    void* args[] = { &p };\n    NVDR_CHECK_CUDA_ERROR(cudaLaunchKernel((const void*)SpecularCubemapFwdKernel, gridSize, blockSize, args, 0, stream));\n\n    return out;\n}\n\ntorch::Tensor specular_cubemap_bwd(torch::Tensor cubemap, torch::Tensor bounds, torch::Tensor grad, float roughness, float costheta_cutoff)\n{\n    CHECK_TENSOR(cubemap, 4, 3);\n    CHECK_TENSOR(bounds, 4, 6*4);\n\n    cudaStream_t stream = at::cuda::getCurrentCUDAStream();\n\n    // Extract input parameters.\n    SpecularCubemapKernelParams p;\n    p.roughness = roughness;\n    p.costheta_cutoff = costheta_cutoff;\n    update_grid(p.gridSize, cubemap);\n\n    // Choose launch parameters.\n    dim3 blockSize = getLaunchBlockSize(BLOCK_X, BLOCK_Y, p.gridSize);\n    dim3 gridSize = getLaunchGridSize(blockSize, p.gridSize);\n\n    // Setup tensors\n    torch::Tensor cubemap_grad;\n    p.cubemap = make_cuda_tensor(cubemap, p.gridSize);\n    p.bounds = make_cuda_tensor(bounds, p.gridSize);\n    p.out = make_cuda_tensor(grad, p.gridSize);\n\n    cubemap_grad = torch::zeros({ p.gridSize.z, p.gridSize.y, p.gridSize.x, cubemap.size(3) }, torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA));\n    p.cubemap.d_val = (void*)cubemap_grad.data_ptr<float>();\n\n    // Launch CUDA kernel.\n    void* args[] = { &p };\n    NVDR_CHECK_CUDA_ERROR(cudaLaunchKernel((const void*)SpecularCubemapBwdKernel, gridSize, blockSize, args, 0, stream));\n\n    return cubemap_grad;\n}\n\n//------------------------------------------------------------------------\n// loss function\n\nLossType strToLoss(std::string str)\n{\n    if (str == \"mse\")\n        return LOSS_MSE;\n    else if (str == \"relmse\")\n        return LOSS_RELMSE;\n    else if (str == \"smape\")\n        return LOSS_SMAPE;\n    else\n        return LOSS_L1;\n}\n\ntorch::Tensor image_loss_fwd(torch::Tensor img, torch::Tensor target, std::string loss, std::string tonemapper, bool fp16)\n{\n    CHECK_TENSOR(img, 4, 3);\n    CHECK_TENSOR(target, 4, 3);\n\n    cudaStream_t stream = at::cuda::getCurrentCUDAStream();\n\n    // Extract input parameters.\n    LossKernelParams p;\n    p.out.fp16 = fp16;\n    p.loss = strToLoss(loss);\n    p.tonemapper = tonemapper == \"log_srgb\" ? TONEMAPPER_LOG_SRGB : TONEMAPPER_NONE;\n    update_grid(p.gridSize, img, target);\n\n    // Choose launch parameters.\n    dim3 blockSize = getLaunchBlockSize(BLOCK_X, BLOCK_Y, p.gridSize);\n    dim3 warpSize = getWarpSize(blockSize);\n    dim3 gridSize = getLaunchGridSize(blockSize, p.gridSize);\n\n    // Allocate output tensors.\n    torch::TensorOptions opts = torch::TensorOptions().dtype(p.out.fp16 ? torch::kBFloat16 : torch::kFloat32).device(torch::kCUDA);\n    torch::Tensor out = torch::empty({ (p.gridSize.z - 1)/ warpSize.z + 1, (p.gridSize.y - 1) / warpSize.y + 1, (p.gridSize.x - 1) / warpSize.x + 1, 1 }, opts);\n\n    p.img = make_cuda_tensor(img, p.gridSize);\n    p.target = make_cuda_tensor(target, p.gridSize);\n    p.out = make_cuda_tensor(out, p.gridSize);\n\n    // Launch CUDA kernel.\n    void* args[] = { &p };\n    NVDR_CHECK_CUDA_ERROR(cudaLaunchKernel((const void*)imgLossFwdKernel, gridSize, blockSize, args, 0, stream));\n\n    return out;\n}\n\nstd::tuple<torch::Tensor, torch::Tensor> image_loss_bwd(torch::Tensor img, torch::Tensor target, torch::Tensor grad, std::string loss, std::string tonemapper)\n{\n    cudaStream_t stream = at::cuda::getCurrentCUDAStream();\n\n    // Extract input parameters.\n    LossKernelParams p;\n    p.loss = strToLoss(loss);\n    p.tonemapper = tonemapper == \"log_srgb\" ? TONEMAPPER_LOG_SRGB : TONEMAPPER_NONE;\n    update_grid(p.gridSize, img, target);\n\n    // Choose launch parameters.\n    dim3 blockSize = getLaunchBlockSize(BLOCK_X, BLOCK_Y, p.gridSize);\n    dim3 warpSize = getWarpSize(blockSize);\n    dim3 gridSize = getLaunchGridSize(blockSize, p.gridSize);\n\n    torch::Tensor img_grad, target_grad;\n    p.img = make_cuda_tensor(img, p.gridSize, &img_grad);\n    p.target = make_cuda_tensor(target, p.gridSize, &target_grad);\n    p.out = make_cuda_tensor(grad, p.gridSize);\n\n    // Launch CUDA kernel.\n    void* args[] = { &p };\n    NVDR_CHECK_CUDA_ERROR(cudaLaunchKernel((const void*)imgLossBwdKernel, gridSize, blockSize, args, 0, stream));\n\n    return std::tuple<torch::Tensor, torch::Tensor>(img_grad, target_grad);\n}\n\n//------------------------------------------------------------------------\n// transform function\n\ntorch::Tensor xfm_fwd(torch::Tensor points, torch::Tensor matrix, bool isPoints, bool fp16)\n{\n    CHECK_TENSOR(points, 3, 3);\n    CHECK_TENSOR(matrix, 3, 4);\n\n    cudaStream_t stream = at::cuda::getCurrentCUDAStream();\n\n    // Extract input parameters.\n    XfmKernelParams p;\n    p.out.fp16 = fp16;\n    p.isPoints = isPoints;\n    p.gridSize.x = points.size(1);\n    p.gridSize.y = 1;\n    p.gridSize.z = std::max(matrix.size(0), points.size(0));\n\n    // Choose launch parameters.\n    dim3 blockSize(BLOCK_X * BLOCK_Y, 1, 1);\n    dim3 warpSize = getWarpSize(blockSize);\n    dim3 gridSize = getLaunchGridSize(blockSize, p.gridSize);\n\n    // Allocate output tensors.\n    torch::TensorOptions opts = torch::TensorOptions().dtype(p.out.fp16 ? torch::kBFloat16 : torch::kFloat32).device(torch::kCUDA);\n    torch::Tensor out = isPoints ? torch::empty({ matrix.size(0), points.size(1), 4 }, opts) : torch::empty({ matrix.size(0), points.size(1), 3 }, opts);\n\n    p.points = make_cuda_tensor(points, p.gridSize);\n    p.matrix = make_cuda_tensor(matrix, p.gridSize);\n    p.out = make_cuda_tensor(out, p.gridSize);\n\n    // Launch CUDA kernel.\n    void* args[] = { &p };\n    NVDR_CHECK_CUDA_ERROR(cudaLaunchKernel((const void*)xfmPointsFwdKernel, gridSize, blockSize, args, 0, stream));\n\n    return out;\n}\n\ntorch::Tensor xfm_bwd(torch::Tensor points, torch::Tensor matrix, torch::Tensor grad, bool isPoints)\n{\n    cudaStream_t stream = at::cuda::getCurrentCUDAStream();\n\n    // Extract input parameters.\n    XfmKernelParams p;\n    p.isPoints = isPoints;\n    p.gridSize.x = points.size(1);\n    p.gridSize.y = 1;\n    p.gridSize.z = std::max(matrix.size(0), points.size(0));\n\n    // Choose launch parameters.\n    dim3 blockSize(BLOCK_X * BLOCK_Y, 1, 1);\n    dim3 warpSize = getWarpSize(blockSize);\n    dim3 gridSize = getLaunchGridSize(blockSize, p.gridSize);\n\n    torch::Tensor points_grad;\n    p.points = make_cuda_tensor(points, p.gridSize, &points_grad);\n    p.matrix = make_cuda_tensor(matrix, p.gridSize);\n    p.out = make_cuda_tensor(grad, p.gridSize);\n\n    // Launch CUDA kernel.\n    void* args[] = { &p };\n    NVDR_CHECK_CUDA_ERROR(cudaLaunchKernel((const void*)xfmPointsBwdKernel, gridSize, blockSize, args, 0, stream));\n\n    return points_grad;\n}\n\nPYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {\n    m.def(\"prepare_shading_normal_fwd\", &prepare_shading_normal_fwd, \"prepare_shading_normal_fwd\");\n    m.def(\"prepare_shading_normal_bwd\", &prepare_shading_normal_bwd, \"prepare_shading_normal_bwd\");\n    m.def(\"lambert_fwd\", &lambert_fwd, \"lambert_fwd\");\n    m.def(\"lambert_bwd\", &lambert_bwd, \"lambert_bwd\");\n    m.def(\"frostbite_fwd\", &frostbite_fwd, \"frostbite_fwd\");\n    m.def(\"frostbite_bwd\", &frostbite_bwd, \"frostbite_bwd\");\n    m.def(\"fresnel_shlick_fwd\", &fresnel_shlick_fwd, \"fresnel_shlick_fwd\");\n    m.def(\"fresnel_shlick_bwd\", &fresnel_shlick_bwd, \"fresnel_shlick_bwd\");\n    m.def(\"ndf_ggx_fwd\", &ndf_ggx_fwd, \"ndf_ggx_fwd\");\n    m.def(\"ndf_ggx_bwd\", &ndf_ggx_bwd, \"ndf_ggx_bwd\");\n    m.def(\"lambda_ggx_fwd\", &lambda_ggx_fwd, \"lambda_ggx_fwd\");\n    m.def(\"lambda_ggx_bwd\", &lambda_ggx_bwd, \"lambda_ggx_bwd\");\n    m.def(\"masking_smith_fwd\", &masking_smith_fwd, \"masking_smith_fwd\");\n    m.def(\"masking_smith_bwd\", &masking_smith_bwd, \"masking_smith_bwd\");\n    m.def(\"pbr_specular_fwd\", &pbr_specular_fwd, \"pbr_specular_fwd\");\n    m.def(\"pbr_specular_bwd\", &pbr_specular_bwd, \"pbr_specular_bwd\");\n    m.def(\"pbr_bsdf_fwd\", &pbr_bsdf_fwd, \"pbr_bsdf_fwd\");\n    m.def(\"pbr_bsdf_bwd\", &pbr_bsdf_bwd, \"pbr_bsdf_bwd\");\n    m.def(\"diffuse_cubemap_fwd\", &diffuse_cubemap_fwd, \"diffuse_cubemap_fwd\");\n    m.def(\"diffuse_cubemap_bwd\", &diffuse_cubemap_bwd, \"diffuse_cubemap_bwd\");\n    m.def(\"specular_bounds\", &specular_bounds, \"specular_bounds\");\n    m.def(\"specular_cubemap_fwd\", &specular_cubemap_fwd, \"specular_cubemap_fwd\");\n    m.def(\"specular_cubemap_bwd\", &specular_cubemap_bwd, \"specular_cubemap_bwd\");\n    m.def(\"image_loss_fwd\", &image_loss_fwd, \"image_loss_fwd\");\n    m.def(\"image_loss_bwd\", &image_loss_bwd, \"image_loss_bwd\");\n    m.def(\"xfm_fwd\", &xfm_fwd, \"xfm_fwd\");\n    m.def(\"xfm_bwd\", &xfm_bwd, \"xfm_bwd\");\n}"
  },
  {
    "path": "render/renderutils/c_src/vec3f.h",
    "content": "/*\n * Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n *\n * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual\n * property and proprietary rights in and to this material, related \n * documentation and any modifications thereto. Any use, reproduction, \n * disclosure or distribution of this material and related documentation\n * without an express license agreement from NVIDIA CORPORATION or \n * its affiliates is strictly prohibited.\n */\n\n#pragma once \n\nstruct vec3f\n{\n    float x, y, z;\n\n#ifdef __CUDACC__\n    __device__ vec3f() { }\n    __device__ vec3f(float v) { x = v; y = v; z = v; }\n    __device__ vec3f(float _x, float _y, float _z) { x = _x; y = _y; z = _z; }\n    __device__ vec3f(float3 v) { x = v.x; y = v.y; z = v.z; }\n\n    __device__ inline vec3f& operator+=(const vec3f& b) { x += b.x; y += b.y; z += b.z; return *this; }\n    __device__ inline vec3f& operator-=(const vec3f& b) { x -= b.x; y -= b.y; z -= b.z; return *this; }\n    __device__ inline vec3f& operator*=(const vec3f& b) { x *= b.x; y *= b.y; z *= b.z; return *this; }\n    __device__ inline vec3f& operator/=(const vec3f& b) { x /= b.x; y /= b.y; z /= b.z; return *this; }\n#endif\n};\n\n#ifdef __CUDACC__\n__device__ static inline vec3f operator+(const vec3f& a, const vec3f& b) { return vec3f(a.x + b.x, a.y + b.y, a.z + b.z); }\n__device__ static inline vec3f operator-(const vec3f& a, const vec3f& b) { return vec3f(a.x - b.x, a.y - b.y, a.z - b.z); }\n__device__ static inline vec3f operator*(const vec3f& a, const vec3f& b) { return vec3f(a.x * b.x, a.y * b.y, a.z * b.z); }\n__device__ static inline vec3f operator/(const vec3f& a, const vec3f& b) { return vec3f(a.x / b.x, a.y / b.y, a.z / b.z); }\n__device__ static inline vec3f operator-(const vec3f& a) { return vec3f(-a.x, -a.y, -a.z); }\n\n__device__ static inline float sum(vec3f a)\n{\n    return a.x + a.y + a.z;\n}\n\n__device__ static inline vec3f cross(vec3f a, vec3f b)\n{\n    vec3f out;\n    out.x = a.y * b.z - a.z * b.y;\n    out.y = a.z * b.x - a.x * b.z;\n    out.z = a.x * b.y - a.y * b.x;\n    return out;\n}\n\n__device__ static inline void bwdCross(vec3f a, vec3f b, vec3f &d_a, vec3f &d_b, vec3f d_out)\n{\n    d_a.x += d_out.z * b.y - d_out.y * b.z;\n    d_a.y += d_out.x * b.z - d_out.z * b.x;\n    d_a.z += d_out.y * b.x - d_out.x * b.y;\n\n    d_b.x += d_out.y * a.z - d_out.z * a.y;\n    d_b.y += d_out.z * a.x - d_out.x * a.z;\n    d_b.z += d_out.x * a.y - d_out.y * a.x;\n}\n\n__device__ static inline float dot(vec3f a, vec3f b)\n{\n    return a.x * b.x + a.y * b.y + a.z * b.z;\n}\n\n__device__ static inline void bwdDot(vec3f a, vec3f b, vec3f& d_a, vec3f& d_b, float d_out)\n{\n    d_a.x += d_out * b.x; d_a.y += d_out * b.y; d_a.z += d_out * b.z;\n    d_b.x += d_out * a.x; d_b.y += d_out * a.y; d_b.z += d_out * a.z;\n}\n\n__device__ static inline vec3f reflect(vec3f x, vec3f n)\n{\n    return n * 2.0f * dot(n, x) - x;\n}\n\n__device__ static inline void bwdReflect(vec3f x, vec3f n, vec3f& d_x, vec3f& d_n, const vec3f d_out)\n{\n    d_x.x += d_out.x * (2 * n.x * n.x - 1) + d_out.y * (2 * n.x * n.y) + d_out.z * (2 * n.x * n.z);\n    d_x.y += d_out.x * (2 * n.x * n.y) + d_out.y * (2 * n.y * n.y - 1) + d_out.z * (2 * n.y * n.z);\n    d_x.z += d_out.x * (2 * n.x * n.z) + d_out.y * (2 * n.y * n.z) + d_out.z * (2 * n.z * n.z - 1);\n\n    d_n.x += d_out.x * (2 * (2 * n.x * x.x + n.y * x.y + n.z * x.z)) + d_out.y * (2 * n.y * x.x) + d_out.z * (2 * n.z * x.x);\n    d_n.y += d_out.x * (2 * n.x * x.y) + d_out.y * (2 * (n.x * x.x + 2 * n.y * x.y + n.z * x.z)) + d_out.z * (2 * n.z * x.y);\n    d_n.z += d_out.x * (2 * n.x * x.z) + d_out.y * (2 * n.y * x.z) + d_out.z * (2 * (n.x * x.x + n.y * x.y + 2 * n.z * x.z));\n}\n\n__device__ static inline vec3f safeNormalize(vec3f v)\n{\n    float l = sqrtf(v.x * v.x + v.y * v.y + v.z * v.z);\n    return l > 0.0f ? (v / l) : vec3f(0.0f);\n}\n\n__device__ static inline void bwdSafeNormalize(const vec3f v, vec3f& d_v, const vec3f d_out)\n{\n\n    float l = sqrtf(v.x * v.x + v.y * v.y + v.z * v.z);\n    if (l > 0.0f)\n    {\n        float fac = 1.0 / powf(v.x * v.x + v.y * v.y + v.z * v.z, 1.5f);\n        d_v.x += (d_out.x * (v.y * v.y + v.z * v.z) - d_out.y * (v.x * v.y) - d_out.z * (v.x * v.z)) * fac;\n        d_v.y += (d_out.y * (v.x * v.x + v.z * v.z) - d_out.x * (v.y * v.x) - d_out.z * (v.y * v.z)) * fac;\n        d_v.z += (d_out.z * (v.x * v.x + v.y * v.y) - d_out.x * (v.z * v.x) - d_out.y * (v.z * v.y)) * fac;\n    }\n}\n\n#endif"
  },
  {
    "path": "render/renderutils/c_src/vec4f.h",
    "content": "/*\n * Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n *\n * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual\n * property and proprietary rights in and to this material, related \n * documentation and any modifications thereto. Any use, reproduction, \n * disclosure or distribution of this material and related documentation\n * without an express license agreement from NVIDIA CORPORATION or \n * its affiliates is strictly prohibited.\n */\n\n#pragma once \n\nstruct vec4f\n{\n    float x, y, z, w;\n\n#ifdef __CUDACC__\n    __device__ vec4f() { }\n    __device__ vec4f(float v) { x = v; y = v; z = v; w = v; }\n    __device__ vec4f(float _x, float _y, float _z, float _w) { x = _x; y = _y; z = _z; w = _w; }\n    __device__ vec4f(float4 v) { x = v.x; y = v.y; z = v.z; w = v.w; }\n#endif\n};\n\n"
  },
  {
    "path": "render/renderutils/loss.py",
    "content": "# Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. \n#\n# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual\n# property and proprietary rights in and to this material, related\n# documentation and any modifications thereto. Any use, reproduction, \n# disclosure or distribution of this material and related documentation \n# without an express license agreement from NVIDIA CORPORATION or \n# its affiliates is strictly prohibited.\n\nimport torch\n\n#----------------------------------------------------------------------------\n# HDR image losses\n#----------------------------------------------------------------------------\n\ndef _tonemap_srgb(f, exposure=5):\n    f = f * exposure\n    return torch.where(f > 0.0031308, torch.pow(torch.clamp(f, min=0.0031308), 1.0/2.4)*1.055 - 0.055, 12.92*f)\n\ndef _SMAPE(img, target, eps=0.01):\n    nom = torch.abs(img - target)\n    denom = torch.abs(img) + torch.abs(target) + 0.01\n    return torch.mean(nom / denom)\n\ndef _RELMSE(img, target, eps=0.1):\n    nom = (img - target) * (img - target)\n    denom = img * img + target * target + 0.1 \n    return torch.mean(nom / denom)\n\ndef image_loss_fn(img, target, loss, tonemapper):\n    if tonemapper == 'log_srgb':\n        img    = _tonemap_srgb(torch.log(torch.clamp(img, min=0, max=65535) + 1))\n        target = _tonemap_srgb(torch.log(torch.clamp(target, min=0, max=65535) + 1))\n\n    if loss == 'mse':\n        return torch.nn.functional.mse_loss(img, target)\n    elif loss == 'smape':\n        return _SMAPE(img, target)\n    elif loss == 'relmse':\n        return _RELMSE(img, target)\n    else:\n        return torch.nn.functional.l1_loss(img, target)\n"
  },
  {
    "path": "render/renderutils/ops.py",
    "content": "# Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. \n#\n# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual\n# property and proprietary rights in and to this material, related\n# documentation and any modifications thereto. Any use, reproduction, \n# disclosure or distribution of this material and related documentation \n# without an express license agreement from NVIDIA CORPORATION or \n# its affiliates is strictly prohibited.\n\nimport numpy as np\nimport os\nimport sys\nimport torch\nimport torch.utils.cpp_extension\n\nfrom .bsdf import *\nfrom .loss import *\n\n#----------------------------------------------------------------------------\n# C++/Cuda plugin compiler/loader.\n\n_cached_plugin = None\ndef _get_plugin():\n    # Return cached plugin if already loaded.\n    global _cached_plugin\n    if _cached_plugin is not None:\n        return _cached_plugin\n\n    # Make sure we can find the necessary compiler and libary binaries.\n    if os.name == 'nt':\n        def find_cl_path():\n            import glob\n            for edition in ['Enterprise', 'Professional', 'BuildTools', 'Community']:\n                paths = sorted(glob.glob(r\"C:\\Program Files (x86)\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64\" % edition), reverse=True)\n                if paths:\n                    return paths[0]\n\n        # If cl.exe is not on path, try to find it.\n        if os.system(\"where cl.exe >nul 2>nul\") != 0:\n            cl_path = find_cl_path()\n            if cl_path is None:\n                raise RuntimeError(\"Could not locate a supported Microsoft Visual C++ installation\")\n            os.environ['PATH'] += ';' + cl_path\n\n    # Compiler options.\n    opts = ['-DNVDR_TORCH']\n\n    # Linker options.\n    if os.name == 'posix':\n        ldflags = ['-lcuda', '-lnvrtc']\n    elif os.name == 'nt':\n        ldflags = ['cuda.lib', 'advapi32.lib', 'nvrtc.lib']\n\n    # List of sources.\n    source_files = [\n        'c_src/mesh.cu',\n        'c_src/loss.cu',\n        'c_src/bsdf.cu',\n        'c_src/normal.cu',\n        'c_src/cubemap.cu',\n        'c_src/common.cpp',\n        'c_src/torch_bindings.cpp'\n    ]\n\n    # Some containers set this to contain old architectures that won't compile. We only need the one installed in the machine.\n    os.environ['TORCH_CUDA_ARCH_LIST'] = ''\n\n    # Try to detect if a stray lock file is left in cache directory and show a warning. This sometimes happens on Windows if the build is interrupted at just the right moment.\n    try:\n        lock_fn = os.path.join(torch.utils.cpp_extension._get_build_directory('renderutils_plugin', False), 'lock')\n        if os.path.exists(lock_fn):\n            print(\"Warning: Lock file exists in build directory: '%s'\" % lock_fn)\n    except:\n        pass\n\n    # Compile and load.\n    build_dir = os.path.join(os. path. dirname(__file__), 'build')\n    os.makedirs(build_dir, exist_ok=True)\n    source_paths = [os.path.join(os.path.dirname(__file__), fn) for fn in source_files]\n    torch.utils.cpp_extension.load(name='renderutils_plugin', sources=source_paths, extra_cflags=opts,\n        build_directory=build_dir,\n        extra_cuda_cflags=opts, extra_ldflags=ldflags, with_cuda=True, verbose=True)\n\n    # Import, cache, and return the compiled module.\n    import renderutils_plugin\n    _cached_plugin = renderutils_plugin\n    return _cached_plugin\n\n#----------------------------------------------------------------------------\n# Internal kernels, just used for testing functionality\n\nclass _fresnel_shlick_func(torch.autograd.Function):\n    @staticmethod\n    def forward(ctx, f0, f90, cosTheta):\n        out = _get_plugin().fresnel_shlick_fwd(f0, f90, cosTheta, False)\n        ctx.save_for_backward(f0, f90, cosTheta)\n        return out\n\n    @staticmethod\n    def backward(ctx, dout):\n        f0, f90, cosTheta = ctx.saved_variables\n        return _get_plugin().fresnel_shlick_bwd(f0, f90, cosTheta, dout) + (None,)\n\ndef _fresnel_shlick(f0, f90, cosTheta, use_python=False):\n    if use_python:\n        out = bsdf_fresnel_shlick(f0, f90, cosTheta)\n    else:\n        out = _fresnel_shlick_func.apply(f0, f90, cosTheta)\n\n    if torch.is_anomaly_enabled():\n        assert torch.all(torch.isfinite(out)), \"Output of _fresnel_shlick contains inf or NaN\"\n    return out\n\n\nclass _ndf_ggx_func(torch.autograd.Function):\n    @staticmethod\n    def forward(ctx, alphaSqr, cosTheta):\n        out = _get_plugin().ndf_ggx_fwd(alphaSqr, cosTheta, False)\n        ctx.save_for_backward(alphaSqr, cosTheta)\n        return out\n\n    @staticmethod\n    def backward(ctx, dout):\n        alphaSqr, cosTheta = ctx.saved_variables\n        return _get_plugin().ndf_ggx_bwd(alphaSqr, cosTheta, dout) + (None,)\n\ndef _ndf_ggx(alphaSqr, cosTheta, use_python=False):\n    if use_python:\n        out = bsdf_ndf_ggx(alphaSqr, cosTheta)\n    else:\n        out = _ndf_ggx_func.apply(alphaSqr, cosTheta)\n\n    if torch.is_anomaly_enabled():\n        assert torch.all(torch.isfinite(out)), \"Output of _ndf_ggx contains inf or NaN\"\n    return out\n\nclass _lambda_ggx_func(torch.autograd.Function):\n    @staticmethod\n    def forward(ctx, alphaSqr, cosTheta):\n        out = _get_plugin().lambda_ggx_fwd(alphaSqr, cosTheta, False)\n        ctx.save_for_backward(alphaSqr, cosTheta)\n        return out\n\n    @staticmethod\n    def backward(ctx, dout):\n        alphaSqr, cosTheta = ctx.saved_variables\n        return _get_plugin().lambda_ggx_bwd(alphaSqr, cosTheta, dout) + (None,)\n\ndef _lambda_ggx(alphaSqr, cosTheta, use_python=False):\n    if use_python:\n        out = bsdf_lambda_ggx(alphaSqr, cosTheta)\n    else:\n        out = _lambda_ggx_func.apply(alphaSqr, cosTheta)\n\n    if torch.is_anomaly_enabled():\n        assert torch.all(torch.isfinite(out)), \"Output of _lambda_ggx contains inf or NaN\"\n    return out\n\nclass _masking_smith_func(torch.autograd.Function):\n    @staticmethod\n    def forward(ctx, alphaSqr, cosThetaI, cosThetaO):\n        ctx.save_for_backward(alphaSqr, cosThetaI, cosThetaO)\n        out = _get_plugin().masking_smith_fwd(alphaSqr, cosThetaI, cosThetaO, False)\n        return out\n\n    @staticmethod\n    def backward(ctx, dout):\n        alphaSqr, cosThetaI, cosThetaO = ctx.saved_variables\n        return _get_plugin().masking_smith_bwd(alphaSqr, cosThetaI, cosThetaO, dout) + (None,)\n\ndef _masking_smith(alphaSqr, cosThetaI, cosThetaO, use_python=False):\n    if use_python:\n        out = bsdf_masking_smith_ggx_correlated(alphaSqr, cosThetaI, cosThetaO)\n    else:\n        out = _masking_smith_func.apply(alphaSqr, cosThetaI, cosThetaO)\n\n    if torch.is_anomaly_enabled():\n        assert torch.all(torch.isfinite(out)), \"Output of _masking_smith contains inf or NaN\"\n    return out\n\n#----------------------------------------------------------------------------\n# Shading normal setup (bump mapping + bent normals)\n\nclass _prepare_shading_normal_func(torch.autograd.Function):\n    @staticmethod\n    def forward(ctx, pos, view_pos, perturbed_nrm, smooth_nrm, smooth_tng, geom_nrm, two_sided_shading, opengl):\n        ctx.two_sided_shading, ctx.opengl = two_sided_shading, opengl\n        out = _get_plugin().prepare_shading_normal_fwd(pos, view_pos, perturbed_nrm, smooth_nrm, smooth_tng, geom_nrm, two_sided_shading, opengl, False)\n        ctx.save_for_backward(pos, view_pos, perturbed_nrm, smooth_nrm, smooth_tng, geom_nrm)\n        return out\n\n    @staticmethod\n    def backward(ctx, dout):\n        pos, view_pos, perturbed_nrm, smooth_nrm, smooth_tng, geom_nrm = ctx.saved_variables\n        return _get_plugin().prepare_shading_normal_bwd(pos, view_pos, perturbed_nrm, smooth_nrm, smooth_tng, geom_nrm, dout, ctx.two_sided_shading, ctx.opengl) + (None, None, None)\n\ndef prepare_shading_normal(pos, view_pos, perturbed_nrm, smooth_nrm, smooth_tng, geom_nrm, two_sided_shading=True, opengl=True, use_python=False):\n    '''Takes care of all corner cases and produces a final normal used for shading:\n        - Constructs tangent space\n        - Flips normal direction based on geometric normal for two sided Shading\n        - Perturbs shading normal by normal map\n        - Bends backfacing normals towards the camera to avoid shading artifacts\n\n        All tensors assume a shape of [minibatch_size, height, width, 3] or broadcastable equivalent.\n\n    Args:\n        pos: World space g-buffer position.\n        view_pos: Camera position in world space (typically using broadcasting).\n        perturbed_nrm: Trangent-space normal perturbation from normal map lookup.\n        smooth_nrm: Interpolated vertex normals.\n        smooth_tng: Interpolated vertex tangents.\n        geom_nrm: Geometric (face) normals.\n        two_sided_shading: Use one/two sided shading\n        opengl: Use OpenGL/DirectX normal map conventions \n        use_python: Use PyTorch implementation (for validation)\n    Returns:\n        Final shading normal\n    '''    \n\n    if perturbed_nrm is None:\n        perturbed_nrm = torch.tensor([0, 0, 1], dtype=torch.float32, device='cuda', requires_grad=False)[None, None, None, ...]\n    \n    if use_python:\n        out = bsdf_prepare_shading_normal(pos, view_pos, perturbed_nrm, smooth_nrm, smooth_tng, geom_nrm, two_sided_shading, opengl)\n    else:\n        out = _prepare_shading_normal_func.apply(pos, view_pos, perturbed_nrm, smooth_nrm, smooth_tng, geom_nrm, two_sided_shading, opengl)\n    \n    if torch.is_anomaly_enabled():\n        assert torch.all(torch.isfinite(out)), \"Output of prepare_shading_normal contains inf or NaN\"\n    return out\n\n#----------------------------------------------------------------------------\n# BSDF functions\n\nclass _lambert_func(torch.autograd.Function):\n    @staticmethod\n    def forward(ctx, nrm, wi):\n        out = _get_plugin().lambert_fwd(nrm, wi, False)\n        ctx.save_for_backward(nrm, wi)\n        return out\n\n    @staticmethod\n    def backward(ctx, dout):\n        nrm, wi = ctx.saved_variables\n        return _get_plugin().lambert_bwd(nrm, wi, dout) + (None,)\n\ndef lambert(nrm, wi, use_python=False):\n    '''Lambertian bsdf. \n    All tensors assume a shape of [minibatch_size, height, width, 3] or broadcastable equivalent.\n\n    Args:\n        nrm: World space shading normal.\n        wi: World space light vector.\n        use_python: Use PyTorch implementation (for validation)\n\n    Returns:\n        Shaded diffuse value with shape [minibatch_size, height, width, 1]\n    '''\n\n    if use_python:\n        out = bsdf_lambert(nrm, wi)\n    else:\n        out = _lambert_func.apply(nrm, wi)\n \n    if torch.is_anomaly_enabled():\n        assert torch.all(torch.isfinite(out)), \"Output of lambert contains inf or NaN\"\n    return out\n\nclass _frostbite_diffuse_func(torch.autograd.Function):\n    @staticmethod\n    def forward(ctx, nrm, wi, wo, linearRoughness):\n        out = _get_plugin().frostbite_fwd(nrm, wi, wo, linearRoughness, False)\n        ctx.save_for_backward(nrm, wi, wo, linearRoughness)\n        return out\n\n    @staticmethod\n    def backward(ctx, dout):\n        nrm, wi, wo, linearRoughness = ctx.saved_variables\n        return _get_plugin().frostbite_bwd(nrm, wi, wo, linearRoughness, dout) + (None,)\n\ndef frostbite_diffuse(nrm, wi, wo, linearRoughness, use_python=False):\n    '''Frostbite, normalized Disney Diffuse bsdf. \n    All tensors assume a shape of [minibatch_size, height, width, 3] or broadcastable equivalent.\n\n    Args:\n        nrm: World space shading normal.\n        wi: World space light vector.\n        wo: World space camera vector.\n        linearRoughness: Material roughness\n        use_python: Use PyTorch implementation (for validation)\n\n    Returns:\n        Shaded diffuse value with shape [minibatch_size, height, width, 1]\n    '''\n\n    if use_python:\n        out = bsdf_frostbite(nrm, wi, wo, linearRoughness)\n    else:\n        out = _frostbite_diffuse_func.apply(nrm, wi, wo, linearRoughness)\n \n    if torch.is_anomaly_enabled():\n        assert torch.all(torch.isfinite(out)), \"Output of lambert contains inf or NaN\"\n    return out\n\nclass _pbr_specular_func(torch.autograd.Function):\n    @staticmethod\n    def forward(ctx, col, nrm, wo, wi, alpha, min_roughness):\n        ctx.save_for_backward(col, nrm, wo, wi, alpha)\n        ctx.min_roughness = min_roughness\n        out = _get_plugin().pbr_specular_fwd(col, nrm, wo, wi, alpha, min_roughness, False)\n        return out\n\n    @staticmethod\n    def backward(ctx, dout):\n        col, nrm, wo, wi, alpha = ctx.saved_variables\n        return _get_plugin().pbr_specular_bwd(col, nrm, wo, wi, alpha, ctx.min_roughness, dout) + (None, None)\n\ndef pbr_specular(col, nrm, wo, wi, alpha, min_roughness=0.08, use_python=False):\n    '''Physically-based specular bsdf.\n    All tensors assume a shape of [minibatch_size, height, width, 3] or broadcastable equivalent unless otherwise noted.\n\n    Args:\n        col: Specular lobe color\n        nrm: World space shading normal.\n        wo: World space camera vector.\n        wi: World space light vector\n        alpha: Specular roughness parameter with shape [minibatch_size, height, width, 1]\n        min_roughness: Scalar roughness clamping threshold\n\n        use_python: Use PyTorch implementation (for validation)\n    Returns:\n        Shaded specular color\n    '''\n\n    if use_python:\n        out = bsdf_pbr_specular(col, nrm, wo, wi, alpha, min_roughness=min_roughness)\n    else:\n        out = _pbr_specular_func.apply(col, nrm, wo, wi, alpha, min_roughness)\n    \n    if torch.is_anomaly_enabled():\n        assert torch.all(torch.isfinite(out)), \"Output of pbr_specular contains inf or NaN\"\n    return out\n\nclass _pbr_bsdf_func(torch.autograd.Function):\n    @staticmethod\n    def forward(ctx, kd, arm, pos, nrm, view_pos, light_pos, min_roughness, BSDF):\n        ctx.save_for_backward(kd, arm, pos, nrm, view_pos, light_pos)\n        ctx.min_roughness = min_roughness\n        ctx.BSDF = BSDF\n        out = _get_plugin().pbr_bsdf_fwd(kd, arm, pos, nrm, view_pos, light_pos, min_roughness, BSDF, False)\n        return out\n\n    @staticmethod\n    def backward(ctx, dout):\n        kd, arm, pos, nrm, view_pos, light_pos = ctx.saved_variables\n        return _get_plugin().pbr_bsdf_bwd(kd, arm, pos, nrm, view_pos, light_pos, ctx.min_roughness, ctx.BSDF, dout) + (None, None, None)\n\ndef pbr_bsdf(kd, arm, pos, nrm, view_pos, light_pos, min_roughness=0.08, bsdf=\"lambert\", use_python=False):\n    '''Physically-based bsdf, both diffuse & specular lobes\n    All tensors assume a shape of [minibatch_size, height, width, 3] or broadcastable equivalent unless otherwise noted.\n\n    Args:\n        kd: Diffuse albedo.\n        arm: Specular parameters (attenuation, linear roughness, metalness).\n        pos: World space position.\n        nrm: World space shading normal.\n        view_pos: Camera position in world space, typically using broadcasting.\n        light_pos: Light position in world space, typically using broadcasting.\n        min_roughness: Scalar roughness clamping threshold\n        bsdf: Controls diffuse BSDF, can be either 'lambert' or 'frostbite'\n\n        use_python: Use PyTorch implementation (for validation)\n\n    Returns:\n        Shaded color.\n    '''    \n\n    BSDF = 0 \n    if bsdf == 'frostbite':\n        BSDF = 1\n\n    if use_python:\n        out = bsdf_pbr(kd, arm, pos, nrm, view_pos, light_pos, min_roughness, BSDF)\n    else:\n        out = _pbr_bsdf_func.apply(kd, arm, pos, nrm, view_pos, light_pos, min_roughness, BSDF)\n    \n    if torch.is_anomaly_enabled():\n        assert torch.all(torch.isfinite(out)), \"Output of pbr_bsdf contains inf or NaN\"\n    return out\n\n#----------------------------------------------------------------------------\n# cubemap filter with filtering across edges\n\nclass _diffuse_cubemap_func(torch.autograd.Function):\n    @staticmethod\n    def forward(ctx, cubemap):\n        out = _get_plugin().diffuse_cubemap_fwd(cubemap)\n        ctx.save_for_backward(cubemap)\n        return out\n\n    @staticmethod\n    def backward(ctx, dout):\n        cubemap, = ctx.saved_variables\n        cubemap_grad = _get_plugin().diffuse_cubemap_bwd(cubemap, dout)\n        return cubemap_grad, None\n\ndef diffuse_cubemap(cubemap, use_python=False):\n    if use_python:\n        assert False\n    else:\n        out = _diffuse_cubemap_func.apply(cubemap)\n    if torch.is_anomaly_enabled():\n        assert torch.all(torch.isfinite(out)), \"Output of diffuse_cubemap contains inf or NaN\"\n    return out\n\nclass _specular_cubemap(torch.autograd.Function):\n    @staticmethod\n    def forward(ctx, cubemap, roughness, costheta_cutoff, bounds):\n        out = _get_plugin().specular_cubemap_fwd(cubemap, bounds, roughness, costheta_cutoff)\n        ctx.save_for_backward(cubemap, bounds)\n        ctx.roughness, ctx.theta_cutoff = roughness, costheta_cutoff\n        return out\n\n    @staticmethod\n    def backward(ctx, dout):\n        cubemap, bounds = ctx.saved_variables\n        cubemap_grad = _get_plugin().specular_cubemap_bwd(cubemap, bounds, dout, ctx.roughness, ctx.theta_cutoff)\n        return cubemap_grad, None, None, None\n\n# Compute the bounds of the GGX NDF lobe to retain \"cutoff\" percent of the energy\ndef __ndfBounds(res, roughness, cutoff):\n    def ndfGGX(alphaSqr, costheta):\n        costheta = np.clip(costheta, 0.0, 1.0)\n        d = (costheta * alphaSqr - costheta) * costheta + 1.0\n        return alphaSqr / (d * d * np.pi)\n\n    # Sample out cutoff angle\n    nSamples = 1000000\n    costheta = np.cos(np.linspace(0, np.pi/2.0, nSamples))\n    D = np.cumsum(ndfGGX(roughness**4, costheta))\n    idx = np.argmax(D >= D[..., -1] * cutoff)\n\n    # Brute force compute lookup table with bounds\n    bounds = _get_plugin().specular_bounds(res, costheta[idx])\n\n    return costheta[idx], bounds\n__ndfBoundsDict = {}\n\ndef specular_cubemap(cubemap, roughness, cutoff=0.99, use_python=False):\n    assert cubemap.shape[0] == 6 and cubemap.shape[1] == cubemap.shape[2], \"Bad shape for cubemap tensor: %s\" % str(cubemap.shape)\n\n    if use_python:\n        assert False\n    else:\n        key = (cubemap.shape[1], roughness, cutoff)\n        if key not in __ndfBoundsDict:\n            __ndfBoundsDict[key] = __ndfBounds(*key)\n        out = _specular_cubemap.apply(cubemap, roughness, *__ndfBoundsDict[key])\n    if torch.is_anomaly_enabled():\n        assert torch.all(torch.isfinite(out)), \"Output of specular_cubemap contains inf or NaN\"\n    return out[..., 0:3] / out[..., 3:]\n\n#----------------------------------------------------------------------------\n# Fast image loss function\n\nclass _image_loss_func(torch.autograd.Function):\n    @staticmethod\n    def forward(ctx, img, target, loss, tonemapper):\n        ctx.loss, ctx.tonemapper = loss, tonemapper\n        ctx.save_for_backward(img, target)\n        out = _get_plugin().image_loss_fwd(img, target, loss, tonemapper, False)\n        return out\n\n    @staticmethod\n    def backward(ctx, dout):\n        img, target = ctx.saved_variables\n        return _get_plugin().image_loss_bwd(img, target, dout, ctx.loss, ctx.tonemapper) + (None, None, None)\n\ndef image_loss(img, target, loss='l1', tonemapper='none', use_python=False):\n    '''Compute HDR image loss. Combines tonemapping and loss into a single kernel for better perf.\n    All tensors assume a shape of [minibatch_size, height, width, 3] or broadcastable equivalent unless otherwise noted.\n\n    Args:\n        img: Input image.\n        target: Target (reference) image. \n        loss: Type of loss. Valid options are ['l1', 'mse', 'smape', 'relmse']\n        tonemapper: Tonemapping operations. Valid options are ['none', 'log_srgb']\n        use_python: Use PyTorch implementation (for validation)\n\n    Returns:\n        Image space loss (scalar value).\n    '''\n    if use_python:\n        out = image_loss_fn(img, target, loss, tonemapper)\n    else:\n        out = _image_loss_func.apply(img, target, loss, tonemapper)\n        out = torch.sum(out) / (img.shape[0]*img.shape[1]*img.shape[2])\n\n    if torch.is_anomaly_enabled():\n        assert torch.all(torch.isfinite(out)), \"Output of image_loss contains inf or NaN\"\n    return out\n\n#----------------------------------------------------------------------------\n# Transform points function\n\nclass _xfm_func(torch.autograd.Function):\n    @staticmethod\n    def forward(ctx, points, matrix, isPoints):\n        ctx.save_for_backward(points, matrix)\n        ctx.isPoints = isPoints\n        return _get_plugin().xfm_fwd(points, matrix, isPoints, False)\n\n    @staticmethod\n    def backward(ctx, dout):\n        points, matrix = ctx.saved_variables\n        return (_get_plugin().xfm_bwd(points, matrix, dout, ctx.isPoints),) + (None, None, None)\n\ndef xfm_points(points, matrix, use_python=False):\n    '''Transform points.\n    Args:\n        points: Tensor containing 3D points with shape [minibatch_size, num_vertices, 3] or [1, num_vertices, 3]\n        matrix: A 4x4 transform matrix with shape [minibatch_size, 4, 4]\n        use_python: Use PyTorch's torch.matmul (for validation)\n    Returns:\n        Transformed points in homogeneous 4D with shape [minibatch_size, num_vertices, 4].\n    '''    \n    if use_python:\n        out = torch.matmul(torch.nn.functional.pad(points, pad=(0,1), mode='constant', value=1.0), torch.transpose(matrix, 1, 2))\n    else:\n        out = _xfm_func.apply(points, matrix, True)\n\n    if torch.is_anomaly_enabled():\n        assert torch.all(torch.isfinite(out)), \"Output of xfm_points contains inf or NaN\"\n    return out\n\ndef xfm_vectors(vectors, matrix, use_python=False):\n    '''Transform vectors.\n    Args:\n        vectors: Tensor containing 3D vectors with shape [minibatch_size, num_vertices, 3] or [1, num_vertices, 3]\n        matrix: A 4x4 transform matrix with shape [minibatch_size, 4, 4]\n        use_python: Use PyTorch's torch.matmul (for validation)\n\n    Returns:\n        Transformed vectors in homogeneous 4D with shape [minibatch_size, num_vertices, 4].\n    '''    \n\n    if use_python:\n        out = torch.matmul(torch.nn.functional.pad(vectors, pad=(0,1), mode='constant', value=0.0), torch.transpose(matrix, 1, 2))[..., 0:3].contiguous()\n    else:\n        out = _xfm_func.apply(vectors, matrix, False)\n\n    if torch.is_anomaly_enabled():\n        assert torch.all(torch.isfinite(out)), \"Output of xfm_vectors contains inf or NaN\"\n    return out\n\n\n\n"
  },
  {
    "path": "render/renderutils/tests/test_bsdf.py",
    "content": "# Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. \n#\n# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual\n# property and proprietary rights in and to this material, related\n# documentation and any modifications thereto. Any use, reproduction, \n# disclosure or distribution of this material and related documentation \n# without an express license agreement from NVIDIA CORPORATION or \n# its affiliates is strictly prohibited.\n\nimport torch\n\nimport os\nimport sys\nsys.path.insert(0, os.path.join(sys.path[0], '../..'))\nimport renderutils as ru\n\nRES = 4\nDTYPE = torch.float32\n\ndef relative_loss(name, ref, cuda):\n\tref = ref.float()\n\tcuda = cuda.float()\n\tprint(name, torch.max(torch.abs(ref - cuda) / torch.abs(ref + 1e-7)).item())\n\ndef test_normal():\n\tpos_cuda = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True)\n\tpos_ref = pos_cuda.clone().detach().requires_grad_(True)\n\tview_pos_cuda = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True)\n\tview_pos_ref = view_pos_cuda.clone().detach().requires_grad_(True)\n\tperturbed_nrm_cuda = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True)\n\tperturbed_nrm_ref = perturbed_nrm_cuda.clone().detach().requires_grad_(True)\n\tsmooth_nrm_cuda = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True)\n\tsmooth_nrm_ref = smooth_nrm_cuda.clone().detach().requires_grad_(True)\n\tsmooth_tng_cuda = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True)\n\tsmooth_tng_ref = smooth_tng_cuda.clone().detach().requires_grad_(True)\n\tgeom_nrm_cuda = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True)\n\tgeom_nrm_ref = geom_nrm_cuda.clone().detach().requires_grad_(True)\n\ttarget = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda')\n\n\tref = ru.prepare_shading_normal(pos_ref, view_pos_ref, perturbed_nrm_ref, smooth_nrm_ref, smooth_tng_ref, geom_nrm_ref, True, use_python=True)\n\tref_loss = torch.nn.MSELoss()(ref, target)\n\tref_loss.backward()\n\n\tcuda = ru.prepare_shading_normal(pos_cuda, view_pos_cuda, perturbed_nrm_cuda, smooth_nrm_cuda, smooth_tng_cuda, geom_nrm_cuda, True)\n\tcuda_loss = torch.nn.MSELoss()(cuda, target)\n\tcuda_loss.backward()\n\n\tprint(\"-------------------------------------------------------------\")\n\tprint(\"    bent normal\")\n\tprint(\"-------------------------------------------------------------\")\n\trelative_loss(\"res:\", ref, cuda)\n\trelative_loss(\"pos:\", pos_ref.grad, pos_cuda.grad)\n\trelative_loss(\"view_pos:\", view_pos_ref.grad, view_pos_cuda.grad)\n\trelative_loss(\"perturbed_nrm:\", perturbed_nrm_ref.grad, perturbed_nrm_cuda.grad)\n\trelative_loss(\"smooth_nrm:\", smooth_nrm_ref.grad, smooth_nrm_cuda.grad)\n\trelative_loss(\"smooth_tng:\", smooth_tng_ref.grad, smooth_tng_cuda.grad)\n\trelative_loss(\"geom_nrm:\", geom_nrm_ref.grad, geom_nrm_cuda.grad)\n\ndef test_schlick():\n\tf0_cuda = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True)\n\tf0_ref = f0_cuda.clone().detach().requires_grad_(True)\n\tf90_cuda = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True)\n\tf90_ref = f90_cuda.clone().detach().requires_grad_(True)\n\tcosT_cuda = torch.rand(1, RES, RES, 1, dtype=DTYPE, device='cuda', requires_grad=True) * 2.0\n\tcosT_cuda = cosT_cuda.clone().detach().requires_grad_(True)\n\tcosT_ref = cosT_cuda.clone().detach().requires_grad_(True)\n\ttarget = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda')\n\n\tref = ru._fresnel_shlick(f0_ref, f90_ref, cosT_ref, use_python=True)\n\tref_loss = torch.nn.MSELoss()(ref, target)\n\tref_loss.backward()\n\n\tcuda = ru._fresnel_shlick(f0_cuda, f90_cuda, cosT_cuda)\n\tcuda_loss = torch.nn.MSELoss()(cuda, target)\n\tcuda_loss.backward()\n\n\tprint(\"-------------------------------------------------------------\")\n\tprint(\"    Fresnel shlick\")\n\tprint(\"-------------------------------------------------------------\")\n\trelative_loss(\"res:\", ref, cuda)\n\trelative_loss(\"f0:\", f0_ref.grad, f0_cuda.grad)\n\trelative_loss(\"f90:\", f90_ref.grad, f90_cuda.grad)\n\trelative_loss(\"cosT:\", cosT_ref.grad, cosT_cuda.grad)\n\ndef test_ndf_ggx():\n\talphaSqr_cuda = torch.rand(1, RES, RES, 1, dtype=DTYPE, device='cuda', requires_grad=True)\n\talphaSqr_cuda = alphaSqr_cuda.clone().detach().requires_grad_(True)\n\talphaSqr_ref = alphaSqr_cuda.clone().detach().requires_grad_(True)\n\tcosT_cuda = torch.rand(1, RES, RES, 1, dtype=DTYPE, device='cuda', requires_grad=True) * 3.0 - 1\n\tcosT_cuda = cosT_cuda.clone().detach().requires_grad_(True)\n\tcosT_ref = cosT_cuda.clone().detach().requires_grad_(True)\n\ttarget = torch.rand(1, RES, RES, 1, dtype=DTYPE, device='cuda')\n\n\tref = ru._ndf_ggx(alphaSqr_ref, cosT_ref, use_python=True)\n\tref_loss = torch.nn.MSELoss()(ref, target)\n\tref_loss.backward()\n\n\tcuda = ru._ndf_ggx(alphaSqr_cuda, cosT_cuda)\n\tcuda_loss = torch.nn.MSELoss()(cuda, target)\n\tcuda_loss.backward()\n\n\tprint(\"-------------------------------------------------------------\")\n\tprint(\"    Ndf GGX\")\n\tprint(\"-------------------------------------------------------------\")\n\trelative_loss(\"res:\", ref, cuda)\n\trelative_loss(\"alpha:\", alphaSqr_ref.grad, alphaSqr_cuda.grad)\n\trelative_loss(\"cosT:\", cosT_ref.grad, cosT_cuda.grad)\n\ndef test_lambda_ggx():\n\talphaSqr_cuda = torch.rand(1, RES, RES, 1, dtype=DTYPE, device='cuda', requires_grad=True)\n\talphaSqr_ref = alphaSqr_cuda.clone().detach().requires_grad_(True)\n\tcosT_cuda = torch.rand(1, RES, RES, 1, dtype=DTYPE, device='cuda', requires_grad=True) * 3.0 - 1\n\tcosT_cuda = cosT_cuda.clone().detach().requires_grad_(True)\n\tcosT_ref = cosT_cuda.clone().detach().requires_grad_(True)\n\ttarget = torch.rand(1, RES, RES, 1, dtype=DTYPE, device='cuda')\n\n\tref = ru._lambda_ggx(alphaSqr_ref, cosT_ref, use_python=True)\n\tref_loss = torch.nn.MSELoss()(ref, target)\n\tref_loss.backward()\n\n\tcuda = ru._lambda_ggx(alphaSqr_cuda, cosT_cuda)\n\tcuda_loss = torch.nn.MSELoss()(cuda, target)\n\tcuda_loss.backward()\n\n\tprint(\"-------------------------------------------------------------\")\n\tprint(\"    Lambda GGX\")\n\tprint(\"-------------------------------------------------------------\")\n\trelative_loss(\"res:\", ref, cuda)\n\trelative_loss(\"alpha:\", alphaSqr_ref.grad, alphaSqr_cuda.grad)\n\trelative_loss(\"cosT:\", cosT_ref.grad, cosT_cuda.grad)\n\ndef test_masking_smith():\n\talphaSqr_cuda = torch.rand(1, RES, RES, 1, dtype=DTYPE, device='cuda', requires_grad=True)\n\talphaSqr_ref = alphaSqr_cuda.clone().detach().requires_grad_(True)\n\tcosThetaI_cuda = torch.rand(1, RES, RES, 1, dtype=DTYPE, device='cuda', requires_grad=True)\n\tcosThetaI_ref = cosThetaI_cuda.clone().detach().requires_grad_(True)\n\tcosThetaO_cuda = torch.rand(1, RES, RES, 1, dtype=DTYPE, device='cuda', requires_grad=True)\n\tcosThetaO_ref = cosThetaO_cuda.clone().detach().requires_grad_(True)\n\ttarget = torch.rand(1, RES, RES, 1, dtype=DTYPE, device='cuda')\n\n\tref = ru._masking_smith(alphaSqr_ref, cosThetaI_ref, cosThetaO_ref, use_python=True)\n\tref_loss = torch.nn.MSELoss()(ref, target)\n\tref_loss.backward()\n\n\tcuda = ru._masking_smith(alphaSqr_cuda, cosThetaI_cuda, cosThetaO_cuda)\n\tcuda_loss = torch.nn.MSELoss()(cuda, target)\n\tcuda_loss.backward()\n\n\tprint(\"-------------------------------------------------------------\")\n\tprint(\"    Smith masking term\")\n\tprint(\"-------------------------------------------------------------\")\n\trelative_loss(\"res:\", ref, cuda)\n\trelative_loss(\"alpha:\", alphaSqr_ref.grad, alphaSqr_cuda.grad)\n\trelative_loss(\"cosThetaI:\", cosThetaI_ref.grad, cosThetaI_cuda.grad)\n\trelative_loss(\"cosThetaO:\", cosThetaO_ref.grad, cosThetaO_cuda.grad)\n\ndef test_lambert():\n\tnormals_cuda = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True)\n\tnormals_ref = normals_cuda.clone().detach().requires_grad_(True)\n\twi_cuda = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True)\n\twi_ref = wi_cuda.clone().detach().requires_grad_(True)\n\ttarget = torch.rand(1, RES, RES, 1, dtype=DTYPE, device='cuda')\n\n\tref = ru.lambert(normals_ref, wi_ref, use_python=True)\n\tref_loss = torch.nn.MSELoss()(ref, target)\n\tref_loss.backward()\n\n\tcuda = ru.lambert(normals_cuda, wi_cuda)\n\tcuda_loss = torch.nn.MSELoss()(cuda, target)\n\tcuda_loss.backward()\n\n\tprint(\"-------------------------------------------------------------\")\n\tprint(\"    Lambert\")\n\tprint(\"-------------------------------------------------------------\")\n\trelative_loss(\"res:\", ref, cuda)\n\trelative_loss(\"nrm:\", normals_ref.grad, normals_cuda.grad)\n\trelative_loss(\"wi:\", wi_ref.grad, wi_cuda.grad)\n\ndef test_frostbite():\n\tnormals_cuda = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True)\n\tnormals_ref = normals_cuda.clone().detach().requires_grad_(True)\n\twi_cuda = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True)\n\twi_ref = wi_cuda.clone().detach().requires_grad_(True)\n\two_cuda = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True)\n\two_ref = wo_cuda.clone().detach().requires_grad_(True)\n\trough_cuda = torch.rand(1, RES, RES, 1, dtype=DTYPE, device='cuda', requires_grad=True)\n\trough_ref = rough_cuda.clone().detach().requires_grad_(True)\n\ttarget = torch.rand(1, RES, RES, 1, dtype=DTYPE, device='cuda')\n\n\tref = ru.frostbite_diffuse(normals_ref, wi_ref, wo_ref, rough_ref, use_python=True)\n\tref_loss = torch.nn.MSELoss()(ref, target)\n\tref_loss.backward()\n\n\tcuda = ru.frostbite_diffuse(normals_cuda, wi_cuda, wo_cuda, rough_cuda)\n\tcuda_loss = torch.nn.MSELoss()(cuda, target)\n\tcuda_loss.backward()\n\n\tprint(\"-------------------------------------------------------------\")\n\tprint(\"    Frostbite\")\n\tprint(\"-------------------------------------------------------------\")\n\trelative_loss(\"res:\", ref, cuda)\n\trelative_loss(\"nrm:\", normals_ref.grad, normals_cuda.grad)\n\trelative_loss(\"wo:\", wo_ref.grad, wo_cuda.grad)\n\trelative_loss(\"wi:\", wi_ref.grad, wi_cuda.grad)\n\trelative_loss(\"rough:\", rough_ref.grad, rough_cuda.grad)\n\ndef test_pbr_specular():\n\tcol_cuda = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True)\n\tcol_ref = col_cuda.clone().detach().requires_grad_(True)\n\tnrm_cuda = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True)\n\tnrm_ref = nrm_cuda.clone().detach().requires_grad_(True)\n\twi_cuda = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True)\n\twi_ref = wi_cuda.clone().detach().requires_grad_(True)\n\two_cuda = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True)\n\two_ref = wo_cuda.clone().detach().requires_grad_(True)\n\talpha_cuda = torch.rand(1, RES, RES, 1, dtype=DTYPE, device='cuda', requires_grad=True)\n\talpha_ref = alpha_cuda.clone().detach().requires_grad_(True)\n\ttarget = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda')\n\n\tref = ru.pbr_specular(col_ref, nrm_ref, wo_ref, wi_ref, alpha_ref, use_python=True)\n\tref_loss = torch.nn.MSELoss()(ref, target)\n\tref_loss.backward()\n\n\tcuda = ru.pbr_specular(col_cuda, nrm_cuda, wo_cuda, wi_cuda, alpha_cuda)\n\tcuda_loss = torch.nn.MSELoss()(cuda, target)\n\tcuda_loss.backward()\n\n\tprint(\"-------------------------------------------------------------\")\n\tprint(\"    Pbr specular\")\n\tprint(\"-------------------------------------------------------------\")\n\n\trelative_loss(\"res:\", ref, cuda)\n\tif col_ref.grad is not None:\n\t\trelative_loss(\"col:\", col_ref.grad, col_cuda.grad)\n\tif nrm_ref.grad is not None:\n\t\trelative_loss(\"nrm:\", nrm_ref.grad, nrm_cuda.grad)\n\tif wi_ref.grad is not None:\n\t\trelative_loss(\"wi:\", wi_ref.grad, wi_cuda.grad)\n\tif wo_ref.grad is not None:\n\t\trelative_loss(\"wo:\", wo_ref.grad, wo_cuda.grad)\n\tif alpha_ref.grad is not None:\n\t\trelative_loss(\"alpha:\", alpha_ref.grad, alpha_cuda.grad)\n\ndef test_pbr_bsdf(bsdf):\n\tkd_cuda = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True)\n\tkd_ref = kd_cuda.clone().detach().requires_grad_(True)\n\tarm_cuda = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True)\n\tarm_ref = arm_cuda.clone().detach().requires_grad_(True)\n\tpos_cuda = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True)\n\tpos_ref = pos_cuda.clone().detach().requires_grad_(True)\n\tnrm_cuda = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True)\n\tnrm_ref = nrm_cuda.clone().detach().requires_grad_(True)\n\tview_cuda = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True)\n\tview_ref = view_cuda.clone().detach().requires_grad_(True)\n\tlight_cuda = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True)\n\tlight_ref = light_cuda.clone().detach().requires_grad_(True)\n\ttarget = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda')\n\n\tref = ru.pbr_bsdf(kd_ref, arm_ref, pos_ref, nrm_ref, view_ref, light_ref, use_python=True, bsdf=bsdf)\n\tref_loss = torch.nn.MSELoss()(ref, target)\n\tref_loss.backward()\n\n\tcuda = ru.pbr_bsdf(kd_cuda, arm_cuda, pos_cuda, nrm_cuda, view_cuda, light_cuda, bsdf=bsdf)\n\tcuda_loss = torch.nn.MSELoss()(cuda, target)\n\tcuda_loss.backward()\n\n\tprint(\"-------------------------------------------------------------\")\n\tprint(\"    Pbr BSDF\")\n\tprint(\"-------------------------------------------------------------\")\n\n\trelative_loss(\"res:\", ref, cuda)\n\tif kd_ref.grad is not None:\n\t\trelative_loss(\"kd:\", kd_ref.grad, kd_cuda.grad)\n\tif arm_ref.grad is not None:\n\t\trelative_loss(\"arm:\", arm_ref.grad, arm_cuda.grad)\n\tif pos_ref.grad is not None:\n\t\trelative_loss(\"pos:\", pos_ref.grad, pos_cuda.grad)\n\tif nrm_ref.grad is not None:\n\t\trelative_loss(\"nrm:\", nrm_ref.grad, nrm_cuda.grad)\n\tif view_ref.grad is not None:\n\t\trelative_loss(\"view:\", view_ref.grad, view_cuda.grad)\n\tif light_ref.grad is not None:\n\t\trelative_loss(\"light:\", light_ref.grad, light_cuda.grad)\n\ntest_normal()\n\ntest_schlick()\ntest_ndf_ggx()\ntest_lambda_ggx()\ntest_masking_smith()\n\ntest_lambert()\ntest_frostbite()\ntest_pbr_specular()\ntest_pbr_bsdf('lambert')\ntest_pbr_bsdf('frostbite')\n"
  },
  {
    "path": "render/renderutils/tests/test_loss.py",
    "content": "# Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. \n#\n# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual\n# property and proprietary rights in and to this material, related\n# documentation and any modifications thereto. Any use, reproduction, \n# disclosure or distribution of this material and related documentation \n# without an express license agreement from NVIDIA CORPORATION or \n# its affiliates is strictly prohibited.\n\nimport torch\n\nimport os\nimport sys\nsys.path.insert(0, os.path.join(sys.path[0], '../..'))\nimport renderutils as ru\n\nRES = 8\nDTYPE = torch.float32\n\ndef tonemap_srgb(f):\n    return torch.where(f > 0.0031308, torch.pow(torch.clamp(f, min=0.0031308), 1.0/2.4)*1.055 - 0.055, 12.92*f)\n\ndef l1(output, target):\n    x = torch.clamp(output, min=0, max=65535)\n    r = torch.clamp(target, min=0, max=65535)\n    x = tonemap_srgb(torch.log(x + 1))\n    r = tonemap_srgb(torch.log(r + 1))\n    return torch.nn.functional.l1_loss(x,r)\n\ndef relative_loss(name, ref, cuda):\n\tref = ref.float()\n\tcuda = cuda.float()\n\tprint(name, torch.max(torch.abs(ref - cuda) / torch.abs(ref + 1e-7)).item())\n\ndef test_loss(loss, tonemapper):\n\timg_cuda = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True)\n\timg_ref = img_cuda.clone().detach().requires_grad_(True)\n\ttarget_cuda = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True)\n\ttarget_ref = target_cuda.clone().detach().requires_grad_(True)\n\n\tref_loss = ru.image_loss(img_ref, target_ref, loss=loss, tonemapper=tonemapper, use_python=True)\n\tref_loss.backward()\n\n\tcuda_loss = ru.image_loss(img_cuda, target_cuda, loss=loss, tonemapper=tonemapper)\n\tcuda_loss.backward()\n\n\tprint(\"-------------------------------------------------------------\")\n\tprint(\"    Loss: %s, %s\" % (loss, tonemapper))\n\tprint(\"-------------------------------------------------------------\")\n\n\trelative_loss(\"res:\", ref_loss, cuda_loss)\n\trelative_loss(\"img:\", img_ref.grad, img_cuda.grad)\n\trelative_loss(\"target:\", target_ref.grad, target_cuda.grad)\n\n\ntest_loss('l1', 'none')\ntest_loss('l1', 'log_srgb')\ntest_loss('mse', 'log_srgb')\ntest_loss('smape', 'none')\ntest_loss('relmse', 'none')\ntest_loss('mse', 'none')"
  },
  {
    "path": "render/renderutils/tests/test_mesh.py",
    "content": "# Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. \n#\n# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual\n# property and proprietary rights in and to this material, related\n# documentation and any modifications thereto. Any use, reproduction, \n# disclosure or distribution of this material and related documentation \n# without an express license agreement from NVIDIA CORPORATION or \n# its affiliates is strictly prohibited.\n\nimport torch\n\nimport os\nimport sys\nsys.path.insert(0, os.path.join(sys.path[0], '../..'))\nimport renderutils as ru\n\nBATCH = 8\nRES = 1024\nDTYPE = torch.float32\n\ntorch.manual_seed(0)\n\ndef tonemap_srgb(f):\n    return torch.where(f > 0.0031308, torch.pow(torch.clamp(f, min=0.0031308), 1.0/2.4)*1.055 - 0.055, 12.92*f)\n\ndef l1(output, target):\n    x = torch.clamp(output, min=0, max=65535)\n    r = torch.clamp(target, min=0, max=65535)\n    x = tonemap_srgb(torch.log(x + 1))\n    r = tonemap_srgb(torch.log(r + 1))\n    return torch.nn.functional.l1_loss(x,r)\n\ndef relative_loss(name, ref, cuda):\n\tref = ref.float()\n\tcuda = cuda.float()\n\tprint(name, torch.max(torch.abs(ref - cuda) / torch.abs(ref)).item())\n\ndef test_xfm_points():\n\tpoints_cuda = torch.rand(1, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True)\n\tpoints_ref = points_cuda.clone().detach().requires_grad_(True)\n\tmtx_cuda = torch.rand(BATCH, 4, 4, dtype=DTYPE, device='cuda', requires_grad=False)\n\tmtx_ref = mtx_cuda.clone().detach().requires_grad_(True)\n\ttarget = torch.rand(BATCH, RES, 4, dtype=DTYPE, device='cuda', requires_grad=True)\n\n\tref_out = ru.xfm_points(points_ref, mtx_ref, use_python=True)\n\tref_loss = torch.nn.MSELoss()(ref_out, target)\n\tref_loss.backward()\n\n\tcuda_out = ru.xfm_points(points_cuda, mtx_cuda)\n\tcuda_loss = torch.nn.MSELoss()(cuda_out, target)\n\tcuda_loss.backward()\n\n\tprint(\"-------------------------------------------------------------\")\n\n\trelative_loss(\"res:\", ref_out, cuda_out)\n\trelative_loss(\"points:\", points_ref.grad, points_cuda.grad)\n\ndef test_xfm_vectors():\n\tpoints_cuda = torch.rand(1, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True)\n\tpoints_ref = points_cuda.clone().detach().requires_grad_(True)\n\tpoints_cuda_p = points_cuda.clone().detach().requires_grad_(True)\n\tpoints_ref_p = points_cuda.clone().detach().requires_grad_(True)\n\tmtx_cuda = torch.rand(BATCH, 4, 4, dtype=DTYPE, device='cuda', requires_grad=False)\n\tmtx_ref = mtx_cuda.clone().detach().requires_grad_(True)\n\ttarget = torch.rand(BATCH, RES, 4, dtype=DTYPE, device='cuda', requires_grad=True)\n\n\tref_out = ru.xfm_vectors(points_ref.contiguous(), mtx_ref, use_python=True)\n\tref_loss = torch.nn.MSELoss()(ref_out, target[..., 0:3])\n\tref_loss.backward()\n\n\tcuda_out = ru.xfm_vectors(points_cuda.contiguous(), mtx_cuda)\n\tcuda_loss = torch.nn.MSELoss()(cuda_out, target[..., 0:3])\n\tcuda_loss.backward()\n\n\tref_out_p = ru.xfm_points(points_ref_p.contiguous(), mtx_ref, use_python=True)\n\tref_loss_p = torch.nn.MSELoss()(ref_out_p, target)\n\tref_loss_p.backward()\n\t\n\tcuda_out_p = ru.xfm_points(points_cuda_p.contiguous(), mtx_cuda)\n\tcuda_loss_p = torch.nn.MSELoss()(cuda_out_p, target)\n\tcuda_loss_p.backward()\n\n\tprint(\"-------------------------------------------------------------\")\n\n\trelative_loss(\"res:\", ref_out, cuda_out)\n\trelative_loss(\"points:\", points_ref.grad, points_cuda.grad)\n\trelative_loss(\"points_p:\", points_ref_p.grad, points_cuda_p.grad)\n\ntest_xfm_points()\ntest_xfm_vectors()\n"
  },
  {
    "path": "render/renderutils/tests/test_perf.py",
    "content": "# Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. \n#\n# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual\n# property and proprietary rights in and to this material, related\n# documentation and any modifications thereto. Any use, reproduction, \n# disclosure or distribution of this material and related documentation \n# without an express license agreement from NVIDIA CORPORATION or \n# its affiliates is strictly prohibited.\n\nimport torch\n\nimport os\nimport sys\nsys.path.insert(0, os.path.join(sys.path[0], '../..'))\nimport renderutils as ru\n\nDTYPE=torch.float32\n\ndef test_bsdf(BATCH, RES, ITR):\n\tkd_cuda = torch.rand(BATCH, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True)\n\tkd_ref = kd_cuda.clone().detach().requires_grad_(True)\n\tarm_cuda = torch.rand(BATCH, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True)\n\tarm_ref = arm_cuda.clone().detach().requires_grad_(True)\n\tpos_cuda = torch.rand(BATCH, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True)\n\tpos_ref = pos_cuda.clone().detach().requires_grad_(True)\n\tnrm_cuda = torch.rand(BATCH, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True)\n\tnrm_ref = nrm_cuda.clone().detach().requires_grad_(True)\n\tview_cuda = torch.rand(BATCH, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True)\n\tview_ref = view_cuda.clone().detach().requires_grad_(True)\n\tlight_cuda = torch.rand(BATCH, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True)\n\tlight_ref = light_cuda.clone().detach().requires_grad_(True)\n\ttarget = torch.rand(BATCH, RES, RES, 3, device='cuda')\n\n\tstart = torch.cuda.Event(enable_timing=True)\n\tend = torch.cuda.Event(enable_timing=True)\n\n\tru.pbr_bsdf(kd_cuda, arm_cuda, pos_cuda, nrm_cuda, view_cuda, light_cuda)\n\n\tprint(\"--- Testing: [%d, %d, %d] ---\" % (BATCH, RES, RES))\n\n\tstart.record()\n\tfor i in range(ITR):\n\t\tref = ru.pbr_bsdf(kd_ref, arm_ref, pos_ref, nrm_ref, view_ref, light_ref, use_python=True)\n\tend.record()\n\ttorch.cuda.synchronize()\n\tprint(\"Pbr BSDF python:\", start.elapsed_time(end))\n\n\tstart.record()\n\tfor i in range(ITR):\n\t\tcuda = ru.pbr_bsdf(kd_cuda, arm_cuda, pos_cuda, nrm_cuda, view_cuda, light_cuda)\n\tend.record()\n\ttorch.cuda.synchronize()\n\tprint(\"Pbr BSDF cuda:\", start.elapsed_time(end))\n\ntest_bsdf(1, 512, 1000)\ntest_bsdf(16, 512, 1000)\ntest_bsdf(1, 2048, 1000)\n"
  },
  {
    "path": "render/texture.py",
    "content": "# Copyright (c) 2020-2022, NVIDIA CORPORATION. All rights reserved.\n#\n# NVIDIA CORPORATION and its licensors retain all intellectual property\n# and proprietary rights in and to this software, related documentation\n# and any modifications thereto. Any use, reproduction, disclosure or\n# distribution of this software and related documentation without an express\n# license agreement from NVIDIA CORPORATION is strictly prohibited.\n\nimport os\nimport numpy as np\nimport torch\nimport nvdiffrast.torch as dr\n\nfrom . import util\n\n######################################################################################\n# Smooth pooling / mip computation with linear gradient upscaling\n######################################################################################\n\nclass texture2d_mip(torch.autograd.Function):\n    @staticmethod\n    def forward(ctx, texture):\n        return util.avg_pool_nhwc(texture, (2,2))\n\n    @staticmethod\n    def backward(ctx, dout):\n        gy, gx = torch.meshgrid(torch.linspace(0.0 + 0.25 / dout.shape[1], 1.0 - 0.25 / dout.shape[1], dout.shape[1]*2, device=\"cuda\"), \n                                torch.linspace(0.0 + 0.25 / dout.shape[2], 1.0 - 0.25 / dout.shape[2], dout.shape[2]*2, device=\"cuda\"))\n        uv = torch.stack((gx, gy), dim=-1)\n        return dr.texture(dout * 0.25, uv[None, ...].contiguous(), filter_mode='linear', boundary_mode='clamp')\n\n########################################################################################################\n# Simple texture class. A texture can be either \n# - A 3D tensor (using auto mipmaps)\n# - A list of 3D tensors (full custom mip hierarchy)\n########################################################################################################\n\nclass Texture2D:\n     # Initializes a texture from image data.\n     # Input can be constant value (1D array) or texture (3D array) or mip hierarchy (list of 3d arrays)\n    def __init__(self, init, min_max=None):\n        if isinstance(init, np.ndarray):\n            init = torch.tensor(init, dtype=torch.float32, device='cuda')\n        elif isinstance(init, list) and len(init) == 1:\n            init = init[0]\n\n        if isinstance(init, list) or len(init.shape) == 4:\n            self.data = init\n        elif len(init.shape) == 3:\n            self.data = init[None, ...]\n        else:\n            self.data = init[None, None, None, :] # Convert constant to 1x1 tensor\n\n        self.min_max = min_max\n\n    # Filtered (trilinear) sample texture at a given location\n    def sample(self, texc, texc_deriv, filter_mode='linear-mipmap-linear'):\n        if isinstance(self.data, list):\n            out = dr.texture(self.data[0], texc, texc_deriv, mip=self.data[1:], filter_mode=filter_mode)\n        else:\n            if self.data.shape[1] > 1 and self.data.shape[2] > 1:\n                mips = [self.data]\n                while mips[-1].shape[1] > 1 and mips[-1].shape[2] > 1:\n                    mips += [texture2d_mip.apply(mips[-1])]\n                out = dr.texture(mips[0], texc, texc_deriv, mip=mips[1:], filter_mode=filter_mode)\n            else:\n                out = dr.texture(self.data, texc, texc_deriv, filter_mode=filter_mode)\n        return out\n\n    def getRes(self):\n        return self.getMips()[0].shape[1:3]\n\n    def getChannels(self):\n        return self.getMips()[0].shape[3]\n\n    def getMips(self):\n        if isinstance(self.data, list):\n            return self.data\n        else:\n            return [self.data]\n\n    def parameters(self):\n        return self.getMips()\n\n    # In-place clamp with no derivative to make sure values are in valid range after training\n    def clamp_(self):\n        if self.min_max is not None:\n            for mip in self.getMips():\n                for i in range(mip.shape[-1]):\n                    mip[..., i].clamp_(min=self.min_max[0][i], max=self.min_max[1][i])\n\n    # In-place clamp with no derivative to make sure values are in valid range after training\n    def normalize_(self):\n        with torch.no_grad():\n            for mip in self.getMips():\n                mip.copy_(util.safe_normalize(mip))\n\n########################################################################################################\n# Helper function to create a trainable texture from a regular texture. The trainable weights are \n# initialized with texture data as an initial guess\n########################################################################################################\n\ndef create_trainable(init, res=None, auto_mipmaps=True, min_max=None):\n    with torch.no_grad():\n        if isinstance(init, Texture2D):\n            assert isinstance(init.data, torch.Tensor)\n            min_max = init.min_max if min_max is None else min_max\n            init = init.data\n        elif isinstance(init, np.ndarray):\n            init = torch.tensor(init, dtype=torch.float32, device='cuda')\n\n        # Pad to NHWC if needed\n        if len(init.shape) == 1: # Extend constant to NHWC tensor\n            init = init[None, None, None, :]\n        elif len(init.shape) == 3:\n            init = init[None, ...]\n\n        # Scale input to desired resolution.\n        if res is not None:\n            init = util.scale_img_nhwc(init, res)\n\n        # Generate custom mipchain\n        if not auto_mipmaps:\n            mip_chain = [init.clone().detach().requires_grad_(True)]\n            while mip_chain[-1].shape[1] > 1 or mip_chain[-1].shape[2] > 1:\n                new_size = [max(mip_chain[-1].shape[1] // 2, 1), max(mip_chain[-1].shape[2] // 2, 1)]\n                init = util.scale_img_nhwc(mip_chain[-1], new_size)\n                mip_chain += [init.clone().detach().requires_grad_(True)]\n            return Texture2D(mip_chain, min_max=min_max)\n        else:\n            return Texture2D(init.clone().detach().requires_grad_(True), min_max=min_max)\n\n########################################################################################################\n# Convert texture to and from SRGB\n########################################################################################################\n\ndef srgb_to_rgb(texture):\n    return Texture2D(list(util.srgb_to_rgb(mip) for mip in texture.getMips()))\n\ndef rgb_to_srgb(texture):\n    return Texture2D(list(util.rgb_to_srgb(mip) for mip in texture.getMips()))\n\n########################################################################################################\n# Utility functions for loading / storing a texture\n########################################################################################################\n\ndef _load_mip2D(fn, lambda_fn=None, channels=None):\n    imgdata = torch.tensor(util.load_image(fn), dtype=torch.float32, device='cuda')\n    if channels is not None:\n        imgdata = imgdata[..., 0:channels]\n    if lambda_fn is not None:\n        imgdata = lambda_fn(imgdata)\n    return imgdata.detach().clone()\n\ndef load_texture2D(fn, lambda_fn=None, channels=None):\n    base, ext = os.path.splitext(fn)\n    if os.path.exists(base + \"_0\" + ext):\n        mips = []\n        while os.path.exists(base + (\"_%d\" % len(mips)) + ext):\n            mips += [_load_mip2D(base + (\"_%d\" % len(mips)) + ext, lambda_fn, channels)]\n        return Texture2D(mips)\n    else:\n        return Texture2D(_load_mip2D(fn, lambda_fn, channels))\n\ndef _save_mip2D(fn, mip, mipidx, lambda_fn):\n    if lambda_fn is not None:\n        data = lambda_fn(mip).detach().cpu().numpy()\n    else:\n        data = mip.detach().cpu().numpy()\n\n    if mipidx is None:\n        util.save_image(fn, data)\n    else:\n        base, ext = os.path.splitext(fn)\n        util.save_image(base + (\"_%d\" % mipidx) + ext, data)\n\ndef save_texture2D(fn, tex, lambda_fn=None):\n    if isinstance(tex.data, list):\n        for i, mip in enumerate(tex.data):\n            _save_mip2D(fn, mip[0,...], i, lambda_fn)\n    else:\n        _save_mip2D(fn, tex.data[0,...], None, lambda_fn)"
  },
  {
    "path": "render/util.py",
    "content": "# Copyright (c) 2020-2022, NVIDIA CORPORATION. All rights reserved.\n#\n# NVIDIA CORPORATION and its licensors retain all intellectual property\n# and proprietary rights in and to this software, related documentation\n# and any modifications thereto. Any use, reproduction, disclosure or\n# distribution of this software and related documentation without an express\n# license agreement from NVIDIA CORPORATION is strictly prohibited.\n\nimport os\nimport numpy as np\nimport torch\nimport nvdiffrast.torch as dr\nimport imageio\n\n#----------------------------------------------------------------------------\n# Vector operations\n#----------------------------------------------------------------------------\n\ndef dot(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:\n    return torch.sum(x*y, -1, keepdim=True)\n\ndef reflect(x: torch.Tensor, n: torch.Tensor) -> torch.Tensor:\n    return 2*dot(x, n)*n - x\n\ndef length(x: torch.Tensor, eps: float =1e-20) -> torch.Tensor:\n    return torch.sqrt(torch.clamp(dot(x,x), min=eps)) # Clamp to avoid nan gradients because grad(sqrt(0)) = NaN\n\ndef safe_normalize(x: torch.Tensor, eps: float =1e-20) -> torch.Tensor:\n    return x / length(x, eps)\n\ndef to_hvec(x: torch.Tensor, w: float) -> torch.Tensor:\n    return torch.nn.functional.pad(x, pad=(0,1), mode='constant', value=w)\n\ndef ycocg2rgb(ycocg):\n    return torch.stack((\n        ycocg[..., 0] + ycocg[..., 1] - ycocg[..., 2],\n        ycocg[..., 0] + ycocg[..., 2],\n        ycocg[..., 0] - ycocg[..., 1] - ycocg[..., 2]\n        ), dim=-1)\n    \ndef hsv2rgb(image): # Based on https://kornia.readthedocs.io/en/latest/_modules/kornia/color/hsv.html\n    h, s, v = image[..., 0], image[..., 1], image[..., 2]\n\n    hi = torch.floor(h * 6) % 6\n    f  = ((h * 6) % 6) - hi\n    p = v * (1 - s)\n    q = v * (1 - f * s)\n    t = v * (1 - (1 - f) * s)\n\n    hi = hi.long()\n    indices = torch.stack([hi, hi + 6, hi + 12], dim=-1)\n    out = torch.stack((v, q, p, p, t, v, t, v, v, q, p, p, p, p, t, v, v, q), dim=-1)\n    out = torch.gather(out, -1, indices)\n\n    return out\n\n#----------------------------------------------------------------------------\n# Pixel grid from resolution\n#----------------------------------------------------------------------------\n\ndef pixel_grid(width, height, center_x = 0.5, center_y = 0.5):\n    y, x = torch.meshgrid(\n            (torch.arange(0, height, dtype=torch.float32, device=\"cuda\") + center_y) / height, \n            (torch.arange(0, width, dtype=torch.float32, device=\"cuda\") + center_x) / width)\n    return torch.stack((x, y), dim=-1)\n\n#----------------------------------------------------------------------------\n# Dilation filter\n#----------------------------------------------------------------------------\ndef dilate(x, x_avg, mask, N):\n    def _gaussian():\n        variance = (1.0 / 2.5)**2.\n\n        grid_y, grid_x = torch.meshgrid(torch.linspace(-1, 1, N, dtype=torch.float32, device=\"cuda\"), torch.linspace(-1, 1, N, dtype=torch.float32, device=\"cuda\"))\n        xy_grid = torch.stack([grid_x, grid_y], dim=-1)\n\n        gaussian_kernel = (.5*np.pi*variance) * torch.exp(-torch.sum(xy_grid**2., dim=-1) / (2*variance))\n        return gaussian_kernel / torch.sum(gaussian_kernel)\n\n    def _w(c, cN):\n        return torch.stack(list(_gaussian() if i == c else torch.zeros(N, N, dtype=torch.float32, device=\"cuda\") for i in range(cN)), dim=0)\n\n    epsilon = 1e-6\n    weights = torch.stack(list(_w(i, x.shape[3]) for i in range(x.shape[3])), dim=0)\n    mask_flt = torch.nn.functional.conv2d(mask.permute(0, 3, 1, 2), weights[0:1, 0:1, ...], padding=N//2).permute(0, 2, 3, 1)\n    x_flt = torch.nn.functional.conv2d((x * mask).permute(0, 3, 1, 2), weights, padding=N//2).permute(0, 2, 3, 1)\n    x_flt = torch.where(mask_flt > epsilon, x_flt / torch.clamp(mask_flt, min=epsilon), x_avg)\n    return x_flt * (1 - mask) + x * mask\n\n#----------------------------------------------------------------------------\n# sRGB color transforms\n#----------------------------------------------------------------------------\n\ndef _rgb_to_srgb(f: torch.Tensor) -> torch.Tensor:\n    return torch.where(f <= 0.0031308, f * 12.92, torch.pow(torch.clamp(f, 0.0031308), 1.0/2.4)*1.055 - 0.055)\n\ndef rgb_to_srgb(f: torch.Tensor) -> torch.Tensor:\n    assert f.shape[-1] == 3 or f.shape[-1] == 4\n    out = torch.cat((_rgb_to_srgb(f[..., 0:3]), f[..., 3:4]), dim=-1) if f.shape[-1] == 4 else _rgb_to_srgb(f)\n    assert out.shape[0] == f.shape[0] and out.shape[1] == f.shape[1] and out.shape[2] == f.shape[2]\n    return out\n\ndef _srgb_to_rgb(f: torch.Tensor) -> torch.Tensor:\n    return torch.where(f <= 0.04045, f / 12.92, torch.pow((torch.clamp(f, 0.04045) + 0.055) / 1.055, 2.4))\n\ndef srgb_to_rgb(f: torch.Tensor) -> torch.Tensor:\n    assert f.shape[-1] == 3 or f.shape[-1] == 4\n    out = torch.cat((_srgb_to_rgb(f[..., 0:3]), f[..., 3:4]), dim=-1) if f.shape[-1] == 4 else _srgb_to_rgb(f)\n    assert out.shape[0] == f.shape[0] and out.shape[1] == f.shape[1] and out.shape[2] == f.shape[2]\n    return out\n\ndef reinhard(f: torch.Tensor) -> torch.Tensor:\n    return f/(1+f)\n\n#-----------------------------------------------------------------------------------\n# Metrics (taken from jaxNerf source code, in order to replicate their measurements)\n#\n# https://github.com/google-research/google-research/blob/301451a62102b046bbeebff49a760ebeec9707b8/jaxnerf/nerf/utils.py#L266\n#\n#-----------------------------------------------------------------------------------\n\ndef mse_to_psnr(mse):\n  \"\"\"Compute PSNR given an MSE (we assume the maximum pixel value is 1).\"\"\"\n  return -10. / np.log(10.) * np.log(mse)\n\ndef psnr_to_mse(psnr):\n  \"\"\"Compute MSE given a PSNR (we assume the maximum pixel value is 1).\"\"\"\n  return np.exp(-0.1 * np.log(10.) * psnr)\n\n#----------------------------------------------------------------------------\n# Displacement texture lookup\n#----------------------------------------------------------------------------\n\ndef get_miplevels(texture: np.ndarray) -> float:\n    minDim = min(texture.shape[0], texture.shape[1])\n    return np.floor(np.log2(minDim))\n\ndef tex_2d(tex_map : torch.Tensor, coords : torch.Tensor, filter='nearest') -> torch.Tensor:\n    tex_map = tex_map[None, ...]    # Add batch dimension\n    tex_map = tex_map.permute(0, 3, 1, 2) # NHWC -> NCHW\n    tex = torch.nn.functional.grid_sample(tex_map, coords[None, None, ...] * 2 - 1, mode=filter, align_corners=False)\n    tex = tex.permute(0, 2, 3, 1) # NCHW -> NHWC\n    return tex[0, 0, ...]\n\n#----------------------------------------------------------------------------\n# Cubemap utility functions\n#----------------------------------------------------------------------------\n\ndef cube_to_dir(s, x, y):\n    if s == 0:   rx, ry, rz = torch.ones_like(x), -y, -x\n    elif s == 1: rx, ry, rz = -torch.ones_like(x), -y, x\n    elif s == 2: rx, ry, rz = x, torch.ones_like(x), y\n    elif s == 3: rx, ry, rz = x, -torch.ones_like(x), -y\n    elif s == 4: rx, ry, rz = x, -y, torch.ones_like(x)\n    elif s == 5: rx, ry, rz = -x, -y, -torch.ones_like(x)\n    return torch.stack((rx, ry, rz), dim=-1)\n\ndef latlong_to_cubemap(latlong_map, res):\n    cubemap = torch.zeros(6, res[0], res[1], latlong_map.shape[-1], dtype=torch.float32, device='cuda')\n    for s in range(6):\n        gy, gx = torch.meshgrid(torch.linspace(-1.0 + 1.0 / res[0], 1.0 - 1.0 / res[0], res[0], device='cuda'), \n                                torch.linspace(-1.0 + 1.0 / res[1], 1.0 - 1.0 / res[1], res[1], device='cuda'),\n                                indexing='ij')\n        v = safe_normalize(cube_to_dir(s, gx, gy))\n\n        tu = torch.atan2(v[..., 0:1], -v[..., 2:3]) / (2 * np.pi) + 0.5\n        tv = torch.acos(torch.clamp(v[..., 1:2], min=-1, max=1)) / np.pi\n        texcoord = torch.cat((tu, tv), dim=-1)\n\n        cubemap[s, ...] = dr.texture(latlong_map[None, ...], texcoord[None, ...], filter_mode='linear')[0]\n    return cubemap\n\ndef cubemap_to_latlong(cubemap, res):\n    gy, gx = torch.meshgrid(torch.linspace( 0.0 + 1.0 / res[0], 1.0 - 1.0 / res[0], res[0], device='cuda'), \n                            torch.linspace(-1.0 + 1.0 / res[1], 1.0 - 1.0 / res[1], res[1], device='cuda'),\n                            indexing='ij')\n    \n    sintheta, costheta = torch.sin(gy*np.pi), torch.cos(gy*np.pi)\n    sinphi, cosphi     = torch.sin(gx*np.pi), torch.cos(gx*np.pi)\n    \n    reflvec = torch.stack((\n        sintheta*sinphi, \n        costheta, \n        -sintheta*cosphi\n        ), dim=-1)\n    return dr.texture(cubemap[None, ...], reflvec[None, ...].contiguous(), filter_mode='linear', boundary_mode='cube')[0]\n\n#----------------------------------------------------------------------------\n# Image scaling\n#----------------------------------------------------------------------------\n\ndef scale_img_hwc(x : torch.Tensor, size, mag='bilinear', min='area') -> torch.Tensor:\n    return scale_img_nhwc(x[None, ...], size, mag, min)[0]\n\ndef scale_img_nhwc(x  : torch.Tensor, size, mag='bilinear', min='area') -> torch.Tensor:\n    assert (x.shape[1] >= size[0] and x.shape[2] >= size[1]) or (x.shape[1] < size[0] and x.shape[2] < size[1]), \"Trying to magnify image in one dimension and minify in the other\"\n    y = x.permute(0, 3, 1, 2) # NHWC -> NCHW\n    if x.shape[1] > size[0] and x.shape[2] > size[1]: # Minification, previous size was bigger\n        y = torch.nn.functional.interpolate(y, size, mode=min)\n    else: # Magnification\n        if mag == 'bilinear' or mag == 'bicubic':\n            y = torch.nn.functional.interpolate(y, size, mode=mag, align_corners=True)\n        else:\n            y = torch.nn.functional.interpolate(y, size, mode=mag)\n    return y.permute(0, 2, 3, 1).contiguous() # NCHW -> NHWC\n\ndef avg_pool_nhwc(x  : torch.Tensor, size) -> torch.Tensor:\n    y = x.permute(0, 3, 1, 2) # NHWC -> NCHW\n    y = torch.nn.functional.avg_pool2d(y, size)\n    return y.permute(0, 2, 3, 1).contiguous() # NCHW -> NHWC\n\n#----------------------------------------------------------------------------\n# Behaves similar to tf.segment_sum\n#----------------------------------------------------------------------------\n\ndef segment_sum(data: torch.Tensor, segment_ids: torch.Tensor) -> torch.Tensor:\n    num_segments = torch.unique_consecutive(segment_ids).shape[0]\n\n    # Repeats ids until same dimension as data\n    if len(segment_ids.shape) == 1:\n        s = torch.prod(torch.tensor(data.shape[1:], dtype=torch.int64, device='cuda')).long()\n        segment_ids = segment_ids.repeat_interleave(s).view(segment_ids.shape[0], *data.shape[1:])\n\n    assert data.shape == segment_ids.shape, \"data.shape and segment_ids.shape should be equal\"\n\n    shape = [num_segments] + list(data.shape[1:])\n    result = torch.zeros(*shape, dtype=torch.float32, device='cuda')\n    result = result.scatter_add(0, segment_ids, data)\n    return result\n\n#----------------------------------------------------------------------------\n# Matrix helpers.\n#----------------------------------------------------------------------------\n\ndef fovx_to_fovy(fovx, aspect):\n    return np.arctan(np.tan(fovx / 2) / aspect) * 2.0\n\ndef focal_length_to_fovy(focal_length, sensor_height):\n    return 2 * np.arctan(0.5 * sensor_height / focal_length)\n\n# Reworked so this matches gluPerspective / glm::perspective, using fovy\ndef perspective(fovy=0.7854, aspect=1.0, n=0.1, f=1000.0, device=None):\n    y = np.tan(fovy / 2)\n    return torch.tensor([[1/(y*aspect),    0,            0,              0], \n                         [           0, 1/-y,            0,              0], \n                         [           0,    0, -(f+n)/(f-n), -(2*f*n)/(f-n)], \n                         [           0,    0,           -1,              0]], dtype=torch.float32, device=device)\n\n# Reworked so this matches gluPerspective / glm::perspective, using fovy\ndef perspective_offcenter(fovy, fraction, rx, ry, aspect=1.0, n=0.1, f=1000.0, device=None):\n    y = np.tan(fovy / 2)\n\n    # Full frustum\n    R, L = aspect*y, -aspect*y\n    T, B = y, -y\n\n    # Create a randomized sub-frustum\n    width  = (R-L)*fraction\n    height = (T-B)*fraction\n    xstart = (R-L)*rx\n    ystart = (T-B)*ry\n\n    l = L + xstart\n    r = l + width\n    b = B + ystart\n    t = b + height\n    \n    # https://www.scratchapixel.com/lessons/3d-basic-rendering/perspective-and-orthographic-projection-matrix/opengl-perspective-projection-matrix\n    return torch.tensor([[2/(r-l),        0,  (r+l)/(r-l),              0], \n                         [      0, -2/(t-b),  (t+b)/(t-b),              0], \n                         [      0,        0, -(f+n)/(f-n), -(2*f*n)/(f-n)], \n                         [      0,        0,           -1,              0]], dtype=torch.float32, device=device)\n\ndef translate(x, y, z, device=None):\n    return torch.tensor([[1, 0, 0, x], \n                         [0, 1, 0, y], \n                         [0, 0, 1, z], \n                         [0, 0, 0, 1]], dtype=torch.float32, device=device)\n\ndef rotate_x(a, device=None):\n    s, c = np.sin(a), np.cos(a)\n    return torch.tensor([[1,  0, 0, 0], \n                         [0,  c, s, 0], \n                         [0, -s, c, 0], \n                         [0,  0, 0, 1]], dtype=torch.float32, device=device)\n\ndef rotate_y(a, device=None):\n    s, c = np.sin(a), np.cos(a)\n    return torch.tensor([[ c, 0, s, 0], \n                         [ 0, 1, 0, 0], \n                         [-s, 0, c, 0], \n                         [ 0, 0, 0, 1]], dtype=torch.float32, device=device)\n\ndef rotate_z(a, device=None):\n    s, c = np.sin(a), np.cos(a)\n    return torch.tensor([[ c, s, 0, 0], \n                         [-s, c, 0, 0], \n                         [ 0, 0, 1, 0], \n                         [ 0, 0, 0, 1]], dtype=torch.float32, device=device)\n\ndef scale(s, device=None):\n    return torch.tensor([[ s, 0, 0, 0], \n                         [ 0, s, 0, 0], \n                         [ 0, 0, s, 0], \n                         [ 0, 0, 0, 1]], dtype=torch.float32, device=device)\n\ndef lookAt(eye, at, up):\n    a = eye - at\n    w = a / torch.linalg.norm(a)\n    u = torch.cross(up, w)\n    u = u / torch.linalg.norm(u)\n    v = torch.cross(w, u)\n    translate = torch.tensor([[1, 0, 0, -eye[0]], \n                              [0, 1, 0, -eye[1]], \n                              [0, 0, 1, -eye[2]], \n                              [0, 0, 0, 1]], dtype=eye.dtype, device=eye.device)\n    rotate = torch.tensor([[u[0], u[1], u[2], 0], \n                           [v[0], v[1], v[2], 0], \n                           [w[0], w[1], w[2], 0], \n                           [0, 0, 0, 1]], dtype=eye.dtype, device=eye.device)\n    return rotate @ translate\n\n@torch.no_grad()\ndef random_rotation_translation(t, device=None):\n    m = np.random.normal(size=[3, 3])\n    m[1] = np.cross(m[0], m[2])\n    m[2] = np.cross(m[0], m[1])\n    m = m / np.linalg.norm(m, axis=1, keepdims=True)\n    m = np.pad(m, [[0, 1], [0, 1]], mode='constant')\n    m[3, 3] = 1.0\n    m[:3, 3] = np.random.uniform(-t, t, size=[3])\n    return torch.tensor(m, dtype=torch.float32, device=device)\n\n@torch.no_grad()\ndef random_rotation(device=None):\n    m = np.random.normal(size=[3, 3])\n    m[1] = np.cross(m[0], m[2])\n    m[2] = np.cross(m[0], m[1])\n    m = m / np.linalg.norm(m, axis=1, keepdims=True)\n    m = np.pad(m, [[0, 1], [0, 1]], mode='constant')\n    m[3, 3] = 1.0\n    m[:3, 3] = np.array([0,0,0]).astype(np.float32)\n    return torch.tensor(m, dtype=torch.float32, device=device)\n\n#----------------------------------------------------------------------------\n# Compute focal points of a set of lines using least squares. \n# handy for poorly centered datasets\n#----------------------------------------------------------------------------\n\ndef lines_focal(o, d):\n    d = safe_normalize(d)\n    I = torch.eye(3, dtype=o.dtype, device=o.device)\n    S = torch.sum(d[..., None] @ torch.transpose(d[..., None], 1, 2) - I[None, ...], dim=0)\n    C = torch.sum((d[..., None] @ torch.transpose(d[..., None], 1, 2) - I[None, ...]) @ o[..., None], dim=0).squeeze(1)\n    return torch.linalg.pinv(S) @ C\n\n#----------------------------------------------------------------------------\n# Cosine sample around a vector N\n#----------------------------------------------------------------------------\n@torch.no_grad()\ndef cosine_sample(N, size=None):\n    # construct local frame\n    N = N/torch.linalg.norm(N)\n\n    dx0 = torch.tensor([0, N[2], -N[1]], dtype=N.dtype, device=N.device)\n    dx1 = torch.tensor([-N[2], 0, N[0]], dtype=N.dtype, device=N.device)\n\n    dx = torch.where(dot(dx0, dx0) > dot(dx1, dx1), dx0, dx1)\n    #dx = dx0 if np.dot(dx0,dx0) > np.dot(dx1,dx1) else dx1\n    dx = dx / torch.linalg.norm(dx)\n    dy = torch.cross(N,dx)\n    dy = dy / torch.linalg.norm(dy)\n\n    # cosine sampling in local frame\n    if size is None:\n        phi = 2.0 * np.pi * np.random.uniform()\n        s = np.random.uniform()\n    else:\n        phi = 2.0 * np.pi * torch.rand(*size, 1, dtype=N.dtype, device=N.device)\n        s = torch.rand(*size, 1, dtype=N.dtype, device=N.device)\n    costheta = np.sqrt(s)\n    sintheta = np.sqrt(1.0 - s)\n\n    # cartesian vector in local space\n    x = np.cos(phi)*sintheta\n    y = np.sin(phi)*sintheta\n    z = costheta\n\n    # local to world\n    return dx*x + dy*y + N*z\n\n#----------------------------------------------------------------------------\n# Bilinear downsample by 2x.\n#----------------------------------------------------------------------------\n\ndef bilinear_downsample(x : torch.tensor) -> torch.Tensor:\n    w = torch.tensor([[1, 3, 3, 1], [3, 9, 9, 3], [3, 9, 9, 3], [1, 3, 3, 1]], dtype=torch.float32, device=x.device) / 64.0\n    w = w.expand(x.shape[-1], 1, 4, 4) \n    x = torch.nn.functional.conv2d(x.permute(0, 3, 1, 2), w, padding=1, stride=2, groups=x.shape[-1])\n    return x.permute(0, 2, 3, 1)\n\n#----------------------------------------------------------------------------\n# Bilinear downsample log(spp) steps\n#----------------------------------------------------------------------------\n\ndef bilinear_downsample(x : torch.tensor, spp) -> torch.Tensor:\n    w = torch.tensor([[1, 3, 3, 1], [3, 9, 9, 3], [3, 9, 9, 3], [1, 3, 3, 1]], dtype=torch.float32, device=x.device) / 64.0\n    g = x.shape[-1]\n    w = w.expand(g, 1, 4, 4) \n    x = x.permute(0, 3, 1, 2) # NHWC -> NCHW\n    steps = int(np.log2(spp))\n    for _ in range(steps):\n        xp = torch.nn.functional.pad(x, (1,1,1,1), mode='replicate')\n        x = torch.nn.functional.conv2d(xp, w, padding=0, stride=2, groups=g)\n    return x.permute(0, 2, 3, 1).contiguous() # NCHW -> NHWC\n\n#----------------------------------------------------------------------------\n# Singleton initialize GLFW\n#----------------------------------------------------------------------------\n\n_glfw_initialized = False\ndef init_glfw():\n    global _glfw_initialized\n    try:\n        import glfw\n        glfw.ERROR_REPORTING = 'raise'\n        glfw.default_window_hints()\n        glfw.window_hint(glfw.VISIBLE, glfw.FALSE)\n        test = glfw.create_window(8, 8, \"Test\", None, None) # Create a window and see if not initialized yet\n    except glfw.GLFWError as e:\n        if e.error_code == glfw.NOT_INITIALIZED:\n            glfw.init()\n            _glfw_initialized = True\n\n#----------------------------------------------------------------------------\n# Image display function using OpenGL.\n#----------------------------------------------------------------------------\n\n_glfw_window = None\ndef display_image(image, title=None):\n    # Import OpenGL\n    import OpenGL.GL as gl\n    import glfw\n\n    # Zoom image if requested.\n    image = np.asarray(image[..., 0:3]) if image.shape[-1] == 4 else np.asarray(image)\n    height, width, channels = image.shape\n\n    # Initialize window.\n    init_glfw()\n    if title is None:\n        title = 'Debug window'\n    global _glfw_window\n    if _glfw_window is None:\n        glfw.default_window_hints()\n        _glfw_window = glfw.create_window(width, height, title, None, None)\n        glfw.make_context_current(_glfw_window)\n        glfw.show_window(_glfw_window)\n        glfw.swap_interval(0)\n    else:\n        glfw.make_context_current(_glfw_window)\n        glfw.set_window_title(_glfw_window, title)\n        glfw.set_window_size(_glfw_window, width, height)\n\n    # Update window.\n    glfw.poll_events()\n    gl.glClearColor(0, 0, 0, 1)\n    gl.glClear(gl.GL_COLOR_BUFFER_BIT)\n    gl.glWindowPos2f(0, 0)\n    gl.glPixelStorei(gl.GL_UNPACK_ALIGNMENT, 1)\n    gl_format = {3: gl.GL_RGB, 2: gl.GL_RG, 1: gl.GL_LUMINANCE}[channels]\n    gl_dtype = {'uint8': gl.GL_UNSIGNED_BYTE, 'float32': gl.GL_FLOAT}[image.dtype.name]\n    gl.glDrawPixels(width, height, gl_format, gl_dtype, image[::-1])\n    glfw.swap_buffers(_glfw_window)\n    if glfw.window_should_close(_glfw_window):\n        return False\n    return True\n\n#----------------------------------------------------------------------------\n# Image save/load helper.\n#----------------------------------------------------------------------------\n\ndef save_image(fn, x : np.ndarray) -> np.ndarray:\n    try:\n        if os.path.splitext(fn)[1] == \".png\":\n            imageio.imwrite(fn, np.clip(np.rint(x * 255.0), 0, 255).astype(np.uint8), compress_level=3) # Low compression for faster saving\n        else:\n            imageio.imwrite(fn, np.clip(np.rint(x * 255.0), 0, 255).astype(np.uint8))\n    except:\n        print(\"WARNING: FAILED to save image %s\" % fn)\n\ndef save_image_raw(fn, x : np.ndarray):\n    try:\n        imageio.imwrite(fn, x)\n    except:\n        print(\"WARNING: FAILED to save image %s\" % fn)\n\n\ndef load_image_raw(fn) -> np.ndarray:\n    return imageio.imread(fn)\n\ndef load_image(fn) -> np.ndarray:\n    img = load_image_raw(fn)\n    if img.dtype == np.float32: # HDR image\n        return img\n    else: # LDR image\n        return img.astype(np.float32) / 255\n\n#----------------------------------------------------------------------------\n\ndef time_to_text(x):\n    if x > 3600:\n        return \"%.2f h\" % (x / 3600)\n    elif x > 60:\n        return \"%.2f m\" % (x / 60)\n    else:\n        return \"%.2f s\" % x\n\n#----------------------------------------------------------------------------\n\ndef checkerboard(res, checker_size) -> np.ndarray:\n    tiles_y = (res[0] + (checker_size*2) - 1) // (checker_size*2)\n    tiles_x = (res[1] + (checker_size*2) - 1) // (checker_size*2)\n    check = np.kron([[1, 0] * tiles_x, [0, 1] * tiles_x] * tiles_y, np.ones((checker_size, checker_size)))*0.33 + 0.33\n    check = check[:res[0], :res[1]]\n    return np.stack((check, check, check), axis=-1)"
  },
  {
    "path": "train_gflexicubes_deepfashion.py",
    "content": "# Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. \n#\n# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual\n# property and proprietary rights in and to this material, related\n# documentation and any modifications thereto. Any use, reproduction, \n# disclosure or distribution of this material and related documentation \n# without an express license agreement from NVIDIA CORPORATION or \n# its affiliates is strictly prohibited.\n\nimport os\nimport sys\nimport time\nimport argparse\nimport json\n\nimport numpy as np\nimport torch\nimport nvdiffrast.torch as dr\nimport xatlas\n\n# Import data readers / generators\nfrom dataset.dataset_deepfashion import DatasetDeepFashion\nfrom dataset.dataset_deepfashion_testset import DatasetDeepFashionTestset\n\n# Import topology / geometry trainers\nfrom geometry.gshell_flexicubes_geometry import GShellFlexiCubesGeometry\n\nimport render.renderutils as ru\nfrom render import obj\nfrom render import material\nfrom render import util\nfrom render import mesh\nfrom render import texture\nfrom render import mlptexture\nfrom render import light\nfrom render import render\n\n\nfrom denoiser.denoiser import BilateralDenoiser\n\n\nRADIUS = 3.0\n\n# Enable to debug back-prop anomalies\n# torch.autograd.set_detect_anomaly(True)\n\n###############################################################################\n# Loss setup\n###############################################################################\n\n@torch.no_grad()\ndef createLoss(FLAGS):\n    if FLAGS.loss == \"smape\":\n        return lambda img, ref: ru.image_loss(img, ref, loss='smape', tonemapper='none')\n    elif FLAGS.loss == \"mse\":\n        return lambda img, ref: ru.image_loss(img, ref, loss='mse', tonemapper='none')\n    elif FLAGS.loss == \"logl1\":\n        return lambda img, ref: ru.image_loss(img, ref, loss='l1', tonemapper='log_srgb')\n    elif FLAGS.loss == \"logl2\":\n        return lambda img, ref: ru.image_loss(img, ref, loss='mse', tonemapper='log_srgb')\n    elif FLAGS.loss == \"relmse\":\n        return lambda img, ref: ru.image_loss(img, ref, loss='relmse', tonemapper='none')\n    else:\n        assert False\n\n###############################################################################\n# Mix background into a dataset image\n###############################################################################\n\n@torch.no_grad()\ndef prepare_batch(target, bg_type='black'):\n    assert len(target['img'].shape) == 4, \"Image shape should be [n, h, w, c]\"\n    if bg_type == 'checker':\n        background = torch.tensor(util.checkerboard(target['img'].shape[1:3], 8), dtype=torch.float32, device='cuda')[None, ...]\n    elif bg_type == 'black':\n        background = torch.zeros(target['img'].shape[0:3] + (3,), dtype=torch.float32, device='cuda')\n    elif bg_type == 'white':\n        background = torch.ones(target['img'].shape[0:3] + (3,), dtype=torch.float32, device='cuda')\n    elif bg_type == 'reference':\n        background = target['img'][..., 0:3]\n    elif bg_type == 'random':\n        background = torch.rand(target['img'].shape[0:3] + (3,), dtype=torch.float32, device='cuda')\n    else:\n        assert False, \"Unknown background type %s\" % bg_type\n\n    target['mv'] = target['mv'].cuda()\n    target['mvp'] = target['mvp'].cuda()\n    target['campos'] = target['campos'].cuda()\n    target['img'] = target['img'].cuda()\n    target['background'] = background\n\n    target['img'] = torch.cat((torch.lerp(background, target['img'][..., 0:3], target['img'][..., 3:4]), target['img'][..., 3:4]), dim=-1)\n\n    return target\n\n###############################################################################\n# UV - map geometry & convert to a mesh\n###############################################################################\n\n@torch.no_grad()\ndef xatlas_uvmap(glctx, geometry, mat, FLAGS):\n    eval_mesh = geometry.getMesh(mat)\n    try:\n        eval_mesh = eval_mesh['imesh']\n    except:\n        pass\n    \n    # Create uvs with xatlas\n    v_pos = eval_mesh.v_pos.detach().cpu().numpy()\n    t_pos_idx = eval_mesh.t_pos_idx.detach().cpu().numpy()\n    vmapping, indices, uvs = xatlas.parametrize(v_pos, t_pos_idx)\n\n    # Convert to tensors\n    indices_int64 = indices.astype(np.uint64, casting='same_kind').view(np.int64)\n    \n    uvs = torch.tensor(uvs, dtype=torch.float32, device='cuda')\n    faces = torch.tensor(indices_int64, dtype=torch.int64, device='cuda')\n\n    new_mesh = mesh.Mesh(v_tex=uvs, t_tex_idx=faces, base=eval_mesh)\n\n    mask, kd, ks = render.render_uv(glctx, new_mesh, FLAGS.texture_res, eval_mesh.material['kd_ks'])\n\n    # Dilate all textures & use average color for background\n    kd_avg = torch.sum(torch.sum(torch.sum(kd * mask, dim=0), dim=0), dim=0) / torch.sum(torch.sum(torch.sum(mask, dim=0), dim=0), dim=0)\n    kd = util.dilate(kd, kd_avg[None, None, None, :], mask, 7)\n\n    ks_avg = torch.sum(torch.sum(torch.sum(ks * mask, dim=0), dim=0), dim=0) / torch.sum(torch.sum(torch.sum(mask, dim=0), dim=0), dim=0)\n    ks = util.dilate(ks, ks_avg[None, None, None, :], mask, 7)\n\n    nrm_avg = torch.tensor([0, 0, 1], dtype=torch.float32, device=\"cuda\")\n    normal = nrm_avg[None, None, None, :].repeat(kd.shape[0], kd.shape[1], kd.shape[2], 1)\n    \n    new_mesh.material = mat.copy()\n    del new_mesh.material['kd_ks']\n\n    if FLAGS.transparency:\n        kd = torch.cat((kd, torch.rand_like(kd[...,0:1])), dim=-1)\n        print(\"kd shape\", kd.shape)\n\n    kd_min, kd_max = torch.tensor(FLAGS.kd_min, dtype=torch.float32, device='cuda'), torch.tensor(FLAGS.kd_max, dtype=torch.float32, device='cuda')\n    ks_min, ks_max = torch.tensor(FLAGS.ks_min, dtype=torch.float32, device='cuda'), torch.tensor(FLAGS.ks_max, dtype=torch.float32, device='cuda')\n    nrm_min, nrm_max = torch.tensor(FLAGS.nrm_min, dtype=torch.float32, device='cuda'), torch.tensor(FLAGS.nrm_max, dtype=torch.float32, device='cuda')\n    new_mesh.material.update({\n        'kd'     : texture.Texture2D(kd.clone().detach().requires_grad_(True), min_max=[kd_min, kd_max]),\n        'ks'     : texture.Texture2D(ks.clone().detach().requires_grad_(True), min_max=[ks_min, ks_max]),\n        'normal' : texture.Texture2D(normal.clone().detach().requires_grad_(True), min_max=[nrm_min, nrm_max]),\n    })\n\n    return new_mesh\n\n###############################################################################\n# Utility functions for material\n###############################################################################\n\ndef initial_guess_material(geometry, mlp, FLAGS, init_mat=None):\n    kd_min, kd_max = torch.tensor(FLAGS.kd_min, dtype=torch.float32, device='cuda'), torch.tensor(FLAGS.kd_max, dtype=torch.float32, device='cuda')\n    ks_min, ks_max = torch.tensor(FLAGS.ks_min, dtype=torch.float32, device='cuda'), torch.tensor(FLAGS.ks_max, dtype=torch.float32, device='cuda')\n    if mlp:\n        mlp_min = torch.cat((kd_min[0:3], ks_min), dim=0)\n        mlp_max = torch.cat((kd_max[0:3], ks_max), dim=0)\n        mlp_map_opt = mlptexture.MLPTexture3D(geometry.getAABB(), channels=6, min_max=[mlp_min, mlp_max], use_float16=FLAGS.use_float16)\n        mat =  {'kd_ks' : mlp_map_opt}\n    else:\n        raise NotImplementedError\n\n    mat['bsdf'] = FLAGS.bsdf\n\n    mat['no_perturbed_nrm'] = FLAGS.no_perturbed_nrm\n\n    return mat\n\ndef initial_guess_material_knownkskd(geometry, mlp, FLAGS, init_mat=None):\n    mat =  {\n        'kd'     : init_mat['kd'],\n        'ks'     : init_mat['ks']\n    }\n\n    if init_mat is not None:\n        mat['bsdf'] = init_mat['bsdf']\n    else:\n        mat['bsdf'] = 'pbr'\n\n    return mat\n\n###############################################################################\n# Validation & testing\n###############################################################################\n\n@torch.no_grad()\ndef validate_itr(glctx, target, geometry, opt_material, lgt, FLAGS, denoiser=None):\n    result_dict = {}\n    with torch.no_grad():\n        buffers = geometry.render(glctx, target, lgt, opt_material, use_uv=False, denoiser=denoiser)['buffers']\n\n        result_dict['ref'] = util.rgb_to_srgb(target['img'][...,0:3])[0]\n        result_dict['opt'] = util.rgb_to_srgb(buffers['shaded'][...,0:3])[0]\n        result_dict['mask_opt'] = buffers['shaded'][...,3:][0].expand(-1, -1, 3)\n        result_dict['mask_ref'] = target['img'][...,3:][0].expand(-1, -1, 3)\n        result_dict['msdf_image'] = buffers['msdf_image'][...,:][0].expand(-1, -1, 3).clamp(min=0, max=1)\n        # result_dict['invdepth_2nd'] = buffers['invdepth_second'][...,0:1].expand(-1, -1, -1, 3).clamp(max=1.0)[0]\n        # result_dict['invdepth_2nd_ref'] = target['invdepth_second'][...,0:1].expand(-1, -1, -1, 3).clamp(max=1.0)[0]\n        result_image = torch.cat([result_dict['opt'], result_dict['ref'], result_dict['mask_opt'], result_dict['mask_ref'], result_dict['msdf_image']], axis=1)\n\n        if FLAGS.display is not None:\n            white_bg = torch.ones_like(target['background'])\n            for layer in FLAGS.display:\n                if 'latlong' in layer and layer['latlong']:\n                    result_dict['light_image'] = lgt.generate_image(FLAGS.display_res)\n                    result_dict['light_image'] = util.rgb_to_srgb(result_dict['light_image'] / (1 + result_dict['light_image']))\n                    result_image = torch.cat([result_image, result_dict['light_image']], axis=1)\n                elif 'bsdf' in layer:\n                    img = render.render_mesh(FLAGS, glctx, opt_mesh, target['mvp'], target['campos'], target['light'] if lgt is None else lgt, target['resolution'],\n                                                spp=target['spp'], num_layers=FLAGS.layers, background=white_bg, bsdf=layer['bsdf'], optix_ctx=geometry.optix_ctx)['shaded']\n                    if layer['bsdf'] == 'kd':\n                        result_dict[layer['bsdf']] = util.rgb_to_srgb(img[..., 0:3])[0]\n                    else:\n                        result_dict[layer['bsdf']] = img[0, ..., 0:3]\n                    result_image = torch.cat([result_image, result_dict[layer['bsdf']]], axis=1)\n                    if ref_mesh is not None:\n                        img = render.render_mesh(FLAGS, glctx, ref_mesh, target['mvp'], target['campos'], target['light'], target['resolution'],\n                                                    spp=target['spp'], num_layers=FLAGS.layers, background=white_bg, bsdf=layer['bsdf'], optix_ctx=geometry.optix_ctx)['shaded']\n                        if layer['bsdf'] == 'kd':\n                            result_dict[layer['bsdf'] + \"_ref\"] = util.rgb_to_srgb(img[..., 0:3])[0]\n                        else:\n                            result_dict[layer['bsdf'] + \"_ref\"] = img[0, ..., 0:3]\n                        result_image = torch.cat([result_image, result_dict[layer['bsdf'] + \"_ref\"]], axis=1)\n                elif 'normals' in layer and not FLAGS.no_perturbed_nrm:\n                    result_image = torch.cat([result_image, (buffers['perturbed_nrm'][0, ...,0:3] + 1.0) * 0.5], axis=1)\n                elif 'diffuse_light' in layer:\n                    result_image = torch.cat([result_image, util.rgb_to_srgb(buffers['diffuse_light'][..., 0:3])[0]], axis=1)\n                elif 'specular_light' in layer:\n                    result_image = torch.cat([result_image, util.rgb_to_srgb(buffers['specular_light'][..., 0:3])[0]], axis=1)\n\n        return result_image, result_dict\n\n@torch.no_grad()\ndef validate(glctx, geometry, opt_material, lgt, dataset_validate, out_dir, FLAGS, denoiser=None, save_viz=False):\n\n    # ==============================================================================================\n    #  Validation loop\n    # ==============================================================================================\n    mse_values = []\n    psnr_values = []\n\n    dataloader_validate = torch.utils.data.DataLoader(dataset_validate, batch_size=1, collate_fn=dataset_validate.collate)\n\n    os.makedirs(out_dir, exist_ok=True)\n    with open(os.path.join(out_dir, 'metrics.txt'), 'w') as fout:\n        fout.write('ID, MSE, PSNR\\n')\n\n        print(\"Running validation\")\n        for it, target in enumerate(dataloader_validate):\n\n            # Mix validation background\n            target = prepare_batch(target, FLAGS.background)\n\n            result_image, result_dict = validate_itr(glctx, target, geometry, opt_material, lgt, FLAGS, denoiser=denoiser)\n\n            # Compute metrics\n            opt = torch.clamp(result_dict['opt'], 0.0, 1.0) \n            ref = torch.clamp(result_dict['ref'], 0.0, 1.0)\n\n            mse = torch.nn.functional.mse_loss(opt, ref, size_average=None, reduce=None, reduction='mean').item()\n            mse_values.append(float(mse))\n            psnr = util.mse_to_psnr(mse)\n            psnr_values.append(float(psnr))\n\n            line = \"%d, %1.8f, %1.8f\\n\" % (it, mse, psnr)\n            fout.write(str(line))\n\n            if save_viz:\n                for k in result_dict.keys():\n                    np_img = result_dict[k].detach().cpu().numpy()\n                    util.save_image(out_dir + '/' + ('val_%06d_%s.png' % (it, k)), np_img)\n\n        avg_mse = np.mean(np.array(mse_values))\n        avg_psnr = np.mean(np.array(psnr_values))\n        line = \"AVERAGES: %1.4f, %2.3f\\n\" % (avg_mse, avg_psnr)\n        fout.write(str(line))\n        print(\"MSE,      PSNR\")\n        print(\"%1.8f, %2.3f\" % (avg_mse, avg_psnr))\n    return avg_psnr\n\n###############################################################################\n# Main shape fitter function / optimization loop\n###############################################################################\n\ndef optimize_mesh(\n        denoiser,\n        glctx,\n        geometry,\n        opt_material,\n        lgt,\n        dataset_train,\n        dataset_validate,\n        FLAGS,\n        warmup_iter=0,\n        log_interval=10,\n        pass_idx=0,\n        pass_name=\"\",\n        optimize_light=True,\n        optimize_geometry=True,\n        visualize=True,\n        save_path=None\n    ):\n\n    # ==============================================================================================\n    #  Setup torch optimizer\n    # ==============================================================================================\n\n    learning_rate = FLAGS.learning_rate[pass_idx] if isinstance(FLAGS.learning_rate, list) or isinstance(FLAGS.learning_rate, tuple) else FLAGS.learning_rate\n    learning_rate_pos = learning_rate[0] if isinstance(learning_rate, list) or isinstance(learning_rate, tuple) else learning_rate\n    learning_rate_mat = learning_rate[1] if isinstance(learning_rate, list) or isinstance(learning_rate, tuple) else learning_rate\n    learning_rate_lgt = learning_rate[2] if isinstance(learning_rate, list) or isinstance(learning_rate, tuple) else learning_rate * 6.0\n\n    def lr_schedule(iter, fraction):\n        if iter < warmup_iter:\n            return iter / warmup_iter \n        return max(0.0, 10**(-(iter - warmup_iter)*0.0002)) # Exponential falloff from [1.0, 0.1] over 5k epochs.    \n\n    # ==============================================================================================\n    #  Image loss\n    # ==============================================================================================\n    image_loss_fn = createLoss(FLAGS)\n\n\n\n    params = list(material.get_parameters(opt_material))\n\n    if optimize_light:\n        optimizer_light = torch.optim.Adam((lgt.parameters() if lgt is not None else []), lr=learning_rate_lgt)\n        scheduler_light = torch.optim.lr_scheduler.LambdaLR(optimizer_light, lr_lambda=lambda x: lr_schedule(x, 0.9)) \n\n    if optimize_geometry:\n        if FLAGS.use_sdf_mlp:\n            lr_msdf = learning_rate_pos * 1e-2 if FLAGS.use_msdf_mlp else learning_rate_pos\n            deform_params = list(v[1] for v in geometry.named_parameters() if 'deform' in v[0]) if optimize_geometry else []\n            msdf_params = list(v[1] for v in geometry.named_parameters() if 'msdf' in v[0]) if optimize_geometry else []\n            sdf_params = list(v[1] for v in geometry.named_parameters() if 'sdf' in v[0] and 'msdf' not in v[0]) if optimize_geometry else []\n            other_params = list(v[1] for v in geometry.named_parameters() if 'sdf' not in v[0] and 'msdf' not in v[0] and 'deform' not in v[0]) if optimize_geometry else []\n            optimizer_mesh = torch.optim.Adam([\n                    {'params': deform_params, 'lr': learning_rate_pos},\n                    {'params': msdf_params, 'lr': lr_msdf},\n                    {'params': sdf_params, 'lr': learning_rate_pos * 1e-2},\n                    {'params': other_params, 'lr': learning_rate_pos * 1e-2},\n                ], eps=1e-8)\n        else:\n            optimizer_mesh = torch.optim.Adam(geometry.parameters(), lr=learning_rate_pos)\n        scheduler_mesh = torch.optim.lr_scheduler.LambdaLR(optimizer_mesh, lr_lambda=lambda x: lr_schedule(x, 0.9)) \n\n    optimizer = torch.optim.Adam(params, lr=learning_rate_mat)\n    scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda x: lr_schedule(x, 0.9))\n\n    # ==============================================================================================\n    #  Training loop\n    # ==============================================================================================\n    img_cnt = 0\n    img_loss_vec = []\n    depth_loss_vec = []\n    reg_loss_vec = []\n    iter_dur_vec = []\n\n    dataloader_train    = torch.utils.data.DataLoader(dataset_train, batch_size=FLAGS.batch, collate_fn=dataset_train.collate, shuffle=True)\n    if visualize:\n        dataloader_validate = torch.utils.data.DataLoader(dataset_validate, batch_size=1, collate_fn=dataset_train.collate)\n\n        def cycle(iterable):\n            iterator = iter(iterable)\n            while True:\n                try:\n                    yield next(iterator)\n                except StopIteration:\n                    iterator = iter(iterable)\n\n        v_it = cycle(dataloader_validate)\n\n    for it, target in enumerate(dataloader_train):\n\n        # Mix randomized background into dataset image\n        target = prepare_batch(target, 'random')\n\n        # ==============================================================================================\n        #  Display / save outputs. Do it before training so we get initial meshes\n        # ==============================================================================================\n\n        # Show/save image before training step (want to get correct rendering of input)\n        if visualize and FLAGS.local_rank == 0 and it != 0:\n            with torch.no_grad():\n                display_image = FLAGS.display_interval and (it % FLAGS.display_interval == 0)\n                save_image = FLAGS.save_interval and (it % FLAGS.save_interval == 0)\n                if display_image or save_image:\n                    save_mesh = True\n                    if save_mesh:\n                        os.makedirs(os.path.join(save_path, pass_name), exist_ok=True)\n                        obj.write_obj(os.path.join(save_path, pass_name), geometry.getMesh(opt_material)['imesh'], save_material=False)\n                    result_image, result_dict = validate_itr(glctx, prepare_batch(next(v_it), FLAGS.background), geometry, opt_material, lgt, FLAGS, denoiser=denoiser)\n            \n                    np_result_image = result_image.detach().cpu().numpy()\n                    if display_image:\n                        util.display_image(np_result_image, title='%d / %d' % (it, FLAGS.iter))\n                    if save_image:\n                        util.save_image(os.path.join(save_path, ('img_%s_%06d.png' % (pass_name, img_cnt))), np_result_image)\n                        img_cnt = img_cnt + 1\n\n        iter_start_time = time.time()\n\n        # ==============================================================================================\n        #  Zero gradients\n        # ==============================================================================================\n        optimizer.zero_grad()\n        if optimize_geometry:\n            optimizer_mesh.zero_grad()\n        if optimize_light:\n            optimizer_light.zero_grad()\n\n        # ==============================================================================================\n        #  Training\n        # ==============================================================================================\n\n        xfm_lgt = None\n        if optimize_light:\n            lgt.update_pdf()\n            \n\n        img_loss, depth_loss, reg_loss = geometry.tick(\n            glctx, target, lgt, opt_material, image_loss_fn, it, \n            denoiser=denoiser)\n\n        # ==============================================================================================\n        #  Final loss\n        # ==============================================================================================\n        total_loss = img_loss + reg_loss\n\n        img_loss_vec.append(img_loss.item())\n        depth_loss_vec.append(depth_loss.item())\n        reg_loss_vec.append(reg_loss.item())\n\n        # ==============================================================================================\n        #  Backpropagate\n        # ==============================================================================================\n        total_loss.backward()\n        if hasattr(lgt, 'base') and lgt.base.grad is not None and optimize_light:\n            lgt.base.grad *= 64\n        if 'kd_ks' in opt_material:\n            opt_material['kd_ks'].encoder.params.grad /= 8.0\n        if 'kd_ks_back' in opt_material:\n            opt_material['kd_ks_back'].encoder.params.grad /= 8.0\n\n        # Optionally clip gradients\n        if FLAGS.clip_max_norm > 0.0:\n            if optimize_geometry:\n                torch.nn.utils.clip_grad_norm_(geometry.parameters() + params, FLAGS.clip_max_norm)\n            else:\n                torch.nn.utils.clip_grad_norm_(params, FLAGS.clip_max_norm)\n\n        optimizer.step()\n        scheduler.step()\n\n        if optimize_geometry:\n            optimizer_mesh.step()\n            scheduler_mesh.step()\n\n        if optimize_light:\n            optimizer_light.step()\n            scheduler_light.step()\n\n        # ==============================================================================================\n        #  Clamp trainables to reasonable range\n        # ==============================================================================================\n        with torch.no_grad():\n            if 'kd' in opt_material:\n                opt_material['kd'].clamp_()\n            if 'ks' in opt_material:\n                opt_material['ks'].clamp_()\n            if 'kd_back' in opt_material:\n                opt_material['kd_back'].clamp_()\n            if 'ks_back' in opt_material:\n                opt_material['ks_back'].clamp_()\n            if 'normal' in opt_material and not FLAGS.normal_only:\n                opt_material['normal'].clamp_()\n                opt_material['normal'].normalize_()\n            if lgt is not None:\n                # lgt.clamp_(min=0.01) # For some reason gradient dissapears if light becomes 0\n                lgt.clamp_(min=1e-4) # For some reason gradient dissapears if light becomes 0\n\n            geometry.clamp_deform()\n        torch.cuda.current_stream().synchronize()\n        iter_dur_vec.append(time.time() - iter_start_time)\n\n        # ==============================================================================================\n        #  Logging\n        # ==============================================================================================\n        if it % log_interval == 0 and FLAGS.local_rank == 0:\n            img_loss_avg = np.mean(np.asarray(img_loss_vec[-log_interval:]))\n            depth_loss_avg = np.mean(np.asarray(depth_loss_vec[-log_interval:]))\n            reg_loss_avg = np.mean(np.asarray(reg_loss_vec[-log_interval:]))\n            iter_dur_avg = np.mean(np.asarray(iter_dur_vec[-log_interval:]))\n            \n            remaining_time = (FLAGS.iter-it)*iter_dur_avg\n            print(\"iter=%5d, img_loss=%.6f, depth_loss=%.6f, reg_loss=%.6f, lr=%.5f, time=%.1f ms, rem=%s\" % \n                (it, img_loss_avg, depth_loss_avg, reg_loss_avg, optimizer.param_groups[0]['lr'], iter_dur_avg*1000, util.time_to_text(remaining_time)))\n            sys.stdout.flush()\n\n        if it == FLAGS.iter:\n            break\n\n    return geometry, opt_material\n\n#----------------------------------------------------------------------------\n# Main function.\n#----------------------------------------------------------------------------\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser(description='nvdiffrec')\n    parser.add_argument('--config', type=str, default=None, help='Config file')\n    parser.add_argument('-i', '--iter', type=int, default=5000)\n    parser.add_argument('-b', '--batch', type=int, default=1)\n    parser.add_argument('-s', '--spp', type=int, default=1)\n    parser.add_argument('-l', '--layers', type=int, default=1)\n    parser.add_argument('-r', '--train-res', nargs=2, type=int, default=[512, 512])\n    parser.add_argument('-dr', '--display-res', type=int, default=None)\n    parser.add_argument('-tr', '--texture-res', nargs=2, type=int, default=[1024, 1024])\n    parser.add_argument('-di', '--display-interval', type=int, default=0)\n    parser.add_argument('-si', '--save-interval', type=int, default=1000)\n    parser.add_argument('-lr', '--learning-rate', type=float, default=0.01)\n    parser.add_argument('-mr', '--min-roughness', type=float, default=0.08)\n    parser.add_argument('-mip', '--custom-mip', action='store_true', default=False)\n    parser.add_argument('-rt', '--random-textures', action='store_true', default=False)\n    parser.add_argument('-bg', '--background', default='checker', choices=['black', 'white', 'checker', 'reference'])\n    parser.add_argument('--loss', default='logl1', choices=['logl1', 'logl2', 'mse', 'smape', 'relmse'])\n    parser.add_argument('-o', '--out-dir', type=str, default=None)\n    parser.add_argument('-rm', '--ref_mesh', type=str)\n    parser.add_argument('-bm', '--base-mesh', type=str, default=None)\n    parser.add_argument('--validate', type=bool, default=True)\n    # Render specific arguments\n    parser.add_argument('--n_samples', type=int, default=4)\n    parser.add_argument('--bsdf', type=str, default='pbr', choices=['pbr', 'diffuse', 'white'])\n    # Denoiser specific arguments\n    parser.add_argument('--denoiser', default='bilateral', choices=['none', 'bilateral'])\n    parser.add_argument('--denoiser_demodulate', type=bool, default=True)\n    parser.add_argument('--index',type=int)\n    parser.add_argument('--trainset_path', type=str)\n    parser.add_argument('--testset_path', type=str, default='')\n    parser.add_argument('--msdf_reg_open_scale', type=float, default=1e-6)\n    parser.add_argument('--msdf_reg_close_scale', type=float, default=3e-6)\n    parser.add_argument('--eikonal_scale', type=float)\n\n    FLAGS = parser.parse_args()\n    FLAGS.mtl_override        = None        # Override material of model\n    FLAGS.gshell_grid          = 64          # Resolution of initial tet grid. We provide 64 and 128 resolution grids. \n                                            #    Other resolutions can be generated with https://github.com/crawforddoran/quartet\n                                            #    We include examples in data/tets/generate_tets.py\n    FLAGS.mesh_scale          = 1.4         # Scale of tet grid box. Adjust to cover the model\n    FLAGS.envlight            = None        # HDR environment probe\n    FLAGS.env_scale           = 1.0         # Env map intensity multiplier\n    FLAGS.probe_res           = 256         # Env map probe resolution\n    FLAGS.learn_lighting      = True        # Enable optimization of env lighting\n    FLAGS.display             = None        # Configure validation window/display. E.g. [{\"bsdf\" : \"kd\"}, {\"bsdf\" : \"ks\"}]\n    FLAGS.transparency        = False       # Enabled transparency through depth peeling\n    FLAGS.lock_light          = False       # Disable light optimization in the second pass\n    FLAGS.lock_pos            = False       # Disable vertex position optimization in the second pass\n    FLAGS.sdf_regularizer     = 0.2         # Weight for sdf regularizer.\n    FLAGS.laplace             = \"relative\"  # Mesh Laplacian [\"absolute\", \"relative\"]\n    FLAGS.laplace_scale       = 3000.0      # Weight for Laplace regularizer. Default is relative with large weight\n    FLAGS.pre_load            = True        # Pre-load entire dataset into memory for faster training\n    FLAGS.no_perturbed_nrm    = False       # Disable normal map\n    FLAGS.decorrelated        = False       # Use decorrelated sampling in forward and backward passes\n    FLAGS.kd_min              = [ 0.0,  0.0,  0.0,  0.0]\n    FLAGS.kd_max              = [ 1.0,  1.0,  1.0,  1.0]\n    # FLAGS.ks_min              = [ 0.0,  0.08, 0.0]\n    FLAGS.ks_min              = [ 0.0,  0.001, 0.0]\n    FLAGS.ks_max              = [ 0.0,  1.0,  1.0]\n    FLAGS.nrm_min             = [-1.0, -1.0,  0.0]\n    FLAGS.nrm_max             = [ 1.0,  1.0,  1.0]\n    FLAGS.clip_max_norm       = 0.0\n    FLAGS.cam_near_far        = [0.1, 1000.0]\n    FLAGS.lambda_kd           = 0.1\n    FLAGS.lambda_ks           = 0.05\n    FLAGS.lambda_nrm          = 0.025\n    FLAGS.lambda_nrm2         = 0.25\n    FLAGS.lambda_chroma       = 0.0\n    FLAGS.lambda_diffuse      = 0.15\n    FLAGS.lambda_specular     = 0.0025\n\n    FLAGS.random_lgt                  = False\n    FLAGS.normal_only                 = False\n    FLAGS.use_img_2nd_layer           = False\n    FLAGS.use_depth                   = False\n    FLAGS.use_depth_2nd_layer         = False\n    FLAGS.use_tanh_deform             = False\n    FLAGS.use_sdf_mlp                 = True\n    FLAGS.use_msdf_mlp                = False\n    FLAGS.use_eikonal                 = True\n    FLAGS.sdf_mlp_pretrain_steps      = 10000\n    FLAGS.use_mesh_msdf_reg           = True\n    FLAGS.sphere_init                 = False\n    FLAGS.sphere_init_norm            = 0.5\n    FLAGS.pretrained_sdf_mlp_path     = f'./data/pretrained_mlp_{FLAGS.gshell_grid}_deeper.pt'\n    FLAGS.n_hidden                    = 6\n    FLAGS.d_hidden                    = 256\n    FLAGS.n_freq                      = 6\n    FLAGS.skip_in                     = [3]\n    FLAGS.use_float16                 = False\n    FLAGS.visualize_watertight        = False\n\n    FLAGS.local_rank = 0\n    FLAGS.multi_gpu  = \"WORLD_SIZE\" in os.environ and int(os.environ[\"WORLD_SIZE\"]) > 1\n    if FLAGS.multi_gpu:\n        if \"MASTER_ADDR\" not in os.environ:\n            os.environ[\"MASTER_ADDR\"] = 'localhost'\n        if \"MASTER_PORT\" not in os.environ:\n            os.environ[\"MASTER_PORT\"] = '23456'\n\n        FLAGS.local_rank = int(os.environ[\"LOCAL_RANK\"])\n        torch.cuda.set_device(FLAGS.local_rank)\n        torch.distributed.init_process_group(backend=\"nccl\", init_method=\"env://\")\n\n    if FLAGS.config is not None:\n        data = json.load(open(FLAGS.config, 'r'))\n        for key in data:\n            FLAGS.__dict__[key] = data[key]\n\n    if FLAGS.display_res is None:\n        FLAGS.display_res = FLAGS.train_res\n\n    if FLAGS.local_rank == 0:\n        print(\"Config / Flags:\")\n        print(\"---------\")\n        for key in FLAGS.__dict__.keys():\n            print(key, FLAGS.__dict__[key])\n        print(\"---------\")\n\n    os.makedirs(FLAGS.out_dir, exist_ok=True)\n\n    glctx = dr.RasterizeGLContext()\n    glctx_display = glctx if FLAGS.batch < 16 else dr.RasterizeGLContext() # Context for display\n\n    mtl_default = None\n\n    # ==============================================================================================\n    #  Create data pipeline\n    # ==============================================================================================\n    dataset_path = FLAGS.trainset_path\n    testset_path = FLAGS.testset_path\n\n    folder_name_list = [30, 92, 117, 133, 164, 320, 448, 522, 591]\n    folder_name = folder_name_list[FLAGS.index]\n\n    folder_name = str(folder_name)\n    data_root = os.path.join(dataset_path, folder_name)\n    dataset_train    = DatasetDeepFashion(data_root, FLAGS, examples=int(1e6))\n    dataset_validate = DatasetDeepFashion(data_root, FLAGS)\n\n    if FLAGS.testset_path is not None and FLAGS.testset_path != '':\n        testdata_root = os.path.join(testset_path, folder_name)\n        dataset_test = DatasetDeepFashionTestset(testdata_root, FLAGS)\n\n\n\n    # ==============================================================================================\n    #  Create env light with trainable parameters\n    # ==============================================================================================\n    \n    lgt = None\n    if FLAGS.learn_lighting:\n        lgt = light.create_trainable_env_rnd(FLAGS.probe_res, scale=0.0, bias=0.5)\n        # lgt = light.create_trainable_env_rnd(FLAGS.probe_res, scale=0.0, bias=0.1)\n    else:\n        lgt = light.load_env(FLAGS.envlight, scale=FLAGS.env_scale, res=[FLAGS.probe_res, FLAGS.probe_res])\n\n    # ==============================================================================================\n    #  Setup denoiser\n    # ==============================================================================================\n\n    denoiser = None\n    if FLAGS.denoiser == 'bilateral':\n        denoiser = BilateralDenoiser().cuda()\n    else:\n        assert FLAGS.denoiser == 'none', \"Invalid denoiser %s\" % FLAGS.denoiser\n\n    # Setup geometry for optimization\n    geometry = GShellFlexiCubesGeometry(FLAGS.gshell_grid, FLAGS.mesh_scale, FLAGS)\n\n    # Setup textures, make initial guess from reference if possible\n    if not FLAGS.normal_only:\n        mat = initial_guess_material(geometry, True, FLAGS, mtl_default)\n    else:\n        mat = initial_guess_material_knownkskd(geometry, True, FLAGS, mtl_default)\n    mat['no_perturbed_nrm'] = True\n\n    # Run optimization\n    geometry, mat = optimize_mesh(denoiser, glctx, geometry, mat, lgt, dataset_train, dataset_validate, \n                    FLAGS, pass_idx=0, pass_name=\"pass1\", optimize_light=FLAGS.learn_lighting, save_path=os.path.join(FLAGS.out_dir, folder_name))\n\n    validate(glctx, geometry, mat, lgt, dataset_validate, os.path.join(FLAGS.out_dir, folder_name, \"validate\"), FLAGS, denoiser=denoiser, save_viz=True)\n    if FLAGS.testset_path is not None and FLAGS.testset_path != '':\n        validate(glctx, geometry, mat, lgt, dataset_test, os.path.join(FLAGS.out_dir, folder_name, \"test\"), FLAGS, denoiser=denoiser, save_viz=False)\n\n    with torch.no_grad():\n        os.makedirs(os.path.join(FLAGS.out_dir, folder_name, \"mesh\"), exist_ok=True)\n        torch.save(geometry.state_dict(), os.path.join(FLAGS.out_dir, folder_name, \"mesh/model.pt\"))\n        torch.save(mat['kd_ks'].state_dict(), os.path.join(FLAGS.out_dir, folder_name, \"mesh/mtl.pt\"))\n        light.save_env_map(os.path.join(FLAGS.out_dir, folder_name, \"mesh/probe.hdr\"), lgt)\n\n        # Create textured mesh from result\n        base_mesh = geometry.getMesh(mat)['imesh']\n\n        # Dump mesh for debugging.\n        os.makedirs(os.path.join(FLAGS.out_dir, folder_name, \"mesh\"), exist_ok=True)\n        obj.write_obj(os.path.join(FLAGS.out_dir, folder_name, \"mesh/\"), base_mesh, save_material=False)\n\n        # Free temporaries / cached memory\n        torch.cuda.empty_cache()\n        mat['kd_ks'].cleanup()\n        del mat['kd_ks']\n        if 'kd_ks_back' in mat:\n            mat['kd_ks_back'].cleanup()\n            del mat['kd_ks_back']\n\n        # Free temporaries / cached memory\n        torch.cuda.empty_cache()\n        del mat"
  },
  {
    "path": "train_gflexicubes_polycam.py",
    "content": "# Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. \n#\n# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual\n# property and proprietary rights in and to this material, related\n# documentation and any modifications thereto. Any use, reproduction, \n# disclosure or distribution of this material and related documentation \n# without an express license agreement from NVIDIA CORPORATION or \n# its affiliates is strictly prohibited.\n\nimport os\nimport sys\nimport time\nimport argparse\nimport json\n\nimport numpy as np\nimport torch\nimport nvdiffrast.torch as dr\nimport xatlas\n\n# Import data readers / generators\nfrom dataset.dataset_nerf_colmap import DatasetNERF\n\n# Import topology / geometry trainers\nfrom geometry.gshell_flexicubes_geometry import GShellFlexiCubesGeometry\n\nimport render.renderutils as ru\nfrom render import obj\nfrom render import material\nfrom render import util\nfrom render import mesh\nfrom render import texture\nfrom render import mlptexture\nfrom render import light\nfrom render import render\n\n\nfrom denoiser.denoiser import BilateralDenoiser\n\nimport tqdm\n\nRADIUS = 3.0\n\n# Enable to debug back-prop anomalies\n# torch.autograd.set_detect_anomaly(True)\n\n###############################################################################\n# Loss setup\n###############################################################################\n\n@torch.no_grad()\ndef createLoss(FLAGS):\n    if FLAGS.loss == \"smape\":\n        return lambda img, ref: ru.image_loss(img, ref, loss='smape', tonemapper='none')\n    elif FLAGS.loss == \"mse\":\n        return lambda img, ref: ru.image_loss(img, ref, loss='mse', tonemapper='none')\n    elif FLAGS.loss == \"logl1\":\n        return lambda img, ref: ru.image_loss(img, ref, loss='l1', tonemapper='log_srgb')\n    elif FLAGS.loss == \"logl2\":\n        return lambda img, ref: ru.image_loss(img, ref, loss='mse', tonemapper='log_srgb')\n    elif FLAGS.loss == \"relmse\":\n        return lambda img, ref: ru.image_loss(img, ref, loss='relmse', tonemapper='none')\n    else:\n        assert False\n\n###############################################################################\n# Mix background into a dataset image\n###############################################################################\n\n@torch.no_grad()\ndef prepare_batch(target, bg_type='black'):\n    assert len(target['img'].shape) == 4, \"Image shape should be [n, h, w, c]\"\n    if bg_type == 'checker':\n        background = torch.tensor(util.checkerboard(target['img'].shape[1:3], 8), dtype=torch.float32, device='cuda')[None, ...]\n    elif bg_type == 'black':\n        background = torch.zeros(target['img'].shape[0:3] + (3,), dtype=torch.float32, device='cuda')\n    elif bg_type == 'white':\n        background = torch.ones(target['img'].shape[0:3] + (3,), dtype=torch.float32, device='cuda')\n    elif bg_type == 'reference':\n        background = target['img'][..., 0:3]\n    elif bg_type == 'random':\n        background = torch.rand(target['img'].shape[0:3] + (3,), dtype=torch.float32, device='cuda')\n    else:\n        assert False, \"Unknown background type %s\" % bg_type\n\n    target['mv'] = target['mv'].cuda()\n    target['mvp'] = target['mvp'].cuda()\n    target['campos'] = target['campos'].cuda()\n    target['img'] = target['img'].cuda()\n    target['background'] = background\n\n    target['img'] = torch.cat((torch.lerp(background, target['img'][..., 0:3], target['img'][..., 3:4]), target['img'][..., 3:4]), dim=-1)\n\n    return target\n\n###############################################################################\n# UV - map geometry & convert to a mesh\n###############################################################################\n\n@torch.no_grad()\ndef xatlas_uvmap(glctx, geometry, mat, FLAGS):\n    eval_mesh = geometry.getMesh(mat)\n    try:\n        eval_mesh = eval_mesh['imesh']\n    except:\n        pass\n    \n    # Create uvs with xatlas\n    v_pos = eval_mesh.v_pos.detach().cpu().numpy()\n    t_pos_idx = eval_mesh.t_pos_idx.detach().cpu().numpy()\n    vmapping, indices, uvs = xatlas.parametrize(v_pos, t_pos_idx)\n\n    # Convert to tensors\n    indices_int64 = indices.astype(np.uint64, casting='same_kind').view(np.int64)\n    \n    uvs = torch.tensor(uvs, dtype=torch.float32, device='cuda')\n    faces = torch.tensor(indices_int64, dtype=torch.int64, device='cuda')\n\n    new_mesh = mesh.Mesh(v_tex=uvs, t_tex_idx=faces, base=eval_mesh)\n\n    mask, kd, ks = render.render_uv(glctx, new_mesh, FLAGS.texture_res, eval_mesh.material['kd_ks'])\n\n    # Dilate all textures & use average color for background\n    kd_avg = torch.sum(torch.sum(torch.sum(kd * mask, dim=0), dim=0), dim=0) / torch.sum(torch.sum(torch.sum(mask, dim=0), dim=0), dim=0)\n    kd = util.dilate(kd, kd_avg[None, None, None, :], mask, 7)\n\n    ks_avg = torch.sum(torch.sum(torch.sum(ks * mask, dim=0), dim=0), dim=0) / torch.sum(torch.sum(torch.sum(mask, dim=0), dim=0), dim=0)\n    ks = util.dilate(ks, ks_avg[None, None, None, :], mask, 7)\n\n    nrm_avg = torch.tensor([0, 0, 1], dtype=torch.float32, device=\"cuda\")\n    normal = nrm_avg[None, None, None, :].repeat(kd.shape[0], kd.shape[1], kd.shape[2], 1)\n    \n    new_mesh.material = mat.copy()\n    del new_mesh.material['kd_ks']\n\n    if FLAGS.transparency:\n        kd = torch.cat((kd, torch.rand_like(kd[...,0:1])), dim=-1)\n        print(\"kd shape\", kd.shape)\n\n    kd_min, kd_max = torch.tensor(FLAGS.kd_min, dtype=torch.float32, device='cuda'), torch.tensor(FLAGS.kd_max, dtype=torch.float32, device='cuda')\n    ks_min, ks_max = torch.tensor(FLAGS.ks_min, dtype=torch.float32, device='cuda'), torch.tensor(FLAGS.ks_max, dtype=torch.float32, device='cuda')\n    nrm_min, nrm_max = torch.tensor(FLAGS.nrm_min, dtype=torch.float32, device='cuda'), torch.tensor(FLAGS.nrm_max, dtype=torch.float32, device='cuda')\n    new_mesh.material.update({\n        'kd'     : texture.Texture2D(kd.clone().detach().requires_grad_(True), min_max=[kd_min, kd_max]),\n        'ks'     : texture.Texture2D(ks.clone().detach().requires_grad_(True), min_max=[ks_min, ks_max]),\n        'normal' : texture.Texture2D(normal.clone().detach().requires_grad_(True), min_max=[nrm_min, nrm_max]),\n    })\n\n    return new_mesh\n\n###############################################################################\n# Utility functions for material\n###############################################################################\n\ndef initial_guess_material(geometry, mlp, FLAGS, init_mat=None):\n    kd_min, kd_max = torch.tensor(FLAGS.kd_min, dtype=torch.float32, device='cuda'), torch.tensor(FLAGS.kd_max, dtype=torch.float32, device='cuda')\n    ks_min, ks_max = torch.tensor(FLAGS.ks_min, dtype=torch.float32, device='cuda'), torch.tensor(FLAGS.ks_max, dtype=torch.float32, device='cuda')\n    if mlp:\n        mlp_min = torch.cat((kd_min[0:3], ks_min), dim=0)\n        mlp_max = torch.cat((kd_max[0:3], ks_max), dim=0)\n        mlp_map_opt = mlptexture.MLPTexture3D(geometry.getAABB(), channels=6, min_max=[mlp_min, mlp_max], use_float16=FLAGS.use_float16)\n        mat =  {'kd_ks' : mlp_map_opt}\n    else:\n        raise NotImplementedError\n\n    mat['bsdf'] = FLAGS.bsdf\n\n    mat['no_perturbed_nrm'] = FLAGS.no_perturbed_nrm\n\n    return mat\n\ndef initial_guess_material_knownkskd(geometry, mlp, FLAGS, init_mat=None):\n    mat =  {\n        'kd'     : init_mat['kd'],\n        'ks'     : init_mat['ks']\n    }\n\n    if init_mat is not None:\n        mat['bsdf'] = init_mat['bsdf']\n    else:\n        mat['bsdf'] = 'pbr'\n\n    return mat\n\n###############################################################################\n# Validation & testing\n###############################################################################\n\n@torch.no_grad()\ndef validate_itr(glctx, target, geometry, opt_material, lgt, FLAGS, denoiser=None):\n    result_dict = {}\n    with torch.no_grad():\n        buffers = geometry.render(glctx, target, lgt, opt_material, use_uv=False, denoiser=denoiser)['buffers']\n\n        result_dict['ref'] = util.rgb_to_srgb(target['img'][...,0:3])[0]\n        result_dict['opt'] = util.rgb_to_srgb(buffers['shaded'][...,0:3])[0]\n        result_dict['mask_opt'] = buffers['shaded'][...,3:][0].expand(-1, -1, 3)\n        result_dict['mask_ref'] = target['img'][...,3:][0].expand(-1, -1, 3)\n        result_dict['msdf_image'] = buffers['msdf_image'][...,:][0].expand(-1, -1, 3).clamp(min=0, max=1)\n        result_image = torch.cat([result_dict['opt'], result_dict['ref'], result_dict['mask_opt'], result_dict['mask_ref'], result_dict['msdf_image']], axis=1)\n\n        result_dict = {}\n        result_dict['ref'] = util.rgb_to_srgb(target['img'][...,0:3])[0]\n        result_dict['opt'] = util.rgb_to_srgb(buffers['shaded'][...,0:3])[0]\n\n        return result_image, result_dict\n\n@torch.no_grad()\ndef validate(glctx, geometry, opt_material, lgt, dataset_validate, out_dir, FLAGS, denoiser=None, save_viz=False):\n\n    # ==============================================================================================\n    #  Validation loop\n    # ==============================================================================================\n    mse_values = []\n    psnr_values = []\n\n    dataloader_validate = torch.utils.data.DataLoader(dataset_validate, batch_size=1, collate_fn=dataset_validate.collate)\n\n    os.makedirs(out_dir, exist_ok=True)\n    with open(os.path.join(out_dir, 'metrics.txt'), 'w') as fout:\n        fout.write('ID, MSE, PSNR\\n')\n\n        print(\"Running validation\")\n        for it, target in enumerate(tqdm.tqdm(dataloader_validate)):\n\n            # Mix validation background\n            target = prepare_batch(target, FLAGS.background)\n\n            result_image, result_dict = validate_itr(glctx, target, geometry, opt_material, lgt, FLAGS, denoiser=denoiser)\n\n            # Compute metrics\n            opt = torch.clamp(result_dict['opt'], 0.0, 1.0) \n            ref = torch.clamp(result_dict['ref'], 0.0, 1.0)\n\n            mse = torch.nn.functional.mse_loss(opt, ref, size_average=None, reduce=None, reduction='mean').item()\n            mse_values.append(float(mse))\n            psnr = util.mse_to_psnr(mse)\n            psnr_values.append(float(psnr))\n\n            line = \"%d, %1.8f, %1.8f\\n\" % (it, mse, psnr)\n            fout.write(str(line))\n\n            if save_viz:\n                for k in result_dict.keys():\n                    np_img = result_dict[k].detach().cpu().numpy()\n                    util.save_image(out_dir + '/' + ('val_%06d_%s.png' % (it, k)), np_img)\n\n        avg_mse = np.mean(np.array(mse_values))\n        avg_psnr = np.mean(np.array(psnr_values))\n        line = \"AVERAGES: %1.4f, %2.3f\\n\" % (avg_mse, avg_psnr)\n        fout.write(str(line))\n        print(\"MSE,      PSNR\")\n        print(\"%1.8f, %2.3f\" % (avg_mse, avg_psnr))\n    return avg_psnr\n\n###############################################################################\n# Main shape fitter function / optimization loop\n###############################################################################\n\ndef optimize_mesh(\n        denoiser,\n        glctx,\n        geometry,\n        opt_material,\n        lgt,\n        dataset_train,\n        dataset_validate,\n        FLAGS,\n        warmup_iter=0,\n        log_interval=10,\n        pass_idx=0,\n        pass_name=\"\",\n        optimize_light=True,\n        optimize_geometry=True,\n        visualize=True,\n        save_path=None\n    ):\n\n    # ==============================================================================================\n    #  Setup torch optimizer\n    # ==============================================================================================\n\n    learning_rate = FLAGS.learning_rate[pass_idx] if isinstance(FLAGS.learning_rate, list) or isinstance(FLAGS.learning_rate, tuple) else FLAGS.learning_rate\n    learning_rate_pos = learning_rate[0] if isinstance(learning_rate, list) or isinstance(learning_rate, tuple) else learning_rate\n    learning_rate_mat = learning_rate[1] if isinstance(learning_rate, list) or isinstance(learning_rate, tuple) else learning_rate\n    learning_rate_lgt = learning_rate[2] if isinstance(learning_rate, list) or isinstance(learning_rate, tuple) else learning_rate * 6.0\n\n    def lr_schedule(iter, fraction):\n        if iter < warmup_iter:\n            return iter / warmup_iter \n        return max(0.0, 10**(-(iter - warmup_iter)*0.0002)) # Exponential falloff from [1.0, 0.1] over 5k epochs.    \n\n    # ==============================================================================================\n    #  Image loss\n    # ==============================================================================================\n    image_loss_fn = createLoss(FLAGS)\n\n\n\n    params = list(material.get_parameters(opt_material))\n\n    if optimize_light:\n        optimizer_light = torch.optim.Adam((lgt.parameters() if lgt is not None else []), lr=learning_rate_lgt)\n        scheduler_light = torch.optim.lr_scheduler.LambdaLR(optimizer_light, lr_lambda=lambda x: lr_schedule(x, 0.9)) \n\n    if optimize_geometry:\n        if FLAGS.use_sdf_mlp:\n            deform_params = list(v[1] for v in geometry.named_parameters() if 'deform' in v[0]) if optimize_geometry else []\n            msdf_params = list(v[1] for v in geometry.named_parameters() if 'msdf' in v[0]) if optimize_geometry else []\n            sdf_params = list(v[1] for v in geometry.named_parameters() if 'sdf' in v[0] and 'msdf' not in v[0]) if optimize_geometry else []\n            other_params = list(v[1] for v in geometry.named_parameters() if 'sdf' not in v[0] and 'msdf' not in v[0] and 'deform' not in v[0]) if optimize_geometry else []\n            optimizer_mesh = torch.optim.Adam([\n                    {'params': deform_params, 'lr': learning_rate_pos},\n                    {'params': msdf_params, 'lr': learning_rate_pos},\n                    {'params': sdf_params, 'lr': learning_rate_pos * 1e-2},\n                    {'params': other_params, 'lr': learning_rate_pos * 1e-2},\n                ], eps=1e-8)\n        else:\n            optimizer_mesh = torch.optim.Adam(geometry.parameters(), lr=learning_rate_pos)\n        scheduler_mesh = torch.optim.lr_scheduler.LambdaLR(optimizer_mesh, lr_lambda=lambda x: lr_schedule(x, 0.9)) \n\n    optimizer = torch.optim.Adam(params, lr=learning_rate_mat)\n    scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda x: lr_schedule(x, 0.9))\n\n    # ==============================================================================================\n    #  Training loop\n    # ==============================================================================================\n    img_cnt = 0\n    img_loss_vec = []\n    depth_loss_vec = []\n    reg_loss_vec = []\n    iter_dur_vec = []\n\n    dataloader_train    = torch.utils.data.DataLoader(dataset_train, batch_size=FLAGS.batch, collate_fn=dataset_train.collate, shuffle=True)\n    if visualize:\n        dataloader_validate = torch.utils.data.DataLoader(dataset_validate, batch_size=1, collate_fn=dataset_train.collate)\n\n        def cycle(iterable):\n            iterator = iter(iterable)\n            while True:\n                try:\n                    yield next(iterator)\n                except StopIteration:\n                    iterator = iter(iterable)\n\n        v_it = cycle(dataloader_validate)\n\n    for it, target in enumerate(dataloader_train):\n\n        # Mix randomized background into dataset image\n        target = prepare_batch(target, 'random')\n\n        # ==============================================================================================\n        #  Display / save outputs. Do it before training so we get initial meshes\n        # ==============================================================================================\n\n        # Show/save image before training step (want to get correct rendering of input)\n        if visualize and FLAGS.local_rank == 0 and it != 0:\n            with torch.no_grad():\n                display_image = FLAGS.display_interval and (it % FLAGS.display_interval == 0)\n                save_image = FLAGS.save_interval and (it % FLAGS.save_interval == 0)\n                if display_image or save_image:\n                    save_mesh = True\n                    if save_mesh:\n                        os.makedirs(os.path.join(save_path, pass_name), exist_ok=True)\n                        obj.write_obj(os.path.join(save_path, pass_name), geometry.getMesh(opt_material)['imesh'], save_material=False)\n                    result_image, result_dict = validate_itr(glctx, prepare_batch(next(v_it), FLAGS.background), geometry, opt_material, lgt, FLAGS, denoiser=denoiser)\n            \n                    np_result_image = result_image.detach().cpu().numpy()\n                    if display_image:\n                        util.display_image(np_result_image, title='%d / %d' % (it, FLAGS.iter))\n                    if save_image:\n                        util.save_image(os.path.join(save_path, ('img_%s_%06d.png' % (pass_name, img_cnt))), np_result_image)\n                        img_cnt = img_cnt + 1\n\n        iter_start_time = time.time()\n\n        # ==============================================================================================\n        #  Zero gradients\n        # ==============================================================================================\n        optimizer.zero_grad()\n        if optimize_geometry:\n            optimizer_mesh.zero_grad()\n        if optimize_light:\n            optimizer_light.zero_grad()\n\n        # ==============================================================================================\n        #  Training\n        # ==============================================================================================\n\n        xfm_lgt = None\n        if optimize_light:\n            if False and FLAGS.camera_space_light:\n                lgt.xfm(target['mv'])\n            elif False and ('envlight_transform' in target and target['envlight_transform'] is not None):\n                xfm_lgt = target['envlight_transform']\n                lgt.xfm(xfm_lgt)\n            lgt.update_pdf()\n            \n\n        img_loss, depth_loss, reg_loss = geometry.tick(\n            glctx, target, lgt, opt_material, image_loss_fn, it, \n            denoiser=denoiser)\n\n        # ==============================================================================================\n        #  Final loss\n        # ==============================================================================================\n        total_loss = img_loss + reg_loss\n\n        img_loss_vec.append(img_loss.item())\n        depth_loss_vec.append(depth_loss.item())\n        reg_loss_vec.append(reg_loss.item())\n\n        # ==============================================================================================\n        #  Backpropagate\n        # ==============================================================================================\n        total_loss.backward()\n        if hasattr(lgt, 'base') and lgt.base.grad is not None and optimize_light:\n            lgt.base.grad *= 64\n        if 'kd_ks' in opt_material:\n            opt_material['kd_ks'].encoder.params.grad /= 8.0\n        if 'kd_ks_back' in opt_material:\n            opt_material['kd_ks_back'].encoder.params.grad /= 8.0\n\n        # Optionally clip gradients\n        if FLAGS.clip_max_norm > 0.0:\n            if optimize_geometry:\n                torch.nn.utils.clip_grad_norm_(geometry.parameters() + params, FLAGS.clip_max_norm)\n            else:\n                torch.nn.utils.clip_grad_norm_(params, FLAGS.clip_max_norm)\n\n        optimizer.step()\n        scheduler.step()\n\n        if optimize_geometry:\n            optimizer_mesh.step()\n            scheduler_mesh.step()\n\n        if optimize_light:\n            optimizer_light.step()\n            scheduler_light.step()\n\n        # ==============================================================================================\n        #  Clamp trainables to reasonable range\n        # ==============================================================================================\n        with torch.no_grad():\n            if 'kd' in opt_material:\n                opt_material['kd'].clamp_()\n            if 'ks' in opt_material:\n                opt_material['ks'].clamp_()\n            if 'kd_back' in opt_material:\n                opt_material['kd_back'].clamp_()\n            if 'ks_back' in opt_material:\n                opt_material['ks_back'].clamp_()\n            if 'normal' in opt_material and not FLAGS.normal_only:\n                opt_material['normal'].clamp_()\n                opt_material['normal'].normalize_()\n            if lgt is not None:\n                # lgt.clamp_(min=0.01) # For some reason gradient dissapears if light becomes 0\n                lgt.clamp_(min=1e-4) # For some reason gradient dissapears if light becomes 0\n\n            geometry.clamp_deform()\n        torch.cuda.current_stream().synchronize()\n        iter_dur_vec.append(time.time() - iter_start_time)\n\n        # ==============================================================================================\n        #  Logging\n        # ==============================================================================================\n        if it % log_interval == 0 and FLAGS.local_rank == 0:\n            img_loss_avg = np.mean(np.asarray(img_loss_vec[-log_interval:]))\n            depth_loss_avg = np.mean(np.asarray(depth_loss_vec[-log_interval:]))\n            reg_loss_avg = np.mean(np.asarray(reg_loss_vec[-log_interval:]))\n            iter_dur_avg = np.mean(np.asarray(iter_dur_vec[-log_interval:]))\n            \n            remaining_time = (FLAGS.iter-it)*iter_dur_avg\n            print(\"iter=%5d, img_loss=%.6f, depth_loss=%.6f, reg_loss=%.6f, lr=%.5f, time=%.1f ms, rem=%s\" % \n                (it, img_loss_avg, depth_loss_avg, reg_loss_avg, optimizer.param_groups[0]['lr'], iter_dur_avg*1000, util.time_to_text(remaining_time)))\n            sys.stdout.flush()\n\n        if it == FLAGS.iter:\n            break\n\n    return geometry, opt_material\n\n#----------------------------------------------------------------------------\n# Main function.\n#----------------------------------------------------------------------------\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser(description='nvdiffrec')\n    parser.add_argument('--config', type=str, default=None, help='Config file')\n    parser.add_argument('-i', '--iter', type=int, default=5000)\n    parser.add_argument('-b', '--batch', type=int, default=1)\n    parser.add_argument('-s', '--spp', type=int, default=1)\n    parser.add_argument('-l', '--layers', type=int, default=1)\n    parser.add_argument('-r', '--train-res', nargs=2, type=int, default=[512, 512])\n    parser.add_argument('-dr', '--display-res', type=int, default=None)\n    parser.add_argument('-tr', '--texture-res', nargs=2, type=int, default=[1024, 1024])\n    parser.add_argument('-di', '--display-interval', type=int, default=0)\n    parser.add_argument('-si', '--save-interval', type=int, default=1000)\n    parser.add_argument('-lr', '--learning-rate', type=float, default=0.01)\n    parser.add_argument('-mr', '--min-roughness', type=float, default=0.08)\n    parser.add_argument('-mip', '--custom-mip', action='store_true', default=False)\n    parser.add_argument('-rt', '--random-textures', action='store_true', default=False)\n    parser.add_argument('-bg', '--background', default='checker', choices=['black', 'white', 'checker', 'reference'])\n    parser.add_argument('--loss', default='logl1', choices=['logl1', 'logl2', 'mse', 'smape', 'relmse'])\n    parser.add_argument('-o', '--out-dir', type=str, default=None)\n    parser.add_argument('-rm', '--ref_mesh', type=str)\n    parser.add_argument('-bm', '--base-mesh', type=str, default=None)\n    parser.add_argument('--validate', type=bool, default=True)\n    # Render specific arguments\n    parser.add_argument('--n_samples', type=int, default=4)\n    parser.add_argument('--bsdf', type=str, default='pbr', choices=['pbr', 'diffuse', 'white'])\n    # Denoiser specific arguments\n    parser.add_argument('--denoiser', default='bilateral', choices=['none', 'bilateral'])\n    parser.add_argument('--denoiser_demodulate', type=bool, default=True)\n    parser.add_argument('--msdf_reg_open_scale', type=float, default=1e-6)\n    parser.add_argument('--msdf_reg_close_scale', type=float, default=3e-4)\n    parser.add_argument('--eikonal_scale', type=float, default=5e-2)\n    parser.add_argument('--trainset_path', type=str)\n    parser.add_argument('--testset_path', type=str, default='')\n\n    FLAGS = parser.parse_args()\n    FLAGS.mtl_override        = None        # Override material of model\n    FLAGS.gshell_grid          = 64          # Resolution of initial tet grid. We provide 64 and 128 resolution grids. \n                                            #    Other resolutions can be generated with https://github.com/crawforddoran/quartet\n                                            #    We include examples in data/tets/generate_tets.py\n    FLAGS.mesh_scale          = 3.6         # Scale of tet grid box. Adjust to cover the model\n    FLAGS.envlight            = None        # HDR environment probe\n    FLAGS.env_scale           = 1.0         # Env map intensity multiplier\n    FLAGS.probe_res           = 256         # Env map probe resolution\n    FLAGS.learn_lighting      = True        # Enable optimization of env lighting\n    FLAGS.display             = None        # Configure validation window/display. E.g. [{\"bsdf\" : \"kd\"}, {\"bsdf\" : \"ks\"}]\n    FLAGS.transparency        = False       # Enabled transparency through depth peeling\n    FLAGS.lock_light          = False       # Disable light optimization in the second pass\n    FLAGS.lock_pos            = False       # Disable vertex position optimization in the second pass\n    FLAGS.sdf_regularizer     = 0.2         # Weight for sdf regularizer.\n    FLAGS.laplace             = \"relative\"  # Mesh Laplacian [\"absolute\", \"relative\"]\n    FLAGS.laplace_scale       = 3000.0      # Weight for Laplace regularizer. Default is relative with large weight\n    FLAGS.pre_load            = True        # Pre-load entire dataset into memory for faster training\n    FLAGS.no_perturbed_nrm    = False       # Disable normal map\n    FLAGS.decorrelated        = False       # Use decorrelated sampling in forward and backward passes\n    FLAGS.kd_min              = [ 0.0,  0.0,  0.0,  0.0]\n    FLAGS.kd_max              = [ 1.0,  1.0,  1.0,  1.0]\n    # FLAGS.ks_min              = [ 0.0,  0.08, 0.0]\n    FLAGS.ks_min              = [ 0.0,  0.001, 0.0]\n    FLAGS.ks_max              = [ 0.0,  1.0,  1.0]\n    FLAGS.nrm_min             = [-1.0, -1.0,  0.0]\n    FLAGS.nrm_max             = [ 1.0,  1.0,  1.0]\n    FLAGS.clip_max_norm       = 0.0\n    FLAGS.cam_near_far        = [0.1, 1000.0]\n    FLAGS.lambda_kd           = 0.1\n    FLAGS.lambda_ks           = 0.05\n    FLAGS.lambda_nrm          = 0.025\n    FLAGS.lambda_nrm2         = 0.25\n    FLAGS.lambda_chroma       = 0.0\n    FLAGS.lambda_diffuse      = 0.15\n    FLAGS.lambda_specular     = 0.0025\n\n    FLAGS.random_lgt                  = False\n    FLAGS.normal_only                 = False\n    FLAGS.use_img_2nd_layer           = False\n    FLAGS.use_depth                   = False\n    FLAGS.use_depth_2nd_layer         = False\n    FLAGS.use_tanh_deform             = False\n    FLAGS.use_sdf_mlp                 = True\n    FLAGS.use_msdf_mlp                = False\n    FLAGS.use_eikonal                 = True\n    FLAGS.sdf_mlp_pretrain_steps      = 10000\n    FLAGS.use_mesh_msdf_reg           = True\n    FLAGS.sphere_init                 = False\n    FLAGS.sphere_init_norm            = 1.5\n    FLAGS.pretrained_sdf_mlp_path     = f'./data/pretrained_mlp_{FLAGS.gshell_grid}_polycam.pt'\n    FLAGS.n_hidden                    = 6\n    FLAGS.d_hidden                    = 256\n    FLAGS.n_freq                      = 6\n    FLAGS.skip_in                     = [3]\n    FLAGS.use_float16                 = False\n    FLAGS.visualize_watertight        = False\n\n    FLAGS.local_rank = 0\n    FLAGS.multi_gpu  = \"WORLD_SIZE\" in os.environ and int(os.environ[\"WORLD_SIZE\"]) > 1\n    if FLAGS.multi_gpu:\n        if \"MASTER_ADDR\" not in os.environ:\n            os.environ[\"MASTER_ADDR\"] = 'localhost'\n        if \"MASTER_PORT\" not in os.environ:\n            os.environ[\"MASTER_PORT\"] = '23456'\n\n        FLAGS.local_rank = int(os.environ[\"LOCAL_RANK\"])\n        torch.cuda.set_device(FLAGS.local_rank)\n        torch.distributed.init_process_group(backend=\"nccl\", init_method=\"env://\")\n\n    if FLAGS.config is not None:\n        data = json.load(open(FLAGS.config, 'r'))\n        for key in data:\n            FLAGS.__dict__[key] = data[key]\n\n    if FLAGS.display_res is None:\n        FLAGS.display_res = FLAGS.train_res\n\n    if FLAGS.local_rank == 0:\n        print(\"Config / Flags:\")\n        print(\"---------\")\n        for key in FLAGS.__dict__.keys():\n            print(key, FLAGS.__dict__[key])\n        print(\"---------\")\n\n    os.makedirs(FLAGS.out_dir, exist_ok=True)\n\n    glctx = dr.RasterizeGLContext()\n    glctx_display = glctx if FLAGS.batch < 16 else dr.RasterizeGLContext() # Context for display\n\n    mtl_default = None\n\n    # ==============================================================================================\n    #  Create data pipeline\n    # ==============================================================================================\n    data_root = FLAGS.trainset_path\n\n\n    dataset_train    = DatasetNERF(os.path.join(data_root, 'transforms.json'), FLAGS, examples=int(1e6))\n    dataset_validate = DatasetNERF(os.path.join(data_root, 'transforms.json'), FLAGS)\n\n\n\n    # ==============================================================================================\n    #  Create env light with trainable parameters\n    # ==============================================================================================\n    \n    lgt = None\n    if FLAGS.learn_lighting:\n        lgt = light.create_trainable_env_rnd(FLAGS.probe_res, scale=0.0, bias=0.5)\n        # lgt = light.create_trainable_env_rnd(FLAGS.probe_res, scale=0.0, bias=0.1)\n    else:\n        lgt = light.load_env(FLAGS.envlight, scale=FLAGS.env_scale, res=[FLAGS.probe_res, FLAGS.probe_res])\n\n    # ==============================================================================================\n    #  Setup denoiser\n    # ==============================================================================================\n\n    denoiser = None\n    if FLAGS.denoiser == 'bilateral':\n        denoiser = BilateralDenoiser().cuda()\n    else:\n        assert FLAGS.denoiser == 'none', \"Invalid denoiser %s\" % FLAGS.denoiser\n\n    # Setup geometry for optimization\n    geometry = GShellFlexiCubesGeometry(FLAGS.gshell_grid, FLAGS.mesh_scale, FLAGS)\n\n    # Setup textures, make initial guess from reference if possible\n    if not FLAGS.normal_only:\n        mat = initial_guess_material(geometry, True, FLAGS, mtl_default)\n    else:\n        mat = initial_guess_material_knownkskd(geometry, True, FLAGS, mtl_default)\n    mat['no_perturbed_nrm'] = True\n\n    # Run optimization\n    geometry, mat = optimize_mesh(denoiser, glctx, geometry, mat, lgt, dataset_train, dataset_validate, \n                    FLAGS, pass_idx=0, pass_name=\"pass1\", optimize_light=FLAGS.learn_lighting, save_path=FLAGS.out_dir)\n\n    validate(glctx, geometry, mat, lgt, dataset_validate, os.path.join(FLAGS.out_dir, \"validate\"), FLAGS, denoiser=denoiser, save_viz=True)\n\n    with torch.no_grad():\n        os.makedirs(os.path.join(FLAGS.out_dir, \"mesh\"), exist_ok=True)\n        torch.save(geometry.state_dict(), os.path.join(FLAGS.out_dir, \"mesh/model.pt\"))\n        torch.save(mat['kd_ks'].state_dict(), os.path.join(FLAGS.out_dir, \"mesh/mtl.pt\"))\n        light.save_env_map(os.path.join(FLAGS.out_dir, \"mesh/probe.hdr\"), lgt)\n\n        # Create textured mesh from result\n        base_mesh = geometry.getMesh(mat)['imesh']\n\n\n        # Dump mesh for debugging.\n        os.makedirs(os.path.join(FLAGS.out_dir, \"mesh\"), exist_ok=True)\n        obj.write_obj(os.path.join(FLAGS.out_dir, \"mesh/\"), base_mesh, save_material=False)\n\n        # Free temporaries / cached memory\n        torch.cuda.empty_cache()\n        mat['kd_ks'].cleanup()\n        del mat['kd_ks']\n        if 'kd_ks_back' in mat:\n            mat['kd_ks_back'].cleanup()\n            del mat['kd_ks_back']\n\n        # Free temporaries / cached memory\n        torch.cuda.empty_cache()\n        del mat"
  },
  {
    "path": "train_gshelltet_deepfashion.py",
    "content": "# Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. \n#\n# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual\n# property and proprietary rights in and to this material, related\n# documentation and any modifications thereto. Any use, reproduction, \n# disclosure or distribution of this material and related documentation \n# without an express license agreement from NVIDIA CORPORATION or \n# its affiliates is strictly prohibited.\n\nimport os\nimport sys\nimport time\nimport argparse\nimport json\n\nimport numpy as np\nimport torch\nimport nvdiffrast.torch as dr\nimport xatlas\n\n# Import data readers / generators\nfrom dataset.dataset_deepfashion import DatasetDeepFashion\nfrom dataset.dataset_deepfashion_testset import DatasetDeepFashionTestset\n\n# Import topology / geometry trainers\nfrom geometry.gshell_tets_geometry import GShellTetsGeometry\n\nimport render.renderutils as ru\nfrom render import obj\nfrom render import material\nfrom render import util\nfrom render import mesh\nfrom render import texture\nfrom render import mlptexture\nfrom render import light\nfrom render import render\n\n\nfrom denoiser.denoiser import BilateralDenoiser\n\n\nRADIUS = 3.0\n\n# Enable to debug back-prop anomalies\n# torch.autograd.set_detect_anomaly(True)\n\n###############################################################################\n# Loss setup\n###############################################################################\n\n@torch.no_grad()\ndef createLoss(FLAGS):\n    if FLAGS.loss == \"smape\":\n        return lambda img, ref: ru.image_loss(img, ref, loss='smape', tonemapper='none')\n    elif FLAGS.loss == \"mse\":\n        return lambda img, ref: ru.image_loss(img, ref, loss='mse', tonemapper='none')\n    elif FLAGS.loss == \"logl1\":\n        return lambda img, ref: ru.image_loss(img, ref, loss='l1', tonemapper='log_srgb')\n    elif FLAGS.loss == \"logl2\":\n        return lambda img, ref: ru.image_loss(img, ref, loss='mse', tonemapper='log_srgb')\n    elif FLAGS.loss == \"relmse\":\n        return lambda img, ref: ru.image_loss(img, ref, loss='relmse', tonemapper='none')\n    else:\n        assert False\n\n###############################################################################\n# Mix background into a dataset image\n###############################################################################\n\n@torch.no_grad()\ndef prepare_batch(target, bg_type='black'):\n    assert len(target['img'].shape) == 4, \"Image shape should be [n, h, w, c]\"\n    if bg_type == 'checker':\n        background = torch.tensor(util.checkerboard(target['img'].shape[1:3], 8), dtype=torch.float32, device='cuda')[None, ...]\n    elif bg_type == 'black':\n        background = torch.zeros(target['img'].shape[0:3] + (3,), dtype=torch.float32, device='cuda')\n    elif bg_type == 'white':\n        background = torch.ones(target['img'].shape[0:3] + (3,), dtype=torch.float32, device='cuda')\n    elif bg_type == 'reference':\n        background = target['img'][..., 0:3]\n    elif bg_type == 'random':\n        background = torch.rand(target['img'].shape[0:3] + (3,), dtype=torch.float32, device='cuda')\n    else:\n        assert False, \"Unknown background type %s\" % bg_type\n\n    target['mv'] = target['mv'].cuda()\n    target['mvp'] = target['mvp'].cuda()\n    target['campos'] = target['campos'].cuda()\n    target['img'] = target['img'].cuda()\n    target['background'] = background\n\n    target['img'] = torch.cat((torch.lerp(background, target['img'][..., 0:3], target['img'][..., 3:4]), target['img'][..., 3:4]), dim=-1)\n\n    return target\n\n###############################################################################\n# UV - map geometry & convert to a mesh\n###############################################################################\n\n@torch.no_grad()\ndef xatlas_uvmap(glctx, geometry, mat, FLAGS):\n    eval_mesh = geometry.getMesh(mat)\n    try:\n        eval_mesh = eval_mesh['imesh']\n    except:\n        pass\n    \n    # Create uvs with xatlas\n    v_pos = eval_mesh.v_pos.detach().cpu().numpy()\n    t_pos_idx = eval_mesh.t_pos_idx.detach().cpu().numpy()\n    vmapping, indices, uvs = xatlas.parametrize(v_pos, t_pos_idx)\n\n    # Convert to tensors\n    indices_int64 = indices.astype(np.uint64, casting='same_kind').view(np.int64)\n    \n    uvs = torch.tensor(uvs, dtype=torch.float32, device='cuda')\n    faces = torch.tensor(indices_int64, dtype=torch.int64, device='cuda')\n\n    new_mesh = mesh.Mesh(v_tex=uvs, t_tex_idx=faces, base=eval_mesh)\n\n    mask, kd, ks = render.render_uv(glctx, new_mesh, FLAGS.texture_res, eval_mesh.material['kd_ks'])\n\n    # Dilate all textures & use average color for background\n    kd_avg = torch.sum(torch.sum(torch.sum(kd * mask, dim=0), dim=0), dim=0) / torch.sum(torch.sum(torch.sum(mask, dim=0), dim=0), dim=0)\n    kd = util.dilate(kd, kd_avg[None, None, None, :], mask, 7)\n\n    ks_avg = torch.sum(torch.sum(torch.sum(ks * mask, dim=0), dim=0), dim=0) / torch.sum(torch.sum(torch.sum(mask, dim=0), dim=0), dim=0)\n    ks = util.dilate(ks, ks_avg[None, None, None, :], mask, 7)\n\n    nrm_avg = torch.tensor([0, 0, 1], dtype=torch.float32, device=\"cuda\")\n    normal = nrm_avg[None, None, None, :].repeat(kd.shape[0], kd.shape[1], kd.shape[2], 1)\n    \n    new_mesh.material = mat.copy()\n    del new_mesh.material['kd_ks']\n\n    if FLAGS.transparency:\n        kd = torch.cat((kd, torch.rand_like(kd[...,0:1])), dim=-1)\n        print(\"kd shape\", kd.shape)\n\n    kd_min, kd_max = torch.tensor(FLAGS.kd_min, dtype=torch.float32, device='cuda'), torch.tensor(FLAGS.kd_max, dtype=torch.float32, device='cuda')\n    ks_min, ks_max = torch.tensor(FLAGS.ks_min, dtype=torch.float32, device='cuda'), torch.tensor(FLAGS.ks_max, dtype=torch.float32, device='cuda')\n    nrm_min, nrm_max = torch.tensor(FLAGS.nrm_min, dtype=torch.float32, device='cuda'), torch.tensor(FLAGS.nrm_max, dtype=torch.float32, device='cuda')\n    new_mesh.material.update({\n        'kd'     : texture.Texture2D(kd.clone().detach().requires_grad_(True), min_max=[kd_min, kd_max]),\n        'ks'     : texture.Texture2D(ks.clone().detach().requires_grad_(True), min_max=[ks_min, ks_max]),\n        'normal' : texture.Texture2D(normal.clone().detach().requires_grad_(True), min_max=[nrm_min, nrm_max]),\n    })\n\n    return new_mesh\n\n###############################################################################\n# Utility functions for material\n###############################################################################\n\ndef initial_guess_material(geometry, mlp, FLAGS, init_mat=None):\n    kd_min, kd_max = torch.tensor(FLAGS.kd_min, dtype=torch.float32, device='cuda'), torch.tensor(FLAGS.kd_max, dtype=torch.float32, device='cuda')\n    ks_min, ks_max = torch.tensor(FLAGS.ks_min, dtype=torch.float32, device='cuda'), torch.tensor(FLAGS.ks_max, dtype=torch.float32, device='cuda')\n    if mlp:\n        mlp_min = torch.cat((kd_min[0:3], ks_min), dim=0)\n        mlp_max = torch.cat((kd_max[0:3], ks_max), dim=0)\n        mlp_map_opt = mlptexture.MLPTexture3D(geometry.getAABB(), channels=6, min_max=[mlp_min, mlp_max], use_float16=FLAGS.use_float16)\n        mat =  {'kd_ks' : mlp_map_opt}\n    else:\n        raise NotImplementedError\n\n    mat['bsdf'] = FLAGS.bsdf\n\n    mat['no_perturbed_nrm'] = FLAGS.no_perturbed_nrm\n\n    return mat\n\ndef initial_guess_material_knownkskd(geometry, mlp, FLAGS, init_mat=None):\n    mat =  {\n        'kd'     : init_mat['kd'],\n        'ks'     : init_mat['ks']\n    }\n\n    if init_mat is not None:\n        mat['bsdf'] = init_mat['bsdf']\n    else:\n        mat['bsdf'] = 'pbr'\n\n    return mat\n\n###############################################################################\n# Validation & testing\n###############################################################################\n\n@torch.no_grad()\ndef validate_itr(glctx, target, geometry, opt_material, lgt, FLAGS, denoiser=None):\n    result_dict = {}\n    with torch.no_grad():\n        buffers = geometry.render(glctx, target, lgt, opt_material, use_uv=False, denoiser=denoiser)['buffers']\n\n        result_dict['ref'] = util.rgb_to_srgb(target['img'][...,0:3])[0]\n        result_dict['opt'] = util.rgb_to_srgb(buffers['shaded'][...,0:3])[0]\n        result_dict['mask_opt'] = buffers['shaded'][...,3:][0].expand(-1, -1, 3)\n        result_dict['mask_ref'] = target['img'][...,3:][0].expand(-1, -1, 3)\n        result_dict['msdf_image'] = buffers['msdf_image'][...,:][0].expand(-1, -1, 3).clamp(min=0, max=1)\n        result_image = torch.cat([result_dict['opt'], result_dict['ref'], result_dict['mask_opt'], result_dict['mask_ref'], result_dict['msdf_image']], axis=1)\n\n        if FLAGS.display is not None:\n            white_bg = torch.ones_like(target['background'])\n            for layer in FLAGS.display:\n                if 'latlong' in layer and layer['latlong']:\n                    result_dict['light_image'] = lgt.generate_image(FLAGS.display_res)\n                    result_dict['light_image'] = util.rgb_to_srgb(result_dict['light_image'] / (1 + result_dict['light_image']))\n                    result_image = torch.cat([result_image, result_dict['light_image']], axis=1)\n                elif 'bsdf' in layer:\n                    img = render.render_mesh(FLAGS, glctx, opt_mesh, target['mvp'], target['campos'], target['light'] if lgt is None else lgt, target['resolution'],\n                                                spp=target['spp'], num_layers=FLAGS.layers, background=white_bg, bsdf=layer['bsdf'], optix_ctx=geometry.optix_ctx)['shaded']\n                    if layer['bsdf'] == 'kd':\n                        result_dict[layer['bsdf']] = util.rgb_to_srgb(img[..., 0:3])[0]\n                    else:\n                        result_dict[layer['bsdf']] = img[0, ..., 0:3]\n                    result_image = torch.cat([result_image, result_dict[layer['bsdf']]], axis=1)\n                elif 'normals' in layer and not FLAGS.no_perturbed_nrm:\n                    result_image = torch.cat([result_image, (buffers['perturbed_nrm'][0, ...,0:3] + 1.0) * 0.5], axis=1)\n                elif 'diffuse_light' in layer:\n                    result_image = torch.cat([result_image, util.rgb_to_srgb(buffers['diffuse_light'][..., 0:3])[0]], axis=1)\n                elif 'specular_light' in layer:\n                    result_image = torch.cat([result_image, util.rgb_to_srgb(buffers['specular_light'][..., 0:3])[0]], axis=1)\n\n        return result_image, result_dict\n\n@torch.no_grad()\ndef validate(glctx, geometry, opt_material, lgt, dataset_validate, out_dir, FLAGS, denoiser=None, save_viz=False):\n\n    # ==============================================================================================\n    #  Validation loop\n    # ==============================================================================================\n    mse_values = []\n    psnr_values = []\n\n    dataloader_validate = torch.utils.data.DataLoader(dataset_validate, batch_size=1, collate_fn=dataset_validate.collate)\n\n    os.makedirs(out_dir, exist_ok=True)\n    with open(os.path.join(out_dir, 'metrics.txt'), 'w') as fout:\n        fout.write('ID, MSE, PSNR\\n')\n\n        print(\"Running validation\")\n        for it, target in enumerate(dataloader_validate):\n\n            # Mix validation background\n            target = prepare_batch(target, FLAGS.background)\n\n            result_image, result_dict = validate_itr(glctx, target, geometry, opt_material, lgt, FLAGS, denoiser=denoiser)\n\n            # Compute metrics\n            opt = torch.clamp(result_dict['opt'], 0.0, 1.0) \n            ref = torch.clamp(result_dict['ref'], 0.0, 1.0)\n\n            mse = torch.nn.functional.mse_loss(opt, ref, size_average=None, reduce=None, reduction='mean').item()\n            mse_values.append(float(mse))\n            psnr = util.mse_to_psnr(mse)\n            psnr_values.append(float(psnr))\n\n            line = \"%d, %1.8f, %1.8f\\n\" % (it, mse, psnr)\n            fout.write(str(line))\n\n            if save_viz:\n                for k in result_dict.keys():\n                    np_img = result_dict[k].detach().cpu().numpy()\n                    util.save_image(out_dir + '/' + ('val_%06d_%s.png' % (it, k)), np_img)\n\n        avg_mse = np.mean(np.array(mse_values))\n        avg_psnr = np.mean(np.array(psnr_values))\n        line = \"AVERAGES: %1.4f, %2.3f\\n\" % (avg_mse, avg_psnr)\n        fout.write(str(line))\n        print(\"MSE,      PSNR\")\n        print(\"%1.8f, %2.3f\" % (avg_mse, avg_psnr))\n    return avg_psnr\n\n###############################################################################\n# Main shape fitter function / optimization loop\n###############################################################################\n\ndef optimize_mesh(\n        denoiser,\n        glctx,\n        geometry,\n        opt_material,\n        lgt,\n        dataset_train,\n        dataset_validate,\n        FLAGS,\n        warmup_iter=0,\n        log_interval=10,\n        pass_idx=0,\n        pass_name=\"\",\n        optimize_light=True,\n        optimize_geometry=True,\n        visualize=True,\n        save_path=None\n    ):\n\n    # ==============================================================================================\n    #  Setup torch optimizer\n    # ==============================================================================================\n\n    learning_rate = FLAGS.learning_rate[pass_idx] if isinstance(FLAGS.learning_rate, list) or isinstance(FLAGS.learning_rate, tuple) else FLAGS.learning_rate\n    learning_rate_pos = learning_rate[0] if isinstance(learning_rate, list) or isinstance(learning_rate, tuple) else learning_rate\n    learning_rate_mat = learning_rate[1] if isinstance(learning_rate, list) or isinstance(learning_rate, tuple) else learning_rate\n    learning_rate_lgt = learning_rate[2] if isinstance(learning_rate, list) or isinstance(learning_rate, tuple) else learning_rate * 6.0\n\n    def lr_schedule(iter, fraction):\n        if iter < warmup_iter:\n            return iter / warmup_iter \n        return max(0.0, 10**(-(iter - warmup_iter)*0.0002)) # Exponential falloff from [1.0, 0.1] over 5k epochs.    \n\n    # ==============================================================================================\n    #  Image loss\n    # ==============================================================================================\n    image_loss_fn = createLoss(FLAGS)\n\n\n\n    params = list(material.get_parameters(opt_material))\n\n    if optimize_light:\n        optimizer_light = torch.optim.Adam((lgt.parameters() if lgt is not None else []), lr=learning_rate_lgt)\n        scheduler_light = torch.optim.lr_scheduler.LambdaLR(optimizer_light, lr_lambda=lambda x: lr_schedule(x, 0.9)) \n\n    if optimize_geometry:\n        if FLAGS.use_sdf_mlp:\n            lr_msdf = learning_rate_pos * 1e-2 if FLAGS.use_msdf_mlp else learning_rate_pos\n            deform_params = list(v[1] for v in geometry.named_parameters() if 'deform' in v[0]) if optimize_geometry else []\n            msdf_params = list(v[1] for v in geometry.named_parameters() if 'msdf' in v[0]) if optimize_geometry else []\n            sdf_params = list(v[1] for v in geometry.named_parameters() if 'sdf' in v[0] and 'msdf' not in v[0]) if optimize_geometry else []\n            other_params = list(v[1] for v in geometry.named_parameters() if 'sdf' not in v[0] and 'msdf' not in v[0] and 'deform' not in v[0]) if optimize_geometry else []\n            optimizer_mesh = torch.optim.Adam([\n                    {'params': deform_params, 'lr': learning_rate_pos},\n                    {'params': msdf_params, 'lr': lr_msdf},\n                    {'params': sdf_params, 'lr': learning_rate_pos * 1e-2},\n                    {'params': other_params, 'lr': learning_rate_pos * 1e-2},\n                ], eps=1e-8)\n        else:\n            optimizer_mesh = torch.optim.Adam(geometry.parameters(), lr=learning_rate_pos)\n        scheduler_mesh = torch.optim.lr_scheduler.LambdaLR(optimizer_mesh, lr_lambda=lambda x: lr_schedule(x, 0.9)) \n\n    optimizer = torch.optim.Adam(params, lr=learning_rate_mat)\n    scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda x: lr_schedule(x, 0.9))\n\n    # ==============================================================================================\n    #  Training loop\n    # ==============================================================================================\n    img_cnt = 0\n    img_loss_vec = []\n    depth_loss_vec = []\n    reg_loss_vec = []\n    iter_dur_vec = []\n\n    dataloader_train    = torch.utils.data.DataLoader(dataset_train, batch_size=FLAGS.batch, collate_fn=dataset_train.collate, shuffle=True)\n    if visualize:\n        dataloader_validate = torch.utils.data.DataLoader(dataset_validate, batch_size=1, collate_fn=dataset_train.collate)\n\n        def cycle(iterable):\n            iterator = iter(iterable)\n            while True:\n                try:\n                    yield next(iterator)\n                except StopIteration:\n                    iterator = iter(iterable)\n\n        v_it = cycle(dataloader_validate)\n\n    for it, target in enumerate(dataloader_train):\n\n        # Mix randomized background into dataset image\n        target = prepare_batch(target, 'random')\n\n        # ==============================================================================================\n        #  Display / save outputs. Do it before training so we get initial meshes\n        # ==============================================================================================\n\n        # Show/save image before training step (want to get correct rendering of input)\n        if visualize and FLAGS.local_rank == 0 and it != 0:\n            with torch.no_grad():\n                display_image = FLAGS.display_interval and (it % FLAGS.display_interval == 0)\n                save_image = FLAGS.save_interval and (it % FLAGS.save_interval == 0)\n                if display_image or save_image:\n                    save_mesh = True\n                    if save_mesh:\n                        os.makedirs(os.path.join(save_path, pass_name), exist_ok=True)\n                        obj.write_obj(os.path.join(save_path, pass_name), geometry.getMesh(opt_material)['imesh'], save_material=False)\n                    result_image, result_dict = validate_itr(glctx, prepare_batch(next(v_it), FLAGS.background), geometry, opt_material, lgt, FLAGS, denoiser=denoiser)\n            \n                    np_result_image = result_image.detach().cpu().numpy()\n                    if display_image:\n                        util.display_image(np_result_image, title='%d / %d' % (it, FLAGS.iter))\n                    if save_image:\n                        util.save_image(os.path.join(save_path, ('img_%s_%06d.png' % (pass_name, img_cnt))), np_result_image)\n                        img_cnt = img_cnt + 1\n\n        iter_start_time = time.time()\n\n        # ==============================================================================================\n        #  Zero gradients\n        # ==============================================================================================\n        optimizer.zero_grad()\n        if optimize_geometry:\n            optimizer_mesh.zero_grad()\n        if optimize_light:\n            optimizer_light.zero_grad()\n\n        # ==============================================================================================\n        #  Training\n        # ==============================================================================================\n\n        xfm_lgt = None\n        if optimize_light:\n            lgt.update_pdf()\n            \n\n        img_loss, depth_loss, reg_loss = geometry.tick(\n            glctx, target, lgt, opt_material, image_loss_fn, it, \n            denoiser=denoiser)\n\n        # ==============================================================================================\n        #  Final loss\n        # ==============================================================================================\n        total_loss = img_loss + reg_loss\n\n        img_loss_vec.append(img_loss.item())\n        depth_loss_vec.append(depth_loss.item())\n        reg_loss_vec.append(reg_loss.item())\n\n        # ==============================================================================================\n        #  Backpropagate\n        # ==============================================================================================\n        total_loss.backward()\n        if hasattr(lgt, 'base') and lgt.base.grad is not None and optimize_light:\n            lgt.base.grad *= 64\n        if 'kd_ks' in opt_material:\n            opt_material['kd_ks'].encoder.params.grad /= 8.0\n        if 'kd_ks_back' in opt_material:\n            opt_material['kd_ks_back'].encoder.params.grad /= 8.0\n\n        # Optionally clip gradients\n        if FLAGS.clip_max_norm > 0.0:\n            if optimize_geometry:\n                torch.nn.utils.clip_grad_norm_(geometry.parameters() + params, FLAGS.clip_max_norm)\n            else:\n                torch.nn.utils.clip_grad_norm_(params, FLAGS.clip_max_norm)\n\n        optimizer.step()\n        scheduler.step()\n\n        if optimize_geometry:\n            optimizer_mesh.step()\n            scheduler_mesh.step()\n\n        if optimize_light:\n            optimizer_light.step()\n            scheduler_light.step()\n\n        # ==============================================================================================\n        #  Clamp trainables to reasonable range\n        # ==============================================================================================\n        with torch.no_grad():\n            if 'kd' in opt_material:\n                opt_material['kd'].clamp_()\n            if 'ks' in opt_material:\n                opt_material['ks'].clamp_()\n            if 'kd_back' in opt_material:\n                opt_material['kd_back'].clamp_()\n            if 'ks_back' in opt_material:\n                opt_material['ks_back'].clamp_()\n            if 'normal' in opt_material and not FLAGS.normal_only:\n                opt_material['normal'].clamp_()\n                opt_material['normal'].normalize_()\n            if lgt is not None:\n                # lgt.clamp_(min=0.01) # For some reason gradient dissapears if light becomes 0\n                lgt.clamp_(min=1e-4) # For some reason gradient dissapears if light becomes 0\n\n            geometry.clamp_deform()\n        torch.cuda.current_stream().synchronize()\n        iter_dur_vec.append(time.time() - iter_start_time)\n\n        # ==============================================================================================\n        #  Logging\n        # ==============================================================================================\n        if it % log_interval == 0 and FLAGS.local_rank == 0:\n            img_loss_avg = np.mean(np.asarray(img_loss_vec[-log_interval:]))\n            depth_loss_avg = np.mean(np.asarray(depth_loss_vec[-log_interval:]))\n            reg_loss_avg = np.mean(np.asarray(reg_loss_vec[-log_interval:]))\n            iter_dur_avg = np.mean(np.asarray(iter_dur_vec[-log_interval:]))\n            \n            remaining_time = (FLAGS.iter-it)*iter_dur_avg\n            print(\"iter=%5d, img_loss=%.6f, depth_loss=%.6f, reg_loss=%.6f, lr=%.5f, time=%.1f ms, rem=%s\" % \n                (it, img_loss_avg, depth_loss_avg, reg_loss_avg, optimizer.param_groups[0]['lr'], iter_dur_avg*1000, util.time_to_text(remaining_time)))\n            sys.stdout.flush()\n\n        if it == FLAGS.iter:\n            break\n\n    return geometry, opt_material\n\n#----------------------------------------------------------------------------\n# Main function.\n#----------------------------------------------------------------------------\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser(description='nvdiffrec')\n    parser.add_argument('--config', type=str, default=None, help='Config file')\n    parser.add_argument('-i', '--iter', type=int, default=5000)\n    parser.add_argument('-b', '--batch', type=int, default=1)\n    parser.add_argument('-s', '--spp', type=int, default=1)\n    parser.add_argument('-l', '--layers', type=int, default=1)\n    parser.add_argument('-r', '--train-res', nargs=2, type=int, default=[512, 512])\n    parser.add_argument('-dr', '--display-res', type=int, default=None)\n    parser.add_argument('-tr', '--texture-res', nargs=2, type=int, default=[1024, 1024])\n    parser.add_argument('-di', '--display-interval', type=int, default=0)\n    parser.add_argument('-si', '--save-interval', type=int, default=1000)\n    parser.add_argument('-lr', '--learning-rate', type=float, default=0.01)\n    parser.add_argument('-mr', '--min-roughness', type=float, default=0.08)\n    parser.add_argument('-mip', '--custom-mip', action='store_true', default=False)\n    parser.add_argument('-rt', '--random-textures', action='store_true', default=False)\n    parser.add_argument('-bg', '--background', default='checker', choices=['black', 'white', 'checker', 'reference'])\n    parser.add_argument('--loss', default='logl1', choices=['logl1', 'logl2', 'mse', 'smape', 'relmse'])\n    parser.add_argument('-o', '--out-dir', type=str, default=None)\n    parser.add_argument('-rm', '--ref_mesh', type=str)\n    parser.add_argument('-bm', '--base-mesh', type=str, default=None)\n    parser.add_argument('--validate', type=bool, default=True)\n    # Render specific arguments\n    parser.add_argument('--n_samples', type=int, default=4)\n    parser.add_argument('--bsdf', type=str, default='pbr', choices=['pbr', 'diffuse', 'white'])\n    # Denoiser specific arguments\n    parser.add_argument('--denoiser', default='bilateral', choices=['none', 'bilateral'])\n    parser.add_argument('--denoiser_demodulate', type=bool, default=True)\n    parser.add_argument('--index',type=int)\n    parser.add_argument('--msdf_reg_open_scale', type=float, default=1e-6)\n    parser.add_argument('--msdf_reg_close_scale', type=float, default=3e-6)\n    parser.add_argument('--eikonal_scale', type=float)\n    parser.add_argument('--sdf_regularizer', type=float, default=0.2)\n    parser.add_argument('--trainset_path', type=str)\n    parser.add_argument('--testset_path', type=str, default='')\n\n    FLAGS = parser.parse_args()\n    FLAGS.mtl_override        = None        # Override material of model\n    FLAGS.gshell_grid          = 64          # Resolution of initial tet grid. We provide 64 and 128 resolution grids. \n                                            #    Other resolutions can be generated with https://github.com/crawforddoran/quartet\n                                            #    We include examples in data/tets/generate_tets.py\n    FLAGS.mesh_scale          = 1.4         # Scale of tet grid box. Adjust to cover the model\n    FLAGS.envlight            = None        # HDR environment probe\n    FLAGS.env_scale           = 1.0         # Env map intensity multiplier\n    FLAGS.probe_res           = 256         # Env map probe resolution\n    FLAGS.learn_lighting      = True        # Enable optimization of env lighting\n    FLAGS.display             = None        # Configure validation window/display. E.g. [{\"bsdf\" : \"kd\"}, {\"bsdf\" : \"ks\"}]\n    FLAGS.transparency        = False       # Enabled transparency through depth peeling\n    FLAGS.lock_light          = False       # Disable light optimization in the second pass\n    FLAGS.lock_pos            = False       # Disable vertex position optimization in the second pass\n    # FLAGS.sdf_regularizer     = 0.2         # Weight for sdf regularizer.\n    FLAGS.laplace             = \"relative\"  # Mesh Laplacian [\"absolute\", \"relative\"]\n    FLAGS.laplace_scale       = 3000.0      # Weight for Laplace regularizer. Default is relative with large weight\n    FLAGS.pre_load            = True        # Pre-load entire dataset into memory for faster training\n    FLAGS.no_perturbed_nrm    = False       # Disable normal map\n    FLAGS.decorrelated        = False       # Use decorrelated sampling in forward and backward passes\n    FLAGS.kd_min              = [ 0.0,  0.0,  0.0,  0.0]\n    FLAGS.kd_max              = [ 1.0,  1.0,  1.0,  1.0]\n    FLAGS.ks_min              = [ 0.0,  0.001, 0.0]\n    FLAGS.ks_max              = [ 0.0,  1.0,  1.0]\n    FLAGS.nrm_min             = [-1.0, -1.0,  0.0]\n    FLAGS.nrm_max             = [ 1.0,  1.0,  1.0]\n    FLAGS.clip_max_norm       = 0.0\n    FLAGS.cam_near_far        = [0.1, 1000.0]\n    FLAGS.lambda_kd           = 0.1\n    FLAGS.lambda_ks           = 0.05\n    FLAGS.lambda_nrm          = 0.025\n    FLAGS.lambda_nrm2         = 0.25\n    FLAGS.lambda_chroma       = 0.0\n    FLAGS.lambda_diffuse      = 0.15\n    FLAGS.lambda_specular     = 0.0025\n\n    FLAGS.random_lgt                  = False\n    FLAGS.normal_only                 = False\n    FLAGS.use_img_2nd_layer           = False\n    FLAGS.use_depth                   = False\n    FLAGS.use_depth_2nd_layer         = False\n    FLAGS.use_tanh_deform             = False\n    FLAGS.use_sdf_mlp                 = True\n    FLAGS.use_msdf_mlp                = False\n    FLAGS.use_eikonal                 = True\n    FLAGS.sdf_mlp_pretrain_steps      = 1000\n    FLAGS.use_mesh_msdf_reg           = True\n    FLAGS.sphere_init                 = False\n    FLAGS.sphere_init_norm            = 0.5\n    FLAGS.pretrained_sdf_mlp_path     = f'./data/pretrained_mlp_{FLAGS.gshell_grid}_deeper.pt'\n    FLAGS.n_hidden                    = 6\n    FLAGS.d_hidden                    = 256\n    FLAGS.n_freq                      = 6\n    FLAGS.skip_in                     = [3]\n    FLAGS.use_float16                 = False\n    FLAGS.visualize_watertight        = False\n\n    FLAGS.local_rank = 0\n    FLAGS.multi_gpu  = \"WORLD_SIZE\" in os.environ and int(os.environ[\"WORLD_SIZE\"]) > 1\n    if FLAGS.multi_gpu:\n        if \"MASTER_ADDR\" not in os.environ:\n            os.environ[\"MASTER_ADDR\"] = 'localhost'\n        if \"MASTER_PORT\" not in os.environ:\n            os.environ[\"MASTER_PORT\"] = '23456'\n\n        FLAGS.local_rank = int(os.environ[\"LOCAL_RANK\"])\n        torch.cuda.set_device(FLAGS.local_rank)\n        torch.distributed.init_process_group(backend=\"nccl\", init_method=\"env://\")\n\n    if FLAGS.config is not None:\n        data = json.load(open(FLAGS.config, 'r'))\n        for key in data:\n            FLAGS.__dict__[key] = data[key]\n\n    if FLAGS.display_res is None:\n        FLAGS.display_res = FLAGS.train_res\n\n    if FLAGS.local_rank == 0:\n        print(\"Config / Flags:\")\n        print(\"---------\")\n        for key in FLAGS.__dict__.keys():\n            print(key, FLAGS.__dict__[key])\n        print(\"---------\")\n\n    os.makedirs(FLAGS.out_dir, exist_ok=True)\n\n    glctx = dr.RasterizeGLContext()\n    glctx_display = glctx if FLAGS.batch < 16 else dr.RasterizeGLContext() # Context for display\n\n    mtl_default = None\n\n    # ==============================================================================================\n    #  Create data pipeline\n    # ==============================================================================================\n    dataset_path = FLAGS.trainset_path\n    testset_path = FLAGS.testset_path\n\n    folder_name_list = [30, 92, 117, 133, 164, 320, 448, 522, 591]\n    folder_name = folder_name_list[FLAGS.index]\n\n    folder_name = str(folder_name)\n    data_root = os.path.join(dataset_path, folder_name)\n    dataset_train    = DatasetDeepFashion(data_root, FLAGS, examples=int(1e6))\n    dataset_validate = DatasetDeepFashion(data_root, FLAGS)\n\n    if FLAGS.testset_path is not None and FLAGS.testset_path != '':\n        testdata_root = os.path.join(testset_path, folder_name)\n        dataset_test = DatasetDeepFashionTestset(testdata_root, FLAGS)\n\n\n\n    # ==============================================================================================\n    #  Create env light with trainable parameters\n    # ==============================================================================================\n    \n    lgt = None\n    if FLAGS.learn_lighting:\n        lgt = light.create_trainable_env_rnd(FLAGS.probe_res, scale=0.0, bias=0.5)\n        # lgt = light.create_trainable_env_rnd(FLAGS.probe_res, scale=0.0, bias=0.1)\n    else:\n        lgt = light.load_env(FLAGS.envlight, scale=FLAGS.env_scale, res=[FLAGS.probe_res, FLAGS.probe_res])\n\n    # ==============================================================================================\n    #  Setup denoiser\n    # ==============================================================================================\n\n    denoiser = None\n    if FLAGS.denoiser == 'bilateral':\n        denoiser = BilateralDenoiser().cuda()\n    else:\n        assert FLAGS.denoiser == 'none', \"Invalid denoiser %s\" % FLAGS.denoiser\n\n    # Setup geometry for optimization\n    geometry = GShellTetsGeometry(FLAGS.gshell_grid, FLAGS.mesh_scale, FLAGS)\n\n    # Setup textures, make initial guess from reference if possible\n    if not FLAGS.normal_only:\n        mat = initial_guess_material(geometry, True, FLAGS, mtl_default)\n    else:\n        mat = initial_guess_material_knownkskd(geometry, True, FLAGS, mtl_default)\n    mat['no_perturbed_nrm'] = True\n\n    # Run optimization\n    geometry, mat = optimize_mesh(denoiser, glctx, geometry, mat, lgt, dataset_train, dataset_validate, \n                    FLAGS, pass_idx=0, pass_name=\"pass1\", optimize_light=FLAGS.learn_lighting, save_path=os.path.join(FLAGS.out_dir, folder_name))\n\n    validate(glctx, geometry, mat, lgt, dataset_validate, os.path.join(FLAGS.out_dir, folder_name, \"validate\"), FLAGS, denoiser=denoiser, save_viz=True)\n    if FLAGS.testset_path is not None and FLAGS.testset_path != '':\n        validate(glctx, geometry, mat, lgt, dataset_test, os.path.join(FLAGS.out_dir, folder_name, \"test\"), FLAGS, denoiser=denoiser, save_viz=False)\n\n    with torch.no_grad():\n        os.makedirs(os.path.join(FLAGS.out_dir, folder_name, \"mesh\"), exist_ok=True)\n        torch.save(geometry.state_dict(), os.path.join(FLAGS.out_dir, folder_name, \"mesh/model.pt\"))\n        torch.save(mat['kd_ks'].state_dict(), os.path.join(FLAGS.out_dir, folder_name, \"mesh/mtl.pt\"))\n        light.save_env_map(os.path.join(FLAGS.out_dir, folder_name, \"mesh/probe.hdr\"), lgt)\n\n        # Create textured mesh from result\n        base_mesh = geometry.getMesh(mat)['imesh']\n\n        # Dump mesh for debugging.\n        os.makedirs(os.path.join(FLAGS.out_dir, folder_name, \"mesh\"), exist_ok=True)\n        obj.write_obj(os.path.join(FLAGS.out_dir, folder_name, \"mesh/\"), base_mesh, save_material=False)\n\n        # Free temporaries / cached memory\n        torch.cuda.empty_cache()\n        mat['kd_ks'].cleanup()\n        del mat['kd_ks']\n        if 'kd_ks_back' in mat:\n            mat['kd_ks_back'].cleanup()\n            del mat['kd_ks_back']\n\n        # Free temporaries / cached memory\n        torch.cuda.empty_cache()\n        del mat"
  },
  {
    "path": "train_gshelltet_polycam.py",
    "content": "# Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. \n#\n# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual\n# property and proprietary rights in and to this material, related\n# documentation and any modifications thereto. Any use, reproduction, \n# disclosure or distribution of this material and related documentation \n# without an express license agreement from NVIDIA CORPORATION or \n# its affiliates is strictly prohibited.\n\nimport os\nimport sys\nimport time\nimport argparse\nimport json\n\nimport numpy as np\nimport torch\nimport nvdiffrast.torch as dr\nimport xatlas\n\n# Import data readers / generators\nfrom dataset.dataset_nerf_colmap import DatasetNERF\n\n# Import topology / geometry trainers\nfrom geometry.gshell_tets_geometry import GShellTetsGeometry\n\nimport render.renderutils as ru\nfrom render import obj\nfrom render import material\nfrom render import util\nfrom render import mesh\nfrom render import texture\nfrom render import mlptexture\nfrom render import light\nfrom render import render\n\n\nfrom denoiser.denoiser import BilateralDenoiser\n\nimport tqdm\n\nRADIUS = 3.0\n\n# Enable to debug back-prop anomalies\n# torch.autograd.set_detect_anomaly(True)\n\n###############################################################################\n# Loss setup\n###############################################################################\n\n@torch.no_grad()\ndef createLoss(FLAGS):\n    if FLAGS.loss == \"smape\":\n        return lambda img, ref: ru.image_loss(img, ref, loss='smape', tonemapper='none')\n    elif FLAGS.loss == \"mse\":\n        return lambda img, ref: ru.image_loss(img, ref, loss='mse', tonemapper='none')\n    elif FLAGS.loss == \"logl1\":\n        return lambda img, ref: ru.image_loss(img, ref, loss='l1', tonemapper='log_srgb')\n    elif FLAGS.loss == \"logl2\":\n        return lambda img, ref: ru.image_loss(img, ref, loss='mse', tonemapper='log_srgb')\n    elif FLAGS.loss == \"relmse\":\n        return lambda img, ref: ru.image_loss(img, ref, loss='relmse', tonemapper='none')\n    else:\n        assert False\n\n###############################################################################\n# Mix background into a dataset image\n###############################################################################\n\n@torch.no_grad()\ndef prepare_batch(target, bg_type='black'):\n    assert len(target['img'].shape) == 4, \"Image shape should be [n, h, w, c]\"\n    if bg_type == 'checker':\n        background = torch.tensor(util.checkerboard(target['img'].shape[1:3], 8), dtype=torch.float32, device='cuda')[None, ...]\n    elif bg_type == 'black':\n        background = torch.zeros(target['img'].shape[0:3] + (3,), dtype=torch.float32, device='cuda')\n    elif bg_type == 'white':\n        background = torch.ones(target['img'].shape[0:3] + (3,), dtype=torch.float32, device='cuda')\n    elif bg_type == 'reference':\n        background = target['img'][..., 0:3]\n    elif bg_type == 'random':\n        background = torch.rand(target['img'].shape[0:3] + (3,), dtype=torch.float32, device='cuda')\n    else:\n        assert False, \"Unknown background type %s\" % bg_type\n\n    target['mv'] = target['mv'].cuda()\n    target['mvp'] = target['mvp'].cuda()\n    target['campos'] = target['campos'].cuda()\n    target['img'] = target['img'].cuda()\n    target['background'] = background\n\n    target['img'] = torch.cat((torch.lerp(background, target['img'][..., 0:3], target['img'][..., 3:4]), target['img'][..., 3:4]), dim=-1)\n\n    return target\n\n###############################################################################\n# UV - map geometry & convert to a mesh\n###############################################################################\n\n@torch.no_grad()\ndef xatlas_uvmap(glctx, geometry, mat, FLAGS):\n    eval_mesh = geometry.getMesh(mat)\n    try:\n        eval_mesh = eval_mesh['imesh']\n    except:\n        pass\n    \n    # Create uvs with xatlas\n    v_pos = eval_mesh.v_pos.detach().cpu().numpy()\n    t_pos_idx = eval_mesh.t_pos_idx.detach().cpu().numpy()\n    vmapping, indices, uvs = xatlas.parametrize(v_pos, t_pos_idx)\n\n    # Convert to tensors\n    indices_int64 = indices.astype(np.uint64, casting='same_kind').view(np.int64)\n    \n    uvs = torch.tensor(uvs, dtype=torch.float32, device='cuda')\n    faces = torch.tensor(indices_int64, dtype=torch.int64, device='cuda')\n\n    new_mesh = mesh.Mesh(v_tex=uvs, t_tex_idx=faces, base=eval_mesh)\n\n    mask, kd, ks = render.render_uv(glctx, new_mesh, FLAGS.texture_res, eval_mesh.material['kd_ks'])\n\n    # Dilate all textures & use average color for background\n    kd_avg = torch.sum(torch.sum(torch.sum(kd * mask, dim=0), dim=0), dim=0) / torch.sum(torch.sum(torch.sum(mask, dim=0), dim=0), dim=0)\n    kd = util.dilate(kd, kd_avg[None, None, None, :], mask, 7)\n\n    ks_avg = torch.sum(torch.sum(torch.sum(ks * mask, dim=0), dim=0), dim=0) / torch.sum(torch.sum(torch.sum(mask, dim=0), dim=0), dim=0)\n    ks = util.dilate(ks, ks_avg[None, None, None, :], mask, 7)\n\n    nrm_avg = torch.tensor([0, 0, 1], dtype=torch.float32, device=\"cuda\")\n    normal = nrm_avg[None, None, None, :].repeat(kd.shape[0], kd.shape[1], kd.shape[2], 1)\n    \n    new_mesh.material = mat.copy()\n    del new_mesh.material['kd_ks']\n\n    if FLAGS.transparency:\n        kd = torch.cat((kd, torch.rand_like(kd[...,0:1])), dim=-1)\n        print(\"kd shape\", kd.shape)\n\n    kd_min, kd_max = torch.tensor(FLAGS.kd_min, dtype=torch.float32, device='cuda'), torch.tensor(FLAGS.kd_max, dtype=torch.float32, device='cuda')\n    ks_min, ks_max = torch.tensor(FLAGS.ks_min, dtype=torch.float32, device='cuda'), torch.tensor(FLAGS.ks_max, dtype=torch.float32, device='cuda')\n    nrm_min, nrm_max = torch.tensor(FLAGS.nrm_min, dtype=torch.float32, device='cuda'), torch.tensor(FLAGS.nrm_max, dtype=torch.float32, device='cuda')\n    new_mesh.material.update({\n        'kd'     : texture.Texture2D(kd.clone().detach().requires_grad_(True), min_max=[kd_min, kd_max]),\n        'ks'     : texture.Texture2D(ks.clone().detach().requires_grad_(True), min_max=[ks_min, ks_max]),\n        'normal' : texture.Texture2D(normal.clone().detach().requires_grad_(True), min_max=[nrm_min, nrm_max]),\n    })\n\n    return new_mesh\n\n###############################################################################\n# Utility functions for material\n###############################################################################\n\ndef initial_guess_material(geometry, mlp, FLAGS, init_mat=None):\n    kd_min, kd_max = torch.tensor(FLAGS.kd_min, dtype=torch.float32, device='cuda'), torch.tensor(FLAGS.kd_max, dtype=torch.float32, device='cuda')\n    ks_min, ks_max = torch.tensor(FLAGS.ks_min, dtype=torch.float32, device='cuda'), torch.tensor(FLAGS.ks_max, dtype=torch.float32, device='cuda')\n    if mlp:\n        mlp_min = torch.cat((kd_min[0:3], ks_min), dim=0)\n        mlp_max = torch.cat((kd_max[0:3], ks_max), dim=0)\n        mlp_map_opt = mlptexture.MLPTexture3D(geometry.getAABB(), channels=6, min_max=[mlp_min, mlp_max], use_float16=FLAGS.use_float16)\n        mat =  {'kd_ks' : mlp_map_opt}\n    else:\n        raise NotImplementedError\n\n    mat['bsdf'] = FLAGS.bsdf\n\n    mat['no_perturbed_nrm'] = FLAGS.no_perturbed_nrm\n\n    return mat\n\ndef initial_guess_material_knownkskd(geometry, mlp, FLAGS, init_mat=None):\n    mat =  {\n        'kd'     : init_mat['kd'],\n        'ks'     : init_mat['ks']\n    }\n\n    if init_mat is not None:\n        mat['bsdf'] = init_mat['bsdf']\n    else:\n        mat['bsdf'] = 'pbr'\n\n    return mat\n\n###############################################################################\n# Validation & testing\n###############################################################################\n\n@torch.no_grad()\ndef validate_itr(glctx, target, geometry, opt_material, lgt, FLAGS, denoiser=None):\n    result_dict = {}\n    with torch.no_grad():\n        buffers = geometry.render(glctx, target, lgt, opt_material, use_uv=False, denoiser=denoiser)['buffers']\n\n        result_dict['ref'] = util.rgb_to_srgb(target['img'][...,0:3])[0]\n        result_dict['opt'] = util.rgb_to_srgb(buffers['shaded'][...,0:3])[0]\n        result_dict['mask_opt'] = buffers['shaded'][...,3:][0].expand(-1, -1, 3)\n        result_dict['mask_ref'] = target['img'][...,3:][0].expand(-1, -1, 3)\n        result_dict['msdf_image'] = buffers['msdf_image'][...,:][0].expand(-1, -1, 3).clamp(min=0, max=1)\n        result_image = torch.cat([result_dict['opt'], result_dict['ref'], result_dict['mask_opt'], result_dict['mask_ref'], result_dict['msdf_image']], axis=1)\n\n        result_dict = {}\n        result_dict['ref'] = util.rgb_to_srgb(target['img'][...,0:3])[0]\n        result_dict['opt'] = util.rgb_to_srgb(buffers['shaded'][...,0:3])[0]\n\n        return result_image, result_dict\n\n@torch.no_grad()\ndef validate(glctx, geometry, opt_material, lgt, dataset_validate, out_dir, FLAGS, denoiser=None, save_viz=False):\n\n    # ==============================================================================================\n    #  Validation loop\n    # ==============================================================================================\n    mse_values = []\n    psnr_values = []\n\n    dataloader_validate = torch.utils.data.DataLoader(dataset_validate, batch_size=1, collate_fn=dataset_validate.collate)\n\n    os.makedirs(out_dir, exist_ok=True)\n    with open(os.path.join(out_dir, 'metrics.txt'), 'w') as fout:\n        fout.write('ID, MSE, PSNR\\n')\n\n        print(\"Running validation\")\n        for it, target in enumerate(tqdm.tqdm(dataloader_validate)):\n\n            # Mix validation background\n            target = prepare_batch(target, FLAGS.background)\n\n            result_image, result_dict = validate_itr(glctx, target, geometry, opt_material, lgt, FLAGS, denoiser=denoiser)\n\n            # Compute metrics\n            opt = torch.clamp(result_dict['opt'], 0.0, 1.0) \n            ref = torch.clamp(result_dict['ref'], 0.0, 1.0)\n\n            mse = torch.nn.functional.mse_loss(opt, ref, size_average=None, reduce=None, reduction='mean').item()\n            mse_values.append(float(mse))\n            psnr = util.mse_to_psnr(mse)\n            psnr_values.append(float(psnr))\n\n            line = \"%d, %1.8f, %1.8f\\n\" % (it, mse, psnr)\n            fout.write(str(line))\n\n            if save_viz:\n                for k in result_dict.keys():\n                    np_img = result_dict[k].detach().cpu().numpy()\n                    util.save_image(out_dir + '/' + ('val_%06d_%s.png' % (it, k)), np_img)\n\n        avg_mse = np.mean(np.array(mse_values))\n        avg_psnr = np.mean(np.array(psnr_values))\n        line = \"AVERAGES: %1.4f, %2.3f\\n\" % (avg_mse, avg_psnr)\n        fout.write(str(line))\n        print(\"MSE,      PSNR\")\n        print(\"%1.8f, %2.3f\" % (avg_mse, avg_psnr))\n    return avg_psnr\n\n###############################################################################\n# Main shape fitter function / optimization loop\n###############################################################################\n\ndef optimize_mesh(\n        denoiser,\n        glctx,\n        geometry,\n        opt_material,\n        lgt,\n        dataset_train,\n        dataset_validate,\n        FLAGS,\n        warmup_iter=0,\n        log_interval=10,\n        pass_idx=0,\n        pass_name=\"\",\n        optimize_light=True,\n        optimize_geometry=True,\n        visualize=True,\n        save_path=None\n    ):\n\n    # ==============================================================================================\n    #  Setup torch optimizer\n    # ==============================================================================================\n\n    learning_rate = FLAGS.learning_rate[pass_idx] if isinstance(FLAGS.learning_rate, list) or isinstance(FLAGS.learning_rate, tuple) else FLAGS.learning_rate\n    learning_rate_pos = learning_rate[0] if isinstance(learning_rate, list) or isinstance(learning_rate, tuple) else learning_rate\n    learning_rate_mat = learning_rate[1] if isinstance(learning_rate, list) or isinstance(learning_rate, tuple) else learning_rate\n    # learning_rate_lgt = learning_rate[2] if isinstance(learning_rate, list) or isinstance(learning_rate, tuple) else learning_rate * 3.0\n    learning_rate_lgt = learning_rate[2] if isinstance(learning_rate, list) or isinstance(learning_rate, tuple) else learning_rate * 6.0\n    # learning_rate_lgt = learning_rate[2] if isinstance(learning_rate, list) or isinstance(learning_rate, tuple) else learning_rate * 0.5\n\n    def lr_schedule(iter, fraction):\n        if iter < warmup_iter:\n            return iter / warmup_iter \n        return max(0.0, 10**(-(iter - warmup_iter)*0.0002)) # Exponential falloff from [1.0, 0.1] over 5k epochs.    \n\n    # ==============================================================================================\n    #  Image loss\n    # ==============================================================================================\n    image_loss_fn = createLoss(FLAGS)\n\n\n\n    params = list(material.get_parameters(opt_material))\n\n    if optimize_light:\n        optimizer_light = torch.optim.Adam((lgt.parameters() if lgt is not None else []), lr=learning_rate_lgt)\n        scheduler_light = torch.optim.lr_scheduler.LambdaLR(optimizer_light, lr_lambda=lambda x: lr_schedule(x, 0.9)) \n\n    if optimize_geometry:\n        if FLAGS.use_sdf_mlp:\n            lr_msdf = learning_rate_pos * 1e-2 if FLAGS.use_msdf_mlp else learning_rate_pos\n            deform_params = list(v[1] for v in geometry.named_parameters() if 'deform' in v[0]) if optimize_geometry else []\n            msdf_params = list(v[1] for v in geometry.named_parameters() if 'msdf' in v[0]) if optimize_geometry else []\n            sdf_params = list(v[1] for v in geometry.named_parameters() if 'sdf' in v[0] and 'msdf' not in v[0]) if optimize_geometry else []\n            other_params = list(v[1] for v in geometry.named_parameters() if 'sdf' not in v[0] and 'msdf' not in v[0] and 'deform' not in v[0]) if optimize_geometry else []\n            optimizer_mesh = torch.optim.Adam([\n                    {'params': deform_params, 'lr': learning_rate_pos},\n                    {'params': msdf_params, 'lr': lr_msdf},\n                    {'params': sdf_params, 'lr': learning_rate_pos * 1e-2},\n                    {'params': other_params, 'lr': learning_rate_pos * 1e-2},\n                ], eps=1e-8)\n        else:\n            optimizer_mesh = torch.optim.Adam(geometry.parameters(), lr=learning_rate_pos)\n        scheduler_mesh = torch.optim.lr_scheduler.LambdaLR(optimizer_mesh, lr_lambda=lambda x: lr_schedule(x, 0.9)) \n\n    optimizer = torch.optim.Adam(params, lr=learning_rate_mat)\n    scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda x: lr_schedule(x, 0.9))\n\n    # ==============================================================================================\n    #  Training loop\n    # ==============================================================================================\n    img_cnt = 0\n    img_loss_vec = []\n    depth_loss_vec = []\n    reg_loss_vec = []\n    iter_dur_vec = []\n\n    dataloader_train    = torch.utils.data.DataLoader(dataset_train, batch_size=FLAGS.batch, collate_fn=dataset_train.collate, shuffle=True)\n    if visualize:\n        dataloader_validate = torch.utils.data.DataLoader(dataset_validate, batch_size=1, collate_fn=dataset_train.collate)\n\n        def cycle(iterable):\n            iterator = iter(iterable)\n            while True:\n                try:\n                    yield next(iterator)\n                except StopIteration:\n                    iterator = iter(iterable)\n\n        v_it = cycle(dataloader_validate)\n\n    for it, target in enumerate(dataloader_train):\n\n        # Mix randomized background into dataset image\n        target = prepare_batch(target, 'random')\n\n        # ==============================================================================================\n        #  Display / save outputs. Do it before training so we get initial meshes\n        # ==============================================================================================\n\n        # Show/save image before training step (want to get correct rendering of input)\n        if visualize and FLAGS.local_rank == 0 and it != 0:\n            with torch.no_grad():\n                display_image = FLAGS.display_interval and (it % FLAGS.display_interval == 0)\n                save_image = FLAGS.save_interval and (it % FLAGS.save_interval == 0)\n                if display_image or save_image:\n                    save_mesh = True\n                    if save_mesh:\n                        os.makedirs(os.path.join(save_path, pass_name), exist_ok=True)\n                        obj.write_obj(os.path.join(save_path, pass_name), geometry.getMesh(opt_material)['imesh'], save_material=False)\n                    result_image, result_dict = validate_itr(glctx, prepare_batch(next(v_it), FLAGS.background), geometry, opt_material, lgt, FLAGS, denoiser=denoiser)\n            \n                    np_result_image = result_image.detach().cpu().numpy()\n                    if display_image:\n                        util.display_image(np_result_image, title='%d / %d' % (it, FLAGS.iter))\n                    if save_image:\n                        util.save_image(os.path.join(save_path, ('img_%s_%06d.png' % (pass_name, img_cnt))), np_result_image)\n                        img_cnt = img_cnt + 1\n\n        iter_start_time = time.time()\n\n        # ==============================================================================================\n        #  Zero gradients\n        # ==============================================================================================\n        optimizer.zero_grad()\n        if optimize_geometry:\n            optimizer_mesh.zero_grad()\n        if optimize_light:\n            optimizer_light.zero_grad()\n\n        # ==============================================================================================\n        #  Training\n        # ==============================================================================================\n\n        xfm_lgt = None\n        if optimize_light:\n            if False and FLAGS.camera_space_light:\n                lgt.xfm(target['mv'])\n            elif False and ('envlight_transform' in target and target['envlight_transform'] is not None):\n                xfm_lgt = target['envlight_transform']\n                lgt.xfm(xfm_lgt)\n            lgt.update_pdf()\n            \n\n        img_loss, depth_loss, reg_loss = geometry.tick(\n            glctx, target, lgt, opt_material, image_loss_fn, it, \n            denoiser=denoiser)\n\n        # ==============================================================================================\n        #  Final loss\n        # ==============================================================================================\n        total_loss = img_loss + reg_loss\n\n        img_loss_vec.append(img_loss.item())\n        depth_loss_vec.append(depth_loss.item())\n        reg_loss_vec.append(reg_loss.item())\n\n        # ==============================================================================================\n        #  Backpropagate\n        # ==============================================================================================\n        total_loss.backward()\n        if hasattr(lgt, 'base') and lgt.base.grad is not None and optimize_light:\n            lgt.base.grad *= 64\n        if 'kd_ks' in opt_material:\n            opt_material['kd_ks'].encoder.params.grad /= 8.0\n        if 'kd_ks_back' in opt_material:\n            opt_material['kd_ks_back'].encoder.params.grad /= 8.0\n\n        # Optionally clip gradients\n        if FLAGS.clip_max_norm > 0.0:\n            if optimize_geometry:\n                torch.nn.utils.clip_grad_norm_(geometry.parameters() + params, FLAGS.clip_max_norm)\n            else:\n                torch.nn.utils.clip_grad_norm_(params, FLAGS.clip_max_norm)\n\n        optimizer.step()\n        scheduler.step()\n\n        if optimize_geometry:\n            optimizer_mesh.step()\n            scheduler_mesh.step()\n\n        if optimize_light:\n            optimizer_light.step()\n            scheduler_light.step()\n\n        # ==============================================================================================\n        #  Clamp trainables to reasonable range\n        # ==============================================================================================\n        with torch.no_grad():\n            if 'kd' in opt_material:\n                opt_material['kd'].clamp_()\n            if 'ks' in opt_material:\n                opt_material['ks'].clamp_()\n            if 'kd_back' in opt_material:\n                opt_material['kd_back'].clamp_()\n            if 'ks_back' in opt_material:\n                opt_material['ks_back'].clamp_()\n            if 'normal' in opt_material and not FLAGS.normal_only:\n                opt_material['normal'].clamp_()\n                opt_material['normal'].normalize_()\n            if lgt is not None:\n                # lgt.clamp_(min=0.01) # For some reason gradient dissapears if light becomes 0\n                lgt.clamp_(min=1e-4) # For some reason gradient dissapears if light becomes 0\n\n            geometry.clamp_deform()\n        torch.cuda.current_stream().synchronize()\n        iter_dur_vec.append(time.time() - iter_start_time)\n\n        # ==============================================================================================\n        #  Logging\n        # ==============================================================================================\n        if it % log_interval == 0 and FLAGS.local_rank == 0:\n            img_loss_avg = np.mean(np.asarray(img_loss_vec[-log_interval:]))\n            depth_loss_avg = np.mean(np.asarray(depth_loss_vec[-log_interval:]))\n            reg_loss_avg = np.mean(np.asarray(reg_loss_vec[-log_interval:]))\n            iter_dur_avg = np.mean(np.asarray(iter_dur_vec[-log_interval:]))\n            \n            remaining_time = (FLAGS.iter-it)*iter_dur_avg\n            print(\"iter=%5d, img_loss=%.6f, depth_loss=%.6f, reg_loss=%.6f, lr=%.5f, time=%.1f ms, rem=%s\" % \n                (it, img_loss_avg, depth_loss_avg, reg_loss_avg, optimizer.param_groups[0]['lr'], iter_dur_avg*1000, util.time_to_text(remaining_time)))\n            sys.stdout.flush()\n\n        if it == FLAGS.iter:\n            break\n\n    return geometry, opt_material\n\n#----------------------------------------------------------------------------\n# Main function.\n#----------------------------------------------------------------------------\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser(description='nvdiffrec')\n    parser.add_argument('--config', type=str, default=None, help='Config file')\n    parser.add_argument('-i', '--iter', type=int, default=5000)\n    parser.add_argument('-b', '--batch', type=int, default=1)\n    parser.add_argument('-s', '--spp', type=int, default=1)\n    parser.add_argument('-l', '--layers', type=int, default=1)\n    parser.add_argument('-r', '--train-res', nargs=2, type=int, default=[512, 512])\n    parser.add_argument('-dr', '--display-res', type=int, default=None)\n    parser.add_argument('-tr', '--texture-res', nargs=2, type=int, default=[1024, 1024])\n    parser.add_argument('-di', '--display-interval', type=int, default=0)\n    parser.add_argument('-si', '--save-interval', type=int, default=1000)\n    parser.add_argument('-lr', '--learning-rate', type=float, default=0.01)\n    parser.add_argument('-mr', '--min-roughness', type=float, default=0.08)\n    parser.add_argument('-mip', '--custom-mip', action='store_true', default=False)\n    parser.add_argument('-rt', '--random-textures', action='store_true', default=False)\n    parser.add_argument('-bg', '--background', default='checker', choices=['black', 'white', 'checker', 'reference'])\n    parser.add_argument('--loss', default='logl1', choices=['logl1', 'logl2', 'mse', 'smape', 'relmse'])\n    parser.add_argument('-o', '--out-dir', type=str, default=None)\n    parser.add_argument('-rm', '--ref_mesh', type=str)\n    parser.add_argument('-bm', '--base-mesh', type=str, default=None)\n    parser.add_argument('--validate', type=bool, default=True)\n    # Render specific arguments\n    parser.add_argument('--n_samples', type=int, default=4)\n    parser.add_argument('--bsdf', type=str, default='pbr', choices=['pbr', 'diffuse', 'white'])\n    # Denoiser specific arguments\n    parser.add_argument('--denoiser', default='bilateral', choices=['none', 'bilateral'])\n    parser.add_argument('--denoiser_demodulate', type=bool, default=True)\n    parser.add_argument('--msdf_reg_open_scale', type=float, default=1e-6)\n    parser.add_argument('--msdf_reg_close_scale', type=float, default=3e-4)\n    parser.add_argument('--eikonal_scale', type=float, default=5e-3)\n    parser.add_argument('--sdf_regularizer', type=float, default=0.2)\n    parser.add_argument('--trainset_path', type=str)\n    parser.add_argument('--testset_path', type=str, default='')\n\n    FLAGS = parser.parse_args()\n    FLAGS.mtl_override        = None        # Override material of model\n    FLAGS.gshell_grid          = 64          # Resolution of initial tet grid. We provide 64 and 128 resolution grids. \n                                            #    Other resolutions can be generated with https://github.com/crawforddoran/quartet\n                                            #    We include examples in data/tets/generate_tets.py\n    FLAGS.mesh_scale          = 3.6         # Scale of tet grid box. Adjust to cover the model\n    FLAGS.envlight            = None        # HDR environment probe\n    FLAGS.env_scale           = 1.0         # Env map intensity multiplier\n    FLAGS.probe_res           = 256         # Env map probe resolution\n    FLAGS.learn_lighting      = True        # Enable optimization of env lighting\n    FLAGS.display             = None        # Configure validation window/display. E.g. [{\"bsdf\" : \"kd\"}, {\"bsdf\" : \"ks\"}]\n    FLAGS.transparency        = False       # Enabled transparency through depth peeling\n    FLAGS.lock_light          = False       # Disable light optimization in the second pass\n    FLAGS.lock_pos            = False       # Disable vertex position optimization in the second pass\n    # FLAGS.sdf_regularizer     = 0.2         # Weight for sdf regularizer.\n    FLAGS.laplace             = \"relative\"  # Mesh Laplacian [\"absolute\", \"relative\"]\n    FLAGS.laplace_scale       = 3000.0      # Weight for Laplace regularizer. Default is relative with large weight\n    FLAGS.pre_load            = True        # Pre-load entire dataset into memory for faster training\n    FLAGS.no_perturbed_nrm    = False       # Disable normal map\n    FLAGS.decorrelated        = False       # Use decorrelated sampling in forward and backward passes\n    FLAGS.kd_min              = [ 0.0,  0.0,  0.0,  0.0]\n    FLAGS.kd_max              = [ 1.0,  1.0,  1.0,  1.0]\n    FLAGS.ks_min              = [ 0.0,  0.001, 0.0]\n    FLAGS.ks_max              = [ 0.0,  1.0,  1.0]\n    FLAGS.nrm_min             = [-1.0, -1.0,  0.0]\n    FLAGS.nrm_max             = [ 1.0,  1.0,  1.0]\n    FLAGS.clip_max_norm       = 0.0\n    FLAGS.cam_near_far        = [0.1, 1000.0]\n    FLAGS.lambda_kd           = 0.1\n    FLAGS.lambda_ks           = 0.05\n    FLAGS.lambda_nrm          = 0.025\n    FLAGS.lambda_nrm2         = 0.25\n    FLAGS.lambda_chroma       = 0.0\n    FLAGS.lambda_diffuse      = 0.15\n    FLAGS.lambda_specular     = 0.0025\n\n    FLAGS.random_lgt                  = False\n    FLAGS.normal_only                 = False\n    FLAGS.use_img_2nd_layer           = False\n    FLAGS.use_depth                   = False\n    FLAGS.use_depth_2nd_layer         = False\n    FLAGS.use_tanh_deform             = False\n    FLAGS.use_sdf_mlp                 = True\n    FLAGS.use_msdf_mlp                = False\n    FLAGS.use_eikonal                 = True\n    FLAGS.sdf_mlp_pretrain_steps      = 10000\n    FLAGS.use_mesh_msdf_reg           = True\n    FLAGS.sphere_init                 = False\n    FLAGS.sphere_init_norm            = 2.0\n    FLAGS.pretrained_sdf_mlp_path     = f'./data/pretrained_mlp_{FLAGS.gshell_grid}_polycam.pt'\n    FLAGS.n_hidden                    = 6\n    FLAGS.d_hidden                    = 256\n    FLAGS.n_freq                      = 6\n    FLAGS.skip_in                     = [3]\n    FLAGS.use_float16                 = False\n    FLAGS.visualize_watertight        = False\n\n    FLAGS.local_rank = 0\n    FLAGS.multi_gpu  = \"WORLD_SIZE\" in os.environ and int(os.environ[\"WORLD_SIZE\"]) > 1\n    if FLAGS.multi_gpu:\n        if \"MASTER_ADDR\" not in os.environ:\n            os.environ[\"MASTER_ADDR\"] = 'localhost'\n        if \"MASTER_PORT\" not in os.environ:\n            os.environ[\"MASTER_PORT\"] = '23456'\n\n        FLAGS.local_rank = int(os.environ[\"LOCAL_RANK\"])\n        torch.cuda.set_device(FLAGS.local_rank)\n        torch.distributed.init_process_group(backend=\"nccl\", init_method=\"env://\")\n\n    if FLAGS.config is not None:\n        data = json.load(open(FLAGS.config, 'r'))\n        for key in data:\n            FLAGS.__dict__[key] = data[key]\n\n    if FLAGS.display_res is None:\n        FLAGS.display_res = FLAGS.train_res\n\n    if FLAGS.local_rank == 0:\n        print(\"Config / Flags:\")\n        print(\"---------\")\n        for key in FLAGS.__dict__.keys():\n            print(key, FLAGS.__dict__[key])\n        print(\"---------\")\n\n    os.makedirs(FLAGS.out_dir, exist_ok=True)\n\n    glctx = dr.RasterizeGLContext()\n    glctx_display = glctx if FLAGS.batch < 16 else dr.RasterizeGLContext() # Context for display\n\n    mtl_default = None\n\n    # ==============================================================================================\n    #  Create data pipeline\n    # ==============================================================================================\n    data_root = FLAGS.trainset_path\n    dataset_train    = DatasetNERF(os.path.join(data_root, 'transforms.json'), FLAGS, examples=int(1e6))\n    dataset_validate = DatasetNERF(os.path.join(data_root, 'transforms.json'), FLAGS)\n\n\n    # ==============================================================================================\n    #  Create env light with trainable parameters\n    # ==============================================================================================\n    \n    lgt = None\n    if FLAGS.learn_lighting:\n        lgt = light.create_trainable_env_rnd(FLAGS.probe_res, scale=0.0, bias=0.5)\n        # lgt = light.create_trainable_env_rnd(FLAGS.probe_res, scale=0.0, bias=0.1)\n    else:\n        lgt = light.load_env(FLAGS.envlight, scale=FLAGS.env_scale, res=[FLAGS.probe_res, FLAGS.probe_res])\n\n    # ==============================================================================================\n    #  Setup denoiser\n    # ==============================================================================================\n\n    denoiser = None\n    if FLAGS.denoiser == 'bilateral':\n        denoiser = BilateralDenoiser().cuda()\n    else:\n        assert FLAGS.denoiser == 'none', \"Invalid denoiser %s\" % FLAGS.denoiser\n\n    # Setup geometry for optimization\n    geometry = GShellTetsGeometry(FLAGS.gshell_grid, FLAGS.mesh_scale, FLAGS)\n\n    # Setup textures, make initial guess from reference if possible\n    if not FLAGS.normal_only:\n        mat = initial_guess_material(geometry, True, FLAGS, mtl_default)\n    else:\n        mat = initial_guess_material_knownkskd(geometry, True, FLAGS, mtl_default)\n    mat['no_perturbed_nrm'] = True\n\n    # Run optimization\n    geometry, mat = optimize_mesh(denoiser, glctx, geometry, mat, lgt, dataset_train, dataset_validate, \n                    FLAGS, pass_idx=0, pass_name=\"pass1\", optimize_light=FLAGS.learn_lighting, save_path=FLAGS.out_dir)\n\n    validate(glctx, geometry, mat, lgt, dataset_validate, os.path.join(FLAGS.out_dir, \"validate\"), FLAGS, denoiser=denoiser, save_viz=True)\n\n    with torch.no_grad():\n        os.makedirs(os.path.join(FLAGS.out_dir, \"mesh\"), exist_ok=True)\n        torch.save(geometry.state_dict(), os.path.join(FLAGS.out_dir, \"mesh/model.pt\"))\n        torch.save(mat['kd_ks'].state_dict(), os.path.join(FLAGS.out_dir, \"mesh/mtl.pt\"))\n        light.save_env_map(os.path.join(FLAGS.out_dir, \"mesh/probe.hdr\"), lgt)\n\n        # Create textured mesh from result\n        base_mesh = geometry.getMesh(mat)['imesh']\n\n        # Dump mesh for debugging.\n        os.makedirs(os.path.join(FLAGS.out_dir, \"mesh\"), exist_ok=True)\n        obj.write_obj(os.path.join(FLAGS.out_dir, \"mesh/\"), base_mesh, save_material=False)\n\n        # Free temporaries / cached memory\n        torch.cuda.empty_cache()\n        mat['kd_ks'].cleanup()\n        del mat['kd_ks']\n        if 'kd_ks_back' in mat:\n            mat['kd_ks_back'].cleanup()\n            del mat['kd_ks_back']\n\n        # Free temporaries / cached memory\n        torch.cuda.empty_cache()\n        del mat"
  },
  {
    "path": "train_gshelltet_synthetic.py",
    "content": "# Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. \n#\n# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual\n# property and proprietary rights in and to this material, related\n# documentation and any modifications thereto. Any use, reproduction, \n# disclosure or distribution of this material and related documentation \n# without an express license agreement from NVIDIA CORPORATION or \n# its affiliates is strictly prohibited.\n\nimport os\nimport sys\nimport time\nimport argparse\nimport json\n\nimport numpy as np\nimport torch\nimport nvdiffrast.torch as dr\nimport xatlas\n\n# Import data readers / generators\nfrom dataset import DatasetMesh, DatasetNERF, DatasetLLFF\n\n# Import topology / geometry trainers\nfrom geometry.gshell_tets_geometry import GShellTetsGeometry\n\nimport render.renderutils as ru\nfrom render import obj\nfrom render import material\nfrom render import util\nfrom render import mesh\nfrom render import texture\nfrom render import mlptexture\nfrom render import light\nfrom render import render\n\n\nfrom denoiser.denoiser import BilateralDenoiser\n\n\nRADIUS = 3.0\n\n# Enable to debug back-prop anomalies\n# torch.autograd.set_detect_anomaly(True)\n\n###############################################################################\n# Loss setup\n###############################################################################\n\n@torch.no_grad()\ndef createLoss(FLAGS):\n    if FLAGS.loss == \"smape\":\n        return lambda img, ref: ru.image_loss(img, ref, loss='smape', tonemapper='none')\n    elif FLAGS.loss == \"mse\":\n        return lambda img, ref: ru.image_loss(img, ref, loss='mse', tonemapper='none')\n    elif FLAGS.loss == \"logl1\":\n        return lambda img, ref: ru.image_loss(img, ref, loss='l1', tonemapper='log_srgb')\n    elif FLAGS.loss == \"logl2\":\n        return lambda img, ref: ru.image_loss(img, ref, loss='mse', tonemapper='log_srgb')\n    elif FLAGS.loss == \"relmse\":\n        return lambda img, ref: ru.image_loss(img, ref, loss='relmse', tonemapper='none')\n    else:\n        assert False\n\n###############################################################################\n# Mix background into a dataset image\n###############################################################################\n\n@torch.no_grad()\ndef prepare_batch(target, bg_type='black'):\n    assert len(target['img'].shape) == 4, \"Image shape should be [n, h, w, c]\"\n    if bg_type == 'checker':\n        background = torch.tensor(util.checkerboard(target['img'].shape[1:3], 8), dtype=torch.float32, device='cuda')[None, ...]\n    elif bg_type == 'black':\n        background = torch.zeros(target['img'].shape[0:3] + (3,), dtype=torch.float32, device='cuda')\n    elif bg_type == 'white':\n        background = torch.ones(target['img'].shape[0:3] + (3,), dtype=torch.float32, device='cuda')\n    elif bg_type == 'reference':\n        background = target['img'][..., 0:3]\n    elif bg_type == 'random':\n        background = torch.rand(target['img'].shape[0:3] + (3,), dtype=torch.float32, device='cuda')\n    else:\n        assert False, \"Unknown background type %s\" % bg_type\n\n    target['mv'] = target['mv'].cuda()\n    target['mvp'] = target['mvp'].cuda()\n    target['campos'] = target['campos'].cuda()\n    target['img'] = target['img'].cuda()\n    target['background'] = background\n\n    target['img'] = torch.cat((torch.lerp(background, target['img'][..., 0:3], target['img'][..., 3:4]), target['img'][..., 3:4]), dim=-1)\n\n    return target\n\n###############################################################################\n# UV - map geometry & convert to a mesh\n###############################################################################\n\n@torch.no_grad()\ndef xatlas_uvmap(glctx, geometry, mat, FLAGS):\n    eval_mesh = geometry.getMesh(mat)\n    try:\n        eval_mesh = eval_mesh['imesh']\n    except:\n        pass\n    \n    # Create uvs with xatlas\n    v_pos = eval_mesh.v_pos.detach().cpu().numpy()\n    t_pos_idx = eval_mesh.t_pos_idx.detach().cpu().numpy()\n    vmapping, indices, uvs = xatlas.parametrize(v_pos, t_pos_idx)\n\n    # Convert to tensors\n    indices_int64 = indices.astype(np.uint64, casting='same_kind').view(np.int64)\n    \n    uvs = torch.tensor(uvs, dtype=torch.float32, device='cuda')\n    faces = torch.tensor(indices_int64, dtype=torch.int64, device='cuda')\n\n    new_mesh = mesh.Mesh(v_tex=uvs, t_tex_idx=faces, base=eval_mesh)\n\n    mask, kd, ks = render.render_uv(glctx, new_mesh, FLAGS.texture_res, eval_mesh.material['kd_ks'])\n\n    # Dilate all textures & use average color for background\n    kd_avg = torch.sum(torch.sum(torch.sum(kd * mask, dim=0), dim=0), dim=0) / torch.sum(torch.sum(torch.sum(mask, dim=0), dim=0), dim=0)\n    kd = util.dilate(kd, kd_avg[None, None, None, :], mask, 7)\n\n    ks_avg = torch.sum(torch.sum(torch.sum(ks * mask, dim=0), dim=0), dim=0) / torch.sum(torch.sum(torch.sum(mask, dim=0), dim=0), dim=0)\n    ks = util.dilate(ks, ks_avg[None, None, None, :], mask, 7)\n\n    nrm_avg = torch.tensor([0, 0, 1], dtype=torch.float32, device=\"cuda\")\n    normal = nrm_avg[None, None, None, :].repeat(kd.shape[0], kd.shape[1], kd.shape[2], 1)\n    \n    new_mesh.material = mat.copy()\n    del new_mesh.material['kd_ks']\n\n    if FLAGS.transparency:\n        kd = torch.cat((kd, torch.rand_like(kd[...,0:1])), dim=-1)\n        print(\"kd shape\", kd.shape)\n\n    kd_min, kd_max = torch.tensor(FLAGS.kd_min, dtype=torch.float32, device='cuda'), torch.tensor(FLAGS.kd_max, dtype=torch.float32, device='cuda')\n    ks_min, ks_max = torch.tensor(FLAGS.ks_min, dtype=torch.float32, device='cuda'), torch.tensor(FLAGS.ks_max, dtype=torch.float32, device='cuda')\n    nrm_min, nrm_max = torch.tensor(FLAGS.nrm_min, dtype=torch.float32, device='cuda'), torch.tensor(FLAGS.nrm_max, dtype=torch.float32, device='cuda')\n    new_mesh.material.update({\n        'kd'     : texture.Texture2D(kd.clone().detach().requires_grad_(True), min_max=[kd_min, kd_max]),\n        'ks'     : texture.Texture2D(ks.clone().detach().requires_grad_(True), min_max=[ks_min, ks_max]),\n        'normal' : texture.Texture2D(normal.clone().detach().requires_grad_(True), min_max=[nrm_min, nrm_max]),\n    })\n\n    return new_mesh\n\n###############################################################################\n# Utility functions for material\n###############################################################################\n\ndef initial_guess_material(geometry, mlp, FLAGS, init_mat=None):\n    kd_min, kd_max = torch.tensor(FLAGS.kd_min, dtype=torch.float32, device='cuda'), torch.tensor(FLAGS.kd_max, dtype=torch.float32, device='cuda')\n    ks_min, ks_max = torch.tensor(FLAGS.ks_min, dtype=torch.float32, device='cuda'), torch.tensor(FLAGS.ks_max, dtype=torch.float32, device='cuda')\n    if mlp:\n        mlp_min = torch.cat((kd_min[0:3], ks_min), dim=0)\n        mlp_max = torch.cat((kd_max[0:3], ks_max), dim=0)\n        mlp_map_opt = mlptexture.MLPTexture3D(geometry.getAABB(), channels=6, min_max=[mlp_min, mlp_max], use_float16=FLAGS.use_float16)\n        mat =  {'kd_ks' : mlp_map_opt}\n    else:\n        raise NotImplementedError\n\n    mat['bsdf'] = FLAGS.bsdf\n\n    mat['no_perturbed_nrm'] = FLAGS.no_perturbed_nrm\n\n    return mat\n\ndef initial_guess_material_knownkskd(geometry, mlp, FLAGS, init_mat=None):\n    mat =  {\n        'kd'     : init_mat['kd'],\n        'ks'     : init_mat['ks']\n    }\n\n    if init_mat is not None:\n        mat['bsdf'] = init_mat['bsdf']\n    else:\n        mat['bsdf'] = 'pbr'\n\n    return mat\n\n###############################################################################\n# Validation & testing\n###############################################################################\n\n@torch.no_grad()\ndef validate_itr(glctx, target, geometry, opt_material, lgt, FLAGS, denoiser=None):\n    result_dict = {}\n    with torch.no_grad():\n        buffers = geometry.render(glctx, target, lgt, opt_material, use_uv=False, denoiser=denoiser)['buffers']\n\n        result_dict['ref'] = util.rgb_to_srgb(target['img'][...,0:3])[0]\n        result_dict['opt'] = util.rgb_to_srgb(buffers['shaded'][...,0:3])[0]\n        result_dict['mask_opt'] = buffers['shaded'][...,3:][0].expand(-1, -1, 3)\n        result_dict['mask_ref'] = target['img'][...,3:][0].expand(-1, -1, 3)\n        result_dict['msdf_image'] = buffers['msdf_image'][...,:][0].expand(-1, -1, 3).clamp(min=0, max=1)\n        result_image = torch.cat([result_dict['opt'], result_dict['ref'], result_dict['mask_opt'], result_dict['mask_ref'], result_dict['msdf_image']], axis=1)\n\n        if FLAGS.display is not None:\n            white_bg = torch.ones_like(target['background'])\n            for layer in FLAGS.display:\n                if 'latlong' in layer and layer['latlong']:\n                    result_dict['light_image'] = lgt.generate_image(FLAGS.display_res)\n                    result_dict['light_image'] = util.rgb_to_srgb(result_dict['light_image'] / (1 + result_dict['light_image']))\n                    result_image = torch.cat([result_image, result_dict['light_image']], axis=1)\n\n        return result_image, result_dict\n\n@torch.no_grad()\ndef validate(glctx, geometry, opt_material, lgt, dataset_validate, out_dir, FLAGS, denoiser=None, save_viz=False):\n\n    # ==============================================================================================\n    #  Validation loop\n    # ==============================================================================================\n    mse_values = []\n    psnr_values = []\n\n    dataloader_validate = torch.utils.data.DataLoader(dataset_validate, batch_size=1, collate_fn=dataset_validate.collate)\n\n    os.makedirs(out_dir, exist_ok=True)\n    with open(os.path.join(out_dir, 'metrics.txt'), 'w') as fout:\n        fout.write('ID, MSE, PSNR\\n')\n\n        print(\"Running validation\")\n        for it, target in enumerate(dataloader_validate):\n\n            # Mix validation background\n            target = prepare_batch(target, FLAGS.background)\n\n            result_image, result_dict = validate_itr(glctx, target, geometry, opt_material, lgt, FLAGS, denoiser=denoiser)\n\n            # Compute metrics\n            opt = torch.clamp(result_dict['opt'], 0.0, 1.0) \n            ref = torch.clamp(result_dict['ref'], 0.0, 1.0)\n\n            mse = torch.nn.functional.mse_loss(opt, ref, size_average=None, reduce=None, reduction='mean').item()\n            mse_values.append(float(mse))\n            psnr = util.mse_to_psnr(mse)\n            psnr_values.append(float(psnr))\n\n            line = \"%d, %1.8f, %1.8f\\n\" % (it, mse, psnr)\n            fout.write(str(line))\n\n            if save_viz:\n                for k in result_dict.keys():\n                    np_img = result_dict[k].detach().cpu().numpy()\n                    util.save_image(out_dir + '/' + ('val_%06d_%s.png' % (it, k)), np_img)\n\n        avg_mse = np.mean(np.array(mse_values))\n        avg_psnr = np.mean(np.array(psnr_values))\n        line = \"AVERAGES: %1.4f, %2.3f\\n\" % (avg_mse, avg_psnr)\n        fout.write(str(line))\n        print(\"MSE,      PSNR\")\n        print(\"%1.8f, %2.3f\" % (avg_mse, avg_psnr))\n    return avg_psnr\n\n###############################################################################\n# Main shape fitter function / optimization loop\n###############################################################################\n\ndef optimize_mesh(\n        denoiser,\n        glctx,\n        geometry,\n        opt_material,\n        lgt,\n        dataset_train,\n        dataset_validate,\n        FLAGS,\n        warmup_iter=0,\n        log_interval=10,\n        pass_idx=0,\n        pass_name=\"\",\n        optimize_light=True,\n        optimize_geometry=True,\n        visualize=True,\n        save_path=None\n    ):\n\n    # ==============================================================================================\n    #  Setup torch optimizer\n    # ==============================================================================================\n\n    learning_rate = FLAGS.learning_rate[pass_idx] if isinstance(FLAGS.learning_rate, list) or isinstance(FLAGS.learning_rate, tuple) else FLAGS.learning_rate\n    learning_rate_pos = learning_rate[0] if isinstance(learning_rate, list) or isinstance(learning_rate, tuple) else learning_rate\n    learning_rate_mat = learning_rate[1] if isinstance(learning_rate, list) or isinstance(learning_rate, tuple) else learning_rate\n    learning_rate_lgt = learning_rate[2] if isinstance(learning_rate, list) or isinstance(learning_rate, tuple) else learning_rate * 6.0\n\n    def lr_schedule(iter, fraction):\n        if iter < warmup_iter:\n            return iter / warmup_iter \n        return max(0.0, 10**(-(iter - warmup_iter)*0.0002)) # Exponential falloff from [1.0, 0.1] over 5k epochs.    \n\n    # ==============================================================================================\n    #  Image loss\n    # ==============================================================================================\n    image_loss_fn = createLoss(FLAGS)\n\n\n\n    params = list(material.get_parameters(opt_material))\n\n    if optimize_light:\n        optimizer_light = torch.optim.Adam((lgt.parameters() if lgt is not None else []), lr=learning_rate_lgt)\n        scheduler_light = torch.optim.lr_scheduler.LambdaLR(optimizer_light, lr_lambda=lambda x: lr_schedule(x, 0.9)) \n\n    if optimize_geometry:\n        if FLAGS.use_sdf_mlp:\n            lr_msdf = learning_rate_pos * 1e-2 if FLAGS.use_msdf_mlp else learning_rate_pos\n            deform_params = list(v[1] for v in geometry.named_parameters() if 'deform' in v[0]) if optimize_geometry else []\n            msdf_params = list(v[1] for v in geometry.named_parameters() if 'msdf' in v[0]) if optimize_geometry else []\n            sdf_params = list(v[1] for v in geometry.named_parameters() if 'sdf' in v[0] and 'msdf' not in v[0]) if optimize_geometry else []\n            other_params = list(v[1] for v in geometry.named_parameters() if 'sdf' not in v[0] and 'msdf' not in v[0] and 'deform' not in v[0]) if optimize_geometry else []\n            optimizer_mesh = torch.optim.Adam([\n                    {'params': deform_params, 'lr': learning_rate_pos},\n                    {'params': msdf_params, 'lr': lr_msdf},\n                    {'params': sdf_params, 'lr': learning_rate_pos * 1e-2},\n                    {'params': other_params, 'lr': learning_rate_pos * 1e-2},\n                ], eps=1e-8)\n        else:\n            optimizer_mesh = torch.optim.Adam(geometry.parameters(), lr=learning_rate_pos)\n        scheduler_mesh = torch.optim.lr_scheduler.LambdaLR(optimizer_mesh, lr_lambda=lambda x: lr_schedule(x, 0.9)) \n\n    optimizer = torch.optim.Adam(params, lr=learning_rate_mat)\n    scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda x: lr_schedule(x, 0.9))\n\n    # ==============================================================================================\n    #  Training loop\n    # ==============================================================================================\n    img_cnt = 0\n    img_loss_vec = []\n    depth_loss_vec = []\n    reg_loss_vec = []\n    iter_dur_vec = []\n\n    dataloader_train    = torch.utils.data.DataLoader(dataset_train, batch_size=FLAGS.batch, collate_fn=dataset_train.collate, shuffle=True)\n    if visualize:\n        dataloader_validate = torch.utils.data.DataLoader(dataset_validate, batch_size=1, collate_fn=dataset_train.collate)\n\n        def cycle(iterable):\n            iterator = iter(iterable)\n            while True:\n                try:\n                    yield next(iterator)\n                except StopIteration:\n                    iterator = iter(iterable)\n\n        v_it = cycle(dataloader_validate)\n\n    for it, target in enumerate(dataloader_train):\n\n        # Mix randomized background into dataset image\n        target = prepare_batch(target, 'random')\n\n        # ==============================================================================================\n        #  Display / save outputs. Do it before training so we get initial meshes\n        # ==============================================================================================\n\n        # Show/save image before training step (want to get correct rendering of input)\n        if visualize and FLAGS.local_rank == 0 and it != 0:\n            with torch.no_grad():\n                display_image = FLAGS.display_interval and (it % FLAGS.display_interval == 0)\n                save_image = FLAGS.save_interval and (it % FLAGS.save_interval == 0)\n                if display_image or save_image:\n                    save_mesh = True\n                    if save_mesh:\n                        os.makedirs(os.path.join(save_path, pass_name), exist_ok=True)\n                        obj.write_obj(os.path.join(save_path, pass_name), geometry.getMesh(opt_material)['imesh'], save_material=False)\n                    result_image, result_dict = validate_itr(glctx, prepare_batch(next(v_it), FLAGS.background), geometry, opt_material, lgt, FLAGS, denoiser=denoiser)\n            \n                    np_result_image = result_image.detach().cpu().numpy()\n                    if display_image:\n                        util.display_image(np_result_image, title='%d / %d' % (it, FLAGS.iter))\n                    if save_image:\n                        util.save_image(os.path.join(save_path, ('img_%s_%06d.png' % (pass_name, img_cnt))), np_result_image)\n                        img_cnt = img_cnt + 1\n\n        iter_start_time = time.time()\n\n        # ==============================================================================================\n        #  Zero gradients\n        # ==============================================================================================\n        optimizer.zero_grad()\n        if optimize_geometry:\n            optimizer_mesh.zero_grad()\n        if optimize_light:\n            optimizer_light.zero_grad()\n\n        # ==============================================================================================\n        #  Training\n        # ==============================================================================================\n\n        xfm_lgt = None\n        if optimize_light:\n            lgt.update_pdf()\n            \n\n        img_loss, depth_loss, reg_loss = geometry.tick(\n            glctx, target, lgt, opt_material, image_loss_fn, it, \n            denoiser=denoiser)\n\n        # ==============================================================================================\n        #  Final loss\n        # ==============================================================================================\n        total_loss = img_loss + reg_loss\n\n        img_loss_vec.append(img_loss.item())\n        depth_loss_vec.append(depth_loss.item())\n        reg_loss_vec.append(reg_loss.item())\n\n        # ==============================================================================================\n        #  Backpropagate\n        # ==============================================================================================\n        total_loss.backward()\n        if hasattr(lgt, 'base') and lgt.base.grad is not None and optimize_light:\n            lgt.base.grad *= 64\n        if 'kd_ks' in opt_material:\n            opt_material['kd_ks'].encoder.params.grad /= 8.0\n        if 'kd_ks_back' in opt_material:\n            opt_material['kd_ks_back'].encoder.params.grad /= 8.0\n\n        # Optionally clip gradients\n        if FLAGS.clip_max_norm > 0.0:\n            if optimize_geometry:\n                torch.nn.utils.clip_grad_norm_(geometry.parameters() + params, FLAGS.clip_max_norm)\n            else:\n                torch.nn.utils.clip_grad_norm_(params, FLAGS.clip_max_norm)\n\n        optimizer.step()\n        scheduler.step()\n\n        if optimize_geometry:\n            optimizer_mesh.step()\n            scheduler_mesh.step()\n\n        if optimize_light:\n            optimizer_light.step()\n            scheduler_light.step()\n\n        # ==============================================================================================\n        #  Clamp trainables to reasonable range\n        # ==============================================================================================\n        with torch.no_grad():\n            if 'kd' in opt_material:\n                opt_material['kd'].clamp_()\n            if 'ks' in opt_material:\n                opt_material['ks'].clamp_()\n            if 'kd_back' in opt_material:\n                opt_material['kd_back'].clamp_()\n            if 'ks_back' in opt_material:\n                opt_material['ks_back'].clamp_()\n            if 'normal' in opt_material and not FLAGS.normal_only:\n                opt_material['normal'].clamp_()\n                opt_material['normal'].normalize_()\n            if lgt is not None:\n                # lgt.clamp_(min=0.01) # For some reason gradient dissapears if light becomes 0\n                lgt.clamp_(min=1e-4) # For some reason gradient dissapears if light becomes 0\n\n            geometry.clamp_deform()\n        torch.cuda.current_stream().synchronize()\n        iter_dur_vec.append(time.time() - iter_start_time)\n\n        # ==============================================================================================\n        #  Logging\n        # ==============================================================================================\n        if it % log_interval == 0 and FLAGS.local_rank == 0:\n            img_loss_avg = np.mean(np.asarray(img_loss_vec[-log_interval:]))\n            depth_loss_avg = np.mean(np.asarray(depth_loss_vec[-log_interval:]))\n            reg_loss_avg = np.mean(np.asarray(reg_loss_vec[-log_interval:]))\n            iter_dur_avg = np.mean(np.asarray(iter_dur_vec[-log_interval:]))\n            \n            remaining_time = (FLAGS.iter-it)*iter_dur_avg\n            print(\"iter=%5d, img_loss=%.6f, depth_loss=%.6f, reg_loss=%.6f, lr=%.5f, time=%.1f ms, rem=%s\" % \n                (it, img_loss_avg, depth_loss_avg, reg_loss_avg, optimizer.param_groups[0]['lr'], iter_dur_avg*1000, util.time_to_text(remaining_time)))\n            sys.stdout.flush()\n\n        if it == FLAGS.iter:\n            break\n\n    return geometry, opt_material\n\n#----------------------------------------------------------------------------\n# Main function.\n#----------------------------------------------------------------------------\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser(description='nvdiffrec')\n    parser.add_argument('--config', type=str, default=None, help='Config file')\n    parser.add_argument('-i', '--iter', type=int, default=5000)\n    parser.add_argument('-b', '--batch', type=int, default=1)\n    parser.add_argument('-s', '--spp', type=int, default=1)\n    parser.add_argument('-l', '--layers', type=int, default=1)\n    parser.add_argument('-r', '--train-res', nargs=2, type=int, default=[512, 512])\n    parser.add_argument('-dr', '--display-res', type=int, default=None)\n    parser.add_argument('-tr', '--texture-res', nargs=2, type=int, default=[1024, 1024])\n    parser.add_argument('-di', '--display-interval', type=int, default=0)\n    parser.add_argument('-si', '--save-interval', type=int, default=1000)\n    parser.add_argument('-lr', '--learning-rate', type=float, default=0.01)\n    parser.add_argument('-mr', '--min-roughness', type=float, default=0.08)\n    parser.add_argument('-mip', '--custom-mip', action='store_true', default=False)\n    parser.add_argument('-rt', '--random-textures', action='store_true', default=False)\n    parser.add_argument('-bg', '--background', default='checker', choices=['black', 'white', 'checker', 'reference'])\n    parser.add_argument('--loss', default='logl1', choices=['logl1', 'logl2', 'mse', 'smape', 'relmse'])\n    parser.add_argument('-o', '--out-dir', type=str, default=None)\n    parser.add_argument('-rm', '--ref_mesh', type=str)\n    parser.add_argument('-bm', '--base-mesh', type=str, default=None)\n    parser.add_argument('--validate', type=bool, default=True)\n    # Render specific arguments\n    parser.add_argument('--n_samples', type=int, default=4)\n    parser.add_argument('--bsdf', type=str, default='pbr', choices=['pbr', 'diffuse', 'white'])\n    # Denoiser specific arguments\n    parser.add_argument('--denoiser', default='bilateral', choices=['none', 'bilateral'])\n    parser.add_argument('--denoiser_demodulate', type=bool, default=True)\n    parser.add_argument('--index',type=int)\n    parser.add_argument('--msdf_reg_open_scale', type=float, default=1e-6)\n    parser.add_argument('--msdf_reg_close_scale', type=float, default=3e-6)\n    parser.add_argument('--eikonal_scale', type=float)\n    parser.add_argument('--sdf_regularizer', type=float, default=0.2)\n\n    FLAGS = parser.parse_args()\n    FLAGS.mtl_override        = None        # Override material of model\n    FLAGS.gshell_grid          = 64          # Resolution of initial tet grid. We provide 64 and 128 resolution grids. \n                                            #    Other resolutions can be generated with https://github.com/crawforddoran/quartet\n                                            #    We include examples in data/tets/generate_tets.py\n    FLAGS.mesh_scale          = 2.1         # Scale of tet grid box. Adjust to cover the model\n    FLAGS.envlight            = None        # HDR environment probe\n    FLAGS.env_scale           = 1.0         # Env map intensity multiplier\n    FLAGS.probe_res           = 256         # Env map probe resolution\n    FLAGS.learn_lighting      = True        # Enable optimization of env lighting\n    FLAGS.display             = None        # Configure validation window/display. E.g. [{\"bsdf\" : \"kd\"}, {\"bsdf\" : \"ks\"}]\n    FLAGS.transparency        = False       # Enabled transparency through depth peeling\n    FLAGS.lock_light          = False       # Disable light optimization in the second pass\n    FLAGS.lock_pos            = False       # Disable vertex position optimization in the second pass\n    # FLAGS.sdf_regularizer     = 0.2         # Weight for sdf regularizer.\n    FLAGS.laplace             = \"relative\"  # Mesh Laplacian [\"absolute\", \"relative\"]\n    FLAGS.laplace_scale       = 3000.0      # Weight for Laplace regularizer. Default is relative with large weight\n    FLAGS.pre_load            = True        # Pre-load entire dataset into memory for faster training\n    FLAGS.no_perturbed_nrm    = False       # Disable normal map\n    FLAGS.decorrelated        = False       # Use decorrelated sampling in forward and backward passes\n    FLAGS.kd_min              = [ 0.0,  0.0,  0.0,  0.0]\n    FLAGS.kd_max              = [ 1.0,  1.0,  1.0,  1.0]\n    FLAGS.ks_min              = [ 0.0,  0.001, 0.0]\n    FLAGS.ks_max              = [ 0.0,  1.0,  1.0]\n    FLAGS.nrm_min             = [-1.0, -1.0,  0.0]\n    FLAGS.nrm_max             = [ 1.0,  1.0,  1.0]\n    FLAGS.clip_max_norm       = 0.0\n    FLAGS.cam_near_far        = [0.1, 1000.0]\n    FLAGS.lambda_kd           = 0.1\n    FLAGS.lambda_ks           = 0.05\n    FLAGS.lambda_nrm          = 0.025\n    FLAGS.lambda_nrm2         = 0.25\n    FLAGS.lambda_chroma       = 0.0\n    FLAGS.lambda_diffuse      = 0.15\n    FLAGS.lambda_specular     = 0.0025\n\n    FLAGS.random_lgt                  = False\n    FLAGS.normal_only                 = False\n    FLAGS.use_img_2nd_layer           = False\n    FLAGS.use_depth                   = False\n    FLAGS.use_depth_2nd_layer         = False\n    FLAGS.use_tanh_deform             = False\n    FLAGS.use_sdf_mlp                 = True\n    FLAGS.use_msdf_mlp                = False\n    FLAGS.use_eikonal                 = True\n    FLAGS.use_mesh_msdf_reg           = True\n    FLAGS.sphere_init                 = False\n    FLAGS.sphere_init_norm            = 1.0\n    FLAGS.pretrained_sdf_mlp_path     = f'./data/pretrained_mlp_{FLAGS.gshell_grid}_synthetic.pt'\n    FLAGS.n_hidden                    = 6\n    FLAGS.d_hidden                    = 256\n    FLAGS.n_freq                      = 6\n    FLAGS.skip_in                     = [3]\n    FLAGS.use_float16                 = False\n    FLAGS.visualize_watertight        = False\n\n    FLAGS.local_rank = 0\n    FLAGS.multi_gpu  = \"WORLD_SIZE\" in os.environ and int(os.environ[\"WORLD_SIZE\"]) > 1\n    if FLAGS.multi_gpu:\n        if \"MASTER_ADDR\" not in os.environ:\n            os.environ[\"MASTER_ADDR\"] = 'localhost'\n        if \"MASTER_PORT\" not in os.environ:\n            os.environ[\"MASTER_PORT\"] = '23456'\n\n        FLAGS.local_rank = int(os.environ[\"LOCAL_RANK\"])\n        torch.cuda.set_device(FLAGS.local_rank)\n        torch.distributed.init_process_group(backend=\"nccl\", init_method=\"env://\")\n\n    if FLAGS.config is not None:\n        data = json.load(open(FLAGS.config, 'r'))\n        for key in data:\n            FLAGS.__dict__[key] = data[key]\n\n    if FLAGS.display_res is None:\n        FLAGS.display_res = FLAGS.train_res\n\n    if FLAGS.local_rank == 0:\n        print(\"Config / Flags:\")\n        print(\"---------\")\n        for key in FLAGS.__dict__.keys():\n            print(key, FLAGS.__dict__[key])\n        print(\"---------\")\n\n    os.makedirs(FLAGS.out_dir, exist_ok=True)\n\n    glctx = dr.RasterizeGLContext()\n    glctx_display = glctx if FLAGS.batch < 16 else dr.RasterizeGLContext() # Context for display\n\n    mtl_default = None\n\n    # ==============================================================================================\n    #  Create data pipeline\n    # ==============================================================================================\n\n    print(FLAGS.ref_mesh)\n    if os.path.splitext(FLAGS.ref_mesh)[1] == '.obj':\n        ref_mesh         = mesh.load_mesh(FLAGS.ref_mesh, FLAGS.mtl_override)\n        dataset_train    = DatasetMesh(ref_mesh, glctx, RADIUS, FLAGS, validate=False)\n        dataset_validate = DatasetMesh(ref_mesh, glctx_display, RADIUS, FLAGS, validate=True)\n    elif os.path.isdir(FLAGS.ref_mesh):\n        if os.path.isfile(os.path.join(FLAGS.ref_mesh, 'poses_bounds.npy')):\n            dataset_train    = DatasetLLFF(FLAGS.ref_mesh, FLAGS, examples=(FLAGS.iter+1)*FLAGS.batch)\n            dataset_validate = DatasetLLFF(FLAGS.ref_mesh, FLAGS)\n        elif os.path.isfile(os.path.join(FLAGS.ref_mesh, 'transforms_train.json'))  and not os.path.isfile(os.path.join(FLAGS.ref_mesh, 'intrinsics.txt')):\n            dataset_train    = DatasetNERF(os.path.join(FLAGS.ref_mesh, 'transforms_train.json'), FLAGS, examples=(FLAGS.iter+1)*FLAGS.batch)\n            dataset_validate = DatasetNERF(os.path.join(FLAGS.ref_mesh, 'transforms_test.json'), FLAGS)\n        else:\n            assert False, \"Invalid dataset format\"\n    else:\n        print(\"Invalid dataset format\", FLAGS.ref_mesh)\n        assert False, \"Invalid dataset format\"\n\n\n\n    # ==============================================================================================\n    #  Create env light with trainable parameters\n    # ==============================================================================================\n    \n    lgt = None\n    if FLAGS.learn_lighting:\n        lgt = light.create_trainable_env_rnd(FLAGS.probe_res, scale=0.0, bias=0.5)\n        # lgt = light.create_trainable_env_rnd(FLAGS.probe_res, scale=0.0, bias=0.1)\n    else:\n        lgt = light.load_env(FLAGS.envlight, scale=FLAGS.env_scale, res=[FLAGS.probe_res, FLAGS.probe_res])\n\n    # ==============================================================================================\n    #  Setup denoiser\n    # ==============================================================================================\n\n    denoiser = None\n    if FLAGS.denoiser == 'bilateral':\n        denoiser = BilateralDenoiser().cuda()\n    else:\n        assert FLAGS.denoiser == 'none', \"Invalid denoiser %s\" % FLAGS.denoiser\n\n    # Setup geometry for optimization\n    geometry = GShellTetsGeometry(FLAGS.gshell_grid, FLAGS.mesh_scale, FLAGS)\n\n    # Setup textures, make initial guess from reference if possible\n    if not FLAGS.normal_only:\n        mat = initial_guess_material(geometry, True, FLAGS, mtl_default)\n    else:\n        mat = initial_guess_material_knownkskd(geometry, True, FLAGS, mtl_default)\n    mat['no_perturbed_nrm'] = True\n\n    # Run optimization\n    geometry, mat = optimize_mesh(denoiser, glctx, geometry, mat, lgt, dataset_train, dataset_validate, \n                    FLAGS, pass_idx=0, pass_name=\"pass1\", optimize_light=FLAGS.learn_lighting, save_path=FLAGS.out_dir)\n\n    validate(glctx, geometry, mat, lgt, dataset_validate, os.path.join(FLAGS.out_dir, \"validate\"), FLAGS, denoiser=denoiser, save_viz=True)\n\n    with torch.no_grad():\n        os.makedirs(os.path.join(FLAGS.out_dir, \"mesh\"), exist_ok=True)\n        torch.save(geometry.state_dict(), os.path.join(FLAGS.out_dir, \"mesh/model.pt\"))\n        torch.save(mat['kd_ks'].state_dict(), os.path.join(FLAGS.out_dir, \"mesh/mtl.pt\"))\n        light.save_env_map(os.path.join(FLAGS.out_dir, \"mesh/probe.hdr\"), lgt)\n\n        # Create textured mesh from result\n        base_mesh = geometry.getMesh(mat)['imesh']\n\n        # Dump mesh for debugging.\n        os.makedirs(os.path.join(FLAGS.out_dir, \"mesh\"), exist_ok=True)\n        obj.write_obj(os.path.join(FLAGS.out_dir, \"mesh/\"), base_mesh, save_material=False)\n\n        # Free temporaries / cached memory\n        torch.cuda.empty_cache()\n        mat['kd_ks'].cleanup()\n        del mat['kd_ks']\n        if 'kd_ks_back' in mat:\n            mat['kd_ks_back'].cleanup()\n            del mat['kd_ks_back']\n\n        # Free temporaries / cached memory\n        torch.cuda.empty_cache()\n        del mat"
  }
]