Repository: lzzcd001/GShell Branch: main Commit: c2f0ba9ea01a Files: 124 Total size: 1.2 MB Directory structure: gitextract_zb4dyoqx/ ├── .gitignore ├── GMeshDiffusion/ │ ├── diffusion_configs/ │ │ ├── config_lower_occgrid_normalized.py │ │ └── config_upper_occgrid_normalized.py │ ├── lib/ │ │ ├── dataset/ │ │ │ ├── gshell_dataset.py │ │ │ └── gshell_dataset_aug.py │ │ └── diffusion/ │ │ ├── evaler.py │ │ ├── likelihood.py │ │ ├── losses.py │ │ ├── models/ │ │ │ ├── __init__.py │ │ │ ├── ema.py │ │ │ ├── functional.py │ │ │ ├── layers.py │ │ │ ├── normalization.py │ │ │ ├── unet3d_occgrid.py │ │ │ └── utils.py │ │ ├── sampling.py │ │ ├── sde_lib.py │ │ ├── trainer.py │ │ ├── trainer_ddp.py │ │ └── utils.py │ ├── main_diffusion.py │ ├── main_diffusion_ddp.py │ ├── metadata/ │ │ ├── get_splits_lower.py │ │ ├── get_splits_upper.py │ │ ├── save_tet_info.py │ │ └── tet_to_cubic_grid_dataset.py │ └── scripts/ │ ├── run_eval_lower_occgrid_normalized.sh │ ├── run_eval_upper_occgrid_normalized.sh │ ├── run_lower_occgrid_normalized_ddp.sh │ └── run_upper_occgrid_normalized_ddp.sh ├── README.md ├── configs/ │ ├── deepfashion_mc.json │ ├── deepfashion_mc_256.json │ ├── deepfashion_mc_512.json │ ├── deepfashion_mc_80.json │ ├── nerf_chair.json │ ├── polycam_mc.json │ ├── polycam_mc_128.json │ └── polycam_mc_16samples.json ├── data/ │ └── tets/ │ └── generate_tets.py ├── dataset/ │ ├── __init__.py │ ├── dataset.py │ ├── dataset_deepfashion.py │ ├── dataset_deepfashion_testset.py │ ├── dataset_llff.py │ ├── dataset_mesh.py │ ├── dataset_nerf.py │ └── dataset_nerf_colmap.py ├── denoiser/ │ └── denoiser.py ├── eval_gmeshdiffusion_generated_samples.py ├── geometry/ │ ├── embedding.py │ ├── flexicubes_table.py │ ├── gshell_flexicubes.py │ ├── gshell_flexicubes_geometry.py │ ├── gshell_tets.py │ ├── gshell_tets_geometry.py │ └── mlp.py ├── render/ │ ├── light.py │ ├── material.py │ ├── mesh.py │ ├── mlptexture.py │ ├── obj.py │ ├── optixutils/ │ │ ├── __init__.py │ │ ├── c_src/ │ │ │ ├── accessor.h │ │ │ ├── bsdf.h │ │ │ ├── common.h │ │ │ ├── denoising.cu │ │ │ ├── denoising.h │ │ │ ├── envsampling/ │ │ │ │ ├── kernel.cu │ │ │ │ └── params.h │ │ │ ├── math_utils.h │ │ │ ├── optix_wrapper.cpp │ │ │ ├── optix_wrapper.h │ │ │ └── torch_bindings.cpp │ │ ├── include/ │ │ │ ├── internal/ │ │ │ │ ├── optix_7_device_impl.h │ │ │ │ ├── optix_7_device_impl_exception.h │ │ │ │ └── optix_7_device_impl_transformations.h │ │ │ ├── optix.h │ │ │ ├── optix_7_device.h │ │ │ ├── optix_7_host.h │ │ │ ├── optix_7_types.h │ │ │ ├── optix_denoiser_tiling.h │ │ │ ├── optix_device.h │ │ │ ├── optix_function_table.h │ │ │ ├── optix_function_table_definition.h │ │ │ ├── optix_host.h │ │ │ ├── optix_stack_size.h │ │ │ ├── optix_stubs.h │ │ │ └── optix_types.h │ │ ├── ops.py │ │ └── tests/ │ │ └── filter_test.py │ ├── regularizer.py │ ├── render.py │ ├── renderutils/ │ │ ├── __init__.py │ │ ├── bsdf.py │ │ ├── c_src/ │ │ │ ├── bsdf.cu │ │ │ ├── bsdf.h │ │ │ ├── common.cpp │ │ │ ├── common.h │ │ │ ├── cubemap.cu │ │ │ ├── cubemap.h │ │ │ ├── loss.cu │ │ │ ├── loss.h │ │ │ ├── mesh.cu │ │ │ ├── mesh.h │ │ │ ├── normal.cu │ │ │ ├── normal.h │ │ │ ├── tensor.h │ │ │ ├── torch_bindings.cpp │ │ │ ├── vec3f.h │ │ │ └── vec4f.h │ │ ├── loss.py │ │ ├── ops.py │ │ └── tests/ │ │ ├── test_bsdf.py │ │ ├── test_loss.py │ │ ├── test_mesh.py │ │ └── test_perf.py │ ├── texture.py │ └── util.py ├── train_gflexicubes_deepfashion.py ├── train_gflexicubes_polycam.py ├── train_gshelltet_deepfashion.py ├── train_gshelltet_polycam.py └── train_gshelltet_synthetic.py ================================================ FILE CONTENTS ================================================ ================================================ FILE: .gitignore ================================================ # Byte-compiled / optimized / DLL files __pycache__/ *.py[cod] *$py.class # C extensions *.so # Distribution / packaging .Python build/ develop-eggs/ dist/ downloads/ eggs/ .eggs/ # lib/ lib64/ parts/ sdist/ var/ wheels/ share/python-wheels/ *.egg-info/ .installed.cfg *.egg MANIFEST # PyInstaller # Usually these files are written by a python script from a template # before PyInstaller builds the exe, so as to inject date/other infos into it. *.manifest *.spec # Installer logs pip-log.txt pip-delete-this-directory.txt # Unit test / coverage reports htmlcov/ .tox/ .nox/ .coverage .coverage.* .cache nosetests.xml coverage.xml *.cover *.py,cover .hypothesis/ .pytest_cache/ cover/ # Translations *.mo *.pot # Django stuff: *.log local_settings.py db.sqlite3 db.sqlite3-journal # Flask stuff: instance/ .webassets-cache # Scrapy stuff: .scrapy # Sphinx documentation docs/_build/ # PyBuilder .pybuilder/ target/ # Jupyter Notebook .ipynb_checkpoints # IPython profile_default/ ipython_config.py # pyenv # For a library or package, you might want to ignore these files since the code is # intended to run in multiple environments; otherwise, check them in: # .python-version # pipenv # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. # However, in case of collaboration, if having platform-specific dependencies or dependencies # having no cross-platform support, pipenv may install dependencies that don't work, or not # install all needed dependencies. #Pipfile.lock # poetry # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. # This is especially recommended for binary packages to ensure reproducibility, and is more # commonly ignored for libraries. # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control #poetry.lock # pdm # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. #pdm.lock # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it # in version control. # https://pdm.fming.dev/#use-with-ide .pdm.toml # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm __pypackages__/ # Celery stuff celerybeat-schedule celerybeat.pid # SageMath parsed files *.sage.py # Environments .env .venv env/ venv/ ENV/ env.bak/ venv.bak/ # Spyder project settings .spyderproject .spyproject # Rope project settings .ropeproject # mkdocs documentation /site # mypy .mypy_cache/ .dmypy.json dmypy.json # Pyre type checker .pyre/ # pytype static type analyzer .pytype/ # Cython debug symbols cython_debug/ # PyCharm # JetBrains specific template is maintained in a separate JetBrains.gitignore that can # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore # and can be added to the global gitignore or merged into this file. For a more nuclear # option (not recommended) you can uncomment the following to ignore the entire idea folder. #.idea/ .DS_STORE ================================================ FILE: GMeshDiffusion/diffusion_configs/config_lower_occgrid_normalized.py ================================================ import ml_collections import torch import os def get_config(): config = ml_collections.ConfigDict() # data data = config.data = ml_collections.ConfigDict() data.root_dir = 'PLACEHOLDER' # data.dataset_metapath = os.path.join(data.root_dir, 'metadata/lower_res64_train.txt') data.num_workers = 4 data.grid_size = 128 data.tet_resolution = 64 data.num_channels = 4 data.use_occ_grid = True data.grid_metafile = os.path.join(data.root_dir, 'metadata/lower_res64_grid_train.txt') data.occgrid_metafile = os.path.join(data.root_dir, 'metadata/lower_res64_occgrid_train.txt') data.occ_mask_path = os.path.join(data.root_dir, 'metadata/occ_mask_res64.pt') data.tet_info_path = os.path.join(data.root_dir, 'metadata/tet_info.pt') data.filter_meta_path = None data.aug = True # training training = config.training = ml_collections.ConfigDict() training.sde = 'vpsde' training.continuous = False training.reduce_mean = True training.batch_size = 1 ### for DDP, global_batch_size = nproc * local_batch_size training.num_grad_acc_steps = 4 training.n_iters = 2400001 training.snapshot_freq = 1000 training.log_freq = 50 ## produce samples at each snapshot. training.snapshot_sampling = True training.likelihood_weighting = False training.loss_type = 'l2' training.train_dir = "PLACEHOLDER" training.snapshot_freq_for_preemption = 1000 training.gradscaler_growth_interval = 1000 training.use_aux_loss = False training.compile = True # PyTorch 2.0, torch.compile training.enable_xformers_memory_efficient_attention = True # sampling sampling = config.sampling = ml_collections.ConfigDict() sampling.method = 'pc' sampling.predictor = 'ancestral_sampling' sampling.corrector = 'none' sampling.n_steps_each = 1 sampling.noise_removal = True sampling.probability_flow = False sampling.snr = 0.075 # model model = config.model = ml_collections.ConfigDict() model.name = 'unet3d_occgrid' model.use_occ_grid = True model.num_res_blocks = 2 model.num_res_blocks_1st_layer = 2 model.base_channels = 128 model.ch_mult = (1, 2, 2, 4, 4, 4) model.down_block_types = ( "ResBlock", "ResBlock", "ResBlock", "AttnResBlock", "ResBlock", "ResBlock" ) model.up_block_types = ( "ResBlock", "ResBlock", "AttnResBlock", "ResBlock", "ResBlock", "ResBlock" ) model.scale_by_sigma = False model.num_scales = 1000 model.ema_rate = 0.9999 model.normalization = 'GroupNorm' model.act_fn = 'swish' model.attn_resolutions = (16,) model.resamp_with_conv = True model.dropout = 0.1 model.sigma_max = 378 model.sigma_min = 0.01 model.beta_min = 0.1 model.beta_max = 20. model.embedding_type = 'fourier' model.pred_type = 'noise' model.conditional = True model.feature_mask_path = os.path.join(data.root_dir, 'metadata/global_mask_res64.pt') model.pixcat_mask_path = os.path.join(data.root_dir, 'metadata/cat_mask_res64.pt') # optimization config.optim = optim = ml_collections.ConfigDict() optim.weight_decay = 1e-5 optim.optimizer = 'AdamW' optim.lr = 1e-5 optim.beta1 = 0.9 optim.eps = 1e-8 optim.warmup = 5000 optim.grad_clip = 1. # eval config.eval = eval_config = ml_collections.ConfigDict() eval_config.batch_size = 2 eval_config.idx = 0 eval_config.bin_size = 30 eval_config.eval_dir = "PLACEHOLDER" eval_config.ckpt_path = "PLACEHOLDER" config.seed = 42 config.device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu') return config ================================================ FILE: GMeshDiffusion/diffusion_configs/config_upper_occgrid_normalized.py ================================================ import ml_collections import torch import os def get_config(): config = ml_collections.ConfigDict() # data data = config.data = ml_collections.ConfigDict() data.root_dir = 'PLACEHOLDER' # data.dataset_metapath = os.path.join(data.root_dir, 'metadata/upper_res64_train.txt') data.num_workers = 4 data.grid_size = 128 data.tet_resolution = 64 data.num_channels = 4 data.use_occ_grid = True data.grid_metafile = os.path.join(data.root_dir, 'metadata/upper_res64_grid_train.txt') data.occgrid_metafile = os.path.join(data.root_dir, 'metadata/upper_res64_occgrid_train.txt') data.occ_mask_path = os.path.join(data.root_dir, 'metadata/occ_mask_res64.pt') data.tet_info_path = os.path.join(data.root_dir, 'metadata/tet_info.pt') data.filter_meta_path = None data.aug = True # training training = config.training = ml_collections.ConfigDict() training.sde = 'vpsde' training.continuous = False training.reduce_mean = True training.batch_size = 1 ### for DDP, global_batch_size = nproc * local_batch_size training.num_grad_acc_steps = 4 training.n_iters = 2400001 training.snapshot_freq = 1000 training.log_freq = 50 ## produce samples at each snapshot. training.snapshot_sampling = True training.likelihood_weighting = False training.loss_type = 'l2' training.train_dir = "PLACEHOLDER" training.snapshot_freq_for_preemption = 1000 training.gradscaler_growth_interval = 1000 training.use_aux_loss = False training.compile = True # PyTorch 2.0, torch.compile training.enable_xformers_memory_efficient_attention = True # sampling sampling = config.sampling = ml_collections.ConfigDict() sampling.method = 'pc' sampling.predictor = 'ancestral_sampling' sampling.corrector = 'none' sampling.n_steps_each = 1 sampling.noise_removal = True sampling.probability_flow = False sampling.snr = 0.075 # model model = config.model = ml_collections.ConfigDict() model.name = 'unet3d_occgrid' model.use_occ_grid = True model.num_res_blocks = 2 model.num_res_blocks_1st_layer = 2 model.base_channels = 128 model.ch_mult = (1, 2, 2, 4, 4, 4) model.down_block_types = ( "ResBlock", "ResBlock", "ResBlock", "AttnResBlock", "ResBlock", "ResBlock" ) model.up_block_types = ( "ResBlock", "ResBlock", "AttnResBlock", "ResBlock", "ResBlock", "ResBlock" ) model.scale_by_sigma = False model.num_scales = 1000 model.ema_rate = 0.9999 model.normalization = 'GroupNorm' model.act_fn = 'swish' model.attn_resolutions = (16,) model.resamp_with_conv = True model.dropout = 0.1 model.sigma_max = 378 model.sigma_min = 0.01 model.beta_min = 0.1 model.beta_max = 20. model.embedding_type = 'fourier' model.pred_type = 'noise' model.conditional = True model.feature_mask_path = os.path.join(data.root_dir, 'metadata/global_mask_res64_occaug_normalized_v1.pt') model.pixcat_mask_path = os.path.join(data.root_dir, 'metadata/cat_mask_res64_occaug_normalized_v1.pt') # optimization config.optim = optim = ml_collections.ConfigDict() optim.weight_decay = 1e-5 optim.optimizer = 'AdamW' optim.lr = 1e-5 optim.beta1 = 0.9 optim.eps = 1e-8 optim.warmup = 5000 optim.grad_clip = 1. # eval config.eval = eval_config = ml_collections.ConfigDict() eval_config.batch_size = 2 eval_config.idx = 0 eval_config.bin_size = 30 eval_config.eval_dir = "PLACEHOLDER" eval_config.ckpt_path = "PLACEHOLDER" config.seed = 42 config.device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu') return config ================================================ FILE: GMeshDiffusion/lib/dataset/gshell_dataset.py ================================================ import torch import numpy as np from torch.utils.data import Dataset class GShellDataset(Dataset): def __init__(self, filepath_metafile, extension='pt'): super().__init__() with open(filepath_metafile, 'r') as f: self.filepath_list = [fpath.rstrip() for fpath in f] self.extension = extension assert self.extension in ['pt', 'npy'] def __len__(self): return len(self.filepath_list) def __getitem__(self, idx): with torch.no_grad(): if self.extension == 'pt': datum = torch.load(self.filepath_list[idx], map_location='cpu') else: datum = torch.tensor(np.load(self.filepath_list[idx])) return datum ================================================ FILE: GMeshDiffusion/lib/dataset/gshell_dataset_aug.py ================================================ import torch from torch.utils.data import Dataset class GShellAugDataset(Dataset): def __init__(self, FLAGS, extension='pt'): super().__init__() with open(FLAGS.data.grid_metafile, 'r') as f: self.filepath_list = [fpath.rstrip() for fpath in f] with open(FLAGS.data.occgrid_metafile, 'r') as f: self.occ_filepath_list = [fpath.rstrip() for fpath in f] self.extension = extension self.num_channels = FLAGS.data.num_channels print('num_channels: ', self.num_channels) assert self.extension in ['pt', 'npy'] def __len__(self): return len(self.filepath_list) def __getitem__(self, idx): with torch.no_grad(): grid = torch.load(self.filepath_list[idx], map_location='cpu') try: occ_grid = torch.load(self.occ_filepath_list[idx], map_location='cpu') except: print(self.occ_filepath_list[idx]) raise return (grid[:self.num_channels], occ_grid) @staticmethod def collate(data): return { 'grid': torch.stack([x[0] for x in data]), 'occgrid': torch.stack([x[1] for x in data]), } ================================================ FILE: GMeshDiffusion/lib/diffusion/evaler.py ================================================ import os import sys import numpy as np import tqdm import logging from . import losses from .models import utils as mutils from .models.ema import ExponentialMovingAverage from . import sde_lib import torch from .utils import restore_checkpoint from . import sampling def uncond_gen( config ): """ Unconditional Generation """ with torch.no_grad(): eval_dir, ckpt_path = config.eval.eval_dir, config.eval.ckpt_path idx = config.eval.idx bin_size = config.eval.bin_size print(f"idx to save: {idx * bin_size} to {idx * bin_size + bin_size - 1}") # Create directory to eval_folder os.makedirs(eval_dir, exist_ok=True) scaler, inverse_scaler = lambda x: x, lambda x: x # Initialize model score_model = mutils.create_model(config, use_parallel=False) optimizer = losses.get_optimizer(config, score_model.parameters()) ema = ExponentialMovingAverage(score_model.parameters(), decay=config.model.ema_rate) state = dict(optimizer=optimizer, model=score_model, ema=ema, step=0) # Setup SDEs sde = sde_lib.VPSDE(beta_min=config.model.beta_min, beta_max=config.model.beta_max, N=config.model.num_scales) sampling_eps = 1e-3 sampling_shape = (config.eval.batch_size, config.data.num_channels, config.data.grid_size, config.data.grid_size, config.data.grid_size) sampling_fn = sampling.get_sampling_fn(config, sde, sampling_shape, inverse_scaler, sampling_eps) assert os.path.exists(ckpt_path) print('ckpt path:', ckpt_path) try: state = restore_checkpoint(ckpt_path, state, device=config.device) except: raise ema.copy_to(score_model.parameters()) print(f"loaded model is trained till iter {state['step'] // config.training.num_grad_acc_steps}") for k in range(bin_size): save_file_path = os.path.join(eval_dir, f"{idx * bin_size + k}") print(f'check: {save_file_path}') if os.path.exists(save_file_path + '.pt'): # continue pass print(f'will save to: {save_file_path}') samples, n = sampling_fn(score_model) if type(samples) != tuple: print(samples[:, 0].unique()) torch.save(samples, save_file_path + '.pt') samples = samples.cpu().numpy() # np.save(save_file_path, samples) else: print(samples[0][:, 0].unique()) torch.save(samples[0], save_file_path + '.pt') torch.save(samples[1], save_file_path + '_occ.pt') # samples, occ = samples[0].cpu().numpy(), samples[1].cpu().numpy() # np.save(save_file_path + '.npy, samples) def slerp(z1, z2, alpha): ''' Spherical Linear Interpolation ''' theta = torch.acos(torch.sum(z1 * z2) / (torch.norm(z1) * torch.norm(z2))) return ( torch.sin((1 - alpha) * theta) / torch.sin(theta) * z1 + torch.sin(alpha * theta) / torch.sin(theta) * z2 ) def uncond_gen_interp( config, idx=0, ): """ Generation with interpolation between initial noises Used for DDIM """ with torch.no_grad(): eval_dir, ckpt_path = config.eval.eval_dir, config.eval.ckpt_path # Create directory to eval_folder os.makedirs(eval_dir, exist_ok=True) scaler, inverse_scaler = lambda x: x, lambda x: x # Initialize model score_model = mutils.create_model(config, use_parallel=False) optimizer = losses.get_optimizer(config, score_model.parameters()) ema = ExponentialMovingAverage(score_model.parameters(), decay=config.model.ema_rate) state = dict(optimizer=optimizer, model=score_model, ema=ema, step=0) # Setup SDEs sde = sde_lib.VPSDE(beta_min=config.model.beta_min, beta_max=config.model.beta_max, N=config.model.num_scales) sampling_eps = 1e-3 sampling_shape = (config.eval.batch_size, config.data.num_channels, config.data.grid_size, config.data.grid_size, config.data.grid_size) sampling_fn = sampling.get_sampling_fn(config, sde, sampling_shape, inverse_scaler, sampling_eps) assert os.path.exists(ckpt_path) print('ckpt path:', ckpt_path) try: state = restore_checkpoint(ckpt_path, state, device=config.device) except: raise ema.copy_to(score_model.parameters()) print(f"loaded model is trained till iter {state['step'] // config.training.num_grad_acc_steps}") idx = config.eval.idx bin_size = config.eval.bin_size config.eval.interp_batch_size = 32 print(f"idx to save: {idx * bin_size} to {idx * bin_size + bin_size - 1}") for k in range(bin_size): save_file_path = os.path.join(eval_dir, f"{idx * bin_size + k}") noise = sde.prior_sampling( (2, config.data.num_channels, config.data.grid_size, config.data.grid_size, config.data.grid_size) ).to(config.device) interp_sampling_shape = (config.eval.interp_batch_size, config.data.num_channels, config.data.grid_size, config.data.grid_size, config.data.grid_size) x0 = torch.zeros(interp_sampling_shape, device=config.device) x0[0] = noise[0] x0[-1] = noise[1] for i in range(1, config.eval.interp_batch_size - 1): x0[i] = slerp(x0[0], x0[-1], i / float(config.eval.interp_batch_size - 1)) if config.model.use_occ_grid: noise_occ = sde.prior_sampling( (2, 1, config.data.grid_size * 2, config.data.grid_size * 2, config.data.grid_size * 2) ).to(config.device) interp_sampling_shape = (config.eval.interp_batch_size, 1, config.data.grid_size * 2, config.data.grid_size * 2, config.data.grid_size * 2) x0_occ = torch.zeros(interp_sampling_shape, device=config.device) x0_occ[0] = noise_occ[0] x0_occ[-1] = noise_occ[1] for i in range(1, config.eval.interp_batch_size - 1): x0_occ[i] = slerp(x0_occ[0], x0_occ[-1], i / float(config.eval.interp_batch_size - 1)) else: x0_occ = None sample_list = [] sample_occ_list = [] for i in tqdm.trange(config.eval.interp_batch_size): samples, n = sampling_fn(score_model, x0=x0[i:i+1], x0_occ=x0_occ[i:i+1]) if type(samples) != tuple: # samples = samples.cpu() sample_list.append(samples.cpu()) else: # samples = samples.cpu() sample_list.append(samples[0].cpu()) sample_occ_list.append(samples[1].cpu()) # np.save(save_file_path, np.concatenate(sample_list, axis=0)) torch.save(torch.cat(sample_list, dim=0), save_file_path + '.pt') if config.model.use_occ_grid: torch.save(torch.cat(sample_occ_list, dim=0), save_file_path + '_occ.pt') def cond_gen( config, save_fname='0', ): """ Conditional Generation with partially completed dmtet from a 2.5D view (converted into a cubic grid) """ with torch.no_grad(): eval_dir, ckpt_path = config.eval.eval_dir, config.eval.ckpt_path # Create directory to eval_folder os.makedirs(eval_dir, exist_ok=True) scaler, inverse_scaler = lambda x: x, lambda x: x # Initialize model score_model = mutils.create_model(config) optimizer = losses.get_optimizer(config, score_model.parameters()) ema = ExponentialMovingAverage(score_model.parameters(), decay=config.model.ema_rate) state = dict(optimizer=optimizer, model=score_model, ema=ema, step=0) # Setup SDEs sde = sde_lib.VPSDE(beta_min=config.model.beta_min, beta_max=config.model.beta_max, N=config.model.num_scales) resolution = config.data.image_size grid_mask = torch.load(f'./data/grid_mask_{resolution}.pt').view(1, 1, resolution, resolution, resolution).to("cuda") grid_mask = grid_mask[:, :, :config.data.input_size, :config.data.input_size, :config.data.input_size] sampling_eps = 1e-3 sampling_shape = (config.eval.batch_size, config.data.num_channels, # resolution, resolution, resolution) config.data.input_size, config.data.input_size, config.data.input_size) sampling_fn = sampling.get_sampling_fn(config, sde, sampling_shape, inverse_scaler, sampling_eps, grid_mask=grid_mask) assert os.path.exists(ckpt_path) print('ckpt path:', ckpt_path) try: state = restore_checkpoint(ckpt_path, state, device=config.device) except: raise ema.copy_to(score_model.parameters()) print(f"loaded model is trained till iter {state['step'] // config.training.iter_size}") save_file_path = os.path.join(eval_dir, f"{save_fname}.npy") ### Conditional but free gradients; start from small t partial_dict = torch.load(config.eval.partial_dmtet_path) partial_sdf = partial_dict['sdf'] partial_mask = partial_dict['vis'] ### compute the mapping from tet indices to 3D cubic grid vertex indices tet_path = config.eval.tet_path tet = np.load(tet_path) vertices = torch.tensor(tet['vertices']) vertices_unique = vertices[:].unique() dx = vertices_unique[1] - vertices_unique[0] ind_to_coord = (torch.round( (vertices - vertices.min()) / dx) ).long() partial_sdf_grid = torch.zeros((1, 1, resolution, resolution, resolution)) partial_sdf_grid[0, 0, ind_to_coord[:, 0], ind_to_coord[:, 1], ind_to_coord[:, 2]] = partial_sdf partial_mask_grid = torch.zeros((1, 1, resolution, resolution, resolution)) partial_mask_grid[0, 0, ind_to_coord[:, 0], ind_to_coord[:, 1], ind_to_coord[:, 2]] = partial_mask.float() samples, n = sampling_fn( score_model, partial=partial_sdf_grid.cuda(), partial_mask=partial_mask_grid.cuda(), freeze_iters=config.eval.freeze_iters ) samples = samples.cpu().numpy() np.save(save_file_path, samples) ================================================ FILE: GMeshDiffusion/lib/diffusion/likelihood.py ================================================ # coding=utf-8 # Copyright 2020 The Google Research Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # pylint: skip-file # pytype: skip-file """Various sampling methods.""" import torch import numpy as np from scipy import integrate from .models import utils as mutils def get_div_fn(fn): """Create the divergence function of `fn` using the Hutchinson-Skilling trace estimator.""" def div_fn(x, t, eps): with torch.enable_grad(): x.requires_grad_(True) fn_eps = torch.sum(fn(x, t) * eps) grad_fn_eps = torch.autograd.grad(fn_eps, x)[0] x.requires_grad_(False) return torch.sum(grad_fn_eps * eps, dim=tuple(range(1, len(x.shape)))) return div_fn def get_likelihood_fn(sde, inverse_scaler, hutchinson_type='Rademacher', rtol=1e-5, atol=1e-5, method='RK45', eps=1e-5): """Create a function to compute the unbiased log-likelihood estimate of a given data point. Args: sde: A `sde_lib.SDE` object that represents the forward SDE. inverse_scaler: The inverse data normalizer. hutchinson_type: "Rademacher" or "Gaussian". The type of noise for Hutchinson-Skilling trace estimator. rtol: A `float` number. The relative tolerance level of the black-box ODE solver. atol: A `float` number. The absolute tolerance level of the black-box ODE solver. method: A `str`. The algorithm for the black-box ODE solver. See documentation for `scipy.integrate.solve_ivp`. eps: A `float` number. The probability flow ODE is integrated to `eps` for numerical stability. Returns: A function that a batch of data points and returns the log-likelihoods in bits/dim, the latent code, and the number of function evaluations cost by computation. """ def drift_fn(model, x, t): """The drift function of the reverse-time SDE.""" score_fn = mutils.get_score_fn(sde, model, train=False, continuous=True) # Probability flow ODE is a special case of Reverse SDE rsde = sde.reverse(score_fn, probability_flow=True) return rsde.sde(x, t)[0] def div_fn(model, x, t, noise): return get_div_fn(lambda xx, tt: drift_fn(model, xx, tt))(x, t, noise) def likelihood_fn(model, data): """Compute an unbiased estimate to the log-likelihood in bits/dim. Args: model: A score model. data: A PyTorch tensor. Returns: bpd: A PyTorch tensor of shape [batch size]. The log-likelihoods on `data` in bits/dim. z: A PyTorch tensor of the same shape as `data`. The latent representation of `data` under the probability flow ODE. nfe: An integer. The number of function evaluations used for running the black-box ODE solver. """ with torch.no_grad(): shape = data.shape if hutchinson_type == 'Gaussian': epsilon = torch.randn_like(data) elif hutchinson_type == 'Rademacher': epsilon = torch.randint_like(data, low=0, high=2).float() * 2 - 1. else: raise NotImplementedError(f"Hutchinson type {hutchinson_type} unknown.") def ode_func(t, x): sample = mutils.from_flattened_numpy(x[:-shape[0]], shape).to(data.device).type(torch.float32) vec_t = torch.ones(sample.shape[0], device=sample.device) * t drift = mutils.to_flattened_numpy(drift_fn(model, sample, vec_t)) logp_grad = mutils.to_flattened_numpy(div_fn(model, sample, vec_t, epsilon)) return np.concatenate([drift, logp_grad], axis=0) init = np.concatenate([mutils.to_flattened_numpy(data), np.zeros((shape[0],))], axis=0) solution = integrate.solve_ivp(ode_func, (eps, sde.T), init, rtol=rtol, atol=atol, method=method) nfe = solution.nfev zp = solution.y[:, -1] z = mutils.from_flattened_numpy(zp[:-shape[0]], shape).to(data.device).type(torch.float32) delta_logp = mutils.from_flattened_numpy(zp[-shape[0]:], (shape[0],)).to(data.device).type(torch.float32) prior_logp = sde.prior_logp(z) bpd = -(prior_logp + delta_logp) / np.log(2) N = np.prod(shape[1:]) bpd = bpd / N # A hack to convert log-likelihoods to bits/dim offset = 7. - inverse_scaler(-1.) bpd = bpd + offset return bpd, z, nfe return likelihood_fn ================================================ FILE: GMeshDiffusion/lib/diffusion/losses.py ================================================ # coding=utf-8 # Copyright 2020 The Google Research Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """All functions related to loss computation and optimization. """ import torch import torch.optim as optim import numpy as np from .models import utils as mutils def get_optimizer(config, params): """Returns a flax optimizer object based on `config`.""" if config.optim.optimizer == 'Adam': optimizer = optim.Adam(params, lr=config.optim.lr, betas=(config.optim.beta1, 0.999), eps=config.optim.eps, weight_decay=config.optim.weight_decay) elif config.optim.optimizer == 'AdamW': optimizer = optim.AdamW(params, lr=config.optim.lr, betas=(config.optim.beta1, 0.999), eps=config.optim.eps, weight_decay=config.optim.weight_decay) else: raise NotImplementedError( f'Optimizer {config.optim.optimizer} not supported yet!') return optimizer def optimization_manager(config): """Returns an optimize_fn based on `config`.""" def optimize_fn(optimizer, params, step, lr=config.optim.lr, warmup=config.optim.warmup, grad_clip=config.optim.grad_clip, gradscaler=None): """Optimizes with warmup and gradient clipping (disabled if negative).""" if warmup > 0: for g in optimizer.param_groups: g['lr'] = lr * np.minimum(step / warmup, 1.0) if grad_clip >= 0: gradscaler.unscale_(optimizer) torch.nn.utils.clip_grad_norm_(params, max_norm=grad_clip) gradscaler.step(optimizer) gradscaler.update() # optimizer.step() return optimize_fn def get_ddpm_loss_fn(vpsde, train, loss_type='l2', pred_type='noise', use_vis_mask=False, use_occ=False, use_aux=False): """Legacy code to reproduce previous results on DDPM. Not recommended for new work.""" if use_occ: def loss_fn(model, batch, use_mesh_reg=False, verts_discretiezd=None, midpoints_discretiezd=None, edges=None): batch, batch_occ = batch['grid'], batch['occgrid'] model_fn = mutils.get_model_fn(model, train=train) labels = torch.randint(0, vpsde.N, (batch.shape[0],), device=batch.device) sqrt_alphas_cumprod = vpsde.sqrt_alphas_cumprod.to(batch.device) sqrt_1m_alphas_cumprod = vpsde.sqrt_1m_alphas_cumprod.to(batch.device) with torch.no_grad(): noise = torch.randn_like(batch, device=batch.device) noise_occ = torch.randn_like(batch_occ, device=batch.device) perturbed_data = sqrt_alphas_cumprod[labels, None, None, None, None] * batch + \ sqrt_1m_alphas_cumprod[labels, None, None, None, None] * noise perturbed_data = perturbed_data.type(batch.dtype) perturbed_data_occ = sqrt_alphas_cumprod[labels, None, None, None, None] * batch_occ + \ sqrt_1m_alphas_cumprod[labels, None, None, None, None] * noise_occ perturbed_data_occ = perturbed_data_occ.type(batch_occ.dtype) with torch.cuda.amp.autocast(dtype=torch.bfloat16): pred, pred_occ = model_fn((perturbed_data, perturbed_data_occ), labels) pred, pred_occ = pred.float(), pred_occ.float() alphas1 = sqrt_alphas_cumprod[labels, None, None, None, None] alphas2 = sqrt_1m_alphas_cumprod[labels, None, None, None, None] if pred_type == 'noise': score = pred score_occ = pred_occ x0 = (perturbed_data - score * alphas2) / alphas1 x0_occ = (perturbed_data_occ - score_occ * alphas2) / alphas1 elif pred_type == 'x0': x0 = pred x0_occ = pred_occ score = (perturbed_data - x0 * alphas1) / alphas2 score_occ = (perturbed_data_occ - pred_occ * alphas1) / alphas2 # noise = noise[:, :, :score.size(2), :score.size(3), :score.size(4)] ### to accommodate change of size due to arch if loss_type == 'l2': losses = torch.square(score - noise) losses_occ = torch.square(score_occ - noise_occ) assert losses_occ.size(1) == 1 elif loss_type == 'l1': raise NotImplementedError losses = torch.abs(score - noise) else: raise NotImplementedError mask = model.module.feature_mask occ_mask = model.module.occ_mask assert len(mask.size()) == 5 assert mask.size(1) == losses.size(1) assert occ_mask.size(1) == losses_occ.size(1) assert losses.size(0) == losses_occ.size(0) if mask is not None: losses = losses * mask losses_occ = losses_occ * occ_mask occ_loss_scale = 1.0 if not use_aux else 1.0 loss = (torch.sum(losses) + torch.sum(losses_occ)) / (mask.sum() + occ_mask.sum()) / losses.size(0) else: raise NotImplementedError if use_aux: pred_vis = model.module.extract_vis_from_cubicgrid(x0, x0_occ) with torch.no_grad(): gt_vis = model.module.extract_vis_from_cubicgrid(batch, batch_occ.view(*x0_occ.size())) reg_loss = ( (pred_vis - gt_vis).pow(2).view(x0.size(0), -1).mean(dim=-1) * sqrt_alphas_cumprod[labels] ).mean() else: reg_loss = torch.zeros_like(loss) total_loss = loss + reg_loss return total_loss, loss, reg_loss else: def loss_fn(model, batch, use_mesh_reg=False, verts_discretiezd=None, midpoints_discretiezd=None, edges=None): model_fn = mutils.get_model_fn(model, train=train) labels = torch.randint(0, vpsde.N, (batch.shape[0],), device=batch.device) sqrt_alphas_cumprod = vpsde.sqrt_alphas_cumprod.to(batch.device) sqrt_1m_alphas_cumprod = vpsde.sqrt_1m_alphas_cumprod.to(batch.device) noise = torch.randn_like(batch, device=batch.device) perturbed_data = sqrt_alphas_cumprod[labels, None, None, None, None] * batch + \ sqrt_1m_alphas_cumprod[labels, None, None, None, None] * noise perturbed_data = perturbed_data.type(batch.dtype) with torch.cuda.amp.autocast(dtype=torch.bfloat16): pred = model_fn(perturbed_data, labels) pred = pred.float() alphas1 = sqrt_alphas_cumprod[labels, None, None, None, None] alphas2 = sqrt_1m_alphas_cumprod[labels, None, None, None, None] if pred_type == 'noise': score = pred x0 = (perturbed_data - score * alphas2) / alphas1 elif pred_type == 'x0': x0 = pred score = (perturbed_data - x0 * alphas1) / alphas2 if use_vis_mask: assert x0.size(0) == 1 vis_mask = model.extract_vismask_from_cubicgrid(x0) # noise = noise[:, :, :score.size(2), :score.size(3), :score.size(4)] ### to accommodate change of size due to arch if loss_type == 'l2': losses = torch.square((score - noise) * vis_mask) elif loss_type == 'l1': losses = torch.abs((score - noise) * vis_mask) else: raise NotImplementedError else: if loss_type == 'l2': losses = torch.square(score - noise) elif loss_type == 'l1': losses = torch.abs(score - noise) else: raise NotImplementedError mask = model.module.feature_mask assert len(mask.size()) == 5 assert mask.size(1) == losses.size(1) if mask is not None: losses = losses * mask loss = torch.sum(losses) / mask.sum() / losses.size(0) else: raise NotImplementedError reg_loss = torch.zeros_like(loss) total_loss = loss return total_loss, loss, reg_loss return loss_fn def get_step_fn(sde, train, optimize_fn=None, loss_type='l2', pred_type='noise', use_vis_mask=False, use_occ=False, use_aux=False): """Create a one-step training/evaluation function. Args: sde: An `sde_lib.SDE` object that represents the forward SDE. optimize_fn: An optimization function. reduce_mean: If `True`, average the loss across data dimensions. Otherwise sum the loss across data dimensions. continuous: `True` indicates that the model is defined to take continuous time steps. likelihood_weighting: If `True`, weight the mixture of score matching losses according to https://arxiv.org/abs/2101.09258; otherwise use the weighting recommended by our paper. Returns: A one-step function for training or evaluation. """ loss_fn = get_ddpm_loss_fn(sde, train, loss_type=loss_type, pred_type=pred_type, use_vis_mask=use_vis_mask, use_occ=use_occ, use_aux=use_aux) def step_fn(state, batch, clear_grad=True, update_param=True, gradscaler=None): """Running one step of training or evaluation. This function will undergo `jax.lax.scan` so that multiple steps can be pmapped and jit-compiled together for faster execution. Args: state: A dictionary of training information, containing the score model, optimizer, EMA status, and number of optimization steps. batch: A mini-batch of training/evaluation data. Returns: loss: The average loss value of this state. """ model = state['model'] if train: optimizer = state['optimizer'] if clear_grad: optimizer.zero_grad() loss_total, loss_score, loss_reg = loss_fn(model, batch) gradscaler.scale(loss_total).backward() if update_param: optimize_fn(optimizer, model.parameters(), step=state['step'], gradscaler=gradscaler) state['step'] += 1 state['ema'].update(model.parameters()) else: with torch.no_grad(): ema = state['ema'] ema.store(model.parameters()) ema.copy_to(model.parameters()) loss_total, loss_score, loss_reg = loss_fn(model, batch) ema.restore(model.parameters()) return { 'loss_total': loss_total, 'loss_score': loss_score, 'loss_reg': loss_reg, } return step_fn ================================================ FILE: GMeshDiffusion/lib/diffusion/models/__init__.py ================================================ # coding=utf-8 # Copyright 2020 The Google Research Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. ================================================ FILE: GMeshDiffusion/lib/diffusion/models/ema.py ================================================ # Modified from https://raw.githubusercontent.com/fadel/pytorch_ema/master/torch_ema/ema.py from __future__ import division from __future__ import unicode_literals import torch # Partially based on: https://github.com/tensorflow/tensorflow/blob/r1.13/tensorflow/python/training/moving_averages.py class ExponentialMovingAverage: """ Maintains (exponential) moving average of a set of parameters. """ def __init__(self, parameters, decay, use_num_updates=True): """ Args: parameters: Iterable of `torch.nn.Parameter`; usually the result of `model.parameters()`. decay: The exponential decay. use_num_updates: Whether to use number of updates when computing averages. """ if decay < 0.0 or decay > 1.0: raise ValueError('Decay must be between 0 and 1') self.decay = decay self.num_updates = 0 if use_num_updates else None self.shadow_params = [p.clone().detach() for p in parameters if p.requires_grad] self.collected_params = [] def update(self, parameters): """ Update currently maintained parameters. Call this every time the parameters are updated, such as the result of the `optimizer.step()` call. Args: parameters: Iterable of `torch.nn.Parameter`; usually the same set of parameters used to initialize this object. """ decay = self.decay if self.num_updates is not None: self.num_updates += 1 decay = min(decay, (1 + self.num_updates) / (10 + self.num_updates)) one_minus_decay = 1.0 - decay with torch.no_grad(): parameters = [p for p in parameters if p.requires_grad] for s_param, param in zip(self.shadow_params, parameters): # print(s_param.device, s_param.device, param.device) s_param.sub_(one_minus_decay * (s_param - param)) def copy_to(self, parameters): """ Copy current parameters into given collection of parameters. Args: parameters: Iterable of `torch.nn.Parameter`; the parameters to be updated with the stored moving averages. """ parameters = [p for p in parameters if p.requires_grad] for s_param, param in zip(self.shadow_params, parameters): if param.requires_grad: param.data.copy_(s_param.data) def store(self, parameters): """ Save the current parameters for restoring later. Args: parameters: Iterable of `torch.nn.Parameter`; the parameters to be temporarily stored. """ self.collected_params = [param.clone() for param in parameters] def restore(self, parameters): """ Restore the parameters stored with the `store` method. Useful to validate the model with EMA parameters without affecting the original optimization process. Store the parameters before the `copy_to` method. After validation (or model saving), use this to restore the former parameters. Args: parameters: Iterable of `torch.nn.Parameter`; the parameters to be updated with the stored parameters. """ for c_param, param in zip(self.collected_params, parameters): param.data.copy_(c_param.data) def state_dict(self): return dict(decay=self.decay, num_updates=self.num_updates, shadow_params=self.shadow_params) def load_state_dict(self, state_dict, device='cuda'): self.decay = state_dict['decay'] self.num_updates = state_dict['num_updates'] self.shadow_params = state_dict['shadow_params'] for k, _ in enumerate(self.shadow_params): self.shadow_params[k] = self.shadow_params[k].to(device) # for k in self.shadow_params: # print(k.device) # raise ================================================ FILE: GMeshDiffusion/lib/diffusion/models/functional.py ================================================ ################################################################################################# # Copyright (c) 2023 Ali Hassani. # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal # in the Software without restriction, including without limitation the rights # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell # copies of the Software, and to permit persons to whom the Software is # furnished to do so, subject to the following conditions: # # The above copyright notice and this permission notice shall be included in all # copies or substantial portions of the Software. # # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. # ################################################################################################# import torch from torch.autograd import Function from torch.cuda.amp import custom_bwd, custom_fwd try: from natten import _C except ImportError: raise ImportError( f"Failed to import NATTEN's CPP backend. " + f"This could be due to an invalid/incomplete install. " + f"Please uninstall NATTEN (pip uninstall natten) and re-install with the" f" correct torch build: " + f"shi-labs.com/natten" ) def has_cuda(): return _C.has_cuda() def has_half(): return _C.has_half() def has_bfloat(): return _C.has_bfloat() def has_gemm(): return _C.has_gemm() def enable_tf32(): return _C.set_gemm_tf32(True) def disable_tf32(): return _C.set_gemm_tf32(False) def enable_tiled_na(): return _C.set_tiled_na(True) def disable_tiled_na(): return _C.set_tiled_na(False) def enable_gemm_na(): return _C.set_gemm_na(True) def disable_gemm_na(): return _C.set_gemm_na(False) class NeighborhoodAttention1DQKAutogradFunction(Function): @staticmethod @custom_fwd def forward(ctx, query, key, rpb, kernel_size, dilation): query = query.contiguous() key = key.contiguous() attn = _C.na1d_qk_forward(query, key, rpb, kernel_size, dilation) ctx.save_for_backward(query, key) ctx.kernel_size = kernel_size ctx.dilation = dilation ctx.bias = rpb is not None return attn @staticmethod @custom_bwd def backward(ctx, grad_out): outputs = _C.na1d_qk_backward( grad_out.contiguous(), ctx.saved_tensors[0], ctx.saved_tensors[1], ctx.bias, ctx.kernel_size, ctx.dilation, ) d_query, d_key, d_rpb = outputs return d_query, d_key, d_rpb, None, None class NeighborhoodAttention1DAVAutogradFunction(Function): @staticmethod @custom_fwd def forward(ctx, attn, value, kernel_size, dilation): attn = attn.contiguous() value = value.contiguous() out = _C.na1d_av_forward(attn, value, kernel_size, dilation) ctx.save_for_backward(attn, value) ctx.kernel_size = kernel_size ctx.dilation = dilation return out @staticmethod @custom_bwd def backward(ctx, grad_out): outputs = _C.na1d_av_backward( grad_out.contiguous(), ctx.saved_tensors[0], ctx.saved_tensors[1], ctx.kernel_size, ctx.dilation, ) d_attn, d_value = outputs return d_attn, d_value, None, None class NeighborhoodAttention2DQKAutogradFunction(Function): @staticmethod @custom_fwd def forward(ctx, query, key, rpb, kernel_size, dilation): query = query.contiguous() key = key.contiguous() if rpb is not None: rpb = rpb.to(key.dtype) attn = _C.na2d_qk_forward(query, key, rpb, kernel_size, dilation) ctx.save_for_backward(query, key) ctx.kernel_size = kernel_size ctx.dilation = dilation ctx.bias = rpb is not None return attn @staticmethod @custom_bwd def backward(ctx, grad_out): outputs = _C.na2d_qk_backward( grad_out.contiguous(), ctx.saved_tensors[0], ctx.saved_tensors[1], ctx.bias, ctx.kernel_size, ctx.dilation, ) d_query, d_key, d_rpb = outputs return d_query, d_key, d_rpb, None, None class NeighborhoodAttention2DAVAutogradFunction(Function): @staticmethod @custom_fwd def forward(ctx, attn, value, kernel_size, dilation): attn = attn.contiguous().to(value.dtype) value = value.contiguous() out = _C.na2d_av_forward(attn, value, kernel_size, dilation) ctx.save_for_backward(attn, value) ctx.kernel_size = kernel_size ctx.dilation = dilation return out @staticmethod @custom_bwd def backward(ctx, grad_out): outputs = _C.na2d_av_backward( grad_out.contiguous(), ctx.saved_tensors[0], ctx.saved_tensors[1], ctx.kernel_size, ctx.dilation, ) d_attn, d_value = outputs return d_attn, d_value, None, None class NeighborhoodAttention3DQKAutogradFunction(Function): @staticmethod @custom_fwd def forward(ctx, query, key, rpb, kernel_size_d, kernel_size, dilation_d, dilation): query = query.contiguous() key = key.contiguous() attn = _C.na3d_qk_forward( query, key, rpb, kernel_size, dilation, kernel_size_d, dilation_d ) ctx.save_for_backward(query, key) ctx.kernel_size_d = kernel_size_d ctx.kernel_size = kernel_size ctx.dilation_d = dilation_d ctx.dilation = dilation ctx.bias = rpb is not None return attn @staticmethod @custom_bwd def backward(ctx, grad_out): outputs = _C.na3d_qk_backward( grad_out.contiguous(), ctx.saved_tensors[0], ctx.saved_tensors[1], ctx.bias, ctx.kernel_size, ctx.dilation, ctx.kernel_size_d, ctx.dilation_d, ) d_query, d_key, d_rpb = outputs return d_query, d_key, d_rpb, None, None, None, None class NeighborhoodAttention3DAVAutogradFunction(Function): @staticmethod @custom_fwd def forward(ctx, attn, value, kernel_size_d, kernel_size, dilation_d, dilation): attn = attn.contiguous() value = value.contiguous() out = _C.na3d_av_forward( attn, value, kernel_size, dilation, kernel_size_d, dilation_d ) ctx.save_for_backward(attn, value) ctx.kernel_size_d = kernel_size_d ctx.kernel_size = kernel_size ctx.dilation_d = dilation_d ctx.dilation = dilation return out @staticmethod @custom_bwd def backward(ctx, grad_out): outputs = _C.na3d_av_backward( grad_out.contiguous(), ctx.saved_tensors[0], ctx.saved_tensors[1], ctx.kernel_size, ctx.dilation, ctx.kernel_size_d, ctx.dilation_d, ) d_attn, d_value = outputs return d_attn, d_value, None, None, None, None def natten1dqkrpb(query, key, rpb, kernel_size, dilation): return NeighborhoodAttention1DQKAutogradFunction.apply( query, key, rpb, kernel_size, dilation ) def natten1dqk(query, key, kernel_size, dilation): return NeighborhoodAttention1DQKAutogradFunction.apply( query, key, None, kernel_size, dilation ) def natten1dav(attn, value, kernel_size, dilation): return NeighborhoodAttention1DAVAutogradFunction.apply( attn, value, kernel_size, dilation ) def natten2dqkrpb(query, key, rpb, kernel_size, dilation): return NeighborhoodAttention2DQKAutogradFunction.apply( query, key, rpb, kernel_size, dilation ) def natten2dqk(query, key, kernel_size, dilation): return NeighborhoodAttention2DQKAutogradFunction.apply( query, key, None, kernel_size, dilation ) def natten2dav(attn, value, kernel_size, dilation): return NeighborhoodAttention2DAVAutogradFunction.apply( attn, value, kernel_size, dilation ) def natten3dqkrpb(query, key, rpb, kernel_size_d, kernel_size, dilation_d, dilation): return NeighborhoodAttention3DQKAutogradFunction.apply( query, key, rpb, kernel_size_d, kernel_size, dilation_d, dilation ) def natten3dqk(query, key, kernel_size_d, kernel_size, dilation_d, dilation): return NeighborhoodAttention3DQKAutogradFunction.apply( query, key, None, kernel_size_d, kernel_size, dilation_d, dilation ) def natten3dav(attn, value, kernel_size_d, kernel_size, dilation_d, dilation): return NeighborhoodAttention3DAVAutogradFunction.apply( attn, value, kernel_size_d, kernel_size, dilation_d, dilation ) ================================================ FILE: GMeshDiffusion/lib/diffusion/models/layers.py ================================================ # coding=utf-8 # Copyright 2020 The Google Research Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # pylint: skip-file """Common layers for defining score networks. """ import math import string from functools import partial import torch.nn as nn import torch import torch.nn.functional as F import numpy as np from .normalization import ConditionalInstanceNorm3dPlus class GroupNormFloat32(nn.GroupNorm): def forward(self, input): with torch.autocast(device_type='cuda', dtype=torch.bfloat16, enabled=False): return F.group_norm( input.float(), self.num_groups, self.weight, self.bias, self.eps) def get_act_fn(act_name): """Get activation functions from the config file.""" if act_name.lower() == 'elu': return nn.ELU() elif act_name.lower() == 'relu': return nn.ReLU() elif act_name.lower() == 'lrelu': return nn.LeakyReLU(negative_slope=0.2) elif act_name.lower() == 'swish' or act_name.lower() == 'silu': return nn.SiLU() else: raise NotImplementedError('activation function does not exist!') def variance_scaling(scale, mode, distribution, in_axis=1, out_axis=0, dtype=torch.float32, device='cpu'): """Ported from JAX. """ def _compute_fans(shape, in_axis=1, out_axis=0): receptive_field_size = np.prod(shape) / shape[in_axis] / shape[out_axis] fan_in = shape[in_axis] * receptive_field_size fan_out = shape[out_axis] * receptive_field_size return fan_in, fan_out def init(shape, dtype=dtype, device=device): fan_in, fan_out = _compute_fans(shape, in_axis, out_axis) if mode == "fan_in": denominator = fan_in elif mode == "fan_out": denominator = fan_out elif mode == "fan_avg": denominator = (fan_in + fan_out) / 2 else: raise ValueError( "invalid mode for variance scaling initializer: {}".format(mode)) variance = scale / denominator if distribution == "normal": return torch.randn(*shape, dtype=dtype, device=device) * np.sqrt(variance) elif distribution == "uniform": return (torch.rand(*shape, dtype=dtype, device=device) * 2. - 1.) * np.sqrt(3 * variance) else: raise ValueError("invalid distribution for variance scaling initializer") return init def default_init(scale=1.): """The same initialization used in DDPM.""" scale = 1e-10 if scale == 0 else scale return variance_scaling(scale, 'fan_avg', 'uniform') class Dense(nn.Module): """Linear layer with `default_init`.""" def __init__(self): super().__init__() def conv1x1(in_planes, out_planes, stride=1, bias=True, init_scale=1., padding=0): """1x1 convolution with DDPM initialization.""" conv = nn.Conv3d(in_planes, out_planes, kernel_size=1, stride=stride, padding=padding, bias=bias) conv.weight.data = default_init(init_scale)(conv.weight.data.shape) nn.init.zeros_(conv.bias) return conv def conv3x3(in_planes, out_planes, stride=1, bias=True, dilation=1, init_scale=1., padding=1): """3x3 convolution with DDPM initialization.""" conv = nn.Conv3d(in_planes, out_planes, kernel_size=3, stride=stride, padding=padding, dilation=dilation, bias=bias) conv.weight.data = default_init(init_scale)(conv.weight.data.shape) nn.init.zeros_(conv.bias) return conv def conv5x5(in_planes, out_planes, stride=2, bias=True, dilation=1, init_scale=1., padding=2): """3x3 convolution with DDPM initialization.""" conv = nn.Conv3d(in_planes, out_planes, kernel_size=5, stride=stride, padding=padding, dilation=dilation, bias=bias) conv.weight.data = default_init(init_scale)(conv.weight.data.shape) nn.init.zeros_(conv.bias) return conv def conv3x3_transposed(in_planes, out_planes, stride=2, bias=True, dilation=1, init_scale=1., padding=1): """3x3 convolution with DDPM initialization.""" conv = nn.ConvTranspose3d(in_planes, out_planes, kernel_size=3, stride=stride, padding=padding, dilation=dilation, bias=bias) conv.weight.data = default_init(init_scale)(conv.weight.data.shape) nn.init.zeros_(conv.bias) return conv def conv5x5_transposed(in_planes, out_planes, stride=2, bias=True, dilation=1, init_scale=1., padding=2): """3x3 convolution with DDPM initialization.""" conv = nn.ConvTranspose3d(in_planes, out_planes, kernel_size=5, stride=stride, padding=padding, dilation=dilation, bias=bias) conv.weight.data = default_init(init_scale)(conv.weight.data.shape) nn.init.zeros_(conv.bias) return conv ########################################################################### # Functions below are ported over from the DDPM codebase: # https://github.com/hojonathanho/diffusion/blob/master/diffusion_tf/nn.py ########################################################################### def get_timestep_embedding(timesteps, embedding_dim, max_positions=10000): with torch.no_grad(): assert len(timesteps.shape) == 1 # and timesteps.dtype == tf.int32 half_dim = embedding_dim // 2 # magic number 10000 is from transformers emb = math.log(max_positions) / (half_dim - 1) # emb = math.log(2.) / (half_dim - 1) emb = torch.exp(torch.arange(half_dim, dtype=torch.float32, device=timesteps.device) * -emb) # emb = tf.range(num_embeddings, dtype=jnp.float32)[:, None] * emb[None, :] # emb = tf.cast(timesteps, dtype=jnp.float32)[:, None] * emb[None, :] emb = timesteps[:, None] * emb[None, :] emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) if embedding_dim % 2 == 1: # zero pad emb = F.pad(emb, (0, 1), mode='constant') assert emb.shape == (timesteps.shape[0], embedding_dim) return emb class AttnBlock(nn.Module): """Channel-wise self-attention block.""" def __init__(self, channels, num_groups=32): super().__init__() self.GroupNorm_0 = GroupNormFloat32(num_groups=num_groups, num_channels=channels, eps=1e-6) self.NIN_0 = conv1x1(channels, channels) self.NIN_1 = conv1x1(channels, channels) self.NIN_2 = conv1x1(channels, channels) self.NIN_3 = conv1x1(channels, channels, init_scale=0.) def forward(self, x): B, C, D, H, W = x.shape h = self.GroupNorm_0(x) q = self.NIN_0(h) k = self.NIN_1(h) v = self.NIN_2(h) # q = q.view(B, C, -1).permute(0, 2, 1) # k = k.view(B, C, -1).permute(0, 2, 1) # v = v.view(B, C, -1).permute(0, 2, 1) # with torch.autocast(device_type='cuda', dtype=torch.bfloat16, enabled=False): # h = F.scaled_dot_product_attention(q.float(), k.float(), v.float()) # h = h.permute(0, 2, 1).view(B, C, D, H, W) with torch.autocast(device_type='cuda', dtype=torch.bfloat16, enabled=False): w = torch.einsum('bcdhw,bckij->bdhwkij', q.float(), k.float()) * (int(C) ** (-0.5)) w = torch.reshape(w, (B, D, H, W, D * H * W)) w = F.softmax(w, dim=-1) w = torch.reshape(w, (B, D, H, W, D, H, W)) h = torch.einsum('bdhwkij,bckij->bcdhw', w, v) h = self.NIN_3(h) return x + h class Upsample(nn.Module): def __init__(self, channels, with_conv=False): super().__init__() if with_conv: self.Conv_0 = conv3x3(channels, channels) self.with_conv = with_conv def forward(self, x, temb=None): B, C, D, H, W = x.shape h = F.interpolate(x.float(), (D * 2, H * 2, W * 2), mode='nearest') if self.with_conv: h = self.Conv_0(h) return h class Downsample(nn.Module): def __init__(self, channels, with_conv=False): super().__init__() if with_conv: self.Conv_0 = conv3x3(channels, channels, stride=2, padding=0) self.with_conv = with_conv def forward(self, x, temb=None): B, C, D, H, W = x.shape # Emulate 'SAME' padding if self.with_conv: x = F.pad(x, (0, 1, 0, 1, 0, 1)) x = self.Conv_0(x) else: x = F.avg_pool3d(x, kernel_size=2, stride=2, padding=0) assert x.shape == (B, C, D // 2, H // 2, W // 2) return x class ResBlock(nn.Module): """The ResNet Blocks used in DDPM.""" def __init__(self, act_fn, in_ch, out_ch=None, temb_dim=None, conv_shortcut=False, dropout=0.1, num_groups=32): super().__init__() if out_ch is None: out_ch = in_ch self.GroupNorm_0 = GroupNormFloat32(num_groups=num_groups, num_channels=in_ch, eps=1e-6) self.act = act_fn self.Conv_0 = conv3x3(in_ch, out_ch) if temb_dim is not None: self.Dense_0 = nn.Linear(temb_dim, out_ch) self.Dense_0.weight.data = default_init()(self.Dense_0.weight.data.shape) nn.init.zeros_(self.Dense_0.bias) self.GroupNorm_1 = GroupNormFloat32(num_groups=num_groups, num_channels=out_ch, eps=1e-6) self.Dropout_0 = nn.Dropout(dropout) self.Conv_1 = conv3x3(out_ch, out_ch, init_scale=0.) if in_ch != out_ch: if conv_shortcut: self.Conv_2 = conv3x3(in_ch, out_ch) else: self.NIN_0 = conv1x1(in_ch, out_ch) self.out_ch = out_ch self.in_ch = in_ch self.conv_shortcut = conv_shortcut def forward(self, x, temb=None): B, C, D, H, W = x.shape assert C == self.in_ch out_ch = self.out_ch if self.out_ch else self.in_ch h = self.act(self.GroupNorm_0(x)) h = self.Conv_0(h) # Add bias to each feature map conditioned on the time embedding if temb is not None: h += self.Dense_0(self.act(temb))[:, :, None, None, None] h = self.act(self.GroupNorm_1(h)) h = self.Dropout_0(h) h = self.Conv_1(h) if C != out_ch: if self.conv_shortcut: x = self.Conv_2(x) else: x = self.NIN_0(x) return x + h class AttnResBlock(ResBlock): """The ResNet Blocks used in DDPM.""" def __init__(self, act_fn, in_ch, out_ch, temb_dim=None, conv_shortcut=False, dropout=0.1, num_groups=32): super().__init__(act_fn, in_ch, out_ch, temb_dim, conv_shortcut, dropout, num_groups=num_groups) self.attn_block = AttnBlock(out_ch, num_groups=num_groups) def forward(self, x, temb=None): h = super().forward(x, temb) return self.attn_block(h) ================================================ FILE: GMeshDiffusion/lib/diffusion/models/normalization.py ================================================ # coding=utf-8 # Copyright 2020 The Google Research Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Normalization layers.""" import torch.nn as nn import torch import functools def get_normalization(config, conditional=False): """Obtain normalization modules from the config file.""" norm = config.model.normalization if conditional: if norm == 'InstanceNorm++': return functools.partial(ConditionalInstanceNorm3dPlus, num_classes=config.model.num_classes) else: raise NotImplementedError(f'{norm} not implemented yet.') else: if norm == 'InstanceNorm': return nn.InstanceNorm3d elif norm == 'InstanceNorm++': return InstanceNorm3dPlus elif norm == 'VarianceNorm': return VarianceNorm3d elif norm == 'GroupNorm': return nn.GroupNorm else: raise ValueError('Unknown normalization: %s' % norm) class ConditionalBatchNorm3d(nn.Module): def __init__(self, num_features, num_classes, bias=True): super().__init__() self.num_features = num_features self.bias = bias self.bn = nn.BatchNorm3d(num_features, affine=False) if self.bias: self.embed = nn.Embedding(num_classes, num_features * 2) self.embed.weight.data[:, :num_features].uniform_() # Initialise scale at N(1, 0.02) self.embed.weight.data[:, num_features:].zero_() # Initialise bias at 0 else: self.embed = nn.Embedding(num_classes, num_features) self.embed.weight.data.uniform_() def forward(self, x, y): out = self.bn(x) if self.bias: gamma, beta = self.embed(y).chunk(2, dim=1) out = gamma.view(-1, self.num_features, 1, 1, 1) * out + beta.view(-1, self.num_features, 1, 1, 1) else: gamma = self.embed(y) out = gamma.view(-1, self.num_features, 1, 1, 1) * out return out class ConditionalInstanceNorm3d(nn.Module): def __init__(self, num_features, num_classes, bias=True): super().__init__() self.num_features = num_features self.bias = bias self.instance_norm = nn.InstanceNorm3d(num_features, affine=False, track_running_stats=False) if bias: self.embed = nn.Embedding(num_classes, num_features * 2) self.embed.weight.data[:, :num_features].uniform_() # Initialise scale at N(1, 0.02) self.embed.weight.data[:, num_features:].zero_() # Initialise bias at 0 else: self.embed = nn.Embedding(num_classes, num_features) self.embed.weight.data.uniform_() def forward(self, x, y): h = self.instance_norm(x) if self.bias: gamma, beta = self.embed(y).chunk(2, dim=-1) out = gamma.view(-1, self.num_features, 1, 1, 1) * h + beta.view(-1, self.num_features, 1, 1, 1) else: gamma = self.embed(y) out = gamma.view(-1, self.num_features, 1, 1, 1) * h return out class ConditionalVarianceNorm3d(nn.Module): def __init__(self, num_features, num_classes, bias=False): super().__init__() self.num_features = num_features self.bias = bias self.embed = nn.Embedding(num_classes, num_features) self.embed.weight.data.normal_(1, 0.02) def forward(self, x, y): vars = torch.var(x, dim=(2, 3, 4), keepdim=True) h = x / torch.sqrt(vars + 1e-5) gamma = self.embed(y) out = gamma.view(-1, self.num_features, 1, 1, 1) * h return out class VarianceNorm3d(nn.Module): def __init__(self, num_features, bias=False): super().__init__() self.num_features = num_features self.bias = bias self.alpha = nn.Parameter(torch.zeros(num_features)) self.alpha.data.normal_(1, 0.02) def forward(self, x): vars = torch.var(x, dim=(2, 3, 4), keepdim=True) h = x / torch.sqrt(vars + 1e-5) out = self.alpha.view(-1, self.num_features, 1, 1, 1) * h return out class ConditionalNoneNorm3d(nn.Module): def __init__(self, num_features, num_classes, bias=True): super().__init__() self.num_features = num_features self.bias = bias if bias: self.embed = nn.Embedding(num_classes, num_features * 2) self.embed.weight.data[:, :num_features].uniform_() # Initialise scale at N(1, 0.02) self.embed.weight.data[:, num_features:].zero_() # Initialise bias at 0 else: self.embed = nn.Embedding(num_classes, num_features) self.embed.weight.data.uniform_() def forward(self, x, y): if self.bias: gamma, beta = self.embed(y).chunk(2, dim=-1) out = gamma.view(-1, self.num_features, 1, 1, 1) * x + beta.view(-1, self.num_features, 1, 1, 1) else: gamma = self.embed(y) out = gamma.view(-1, self.num_features, 1, 1, 1) * x return out class NoneNorm3d(nn.Module): def __init__(self, num_features, bias=True): super().__init__() def forward(self, x): return x class InstanceNorm3dPlus(nn.Module): def __init__(self, num_features, bias=True): super().__init__() self.num_features = num_features self.bias = bias self.instance_norm = nn.InstanceNorm3d(num_features, affine=False, track_running_stats=False) self.alpha = nn.Parameter(torch.zeros(num_features)) self.gamma = nn.Parameter(torch.zeros(num_features)) self.alpha.data.normal_(1, 0.02) self.gamma.data.normal_(1, 0.02) if bias: self.beta = nn.Parameter(torch.zeros(num_features)) def forward(self, x): means = torch.mean(x, dim=(2, 3, 4)) m = torch.mean(means, dim=-1, keepdim=True) v = torch.var(means, dim=-1, keepdim=True) means = (means - m) / (torch.sqrt(v + 1e-5)) h = self.instance_norm(x) if self.bias: h = h + means[..., None, None, None] * self.alpha[..., None, None, None] out = self.gamma.view(-1, self.num_features, 1, 1, 1) * h + self.beta.view(-1, self.num_features, 1, 1, 1) else: h = h + means[..., None, None, None] * self.alpha[..., None, None, None] out = self.gamma.view(-1, self.num_features, 1, 1, 1) * h return out class ConditionalInstanceNorm3dPlus(nn.Module): def __init__(self, num_features, num_classes, bias=True): super().__init__() self.num_features = num_features self.bias = bias self.instance_norm = nn.InstanceNorm3d(num_features, affine=False, track_running_stats=False) if bias: self.embed = nn.Embedding(num_classes, num_features * 3) self.embed.weight.data[:, :2 * num_features].normal_(1, 0.02) # Initialise scale at N(1, 0.02) self.embed.weight.data[:, 2 * num_features:].zero_() # Initialise bias at 0 else: self.embed = nn.Embedding(num_classes, 2 * num_features) self.embed.weight.data.normal_(1, 0.02) def forward(self, x, y): means = torch.mean(x, dim=(2, 3, 4)) m = torch.mean(means, dim=-1, keepdim=True) v = torch.var(means, dim=-1, keepdim=True) means = (means - m) / (torch.sqrt(v + 1e-5)) h = self.instance_norm(x) if self.bias: gamma, alpha, beta = self.embed(y).chunk(3, dim=-1) h = h + means[..., None, None, None] * alpha[..., None, None, None] out = gamma.view(-1, self.num_features, 1, 1, 1) * h + beta.view(-1, self.num_features, 1, 1, 1) else: gamma, alpha = self.embed(y).chunk(2, dim=-1) h = h + means[..., None, None, None] * alpha[..., None, None, None] out = gamma.view(-1, self.num_features, 1, 1, 1) * h return out ================================================ FILE: GMeshDiffusion/lib/diffusion/models/unet3d_occgrid.py ================================================ # coding=utf-8 # Copyright 2020 The Google Research Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # pylint: skip-file """DDPM model. This code is the pytorch equivalent of: https://github.com/hojonathanho/diffusion/blob/master/diffusion_tf/models/unet.py """ import torch import torch.nn as nn import functools import numpy as np from . import utils from .layers import ResBlock, AttnResBlock, Upsample, Downsample, conv1x1, conv3x3, conv5x5, get_act_fn, default_init, get_timestep_embedding, GroupNormFloat32 import sys def str_to_class(classname): return getattr(sys.modules[__name__], classname) @utils.register_model(name='unet3d_occgrid') class UNet3D(nn.Module): def __init__(self, config): super().__init__() self.act_fn = get_act_fn(config.model.act_fn) self.nf = nf = config.model.base_channels data_ch = config.data.num_channels ch_mult = config.model.ch_mult feature_mask = torch.load(config.model.feature_mask_path, map_location='cpu').view(1, data_ch, 128, 128, 128) pixcat_mask = torch.load(config.model.pixcat_mask_path, map_location='cpu').view(1, 1, 128, 128, 128) occ_mask_path = config.data.occ_mask_path occ_mask = torch.load(occ_mask_path, map_location='cpu').view(1, 1, 256, 256, 256) tet_info = torch.load(config.data.tet_info_path) self.tet_edge_vpos = tet_info['tet_edge_vpos'].cuda() self.tet_edge_pix_loc = tet_info['tet_edge_pix_loc'].cuda().view(-1, 2, 3) self.tet_edge_pix_loc = self.tet_edge_pix_loc.view(-1, 2, 3) # self.tet_center_loc = tet_info['tet_center_loc'].cuda() self.vis_edges = tet_info['vis_edges'].cuda() self.occ_edge_cano_order = tet_info['occ_edge_cano_order'].cuda() self.tet_edgenode_loc = self.tet_edge_pix_loc.float().mean(dim=1).long() self.occ_edge_loc = self.tet_edgenode_loc.view(-1, 6, 3)[:, self.vis_edges.view(-1)].view(-1, 2, 3) self.occ_node_loc = (self.occ_edge_loc.view(-1, 12, 2, 3).float().mean(dim=-2) * 2.0).long().view(-1, 3) print(self.tet_edgenode_loc.size(), self.vis_edges.size(), self.occ_edge_loc.size(), self.occ_node_loc.size()) self.tet_edge_pix_loc = self.tet_edge_pix_loc.view(-1, 3) self.feature_mask = torch.nn.Parameter(feature_mask, requires_grad=False) self.pixcat_mask = torch.nn.Parameter(pixcat_mask, requires_grad=False) self.occ_mask = torch.nn.Parameter(occ_mask, requires_grad=False) self.down_block_types = config.model.down_block_types self.up_block_types = config.model.up_block_types self.num_res_blocks = config.model.num_res_blocks self.num_res_blocks_1st_layer = config.model.num_res_blocks_1st_layer resamp_with_conv = config.model.resamp_with_conv dropout = config.model.dropout assert len(self.down_block_types) == len(self.up_block_types) module_dict = { module: functools.partial(str_to_class(module), act_fn=self.act_fn, temb_dim=4 * nf, dropout=dropout) for module in ["ResBlock", "AttnResBlock"] } # Condition on noise levels. noise_temb_layers = [nn.Linear(nf, nf * 4), nn.SiLU(), nn.Linear(nf * 4, nf * 4)] noise_temb_layers[0].weight.data = default_init()(noise_temb_layers[0].weight.data.shape) nn.init.zeros_(noise_temb_layers[0].bias) noise_temb_layers[2].weight.data = default_init()(noise_temb_layers[2].weight.data.shape) nn.init.zeros_(noise_temb_layers[2].bias) self.noise_temb_layers = nn.Sequential(*noise_temb_layers) self.occ_conv = conv3x3(1, nf, stride=2, padding=1) self.occ_mask_conv = conv3x3(1, nf, stride=2, padding=1) # Downsampling block self.mask_layer = conv5x5(1, nf, stride=1, padding=2) self.input_layer = conv5x5(data_ch, nf, stride=1, padding=2) hs_c = [nf] in_ch = nf modules = [] for i_level, down_block_type in enumerate(self.down_block_types): curr_block = module_dict[down_block_type] # Residual blocks for this resolution num_res_blocks = self.num_res_blocks_1st_layer if i_level == 0 else self.num_res_blocks for i_block in range(num_res_blocks): out_ch = nf * ch_mult[i_level] modules.append(curr_block(in_ch=in_ch, out_ch=out_ch)) in_ch = out_ch hs_c.append(in_ch) if i_level != len(self.down_block_types) - 1: modules.append(Downsample(channels=in_ch, with_conv=resamp_with_conv)) hs_c.append(in_ch) in_ch = hs_c[-1] modules.append(module_dict["AttnResBlock"](in_ch=in_ch, out_ch=in_ch)) modules.append(module_dict["ResBlock"](in_ch=in_ch)) # Upsampling block for i_level, up_block_type in enumerate(self.up_block_types): curr_block = module_dict[up_block_type] num_res_blocks = self.num_res_blocks_1st_layer if i_level == len(self.up_block_types) - 1 else self.num_res_blocks for i_block in range(num_res_blocks + 1): out_ch = nf * ch_mult[len(self.up_block_types) - i_level - 1] modules.append(curr_block(in_ch=in_ch + hs_c.pop(), out_ch=out_ch)) in_ch = out_ch if i_level != len(self.up_block_types) - 1: modules.append(Upsample(channels=in_ch, with_conv=resamp_with_conv)) self.all_modules = nn.ModuleList(modules) self.output_norm_layer = nn.Sequential( GroupNormFloat32(num_channels=in_ch, num_groups=32, eps=1e-6), nn.SiLU(), ) self.output_layer = conv5x5(in_ch, data_ch, init_scale=0., stride=1, padding=2) self.occ_output_layer = nn.ConvTranspose3d(in_ch, 1, 4, stride=2, padding=1) def sequentially_call_module(self, idx, x, temb=None): return idx + 1, self.all_modules[idx](x, temb) def forward(self, x, labels): modules = self.all_modules with torch.no_grad(): x, occ_grid = x[0], x[1] if True or self.centered: # Input is in [-1, 1] x = x else: # Input is in [0, 1] x = 2 * x - 1. # Mask out unused values x = x * self.feature_mask occ_grid = occ_grid * self.occ_mask with torch.autocast(device_type='cuda', dtype=torch.bfloat16, enabled=False): # timestep/scale embedding timesteps = labels temb = get_timestep_embedding(timesteps.float(), self.nf) temb = self.noise_temb_layers(temb) # Downsampling block hs = [self.input_layer(x) + self.mask_layer(self.pixcat_mask) + self.occ_conv(occ_grid) + self.occ_mask_conv(self.occ_mask)] m_idx = 0 for i_level in range(len(self.down_block_types)): num_res_blocks = self.num_res_blocks_1st_layer if i_level == 0 else self.num_res_blocks for i_block in range(num_res_blocks): m_idx, h = self.sequentially_call_module(m_idx, hs[-1], temb) hs.append(h) if i_level != len(self.down_block_types) - 1: m_idx, h = self.sequentially_call_module(m_idx, hs[-1]) hs.append(h) h = hs[-1] m_idx, h = self.sequentially_call_module(m_idx, h, temb) m_idx, h = self.sequentially_call_module(m_idx, h, temb) # Upsampling block for i_level in range(len(self.up_block_types)): num_res_blocks = self.num_res_blocks_1st_layer if i_level == len(self.up_block_types) - 1 else self.num_res_blocks for i_block in range(num_res_blocks + 1): hspop = hs.pop() h = torch.cat([h, hspop], dim=1) m_idx, h = self.sequentially_call_module(m_idx, h, temb) if i_level != len(self.up_block_types) - 1: m_idx, h = self.sequentially_call_module(m_idx, h, temb) assert not hs h = self.output_norm_layer(h) grid = self.output_layer(h) grid_occ = self.occ_output_layer(h) # Mask out unused values grid = grid * self.feature_mask grid_occ = grid_occ * self.occ_mask return grid, grid_occ ================================================ FILE: GMeshDiffusion/lib/diffusion/models/utils.py ================================================ # coding=utf-8 # Copyright 2020 The Google Research Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """All functions and modules related to model definition. """ import torch from .. import sde_lib import numpy as np _MODELS = {} def register_model(cls=None, *, name=None): """A decorator for registering model classes.""" def _register(cls): if name is None: local_name = cls.__name__ else: local_name = name if local_name in _MODELS: raise ValueError(f'Already registered model with name: {local_name}') _MODELS[local_name] = cls return cls if cls is None: return _register else: return _register(cls) def get_model(name): return _MODELS[name] def get_sigmas(config): """Get sigmas --- the set of noise levels for SMLD from config files. Args: config: A ConfigDict object parsed from the config file Returns: sigmas: a jax numpy arrary of noise levels """ sigmas = np.exp( np.linspace(np.log(config.model.sigma_max), np.log(config.model.sigma_min), config.model.num_scales)) return sigmas def get_ddpm_params(config): """Get betas and alphas --- parameters used in the original DDPM paper.""" num_diffusion_timesteps = 1000 # parameters need to be adapted if number of time steps differs from 1000 beta_start = config.model.beta_min / config.model.num_scales beta_end = config.model.beta_max / config.model.num_scales betas = np.linspace(beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64) alphas = 1. - betas alphas_cumprod = np.cumprod(alphas, axis=0) sqrt_alphas_cumprod = np.sqrt(alphas_cumprod) sqrt_1m_alphas_cumprod = np.sqrt(1. - alphas_cumprod) return { 'betas': betas, 'alphas': alphas, 'alphas_cumprod': alphas_cumprod, 'sqrt_alphas_cumprod': sqrt_alphas_cumprod, 'sqrt_1m_alphas_cumprod': sqrt_1m_alphas_cumprod, 'beta_min': beta_start * (num_diffusion_timesteps - 1), 'beta_max': beta_end * (num_diffusion_timesteps - 1), 'num_diffusion_timesteps': num_diffusion_timesteps } def create_model(config, use_parallel=True, ddp=False, rank=None): """Create the score model.""" model_name = config.model.name score_model = get_model(model_name)(config) if use_parallel: if ddp: score_model = score_model.to(rank) score_model = torch.nn.parallel.DistributedDataParallel( score_model, find_unused_parameters=False, # find_unused_parameters=True, gradient_as_bucket_view=True, # static_graph=True, device_ids=[rank]) # score_model = torch.compile(score_model) # score_model = torch.compile(score_model, fullgraph=True) else: score_model = torch.nn.DataParallel(score_model).to(config.device) else: score_model = score_model.to(config.device) return score_model def get_model_fn(model, train=False): """Create a function to give the output of the score-based model. Args: model: The score model. train: `True` for training and `False` for evaluation. Returns: A model function. """ def model_fn(x, labels): """Compute the output of the score-based model. Args: x: A mini-batch of input data. labels: A mini-batch of conditioning variables for time steps. Should be interpreted differently for different models. Returns: A tuple of (model output, new mutable states) """ if not train: model.eval() return model(x, labels) else: model.train() return model(x, labels) return model_fn def get_reg_fn(model, train=False): """Create a function to give the output of the score-based model. Args: model: The score model. train: `True` for training and `False` for evaluation. Returns: A model function. """ def model_fn(x): """Compute the output of the score-based model. Args: x: A mini-batch of input data. labels: A mini-batch of conditioning variables for time steps. Should be interpreted differently for different models. Returns: A tuple of (model output, new mutable states) """ if not train: model.eval() try: return model.get_reg(x) except: return torch.zeros_like(x, device=x.device) else: model.train() try: return model.get_reg(x) except: return torch.zeros_like(x, device=x.device) return model_fn def get_score_fn(sde, model, train=False, continuous=False, std_scale=True, pred_type='noise'): """Wraps `score_fn` so that the model output corresponds to a real time-dependent score function. Args: sde: An `sde_lib.SDE` object that represents the forward SDE. model: A score model. train: `True` for training and `False` for evaluation. continuous: If `True`, the score-based model is expected to directly take continuous time steps. std_scale: whether to scale the score function by the inverse of std. Used for DDIM sampling Returns: A score function. """ model_fn = get_model_fn(model, train=train) reg_fn = get_reg_fn(model, train=train) assert not continuous if isinstance(sde, sde_lib.VPSDE): if not std_scale: def score_fn(x, t): labels = t * (sde.N - 1) pred = model_fn(x, labels) if pred_type == 'x0': labels = labels.long() alphas1 = sde.sqrt_alphas_cumprod[labels, None, None, None, None].cuda() alphas2 = sde.sqrt_1m_alphas_cumprod[labels, None, None, None, None].cuda() score = (x - pred * alphas1) / alphas2 elif pred_type == 'noise': score = pred return score else: def score_fn(x, t): # For VP-trained models, t=0 corresponds to the lowest noise level labels = t * (sde.N - 1) pred = model_fn(x, labels) if pred_type == 'x0': labels = labels.long() alphas1 = sde.sqrt_alphas_cumprod[labels, None, None, None, None].cuda() alphas2 = sde.sqrt_1m_alphas_cumprod[labels, None, None, None, None].cuda() score = (x - pred * alphas1) / alphas2 elif pred_type == 'noise': score = pred std = sde.sqrt_1m_alphas_cumprod.to(labels.device)[labels.long()] score = -score / std[:, None, None, None, None] return score else: raise NotImplementedError(f"SDE class {sde.__class__.__name__} not yet supported.") return score_fn def to_flattened_numpy(x): """Flatten a torch tensor `x` and convert it to numpy.""" return x.detach().cpu().numpy().reshape((-1,)) def from_flattened_numpy(x, shape): """Form a torch tensor with the given `shape` from a flattened numpy array `x`.""" return torch.from_numpy(x.reshape(shape)) ================================================ FILE: GMeshDiffusion/lib/diffusion/sampling.py ================================================ # coding=utf-8 # Copyright 2020 The Google Research Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # pylint: skip-file # pytype: skip-file """Various sampling methods.""" import functools import torch import numpy as np import abc from .models.utils import from_flattened_numpy, to_flattened_numpy, get_score_fn from scipy import integrate from . import sde_lib from .models import utils as mutils import logging import tqdm _CORRECTORS = {} _PREDICTORS = {} def register_predictor(cls=None, *, name=None): """A decorator for registering predictor classes.""" def _register(cls): if name is None: local_name = cls.__name__ else: local_name = name if local_name in _PREDICTORS: raise ValueError(f'Already registered model with name: {local_name}') _PREDICTORS[local_name] = cls return cls if cls is None: return _register else: return _register(cls) def register_corrector(cls=None, *, name=None): """A decorator for registering corrector classes.""" def _register(cls): if name is None: local_name = cls.__name__ else: local_name = name if local_name in _CORRECTORS: raise ValueError(f'Already registered model with name: {local_name}') _CORRECTORS[local_name] = cls return cls if cls is None: return _register else: return _register(cls) def get_predictor(name): return _PREDICTORS[name] def get_corrector(name): return _CORRECTORS[name] def get_sampling_fn(config, sde, shape, inverse_scaler, eps, grid_mask=None, return_traj=False, pred_type='noise'): """Create a sampling function. Args: config: A `ml_collections.ConfigDict` object that contains all configuration information. sde: A `sde_lib.SDE` object that represents the forward SDE. shape: A sequence of integers representing the expected shape of a single sample. inverse_scaler: The inverse data normalizer function. eps: A `float` number. The reverse-time SDE is only integrated to `eps` for numerical stability. Returns: A function that takes random states and a replicated training state and outputs samples with the trailing dimensions matching `shape`. """ sampler_name = config.sampling.method # Probability flow ODE sampling with black-box ODE solvers # Predictor-Corrector sampling. Predictor-only and Corrector-only samplers are special cases. if sampler_name.lower() == 'pc': predictor = get_predictor(config.sampling.predictor.lower()) corrector = get_corrector(config.sampling.corrector.lower()) sampling_fn = get_pc_sampler(sde=sde, shape=shape, predictor=predictor, corrector=corrector, inverse_scaler=inverse_scaler, snr=config.sampling.snr, n_steps=config.sampling.n_steps_each, probability_flow=config.sampling.probability_flow, continuous=config.training.continuous, denoise=config.sampling.noise_removal, eps=eps, device=config.device, grid_mask=grid_mask, return_traj=return_traj, pred_type=pred_type, use_occ=config.model.use_occ_grid) elif sampler_name.lower() == 'ddim': predictor = get_predictor('ddim') sampling_fn = get_ddim_sampler(sde=sde, shape=shape, predictor=predictor, inverse_scaler=inverse_scaler, n_steps=config.sampling.n_steps_each, denoise=config.sampling.noise_removal, eps=eps, device=config.device, grid_mask=grid_mask, pred_type=pred_type, use_occ=config.model.use_occ_grid) else: raise ValueError(f"Sampler name {sampler_name} unknown.") return sampling_fn class Predictor(abc.ABC): """The abstract class for a predictor algorithm.""" def __init__(self, sde, score_fn, probability_flow=False): super().__init__() self.sde = sde # Compute the reverse SDE/ODE self.rsde = sde.reverse(score_fn, probability_flow) self.score_fn = score_fn @abc.abstractmethod def update_fn(self, x, t): """One update of the predictor. Args: x: A PyTorch tensor representing the current state t: A Pytorch tensor representing the current time step. Returns: x: A PyTorch tensor of the next state. x_mean: A PyTorch tensor. The next state without random noise. Useful for denoising. """ pass class Corrector(abc.ABC): """The abstract class for a corrector algorithm.""" def __init__(self, sde, score_fn, snr, n_steps): super().__init__() self.sde = sde self.score_fn = score_fn self.snr = snr self.n_steps = n_steps @abc.abstractmethod def update_fn(self, x, t): """One update of the corrector. Args: x: A PyTorch tensor representing the current state t: A PyTorch tensor representing the current time step. Returns: x: A PyTorch tensor of the next state. x_mean: A PyTorch tensor. The next state without random noise. Useful for denoising. """ pass @register_predictor(name='euler_maruyama') class EulerMaruyamaPredictor(Predictor): def __init__(self, sde, score_fn, probability_flow=False): super().__init__(sde, score_fn, probability_flow) def update_fn(self, x, t): dt = -1. / self.rsde.N z = torch.randn_like(x) drift, diffusion = self.rsde.sde(x, t) x_mean = x + drift * dt x = x_mean + diffusion[:, None, None, None, None] * np.sqrt(-dt) * z return x, x_mean @register_predictor(name='reverse_diffusion') class ReverseDiffusionPredictor(Predictor): def __init__(self, sde, score_fn, probability_flow=False): super().__init__(sde, score_fn, probability_flow) def update_fn(self, x, t): f, G = self.rsde.discretize(x, t) z = torch.randn_like(x) x_mean = x - f x = x_mean + G[:, None, None, None, None] * z return x, x_mean @register_predictor(name='ancestral_sampling') class AncestralSamplingPredictor(Predictor): """The ancestral sampling predictor. Currently only supports VE/VP SDEs.""" def __init__(self, sde, score_fn, probability_flow=False): super().__init__(sde, score_fn, probability_flow) if not isinstance(sde, sde_lib.VPSDE): raise NotImplementedError(f"SDE class {sde.__class__.__name__} not yet supported.") assert not probability_flow, "Probability flow not supported by ancestral sampling" def vpsde_update_fn(self, x, t): sde = self.sde timestep = (t * (sde.N - 1) / sde.T).long() beta = sde.discrete_betas.to(t.device)[timestep] score = self.score_fn(x, t) x_mean = (x + beta[:, None, None, None, None] * score) / torch.sqrt(1. - beta)[:, None, None, None, None] noise = torch.randn_like(x) x = x_mean + torch.sqrt(beta)[:, None, None, None, None] * noise return x, x_mean def update_fn(self, x, t): if isinstance(self.sde, sde_lib.VPSDE): return self.vpsde_update_fn(x, t) else: raise NotImplementedError @register_predictor(name='none') class NonePredictor(Predictor): """An empty predictor that does nothing.""" def __init__(self, sde, score_fn, probability_flow=False): pass def update_fn(self, x, t): return x, x @register_predictor(name='ddim') class DDIMPredictor(Predictor): def __init__(self, sde, score_fn, probability_flow=False): super().__init__(sde, score_fn, probability_flow) def update_fn(self, x, t, tprev=None): x, x0_pred = self.rsde.discretize_ddim(x, t, tprev=tprev) return x, x0_pred @register_corrector(name='langevin') class LangevinCorrector(Corrector): def __init__(self, sde, score_fn, snr, n_steps): super().__init__(sde, score_fn, snr, n_steps) if not isinstance(sde, sde_lib.VPSDE): raise NotImplementedError(f"SDE class {sde.__class__.__name__} not yet supported.") def update_fn(self, x, t): sde = self.sde score_fn = self.score_fn n_steps = self.n_steps target_snr = self.snr if isinstance(sde, sde_lib.VPSDE): timestep = (t * (sde.N - 1) / sde.T).long() alpha = sde.alphas.to(t.device)[timestep] else: alpha = torch.ones_like(t) for i in range(n_steps): grad = score_fn(x, t) noise = torch.randn_like(x) grad_norm = torch.norm(grad.reshape(grad.shape[0], -1), dim=-1).mean() noise_norm = torch.norm(noise.reshape(noise.shape[0], -1), dim=-1).mean() step_size = (target_snr * noise_norm / grad_norm) ** 2 * 2 * alpha x_mean = x + step_size[:, None, None, None, None] * grad x = x_mean + torch.sqrt(step_size * 2)[:, None, None, None, None] * noise return x, x_mean @register_corrector(name='ald') class AnnealedLangevinDynamics(Corrector): """The original annealed Langevin dynamics predictor in NCSN/NCSNv2. We include this corrector only for completeness. It was not directly used in our paper. """ def __init__(self, sde, score_fn, snr, n_steps): super().__init__(sde, score_fn, snr, n_steps) if not isinstance(sde, sde_lib.VPSDE): raise NotImplementedError(f"SDE class {sde.__class__.__name__} not yet supported.") def update_fn(self, x, t): sde = self.sde score_fn = self.score_fn n_steps = self.n_steps target_snr = self.snr if isinstance(sde, sde_lib.VPSDE): timestep = (t * (sde.N - 1) / sde.T).long() alpha = sde.alphas.to(t.device)[timestep] else: alpha = torch.ones_like(t) std = self.sde.marginal_prob(x, t)[1] for i in range(n_steps): grad = score_fn(x, t) noise = torch.randn_like(x) step_size = (target_snr * std) ** 2 * 2 * alpha x_mean = x + step_size[:, None, None, None, None] * grad x = x_mean + noise * torch.sqrt(step_size * 2)[:, None, None, None, None] return x, x_mean @register_corrector(name='none') class NoneCorrector(Corrector): """An empty corrector that does nothing.""" def __init__(self, sde, score_fn, snr, n_steps): pass def update_fn(self, x, t): return x, x def shared_predictor_update_fn(x, t, sde, model, predictor, probability_flow, continuous, pred_type='noise'): """A wrapper that configures and returns the update function of predictors.""" score_fn = mutils.get_score_fn(sde, model, train=False, continuous=continuous, pred_type=pred_type) if predictor is None: # Corrector-only sampler predictor_obj = NonePredictor(sde, score_fn, probability_flow) else: predictor_obj = predictor(sde, score_fn, probability_flow) return predictor_obj.update_fn(x, t) def shared_corrector_update_fn(x, t, sde, model, corrector, continuous, snr, n_steps, pred_type='noise'): """A wrapper tha configures and returns the update function of correctors.""" score_fn = mutils.get_score_fn(sde, model, train=False, continuous=continuous, pred_type=pred_type) if corrector is None: # Predictor-only sampler corrector_obj = NoneCorrector(sde, score_fn, snr, n_steps) else: corrector_obj = corrector(sde, score_fn, snr, n_steps) return corrector_obj.update_fn(x, t) def get_pc_sampler(sde, shape, predictor, corrector, inverse_scaler, snr, n_steps=1, probability_flow=False, continuous=False, denoise=True, eps=1e-3, device='cuda', grid_mask=None, return_traj=False, pred_type='noise', use_occ=False): """Create a Predictor-Corrector (PC) sampler. Args: sde: An `sde_lib.SDE` object representing the forward SDE. shape: A sequence of integers. The expected shape of a single sample. predictor: A subclass of `sampling.Predictor` representing the predictor algorithm. corrector: A subclass of `sampling.Corrector` representing the corrector algorithm. inverse_scaler: The inverse data normalizer. snr: A `float` number. The signal-to-noise ratio for configuring correctors. n_steps: An integer. The number of corrector steps per predictor update. probability_flow: If `True`, solve the reverse-time probability flow ODE when running the predictor. continuous: `True` indicates that the score model was continuously trained. denoise: If `True`, add one-step denoising to the final samples. eps: A `float` number. The reverse-time SDE and ODE are integrated to `epsilon` to avoid numerical issues. device: PyTorch device. Returns: A sampling function that returns samples and the number of function evaluations during sampling. """ # Create predictor & corrector update functions predictor_update_fn = functools.partial(shared_predictor_update_fn, sde=sde, predictor=predictor, probability_flow=probability_flow, continuous=continuous, pred_type=pred_type) corrector_update_fn = functools.partial(shared_corrector_update_fn, sde=sde, corrector=corrector, continuous=continuous, snr=snr, n_steps=n_steps, pred_type=pred_type) def pc_sampler(model, partial=None, partial_grid_mask=None, partial_channel=0, freeze_iters=None): """ The PC sampler funciton. Args: model: A score model. Returns: Samples, number of function evaluations. """ with torch.no_grad(): if freeze_iters is None: freeze_iters = sde.N + 10 # just some randomly large number greater than sde.N timesteps = torch.linspace(sde.T, eps, sde.N, device=device) mask = model.feature_mask if pred_type == 'noise': def compute_xzero(sde, model, x, t, grid_mask_input): timestep_int = (t * (sde.N - 1) / sde.T).long() alphas1 = sde.sqrt_alphas_cumprod[timestep_int].cuda() alphas2 = sde.sqrt_1m_alphas_cumprod[timestep_int].cuda() alphas1_prev = sde.sqrt_alphas_cumprod[timestep_int - 1].cuda() alphas2_prev = sde.sqrt_1m_alphas_cumprod[timestep_int - 1].cuda() score_pred = model(x, t * torch.ones(shape[0], device=x.device)) x0_pred_scaled = (x - alphas2 * score_pred) x0_pred = x0_pred_scaled / alphas1 x0_pred = x0_pred.clamp(-1, 1) return x0_pred * grid_mask_input elif pred_type == 'x0': def compute_xzero(sde, model, x, t, grid_mask_input): timestep_int = (t * (sde.N - 1) / sde.T).long() alphas1 = sde.sqrt_alphas_cumprod[timestep_int].cuda() alphas2 = sde.sqrt_1m_alphas_cumprod[timestep_int].cuda() alphas1_prev = sde.sqrt_alphas_cumprod[timestep_int - 1].cuda() alphas2_prev = sde.sqrt_1m_alphas_cumprod[timestep_int - 1].cuda() x0_pred = model(x, t * torch.ones(shape[0], device=x.device)) return x0_pred * grid_mask_input # Initial sample x = sde.prior_sampling(shape).to(device) assert len(x.size()) == 5 traj_buffer = [] if partial is not None: assert len(partial.size()) == 5 t = timesteps[0] vec_t = torch.ones(shape[0], device=t.device) * t x[:, partial_channel] = partial[:, partial_channel] * grid_mask[:, partial_channel] partial_mean, partial_std = sde.marginal_prob(x, vec_t) sampled_update = partial_mean[:, partial_channel] + partial_std[:, None, None, None, None] * torch.randn_like(partial_mean[:, partial_channel], device=partial_std.device) x[:, partial_channel] = ( x[:, partial_channel] * (1 - partial_mask[:, partial_channel]) + sampled_update[:, partial_channel] * partial_mask[:, partial_channel] ) * grid_mask[:, partial_channel] if partial is not None: x_mean = x for i in tqdm.trange(sde.N): t = timesteps[i] vec_t = torch.ones(shape[0], device=t.device) * t x, x_mean = corrector_update_fn(x, vec_t, model=model) x, x_mean = x * grid_mask, x_mean * grid_mask x, x_mean = predictor_update_fn(x, vec_t, model=model) x, x_mean = x * grid_mask, x_mean * grid_mask if i != sde.N - 1 and i < freeze_iters: x[:, partial_channel] = (x[:, partial_channel] * (1 - partial_mask[:, partial_channel]) + partial[:, partial_channel] * partial_mask[:, partial_channel]) * grid_mask[:, partial_channel] x_mean[:, partial_channel] = (x_mean[:, partial_channel] * (1 - partial_mask[:, partial_channel]) + partial[:, partial_channel] * partial_mask[:, partial_channel]) * grid_mask[:, partial_channel] ### add noise to the condition x0_star partial_mean, partial_std = sde.marginal_prob(x, timesteps[i] * torch.ones(shape[0], device=t.device)) sampled_update = partial_mean[:, partial_channel] + partial_std[:, None, None, None] * torch.randn_like(partial_mean[:, partial_channel], device=partial_std.device) x[:, partial_channel] = ( x[:, partial_channel] * (1 - partial_mask[:, partial_channel]) + sampled_update * partial_mask[:, partial_channel] ) * grid_mask[:, partial_channel] x_mean[:, partial_channel] = x[:, partial_channel] else: for i in tqdm.trange(sde.N - 1): t = timesteps[i] with torch.autocast(device_type='cuda', dtype=torch.float16, enabled=False): vec_t = torch.ones(shape[0], device=t.device) * t x, x_mean = corrector_update_fn(x, vec_t, model=model) x, x_mean = x * mask, x_mean * mask x, x_mean = predictor_update_fn(x, vec_t, model=model) x, x_mean = x * mask, x_mean * mask print(x.min(), x.max()) if return_traj and i >= 700 and i % 10 == 0: traj_buffer.append(compute_xzero(sde, model, x, t, grid_mask)) if return_traj: return traj_buffer, sde.N * (n_steps + 1) return inverse_scaler(x_mean * mask if denoise else x * mask), sde.N * (n_steps + 1) return pc_sampler def ddim_predictor_update_fn(x, t, tprev, sde, model, predictor, probability_flow, continuous, pred_type='noise'): """A wrapper that configures and returns the update function of predictors.""" assert not continuous score_fn = mutils.get_score_fn(sde, model, train=False, continuous=False, std_scale=False, pred_type=pred_type) if predictor is None: # Corrector-only sampler predictor_obj = NonePredictor(sde, score_fn, probability_flow) else: predictor_obj = predictor(sde, score_fn, probability_flow) return predictor_obj.update_fn(x, t, tprev) def get_ddim_sampler(sde, shape, predictor, inverse_scaler, n_steps=1, denoise=False, eps=1e-3, device='cuda', grid_mask=None, pred_type='noise', use_occ=False): """Probability flow ODE sampler with the black-box ODE solver. Args: sde: An `sde_lib.SDE` object that represents the forward SDE. shape: A sequence of integers. The expected shape of a single sample. inverse_scaler: The inverse data normalizer. denoise: If `True`, add one-step denoising to final samples. eps: A `float` number. The reverse-time SDE/ODE will be integrated to `eps` for numerical stability. device: PyTorch device. Returns: A sampling function that returns samples and the number of function evaluations during sampling. """ predictor_update_fn = functools.partial(ddim_predictor_update_fn, sde=sde, predictor=predictor, probability_flow=False, continuous=False, pred_type=pred_type) def ddim_sampler(model, schedule='quad', num_steps=100, x0=None, x0_occ=None, partial=None, partial_grid_mask=None, partial_channel=0): """ The PC sampler funciton. Args: model: A score model. Returns: Samples, number of function evaluations. """ with torch.no_grad(): print(device) if x0 is not None: x = x0.to(device) else: # Initial sample x = sde.prior_sampling(shape).to(device) mask = model.feature_mask if use_occ: occ_mask = model.occ_mask.float() if x0_occ is not None: x_occ = x0_occ.to(device) else: # Initial sample x_occ = sde.prior_sampling((x.size(0), 1, x.size(2)*2, x.size(3)*2, x.size(4)*2)).to(device) x = (x.float() * mask, x_occ.float() * occ_mask) if partial is not None: x[:, partial_channel] = x[:, partial_channel] * (1 - partial_mask) + partial * partial_mask timesteps = torch.linspace(sde.T, eps, sde.N, device=device) if schedule == 'uniform': skip = sde.N // num_steps seq = range(0, sde.N, skip) elif schedule == 'quad': seq = ( np.linspace( 0, np.sqrt(sde.N * 0.8), 100 ) ** 2 ) seq = [int(s) for s in list(seq)] timesteps = torch.tensor(seq, device=device) / sde.N for i in tqdm.tqdm(list(reversed(range(1, len(timesteps)))), leave=False): t = timesteps[i] tprev = timesteps[i - 1] vec_t = torch.ones(1, device=t.device) * t vec_tprev = torch.ones(1, device=t.device) * tprev with torch.autocast(device_type='cuda', dtype=torch.bfloat16, enabled=False): x, x0_pred = predictor_update_fn(x, vec_t, model=model, tprev=vec_tprev) if use_occ: x = (x[0] * mask, x[1] * occ_mask) x0_pred = (x0_pred[0] * mask, x0_pred[1] * occ_mask) # print(x[0].min(), x[0].max()) else: x, x0_pred = x * mask, x0_pred * mask # print(x.min(), x.max()) if partial is not None: x[:, partial_channel] = x[:, partial_channel] * (1 - partial_mask) + partial * partial_mask x0_pred[:, partial_channel] = x0_pred[:, partial_channel] * (1 - partial_mask) + partial * partial_mask if use_occ: encode = False return ( inverse_scaler(x0_pred[0] * mask if (denoise and not encode) else x[0] * mask), inverse_scaler(x0_pred[1] * occ_mask if (denoise and not encode) else x[1] * occ_mask) ), sde.N * (n_steps + 1) else: encode = False return inverse_scaler(x0_pred * mask if (denoise and not encode) else x * mask), sde.N * (n_steps + 1) return ddim_sampler ================================================ FILE: GMeshDiffusion/lib/diffusion/sde_lib.py ================================================ """Abstract SDE classes, Reverse SDE, and VE/VP SDEs.""" import abc import torch import numpy as np import torch.nn.functional as F import time class SDE(abc.ABC): """SDE abstract class. Functions are designed for a mini-batch of inputs.""" def __init__(self, N): """Construct an SDE. Args: N: number of discretization time steps. """ super().__init__() self.N = N @property @abc.abstractmethod def T(self): """End time of the SDE.""" pass @abc.abstractmethod def sde(self, x, t): pass @abc.abstractmethod def marginal_prob(self, x, t): """Parameters to determine the marginal distribution of the SDE, $p_t(x)$.""" pass @abc.abstractmethod def prior_sampling(self, shape): """Generate one sample from the prior distribution, $p_T(x)$.""" pass @abc.abstractmethod def prior_logp(self, z): """Compute log-density of the prior distribution. Useful for computing the log-likelihood via probability flow ODE. Args: z: latent code Returns: log probability density """ pass def discretize(self, x, t): """Discretize the SDE in the form: x_{i+1} = x_i + f_i(x_i) + G_i z_i. Useful for reverse diffusion sampling and probabiliy flow sampling. Defaults to Euler-Maruyama discretization. Args: x: a torch tensor t: a torch float representing the time step (from 0 to `self.T`) Returns: f, G """ dt = 1 / self.N drift, diffusion = self.sde(x, t) f = drift * dt G = diffusion * torch.sqrt(torch.tensor(dt, device=t.device)) return f, G def reverse(self, score_fn, probability_flow=False): """Create the reverse-time SDE/ODE. Args: score_fn: A time-dependent score-based model that takes x and t and returns the score. probability_flow: If `True`, create the reverse-time ODE used for probability flow sampling. """ N = self.N T = self.T sde_fn = self.sde discretize_fn = self.discretize sqrt_alphas_cumprod = self.sqrt_alphas_cumprod sqrt_1m_alphas_cumprod = self.sqrt_1m_alphas_cumprod # Build the class for reverse-time SDE. class RSDE(self.__class__): def __init__(self): self.N = N self.probability_flow = probability_flow @property def T(self): return T def sde(self, x, t): """Create the drift and diffusion functions for the reverse SDE/ODE.""" drift, diffusion = sde_fn(x, t) score = score_fn(x, t) drift = drift - diffusion[:, None, None, None, None] ** 2 * score * (0.5 if self.probability_flow else 1.) # Set the diffusion function to zero for ODEs. diffusion = 0. if self.probability_flow else diffusion return drift, diffusion def discretize(self, x, t): """Create discretized iteration rules for the reverse diffusion sampler.""" f, G = discretize_fn(x, t) rev_f = f - G[:, None, None, None, None] ** 2 * score_fn(x, t) * (0.5 if self.probability_flow else 1.) rev_G = torch.zeros_like(G) if self.probability_flow else G return rev_f, rev_G def discretize_ddim(self, x, t, tprev=None, encode=False): """DDPM discretization.""" timestep = (t * (N - 1) / T).long() timestep_prev = (tprev * (N - 1) / T).long() if type(x) == torch.Tensor: score = score_fn(x.float(), t.float()) # alphas1prev_div_alphas1 = torch.exp(log_diff) alphas1 = sqrt_alphas_cumprod[timestep].cuda()[:, None, None, None, None] alphas2 = sqrt_1m_alphas_cumprod[timestep].cuda()[:, None, None, None, None] alphas1_prev = sqrt_alphas_cumprod[timestep_prev].cuda()[:, None, None, None, None] alphas2_prev = sqrt_1m_alphas_cumprod[timestep_prev].cuda()[:, None, None, None, None] alphas1prev_div_alphas1 = alphas1_prev.double() / alphas1.double() alphas2prev_div_alphas2 = alphas2_prev.double() / alphas2.double() x0_pred_scaled = (x.double() - alphas2.double() * score.double()) use_clip = False if use_clip: # raise NotImplementedError x0_pred_scaled = x0_pred_scaled.clamp(-alphas1[0].squeeze(), alphas1[0].squeeze()) score_scaled_t = x - x0_pred_scaled x0_pred = x0_pred_scaled / alphas1 x_new = ( alphas1prev_div_alphas1.double() * x + (-alphas1prev_div_alphas1 + alphas2prev_div_alphas2.double()) * score_scaled_t.double() ) return x_new, x0_pred else: score, score_occ = score_fn(x, t.float()) x, x_occ = x alphas1 = sqrt_alphas_cumprod[timestep].cuda()[:, None, None, None, None] alphas2 = sqrt_1m_alphas_cumprod[timestep].cuda()[:, None, None, None, None] alphas1_prev = sqrt_alphas_cumprod[timestep_prev].cuda()[:, None, None, None, None] alphas2_prev = sqrt_1m_alphas_cumprod[timestep_prev].cuda()[:, None, None, None, None] alphas1prev_div_alphas1 = alphas1_prev / alphas1 alphas2prev_div_alphas2 = alphas2_prev / alphas2 x0_pred_scaled = (x - alphas2 * score) x0_occ_pred_scaled = (x_occ - alphas2 * score_occ) use_clip = False if use_clip: x0_pred_scaled = x0_pred_scaled.clamp(-alphas1[0].squeeze(), alphas1[0].squeeze()) x0_occ_pred_scaled = x0_occ_pred_scaled.clamp(-alphas1[0].squeeze(), alphas1[0].squeeze()) score_scaled_t = x - x0_pred_scaled x0_pred = x0_pred_scaled / alphas1 score_occ_scaled_t = x_occ - x0_occ_pred_scaled x0_occ_pred = x0_occ_pred_scaled / alphas1 x_new = ( alphas1prev_div_alphas1 * x + (-alphas1prev_div_alphas1 + alphas2prev_div_alphas2) * score_scaled_t ) x_occ_new = ( alphas1prev_div_alphas1 * x_occ + (-alphas1prev_div_alphas1 + alphas2prev_div_alphas2) * score_occ_scaled_t ) return (x_new, x_occ_new), (x0_pred, x0_occ_pred) def discretize_conditional_ddpm(self, x, t, tprev=None, condition_func=None, condition=False): """DDPM discretization.""" timestep = (t * (N - 1) / T).long() timestep_prev = (tprev * (N - 1) / T).long() score = score_fn(x.float(), t.float()) # alphas1prev_div_alphas1 = torch.exp(log_diff) alphas1 = sqrt_alphas_cumprod[timestep].cuda()[:, None, None, None, None] alphas2 = sqrt_1m_alphas_cumprod[timestep].cuda()[:, None, None, None, None] alphas1_prev = sqrt_alphas_cumprod[timestep_prev].cuda()[:, None, None, None, None] alphas1prev_div_alphas1 = alphas1_prev.double() / alphas1.double() x0_pred_scaled = (x.double() - alphas2.double() * score.double()) x0_pred_scaled = x0_pred_scaled.clamp(-alphas1[0].squeeze(), alphas1[0].squeeze()) x0_pred = x0_pred_scaled / alphas1 if condition is None: condition_update = 0 else: if (t - 0.99).mean() < 1e-3: x = x0_pred condition_update = condition_func(x.float(), condition) x_new = ( x - alphas1prev_div_alphas1.double() * condition_update ) return x_new, x0_pred return RSDE() class VPSDE(SDE): def __init__(self, beta_min=0.1, beta_max=20, N=1000): """Construct a Variance Preserving SDE. Args: beta_min: value of beta(0) beta_max: value of beta(1) N: number of discretization steps """ super().__init__(N) self.beta_0 = beta_min self.beta_1 = beta_max self.N = N self.discrete_betas = torch.linspace(beta_min / N, beta_max / N, N).cuda() self.alphas = 1. - self.discrete_betas self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) self.alphas_cumprod_ext = torch.cat([torch.tensor([1.0 - 1e-4]).cuda(), torch.cumprod(self.alphas, dim=0)], dim=0) self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod) self.sqrt_1m_alphas_cumprod = torch.sqrt(1. - self.alphas_cumprod) self.alphas_cumprod = self.alphas_cumprod self.alphas_cumprod_ext = self.alphas_cumprod_ext @property def T(self): return 1 def sde(self, x, t): beta_t = self.beta_0 + t * (self.beta_1 - self.beta_0) drift = -0.5 * beta_t[:, None, None, None, None] * x diffusion = torch.sqrt(beta_t) return drift, diffusion def marginal_prob(self, x, t): log_mean_coeff = -0.25 * t ** 2 * (self.beta_1 - self.beta_0) - 0.5 * t * self.beta_0 mean = torch.exp(log_mean_coeff[:, None, None, None, None]) * x std = torch.sqrt(1. - torch.exp(2. * log_mean_coeff)) return mean, std def prior_sampling(self, shape): return torch.randn(*shape) def prior_logp(self, z): shape = z.shape N = np.prod(shape[1:]) logps = -N / 2. * np.log(2 * np.pi) - torch.sum(z ** 2, dim=(1, 2, 3, 4)) / 2. return logps def discretize(self, x, t): """DDPM discretization.""" timestep = (t * (self.N - 1) / self.T).long() beta = self.discrete_betas.to(x.device)[timestep] alpha = self.alphas.to(x.device)[timestep] sqrt_beta = torch.sqrt(beta) f = torch.sqrt(alpha)[:, None, None, None, None] * x - x G = sqrt_beta return f, G ================================================ FILE: GMeshDiffusion/lib/diffusion/trainer.py ================================================ import os import sys import numpy as np import logging # Keep the import below for registering all model definitions from .models import unet3d, unet3d_occgrid, unet3d_tet_aware, unet3d_occgrid_v2, unet3d_meshdiffusion from . import losses from .models import utils as mutils from .models.ema import ExponentialMovingAverage from . import sde_lib import torch from torch.utils import tensorboard from .utils import save_checkpoint, restore_checkpoint from ..dataset.gshell_dataset import GShellDataset from ..dataset.gshell_dataset_aug import GShellAugDataset def train(config): """Runs the training pipeline. Args: config: Configuration to use. workdir: Working directory for checkpoints and TF summaries. If this contains checkpoint training will be resumed from the latest checkpoint. """ workdir = config.training.train_dir # Create directories for experimental logs logging.info("working dir: {:s}".format(workdir)) tb_dir = os.path.join(workdir, "tensorboard") writer = tensorboard.SummaryWriter(tb_dir) # Initialize model. score_model = mutils.create_model(config) ema = ExponentialMovingAverage(score_model.parameters(), decay=config.model.ema_rate) optimizer = losses.get_optimizer(config, score_model.parameters()) gradscaler = torch.cuda.amp.GradScaler(enabled=True) state = dict(optimizer=optimizer, model=score_model, ema=ema, gradscaler=gradscaler, step=0) # Create checkpoints directory checkpoint_dir = os.path.join(workdir, "checkpoints") # Intermediate checkpoints to resume training after pre-emption in cloud environments checkpoint_meta_dir = os.path.join(workdir, "checkpoints-meta", "checkpoint.pth") os.makedirs(checkpoint_dir, exist_ok=True) os.makedirs(os.path.dirname(checkpoint_meta_dir), exist_ok=True) # Resume training when intermediate checkpoints are detected state = restore_checkpoint(checkpoint_meta_dir, state, config.device) initial_step = int(state['step']) print(f"work dir: {workdir}") try: use_occ_grid = config.data.use_occ_grid except: use_occ_grid = False if use_occ_grid: train_dataset = GShellAugDataset(config) else: train_dataset = GShellDataset(config.data.dataset_metapath) try: collate_fn = train_dataset.collate except: collate_fn = None train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=config.training.batch_size, shuffle=True, num_workers=config.data.num_workers, collate_fn=collate_fn, pin_memory=True ) data_iter = iter(train_loader) print("data loader set") # Setup SDEs sde = sde_lib.VPSDE(beta_min=config.model.beta_min, beta_max=config.model.beta_max, N=config.model.num_scales) # Build one-step training and evaluation functions optimize_fn = losses.optimization_manager(config) try: use_vis_mask = config.model.use_vis_mask except: use_vis_mask = False print('use_vis_mask', use_vis_mask) train_step_fn = losses.get_step_fn(sde, train=True, optimize_fn=optimize_fn, loss_type=config.training.loss_type, pred_type=config.model.pred_type, use_vis_mask=use_vis_mask, use_occ=use_occ_grid, use_aux=config.training.use_aux_loss) num_train_steps = config.training.n_iters # In case there are multiple hosts (e.g., TPU pods), only log to host 0 logging.info("Starting training loop at step %d." % (initial_step // config.training.num_grad_acc_steps,)) iter_size = config.training.num_grad_acc_steps for step in range(initial_step // iter_size, num_train_steps + 1): tmp_loss_dict = { 'loss_total': 0.0, 'loss_score': 0.0, 'loss_reg': 0.0, } for step_inner in range(iter_size): try: # batch, batch_mask = next(data_iter) batch = next(data_iter) except StopIteration: # StopIteration is thrown if dataset ends # reinitialize data loader data_iter = iter(train_loader) batch = next(data_iter) if type(batch) == dict: for k in batch: batch[k] = batch[k].to('cuda', non_blocking=False) else: batch = batch.to('cuda', non_blocking=False) # Execute one training step clear_grad_flag = (step_inner == 0) update_param_flag = (step_inner == iter_size - 1) loss_dict = train_step_fn(state, batch, clear_grad=clear_grad_flag, update_param=update_param_flag, gradscaler=gradscaler) for key in loss_dict: tmp_loss_dict[key] += loss_dict[key].item() / iter_size # print(torch.cuda.memory_summary()) if step % config.training.log_freq == 0: # logging.info("step: %d, training_loss: %.5e" % (step, tmp_loss)) logging.info( "step: %d, loss_total: %.5e, loss_score: %.5e, loss_reg: %.5e" % (step, tmp_loss_dict['loss_total'], tmp_loss_dict['loss_score'], tmp_loss_dict['loss_reg']) ) sys.stdout.flush() writer.add_scalar("loss_total", tmp_loss_dict['loss_total'], step) writer.add_scalar("loss_score", tmp_loss_dict['loss_score'], step) writer.add_scalar("loss_reg", tmp_loss_dict['loss_reg'], step) # Save a temporary checkpoint to resume training after pre-emption periodically if step != 0 and step % config.training.snapshot_freq_for_preemption == 0: logging.info(f"save meta at iter {step}") save_checkpoint(checkpoint_meta_dir, state) # Save a checkpoint periodically and generate samples if needed if step != 0 and step % config.training.snapshot_freq == 0 or step == num_train_steps: logging.info(f"save model: {step}-th") save_checkpoint(os.path.join(checkpoint_dir, f'checkpoint_{step}.pth'), state) ================================================ FILE: GMeshDiffusion/lib/diffusion/trainer_ddp.py ================================================ import os import sys import numpy as np import logging # Keep the import below for registering all model definitions from .models import unet3d, unet3d_occgrid, unet3d_tet_aware, unet3d_occgrid_v2, unet3d_meshdiffusion from . import losses from .models import utils as mutils from .models.ema import ExponentialMovingAverage from . import sde_lib import torch from torch.utils import tensorboard from .utils import save_checkpoint, restore_checkpoint from ..dataset.gshell_dataset import GShellDataset from ..dataset.gshell_dataset_aug import GShellAugDataset from .lion.lion import Lion import torch.distributed as dist def train(config): """Runs the training pipeline. Args: config: Configuration to use. workdir: Working directory for checkpoints and TF summaries. If this contains checkpoint training will be resumed from the latest checkpoint. """ dist.init_process_group("nccl") rank = dist.get_rank() torch.cuda.set_device(rank) device = torch.device("cuda", rank) print(f"Start running basic DDP example on rank {rank}.") # create model and move it to GPU with id rank world_size = torch.cuda.device_count() device_id = rank % torch.cuda.device_count() workdir = config.training.train_dir # Create directories for experimental logs logging.info("working dir: {:s}".format(workdir)) tb_dir = os.path.join(workdir, "tensorboard") writer = tensorboard.SummaryWriter(tb_dir) # Initialize model. score_model = mutils.create_model(config, ddp=True, rank=rank) ema = ExponentialMovingAverage(score_model.parameters(), decay=config.model.ema_rate) optimizer = losses.get_optimizer(config, score_model.parameters()) gradscaler = torch.cuda.amp.GradScaler(growth_interval=config.training.gradscaler_growth_interval) state = dict(optimizer=optimizer, model=score_model, ema=ema, gradscaler=gradscaler, step=0) # Create checkpoints directory checkpoint_dir = os.path.join(workdir, "checkpoints") # Intermediate checkpoints to resume training after pre-emption in cloud environments checkpoint_meta_dir = os.path.join(workdir, "checkpoints-meta", "checkpoint.pth") os.makedirs(checkpoint_dir, exist_ok=True) os.makedirs(os.path.dirname(checkpoint_meta_dir), exist_ok=True) # Resume training when intermediate checkpoints are detected state = restore_checkpoint(checkpoint_meta_dir, state, config.device, rank=rank) initial_step = int(state['step']) print(f"work dir: {workdir}") try: use_occ_grid = config.data.use_occ_grid except: use_occ_grid = False if use_occ_grid: train_dataset = GShellAugDataset(config) else: train_dataset = GShellDataset(config.data.dataset_metapath) train_sampler = torch.utils.data.distributed.DistributedSampler( train_dataset, num_replicas=world_size, rank=rank ) try: collate_fn = train_dataset.collate except: collate_fn = None train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=config.training.batch_size, num_workers=config.data.num_workers, # pin_memory=True, sampler=train_sampler, collate_fn=collate_fn ) data_iter = iter(train_loader) print("data loader set") # Setup SDEs sde = sde_lib.VPSDE(beta_min=config.model.beta_min, beta_max=config.model.beta_max, N=config.model.num_scales) # Build one-step training and evaluation functions optimize_fn = losses.optimization_manager(config) try: use_vis_mask = config.model.use_vis_mask except: use_vis_mask = False print('use_vis_mask', use_vis_mask) train_step_fn = losses.get_step_fn(sde, train=True, optimize_fn=optimize_fn, loss_type=config.training.loss_type, pred_type=config.model.pred_type, use_vis_mask=use_vis_mask, use_occ=use_occ_grid, use_aux=config.training.use_aux_loss) num_train_steps = config.training.n_iters # In case there are multiple hosts (e.g., TPU pods), only log to host 0 logging.info("Starting training loop at step %d." % (initial_step // config.training.num_grad_acc_steps,)) iter_size = config.training.num_grad_acc_steps epoch = 0 train_sampler.set_epoch(epoch) for step in range(initial_step // iter_size, num_train_steps + 1): tmp_loss_dict = { 'loss_total': 0.0, 'loss_score': 0.0, 'loss_reg': 0.0, } for step_inner in range(iter_size): try: # batch, batch_mask = next(data_iter) batch = next(data_iter) except StopIteration: # StopIteration is thrown if dataset ends # reinitialize data loader epoch += 1 train_sampler.set_epoch(epoch) data_iter = iter(train_loader) batch = next(data_iter) if type(batch) == dict: for k in batch: batch[k] = batch[k].to(rank, non_blocking=False) else: batch = batch.to(rank, non_blocking=False) # Execute one training step clear_grad_flag = (step_inner == 0) update_param_flag = (step_inner == iter_size - 1) if not update_param_flag: with score_model.no_sync(): loss_dict = train_step_fn(state, batch, clear_grad=clear_grad_flag, update_param=update_param_flag, gradscaler=gradscaler) else: loss_dict = train_step_fn(state, batch, clear_grad=clear_grad_flag, update_param=update_param_flag, gradscaler=gradscaler) for key in loss_dict: tmp_loss_dict[key] += loss_dict[key].item() / iter_size # print(torch.cuda.memory_summary()) if step % config.training.log_freq == 0: loss = tmp_loss_dict['loss_total'] loss = torch.tensor(loss / world_size).to(rank) # logging.info("step: %d, training_loss: %.5e" % (step, tmp_loss)) dist.reduce(loss, dst=0, op=dist.ReduceOp.SUM) if rank == 0: loss = loss.item() logging.info("step: %d, loss: %.5e, scale: %.5e" % (step, loss, gradscaler.get_scale())) sys.stdout.flush() writer.add_scalar("loss", loss, step) if rank == 0: # Save a temporary checkpoint to resume training after pre-emption periodically if step != 0 and step % config.training.snapshot_freq_for_preemption == 0: logging.info(f"save meta at iter {step}") save_checkpoint(checkpoint_meta_dir, state) # Save a checkpoint periodically and generate samples if needed if step != 0 and step % config.training.snapshot_freq == 0 or step == num_train_steps: logging.info(f"save model: {step}-th") save_checkpoint(os.path.join(checkpoint_dir, f'checkpoint_{step}.pth'), state) dist.destroy_process_group() ================================================ FILE: GMeshDiffusion/lib/diffusion/utils.py ================================================ import torch import os import logging def restore_checkpoint(ckpt_dir, state, device, strict=False, rank=None): if not os.path.exists(ckpt_dir): os.makedirs(os.path.dirname(ckpt_dir), exist_ok=True) logging.warning(f"No checkpoint found at {ckpt_dir}. " f"Returned the same state as input") if strict: raise return state else: if rank is not None: device = f"cuda:{rank}" # loaded_state = torch.load(ckpt_dir, map_location=device) loaded_state = torch.load(ckpt_dir, map_location='cpu') state['optimizer'].load_state_dict(loaded_state['optimizer']) try: state['model'].load_state_dict(loaded_state['model'], strict=False) except: consume_prefix_in_state_dict_if_present(loaded_state['model']) state['model'].load_state_dict(loaded_state['model'], strict=False) state['ema'].load_state_dict(loaded_state['ema'], device=device) state['step'] = loaded_state['step'] state['model'].to(device) try: state['gradscaler'].load_state_dict(loaded_state['gradscaler']) # state['gradscaler'].to(device) except: # raise pass torch.cuda.empty_cache() return state def save_checkpoint(ckpt_dir, state): saved_state = { 'optimizer': state['optimizer'].state_dict(), 'model': state['model'].state_dict(), 'ema': state['ema'].state_dict(), 'step': state['step'], 'gradscaler': state['gradscaler'].state_dict() } torch.save(saved_state, ckpt_dir) ================================================ FILE: GMeshDiffusion/main_diffusion.py ================================================ """Training and evaluation""" from absl import app from absl import flags from ml_collections.config_flags import config_flags import lib.diffusion.trainer as trainer import lib.diffusion.evaler as evaler FLAGS = flags.FLAGS config_flags.DEFINE_config_file( "config", None, "diffusion configs", lock_config=False) flags.DEFINE_enum("mode", None, ["train", "uncond_gen", "cond_gen", "uncond_gen_interp"], "Running mode") flags.mark_flags_as_required(["config", "mode"]) def main(argv): if FLAGS.mode == 'train': trainer.train(FLAGS.config) elif FLAGS.mode == 'uncond_gen': evaler.uncond_gen(FLAGS.config) elif FLAGS.mode == 'uncond_gen_interp': evaler.uncond_gen_interp(FLAGS.config) elif FLAGS.mode == 'cond_gen': evaler.cond_gen(FLAGS.config) if __name__ == "__main__": app.run(main) ================================================ FILE: GMeshDiffusion/main_diffusion_ddp.py ================================================ """Training and evaluation""" from absl import app from absl import flags from ml_collections.config_flags import config_flags import lib.diffusion.trainer_ddp as trainer import lib.diffusion.evaler as evaler FLAGS = flags.FLAGS config_flags.DEFINE_config_file( "config", None, "diffusion configs", lock_config=False) flags.DEFINE_enum("mode", None, ["train", "uncond_gen", "cond_gen", "uncond_gen_interp"], "Running mode") flags.mark_flags_as_required(["config", "mode"]) def main(argv): print(FLAGS.config) if FLAGS.mode == 'train': trainer.train(FLAGS.config) if __name__ == "__main__": app.run(main) ================================================ FILE: GMeshDiffusion/metadata/get_splits_lower.py ================================================ import os import random random.seed(42) split_ratio = 0.9 data_root = 'PLACEHOLDER' grid_root = os.path.join(data_root, 'grid') occgrid_root = os.path.join(data_root, 'grid_aug') data_path_list = sorted([os.path.join(data_root, fpath) for fpath in os.listdir(data_root)]) random.shuffle(data_path_list) n_train = int(len(data_path_list) * split_ratio) train_list = data_path_list[:n_train] test_list = data_path_list[n_train:] with open('lower_res64_grid_train.txt', 'w') as f: f.write('\n'.join(train_list)) with open('lower_res64_grid_test.txt', 'w') as f: f.write('\n'.join(test_list)) occgrid_train_list = [os.path.join(occgrid_root, x.split('/')[-1]) for x in train_list] occgrid_test_list = [os.path.join(occgrid_root, x.split('/')[-1]) for x in test_list] with open('lower_res64_occgrid_train.txt', 'w') as f: f.write('\n'.join(occgrid_train_list)) with open('lower_res64_occgrid_test.txt', 'w') as f: f.write('\n'.join(occgrid_test_list)) ================================================ FILE: GMeshDiffusion/metadata/get_splits_upper.py ================================================ import os import random random.seed(42) split_ratio = 0.9 data_root = 'PLACEHOLDER' grid_root = os.path.join(data_root, 'grid') occgrid_root = os.path.join(data_root, 'grid_aug') data_path_list = sorted([os.path.join(data_root, fpath) for fpath in os.listdir(data_root)]) random.shuffle(data_path_list) n_train = int(len(data_path_list) * split_ratio) train_list = data_path_list[:n_train] test_list = data_path_list[n_train:] with open('upper_res64_grid_train.txt', 'w') as f: f.write('\n'.join(train_list)) with open('upper_res64_grid_test.txt', 'w') as f: f.write('\n'.join(test_list)) occgrid_train_list = [os.path.join(occgrid_root, x.split('/')[-1]) for x in train_list] occgrid_test_list = [os.path.join(occgrid_root, x.split('/')[-1]) for x in test_list] with open('upper_res64_occgrid_train.txt', 'w') as f: f.write('\n'.join(occgrid_train_list)) with open('upper_res64_occgrid_test.txt', 'w') as f: f.write('\n'.join(occgrid_test_list)) ================================================ FILE: GMeshDiffusion/metadata/save_tet_info.py ================================================ ''' Storing tet-grid related meta-info into a single file ''' import numpy as np import torch import os import tqdm import argparse from itertools import combinations def tet_to_grids(vertices, values_list, grid_size): grid = torch.zeros(12, grid_size, grid_size, grid_size, device=vertices.device) with torch.no_grad(): for k, values in enumerate(values_list): if k == 0: grid[k, vertices[:, 0], vertices[:, 1], vertices[:, 2]] = values.squeeze() else: grid[1:4, vertices[:, 0], vertices[:, 1], vertices[:, 2]] = values.transpose(0, 1) return grid if __name__ == "__main__": parser = argparse.ArgumentParser(description='nvdiffrec') parser.add_argument('-res', '--resolution', type=int) parser.add_argument('-r', '--root', type=str) parser.add_argument('-s', '--source', type=str) parser.add_argument('-t', '--target', type=str) FLAGS = parser.parse_args() tet_path = f'./tets/{FLAGS.resolution}_tets_cropped_reordered.npz' tet = np.load(tet_path) vertices = torch.tensor(tet['vertices']).cuda() indices = torch.tensor(tet['indices']).long().cuda() edges = torch.tensor(tet['edges']).long().cuda() tet_edges = torch.tensor(tet['tet_edges']).long().view(-1, 2).cuda() vertices_unique = vertices[:].unique() dx = vertices_unique[1] - vertices_unique[0] dx = dx / 2.0 ### denser grid vertices_discretized = ( ((vertices - vertices.min()) / dx) ).long() midpoints = (vertices_discretized[edges[:, 0]] + vertices_discretized[edges[:, 1]]) / 2.0 midpoints_dicretized = midpoints.long() tet_verts = vertices_discretized[indices.view(-1)].view(-1, 4, 3) tet_center = tet_verts.float().mean(dim=1) tet_center_discretized = tet_center.long() edge_ind_list = [[0, 1], [0, 2], [0, 3], [1, 2], [1, 3], [2, 3]] msdf_tetedges = [] msdf_from_tetverts = [] for i in range(5): for j in range(i+1, 6): if (edge_ind_list[i][0] == edge_ind_list[j][0] or edge_ind_list[i][0] == edge_ind_list[j][1] or edge_ind_list[i][1] == edge_ind_list[j][0] or edge_ind_list[i][1] == edge_ind_list[j][1] ): msdf_tetedges.append(i) msdf_tetedges.append(j) msdf_from_tetverts.extend([edge_ind_list[i][0], edge_ind_list[i][1], edge_ind_list[j][0], edge_ind_list[j][1]]) msdf_tetedges = torch.tensor(msdf_tetedges) msdf_from_tetverts = torch.tensor(msdf_from_tetverts) print(msdf_tetedges) print(msdf_tetedges.size()) tet_edges = tet_edges.view(-1, 2) msdf_tetedges = msdf_tetedges.view(-1) tet_edgenodes_pos = (vertices_discretized[tet_edges[:, 0]] + vertices_discretized[tet_edges[:, 1]]) / 2.0 tet_edgenodes_pos = tet_edgenodes_pos.view(-1, 6, 2) occ_edge_pos = tet_edgenodes_pos[:, msdf_tetedges].view(-1, 12, 2, 3) edge_twopoint_order = torch.sign(occ_edge_pos[:, :, 0, :] - occ_edge_pos[:, :, 1, :]) edge_twopoint_order_binary_code = (edge_twopoint_order * torch.tensor([16, 4, 1], device=edge_twopoint_order.device).view(1, 1, -1)).sum(dim=-1) edge_twopoint_order_binary_code = torch.stack([edge_twopoint_order_binary_code, -edge_twopoint_order_binary_code], dim=-1) _, edge_twopoint_order = edge_twopoint_order_binary_code.sort(dim=-1) occ_edge_cano_order = torch.arange(2).view(1, 1, 2).expand(occ_edge_pos.size(0), 12, 2).cuda() occ_edge_cano_order = torch.gather( input=occ_edge_cano_order, dim=-1, index=edge_twopoint_order ) tet_edges = tet_edges.view(-1) torch.save({ 'tet_v_pos': vertices, 'tet_edge_vpos': vertices[tet_edges].view(-1, 2, 3), 'tet_edge_pix_loc': vertices_discretized[tet_edges].view(-1, 2, 3), 'tet_center_loc': tet_center_discretized, 'msdf_edges': msdf_tetedges.view(12, 2), 'occ_edge_cano_order': occ_edge_cano_order }, 'tet_info.pt') ================================================ FILE: GMeshDiffusion/metadata/tet_to_cubic_grid_dataset.py ================================================ import numpy as np import torch import os import tqdm import argparse def tet_to_grids(vertices, values_list, grid_size): grid = torch.zeros(4, grid_size, grid_size, grid_size, device=vertices.device) with torch.no_grad(): for k, values in enumerate(values_list): if k == 0: grid[k, vertices[:, 0], vertices[:, 1], vertices[:, 2]] = values.squeeze() else: grid[1:4, vertices[:, 0], vertices[:, 1], vertices[:, 2]] = values.transpose(0, 1) return grid if __name__ == "__main__": parser = argparse.ArgumentParser(description='nvdiffrec') parser.add_argument('-res', '--resolution', type=int) parser.add_argument('-ss', '--split-size', type=int, default=int(1e8)) parser.add_argument('-ind', '--index', type=int) parser.add_argument('-r', '--root', type=str) parser.add_argument('-s', '--source', type=str) parser.add_argument('-t', '--target', type=str) FLAGS = parser.parse_args() tet_path = f'./tets/{FLAGS.resolution}_tets_cropped_reordered.npz' tet = np.load(tet_path) vertices = torch.tensor(tet['vertices']).cuda() indices = torch.tensor(tet['indices']).long().cuda() edges = torch.tensor(tet['edges']).long().cuda() tet_edges = torch.tensor(tet['tet_edges']).long().view(-1, 2).cuda() vertices_unique = vertices[:].unique() dx = vertices_unique[1] - vertices_unique[0] dx = dx / 2.0 ### denser grid vertices_discretized = ( ((vertices - vertices.min()) / dx) ).long() print(vertices_discretized.size()) midpoints = (vertices_discretized[edges[:, 0]] + vertices_discretized[edges[:, 1]]) / 2.0 midpoints_dicretized = midpoints.long() tet_verts = vertices_discretized[indices.view(-1)].view(-1, 4, 3) tet_center = tet_verts.float().mean(dim=1) tet_center_discretized = tet_center.long() global_mask = torch.zeros(4, FLAGS.resolution * 2, FLAGS.resolution * 2, FLAGS.resolution * 2).cuda() cat_mask = torch.zeros(FLAGS.resolution * 2, FLAGS.resolution * 2, FLAGS.resolution * 2).cuda() global_mask[:4, vertices_discretized[:, 0], vertices_discretized[:, 1], vertices_discretized[:, 2]] += 1.0 cat_mask[vertices_discretized[:, 0], vertices_discretized[:, 1], vertices_discretized[:, 2]] = 1 global_mask[0, midpoints_dicretized[:, 0], midpoints_dicretized[:, 1], midpoints_dicretized[:, 2]] += 1.0 cat_mask[midpoints_dicretized[:, 0], midpoints_dicretized[:, 1], midpoints_dicretized[:, 2]] = -1 torch.save(global_mask, f'global_mask_res{FLAGS.resolution}.pt') torch.save(cat_mask, f'cat_mask_res{FLAGS.resolution}.pt') save_folder = FLAGS.root grid_folder_base = os.path.join(save_folder, FLAGS.target) os.makedirs(grid_folder_base, exist_ok=True) print(grid_folder_base) edge_ind_list = [[0, 1], [0, 2], [0, 3], [1, 2], [1, 3], [2, 3]] msdf_tetedges = [] msdf_from_tetverts = [] for i in range(5): for j in range(i+1, 6): if (edge_ind_list[i][0] == edge_ind_list[j][0] or edge_ind_list[i][0] == edge_ind_list[j][1] or edge_ind_list[i][1] == edge_ind_list[j][0] or edge_ind_list[i][1] == edge_ind_list[j][1] ): msdf_tetedges.append(i) msdf_tetedges.append(j) msdf_from_tetverts.extend([edge_ind_list[i][0], edge_ind_list[i][1], edge_ind_list[j][0], edge_ind_list[j][1]]) msdf_tetedges = torch.tensor(msdf_tetedges) msdf_from_tetverts = torch.tensor(msdf_from_tetverts) print(msdf_tetedges) print(msdf_tetedges.size()) occgrid_mask_already_saved = False tets_folder = os.path.join(save_folder, FLAGS.source) with torch.no_grad(): for k in tqdm.trange(FLAGS.split_size): global_index = k + FLAGS.index * FLAGS.split_size tet_path = os.path.join(tets_folder, 'dmt_dict_{:05d}.pt'.format(global_index)) if os.path.exists(os.path.join(grid_folder_base, 'grid_{:05d}.pt'.format(global_index))): # continue pass try: if os.path.exists(tet_path): tet = torch.load(tet_path, map_location="cuda") sdf = tet['sdf'].view(-1, 1) msdf = tet['msdf'].view(-1, 1) deform = tet['deform'] ### resetting sdfs and offsets of all non-mesh-generating tet vertices tet_edges = tet_edges.view(-1, 2) tet_edge_mask = ((torch.sign(sdf[tet_edges[:, 0]]) - torch.sign(sdf[tet_edges[:, 1]])) != 0).bool().squeeze(-1).view(-1, 6) tet_sdf_coeff = ( torch.abs(sdf[tet_edges[:, 0]]) / (torch.abs(sdf[tet_edges[:, 0]] - sdf[tet_edges[:, 1]]) + 1e-10) ).squeeze(-1) tet_sdf_coeff = tet_sdf_coeff.view(-1, 1) midpoint_msdf_tet = msdf[tet_edges[:, 0]] * (1 - tet_sdf_coeff) + msdf[tet_edges[:, 1]] * tet_sdf_coeff midpoint_msdf_tet = midpoint_msdf_tet.view(-1, 6) tet_mask = ((midpoint_msdf_tet > 0) & tet_edge_mask).sum(dim=-1).bool() vert_mask = torch.zeros_like(sdf.squeeze()) vert_mask[indices[tet_mask].view(-1)] = 1.0 vert_mask = ~vert_mask.bool() msdf[vert_mask] = -1.0 deform[vert_mask] = 0.0 tet_nonallnegmsdf = (torch.sign(msdf[indices.view(-1)].view(-1, 4)).sum(dim=-1) != -4) vert_mask_nonallnegmsdf = torch.zeros_like(sdf.squeeze()) vert_mask_nonallnegmsdf[indices[tet_nonallnegmsdf].view(-1)] = 1.0 vert_mask_nonallnegmsdf = ~vert_mask_nonallnegmsdf.bool() sdf[vert_mask_nonallnegmsdf] = 1.0 mask = ( (torch.sign(sdf[edges[:, 0]]) - torch.sign(sdf[edges[:, 1]]) != 0).bool() ) nan_mask = ( ((torch.sign(sdf[edges[:, 0]]) + torch.sign(sdf[edges[:, 1]])) == 2) | ((torch.sign(sdf[edges[:, 0]]) + torch.sign(sdf[edges[:, 1]])) == -2) ).bool().squeeze(-1) original_sdf_coeff = torch.abs(sdf[edges[:, 0]]) / (torch.abs(sdf[edges[:, 0]] - sdf[edges[:, 1]]) + 1e-10) original_sdf_coeff[nan_mask] = torch.nan normalized_sdf_coeff = ((original_sdf_coeff - 0.5) * 2.0) normalized_sdf_coeff = torch.nan_to_num(normalized_sdf_coeff) assert torch.all(normalized_sdf_coeff.abs() <= 1.0) sdf_sign = torch.sign(sdf) sdf_sign[sdf_sign == 0] = 1 midpoint_msdf = msdf[edges[:, 0]] * (1 - original_sdf_coeff.view(-1, 1)) + msdf[edges[:, 1]] * original_sdf_coeff.view(-1, 1) midpoint_msdf_sign = torch.sign(midpoint_msdf) midpoint_msdf_sign[midpoint_msdf_sign == 0] = -1 midpoint_msdf_sign = midpoint_msdf_sign * mask - (1.0 - mask.float()) ############################ Occ Grid ############################ tet_edges = tet_edges.view(-1, 2) tet_edge_mask = ((torch.sign(sdf[tet_edges[:, 0]]) - torch.sign(sdf[tet_edges[:, 1]])) != 0).bool().squeeze(-1).view(-1, 6) tet_sdf_coeff = ( torch.abs(sdf[tet_edges[:, 0]]) / (torch.abs(sdf[tet_edges[:, 0]] - sdf[tet_edges[:, 1]]) + 1e-10) ).squeeze(-1) tet_sdf_coeff = tet_sdf_coeff * tet_edge_mask.view(-1) tet_sdf_coeff = tet_sdf_coeff.view(-1, 1) nan_mask = ( ((torch.sign(sdf[tet_edges[:, 0]]) + torch.sign(sdf[tet_edges[:, 1]])) == 2) | ((torch.sign(sdf[tet_edges[:, 0]]) + torch.sign(sdf[tet_edges[:, 1]])) == -2) ).bool().squeeze(-1) tet_sdf_coeff[nan_mask] = torch.nan midpoint_msdf_tet = msdf[tet_edges[:, 0]] * (1 - tet_sdf_coeff) + msdf[tet_edges[:, 1]] * tet_sdf_coeff midpoint_msdf_tet = midpoint_msdf_tet.view(-1, 6) inscribed_edge_twopoint_msdf = midpoint_msdf_tet[:, msdf_tetedges.view(-1)].view(-1, 12, 2) assert (( (tet_edges.view(-1, 6, 2)[:, msdf_tetedges.view(-1), :].view(-1, 24, 2).sum(dim=-1)) - indices[:, msdf_from_tetverts].view(-1, 24, 2).sum(dim=-1) ).sum().item() == 0) assert msdf_tetedges.view(-1).size(0) == 24 inscribed_tet_fourpoint_pos = vertices_discretized[indices[:, msdf_from_tetverts].view(-1)].view(-1, 12, 4, 3).to(torch.float64) inscribed_edge_twopoint_pos = inscribed_tet_fourpoint_pos.view(-1, 12, 2, 2, 3).mean(dim=-2) occgrid_loc = inscribed_edge_twopoint_pos.mean(dim=-2) occgrid_loc = (occgrid_loc * 2).to(torch.int64).view(-1, 3) edge_twopoint_order = torch.sign(inscribed_edge_twopoint_pos[:, :, 0, :] - inscribed_edge_twopoint_pos[:, :, 1, :]) edge_twopoint_order_binary_code = (edge_twopoint_order * torch.tensor([16, 4, 1], device=edge_twopoint_order.device).view(1, 1, -1)).sum(dim=-1) edge_twopoint_order_binary_code = torch.stack([edge_twopoint_order_binary_code, -edge_twopoint_order_binary_code], dim=-1) _, edge_twopoint_order = edge_twopoint_order_binary_code.sort(dim=-1) inscribed_edge_twopoint_msdf = torch.gather( input=inscribed_edge_twopoint_msdf, dim=-1, index=edge_twopoint_order ) mask_msdf = ( ((inscribed_edge_twopoint_msdf[:, :, 0] > 0) & (inscribed_edge_twopoint_msdf[:, :, 1] <= 0)) | ((inscribed_edge_twopoint_msdf[:, :, 0] <= 0) & (inscribed_edge_twopoint_msdf[:, :, 1] > 0)) ) msdf_coeff_12 = ( torch.abs(inscribed_edge_twopoint_msdf[:, :, 0]) / ( torch.abs(inscribed_edge_twopoint_msdf[:, :, 0] - inscribed_edge_twopoint_msdf[:, :, 1]) + 1e-10 ) ) msdf_coeff_12 = (msdf_coeff_12 - 0.5) * 2.0 * mask_msdf msdf_coeff_12 = torch.nan_to_num(msdf_coeff_12) occ_grid = torch.zeros(256, 256, 256, dtype=torch.float, device=msdf_coeff_12.device) occ_grid[occgrid_loc[:, 0], occgrid_loc[:, 1], occgrid_loc[:, 2]] = msdf_coeff_12.view(-1).to(torch.float) if not occgrid_mask_already_saved: occ_grid_mask = torch.zeros(256, 256, 256, dtype=torch.float, device=msdf_coeff_12.device) occ_grid_mask[occgrid_loc[:, 0], occgrid_loc[:, 1], occgrid_loc[:, 2]] = 1 torch.save(occ_grid_mask, f'occ_mask_res{FLAGS.resolution}.pt') occgrid_mask_already_saved = True # ################# torch.cuda.empty_cache() grid = torch.zeros(4, FLAGS.resolution * 2, FLAGS.resolution * 2, FLAGS.resolution * 2).cuda() grid[0, vertices_discretized[:, 0], vertices_discretized[:, 1], vertices_discretized[:, 2]] = sdf_sign.squeeze() grid[1:4, vertices_discretized[:, 0], vertices_discretized[:, 1], vertices_discretized[:, 2]] = deform.transpose(0, 1) grid[0, midpoints_dicretized[:, 0], midpoints_dicretized[:, 1], midpoints_dicretized[:, 2]] = midpoint_msdf_sign.squeeze() assert grid.abs().max() <= 1 save_path = os.path.join(grid_folder_base, 'grid_{:05d}.pt'.format(global_index)) torch.save(grid, save_path) save_path = os.path.join(grid_folder_base, 'occgrid_{:05d}.pt'.format(global_index)) torch.save(occ_grid, save_path) except: raise ================================================ FILE: GMeshDiffusion/scripts/run_eval_lower_occgrid_normalized.sh ================================================ python main_diffusion.py --mode uncond_gen --config diffusion_configs/config_lower_occgrid_normalized.py \ --config.eval.eval_dir=$EVAL_DIR \ --config.data.root_dir=$REPO_ROOT_DIR \ --config.sampling.method=ddim \ --config.eval.ckpt_path=$CKPT_PATH \ --config.eval.bin_size=30 \ --config.eval.idx $1 ================================================ FILE: GMeshDiffusion/scripts/run_eval_upper_occgrid_normalized.sh ================================================ python main_diffusion.py --mode uncond_gen --config diffusion_configs/config_upper_occgrid_normalized.py \ --config.eval.eval_dir=$EVAL_DIR \ --config.data.root_dir=$REPO_ROOT_DIR \ --config.sampling.method=ddim \ --config.eval.ckpt_path=$CKPT_PATH \ --config.eval.bin_size=10 \ --config.eval.idx $1 ================================================ FILE: GMeshDiffusion/scripts/run_lower_occgrid_normalized_ddp.sh ================================================ torchrun --nnodes=1 --nproc_per_node=8 main_diffusion_ddp.py --mode=train --config=diffusion_configs/config_lower_occgrid_normalized.py \ --config.training.train_dir=$SAVE_DIR --config.data.root_dir=$REPO_ROOT_DIR ================================================ FILE: GMeshDiffusion/scripts/run_upper_occgrid_normalized_ddp.sh ================================================ torchrun --nnodes=1 --nproc_per_node=8 main_diffusion_ddp.py --mode=train --config=diffusion_configs/config_upper_occgrid_normalized.py \ --config.training.train_dir=$SAVE_DIR --config.data.root_dir=$REPO_ROOT_DIR ================================================ FILE: README.md ================================================
# Ghost on the Shell: An Expressive Representation of General 3D Shapes
## Introduction This is the official implementation of our paper (ICLR 2024 oral) "Ghost on the Shell: An Expressive Representation of General 3D Shapes" (G-Shell). G-Shell is a generic and differentiable representation for both watertight and non-watertight meshes. It enables 1) efficient and robust rasterization-based multiview reconstruction and 2) template-free generation of non-watertight meshes. Please refer to [our project page](https://gshell3d.github.io) and [our paper](https://gshell3d.github.io/static/paper/gshell.pdf) for more details. ## Getting Started ### Requirements - Python >= 3.8 - CUDA 11.8 - PyTorch == 1.13.1 (Conda installation recommended) #### Reconstruction Run the following ``` pip install ninja imageio PyOpenGL glfw xatlas gdown pip install git+https://github.com/NVlabs/nvdiffrast/ pip install --global-option="--no-networks" git+https://github.com/NVlabs/tiny-cuda-nn#subdirectory=bindings/torch ``` Follow the instructions [here](https://github.com/NVIDIAGameWorks/kaolin/) to install kaolin. Download the tet-grid files ([res128](https://drive.google.com/file/d/1u5FzpuY_BOAg8-g9lRwvah7mbCBOfNVg/view?usp=sharing), [res256](https://drive.google.com/file/d/1JnFoPEGcTLFJ7OHSWrI72h1H9_yOxUP6/view?usp=sharing)) & [res64 for G-MeshDiffusion](https://drive.google.com/file/d/1YQuU4D-0q8kwrzEfla3hGzBg4erBhand/view?usp=drive_link) to `data/tets` folder under the root directory. Alternatively, you may follow https://github.com/crawforddoran/quartet and `data/tets/generate_tets.py` to create the tet-grid files. #### Generation Install the following - Pytorch3D - ml_collections ## To-dos - [x] Code for reconstruction - [ ] DeepFashion3D multiview image dataset for metallic surfaces - [x] Code for generative models - [ ] Code for DeepFashion3D dataset preparation - [ ] Evaluation code for generative models ## Reconstruction ### Datasets #### DeepFashion3D mesh dataset We provide ground-truth images (rendered under realistic environment light with Blender) for 9 instances in [DeepFashion3D-v2 dataset](https://github.com/GAP-LAB-CUHK-SZ/deepFashion3D). The download links for the raw meshes can be found in their repo. non-metallic material: [training data](https://drive.google.com/file/d/1LwBqLYzamFLyBIiNpD6kEkvySrq2nruG/view?usp=sharing), [test data](https://drive.google.com/file/d/1-47dH_yJrUzKVdKbJenslpdyHwDI6QVo/view?usp=sharing) #### NeRF synthetic dataset Download the [NeRF synthetic dataset archive](https://drive.google.com/uc?export=download&id=18JxhpWD-4ZmuFKLzKlAw-w5PpzZxXOcG) and unzip it into the `data/` folder. #### Hat dataset Download link: https://drive.google.com/file/d/18UmT1NM5wJQ-ZM-rtUXJHXkDc-ba-xVk/view?usp=sharing RGB images, segmentation masks and the corresponding camera poses are included. Alternatively, you may choose to 1) generate the camera poses with COLMAP and 2) create binary segmentation masks by yourself. ### Training #### DeepFashion3D-v2 instances The mesh instances' IDs are [30, 92, 117, 133, 164, 320, 448, 522, 591]. To reconstruct the `$INDEX`-th mesh (0-8) in the list using tet-based G-Shell, run ``` python train_gshelltet_deepfashion.py --config config/deepfashion_mc_256.json --index $INDEX --trainset_path $TRAINSET_PATH --testset_path $TESTSET_PATH --o $OUTPUT_PATH ``` For FlexiCubes + G-Shell, run ``` python train_gflexicubes_deepfashion.py --config config/deepfashion_mc_80.json --index $INDEX --trainset_path $TRAINSET_PATH --testset_path $TESTSET_PATH --o $OUTPUT_PATH ``` #### Synthetic data ``` python train_gshelltet_synthetic.py --config config/nerf_chair.json --o $OUTPUT_PATH ``` #### Hat data ``` python train_gshelltet_polycam.py --config config/polycam_mc_128.json --trainset_path $TRAINSET_PATH --o $OUTPUT_PATH ``` ``` python train_gshelltet_polycam.py --config config/polycam_mc_128.json --trainset_path $TRAINSET_PATH --o $OUTPUT_PATH ``` #### On config files You may consider modify the following, depending on your demand: - `gshell_grid`: the G-Shell grid size. For tet-based G-Shell, please make sure the corresponding tet-grid file exists under `data/tets` (e.g., `256_tets.npz`). Otherwise, follow https://github.com/crawforddoran/quartet and `data/tets/generate_tets.py` to generate the desired tet-grid file. - `n_samples`: the number of MC samples for light rays per rasterized pixel. The higher the better (at a cost of memory and speed). - `batch_size`: how many views sampled in each iteration. - `iteration`: total number of iterations. - `kd_min`, `kd_max`, etc: the min/max of the corresponding PBR material parameter. ## Generation ### Preparation Download info files for the underlying tet grids and binary masks that indicating which locations store useful values in the cubic grids from [tet_info.pt](https://drive.google.com/file/d/19Dw_hOpcVHazpm2_1qA7T7j5xABOUWxv/view?usp=drive_link), [global_mask_res64.pt](https://drive.google.com/file/d/1mlSnu23_u08HH5aO3x5z1V9GzguFzoiT/view?usp=drive_link), [cat_mask_res64.pt](https://drive.google.com/file/d/11Bm4CQX-y1X7R47AfQQz20s7oP6AbbNK/view?usp=drive_link) and [occ_mask_res64.pt](https://drive.google.com/file/d/1qEqqLfZe633GdVkj5MGOCON_kf0l4e4G/view?usp=drive_link). Put these files under `GMeshDiffusion/metadata/`. #### For inference Download the pretrained model for upper-body garments lower-body garments [here](https://huggingface.co/lzzcd001/GMeshDiffusion-Models). #### For training 1) Download the processed Cloth3D garment dataset (for upper-body & lower-body garments) from [link](https://huggingface.co/datasets/lzzcd001/Cloth3D-GShell-Dataset). Alternatively, you may create a grid dataset for your own objects by a) normalize your datapoints by re-center and re-scaling meshes, b) fitting G-Shell representations and c) turn these representations into cubid grids by running `GMeshDiffusion/metadata/tet_to_cubic_grid_dataset.py`. 2) Run `GMeshDiffusion/metadata/get_splits_lower.py` and/or `GMeshDiffusion/metadata/get_splits_upper.py` to generate lists of training and test datapoints. ### Inference 1. Modify the batch size in config files in `GMeshDiffusion/diffusion_config/` and enter the desired directories and values (for model checkpoints, where to store generated samples, etc.) in `GMeshDiffusion/scripts`. 2. Run the eval scripts in `GMeshDiffusion/scripts`. 3. Run `eval_gmeshdiffusion_generated_samples.py` to extract triangular meshes. ### Training 1. Enter the desired directories (for model checkpoints and where to store generated samples) in `GMeshDiffusion/scripts`. 2. Modify the config files if necessary. 3. Run the training scripts in `GMeshDiffusion/scripts`. ## Citation If you find our work useful to your research, please consider citing: ``` @article{liu2024gshell, title={Ghost on the Shell: An Expressive Representation of General 3D Shapes}, author={Liu, Zhen and Feng, Yao and Xiu, Yuliang and Liu, Weiyang and Paull, Liam and Black, Michael J. and Sch{\"o}lkopf, Bernhard}, booktitle={The Twelfth International Conference on Learning Representations}, year={2024}, } ``` ## Acknowledgement We sincerely thank the authors of [Nvdiffrecmc](https://github.com/NVlabs/nvdiffrecmc), [FlexiCubes](https://github.com/nv-tlabs/FlexiCubes) and https://github.com/yang-song/score_sde_pytorch for sharing their codes. Our repo is adapted from [MeshDiffusion](https://github.com/lzzcd001/MeshDiffusion/). ================================================ FILE: configs/deepfashion_mc.json ================================================ { "ref_mesh": "data/spot/spot.obj", "random_textures": true, "iter": 5000, "save_interval": 100, "texture_res": [ 1024, 1024 ], "train_res": [1024, 1024], "batch": 2, "learning_rate": [0.03, 0.005], "ks_min" : [0, 0.001, 0.0], "ks_max" : [0, 1.0, 1.0], "envlight": "data/irrmaps/aerodynamics_workshop_2k.hdr", "lock_pos" : false, "display": [{"latlong" : true}], "background" : "white", "denoiser": "bilateral", "n_samples" : 24, "env_scale" : 2.0, "gshell_grid" : 128, "validate" : true, "laplace_scale" : 6000, "boxscale": [1, 1, 1], "aabb": [-1, -1, -1, 1, 1, 1] } ================================================ FILE: configs/deepfashion_mc_256.json ================================================ { "ref_mesh": "data/spot/spot.obj", "random_textures": true, "iter": 5000, "save_interval": 100, "texture_res": [ 1024, 1024 ], "train_res": [1024, 1024], "batch": 2, "learning_rate": [0.03, 0.005], "ks_min" : [0, 0.001, 0.0], "ks_max" : [0, 1.0, 1.0], "envlight": "data/irrmaps/aerodynamics_workshop_2k.hdr", "lock_pos" : false, "display": [{"latlong" : true}], "background" : "white", "denoiser": "bilateral", "n_samples" : 24, "env_scale" : 2.0, "gshell_grid" : 256, "validate" : true, "laplace_scale" : 6000, "boxscale": [1, 1, 1], "aabb": [-1, -1, -1, 1, 1, 1] } ================================================ FILE: configs/deepfashion_mc_512.json ================================================ { "ref_mesh": "data/spot/spot.obj", "random_textures": true, "iter": 5000, "save_interval": 100, "texture_res": [ 1024, 1024 ], "train_res": [1024, 1024], "batch": 2, "learning_rate": [0.03, 0.005], "ks_min" : [0, 0.001, 0.0], "ks_max" : [0, 1.0, 1.0], "envlight": "data/irrmaps/aerodynamics_workshop_2k.hdr", "validate" : false, "lock_pos" : false, "display": [{"latlong" : true}], "background" : "white", "denoiser": "bilateral", "n_samples" : 12, "env_scale" : 2.0, "gshell_grid" : 512, "validate" : true, "laplace_scale" : 6000, "boxscale": [1, 1, 1], "aabb": [-1, -1, -1, 1, 1, 1] } ================================================ FILE: configs/deepfashion_mc_80.json ================================================ { "ref_mesh": "data/spot/spot.obj", "random_textures": true, "iter": 5000, "save_interval": 100, "texture_res": [ 1024, 1024 ], "train_res": [1024, 1024], "batch": 2, "learning_rate": [0.03, 0.005], "ks_min" : [0, 0.001, 0.0], "ks_max" : [0, 1.0, 1.0], "envlight": "data/irrmaps/aerodynamics_workshop_2k.hdr", "lock_pos" : false, "display": [{"latlong" : true}], "background" : "white", "denoiser": "bilateral", "n_samples" : 24, "env_scale" : 2.0, "gshell_grid" : 80, "validate" : true, "laplace_scale" : 6000, "boxscale": [1, 1, 1], "aabb": [-1, -1, -1, 1, 1, 1] } ================================================ FILE: configs/nerf_chair.json ================================================ { "ref_mesh": "data/nerf_synthetic/chair", "random_textures": true, "iter": 5000, "save_interval": 100, "texture_res": [ 1024, 1024 ], "train_res": [800, 800], "batch": 2, "learning_rate": [0.03, 0.005], "gshell_grid" : 128, "mesh_scale" : 2.1, "validate" : true, "n_samples" : 8, "denoiser" : "bilateral", "display": [{"latlong" : true}, {"bsdf" : "kd"}, {"bsdf" : "ks"}, {"bsdf" : "normal"}], "background" : "white", "boxscale": [1, 1, 1], "aabb": [-1, -1, -1, 1, 1, 1] } ================================================ FILE: configs/polycam_mc.json ================================================ { "ref_mesh": "data/spot/spot.obj", "random_textures": true, "iter": 5000, "save_interval": 100, "texture_res": [ 1024, 1024 ], "train_res": [768, 1024], "batch": 2, "learning_rate": [0.03, 0.005], "ks_min" : [0, 0.001, 0.0], "ks_max" : [0, 1.0, 1.0], "envlight": "data/irrmaps/aerodynamics_workshop_2k.hdr", "lock_pos" : false, "display": [{"latlong" : true}], "background" : "white", "denoiser": "bilateral", "n_samples" : 8, "env_scale" : 2.0, "gshell_grid" : 256, "validate" : true, "laplace_scale" : 6000, "boxscale": [1, 1, 1], "aabb": [-1, -1, -1, 1, 1, 1] } ================================================ FILE: configs/polycam_mc_128.json ================================================ { "ref_mesh": "data/spot/spot.obj", "random_textures": true, "iter": 5000, "save_interval": 100, "texture_res": [ 1024, 1024 ], "train_res": [768, 1024], "batch": 2, "learning_rate": [0.03, 0.005], "ks_min" : [0, 0.001, 0.0], "ks_max" : [0, 1.0, 1.0], "envlight": "data/irrmaps/aerodynamics_workshop_2k.hdr", "lock_pos" : false, "display": [{"latlong" : true}], "background" : "white", "denoiser": "bilateral", "n_samples" : 8, "env_scale" : 2.0, "gshell_grid" : 128, "validate" : true, "laplace_scale" : 6000, "boxscale": [1, 1, 1], "aabb": [-1, -1, -1, 1, 1, 1] } ================================================ FILE: configs/polycam_mc_16samples.json ================================================ { "ref_mesh": "data/spot/spot.obj", "random_textures": true, "iter": 5000, "save_interval": 100, "texture_res": [ 1024, 1024 ], "train_res": [768, 1024], "batch": 2, "learning_rate": [0.03, 0.005], "ks_min" : [0, 0.001, 0.0], "ks_max" : [0, 1.0, 1.0], "envlight": "data/irrmaps/aerodynamics_workshop_2k.hdr", "lock_pos" : false, "display": [{"latlong" : true}], "background" : "white", "denoiser": "bilateral", "n_samples" : 16, "env_scale" : 2.0, "gshell_grid" : 256, "validate" : true, "laplace_scale" : 6000, "boxscale": [1, 1, 1], "aabb": [-1, -1, -1, 1, 1, 1] } ================================================ FILE: data/tets/generate_tets.py ================================================ # Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual # property and proprietary rights in and to this material, related # documentation and any modifications thereto. Any use, reproduction, # disclosure or distribution of this material and related documentation # without an express license agreement from NVIDIA CORPORATION or # its affiliates is strictly prohibited. import os import numpy as np ''' This code segment shows how to use Quartet: https://github.com/crawforddoran/quartet, to generate a tet grid 1) Download, compile and run Quartet as described in the link above. Example usage `quartet meshes/cube.obj 0.5 cube_5.tet` 2) Run the function below to generate a file `cube_32_tet.tet` ''' def generate_tetrahedron_grid_file(res=32, root='..'): frac = 1.0 / res command = 'cd %s/quartet; ' % (root) + \ './quartet meshes/cube.obj %f meshes/cube_%f_tet.tet -s meshes/cube_boundary_%f.obj' % (frac, res, res) os.system(command) ''' This code segment shows how to convert from a quartet .tet file to compressed npz file ''' def convert_from_quartet_to_npz(quartetfile = 'cube_32_tet.tet', npzfile = '32_tets'): file1 = open(quartetfile, 'r') header = file1.readline() numvertices = int(header.split(" ")[1]) numtets = int(header.split(" ")[2]) print(numvertices, numtets) # load vertices vertices = np.loadtxt(quartetfile, skiprows=1, max_rows=numvertices) print(vertices.shape) # load indices indices = np.loadtxt(quartetfile, dtype=int, skiprows=1+numvertices, max_rows=numtets) print(indices.shape) np.savez_compressed(npzfile, vertices=vertices, indices=indices) ================================================ FILE: dataset/__init__.py ================================================ # Copyright (c) 2020-2022, NVIDIA CORPORATION. All rights reserved. # # NVIDIA CORPORATION and its licensors retain all intellectual property # and proprietary rights in and to this software, related documentation # and any modifications thereto. Any use, reproduction, disclosure or # distribution of this software and related documentation without an express # license agreement from NVIDIA CORPORATION is strictly prohibited. from .dataset import Dataset from .dataset_mesh import DatasetMesh from .dataset_nerf import DatasetNERF from .dataset_llff import DatasetLLFF ================================================ FILE: dataset/dataset.py ================================================ # Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual # property and proprietary rights in and to this material, related # documentation and any modifications thereto. Any use, reproduction, # disclosure or distribution of this material and related documentation # without an express license agreement from NVIDIA CORPORATION or # its affiliates is strictly prohibited. import torch class Dataset(torch.utils.data.Dataset): """Basic dataset interface""" def __init__(self): super().__init__() def __len__(self): raise NotImplementedError def __getitem__(self): raise NotImplementedError def collate(self, batch): iter_res, iter_spp = batch[0]['resolution'], batch[0]['spp'] return { 'mv' : torch.cat(list([item['mv'] for item in batch]), dim=0), 'mvp' : torch.cat(list([item['mvp'] for item in batch]), dim=0), 'campos' : torch.cat(list([item['campos'] for item in batch]), dim=0), 'resolution' : iter_res, 'spp' : iter_spp, 'img' : torch.cat(list([item['img'] for item in batch]), dim=0) if 'img' in batch[0] else None, 'img_second' : torch.cat(list([item['img_second'] for item in batch]), dim=0) if 'img_second' in batch[0] else None, 'invdepth' : torch.cat(list([item['invdepth'] for item in batch]), dim=0)if 'invdepth' in batch[0] else None, 'invdepth_second' : torch.cat(list([item['invdepth_second'] for item in batch]), dim=0) if 'invdepth_second' in batch[0] else None, 'envlight_transform': torch.cat(list([item['envlight_transform'] for item in batch]), dim=0) if 'envlight_transform' in batch and batch[0]['envlight_transform'] is not None else None, } ================================================ FILE: dataset/dataset_deepfashion.py ================================================ # Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual # property and proprietary rights in and to this material, related # documentation and any modifications thereto. Any use, reproduction, # disclosure or distribution of this material and related documentation # without an express license agreement from NVIDIA CORPORATION or # its affiliates is strictly prohibited. import os import glob import json import torch import numpy as np from render import util from .dataset import Dataset import cv2 as cv # This function is borrowed from IDR: https://github.com/lioryariv/idr def load_K_Rt_from_P(filename, P=None): if P is None: lines = open(filename).read().splitlines() if len(lines) == 4: lines = lines[1:] lines = [[x[0], x[1], x[2], x[3]] for x in (x.split(" ") for x in lines)] P = np.asarray(lines).astype(np.float32).squeeze() out = cv.decomposeProjectionMatrix(P) K = out[0] R = out[1] t = out[2] K = K / K[2, 2] intrinsics = np.eye(4) intrinsics[:3, :3] = K pose = np.eye(4, dtype=np.float32) pose[:3, :3] = R.transpose() pose[:3, 3] = (t[:3] / t[3])[:, 0] return intrinsics, pose def _load_img(path): img = util.load_image_raw(path) if img.dtype != np.float32: # LDR image img = torch.tensor(img / 255, dtype=torch.float32) img[..., 0:3] = util.srgb_to_rgb(img[..., 0:3]) else: img = torch.tensor(img, dtype=torch.float32) return img class DatasetDeepFashion(Dataset): def __init__(self, base_dir, FLAGS, examples=None): self.FLAGS = FLAGS self.examples = examples self.base_dir = base_dir # Load config / transforms self.n_images = 72 ### hardcoded self.fovy = np.deg2rad(60) self.proj_mtx = util.perspective( self.fovy, self.FLAGS.display_res[1] / self.FLAGS.display_res[0], self.FLAGS.cam_near_far[0], self.FLAGS.cam_near_far[1] ) camera_dict = np.load(os.path.join(self.base_dir, 'cameras_sphere.npz')) # world_mat is a projection matrix from world to image self.world_mats_np = [camera_dict['world_mat_%d' % idx].astype(np.float32) for idx in range(self.n_images)] self.scale_mats_np = [] # scale_mat: used for coordinate normalization, we assume the scene to render is inside a unit sphere at origin. self.scale_mats_np = [camera_dict['scale_mat_%d' % idx].astype(np.float32) for idx in range(self.n_images)] self.intrinsics_all = [] self.pose_all = [] for scale_mat, world_mat in zip(self.scale_mats_np, self.world_mats_np): P = world_mat @ scale_mat P = P[:3, :4] intrinsics, pose = load_K_Rt_from_P(None, P) self.intrinsics_all.append(torch.from_numpy(intrinsics).float()) self.pose_all.append(torch.from_numpy(pose).float()) # Determine resolution & aspect ratio self.resolution = _load_img(os.path.join(self.base_dir, '{:03d}.png'.format(0))).shape[0:2] self.aspect = self.resolution[1] / self.resolution[0] if self.FLAGS.local_rank == 0: print("DatasetNERF: %d images with shape [%d, %d]" % (self.n_images, self.resolution[0], self.resolution[1])) def _parse_frame(self, idx): # Load image data and modelview matrix img = _load_img(os.path.join(self.base_dir, '{:03d}.png'.format(idx))) img[:,:,:3] = img[:,:,:3] * img[:,:,3:] img[:,:,3] = torch.sign(img[:,:,3]) assert img.size(-1) == 4 flip_mat = torch.tensor([ [ 1, 0, 0, 0], [ 0, -1, 0, 0], [ 0, 0, -1, 0], [ 0, 0, 0, 1] ], dtype=torch.float) mv = flip_mat @ torch.linalg.inv(self.pose_all[idx]) campos = torch.linalg.inv(mv)[:3, 3] mvp = self.proj_mtx @ mv return img[None, ...].cuda(), mv[None, ...].cuda(), mvp[None, ...].cuda(), campos[None, ...].cuda() # Add batch dimension def __len__(self): return self.n_images if self.examples is None else self.examples def __getitem__(self, itr): iter_res = self.FLAGS.train_res img = [] img, mv, mvp, campos = self._parse_frame(itr % self.n_images) return { 'mv' : mv, 'mvp' : mvp, 'campos' : campos, 'resolution' : iter_res, 'spp' : self.FLAGS.spp, 'img' : img } ================================================ FILE: dataset/dataset_deepfashion_testset.py ================================================ # Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual # property and proprietary rights in and to this material, related # documentation and any modifications thereto. Any use, reproduction, # disclosure or distribution of this material and related documentation # without an express license agreement from NVIDIA CORPORATION or # its affiliates is strictly prohibited. import os import glob import json import torch import numpy as np from render import util from .dataset import Dataset import cv2 as cv # This function is borrowed from IDR: https://github.com/lioryariv/idr def load_K_Rt_from_P(filename, P=None): if P is None: lines = open(filename).read().splitlines() if len(lines) == 4: lines = lines[1:] lines = [[x[0], x[1], x[2], x[3]] for x in (x.split(" ") for x in lines)] P = np.asarray(lines).astype(np.float32).squeeze() out = cv.decomposeProjectionMatrix(P) K = out[0] R = out[1] t = out[2] K = K / K[2, 2] intrinsics = np.eye(4) intrinsics[:3, :3] = K pose = np.eye(4, dtype=np.float32) pose[:3, :3] = R.transpose() pose[:3, 3] = (t[:3] / t[3])[:, 0] return intrinsics, pose def _load_img(path): img = util.load_image_raw(path) if img.dtype != np.float32: # LDR image img = torch.tensor(img / 255, dtype=torch.float32) img[..., 0:3] = util.srgb_to_rgb(img[..., 0:3]) else: img = torch.tensor(img, dtype=torch.float32) return img def _load_mask(path): img = util.load_image_raw(path) if img.dtype != np.float32: # LDR image img = torch.tensor(img / 255, dtype=torch.float32) else: img = torch.tensor(img, dtype=torch.float32) return img class DatasetDeepFashionTestset(Dataset): def __init__(self, base_dir, FLAGS, examples=None): self.FLAGS = FLAGS self.examples = examples self.base_dir = base_dir # Load config / transforms self.n_images = 200 ### hardcoded proj_mtx_all = np.load(os.path.join(self.base_dir, 'proj_mtx_all.npy')) self.intrinsics_all = [] self.pose_all = [] self.fovy = np.deg2rad(60) self.proj_mtx = util.perspective( self.fovy, self.FLAGS.display_res[1] / self.FLAGS.display_res[0], self.FLAGS.cam_near_far[0], self.FLAGS.cam_near_far[1] ) for i in range(proj_mtx_all.shape[0]): P = proj_mtx_all[i] P = P[:3, :4] intrinsics, pose = load_K_Rt_from_P(None, P) self.intrinsics_all.append(torch.from_numpy(intrinsics).float()) self.pose_all.append(torch.from_numpy(pose).float()) # Determine resolution & aspect ratio self.resolution = _load_img(os.path.join(self.base_dir, '{:03d}.png'.format(0))).shape[0:2] self.aspect = self.resolution[1] / self.resolution[0] if self.FLAGS.local_rank == 0: print("DatasetNERF: %d images with shape [%d, %d]" % (self.n_images, self.resolution[0], self.resolution[1])) def _parse_frame(self, idx): # Load image data and modelview matrix img = _load_img(os.path.join(self.base_dir, '{:03d}.png'.format(idx))) assert img.size(-1) == 4 flip_mat = torch.tensor([ [ 1, 0, 0, 0], [ 0, -1, 0, 0], [ 0, 0, -1, 0], [ 0, 0, 0, 1] ], dtype=torch.float) mv = flip_mat @ torch.linalg.inv(self.pose_all[idx]) campos = torch.linalg.inv(mv)[:3, 3] mvp = self.proj_mtx @ mv return img[None, ...].cuda(), mv[None, ...].cuda(), mvp[None, ...].cuda(), campos[None, ...].cuda() # Add batch dimension def __len__(self): return self.n_images if self.examples is None else self.examples def __getitem__(self, itr): iter_res = self.FLAGS.train_res img = [] img, mv, mvp, campos = self._parse_frame(itr % self.n_images) return { 'mv' : mv, 'mvp' : mvp, 'campos' : campos, 'resolution' : iter_res, 'spp' : self.FLAGS.spp, 'img' : img } ================================================ FILE: dataset/dataset_llff.py ================================================ # Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual # property and proprietary rights in and to this material, related # documentation and any modifications thereto. Any use, reproduction, # disclosure or distribution of this material and related documentation # without an express license agreement from NVIDIA CORPORATION or # its affiliates is strictly prohibited. import os import glob import torch import numpy as np from render import util from .dataset import Dataset def _load_mask(fn): img = torch.tensor(util.load_image(fn), dtype=torch.float32) if len(img.shape) == 2: img = img[..., None].repeat(1, 1, 3) return img def _load_img(fn): img = util.load_image_raw(fn) if img.dtype != np.float32: # LDR image img = torch.tensor(img / 255, dtype=torch.float32) img[..., 0:3] = util.srgb_to_rgb(img[..., 0:3]) else: img = torch.tensor(img, dtype=torch.float32) return img ############################################################################### # LLFF datasets (real world camera lightfields) ############################################################################### class DatasetLLFF(Dataset): def __init__(self, base_dir, FLAGS, examples=None): self.FLAGS = FLAGS self.base_dir = base_dir self.examples = examples # Enumerate all image files and get resolution all_img = [f for f in sorted(glob.glob(os.path.join(self.base_dir, "images", "*"))) if f.lower().endswith('png') or f.lower().endswith('jpg') or f.lower().endswith('jpeg')] self.resolution = _load_img(all_img[0]).shape[0:2] # Load camera poses poses_bounds = np.load(os.path.join(self.base_dir, 'poses_bounds.npy')) poses = poses_bounds[:, :-2].reshape([-1, 3, 5]).transpose([1,2,0]) poses = np.concatenate([poses[:, 1:2, :], -poses[:, 0:1, :], poses[:, 2:, :]], 1) # Taken from nerf, swizzles from LLFF to expected coordinate system poses = np.moveaxis(poses, -1, 0).astype(np.float32) lcol = np.array([0,0,0,1], dtype=np.float32)[None, None, :].repeat(poses.shape[0], 0) self.imvs = torch.tensor(np.concatenate((poses[:, :, 0:4], lcol), axis=1), dtype=torch.float32) self.aspect = self.resolution[1] / self.resolution[0] # width / height self.fovy = util.focal_length_to_fovy(poses[:, 2, 4], poses[:, 0, 4]) # Recenter scene so lookat position is origin center = util.lines_focal(self.imvs[..., :3, 3], -self.imvs[..., :3, 2]) self.imvs[..., :3, 3] = self.imvs[..., :3, 3] - center[None, ...] if self.FLAGS.local_rank == 0: print("DatasetLLFF: %d images with shape [%d, %d]" % (len(all_img), self.resolution[0], self.resolution[1])) print("DatasetLLFF: auto-centering at %s" % (center.cpu().numpy())) # Pre-load from disc to avoid slow png parsing if self.FLAGS.pre_load: self.preloaded_data = [] for i in range(self.imvs.shape[0]): self.preloaded_data += [self._parse_frame(i)] def _parse_frame(self, idx): all_img = [f for f in sorted(glob.glob(os.path.join(self.base_dir, "images", "*"))) if f.lower().endswith('png') or f.lower().endswith('jpg') or f.lower().endswith('jpeg')] all_mask = [f for f in sorted(glob.glob(os.path.join(self.base_dir, "masks", "*"))) if f.lower().endswith('png') or f.lower().endswith('jpg') or f.lower().endswith('jpeg')] assert len(all_img) == self.imvs.shape[0] and len(all_mask) == self.imvs.shape[0] # Load image+mask data img = _load_img(all_img[idx]) mask = _load_mask(all_mask[idx]) img = torch.cat((img, mask[..., 0:1]), dim=-1) # Setup transforms proj = util.perspective(self.fovy[idx, ...], self.aspect, self.FLAGS.cam_near_far[0], self.FLAGS.cam_near_far[1]) mv = torch.linalg.inv(self.imvs[idx, ...]) campos = torch.linalg.inv(mv)[:3, 3] mvp = proj @ mv return img[None, ...], mv[None, ...], mvp[None, ...], campos[None, ...] # Add batch dimension def __len__(self): return self.imvs.shape[0] if self.examples is None else self.examples def __getitem__(self, itr): if self.FLAGS.pre_load: img, mv, mvp, campos = self.preloaded_data[itr % self.imvs.shape[0]] else: img, mv, mvp, campos = self._parse_frame(itr % self.imvs.shape[0]) return { 'mv' : mv, 'mvp' : mvp, 'campos' : campos, 'resolution' : self.resolution, 'spp' : self.FLAGS.spp, 'img' : img } ================================================ FILE: dataset/dataset_mesh.py ================================================ # Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual # property and proprietary rights in and to this material, related # documentation and any modifications thereto. Any use, reproduction, # disclosure or distribution of this material and related documentation # without an express license agreement from NVIDIA CORPORATION or # its affiliates is strictly prohibited. import numpy as np import torch from render import util from render import mesh from render import render from render import light from .dataset import Dataset ############################################################################### # Reference dataset using mesh & rendering ############################################################################### class DatasetMesh(Dataset): def __init__(self, ref_mesh, glctx, cam_radius, FLAGS, validate=False, mesh_center=None): # Init self.glctx = glctx self.cam_radius = cam_radius self.FLAGS = FLAGS self.validate = validate self.fovy = np.deg2rad(45) self.aspect = FLAGS.train_res[1] / FLAGS.train_res[0] self.random_lgt = FLAGS.random_lgt self.camera_lgt = False self.mesh_center = mesh_center if self.FLAGS.local_rank == 0: print("DatasetMesh: ref mesh has %d triangles and %d vertices" % (ref_mesh.t_pos_idx.shape[0], ref_mesh.v_pos.shape[0])) # Sanity test training texture resolution ref_texture_res = np.maximum(ref_mesh.material['kd'].getRes(), ref_mesh.material['ks'].getRes()) if 'normal' in ref_mesh.material: ref_texture_res = np.maximum(ref_texture_res, ref_mesh.material['normal'].getRes()) if self.FLAGS.local_rank == 0 and FLAGS.texture_res[0] < ref_texture_res[0] or FLAGS.texture_res[1] < ref_texture_res[1]: print("---> WARNING: Picked a texture resolution lower than the reference mesh [%d, %d] < [%d, %d]" % (FLAGS.texture_res[0], FLAGS.texture_res[1], ref_texture_res[0], ref_texture_res[1])) # Load environment map texture self.envlight = light.load_env(FLAGS.envmap, scale=FLAGS.env_scale) self.ref_mesh = mesh.compute_tangents(ref_mesh) def _rotate_scene(self, itr): proj_mtx = util.perspective(self.fovy, self.FLAGS.display_res[1] / self.FLAGS.display_res[0], self.FLAGS.cam_near_far[0], self.FLAGS.cam_near_far[1]) # Smooth rotation for display. ang = (itr / 50) * np.pi * 2 mv = util.translate(0, 0, -self.cam_radius) @ (util.rotate_x(-0.4) @ util.rotate_y(ang)) mvp = proj_mtx @ mv campos = torch.linalg.inv(mv)[:3, 3] return mv[None, ...].cuda(), mvp[None, ...].cuda(), campos[None, ...].cuda(), self.FLAGS.display_res, self.FLAGS.spp def _random_scene(self): # ============================================================================================== # Setup projection matrix # ============================================================================================== iter_res = self.FLAGS.train_res proj_mtx = util.perspective(self.fovy, iter_res[1] / iter_res[0], self.FLAGS.cam_near_far[0], self.FLAGS.cam_near_far[1]) # ============================================================================================== # Random camera & light position # ============================================================================================== # Random rotation/translation matrix for optimization. if self.mesh_center is not None: mv = ( util.translate(-self.mesh_center[0], -self.mesh_center[1], -self.mesh_center[2]-self.cam_radius) @ util.random_rotation_translation(0.25) ) else: mv = util.translate(0, 0, -self.cam_radius) @ util.random_rotation_translation(0.25) mvp = proj_mtx @ mv campos = torch.linalg.inv(mv)[:3, 3] return mv[None, ...].cuda(), mvp[None, ...].cuda(), campos[None, ...].cuda(), iter_res, self.FLAGS.spp # Add batch dimension def __len__(self): return 50 if self.validate else (self.FLAGS.iter + 1) * self.FLAGS.batch def __getitem__(self, itr): # ============================================================================================== # Randomize scene parameters # ============================================================================================== if self.validate: mv, mvp, campos, iter_res, iter_spp = self._rotate_scene(itr) camera_mv = None else: mv, mvp, campos, iter_res, iter_spp = self._random_scene() if self.random_lgt: rnd_rot = util.random_rotation() camera_mv = rnd_rot.unsqueeze(0).clone() elif self.camera_lgt: camera_mv = mv.clone() else: camera_mv = None with torch.no_grad(): rendered = render.render_mesh(self.glctx, self.ref_mesh, mvp, campos, self.envlight, iter_res, spp=iter_spp, num_layers=self.FLAGS.layers, msaa=True, background=None, shade_data=True) return { 'mv' : mv, 'mvp' : mvp, 'campos' : campos, 'resolution' : iter_res, 'spp' : iter_spp, 'img' : rendered['shaded'], 'img_second' : rendered['shaded_second'], 'invdepth' : rendered['invdepth'], 'invdepth_second' : rendered['invdepth_second'], 'envlight_transform': camera_mv } ================================================ FILE: dataset/dataset_nerf.py ================================================ # Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual # property and proprietary rights in and to this material, related # documentation and any modifications thereto. Any use, reproduction, # disclosure or distribution of this material and related documentation # without an express license agreement from NVIDIA CORPORATION or # its affiliates is strictly prohibited. import os import glob import json import torch import numpy as np from render import util from .dataset import Dataset ############################################################################### # NERF image based dataset (synthetic) ############################################################################### def _load_img(path): files = glob.glob(path + '.*') assert len(files) > 0, "Tried to find image file for: %s, but found 0 files" % (path) img = util.load_image_raw(files[0]) if img.dtype != np.float32: # LDR image img = torch.tensor(img / 255, dtype=torch.float32) img[..., 0:3] = util.srgb_to_rgb(img[..., 0:3]) else: img = torch.tensor(img, dtype=torch.float32) return img class DatasetNERF(Dataset): def __init__(self, cfg_path, FLAGS, examples=None): self.FLAGS = FLAGS self.examples = examples self.base_dir = os.path.dirname(cfg_path) # Load config / transforms self.cfg = json.load(open(cfg_path, 'r')) self.n_images = len(self.cfg['frames']) # Determine resolution & aspect ratio self.resolution = _load_img(os.path.join(self.base_dir, self.cfg['frames'][0]['file_path'])).shape[0:2] self.aspect = self.resolution[1] / self.resolution[0] if self.FLAGS.local_rank == 0: print("DatasetNERF: %d images with shape [%d, %d]" % (self.n_images, self.resolution[0], self.resolution[1])) # Pre-load from disc to avoid slow png parsing if self.FLAGS.pre_load: self.preloaded_data = [] for i in range(self.n_images): self.preloaded_data += [self._parse_frame(self.cfg, i)] def _parse_frame(self, cfg, idx): # Config projection matrix (static, so could be precomputed) fovy = util.fovx_to_fovy(cfg['camera_angle_x'], self.aspect) proj = util.perspective(fovy, self.aspect, self.FLAGS.cam_near_far[0], self.FLAGS.cam_near_far[1]) # Load image data and modelview matrix img = _load_img(os.path.join(self.base_dir, cfg['frames'][idx]['file_path'])) mv = torch.linalg.inv(torch.tensor(cfg['frames'][idx]['transform_matrix'], dtype=torch.float32)) mv = mv @ util.rotate_x(-np.pi / 2) campos = torch.linalg.inv(mv)[:3, 3] mvp = proj @ mv return img[None, ...], mv[None, ...], mvp[None, ...], campos[None, ...] # Add batch dimension def __len__(self): return self.n_images if self.examples is None else self.examples def __getitem__(self, itr): iter_res = self.FLAGS.train_res img = [] fovy = util.fovx_to_fovy(self.cfg['camera_angle_x'], self.aspect) if self.FLAGS.pre_load: img, mv, mvp, campos = self.preloaded_data[itr % self.n_images] else: img, mv, mvp, campos = self._parse_frame(self.cfg, itr % self.n_images) return { 'mv' : mv, 'mvp' : mvp, 'campos' : campos, 'resolution' : iter_res, 'spp' : self.FLAGS.spp, 'img' : img } ================================================ FILE: dataset/dataset_nerf_colmap.py ================================================ # Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual # property and proprietary rights in and to this material, related # documentation and any modifications thereto. Any use, reproduction, # disclosure or distribution of this material and related documentation # without an express license agreement from NVIDIA CORPORATION or # its affiliates is strictly prohibited. import os import glob import json import torch import numpy as np from render import util from .dataset import Dataset ############################################################################### # NERF image based dataset (synthetic) ############################################################################### def _load_img(path): img = util.load_image_raw(path) if img.dtype != np.float32: # LDR image img = torch.tensor(img / 255, dtype=torch.float32) img[..., 0:3] = util.srgb_to_rgb(img[..., 0:3]) else: img = torch.tensor(img, dtype=torch.float32) return img class DatasetNERF(Dataset): def __init__(self, cfg_path, FLAGS, examples=None): self.FLAGS = FLAGS self.examples = examples self.base_dir = os.path.dirname(cfg_path) # Load config / transforms self.cfg = json.load(open(cfg_path, 'r')) self.n_images = len(self.cfg['frames']) # Determine resolution & aspect ratio self.resolution = _load_img(os.path.join(self.base_dir, self.cfg['frames'][0]['file_path'])).shape[0:2] self.aspect = self.resolution[1] / self.resolution[0] if self.FLAGS.local_rank == 0: print("DatasetNERF: %d images with shape [%d, %d]" % (self.n_images, self.resolution[0], self.resolution[1])) # Pre-load from disc to avoid slow png parsing if self.FLAGS.pre_load: self.preloaded_data = [] for i in range(self.n_images): self.preloaded_data += [self._parse_frame(self.cfg, i)] def _parse_frame(self, cfg, idx): # Config projection matrix (static, so could be precomputed) fovy = util.fovx_to_fovy(cfg['frames'][idx]['camera_angle_x'], self.aspect) proj = util.perspective(fovy, self.aspect, self.FLAGS.cam_near_far[0], self.FLAGS.cam_near_far[1]) # Load image data and modelview matrix img = _load_img(os.path.join(self.base_dir, cfg['frames'][idx]['file_path'])) mask = _load_img(os.path.join(self.base_dir, cfg['frames'][idx]['file_path']).replace('/image/', '/mask/').replace('.jpg', '.png')) img = torch.cat([img, mask[:,:,:1]], dim=-1) mv = torch.linalg.inv(torch.tensor(cfg['frames'][idx]['transform_matrix'], dtype=torch.float32)) mv = mv @ util.rotate_x(-np.pi / 2) campos = torch.linalg.inv(mv)[:3, 3] mvp = proj @ mv return img[None, ...], mv[None, ...], mvp[None, ...], campos[None, ...] # Add batch dimension def __len__(self): return self.n_images if self.examples is None else self.examples def __getitem__(self, itr): iter_res = self.FLAGS.train_res img = [] fovy = util.fovx_to_fovy(self.cfg['frames'][itr % self.n_images]['camera_angle_x'], self.aspect) if self.FLAGS.pre_load: img, mv, mvp, campos = self.preloaded_data[itr % self.n_images] else: img, mv, mvp, campos = self._parse_frame(self.cfg, itr % self.n_images) return { 'mv' : mv, 'mvp' : mvp, 'campos' : campos, 'resolution' : iter_res, 'spp' : self.FLAGS.spp, 'img' : img } ================================================ FILE: denoiser/denoiser.py ================================================ import os import torch import numpy as np import math from render import util if "TWOSIDED_TEXTURE" not in os.environ or os.environ["TWOSIDED_TEXTURE"] == "True": from render import optixutils as ou else: from render import optixutils_single_sided as ou ############################################################################### # Bilateral denoiser # # Loosely based on SVGF, but removing temporal components and variance stopping guides. # https://research.nvidia.com/publication/2017-07_spatiotemporal-variance-guided-filtering-real-time-reconstruction-path-traced ############################################################################### class BilateralDenoiser(torch.nn.Module): def __init__(self, influence=1.0): super(BilateralDenoiser, self).__init__() self.set_influence(influence) def set_influence(self, factor): self.sigma = max(factor * 2, 0.0001) self.variance = self.sigma**2. self.N = 2 * math.ceil(self.sigma * 2.5) + 1 def forward(self, input): col = input[..., 0:3] nrm = util.safe_normalize(input[..., 3:6]) # Bent normals can produce normals of length < 1 here zdz = input[..., 6:8] return ou.bilateral_denoiser(col, nrm, zdz, self.sigma) ================================================ FILE: eval_gmeshdiffusion_generated_samples.py ================================================ # Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual # property and proprietary rights in and to this material, related # documentation and any modifications thereto. Any use, reproduction, # disclosure or distribution of this material and related documentation # without an express license agreement from NVIDIA CORPORATION or # its affiliates is strictly prohibited. import os import argparse import json import numpy as np import torch # Import topology / geometry trainers from geometry.gshell_tets_geometry import GShellTetsGeometry from render import texture import pymeshlab from pytorch3d.io import save_obj import tqdm RADIUS = 4.0 # RADIUS = 2.5 # Enable to debug back-prop anomalies # torch.autograd.set_detect_anomaly(True) #---------------------------------------------------------------------------- # Main function. #---------------------------------------------------------------------------- if __name__ == "__main__": parser = argparse.ArgumentParser(description='nvdiffrec') parser.add_argument('--config', type=str, default=None, help='Config file') parser.add_argument('-i', '--iter', type=int, default=5000) parser.add_argument('-b', '--batch', type=int, default=1) parser.add_argument('-s', '--spp', type=int, default=1) parser.add_argument('-l', '--layers', type=int, default=1) parser.add_argument('-r', '--train-res', nargs=2, type=int, default=[512, 512]) parser.add_argument('-dr', '--display-res', type=int, default=None) parser.add_argument('-tr', '--texture-res', nargs=2, type=int, default=[1024, 1024]) parser.add_argument('-di', '--display-interval', type=int, default=0) parser.add_argument('-si', '--save-interval', type=int, default=1000) parser.add_argument('-lr', '--learning-rate', type=float, default=0.01) parser.add_argument('-mr', '--min-roughness', type=float, default=0.08) parser.add_argument('-mip', '--custom-mip', action='store_true', default=False) parser.add_argument('-rt', '--random-textures', action='store_true', default=False) parser.add_argument('-bg', '--background', default='checker', choices=['black', 'white', 'checker', 'reference']) parser.add_argument('--loss', default='logl1', choices=['logl1', 'logl2', 'mse', 'smape', 'relmse']) parser.add_argument('-o', '--out-dir', type=str, default=None) parser.add_argument('-rm', '--ref_mesh', type=str) parser.add_argument('-bm', '--base-mesh', type=str, default=None) parser.add_argument('--validate', type=bool, default=True) parser.add_argument('--grid_root', type=str) FLAGS = parser.parse_args() FLAGS.mtl_override = None # Override material of model FLAGS.dmtet_grid = 64 # Resolution of initial tet grid. We provide 64 and 128 resolution grids. Other resolutions can be generated with https://github.com/crawforddoran/quartet FLAGS.mesh_scale = 2.3 # Scale of tet grid box. Adjust to cover the model FLAGS.env_scale = 1.0 # Env map intensity multiplier FLAGS.envmap = None # HDR environment probe FLAGS.display = None # Conf validation window/display. E.g. [{"relight" : }] FLAGS.camera_space_light = False # Fixed light in camera space. This is needed for setups like ethiopian head where the scanned object rotates on a stand. FLAGS.lock_light = False # Disable light optimization in the second pass FLAGS.lock_pos = False # Disable vertex position optimization in the second pass FLAGS.sdf_regularizer = 0.2 # Weight for sdf regularizer (see paper for details) FLAGS.laplace = "relative" # Mesh Laplacian ["absolute", "relative"] FLAGS.laplace_scale = 10000.0 # Weight for Laplacian regularizer. Default is relative with large weight FLAGS.pre_load = True # Pre-load entire dataset into memory for faster training FLAGS.kd_min = [ 0.0, 0.0, 0.0, 0.0] # Limits for kd FLAGS.kd_max = [ 1.0, 1.0, 1.0, 1.0] FLAGS.ks_min = [ 0.0, 0.08, 0.0] # Limits for ks FLAGS.ks_max = [ 1.0, 1.0, 1.0] FLAGS.nrm_min = [-1.0, -1.0, 0.0] # Limits for normal map FLAGS.nrm_max = [ 1.0, 1.0, 1.0] FLAGS.cam_near_far = [0.1, 1000.0] FLAGS.use_tanh_deform = False FLAGS.use_sdf_mlp = False FLAGS.force_default_mtl = True FLAGS.twosided_texture = True FLAGS.random_lgt = False FLAGS.sphere_init = False FLAGS.num_smooth_steps = 3 FLAGS.use_msdf_mlp = False if FLAGS.config is not None: data = json.load(open(FLAGS.config, 'r')) for key in data: FLAGS.__dict__[key] = data[key] os.makedirs(FLAGS.out_dir, exist_ok=True) mtl_default_diffuse = { 'name' : '_default_mat', 'bsdf': 'diffuse', 'uniform': True, 'kd' : texture.Texture2D(torch.tensor([0.75, 0.3, 0.6], dtype=torch.float32, device='cuda')), 'ks' : texture.Texture2D(torch.tensor([0.0, 0.0, 0.0], dtype=torch.float32, device='cuda')) } if FLAGS.force_default_mtl: mtl_default = mtl_default_diffuse else: mtl_default = None tet_path = './data/tets/64_tets_cropped_reordered.npz' tet = np.load(tet_path) vertices = torch.tensor(tet['vertices']) edges = torch.tensor(tet['edges']).long() vertices_unique = vertices[:].unique() dx = (vertices_unique[1] - vertices_unique[0]) / 2.0 vertices_discretized = (torch.round( (vertices - vertices.min()) / dx) ).long() midpoints = (vertices[edges[:, 0]] + vertices[edges[:, 1]]) / 2.0 midpoints_dicretized = (torch.round( (midpoints - vertices.min()) / dx) ).long() aabb = torch.tensor(FLAGS.aabb, dtype=torch.float).cuda().view(2, 3) center = aabb.mean(0, keepdim=True) / 2.0 mesh_scale = 3.8 mesh_scale = mesh_scale / torch.max(aabb[1] - aabb[0]).item() count = 0 grid_root = FLAGS.grid_root geometry = GShellTetsGeometry(FLAGS.dmtet_grid, FLAGS.mesh_scale, FLAGS, tet_init_file=tet_path, extract_from_generative=True) with torch.no_grad(): for grid_name in tqdm.tqdm(sorted(list(os.listdir(grid_root)))): if '_occ' in grid_name: continue grid_all = torch.load( os.path.join(grid_root, grid_name), map_location='cuda' ) occgrid_all = torch.load( os.path.join(grid_root, grid_name).replace('.pt', '_occ.pt'), map_location='cuda' )[:, 0] for i in tqdm.trange(grid_all.size(0), leave=False): mesh_path = FLAGS.out_dir os.makedirs(mesh_path, exist_ok=True) mesh_savepath = os.path.join(mesh_path, '{:06d}.obj'.format(count)) if os.path.exists(mesh_savepath): count += 1 continue grid = grid_all[i] occgrid = occgrid_all[i] sdf_sign = ( grid[0, vertices_discretized[:, 0], vertices_discretized[:, 1], vertices_discretized[:, 2]] ).cuda().float() geometry.deform.data[:] = ( grid[1:4, vertices_discretized[:, 0], vertices_discretized[:, 1], vertices_discretized[:, 2]] ).cuda().transpose(0, 1).float().clamp(-1, 1) sdf_coeff = torch.ones(128, 128, 128).float().cuda() * 0.5 msdf_sign = torch.zeros(128, 128, 128).float().cuda() msdf_sign[midpoints_dicretized[:, 0], midpoints_dicretized[:, 1], midpoints_dicretized[:, 2]] = torch.sign( grid[0, midpoints_dicretized[:, 0], midpoints_dicretized[:, 1], midpoints_dicretized[:, 2]].cuda() ).float() geometry.deform.data[:] = geometry.deform.data[:].clip(-1.0, 1.0) geometry.deform_scale = 2.0 base_mesh = geometry.getMesh_from_augmented_grid_withocc(mtl_default, torch.sign(sdf_sign), sdf_coeff, msdf_sign, occgrid=occgrid)['imesh'] ### rescale and translate back to align with the dataset base_mesh.v_pos = (base_mesh.v_pos / mesh_scale) + center ### save post-processed mesh save_obj( verts=base_mesh.v_pos, faces=base_mesh.t_pos_idx, f=mesh_savepath ) ms = pymeshlab.MeshSet() ms.load_new_mesh(mesh_savepath) ms.meshing_remove_unreferenced_vertices() ms.meshing_isotropic_explicit_remeshing() ms.apply_coord_laplacian_smoothing(stepsmoothnum=FLAGS.num_smooth_steps, cotangentweight=True) # ms.apply_coord_hc_laplacian_smoothing() # ms.apply_coord_laplacian_smoothing(stepsmoothnum=3, cotangentweight=True) ## for smoother surface ms.meshing_isotropic_explicit_remeshing() ms.apply_filter_script() ms.save_current_mesh(mesh_savepath) count += 1 ================================================ FILE: geometry/embedding.py ================================================ import torch from torch import nn class Embedding(nn.Module): def __init__(self, in_channels, N_freqs, logscale=True): """ Defines a function that embeds x to (x, sin(2^k x), cos(2^k x), ...) in_channels: number of input channels (3 for both xyz and direction) """ super(Embedding, self).__init__() self.N_freqs = N_freqs self.in_channels = in_channels self.funcs = [torch.sin, torch.cos] self.out_channels = in_channels*(len(self.funcs)*N_freqs+1) if logscale: self.freq_bands = 2**torch.linspace(0, N_freqs-1, N_freqs) else: self.freq_bands = torch.linspace(1, 2**(N_freqs-1), N_freqs) def forward(self, x): """ Embeds x to (x, sin(2^k x), cos(2^k x), ...) Different from the paper, "x" is also in the output See https://github.com/bmild/nerf/issues/12 Inputs: x: (B, self.in_channels) Outputs: out: (B, self.out_channels) """ out = [x] for freq in self.freq_bands: for func in self.funcs: out += [func(freq*x)] return torch.cat(out, -1) ================================================ FILE: geometry/flexicubes_table.py ================================================ # Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property # and proprietary rights in and to this software, related documentation # and any modifications thereto. Any use, reproduction, disclosure or # distribution of this software and related documentation without an express # license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited. dmc_table = [ [[-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], [[0, 3, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], [[0, 1, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], [[1, 3, 8, 9, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], [[4, 7, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], [[0, 3, 4, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], [[0, 1, 9, -1, -1, -1, -1], [4, 7, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], [[1, 3, 4, 7, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], [[4, 5, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], [[0, 3, 8, -1, -1, -1, -1], [4, 5, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], [[0, 1, 4, 5, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], [[1, 3, 4, 5, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], [[5, 7, 8, 9, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], [[0, 3, 5, 7, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], [[0, 1, 5, 7, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], [[1, 3, 5, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], [[2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], [[0, 2, 8, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], [[0, 1, 9, -1, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], [[1, 2, 8, 9, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], [[4, 7, 8, -1, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], [[0, 2, 4, 7, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], [[0, 1, 9, -1, -1, -1, -1], [4, 7, 8, -1, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], [[1, 2, 4, 7, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], [[4, 5, 9, -1, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], [[0, 2, 8, 11, -1, -1, -1], [4, 5, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], [[0, 1, 4, 5, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], [[1, 2, 4, 5, 8, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], [[5, 7, 8, 9, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], [[0, 2, 5, 7, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], [[0, 1, 5, 7, 8, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], [[1, 2, 5, 7, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], [[1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], [[0, 3, 8, -1, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], [[0, 2, 9, 10, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], [[2, 3, 8, 9, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], [[4, 7, 8, -1, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], [[0, 3, 4, 7, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], [[0, 2, 9, 10, -1, -1, -1], [4, 7, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], [[2, 3, 4, 7, 9, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], [[4, 5, 9, -1, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], [[0, 3, 8, -1, -1, -1, -1], [4, 5, 9, -1, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], [[0, 2, 4, 5, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], [[2, 3, 4, 5, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], [[5, 7, 8, 9, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], [[0, 3, 5, 7, 9, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], [[0, 2, 5, 7, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], [[2, 3, 5, 7, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], [[1, 3, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], [[0, 1, 8, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], [[0, 3, 9, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], [[8, 9, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], [[4, 7, 8, -1, -1, -1, -1], [1, 3, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], [[0, 1, 4, 7, 10, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], [[0, 3, 9, 10, 11, -1, -1], [4, 7, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], [[4, 7, 9, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], [[4, 5, 9, -1, -1, -1, -1], [1, 3, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], [[0, 1, 8, 10, 11, -1, -1], [4, 5, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], [[0, 3, 4, 5, 10, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], [[4, 5, 8, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], [[5, 7, 8, 9, -1, -1, -1], [1, 3, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], [[0, 1, 5, 7, 9, 10, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], [[0, 3, 5, 7, 8, 10, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], [[5, 7, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], [[6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], [[0, 3, 8, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], [[0, 1, 9, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], [[1, 3, 8, 9, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], [[4, 6, 8, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], [[0, 3, 4, 6, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], [[0, 1, 9, -1, -1, -1, -1], [4, 6, 8, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], [[1, 3, 4, 6, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], [[4, 5, 9, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], [[0, 3, 8, -1, -1, -1, -1], [4, 5, 9, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], [[0, 1, 4, 5, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], [[1, 3, 4, 5, 8, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], [[5, 6, 8, 9, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], [[0, 3, 5, 6, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], [[0, 1, 5, 6, 8, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], [[1, 3, 5, 6, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], [[2, 3, 6, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], [[0, 2, 6, 7, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], [[0, 1, 9, -1, -1, -1, -1], [2, 3, 6, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], [[1, 2, 6, 7, 8, 9, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], [[2, 3, 4, 6, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], [[0, 2, 4, 6, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], [[0, 1, 9, -1, -1, -1, -1], [2, 3, 4, 6, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], [[1, 2, 4, 6, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], [[4, 5, 9, -1, -1, -1, -1], [2, 3, 6, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], [[0, 2, 6, 7, 8, -1, -1], [4, 5, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], [[0, 1, 4, 5, -1, -1, -1], [2, 3, 6, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], [[1, 2, 4, 5, 6, 7, 8], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], [[2, 3, 5, 6, 8, 9, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], [[0, 2, 5, 6, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], [[0, 1, 2, 3, 5, 6, 8], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], [[1, 2, 5, 6, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], [[1, 2, 10, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], [[0, 3, 8, -1, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], [[0, 2, 9, 10, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], [[2, 3, 8, 9, 10, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], [[4, 6, 8, 11, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], [[0, 3, 4, 6, 11, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], [[0, 2, 9, 10, -1, -1, -1], [4, 6, 8, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], [[2, 3, 4, 6, 9, 10, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], [[4, 5, 9, -1, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], [[0, 3, 8, -1, -1, -1, -1], [4, 5, 9, -1, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1]], [[0, 2, 4, 5, 10, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], [[2, 3, 4, 5, 8, 10, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], [[5, 6, 8, 9, 11, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], [[0, 3, 5, 6, 9, 11, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], [[0, 2, 5, 6, 8, 10, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], [[2, 3, 5, 6, 10, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], [[1, 3, 6, 7, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], [[0, 1, 6, 7, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], [[0, 3, 6, 7, 9, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], [[6, 7, 8, 9, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], [[1, 3, 4, 6, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], [[0, 1, 4, 6, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], [[0, 3, 4, 6, 8, 9, 10], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], [[4, 6, 9, 10, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], [[4, 5, 9, -1, -1, -1, -1], [1, 3, 6, 7, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], [[0, 1, 6, 7, 8, 10, -1], [4, 5, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], [[0, 3, 4, 5, 6, 7, 10], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], [[4, 5, 6, 7, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], [[1, 3, 5, 6, 8, 9, 10], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], [[0, 1, 5, 6, 9, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], [[0, 3, 8, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], [[5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], [[5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], [[0, 3, 8, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], [[0, 1, 9, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], [[1, 3, 8, 9, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], [[4, 7, 8, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], [[0, 3, 4, 7, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], [[0, 1, 9, -1, -1, -1, -1], [4, 7, 8, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], [[1, 3, 4, 7, 9, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], [[4, 6, 9, 10, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], [[0, 3, 8, -1, -1, -1, -1], [4, 6, 9, 10, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], [[0, 1, 4, 6, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], [[1, 3, 4, 6, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], [[6, 7, 8, 9, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], [[0, 3, 6, 7, 9, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], [[0, 1, 6, 7, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], [[1, 3, 6, 7, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], [[2, 3, 11, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], [[0, 2, 8, 11, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], [[0, 1, 9, -1, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], [[1, 2, 8, 9, 11, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], [[4, 7, 8, -1, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], [[0, 2, 4, 7, 11, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], [[0, 1, 9, -1, -1, -1, -1], [4, 7, 8, -1, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1]], [[1, 2, 4, 7, 9, 11, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], [[4, 6, 9, 10, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], [[0, 2, 8, 11, -1, -1, -1], [4, 6, 9, 10, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], [[0, 1, 4, 6, 10, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], [[1, 2, 4, 6, 8, 10, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], [[6, 7, 8, 9, 10, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], [[0, 2, 6, 7, 9, 10, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], [[0, 1, 6, 7, 8, 10, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], [[1, 2, 6, 7, 10, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], [[1, 2, 5, 6, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], [[0, 3, 8, -1, -1, -1, -1], [1, 2, 5, 6, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], [[0, 2, 5, 6, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], [[2, 3, 5, 6, 8, 9, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], [[4, 7, 8, -1, -1, -1, -1], [1, 2, 5, 6, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], [[0, 3, 4, 7, -1, -1, -1], [1, 2, 5, 6, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], [[0, 2, 5, 6, 9, -1, -1], [4, 7, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], [[2, 3, 4, 5, 6, 7, 9], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], [[1, 2, 4, 6, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], [[0, 3, 8, -1, -1, -1, -1], [1, 2, 4, 6, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], [[0, 2, 4, 6, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], [[2, 3, 4, 6, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], [[1, 2, 6, 7, 8, 9, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], [[0, 1, 2, 3, 6, 7, 9], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], [[0, 2, 6, 7, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], [[2, 3, 6, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], [[1, 3, 5, 6, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], [[0, 1, 5, 6, 8, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], [[0, 3, 5, 6, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], [[5, 6, 8, 9, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], [[4, 7, 8, -1, -1, -1, -1], [1, 3, 5, 6, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], [[0, 1, 4, 5, 6, 7, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], [[0, 3, 5, 6, 9, 11, -1], [4, 7, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], [[4, 5, 6, 7, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], [[1, 3, 4, 6, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], [[0, 1, 4, 6, 8, 9, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], [[0, 3, 4, 6, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], [[4, 6, 8, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], [[1, 3, 6, 7, 8, 9, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], [[0, 1, 9, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], [[0, 3, 6, 7, 8, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], [[6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], [[5, 7, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], [[0, 3, 8, -1, -1, -1, -1], [5, 7, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], [[0, 1, 9, -1, -1, -1, -1], [5, 7, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], [[1, 3, 8, 9, -1, -1, -1], [5, 7, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], [[4, 5, 8, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], [[0, 3, 4, 5, 10, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], [[0, 1, 9, -1, -1, -1, -1], [4, 5, 8, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], [[1, 3, 4, 5, 9, 10, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], [[4, 7, 9, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], [[0, 3, 8, -1, -1, -1, -1], [4, 7, 9, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], [[0, 1, 4, 7, 10, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], [[1, 3, 4, 7, 8, 10, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], [[8, 9, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], [[0, 3, 9, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], [[0, 1, 8, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], [[1, 3, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], [[2, 3, 5, 7, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], [[0, 2, 5, 7, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], [[0, 1, 9, -1, -1, -1, -1], [2, 3, 5, 7, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], [[1, 2, 5, 7, 8, 9, 10], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], [[2, 3, 4, 5, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], [[0, 2, 4, 5, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], [[0, 1, 9, -1, -1, -1, -1], [2, 3, 4, 5, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], [[1, 2, 4, 5, 9, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], [[2, 3, 4, 7, 9, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], [[0, 2, 4, 7, 8, 9, 10], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], [[0, 1, 2, 3, 4, 7, 10], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], [[4, 7, 8, -1, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], [[2, 3, 8, 9, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], [[0, 2, 9, 10, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], [[0, 1, 2, 3, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], [[1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], [[1, 2, 5, 7, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], [[0, 3, 8, -1, -1, -1, -1], [1, 2, 5, 7, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], [[0, 2, 5, 7, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], [[2, 3, 5, 7, 8, 9, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], [[1, 2, 4, 5, 8, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], [[0, 1, 2, 3, 4, 5, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], [[0, 2, 4, 5, 8, 9, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], [[4, 5, 9, -1, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], [[1, 2, 4, 7, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], [[0, 3, 8, -1, -1, -1, -1], [1, 2, 4, 7, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], [[0, 2, 4, 7, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], [[2, 3, 4, 7, 8, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], [[1, 2, 8, 9, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], [[0, 1, 2, 3, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], [[0, 2, 8, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], [[2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], [[1, 3, 5, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], [[0, 1, 5, 7, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], [[0, 3, 5, 7, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], [[5, 7, 8, 9, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], [[1, 3, 4, 5, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], [[0, 1, 4, 5, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], [[0, 3, 4, 5, 8, 9, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], [[4, 5, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], [[1, 3, 4, 7, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], [[0, 1, 4, 7, 8, 9, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], [[0, 3, 4, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], [[4, 7, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], [[1, 3, 8, 9, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], [[0, 1, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], [[0, 3, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], [[-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]] ] num_vd_table = [0, 1, 1, 1, 1, 1, 2, 1, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 2, 1, 2, 1, 3, 1, 2, 2, 2, 1, 2, 1, 2, 1, 1, 2, 1, 1, 2, 2, 2, 1, 2, 3, 1, 1, 2, 2, 1, 1, 1, 1, 1, 1, 2, 1, 2, 1, 2, 2, 1, 1, 2, 1, 1, 1, 1, 2, 2, 2, 1, 1, 2, 1, 2, 3, 2, 2, 1, 1, 1, 1, 1, 1, 2, 1, 1, 1, 2, 1, 2, 2, 2, 1, 1, 1, 1, 1, 2, 3, 2, 2, 2, 2, 2, 1, 3, 4, 2, 2, 2, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 1, 1, 1, 1, 2, 1, 1, 2, 2, 2, 2, 2, 3, 2, 1, 2, 1, 1, 1, 1, 1, 1, 2, 2, 3, 2, 3, 2, 4, 2, 2, 2, 2, 1, 2, 1, 2, 1, 1, 2, 1, 1, 2, 2, 2, 1, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 1, 2, 1, 1, 1, 1, 1, 1, 2, 1, 1, 1, 2, 2, 2, 1, 1, 2, 1, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 2, 1, 1, 1, 2, 1, 1, 1, 1, 2, 1, 1, 1, 1, 1, 2, 1, 1, 1, 1, 1, 2, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0] check_table = [ [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [1, 1, 0, 0, 194], [1, -1, 0, 0, 193], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [1, 0, 1, 0, 164], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [1, 0, -1, 0, 161], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [1, 0, 0, 1, 152], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [1, 0, 0, 1, 145], [1, 0, 0, 1, 144], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [1, 0, 0, -1, 137], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [1, 0, 1, 0, 133], [1, 0, 1, 0, 132], [1, 1, 0, 0, 131], [1, 1, 0, 0, 130], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [1, 0, 0, 1, 100], [0, 0, 0, 0, 0], [1, 0, 0, 1, 98], [0, 0, 0, 0, 0], [1, 0, 0, 1, 96], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [1, 0, 1, 0, 88], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [1, 0, -1, 0, 82], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [1, 0, 1, 0, 74], [0, 0, 0, 0, 0], [1, 0, 1, 0, 72], [0, 0, 0, 0, 0], [1, 0, 0, -1, 70], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [1, -1, 0, 0, 67], [0, 0, 0, 0, 0], [1, -1, 0, 0, 65], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [1, 1, 0, 0, 56], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [1, -1, 0, 0, 52], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [1, 1, 0, 0, 44], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [1, 1, 0, 0, 40], [0, 0, 0, 0, 0], [1, 0, 0, -1, 38], [1, 0, -1, 0, 37], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [1, 0, -1, 0, 33], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [1, -1, 0, 0, 28], [0, 0, 0, 0, 0], [1, 0, -1, 0, 26], [1, 0, 0, -1, 25], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [1, -1, 0, 0, 20], [0, 0, 0, 0, 0], [1, 0, -1, 0, 18], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [1, 0, 0, -1, 9], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [1, 0, 0, -1, 6], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0] ] tet_table = [ [-1, -1, -1, -1, -1, -1], [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0], [1, 1, 1, 1, 1, 1], [4, 4, 4, 4, 4, 4], [0, 0, 0, 0, 0, 0], [4, 0, 0, 4, 4, -1], [1, 1, 1, 1, 1, 1], [4, 4, 4, 4, 4, 4], [0, 4, 0, 4, 4, -1], [0, 0, 0, 0, 0, 0], [1, 1, 1, 1, 1, 1], [5, 5, 5, 5, 5, 5], [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0], [1, 1, 1, 1, 1, 1], [2, 2, 2, 2, 2, 2], [0, 0, 0, 0, 0, 0], [2, 0, 2, -1, 0, 2], [1, 1, 1, 1, 1, 1], [2, -1, 2, 4, 4, 2], [0, 0, 0, 0, 0, 0], [2, 0, 2, 4, 4, 2], [1, 1, 1, 1, 1, 1], [2, 4, 2, 4, 4, 2], [0, 4, 0, 4, 4, 0], [2, 0, 2, 0, 0, 2], [1, 1, 1, 1, 1, 1], [2, 5, 2, 5, 5, 2], [0, 0, 0, 0, 0, 0], [2, 0, 2, 0, 0, 2], [1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1], [0, 1, 1, -1, 0, 1], [0, 0, 0, 0, 0, 0], [2, 2, 2, 2, 2, 2], [4, 1, 1, 4, 4, 1], [0, 1, 1, 0, 0, 1], [4, 0, 0, 4, 4, 0], [2, 2, 2, 2, 2, 2], [-1, 1, 1, 4, 4, 1], [0, 1, 1, 4, 4, 1], [0, 0, 0, 0, 0, 0], [2, 2, 2, 2, 2, 2], [5, 1, 1, 5, 5, 1], [0, 1, 1, 0, 0, 1], [0, 0, 0, 0, 0, 0], [2, 2, 2, 2, 2, 2], [1, 1, 1, 1, 1, 1], [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0], [8, 8, 8, 8, 8, 8], [1, 1, 1, 4, 4, 1], [0, 0, 0, 0, 0, 0], [4, 0, 0, 4, 4, 0], [4, 4, 4, 4, 4, 4], [1, 1, 1, 4, 4, 1], [0, 4, 0, 4, 4, 0], [0, 0, 0, 0, 0, 0], [4, 4, 4, 4, 4, 4], [1, 1, 1, 5, 5, 1], [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0], [5, 5, 5, 5, 5, 5], [6, 6, 6, 6, 6, 6], [6, -1, 0, 6, 0, 6], [6, 0, 0, 6, 0, 6], [6, 1, 1, 6, 1, 6], [4, 4, 4, 4, 4, 4], [0, 0, 0, 0, 0, 0], [4, 0, 0, 4, 4, 4], [1, 1, 1, 1, 1, 1], [6, 4, -1, 6, 4, 6], [6, 4, 0, 6, 4, 6], [6, 0, 0, 6, 0, 6], [6, 1, 1, 6, 1, 6], [5, 5, 5, 5, 5, 5], [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0], [1, 1, 1, 1, 1, 1], [2, 2, 2, 2, 2, 2], [0, 0, 0, 0, 0, 0], [2, 0, 2, 2, 0, 2], [1, 1, 1, 1, 1, 1], [2, 2, 2, 2, 2, 2], [0, 0, 0, 0, 0, 0], [2, 0, 2, 2, 2, 2], [1, 1, 1, 1, 1, 1], [2, 4, 2, 2, 4, 2], [0, 4, 0, 4, 4, 0], [2, 0, 2, 2, 0, 2], [1, 1, 1, 1, 1, 1], [2, 2, 2, 2, 2, 2], [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0], [1, 1, 1, 1, 1, 1], [6, 1, 1, 6, -1, 6], [6, 1, 1, 6, 0, 6], [6, 0, 0, 6, 0, 6], [6, 2, 2, 6, 2, 6], [4, 1, 1, 4, 4, 1], [0, 1, 1, 0, 0, 1], [4, 0, 0, 4, 4, 4], [2, 2, 2, 2, 2, 2], [6, 1, 1, 6, 4, 6], [6, 1, 1, 6, 4, 6], [6, 0, 0, 6, 0, 6], [6, 2, 2, 6, 2, 6], [5, 1, 1, 5, 5, 1], [0, 1, 1, 0, 0, 1], [0, 0, 0, 0, 0, 0], [2, 2, 2, 2, 2, 2], [1, 1, 1, 1, 1, 1], [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0], [6, 6, 6, 6, 6, 6], [1, 1, 1, 1, 1, 1], [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0], [4, 4, 4, 4, 4, 4], [1, 1, 1, 1, 4, 1], [0, 4, 0, 4, 4, 0], [0, 0, 0, 0, 0, 0], [4, 4, 4, 4, 4, 4], [1, 1, 1, 1, 1, 1], [0, 0, 0, 0, 0, 0], [0, 5, 0, 5, 0, 5], [5, 5, 5, 5, 5, 5], [5, 5, 5, 5, 5, 5], [0, 5, 0, 5, 0, 5], [-1, 5, 0, 5, 0, 5], [1, 5, 1, 5, 1, 5], [4, 5, -1, 5, 4, 5], [0, 5, 0, 5, 0, 5], [4, 5, 0, 5, 4, 5], [1, 5, 1, 5, 1, 5], [4, 4, 4, 4, 4, 4], [0, 4, 0, 4, 4, 4], [0, 0, 0, 0, 0, 0], [1, 1, 1, 1, 1, 1], [6, 6, 6, 6, 6, 6], [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0], [1, 1, 1, 1, 1, 1], [2, 5, 2, 5, -1, 5], [0, 5, 0, 5, 0, 5], [2, 5, 2, 5, 0, 5], [1, 5, 1, 5, 1, 5], [2, 5, 2, 5, 4, 5], [0, 5, 0, 5, 0, 5], [2, 5, 2, 5, 4, 5], [1, 5, 1, 5, 1, 5], [2, 4, 2, 4, 4, 2], [0, 4, 0, 4, 4, 4], [2, 0, 2, 0, 0, 2], [1, 1, 1, 1, 1, 1], [2, 6, 2, 6, 6, 2], [0, 0, 0, 0, 0, 0], [2, 0, 2, 0, 0, 2], [1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1], [0, 1, 1, 1, 0, 1], [0, 0, 0, 0, 0, 0], [2, 2, 2, 2, 2, 2], [4, 1, 1, 1, 4, 1], [0, 1, 1, 1, 0, 1], [4, 0, 0, 4, 4, 0], [2, 2, 2, 2, 2, 2], [1, 1, 1, 1, 1, 1], [0, 1, 1, 1, 1, 1], [0, 0, 0, 0, 0, 0], [2, 2, 2, 2, 2, 2], [1, 1, 1, 1, 1, 1], [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0], [2, 2, 2, 2, 2, 2], [1, 1, 1, 1, 1, 1], [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0], [5, 5, 5, 5, 5, 5], [1, 1, 1, 1, 4, 1], [0, 0, 0, 0, 0, 0], [4, 0, 0, 4, 4, 0], [4, 4, 4, 4, 4, 4], [1, 1, 1, 1, 1, 1], [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0], [4, 4, 4, 4, 4, 4], [1, 1, 1, 1, 1, 1], [6, 0, 0, 6, 0, 6], [0, 0, 0, 0, 0, 0], [6, 6, 6, 6, 6, 6], [5, 5, 5, 5, 5, 5], [5, 5, 0, 5, 0, 5], [5, 5, 0, 5, 0, 5], [5, 5, 1, 5, 1, 5], [4, 4, 4, 4, 4, 4], [0, 0, 0, 0, 0, 0], [4, 4, 0, 4, 4, 4], [1, 1, 1, 1, 1, 1], [4, 4, 4, 4, 4, 4], [4, 4, 0, 4, 4, 4], [0, 0, 0, 0, 0, 0], [1, 1, 1, 1, 1, 1], [8, 8, 8, 8, 8, 8], [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0], [1, 1, 1, 1, 1, 1], [2, 2, 2, 2, 2, 2], [0, 0, 0, 0, 0, 0], [2, 2, 2, 2, 0, 2], [1, 1, 1, 1, 1, 1], [2, 2, 2, 2, 2, 2], [0, 0, 0, 0, 0, 0], [2, 2, 2, 2, 2, 2], [1, 1, 1, 1, 1, 1], [2, 2, 2, 2, 2, 2], [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0], [4, 1, 1, 4, 4, 1], [2, 2, 2, 2, 2, 2], [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0], [1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 0, 1], [0, 0, 0, 0, 0, 0], [2, 2, 2, 2, 2, 2], [1, 1, 1, 1, 1, 1], [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0], [2, 4, 2, 4, 4, 2], [1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1], [0, 0, 0, 0, 0, 0], [2, 2, 2, 2, 2, 2], [1, 1, 1, 1, 1, 1], [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0], [2, 2, 2, 2, 2, 2], [1, 1, 1, 1, 1, 1], [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0], [5, 5, 5, 5, 5, 5], [1, 1, 1, 1, 1, 1], [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0], [4, 4, 4, 4, 4, 4], [1, 1, 1, 1, 1, 1], [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0], [4, 4, 4, 4, 4, 4], [1, 1, 1, 1, 1, 1], [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0], [12, 12, 12, 12, 12, 12] ] gflex_num_triangles_table = [0,1,1,2,1,2,2,1] gflex_configuration_table = [ ## 000 [-1, -1, -1, -1, -1, -1], ## 001 [ 4, 2, 5, -1, -1, -1], ## 010 [ 3, 1, 4, -1, -1, -1], ## 011 [ 3, 1, 2, 3, 2, 5], ## 100 [ 0, 3, 5, -1, -1, -1], ## 101 [ 0, 3, 4, 0, 4, 2], ## 110 [ 0, 1, 4, 0, 4, 5], ## 111 [ 0, 1, 2, -1, -1, -1], ] ================================================ FILE: geometry/gshell_flexicubes.py ================================================ # Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property # and proprietary rights in and to this software, related documentation # and any modifications thereto. Any use, reproduction, disclosure or # distribution of this software and related documentation without an express # license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited. import torch from .flexicubes_table import * __all__ = [ 'GShellFlexiCubes' ] class GShellFlexiCubes: """ This class implements the FlexiCubes method for extracting meshes from scalar fields. It maintains a series of lookup tables and indices to support the mesh extraction process. FlexiCubes, a differentiable variant of the Dual Marching Cubes (DMC) scheme, enhances the geometric fidelity and mesh quality of reconstructed meshes by dynamically adjusting the surface representation through gradient-based optimization. During instantiation, the class loads DMC tables from a file and transforms them into PyTorch tensors on the specified device. Attributes: device (str): Specifies the computational device (default is "cuda"). dmc_table (torch.Tensor): Dual Marching Cubes (DMC) table that encodes the edges associated with each dual vertex in 256 Marching Cubes (MC) configurations. num_vd_table (torch.Tensor): Table holding the number of dual vertices in each of the 256 MC configurations. check_table (torch.Tensor): Table resolving ambiguity in cases C16 and C19 of the DMC configurations. tet_table (torch.Tensor): Lookup table used in tetrahedralizing the isosurface. quad_split_1 (torch.Tensor): Indices for splitting a quad into two triangles along one diagonal. quad_split_2 (torch.Tensor): Alternative indices for splitting a quad into two triangles along the other diagonal. quad_split_train (torch.Tensor): Indices for splitting a quad into four triangles during training by connecting all edges to their midpoints. cube_corners (torch.Tensor): Defines the positions of a standard unit cube's eight corners in 3D space, ordered starting from the origin (0,0,0), moving along the x-axis, then y-axis, and finally z-axis. Used as a blueprint for generating a voxel grid. cube_corners_idx (torch.Tensor): Cube corners indexed as powers of 2, used to retrieve the case id. cube_edges (torch.Tensor): Edge connections in a cube, listed in pairs. Used to retrieve edge vertices in DMC. edge_dir_table (torch.Tensor): A mapping tensor that associates edge indices with their corresponding axis. For instance, edge_dir_table[0] = 0 indicates that the first edge is oriented along the x-axis. dir_faces_table (torch.Tensor): A tensor that maps the corresponding axis of shared edges across four adjacent cubes to the shared faces of these cubes. For instance, dir_faces_table[0] = [5, 4] implies that for four cubes sharing an edge along the x-axis, the first and second cubes share faces indexed as 5 and 4, respectively. This tensor is only utilized during isosurface tetrahedralization. adj_pairs (torch.Tensor): A tensor containing index pairs that correspond to neighboring cubes that share the same edge. qef_reg_scale (float): The scaling factor applied to the regularization loss to prevent issues with singularity when solving the QEF. This parameter is only used when a 'grad_func' is specified. weight_scale (float): The scale of weights in FlexiCubes. Should be between 0 and 1. """ def __init__(self, device="cuda", qef_reg_scale=1e-3, weight_scale=0.99): self.device = device self.dmc_table = torch.tensor(dmc_table, dtype=torch.long, device=device, requires_grad=False) self.num_vd_table = torch.tensor(num_vd_table, dtype=torch.long, device=device, requires_grad=False) self.check_table = torch.tensor( check_table, dtype=torch.long, device=device, requires_grad=False) self.tet_table = torch.tensor(tet_table, dtype=torch.long, device=device, requires_grad=False) self.quad_split_1 = torch.tensor([0, 1, 2, 0, 2, 3], dtype=torch.long, device=device, requires_grad=False) self.quad_split_2 = torch.tensor([0, 1, 3, 3, 1, 2], dtype=torch.long, device=device, requires_grad=False) self.quad_split_train = torch.tensor( [0, 1, 1, 2, 2, 3, 3, 0], dtype=torch.long, device=device, requires_grad=False) self.cube_corners = torch.tensor([[0, 0, 0], [1, 0, 0], [0, 1, 0], [1, 1, 0], [0, 0, 1], [ 1, 0, 1], [0, 1, 1], [1, 1, 1]], dtype=torch.float, device=device) self.cube_corners_idx = torch.pow(2, torch.arange(8, requires_grad=False)) self.cube_edges = torch.tensor([0, 1, 1, 5, 4, 5, 0, 4, 2, 3, 3, 7, 6, 7, 2, 6, 2, 0, 3, 1, 7, 5, 6, 4], dtype=torch.long, device=device, requires_grad=False) self.edge_dir_table = torch.tensor([0, 2, 0, 2, 0, 2, 0, 2, 1, 1, 1, 1], dtype=torch.long, device=device) self.dir_faces_table = torch.tensor([ [[5, 4], [3, 2], [4, 5], [2, 3]], [[5, 4], [1, 0], [4, 5], [0, 1]], [[3, 2], [1, 0], [2, 3], [0, 1]] ], dtype=torch.long, device=device) self.adj_pairs = torch.tensor([0, 1, 1, 3, 3, 2, 2, 0], dtype=torch.long, device=device) self.qef_reg_scale = qef_reg_scale self.weight_scale = weight_scale self.gflex_num_triangles_table = torch.tensor(gflex_num_triangles_table, dtype=torch.long, device=device, requires_grad=False) self.gflex_configuration_table = torch.tensor(gflex_configuration_table, dtype=torch.long, device=device, requires_grad=False) def construct_voxel_grid(self, res): """ Generates a voxel grid based on the specified resolution. Args: res (int or list[int]): The resolution of the voxel grid. If an integer is provided, it is used for all three dimensions. If a list or tuple of 3 integers is provided, they define the resolution for the x, y, and z dimensions respectively. Returns: (torch.Tensor, torch.Tensor): Returns the vertices and the indices of the cube corners (index into vertices) of the constructed voxel grid. The vertices are centered at the origin, with the length of each dimension in the grid being one. """ base_cube_f = torch.arange(8).to(self.device) if isinstance(res, int): res = (res, res, res) voxel_grid_template = torch.ones(res, device=self.device) res = torch.tensor([res], dtype=torch.float, device=self.device) coords = torch.nonzero(voxel_grid_template).float() / res # N, 3 verts = (self.cube_corners.unsqueeze(0) / res + coords.unsqueeze(1)).reshape(-1, 3) cubes = (base_cube_f.unsqueeze(0) + torch.arange(coords.shape[0], device=self.device).unsqueeze(1) * 8).reshape(-1) verts_rounded = torch.round(verts * 10**5) / (10**5) verts_unique, inverse_indices = torch.unique(verts_rounded, dim=0, return_inverse=True) cubes = inverse_indices[cubes.reshape(-1)].reshape(-1, 8) return verts_unique - 0.5, cubes def __call__(self, x_nx3, s_n, nu_n, cube_fx8, res, beta_fx12=None, alpha_fx8=None, gamma_f=None, training=False, output_tetmesh=False, grad_func=None): r""" Main function for mesh extraction from scalar field using FlexiCubes. This function converts discrete signed distance fields, encoded on voxel grids and additional per-cube parameters, to triangle or tetrahedral meshes using a differentiable operation as described in `Flexible Isosurface Extraction for Gradient-Based Mesh Optimization`_. FlexiCubes enhances mesh quality and geometric fidelity by adjusting the surface representation based on gradient optimization. The output surface is differentiable with respect to the input vertex positions, scalar field values, and weight parameters. If you intend to extract a surface mesh from a fixed Signed Distance Field without the optimization of parameters, it is suggested to provide the "grad_func" which should return the surface gradient at any given 3D position. When grad_func is provided, the process to determine the dual vertex position adapts to solve a Quadratic Error Function (QEF), as described in the `Manifold Dual Contouring`_ paper, and employs an smart splitting strategy. Please note, this approach is non-differentiable. For more details and example usage in optimization, refer to the `Flexible Isosurface Extraction for Gradient-Based Mesh Optimization`_ SIGGRAPH 2023 paper. Args: x_nx3 (torch.Tensor): Coordinates of the voxel grid vertices, can be deformed. s_n (torch.Tensor): Scalar field values at each vertex of the voxel grid. Negative values denote that the corresponding vertex resides inside the isosurface. This affects the directions of the extracted triangle faces and volume to be tetrahedralized. cube_fx8 (torch.Tensor): Indices of 8 vertices for each cube in the voxel grid. res (int or list[int]): The resolution of the voxel grid. If an integer is provided, it is used for all three dimensions. If a list or tuple of 3 integers is provided, they specify the resolution for the x, y, and z dimensions respectively. beta_fx12 (torch.Tensor, optional): Weight parameters for the cube edges to adjust dual vertices positioning. Defaults to uniform value for all edges. alpha_fx8 (torch.Tensor, optional): Weight parameters for the cube corners to adjust dual vertices positioning. Defaults to uniform value for all vertices. gamma_f (torch.Tensor, optional): Weight parameters to control the splitting of quadrilaterals into triangles. Defaults to uniform value for all cubes. training (bool, optional): If set to True, applies differentiable quad splitting for training. Defaults to False. output_tetmesh (bool, optional): If set to True, outputs a tetrahedral mesh, otherwise, outputs a triangular mesh. Defaults to False. grad_func (callable, optional): A function to compute the surface gradient at specified 3D positions (input: Nx3 positions). The function should return gradients as an Nx3 tensor. If None, the original FlexiCubes algorithm is utilized. Defaults to None. Returns: (torch.Tensor, torch.LongTensor, torch.Tensor): Tuple containing: - Vertices for the extracted triangular/tetrahedral mesh. - Faces for the extracted triangular/tetrahedral mesh. - Regularizer L_dev, computed per dual vertex. .. _Flexible Isosurface Extraction for Gradient-Based Mesh Optimization: https://research.nvidia.com/labs/toronto-ai/flexicubes/ .. _Manifold Dual Contouring: https://people.engr.tamu.edu/schaefer/research/dualsimp_tvcg.pdf """ surf_cubes, occ_fx8 = self._identify_surf_cubes(s_n, cube_fx8) if surf_cubes.sum() == 0: return torch.zeros( (0, 3), device=self.device), torch.zeros( (0, 4), dtype=torch.long, device=self.device) if output_tetmesh else torch.zeros( (0, 3), dtype=torch.long, device=self.device), torch.zeros( (0), device=self.device) beta_fx12, alpha_fx8, gamma_f = self._normalize_weights(beta_fx12, alpha_fx8, gamma_f, surf_cubes) case_ids = self._get_case_id(occ_fx8, surf_cubes, res) surf_edges, idx_map, edge_counts, surf_edges_mask = self._identify_surf_edges(s_n, cube_fx8, surf_cubes) vd, nu_d, nu_d_stopvgd, L_dev, vd_gamma, vd_idx_map = self._compute_vd( x_nx3, cube_fx8[surf_cubes], surf_edges, s_n, nu_n, case_ids, beta_fx12, alpha_fx8, gamma_f, idx_map, grad_func) vertices, nus, nus_stopvgd, faces, s_edges, edge_indices = self._triangulate( s_n, surf_edges, vd, nu_d, nu_d_stopvgd, vd_gamma, edge_counts, idx_map, vd_idx_map, surf_edges_mask, training, grad_func) vertices_open, faces_open, nus_open_stopvgd, nus_boundary_stopvgd = self._triangulate_msdf(vertices, faces, nus, nus_stopvgd) if not output_tetmesh: extra = { 'n_verts_watertight': vertices.size(0), 'vertices_watertight': vertices, 'faces_watertight': faces, 'msdf': nus_open_stopvgd, 'msdf_watertight': nus, 'msdf_boundary': nus_boundary_stopvgd, } # print(torch.any(torch.isnan(nus_open_stopvgd))) return vertices_open, faces_open, L_dev, extra else: raise NotImplementedError vertices, tets = self._tetrahedralize( x_nx3, s_n, cube_fx8, vertices, faces, surf_edges, s_edges, vd_idx_map, case_ids, edge_indices, surf_cubes, training) return vertices, tets, L_dev def _compute_reg_loss(self, vd, ue, edge_group_to_vd, vd_num_edges): """ Regularizer L_dev as in Equation 8 """ dist = torch.norm(ue - torch.index_select(input=vd, index=edge_group_to_vd, dim=0), dim=-1) mean_l2 = torch.zeros_like(vd[:, 0]) mean_l2 = (mean_l2).index_add_(0, edge_group_to_vd, dist) / vd_num_edges.squeeze(1).float() mad = (dist - torch.index_select(input=mean_l2, index=edge_group_to_vd, dim=0)).abs() return mad def _normalize_weights(self, beta_fx12, alpha_fx8, gamma_f, surf_cubes): """ Normalizes the given weights to be non-negative. If input weights are None, it creates and returns a set of weights of ones. """ n_cubes = surf_cubes.shape[0] if beta_fx12 is not None: beta_fx12 = (torch.tanh(beta_fx12) * self.weight_scale + 1) else: beta_fx12 = torch.ones((n_cubes, 12), dtype=torch.float, device=self.device) if alpha_fx8 is not None: alpha_fx8 = (torch.tanh(alpha_fx8) * self.weight_scale + 1) else: alpha_fx8 = torch.ones((n_cubes, 8), dtype=torch.float, device=self.device) if gamma_f is not None: gamma_f = torch.sigmoid(gamma_f) * self.weight_scale + (1 - self.weight_scale)/2 else: gamma_f = torch.ones((n_cubes), dtype=torch.float, device=self.device) return beta_fx12[surf_cubes], alpha_fx8[surf_cubes], gamma_f[surf_cubes] @torch.no_grad() def _get_case_id(self, occ_fx8, surf_cubes, res): """ Obtains the ID of topology cases based on cell corner occupancy. This function resolves the ambiguity in the Dual Marching Cubes (DMC) configurations as described in Section 1.3 of the supplementary material. It should be noted that this function assumes a regular grid. """ case_ids = (occ_fx8[surf_cubes] * self.cube_corners_idx.to(self.device).unsqueeze(0)).sum(-1) problem_config = self.check_table.to(self.device)[case_ids] to_check = problem_config[..., 0] == 1 problem_config = problem_config[to_check] if not isinstance(res, (list, tuple)): res = [res, res, res] # The 'problematic_configs' only contain configurations for surface cubes. Next, we construct a 3D array, # 'problem_config_full', to store configurations for all cubes (with default config for non-surface cubes). # This allows efficient checking on adjacent cubes. problem_config_full = torch.zeros(list(res) + [5], device=self.device, dtype=torch.long) vol_idx = torch.nonzero(problem_config_full[..., 0] == 0) # N, 3 vol_idx_problem = vol_idx[surf_cubes][to_check] problem_config_full[vol_idx_problem[..., 0], vol_idx_problem[..., 1], vol_idx_problem[..., 2]] = problem_config vol_idx_problem_adj = vol_idx_problem + problem_config[..., 1:4] within_range = ( vol_idx_problem_adj[..., 0] >= 0) & ( vol_idx_problem_adj[..., 0] < res[0]) & ( vol_idx_problem_adj[..., 1] >= 0) & ( vol_idx_problem_adj[..., 1] < res[1]) & ( vol_idx_problem_adj[..., 2] >= 0) & ( vol_idx_problem_adj[..., 2] < res[2]) vol_idx_problem = vol_idx_problem[within_range] vol_idx_problem_adj = vol_idx_problem_adj[within_range] problem_config = problem_config[within_range] problem_config_adj = problem_config_full[vol_idx_problem_adj[..., 0], vol_idx_problem_adj[..., 1], vol_idx_problem_adj[..., 2]] # If two cubes with cases C16 and C19 share an ambiguous face, both cases are inverted. to_invert = (problem_config_adj[..., 0] == 1) idx = torch.arange(case_ids.shape[0], device=self.device)[to_check][within_range][to_invert] case_ids.index_put_((idx,), problem_config[to_invert][..., -1]) return case_ids @torch.no_grad() def _identify_surf_edges(self, s_n, cube_fx8, surf_cubes): """ Identifies grid edges that intersect with the underlying surface by checking for opposite signs. As each edge can be shared by multiple cubes, this function also assigns a unique index to each surface-intersecting edge and marks the cube edges with this index. """ occ_n = s_n < 0 all_edges = cube_fx8[surf_cubes][:, self.cube_edges].reshape(-1, 2) unique_edges, _idx_map, counts = torch.unique(all_edges, dim=0, return_inverse=True, return_counts=True) unique_edges = unique_edges.long() mask_edges = occ_n[unique_edges.reshape(-1)].reshape(-1, 2).sum(-1) == 1 surf_edges_mask = mask_edges[_idx_map] counts = counts[_idx_map] mapping = torch.ones((unique_edges.shape[0]), dtype=torch.long, device=cube_fx8.device) * -1 mapping[mask_edges] = torch.arange(mask_edges.sum(), device=cube_fx8.device) # Shaped as [number of cubes x 12 edges per cube]. This is later used to map a cube edge to the unique index # for a surface-intersecting edge. Non-surface-intersecting edges are marked with -1. idx_map = mapping[_idx_map] surf_edges = unique_edges[mask_edges] return surf_edges, idx_map, counts, surf_edges_mask @torch.no_grad() def _identify_surf_cubes(self, s_n, cube_fx8): """ Identifies grid cubes that intersect with the underlying surface by checking if the signs at all corners are not identical. """ occ_n = s_n < 0 occ_fx8 = occ_n[cube_fx8.reshape(-1)].reshape(-1, 8) _occ_sum = torch.sum(occ_fx8, -1) surf_cubes = (_occ_sum > 0) & (_occ_sum < 8) return surf_cubes, occ_fx8 def _linear_interp(self, edges_weight, edges_x): """ Computes the location of zero-crossings on 'edges_x' using linear interpolation with 'edges_weight'. """ edge_dim = edges_weight.dim() - 2 assert edges_weight.shape[edge_dim] == 2 edges_weight = torch.cat([torch.index_select(input=edges_weight, index=torch.tensor(1, device=self.device), dim=edge_dim), - torch.index_select(input=edges_weight, index=torch.tensor(0, device=self.device), dim=edge_dim)], edge_dim) denominator = edges_weight.sum(edge_dim) ue = (edges_x * edges_weight).sum(edge_dim) / denominator return ue def _linear_interp_nonan(self, edges_weight, edges_x): """ Computes the location of zero-crossings on 'edges_x' using linear interpolation with 'edges_weight'. """ edge_dim = edges_weight.dim() - 2 assert edges_weight.shape[edge_dim] == 2 edges_weight = torch.cat([torch.index_select(input=edges_weight, index=torch.tensor(1, device=self.device), dim=edge_dim), - torch.index_select(input=edges_weight, index=torch.tensor(0, device=self.device), dim=edge_dim)], edge_dim) denominator = edges_weight.sum(edge_dim, keepdim=True).expand(-1, 2, 1) with torch.no_grad(): nonzero_mask = (denominator.abs() > 0) scale = torch.zeros_like(edges_weight) scale[nonzero_mask] = edges_weight[nonzero_mask] / denominator[nonzero_mask] ue = (edges_x * scale).sum(edge_dim) return ue def _solve_vd_QEF(self, p_bxnx3, norm_bxnx3, c_bx3=None): p_bxnx3 = p_bxnx3.reshape(-1, 7, 3) norm_bxnx3 = norm_bxnx3.reshape(-1, 7, 3) c_bx3 = c_bx3.reshape(-1, 3) A = norm_bxnx3 B = ((p_bxnx3) * norm_bxnx3).sum(-1, keepdims=True) A_reg = (torch.eye(3, device=p_bxnx3.device) * self.qef_reg_scale).unsqueeze(0).repeat(p_bxnx3.shape[0], 1, 1) B_reg = (self.qef_reg_scale * c_bx3).unsqueeze(-1) A = torch.cat([A, A_reg], 1) B = torch.cat([B, B_reg], 1) dual_verts = torch.linalg.lstsq(A, B).solution.squeeze(-1) return dual_verts def _compute_vd(self, x_nx3, surf_cubes_fx8, surf_edges, s_n, nu_n, case_ids, beta_fx12, alpha_fx8, gamma_f, idx_map, grad_func): """ Computes the location of dual vertices as described in Section 4.2 """ alpha_nx12x2 = torch.index_select(input=alpha_fx8, index=self.cube_edges, dim=1).reshape(-1, 12, 2) surf_edges_x = torch.index_select(input=x_nx3, index=surf_edges.reshape(-1), dim=0).reshape(-1, 2, 3) surf_edges_s = torch.index_select(input=s_n, index=surf_edges.reshape(-1), dim=0).reshape(-1, 2, 1) surf_edges_nu = torch.index_select(input=nu_n, index=surf_edges.reshape(-1), dim=0).reshape(-1, 2, 1) zero_crossing = self._linear_interp(surf_edges_s, surf_edges_x) idx_map = idx_map.reshape(-1, 12) num_vd = torch.index_select(input=self.num_vd_table, index=case_ids, dim=0) edge_group, edge_group_to_vd, edge_group_to_cube, vd_num_edges, vd_gamma = [], [], [], [], [] total_num_vd = 0 vd_idx_map = torch.zeros((case_ids.shape[0], 12), dtype=torch.long, device=self.device, requires_grad=False) if grad_func is not None: normals = torch.nn.functional.normalize(grad_func(zero_crossing), dim=-1) vd = [] for num in torch.unique(num_vd): cur_cubes = (num_vd == num) # consider cubes with the same numbers of vd emitted (for batching) curr_num_vd = cur_cubes.sum() * num curr_edge_group = self.dmc_table[case_ids[cur_cubes], :num].reshape(-1, num * 7) curr_edge_group_to_vd = torch.arange( curr_num_vd, device=self.device).unsqueeze(-1).repeat(1, 7) + total_num_vd total_num_vd += curr_num_vd curr_edge_group_to_cube = torch.arange(idx_map.shape[0], device=self.device)[ cur_cubes].unsqueeze(-1).repeat(1, num * 7).reshape_as(curr_edge_group) curr_mask = (curr_edge_group != -1) edge_group.append(torch.masked_select(curr_edge_group, curr_mask)) edge_group_to_vd.append(torch.masked_select(curr_edge_group_to_vd.reshape_as(curr_edge_group), curr_mask)) edge_group_to_cube.append(torch.masked_select(curr_edge_group_to_cube, curr_mask)) vd_num_edges.append(curr_mask.reshape(-1, 7).sum(-1, keepdims=True)) vd_gamma.append(torch.masked_select(gamma_f, cur_cubes).unsqueeze(-1).repeat(1, num).reshape(-1)) if grad_func is not None: with torch.no_grad(): cube_e_verts_idx = idx_map[cur_cubes] curr_edge_group[~curr_mask] = 0 verts_group_idx = torch.gather(input=cube_e_verts_idx, dim=1, index=curr_edge_group) verts_group_idx[verts_group_idx == -1] = 0 verts_group_pos = torch.index_select( input=zero_crossing, index=verts_group_idx.reshape(-1), dim=0).reshape(-1, num.item(), 7, 3) v0 = x_nx3[surf_cubes_fx8[cur_cubes][:, 0]].reshape(-1, 1, 1, 3).repeat(1, num.item(), 1, 1) curr_mask = curr_mask.reshape(-1, num.item(), 7, 1) verts_centroid = (verts_group_pos * curr_mask).sum(2) / (curr_mask.sum(2)) normals_bx7x3 = torch.index_select(input=normals, index=verts_group_idx.reshape(-1), dim=0).reshape( -1, num.item(), 7, 3) curr_mask = curr_mask.squeeze(2) vd.append(self._solve_vd_QEF((verts_group_pos - v0) * curr_mask, normals_bx7x3 * curr_mask, verts_centroid - v0.squeeze(2)) + v0.reshape(-1, 3)) edge_group = torch.cat(edge_group) edge_group_to_vd = torch.cat(edge_group_to_vd) edge_group_to_cube = torch.cat(edge_group_to_cube) vd_num_edges = torch.cat(vd_num_edges) vd_gamma = torch.cat(vd_gamma) if grad_func is not None: vd = torch.cat(vd) L_dev = torch.zeros([1], device=self.device) else: vd = torch.zeros((total_num_vd, 3), device=self.device) nu_d = torch.zeros((total_num_vd, 1), device=self.device) beta_sum = torch.zeros((total_num_vd, 1), device=self.device) idx_group = torch.gather(input=idx_map.reshape(-1), dim=0, index=edge_group_to_cube * 12 + edge_group) x_group = torch.index_select(input=surf_edges_x, index=idx_group.reshape(-1), dim=0).reshape(-1, 2, 3) s_group = torch.index_select(input=surf_edges_s, index=idx_group.reshape(-1), dim=0).reshape(-1, 2, 1) nu_group = torch.index_select(input=surf_edges_nu, index=idx_group.reshape(-1), dim=0).reshape(-1, 2, 1) zero_crossing_group = torch.index_select( input=zero_crossing, index=idx_group.reshape(-1), dim=0).reshape(-1, 3) alpha_group = torch.index_select(input=alpha_nx12x2.reshape(-1, 2), dim=0, index=edge_group_to_cube * 12 + edge_group).reshape(-1, 2, 1) interp_coeff_group = s_group * alpha_group ue_group = self._linear_interp(interp_coeff_group, x_group) nu_e_group = self._linear_interp(interp_coeff_group, nu_group) nu_e_stopvgd_group = self._linear_interp(interp_coeff_group.detach(), nu_group) beta_group = torch.gather(input=beta_fx12.reshape(-1), dim=0, index=edge_group_to_cube * 12 + edge_group).reshape(-1, 1) beta_sum = beta_sum.index_add_(0, index=edge_group_to_vd, source=beta_group) vd = vd.index_add_(0, index=edge_group_to_vd, source=ue_group * beta_group) / beta_sum nu_d = nu_d.index_add_(0, index=edge_group_to_vd, source=nu_e_group * beta_group) / beta_sum nu_d_stopvgd = nu_d.index_add_(0, index=edge_group_to_vd, source=nu_e_stopvgd_group * beta_group.detach()) / beta_sum.detach() L_dev = self._compute_reg_loss(vd, zero_crossing_group, edge_group_to_vd, vd_num_edges) v_idx = torch.arange(vd.shape[0], device=self.device) # + total_num_vd vd_idx_map = (vd_idx_map.reshape(-1)).scatter(dim=0, index=edge_group_to_cube * 12 + edge_group, src=v_idx[edge_group_to_vd]) return vd, nu_d, nu_d_stopvgd, L_dev, vd_gamma, vd_idx_map def _triangulate(self, s_n, surf_edges, vd, nu_d, nu_d_stopvgd, vd_gamma, edge_counts, idx_map, vd_idx_map, surf_edges_mask, training, grad_func): """ Connects four neighboring dual vertices to form a quadrilateral. The quadrilaterals are then split into triangles based on the gamma parameter, as described in Section 4.3. """ with torch.no_grad(): group_mask = (edge_counts == 4) & surf_edges_mask # surface edges shared by 4 cubes. group = idx_map.reshape(-1)[group_mask] vd_idx = vd_idx_map[group_mask] edge_indices, indices = torch.sort(group, stable=True) quad_vd_idx = vd_idx[indices].reshape(-1, 4) # Ensure all face directions point towards the positive SDF to maintain consistent winding. s_edges = s_n[surf_edges[edge_indices.reshape(-1, 4)[:, 0]].reshape(-1)].reshape(-1, 2) flip_mask = s_edges[:, 0] > 0 quad_vd_idx = torch.cat((quad_vd_idx[flip_mask][:, [0, 1, 3, 2]], quad_vd_idx[~flip_mask][:, [2, 3, 1, 0]])) if grad_func is not None: # when grad_func is given, split quadrilaterals along the diagonals with more consistent gradients. with torch.no_grad(): vd_gamma = torch.nn.functional.normalize(grad_func(vd), dim=-1) quad_gamma = torch.index_select(input=vd_gamma, index=quad_vd_idx.reshape(-1), dim=0).reshape(-1, 4, 3) gamma_02 = (quad_gamma[:, 0] * quad_gamma[:, 2]).sum(-1, keepdims=True) gamma_13 = (quad_gamma[:, 1] * quad_gamma[:, 3]).sum(-1, keepdims=True) else: quad_gamma = torch.index_select(input=vd_gamma, index=quad_vd_idx.reshape(-1), dim=0).reshape(-1, 4) gamma_02 = torch.index_select(input=quad_gamma, index=torch.tensor( 0, device=self.device), dim=1) * torch.index_select(input=quad_gamma, index=torch.tensor(2, device=self.device), dim=1) gamma_13 = torch.index_select(input=quad_gamma, index=torch.tensor( 1, device=self.device), dim=1) * torch.index_select(input=quad_gamma, index=torch.tensor(3, device=self.device), dim=1) if not training: mask = (gamma_02 > gamma_13).squeeze(1) faces = torch.zeros((quad_gamma.shape[0], 6), dtype=torch.long, device=quad_vd_idx.device) faces[mask] = quad_vd_idx[mask][:, self.quad_split_1] faces[~mask] = quad_vd_idx[~mask][:, self.quad_split_2] faces = faces.reshape(-1, 3) else: vd_quad = torch.index_select(input=vd, index=quad_vd_idx.reshape(-1), dim=0).reshape(-1, 4, 3) nu_d_quad = torch.index_select(input=nu_d, index=quad_vd_idx.reshape(-1), dim=0).reshape(-1, 4, 1) nu_d_stopvgd_quad = torch.index_select(input=nu_d_stopvgd, index=quad_vd_idx.reshape(-1), dim=0).reshape(-1, 4, 1) vd_02 = (torch.index_select(input=vd_quad, index=torch.tensor(0, device=self.device), dim=1) + torch.index_select(input=vd_quad, index=torch.tensor(2, device=self.device), dim=1)) / 2 nu_d_02 = (torch.index_select(input=nu_d_quad, index=torch.tensor(0, device=self.device), dim=1) + torch.index_select(input=nu_d_quad, index=torch.tensor(2, device=self.device), dim=1)) / 2 nu_d_stopvgd_02 = (torch.index_select(input=nu_d_stopvgd_quad, index=torch.tensor(0, device=self.device), dim=1) + torch.index_select(input=nu_d_stopvgd_quad, index=torch.tensor(2, device=self.device), dim=1)) / 2 vd_13 = (torch.index_select(input=vd_quad, index=torch.tensor(1, device=self.device), dim=1) + torch.index_select(input=vd_quad, index=torch.tensor(3, device=self.device), dim=1)) / 2 nu_d_13 = (torch.index_select(input=nu_d_quad, index=torch.tensor(1, device=self.device), dim=1) + torch.index_select(input=nu_d_quad, index=torch.tensor(3, device=self.device), dim=1)) / 2 nu_d_stopvgd_13 = (torch.index_select(input=nu_d_stopvgd_quad, index=torch.tensor(1, device=self.device), dim=1) + torch.index_select(input=nu_d_stopvgd_quad, index=torch.tensor(3, device=self.device), dim=1)) / 2 weight_sum = (gamma_02 + gamma_13) + 1e-8 vd_center = ((vd_02 * gamma_02.unsqueeze(-1) + vd_13 * gamma_13.unsqueeze(-1)) / weight_sum.unsqueeze(-1)).squeeze(1) nu_d_center = ((nu_d_02 * gamma_02.unsqueeze(-1) + nu_d_13 * gamma_13.unsqueeze(-1)) / weight_sum.unsqueeze(-1)).squeeze(1) nu_d_stopvgd_center = ((nu_d_stopvgd_02 * gamma_02.unsqueeze(-1).detach() + nu_d_stopvgd_13 * gamma_13.unsqueeze(-1).detach()) / weight_sum.unsqueeze(-1).detach()).squeeze(1) vd_center_idx = torch.arange(vd_center.shape[0], device=self.device) + vd.shape[0] vd = torch.cat([vd, vd_center]) nu_d = torch.cat([nu_d, nu_d_center]) nu_d_stopvgd = torch.cat([nu_d_stopvgd, nu_d_stopvgd_center]) faces = quad_vd_idx[:, self.quad_split_train].reshape(-1, 4, 2) faces = torch.cat([faces, vd_center_idx.reshape(-1, 1, 1).repeat(1, 4, 1)], -1).reshape(-1, 3) return vd, nu_d, nu_d_stopvgd, faces, s_edges, edge_indices def _triangulate_msdf(self, vertices, faces, nu_n, nu_n_stopvgd): with torch.no_grad(): mocc_n = nu_n >= 0 mocc_fx3 = mocc_n[faces.reshape(-1)].reshape(-1,3) mocc_sum = torch.sum(mocc_fx3, -1) uncut_faces_mask = (mocc_sum == 3) cut_faces_mask = (mocc_sum < 3) & (mocc_sum > 0) uncut_faces = faces[uncut_faces_mask] cut_faces = faces[cut_faces_mask] if uncut_faces.size(0) == 0: return vertices, faces, nu_n, nu_n[:1].detach() * 0.0 vertices_cut_edges_fx2 = vertices[cut_faces[:, [0,1,1,2,2,0]].view(-1)].view(-1, 2, 3) nu_cut_edges_fx2 = nu_n[cut_faces[:, [0,1,1,2,2,0]].view(-1)].view(-1, 2, 1) nu_cut_edges_fx2_stopvgd = nu_n_stopvgd[cut_faces[:, [0,1,1,2,2,0]].view(-1)].view(-1, 2, 1) assert vertices_cut_edges_fx2.size(0) == cut_faces.size(0) * 3 ### DEBUG msdf_zero_crossing = self._linear_interp_nonan(nu_cut_edges_fx2, vertices_cut_edges_fx2) nu_boundary_stopvgd = self._linear_interp_nonan(nu_cut_edges_fx2_stopvgd.detach(), nu_cut_edges_fx2_stopvgd) vertices_open = torch.cat([vertices, msdf_zero_crossing], dim=0) nus_open_stopvgd = torch.cat([nu_n_stopvgd, nu_boundary_stopvgd], dim=0) with torch.no_grad(): v_id = torch.flip(torch.pow(2, torch.arange(3, dtype=torch.long, device="cuda")), dims=[0]) ## do this flip because the triangle table uses a different assumption by mistake.. configuration_idx = (mocc_fx3[cut_faces_mask] * v_id.unsqueeze(0)).sum(-1) idx_map = torch.cat([cut_faces, vertices.size(0) + torch.arange(cut_faces.size(0) * 3, device='cuda').view(-1, 3)], dim=-1) num_triangles = self.gflex_num_triangles_table[configuration_idx] faces_open = torch.cat([ uncut_faces, torch.gather(input=idx_map[num_triangles == 1], dim=1, index=self.gflex_configuration_table[configuration_idx[num_triangles == 1]][:, :3]).view(-1, 3), torch.gather(input=idx_map[num_triangles == 2], dim=1, index=self.gflex_configuration_table[configuration_idx[num_triangles == 2]][:, :6]).view(-1, 3), ]) return vertices_open, faces_open, nus_open_stopvgd, nu_boundary_stopvgd def _tetrahedralize( self, x_nx3, s_n, cube_fx8, vertices, faces, surf_edges, s_edges, vd_idx_map, case_ids, edge_indices, surf_cubes, training): """ Tetrahedralizes the interior volume to produce a tetrahedral mesh, as described in Section 4.5. """ occ_n = s_n < 0 occ_fx8 = occ_n[cube_fx8.reshape(-1)].reshape(-1, 8) occ_sum = torch.sum(occ_fx8, -1) inside_verts = x_nx3[occ_n] mapping_inside_verts = torch.ones((occ_n.shape[0]), dtype=torch.long, device=self.device) * -1 mapping_inside_verts[occ_n] = torch.arange(occ_n.sum(), device=self.device) + vertices.shape[0] """ For each grid edge connecting two grid vertices with different signs, we first form a four-sided pyramid by connecting one of the grid vertices with four mesh vertices that correspond to the grid edge and then subdivide the pyramid into two tetrahedra """ inside_verts_idx = mapping_inside_verts[surf_edges[edge_indices.reshape(-1, 4)[:, 0]].reshape(-1, 2)[ s_edges < 0]] if not training: inside_verts_idx = inside_verts_idx.unsqueeze(1).expand(-1, 2).reshape(-1) else: inside_verts_idx = inside_verts_idx.unsqueeze(1).expand(-1, 4).reshape(-1) tets_surface = torch.cat([faces, inside_verts_idx.unsqueeze(-1)], -1) """ For each grid edge connecting two grid vertices with the same sign, the tetrahedron is formed by the two grid vertices and two vertices in consecutive adjacent cells """ inside_cubes = (occ_sum == 8) inside_cubes_center = x_nx3[cube_fx8[inside_cubes].reshape(-1)].reshape(-1, 8, 3).mean(1) inside_cubes_center_idx = torch.arange( inside_cubes_center.shape[0], device=inside_cubes.device) + vertices.shape[0] + inside_verts.shape[0] surface_n_inside_cubes = surf_cubes | inside_cubes edge_center_vertex_idx = torch.ones(((surface_n_inside_cubes).sum(), 13), dtype=torch.long, device=x_nx3.device) * -1 surf_cubes = surf_cubes[surface_n_inside_cubes] inside_cubes = inside_cubes[surface_n_inside_cubes] edge_center_vertex_idx[surf_cubes, :12] = vd_idx_map.reshape(-1, 12) edge_center_vertex_idx[inside_cubes, 12] = inside_cubes_center_idx all_edges = cube_fx8[surface_n_inside_cubes][:, self.cube_edges].reshape(-1, 2) unique_edges, _idx_map, counts = torch.unique(all_edges, dim=0, return_inverse=True, return_counts=True) unique_edges = unique_edges.long() mask_edges = occ_n[unique_edges.reshape(-1)].reshape(-1, 2).sum(-1) == 2 mask = mask_edges[_idx_map] counts = counts[_idx_map] mapping = torch.ones((unique_edges.shape[0]), dtype=torch.long, device=self.device) * -1 mapping[mask_edges] = torch.arange(mask_edges.sum(), device=self.device) idx_map = mapping[_idx_map] group_mask = (counts == 4) & mask group = idx_map.reshape(-1)[group_mask] edge_indices, indices = torch.sort(group) cube_idx = torch.arange((_idx_map.shape[0] // 12), dtype=torch.long, device=self.device).unsqueeze(1).expand(-1, 12).reshape(-1)[group_mask] edge_idx = torch.arange((12), dtype=torch.long, device=self.device).unsqueeze( 0).expand(_idx_map.shape[0] // 12, -1).reshape(-1)[group_mask] # Identify the face shared by the adjacent cells. cube_idx_4 = cube_idx[indices].reshape(-1, 4) edge_dir = self.edge_dir_table[edge_idx[indices]].reshape(-1, 4)[..., 0] shared_faces_4x2 = self.dir_faces_table[edge_dir].reshape(-1) cube_idx_4x2 = cube_idx_4[:, self.adj_pairs].reshape(-1) # Identify an edge of the face with different signs and # select the mesh vertex corresponding to the identified edge. case_ids_expand = torch.ones((surface_n_inside_cubes).sum(), dtype=torch.long, device=x_nx3.device) * 255 case_ids_expand[surf_cubes] = case_ids cases = case_ids_expand[cube_idx_4x2] quad_edge = edge_center_vertex_idx[cube_idx_4x2, self.tet_table[cases, shared_faces_4x2]].reshape(-1, 2) mask = (quad_edge == -1).sum(-1) == 0 inside_edge = mapping_inside_verts[unique_edges[mask_edges][edge_indices].reshape(-1)].reshape(-1, 2) tets_inside = torch.cat([quad_edge, inside_edge], -1)[mask] tets = torch.cat([tets_surface, tets_inside]) vertices = torch.cat([vertices, inside_verts, inside_cubes_center]) return vertices, tets ================================================ FILE: geometry/gshell_flexicubes_geometry.py ================================================ # Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property # and proprietary rights in and to this software, related documentation # and any modifications thereto. Any use, reproduction, disclosure or # distribution of this software and related documentation without an express # license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited import os from tqdm import trange import numpy as np import torch import torch.nn.functional as F from render import mesh from render import render import render.optixutils as ou from render import regularizer from .gshell_flexicubes import GShellFlexiCubes from render import util import kaolin from .mlp import MLP ############################################################################### # Regularizer ############################################################################### def compute_sdf_reg_loss(sdf, all_edges): sdf_f1x6x2 = sdf[all_edges.reshape(-1)].reshape(-1,2) mask = torch.sign(sdf_f1x6x2[...,0]) != torch.sign(sdf_f1x6x2[...,1]) sdf_f1x6x2 = sdf_f1x6x2[mask] sdf_diff = torch.nn.functional.binary_cross_entropy_with_logits(sdf_f1x6x2[...,0], (sdf_f1x6x2[...,1] > 0).float()) + \ torch.nn.functional.binary_cross_entropy_with_logits(sdf_f1x6x2[...,1], (sdf_f1x6x2[...,0] > 0).float()) return sdf_diff ############################################################################### # Geometry interface ############################################################################### class GShellFlexiCubesGeometry(torch.nn.Module): def __init__(self, grid_res, scale, FLAGS): super(GShellFlexiCubesGeometry, self).__init__() self.FLAGS = FLAGS self.grid_res = grid_res self.gflexicubes = GShellFlexiCubes() verts, indices = self.gflexicubes.construct_voxel_grid(grid_res) self.boxscale = torch.tensor(FLAGS.boxscale).view(1, 3).cuda() with torch.no_grad(): self.optix_ctx = ou.OptiXContext() n_cubes = indices.shape[0] per_cube_weights = torch.ones((n_cubes, 21),dtype=torch.float,device='cuda') self.verts = verts * scale * self.boxscale self.indices = indices print("FlexiCubes grid min/max", torch.min(self.verts).item(), torch.max(self.verts).item()) self.generate_edges() if self.FLAGS.use_sdf_mlp: self.sdf = torch.nn.Parameter(torch.zeros_like(self.verts[:, 0]), requires_grad=True) ## placeholder self.register_parameter('sdf', self.sdf) self.sdf_net = MLP( skip_in=self.FLAGS.skip_in, n_freq=self.FLAGS.n_freq, n_hidden=self.FLAGS.n_hidden, d_hidden=self.FLAGS.d_hidden, use_float16=self.FLAGS.use_float16 ) self.sdf_net.cuda() optimizer = torch.optim.Adam(self.sdf_net.parameters(), lr=1e-3) for _ in trange(self.FLAGS.sdf_mlp_pretrain_steps): scaled_verts = self.verts / self.boxscale loss = (self.sdf_net(self.verts) - (scaled_verts.norm(dim=1, keepdim=True) - self.FLAGS.sphere_init_norm)).pow(2).mean() optimizer.zero_grad() loss.backward() optimizer.step() print('sdf net trained with loss:', loss) else: # Random init if not self.FLAGS.sphere_init: sdf = torch.rand_like(self.verts[:,0]) - 0.1 else: scaled_verts = self.verts / self.boxscale sdf = scaled_verts.norm(dim=1) - 0.5 self.sdf = torch.nn.Parameter(sdf.clone().detach(), requires_grad=True) self.register_parameter('sdf', self.sdf) self.per_cube_weights = torch.nn.Parameter(torch.ones_like(per_cube_weights), requires_grad=True) self.register_parameter('weight', self.per_cube_weights) msdf = (torch.rand_like(self.verts[:,0]) - 0.01).clamp(-1, 1) self.msdf = torch.nn.Parameter(msdf.clone().detach(), requires_grad=True) self.register_parameter('msdf', self.msdf) self.deform = torch.nn.Parameter(torch.zeros_like(self.verts), requires_grad=True) self.register_parameter('deform', self.deform) self.clamp_deform() @torch.no_grad() def generate_edges(self): with torch.no_grad(): edges = self.gflexicubes.cube_edges all_edges = self.indices[:,edges].reshape(-1,2) all_edges_sorted = torch.sort(all_edges, dim=1)[0] self.all_edges = torch.unique(all_edges_sorted, dim=0) self.max_displacement = util.length(self.verts[self.all_edges[:, 0]] - self.verts[self.all_edges[:, 1]]).mean() / 4 @torch.no_grad() def getAABB(self): return torch.min(self.verts, dim=0).values, torch.max(self.verts, dim=0).values @torch.no_grad() def clamp_deform(self): if not self.FLAGS.use_tanh_deform: self.deform.data[:] = self.deform.clamp(-1.0, 1.0) self.msdf.data[:] = self.msdf.clamp(-2.0, 2.0) @torch.no_grad() def map_uv2(self, faces): uvs = torch.tensor([[0.0, 0.0], [1.0, 0.0], [0.0, 1.0]], dtype=torch.float, device='cuda') uv_idx = torch.tensor([0,1,2], dtype=torch.long, device='cuda').repeat(faces.shape[0],1) return uvs, uv_idx @torch.no_grad() def map_uv(self, face_gidx, max_idx): N = int(np.ceil(np.sqrt((max_idx+1)//2))) tex_y, tex_x = torch.meshgrid( torch.linspace(0, 1 - (1 / N), N, dtype=torch.float32, device="cuda"), torch.linspace(0, 1 - (1 / N), N, dtype=torch.float32, device="cuda") ) pad = 0.9 / N uvs = torch.stack([ tex_x , tex_y, tex_x + pad, tex_y, tex_x + pad, tex_y + pad, tex_x , tex_y + pad ], dim=-1).view(-1, 2) def _idx(tet_idx, N): x = tet_idx % N y = torch.div(tet_idx, N, rounding_mode='floor') return y * N + x tet_idx = _idx(torch.div(face_gidx, N, rounding_mode='floor'), N) tri_idx = face_gidx % 2 uv_idx = torch.stack(( tet_idx * 4, tet_idx * 4 + tri_idx + 1, tet_idx * 4 + tri_idx + 2 ), dim = -1). view(-1, 3) return uvs, uv_idx def getMesh(self, material, _training=False): v_deformed = self.verts + self.max_displacement * self.deform if self.FLAGS.use_sdf_mlp: sdf = self.sdf_net(v_deformed) else: sdf = self.sdf if self.FLAGS.use_msdf_mlp: msdf = self.msdf_net(v_deformed) else: msdf = self.msdf verts, faces, reg_loss, extra = self.gflexicubes(v_deformed, sdf, msdf, self.indices, self.grid_res, self.per_cube_weights[:,:12], self.per_cube_weights[:,12:20], self.per_cube_weights[:,20], training=_training) self.gflexi_reg_loss = reg_loss.mean() face_gidx = torch.arange(faces.shape[0], dtype=torch.long, device="cuda") uvs, uv_idx = self.map_uv(face_gidx, faces.shape[0]) imesh = mesh.Mesh(verts, faces, v_tex=uvs, t_tex_idx=uv_idx, material=material) with torch.no_grad(): ou.optix_build_bvh(self.optix_ctx, imesh.v_pos.contiguous(), imesh.t_pos_idx.int(), rebuild=1) # Run mesh operations to generate tangent space imesh = mesh.auto_normals(imesh) return_dict = { 'imesh': imesh, 'sdf': sdf, 'msdf': extra['msdf'], 'msdf_watertight': extra['msdf_watertight'], 'msdf_boundary': extra['msdf_boundary'], 'n_verts_watertight': extra['n_verts_watertight'], } if self.FLAGS.visualize_watertight: imesh_watertight = mesh.Mesh(extra['vertices_watertight'], extra['faces_watertight'], v_tex=None, t_tex_idx=None, material=material) imesh_watertight = mesh.auto_normals(imesh_watertight) return_dict['imesh_watertight'] = imesh_watertight return return_dict def render(self, glctx, target, lgt, opt_material, bsdf=None, denoiser=None, shadow_scale=1.0, use_uv=False, training=False): opt_mesh_dict = self.getMesh(opt_material) opt_mesh = opt_mesh_dict['imesh'] opt_mesh_watertight = opt_mesh_dict['imesh_watertight'] if 'imesh_watertight' in opt_mesh_dict else None if opt_mesh.v_pos.size(0) != 0: sampled_pts = kaolin.ops.mesh.sample_points(opt_mesh.v_pos[None,...], opt_mesh.t_pos_idx, 50000)[0][0] opt_mesh_dict['sampled_pts'] = sampled_pts else: opt_mesh_dict['sampled_pts'] = None extra_dict = { 'msdf': opt_mesh_dict['msdf'], } opt_mesh_dict['buffers'] = render.render_mesh( self.FLAGS, glctx, opt_mesh, target['mvp'], target['campos'], lgt, target['resolution'], spp=target['spp'], msaa=True, background=target['background'], bsdf=bsdf, use_uv=use_uv, optix_ctx=self.optix_ctx, denoiser=denoiser, shadow_scale=shadow_scale, extra_dict=extra_dict) if self.FLAGS.visualize_watertight: opt_mesh_dict['buffers_watertight'] = render.render_mesh( self.FLAGS, glctx, opt_mesh_watertight, target['mvp'], target['campos'], lgt, target['resolution'], spp=target['spp'], msaa=True, background=target['background'], bsdf=bsdf, use_uv=use_uv, optix_ctx=self.optix_ctx, denoiser=denoiser, shadow_scale=shadow_scale, extra_dict=extra_dict) return opt_mesh_dict def tick(self, glctx, target, lgt, opt_material, loss_fn, iteration, denoiser): t_iter = iteration / self.FLAGS.iter # ============================================================================================== # Render optimizable object with identical conditions # ============================================================================================== shadow_ramp = min(iteration / 1000, 1.0) if denoiser is not None: denoiser.set_influence(shadow_ramp) opt_mesh_dict = self.render(glctx, target, lgt, opt_material, denoiser=denoiser, shadow_scale=shadow_ramp, training=True) buffers = opt_mesh_dict['buffers'] # ============================================================================================== # Compute loss # ============================================================================================== with torch.no_grad(): # Image-space loss, split into a coverage component and a color component color_ref = target['img'] gt_mask = color_ref[..., 3:] img_loss = F.mse_loss(buffers['shaded'][..., 3:], color_ref[..., 3:]) img_loss = img_loss + loss_fn(buffers['shaded'][..., 0:3] * color_ref[..., 3:], color_ref[..., 0:3] * color_ref[..., 3:]) img_loss = img_loss + 5e-1 * F.l1_loss(buffers['msdf_image'].clamp(min=0) * (gt_mask == 0).float(), torch.zeros_like(gt_mask)) img_loss = img_loss + 5e-1 * F.l1_loss(buffers['msdf_image'].clamp(max=0) * (gt_mask == 1).float(), torch.ones_like(gt_mask)) if self.FLAGS.use_img_2nd_layer: color_ref_2nd = target['img_second'] img_loss = img_loss + F.mse_loss(buffers['shaded_second'][..., 3:], color_ref_2nd[..., 3:]) img_loss = img_loss + loss_fn(buffers['shaded_second'][..., 0:3] * color_ref_2nd[..., 3:], color_ref_2nd[..., 0:3] * color_ref_2nd[..., 3:]) if self.FLAGS.use_depth: depth_loss_scale = 100. depth_loss = depth_loss_scale * ((buffers['invdepth'][:, :, :, :1] - target['invdepth'][:, :, :, :1]).abs()).mean() if self.FLAGS.use_depth_2nd_layer: depth_loss += 0.1 * depth_loss_scale * ((buffers['invdepth_second'][:, :, :, :1] - target['invdepth_second'][:, :, :, :1]).abs()).mean() else: depth_loss = torch.tensor(0., device=img_loss.device) # Eikonal if self.FLAGS.use_sdf_mlp and self.FLAGS.use_eikonal and opt_mesh_dict['sampled_pts'] is not None: v = opt_mesh_dict['sampled_pts'].detach() v.requires_grad = True sdf_eik = self.sdf_net(v) if self.FLAGS.eikonal_scale is None: ### Default hardcoded Eikonal loss schedule if iteration < 500: eik_coeff = 3e-1 elif iteration < 1000: eik_coeff = 1e-1 elif iteration < 2000: eik_coeff = 1e-1 else: eik_coeff = 1e-2 else: eik_coeff = self.FLAGS.eikonal_scale eik_loss = eik_coeff * ( torch.autograd.grad(sdf_eik.sum(), v, create_graph=True)[0].pow(2).sum(dim=-1).sqrt() - 1 ).pow(2).mean() else: eik_loss = torch.tensor(0., device=img_loss.device) if self.FLAGS.use_mesh_msdf_reg: mesh_msdf_regscale = (64 / self.grid_res) ** 3 # scale inversely proportional to grid_res^3 eps = 1e-3 open_scale = self.FLAGS.msdf_reg_open_scale close_scale = self.FLAGS.msdf_reg_close_scale eps = torch.tensor([eps]).cuda() mesh_msdf_reg_loss = open_scale * mesh_msdf_regscale * F.huber_loss( opt_mesh_dict['msdf'].clamp(min=-eps).squeeze(), -eps.expand(opt_mesh_dict['msdf'].size(0)), reduction='sum' ) if close_scale != 0: with torch.no_grad(): visible_verts = (opt_mesh_dict['imesh'].t_pos_idx[buffers['visible_triangles']]).unique() visible_boundary_verts = visible_verts[visible_verts >= opt_mesh_dict['n_verts_watertight']] - opt_mesh_dict['n_verts_watertight'] visible_boundary_mask = torch.zeros(opt_mesh_dict['msdf_boundary'].size(0)).cuda() visible_boundary_mask[visible_boundary_verts] = 1 visible_boundary_mask = visible_boundary_mask.bool() boundary_msdf = opt_mesh_dict['msdf_boundary'] boundary_msdf = boundary_msdf[visible_boundary_mask] mesh_msdf_reg_loss += close_scale * mesh_msdf_regscale * F.huber_loss( boundary_msdf.clamp(max=eps).squeeze(), eps.expand(boundary_msdf.size(0)), reduction='sum' ) else: mesh_msdf_reg_loss = torch.tensor(0., device=img_loss.device) # SDF regularizer sdf_weight = self.FLAGS.sdf_regularizer - (self.FLAGS.sdf_regularizer - 0.01) * min(1.0, 4.0 * t_iter) sdf_reg_loss = compute_sdf_reg_loss(opt_mesh_dict['sdf'], self.all_edges).mean() * sdf_weight # Monochrome shading regularizer if 'diffuse_light' not in buffers: monochrome_loss = torch.zeros_like(img_loss) else: monochrome_loss = regularizer.shading_loss(buffers['diffuse_light'], buffers['specular_light'], color_ref, self.FLAGS.lambda_diffuse, self.FLAGS.lambda_specular) # Material smoothness regularizer mtl_smooth_loss = regularizer.material_smoothness_grad( buffers['kd_grad'], buffers['ks_grad'], buffers['normal_grad'], lambda_kd=self.FLAGS.lambda_kd, lambda_ks=self.FLAGS.lambda_ks, lambda_nrm=self.FLAGS.lambda_nrm) # Chroma regularizer chroma_loss = regularizer.chroma_loss(buffers['kd'], color_ref, self.FLAGS.lambda_chroma) assert 'perturbed_nrm' not in buffers # disable normal map in first pass # FlexiCubes reg loss flexicube_reg_loss = self.gflexi_reg_loss * 0.25 geo_reg_loss = sdf_reg_loss + eik_loss + mesh_msdf_reg_loss + flexicube_reg_loss shading_reg_loss = monochrome_loss + mtl_smooth_loss + chroma_loss reg_loss = geo_reg_loss + shading_reg_loss return img_loss, depth_loss, reg_loss ================================================ FILE: geometry/gshell_tets.py ================================================ import numpy as np import torch from render import util ###################################################################################### # Simple smooth vertex normal computation ###################################################################################### def auto_normals(v_pos, t_pos_idx): i0 = t_pos_idx[:, 0] i1 = t_pos_idx[:, 1] i2 = t_pos_idx[:, 2] v0 = v_pos[i0, :] v1 = v_pos[i1, :] v2 = v_pos[i2, :] face_normals = torch.cross(v1 - v0, v2 - v0) # Splat face normals to vertices v_nrm = torch.zeros_like(v_pos) v_nrm.scatter_add_(0, i0[:, None].repeat(1,3), face_normals) v_nrm.scatter_add_(0, i1[:, None].repeat(1,3), face_normals) v_nrm.scatter_add_(0, i2[:, None].repeat(1,3), face_normals) # Normalize, replace zero (degenerated) normals with some default value v_nrm = torch.where(util.dot(v_nrm, v_nrm) > 1e-20, v_nrm, torch.tensor([0.0, 0.0, 1.0], dtype=torch.float32, device='cuda')) v_nrm = util.safe_normalize(v_nrm) if torch.is_anomaly_enabled(): assert torch.all(torch.isfinite(v_nrm)) return v_nrm, t_pos_idx ###################################################################################### # Compute tangent space from texture map coordinates # Follows http://www.mikktspace.com/ conventions ###################################################################################### def compute_tangents(v_pos, v_tex, v_nrm, t_pos_idx, t_tex_idx, t_nrm_idx): vn_idx = [None] * 3 pos = [None] * 3 tex = [None] * 3 for i in range(0,3): pos[i] = v_pos[t_pos_idx[:, i]] tex[i] = v_tex[t_tex_idx[:, i]] vn_idx[i] = t_nrm_idx[:, i] tangents = torch.zeros_like(v_nrm) tansum = torch.zeros_like(v_nrm) # Compute tangent space for each triangle uve1 = tex[1] - tex[0] uve2 = tex[2] - tex[0] pe1 = pos[1] - pos[0] pe2 = pos[2] - pos[0] nom = (pe1 * uve2[..., 1:2] - pe2 * uve1[..., 1:2]) denom = (uve1[..., 0:1] * uve2[..., 1:2] - uve1[..., 1:2] * uve2[..., 0:1]) # Avoid dimsdfion by zero for degenerated texture coordinates tang = nom / torch.where(denom > 0.0, torch.clamp(denom, min=1e-6), torch.clamp(denom, max=-1e-6)) # Update all 3 vertices for i in range(0,3): idx = vn_idx[i][:, None].repeat(1,3) tangents.scatter_add_(0, idx, tang) # tangents[n_i] = tangents[n_i] + tang tansum.scatter_add_(0, idx, torch.ones_like(tang)) # tansum[n_i] = tansum[n_i] + 1 tangents = tangents / tansum # Normalize and make sure tangent is perpendicular to normal tangents = util.safe_normalize(tangents) tangents = util.safe_normalize(tangents - util.dot(tangents, v_nrm) * v_nrm) if torch.is_anomaly_enabled(): assert torch.all(torch.isfinite(tangents)) return tangents, t_nrm_idx class GShell_Tets: def __init__(self): self.triangle_table = torch.tensor([ [-1, -1, -1, -1, -1, -1], [ 1, 0, 2, -1, -1, -1], [ 4, 0, 3, -1, -1, -1], [ 1, 4, 2, 1, 3, 4], [ 3, 1, 5, -1, -1, -1], [ 2, 3, 0, 2, 5, 3], [ 1, 4, 0, 1, 5, 4], [ 4, 2, 5, -1, -1, -1], [ 4, 5, 2, -1, -1, -1], [ 4, 1, 0, 4, 5, 1], [ 3, 2, 0, 3, 5, 2], [ 1, 3, 5, -1, -1, -1], [ 4, 1, 2, 4, 3, 1], [ 3, 0, 4, -1, -1, -1], [ 2, 0, 1, -1, -1, -1], [-1, -1, -1, -1, -1, -1] ], dtype=torch.long, device='cuda') self.mesh_edge_table = torch.tensor([ [-1, -1, -1, -1, -1, -1], [ 1, 0, 2, 1, -1, -1], [ 4, 0, 3, 4, -1, -1], [ 1, 3, 4, 2, 1, -1], [ 3, 1, 5, 3, -1, -1], [ 2, 5, 3, 0, 2, -1], [ 1, 5, 4, 0, 1, -1], [ 4, 2, 5, 4, -1, -1], [ 4, 5, 2, 4, -1, -1], [ 4, 5, 1, 0, 4, -1], [ 3, 5, 2, 0, 3, -1], [ 1, 3, 5, 1, -1, -1], [ 4, 3, 1, 2, 4, -1], [ 3, 0, 4, 3, -1, -1], [ 2, 0, 1, 2, -1, -1], [-1, -1, -1, -1, -1, -1] ], dtype=torch.long, device='cuda') self.triangle_table_tri = torch.tensor([ ## 000 [-1, -1, -1, -1, -1, -1], ## 001 [ 4, 2, 5, -1, -1, -1], ## 010 [ 3, 1, 4, -1, -1, -1], ## 011 [ 3, 1, 2, 3, 2, 5], ## 100 [ 0, 3, 5, -1, -1, -1], ## 101 [ 0, 3, 4, 0, 4, 2], ## 110 [ 0, 1, 4, 0, 4, 5], ## 111 [ 0, 1, 2, -1, -1, -1], ], dtype=torch.long, device='cuda') self.triangle_table_quad = torch.tensor([ ### in the order of [0, 1, 2, 3] ### so 1000 corresponds to single positive mSDF vertex of index 0 ## 0000 [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], ## 0001 [ 6, 3, 7, -1, -1, -1, -1, -1, -1, -1, -1, -1], ## 0010 [ 5, 2, 6, -1, -1, -1, -1, -1, -1, -1, -1, -1], ## 0011 [ 5, 2, 7, 3, 7, 2, -1, -1, -1, -1, -1, -1], ## 0100 [ 4, 1, 5, -1, -1, -1, -1, -1, -1, -1, -1, -1], ## 0101 [ 4, 1, 5, 4, 5, 7, 5, 6, 7, 7, 6, 3], ## 0110 [ 4, 1, 2, 6, 4, 2, -1, -1, -1, -1, -1, -1], ## 0111 [ 4, 1, 2, 7, 4, 2, 7, 2, 3, -1, -1, -1], ## 1000 [ 0, 4, 7, -1, -1, -1, -1, -1, -1, -1, -1, -1], ## 1001 [ 0, 4, 6, 3, 0, 6, -1, -1, -1, -1, -1, -1], ## 1010 [ 0, 4, 5, 0, 5, 2, 0, 2, 6, 0, 6, 7], ## 1011 [ 0, 4, 5, 0, 5, 2, 0, 2, 3, -1, -1, -1], ## 1100 [ 0, 1, 5, 7, 0, 5, -1, -1, -1, -1, -1, -1], ## 1101 [ 0, 1, 5, 0, 5, 6, 0, 6, 3, -1, -1, -1], ## 1110 [ 0, 1, 2, 0, 2, 6, 0, 6, 7, -1, -1, -1], ## 1111 [ 0, 1, 2, 0, 2, 3, -1, -1, -1, -1, -1, -1], ], dtype=torch.long, device='cuda') self.num_triangles_table = torch.tensor([0,1,1,2,1,2,2,1,1,2,2,1,2,1,1,0], dtype=torch.long, device='cuda') self.base_tet_edges = torch.tensor([0,1,0,2,0,3,1,2,1,3,2,3], dtype=torch.long, device='cuda') self.num_triangles_tri_table = torch.tensor([0,1,1,2,1,2,2,1], dtype=torch.long, device='cuda') self.num_triangles_quad_table = torch.tensor([0,1,1,2,1,4,2,3,1,2,4,3,2,3,3,2], dtype=torch.long, device='cuda') edge_ind_list = [[0, 1], [0, 2], [0, 3], [1, 2], [1, 3], [2, 3]] msdf_from_tetverts = [] for i in range(5): for j in range(i+1, 6): if (edge_ind_list[i][0] == edge_ind_list[j][0] or edge_ind_list[i][0] == edge_ind_list[j][1] or edge_ind_list[i][1] == edge_ind_list[j][0] or edge_ind_list[i][1] == edge_ind_list[j][1] ): msdf_from_tetverts.extend([edge_ind_list[i][0], edge_ind_list[i][1], edge_ind_list[j][0], edge_ind_list[j][1]]) self.msdf_from_tetverts = torch.tensor(msdf_from_tetverts) ############################################################################### # Utility functions ############################################################################### def sort_edges(self, edges_ex2): with torch.no_grad(): order = (edges_ex2[:,0] > edges_ex2[:,1]).long() order = order.unsqueeze(dim=1) a = torch.gather(input=edges_ex2, index=order, dim=1) b = torch.gather(input=edges_ex2, index=1-order, dim=1) return torch.stack([a, b],-1) def map_uv(self, face_gidx, max_idx): N = int(np.ceil(np.sqrt((max_idx+1)//2))) tex_y, tex_x = torch.meshgrid( torch.linspace(0, 1 - (1 / N), N, dtype=torch.float32, device="cuda"), torch.linspace(0, 1 - (1 / N), N, dtype=torch.float32, device="cuda"), indexing='ij' ) pad = 0.9 / N uvs = torch.stack([ tex_x , tex_y, tex_x + pad, tex_y, tex_x + pad, tex_y + pad, tex_x , tex_y + pad ], dim=-1).view(-1, 2) def _idx(tet_idx, N): x = tet_idx % N y = torch.div(tet_idx, N, rounding_mode='trunc') return y * N + x tet_idx = _idx(torch.div(face_gidx, 2, rounding_mode='trunc'), N) tri_idx = face_gidx % 2 uv_idx = torch.stack(( tet_idx * 4, tet_idx * 4 + tri_idx + 1, tet_idx * 4 + tri_idx + 2 ), dim = -1). view(-1, 3) return uvs, uv_idx ############################################################################### # Marching tets implementation ############################################################################### def __call__(self, pos_nx3, sdf_n, msdf_n, tet_fx4, output_watertight_template=True): sdf_n = sdf_n.float() with torch.no_grad(): ### To determine if tets are valid ### Step 1: SDF criteria occ_n = sdf_n > 0 occ_fx4 = occ_n[tet_fx4.reshape(-1)].reshape(-1,4) occ_sum = torch.sum(occ_fx4, -1) ### Step 2: pre-filtering with mSDF - mSDF cannot be all non-negative msdf_fx4 = msdf_n[tet_fx4.reshape(-1)].reshape(-1,4) msdf_sign_fx4 = msdf_fx4 > 0 msdf_sign_sum = torch.sum(msdf_sign_fx4, -1) if output_watertight_template: valid_tets = (occ_sum>0) & (occ_sum<4) else: valid_tets = (occ_sum>0) & (occ_sum<4) & (msdf_sign_sum > 0) # find all vertices all_edges = tet_fx4[valid_tets][:,self.base_tet_edges].reshape(-1,2) all_edges = self.sort_edges(all_edges) unique_edges, idx_map = torch.unique(all_edges, dim=0, return_inverse=True) unique_edges = unique_edges.long() mask_edges = occ_n[unique_edges.reshape(-1)].reshape(-1,2).sum(-1) == 1 mapping = torch.ones((unique_edges.shape[0]), dtype=torch.long, device="cuda") * -1 mapping[mask_edges] = torch.arange(mask_edges.sum(), dtype=torch.long, device="cuda") idx_map = mapping[idx_map] # map edges to verts interp_v = unique_edges[mask_edges] edges_to_interp = pos_nx3[interp_v.reshape(-1)].reshape(-1,2,3) edges_to_interp_sdf = sdf_n[interp_v.reshape(-1)].reshape(-1,2,1) edges_to_interp_sdf[:,-1] *= -1 denominator = edges_to_interp_sdf.sum(1, keepdim = True) denominator = torch.sign(denominator) * (denominator.abs() + 1e-12) denominator[denominator == 0] = 1e-12 edges_to_interp_sdf = torch.flip(edges_to_interp_sdf, [1]) / denominator verts = (edges_to_interp * edges_to_interp_sdf).sum(1) msdf_to_interp = msdf_n[interp_v.reshape(-1)].reshape(-1,2) msdf_vert = (msdf_to_interp * edges_to_interp_sdf.squeeze(-1)).sum(1) msdf_vert_stopvgd = (msdf_to_interp * edges_to_interp_sdf.squeeze(-1).detach()).sum(1) # (M, 6), M: num of pre-filtered tets, storing indices (besides -1) from 0 to num_mask_edges idx_map = idx_map.reshape(-1,6) v_id = torch.pow(2, torch.arange(4, dtype=torch.long, device="cuda")) tetindex = (occ_fx4[valid_tets] * v_id.unsqueeze(0)).sum(-1) # triangle count num_triangles = self.num_triangles_table[tetindex] # Get global face index (static, does not depend on topology), before mSDF processing num_tets = tet_fx4.shape[0] tet_gidx = torch.arange(num_tets, dtype=torch.long, device="cuda")[valid_tets] face_gidx_pre = torch.cat(( tet_gidx[num_triangles == 1]*2, torch.stack((tet_gidx[num_triangles == 2]*2, tet_gidx[num_triangles == 2]*2 + 1), dim=-1).view(-1) ), dim=0) # Get uv before mSDF processing uvs_pre, uv_idx_pre = self.map_uv(face_gidx_pre, num_tets*2) # Generate triangle indices before msdf processing faces = torch.cat(( torch.gather(input=idx_map[num_triangles == 1], dim=1, index=self.triangle_table[tetindex[num_triangles == 1]][:, :3]).reshape(-1,3), torch.gather(input=idx_map[num_triangles == 2], dim=1, index=self.triangle_table[tetindex[num_triangles == 2]][:, :6]).reshape(-1,3), ), dim=0) v_nrm, t_nrm_idx = auto_normals(verts, faces) v_tng, _ = compute_tangents(verts, uvs_pre, v_nrm, faces, faces, faces) ###### Triangulation with mSDF ### Note: we allow area-0 triangular faces for convenience. Can always remove them during post-processing with torch.no_grad(): mesh_edge_tri = torch.gather(input=idx_map[num_triangles == 1], dim=1, index=self.mesh_edge_table[tetindex[num_triangles == 1]][:, [0, 1, 1, 2, 2, 0]] ).view(-1, 3, 2) mesh_edge_quad = torch.gather(input=idx_map[num_triangles == 2], dim=1, index=self.mesh_edge_table[tetindex[num_triangles == 2]][:, [0, 1, 1, 2, 2, 3, 3, 0]] ).view(-1, 4, 2) mocc_fx3 = (msdf_vert[mesh_edge_tri[:, :, 0].reshape(-1)].reshape(-1, 3) > 0).long() mocc_fx4 = (msdf_vert[mesh_edge_quad[:, :, 0].reshape(-1)].reshape(-1, 4) > 0).long() ### Attributes to be interpolated for (non-watertight) mesh vertices on the boundary edges_to_interp_vpos_tri = verts[mesh_edge_tri.reshape(-1)].reshape(-1,2,3) edges_to_interp_vpos_quad = verts[mesh_edge_quad.reshape(-1)].reshape(-1,2,3) edges_to_interp_tng_tri = v_tng[mesh_edge_tri.reshape(-1)].reshape(-1,2,3) edges_to_interp_tng_quad = v_tng[mesh_edge_quad.reshape(-1)].reshape(-1,2,3) edges_to_interp_msdf_tri = msdf_vert[mesh_edge_tri.reshape(-1)].reshape(-1,2,1) edges_to_interp_msdf_quad = msdf_vert[mesh_edge_quad.reshape(-1)].reshape(-1,2,1) edges_to_interp_msdf_tri_stopvgd = msdf_vert_stopvgd[mesh_edge_tri.reshape(-1)].reshape(-1,2,1) edges_to_interp_msdf_quad_stopvgd = msdf_vert_stopvgd[mesh_edge_quad.reshape(-1)].reshape(-1,2,1) ### Linear interpolation on mesh edges (triangle / quad faces) denominator_tri_nonzero = torch.sign(edges_to_interp_msdf_tri[:,:,0]).sum(dim=1).abs() != 2 denominator_quad_nonzero = torch.sign(edges_to_interp_msdf_quad[:,:,0]).sum(dim=1).abs() != 2 edges_to_interp_msdf_tri[:,-1] *= -1 edges_to_interp_msdf_quad[:,-1] *= -1 denominator_tri = edges_to_interp_msdf_tri.sum(1, keepdim=True) denominator_quad = edges_to_interp_msdf_quad.sum(1, keepdim=True) denominator_tri_nonzero = (denominator_tri[:,0,0].abs() > 1e-12) & denominator_tri_nonzero denominator_quad_nonzero = (denominator_quad[:,0,0].abs() > 1e-12) & denominator_quad_nonzero edges_to_interp_msdf_tri_new = torch.zeros_like(edges_to_interp_msdf_tri) edges_to_interp_msdf_quad_new = torch.zeros_like(edges_to_interp_msdf_quad) edges_to_interp_msdf_tri_new[denominator_tri_nonzero] = torch.flip(edges_to_interp_msdf_tri[denominator_tri_nonzero], [1]) / denominator_tri[denominator_tri_nonzero] edges_to_interp_msdf_quad_new[denominator_quad_nonzero] = torch.flip(edges_to_interp_msdf_quad[denominator_quad_nonzero], [1]) / denominator_quad[denominator_quad_nonzero] edges_to_interp_msdf_tri = edges_to_interp_msdf_tri_new edges_to_interp_msdf_quad = edges_to_interp_msdf_quad_new ### Append additional boundary vertices (with negligible corner cases). Notice that unused vertices are included for efficiency reasons. verts_aug = torch.cat([ verts, (edges_to_interp_vpos_tri * edges_to_interp_msdf_tri).sum(1), (edges_to_interp_vpos_quad * edges_to_interp_msdf_quad).sum(1) ], dim=0) v_tng_aug = torch.cat([ v_tng, (edges_to_interp_tng_tri * edges_to_interp_msdf_tri).sum(1), (edges_to_interp_tng_quad * edges_to_interp_msdf_quad).sum(1) ], dim=0) ### NOTE: important to stop gradients from passing through the 'interpolation coefficients' (basically the 'coordinates' of boundary vertices) msdf_vert_tri_stopvgd = (edges_to_interp_msdf_tri_stopvgd * edges_to_interp_msdf_tri.detach()).sum(1).squeeze(dim=-1) msdf_vert_quad_stopvgd = (edges_to_interp_msdf_quad_stopvgd * edges_to_interp_msdf_quad.detach()).sum(1).squeeze(dim=-1) msdf_vert_aug_stopvgd = torch.cat([ msdf_vert_stopvgd, msdf_vert_tri_stopvgd, msdf_vert_quad_stopvgd, ]) msdf_vert_boundary_stopvgd = msdf_vert_aug_stopvgd[msdf_vert.size(0):] ## not all boundary vertices but good enough ### Determine how to cut polygon faces by checking the look-up tables with torch.no_grad(): v_id_msdf_tri = torch.flip(torch.pow(2, torch.arange(3, dtype=torch.long, device="cuda")), dims=[0]) v_id_msdf_quad = torch.flip(torch.pow(2, torch.arange(4, dtype=torch.long, device="cuda")), dims=[0]) mesh_index_tri = (mocc_fx3 * v_id_msdf_tri.unsqueeze(0)).sum(-1) mesh_index_quad = (mocc_fx4 * v_id_msdf_quad.unsqueeze(0)).sum(-1) idx_map_tri = torch.cat([mesh_edge_tri[:, :, 0], verts.size(0) + torch.arange(mesh_edge_tri.size(0) * 3, device='cuda').view(-1, 3)], dim=-1) idx_map_quad = torch.cat([mesh_edge_quad[:, :, 0], verts.size(0) + mesh_edge_tri.size(0) * 3 + torch.arange(mesh_edge_quad.size(0) * 4, device='cuda').view(-1, 4)], dim=-1) num_triangles_tri = self.num_triangles_tri_table[mesh_index_tri] num_triangles_quad = self.num_triangles_quad_table[mesh_index_quad] ### Cut the polygon faces (case-by-case) faces_aug = torch.cat(( torch.gather(input=idx_map_tri[num_triangles_tri == 1], dim=1, index=self.triangle_table_tri[mesh_index_tri[num_triangles_tri == 1]][:, :3]).view(-1, 3), torch.gather(input=idx_map_tri[num_triangles_tri == 2], dim=1, index=self.triangle_table_tri[mesh_index_tri[num_triangles_tri == 2]][:, :6]).view(-1, 3), torch.gather(input=idx_map_quad[num_triangles_quad == 1], dim=1, index=self.triangle_table_quad[mesh_index_quad[num_triangles_quad == 1]][:, :3]).view(-1, 3), torch.gather(input=idx_map_quad[num_triangles_quad == 2], dim=1, index=self.triangle_table_quad[mesh_index_quad[num_triangles_quad == 2]][:, :6]).view(-1, 3), torch.gather(input=idx_map_quad[num_triangles_quad == 3], dim=1, index=self.triangle_table_quad[mesh_index_quad[num_triangles_quad == 3]][:, :9]).view(-1, 3), torch.gather(input=idx_map_quad[num_triangles_quad == 4], dim=1, index=self.triangle_table_quad[mesh_index_quad[num_triangles_quad == 4]][:, :12]).view(-1, 3), ), dim=0) ### Mark all unused vertices (only for convenience in visualization; not necessary) with torch.no_grad(): referenced_vert_idx = faces_aug.unique() mask = torch.ones(verts_aug.size(0)) mask[referenced_vert_idx] = 0 verts_aug[mask.bool()] = 0 if output_watertight_template: extra = { 'n_verts_watertight': verts.size(0), 'vertices_watertight': verts, 'faces_watertight': faces, 'v_tng_watertight': v_tng, 'msdf': msdf_vert_aug_stopvgd, 'msdf_watertight': msdf_vert_stopvgd, 'msdf_boundary': msdf_vert_boundary_stopvgd, } else: extra = { 'msdf': msdf_vert_aug_stopvgd, 'msdf_watertight': msdf_vert_stopvgd, 'msdf_boundary': msdf_vert_boundary_stopvgd, } return verts_aug, faces_aug, None, None, v_tng_aug, extra @torch.no_grad() def marching_from_auggrid(self, pos_nx3, sdf_n, tet_fx4, sorted_tet_edges_fx6x2, coeff_sdf_interp, verts_discretized, midpoint_msdf_sign_n, occgrid ): sdf_n = sdf_n.float() ### To determine if tets are valid ### Step 1: SDF criteria occ_n = sdf_n > 0 occ_fx4 = occ_n[tet_fx4.reshape(-1)].reshape(-1,4) occ_sum = torch.sum(occ_fx4, -1) valid_tets = (occ_sum>0) & (occ_sum<4) v_id = torch.pow(2, torch.arange(4, dtype=torch.long, device="cuda")) tetindex = (occ_fx4[valid_tets] * v_id.unsqueeze(0)).sum(-1) # find all vertices all_edges = sorted_tet_edges_fx6x2.reshape(-1, 6, 2)[valid_tets].reshape(-1, 2) all_edges = all_edges.view(-1, 1, 2) unique_edges, idx_map = torch.unique(all_edges, dim=0, return_inverse=True) unique_edges = unique_edges.long() mask_edges = occ_n[unique_edges.reshape(-1)].reshape(-1,2).sum(-1) == 1 mapping = torch.ones((unique_edges.shape[0]), dtype=torch.long, device="cuda") * -1 mapping[mask_edges] = torch.arange(mask_edges.sum(), dtype=torch.long, device="cuda") idx_map = mapping[idx_map] # map edges to verts interp_v = unique_edges[mask_edges] edges_to_interp = pos_nx3[interp_v.reshape(-1)].reshape(-1,2,3) edges_to_interp_canonical = verts_discretized[interp_v.reshape(-1)].reshape(-1,2,3).float() verts_canonical = (edges_to_interp_canonical[:, 0] + edges_to_interp_canonical[:, 1]) / 2.0 tetedge_cano_midpts = verts_discretized[interp_v.reshape(-1)].float().reshape(-1,2,3).mean(dim=1).long() coeff_sdf_interp = coeff_sdf_interp[tetedge_cano_midpts[:, 0], tetedge_cano_midpts[:, 1], tetedge_cano_midpts[:, 2]].view(-1, 1).clamp(0, 1) verts = edges_to_interp[:, 1] * coeff_sdf_interp + edges_to_interp[:, 0] * (1 - coeff_sdf_interp) msdf_vert = midpoint_msdf_sign_n[tetedge_cano_midpts[:, 0], tetedge_cano_midpts[:, 1], tetedge_cano_midpts[:, 2]] # (M, 6), M: num of pre-filtered tets, storing indices (besides -1) from 0 to num_mask_edges idx_map = idx_map.reshape(-1,6) v_id = torch.pow(2, torch.arange(4, dtype=torch.long, device="cuda")) tetindex = (occ_fx4[valid_tets] * v_id.unsqueeze(0)).sum(-1) # triangle count num_triangles = self.num_triangles_table[tetindex] # Get global face index (static, does not depend on topology), before mSDF processing num_tets = tet_fx4.shape[0] tet_gidx = torch.arange(num_tets, dtype=torch.long, device="cuda")[valid_tets] face_gidx_pre = torch.cat(( tet_gidx[num_triangles == 1]*2, torch.stack((tet_gidx[num_triangles == 2]*2, tet_gidx[num_triangles == 2]*2 + 1), dim=-1).view(-1) ), dim=0) valid_tet_gidx = torch.cat([tet_gidx[num_triangles == 1], tet_gidx[num_triangles == 2]], dim=0) # Get uv before mSDF processing uvs_pre, uv_idx_pre = self.map_uv(face_gidx_pre, num_tets*2) # Generate triangle indices before vis processing faces = torch.cat(( torch.gather(input=idx_map[num_triangles == 1], dim=1, index=self.triangle_table[tetindex[num_triangles == 1]][:, :3]).reshape(-1,3), torch.gather(input=idx_map[num_triangles == 2], dim=1, index=self.triangle_table[tetindex[num_triangles == 2]][:, :6]).reshape(-1,3), ), dim=0) v_nrm, t_nrm_idx = auto_normals(verts, faces) v_tng, _ = compute_tangents(verts, uvs_pre, v_nrm, faces, faces, faces) ###### Triangulation with mSDF # edge_indices_tri = self.pre_mesh_edge_table[tetindex[num_triangles == 1]][:, [0, 1, 1, 2, 2, 0]] # edge_indices_quad = self.pre_mesh_edge_table[tetindex[num_triangles == 2]][:, [0, 1, 1, 2, 2, 3, 3, 0]] # pre_mesh_edge_tri = torch.gather(input=idx_map[num_triangles == 1], dim=1, # index=edge_indices_tri # ).view(-1, 3, 2) # pre_mesh_edge_quad = torch.gather(input=idx_map[num_triangles == 2], dim=1, # index=edge_indices_quad # ).view(-1, 4, 2) pre_mesh_edge_tri = torch.gather(input=idx_map[num_triangles == 1], dim=1, index=self.mesh_edge_table[tetindex[num_triangles == 1]][:, [0, 1, 1, 2, 2, 0]] ).view(-1, 3, 2) pre_mesh_edge_quad = torch.gather(input=idx_map[num_triangles == 2], dim=1, index=self.mesh_edge_table[tetindex[num_triangles == 2]][:, [0, 1, 1, 2, 2, 3, 3, 0]] ).view(-1, 4, 2) msdf_positive_fx3 = (msdf_vert[pre_mesh_edge_tri[:, :, 0].reshape(-1)].reshape(-1, 3) > 0).long() msdf_positive_fx4 = (msdf_vert[pre_mesh_edge_quad[:, :, 0].reshape(-1)].reshape(-1, 4) > 0).long() edges_to_interp_prevert_tri = verts[pre_mesh_edge_tri.reshape(-1)].reshape(-1,2,3) edges_to_interp_prevert_quad = verts[pre_mesh_edge_quad.reshape(-1)].reshape(-1,2,3) edges_to_interp_pretng_tri = v_tng[pre_mesh_edge_tri.reshape(-1)].reshape(-1,2,3) edges_to_interp_pretng_quad = v_tng[pre_mesh_edge_quad.reshape(-1)].reshape(-1,2,3) edges_to_interp_sort_tri = verts_canonical[pre_mesh_edge_tri.reshape(-1)].reshape(-1,2,3) edges_to_interp_sort_quad = verts_canonical[pre_mesh_edge_quad.reshape(-1)].reshape(-1,2,3) meshocc_loc_tri = (edges_to_interp_sort_tri.mean(dim=1) * 2.0).long() meshocc_loc_quad = (edges_to_interp_sort_quad.mean(dim=1) * 2.0).long() msdf_coeff_tri = occgrid[meshocc_loc_tri[:, 0], meshocc_loc_tri[:, 1], meshocc_loc_tri[:, 2]] * 0.5 + 0.5 msdf_coeff_quad = occgrid[meshocc_loc_quad[:, 0], meshocc_loc_quad[:, 1], meshocc_loc_quad[:, 2]] * 0.5 + 0.5 msdf_coeff_tri = torch.stack([msdf_coeff_tri, 1 - msdf_coeff_tri], dim=-1) msdf_coeff_quad = torch.stack([msdf_coeff_quad, 1 - msdf_coeff_quad], dim=-1) inscribed_edge_twopoint_order_tri = torch.sign(edges_to_interp_sort_tri[:, 0, :] - edges_to_interp_sort_tri[:, 1, :]) inscribed_edge_twopoint_order_tri = (inscribed_edge_twopoint_order_tri * torch.tensor([16, 4, 1], device=inscribed_edge_twopoint_order_tri.device).view(1, -1)).sum(dim=-1) inscribed_edge_twopoint_order_tri = torch.stack([inscribed_edge_twopoint_order_tri, -inscribed_edge_twopoint_order_tri], dim=-1) _, inscribed_edge_twopoint_order_tri = inscribed_edge_twopoint_order_tri.sort(dim=-1, descending=True) inscribed_edge_twopoint_order_quad = torch.sign(edges_to_interp_sort_quad[:, 0, :] - edges_to_interp_sort_quad[:, 1, :]) inscribed_edge_twopoint_order_quad = (inscribed_edge_twopoint_order_quad * torch.tensor([16, 4, 1], device=inscribed_edge_twopoint_order_quad.device).view(1, -1)).sum(dim=-1) inscribed_edge_twopoint_order_quad = torch.stack([inscribed_edge_twopoint_order_quad, -inscribed_edge_twopoint_order_quad], dim=-1) _, inscribed_edge_twopoint_order_quad = inscribed_edge_twopoint_order_quad.sort(dim=-1, descending=True) msdf_coeff_tri = torch.gather( input=msdf_coeff_tri, dim=-1, index=inscribed_edge_twopoint_order_tri.view(-1, 2) ).view(-1, 2, 1) msdf_coeff_quad = torch.gather( input=msdf_coeff_quad, dim=-1, index=inscribed_edge_twopoint_order_quad.view(-1, 2) ).view(-1, 2, 1) msdf_coeff_tri = msdf_coeff_tri.view(-1, 2, 1) msdf_coeff_quad = msdf_coeff_quad.view(-1, 2, 1) verts_aug = torch.cat([ verts, (edges_to_interp_prevert_tri * msdf_coeff_tri).sum(1), (edges_to_interp_prevert_quad * msdf_coeff_quad).sum(1), ], dim=0) v_tng_aug = torch.cat([ v_tng, (edges_to_interp_pretng_tri * msdf_coeff_tri).sum(1), (edges_to_interp_pretng_quad * msdf_coeff_quad).sum(1), ], dim=0) msdf_vert_aug = torch.cat([ msdf_vert, torch.zeros(v_tng_aug.size(0) - v_tng.size(0)).cuda() ]) v_id_msdf_tri = torch.flip(torch.pow(2, torch.arange(3, dtype=torch.long, device="cuda")), dims=[0]) ## do this flip because the triangle table uses a different assumption by mistake.. v_id_msdf_quad = torch.flip(torch.pow(2, torch.arange(4, dtype=torch.long, device="cuda")), dims=[0]) premesh_index_tri = (msdf_positive_fx3 * v_id_msdf_tri.unsqueeze(0)).sum(-1) premesh_index_quad = (msdf_positive_fx4 * v_id_msdf_quad.unsqueeze(0)).sum(-1) idx_map_tri = torch.cat([pre_mesh_edge_tri[:, :, 0], verts.size(0) + torch.arange(pre_mesh_edge_tri.size(0) * 3, device='cuda').view(-1, 3)], dim=-1) idx_map_quad = torch.cat([pre_mesh_edge_quad[:, :, 0], verts.size(0) + pre_mesh_edge_tri.size(0) * 3 + torch.arange(pre_mesh_edge_quad.size(0) * 4, device='cuda').view(-1, 4)], dim=-1) num_triangles_tri = self.num_triangles_tri_table[premesh_index_tri] num_triangles_quad = self.num_triangles_quad_table[premesh_index_quad] faces_aug = torch.cat(( torch.gather(input=idx_map_tri[num_triangles_tri == 1], dim=1, index=self.triangle_table_tri[premesh_index_tri[num_triangles_tri == 1]][:, :3]).view(-1, 3), torch.gather(input=idx_map_tri[num_triangles_tri == 2], dim=1, index=self.triangle_table_tri[premesh_index_tri[num_triangles_tri == 2]][:, :6]).view(-1, 3), torch.gather(input=idx_map_quad[num_triangles_quad == 1], dim=1, index=self.triangle_table_quad[premesh_index_quad[num_triangles_quad == 1]][:, :3]).view(-1, 3), torch.gather(input=idx_map_quad[num_triangles_quad == 2], dim=1, index=self.triangle_table_quad[premesh_index_quad[num_triangles_quad == 2]][:, :6]).view(-1, 3), torch.gather(input=idx_map_quad[num_triangles_quad == 3], dim=1, index=self.triangle_table_quad[premesh_index_quad[num_triangles_quad == 3]][:, :9]).view(-1, 3), torch.gather(input=idx_map_quad[num_triangles_quad == 4], dim=1, index=self.triangle_table_quad[premesh_index_quad[num_triangles_quad == 4]][:, :12]).view(-1, 3), ), dim=0) return verts_aug, faces_aug, None, None, v_tng_aug, verts, valid_tet_gidx, msdf_vert_aug, msdf_vert ================================================ FILE: geometry/gshell_tets_geometry.py ================================================ # Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual # property and proprietary rights in and to this material, related # documentation and any modifications thereto. Any use, reproduction, # disclosure or distribution of this material and related documentation # without an express license agreement from NVIDIA CORPORATION or # its affiliates is strictly prohibited. import os from tqdm import trange import numpy as np import torch import torch.nn.functional as F from render import mesh from render import render import render.optixutils as ou from render import regularizer from .gshell_tets import GShell_Tets import kaolin from .mlp import MLP ############################################################################### # Regularizer ############################################################################### def compute_sdf_reg_loss(sdf, all_edges): sdf_f1x6x2 = sdf[all_edges.reshape(-1)].reshape(-1,2) mask = torch.sign(sdf_f1x6x2[...,0]) != torch.sign(sdf_f1x6x2[...,1]) sdf_f1x6x2 = sdf_f1x6x2[mask] sdf_diff = torch.nn.functional.binary_cross_entropy_with_logits(sdf_f1x6x2[...,0], (sdf_f1x6x2[...,1] > 0).float()) + \ torch.nn.functional.binary_cross_entropy_with_logits(sdf_f1x6x2[...,1], (sdf_f1x6x2[...,0] > 0).float()) return sdf_diff ############################################################################### # Geometry interface ############################################################################### class GShellTetsGeometry(torch.nn.Module): def __init__(self, grid_res, scale, FLAGS, offset=None, tet_init_file=None, extract_from_generative=False): super(GShellTetsGeometry, self).__init__() self.FLAGS = FLAGS self.grid_res = grid_res self.gshell_tets = GShell_Tets() self.scale = scale self.boxscale = torch.tensor(FLAGS.boxscale).view(1, 3).cuda() with torch.no_grad(): self.optix_ctx = ou.OptiXContext() if tet_init_file is None: tets = np.load('data/tets/{}_tets.npz'.format(self.grid_res)) else: tets = np.load(tet_init_file) print(f'using resolution {self.grid_res}') self.verts = torch.tensor(tets['vertices'], dtype=torch.float32, device='cuda') self.original_verts = self.verts.clone() if extract_from_generative else None self.verts = self.verts - self.verts.mean(dim=0) self.verts = self.verts * scale * self.boxscale self.indices = torch.tensor(tets['indices'], dtype=torch.long, device='cuda') self.generate_edges() if extract_from_generative: self.sorted_tetedges = torch.tensor(tets['tet_edges'], dtype=torch.long, device='cuda') vertices = torch.tensor(tets['vertices'], dtype=torch.float32, device='cuda') vertices_unique = vertices.view(-1).unique() dx = (vertices_unique[1] - vertices_unique[0]) / 2.0 ### denser grid for edge + tet features vertices_discretized = ( ((vertices - vertices.min()) / dx) ).long() self.verts_discretized = vertices_discretized.long().float() ### used to identify where to store edge + tet features if offset is None: offset = 0.0 else: offset = torch.tensor(offset).cuda().view(1, 3) self.offset = offset if self.FLAGS.use_sdf_mlp: self.sdf = torch.nn.Parameter(torch.zeros_like(self.verts[:, 0]), requires_grad=True) ## placeholder self.register_parameter('sdf', self.sdf) self.sdf_net = MLP( skip_in=self.FLAGS.skip_in, n_freq=self.FLAGS.n_freq, n_hidden=self.FLAGS.n_hidden, d_hidden=self.FLAGS.d_hidden, use_float16=self.FLAGS.use_float16 ) self.sdf_net.cuda() optimizer = torch.optim.Adam(self.sdf_net.parameters(), lr=1e-3) for _ in trange(self.FLAGS.sdf_mlp_pretrain_steps): scaled_verts = self.verts / self.boxscale loss = (self.sdf_net(self.verts) - (scaled_verts.norm(dim=1, keepdim=True) - self.FLAGS.sphere_init_norm)).pow(2).mean() optimizer.zero_grad() loss.backward() optimizer.step() print('sdf net trained with loss:', loss) else: # Random init if not self.FLAGS.sphere_init: sdf = torch.rand_like(self.verts[:,0]) - 0.1 else: scaled_verts = self.verts / self.boxscale sdf = scaled_verts.norm(dim=1) - 0.5 self.sdf = torch.nn.Parameter(sdf.clone().detach(), requires_grad=True) self.register_parameter('sdf', self.sdf) if self.FLAGS.use_msdf_mlp: self.msdf = torch.nn.Parameter(torch.zeros_like(self.verts[:, 0]), requires_grad=True) ## placeholder self.register_parameter('msdf', self.msdf) self.msdf_net = MLP( skip_in=self.FLAGS.skip_in, n_freq=self.FLAGS.n_freq, n_hidden=self.FLAGS.n_hidden, d_hidden=self.FLAGS.d_hidden, use_float16=self.FLAGS.use_float16 ) self.msdf_net.cuda() optimizer = torch.optim.Adam(self.msdf_net.parameters(), lr=1e-3) for _ in trange(100): scaled_verts = self.verts / self.boxscale loss = (self.msdf_net(self.verts) - 0.1).pow(2).mean() optimizer.zero_grad() loss.backward() optimizer.step() print('sdf net trained with loss:', loss) del optimizer else: msdf = (torch.rand_like(self.verts[:,0]) - 0.01).clamp(-1, 1) self.msdf = torch.nn.Parameter(msdf.clone().detach(), requires_grad=True) self.register_parameter('msdf', self.msdf) self.deform = torch.nn.Parameter(torch.zeros_like(self.verts), requires_grad=True) self.register_parameter('deform', self.deform) self.clamp_deform() @torch.no_grad() def generate_edges(self): with torch.no_grad(): edges = torch.tensor([0,1,0,2,0,3,1,2,1,3,2,3], dtype = torch.long, device = "cuda") all_edges = self.indices[:,edges].reshape(-1,2) all_edges_sorted = torch.sort(all_edges, dim=1)[0] self.all_edges = torch.unique(all_edges_sorted, dim=0) self.max_displacement = 1.0 / self.grid_res * self.scale / 2.1 @torch.no_grad() def getAABB(self): return torch.min(self.verts, dim=0).values, torch.max(self.verts, dim=0).values @torch.no_grad() def clamp_deform(self): if not self.FLAGS.use_tanh_deform: self.deform.data[:] = self.deform.clamp(-1.0, 1.0) self.msdf.data[:] = self.msdf.clamp(-2.0, 2.0) def getMesh_from_augmented_grid_withocc(self, material, sdf_sign, sdf_coeff, msdf_sign, occgrid): # Run DM tet to get a base mesh v_deformed = self.verts + self.max_displacement * self.deform if self.FLAGS.use_sdf_mlp: sdf = self.sdf_net(v_deformed) else: sdf = self.sdf verts, faces, uvs, uv_idx, v_tng, v_pos_original, tet_gidx, v_msdf, msdf_vert_original = self.gshell_tets.marching_from_auggrid( v_deformed, sdf_sign, self.indices, self.sorted_tetedges, sdf_coeff, self.verts_discretized, msdf_sign, occgrid) imesh = mesh.Mesh(verts, faces, v_tex=uvs, t_tex_idx=uv_idx, material=material) # Run mesh operations to generate tangent space imesh = mesh.auto_normals(imesh) imesh = mesh.compute_tangents(imesh, v_tng=v_tng) return { 'imesh': imesh, 'sdf': sdf, 'v_msdf': v_msdf, } def getMesh(self, material): v_deformed = self.verts + self.max_displacement * self.deform if self.FLAGS.use_sdf_mlp: sdf = self.sdf_net(v_deformed) else: sdf = self.sdf if self.FLAGS.use_msdf_mlp: msdf = self.msdf_net(v_deformed) else: msdf = self.msdf v_deformed = v_deformed + self.offset verts, faces, uvs, uv_idx, v_tng, extra = self.gshell_tets( v_deformed, sdf, msdf, self.indices) imesh = mesh.Mesh(verts, faces, v_tex=uvs, t_tex_idx=uv_idx, material=material) with torch.no_grad(): ou.optix_build_bvh(self.optix_ctx, imesh.v_pos.contiguous(), imesh.t_pos_idx.int(), rebuild=1) # Run mesh operations to generate tangent space imesh = mesh.auto_normals(imesh) return_dict = { 'imesh': imesh, 'sdf': sdf, 'msdf': extra['msdf'], 'msdf_watertight': extra['msdf_watertight'], 'msdf_boundary': extra['msdf_boundary'], 'n_verts_watertight': extra['n_verts_watertight'], } if self.FLAGS.visualize_watertight: imesh_watertight = mesh.Mesh(extra['vertices_watertight'], extra['faces_watertight'], v_tex=None, t_tex_idx=None, material=material) imesh_watertight = mesh.auto_normals(imesh_watertight) return_dict['imesh_watertight'] = imesh_watertight return return_dict def render(self, glctx, target, lgt, opt_material, bsdf=None, denoiser=None, shadow_scale=1.0, use_uv=False): opt_mesh_dict = self.getMesh(opt_material) opt_mesh = opt_mesh_dict['imesh'] opt_mesh_watertight = opt_mesh_dict['imesh_watertight'] if 'imesh_watertight' in opt_mesh_dict else None if opt_mesh.v_pos.size(0) != 0: sampled_pts = kaolin.ops.mesh.sample_points(opt_mesh.v_pos[None,...], opt_mesh.t_pos_idx, 50000)[0][0] opt_mesh_dict['sampled_pts'] = sampled_pts else: opt_mesh_dict['sampled_pts'] = None extra_dict = { 'msdf': opt_mesh_dict['msdf'], } opt_mesh_dict['buffers'] = render.render_mesh( self.FLAGS, glctx, opt_mesh, target['mvp'], target['campos'], lgt, target['resolution'], spp=target['spp'], msaa=True, background=target['background'], bsdf=bsdf, use_uv=use_uv, optix_ctx=self.optix_ctx, denoiser=denoiser, shadow_scale=shadow_scale, extra_dict=extra_dict) if self.FLAGS.visualize_watertight: opt_mesh_dict['buffers_watertight'] = render.render_mesh( self.FLAGS, glctx, opt_mesh_watertight, target['mvp'], target['campos'], lgt, target['resolution'], spp=target['spp'], msaa=True, background=target['background'], bsdf=bsdf, use_uv=use_uv, optix_ctx=self.optix_ctx, denoiser=denoiser, shadow_scale=shadow_scale, extra_dict=extra_dict) return opt_mesh_dict def tick(self, glctx, target, lgt, opt_material, loss_fn, iteration, denoiser): t_iter = iteration / self.FLAGS.iter # ============================================================================================== # Render optimizable object with identical conditions # ============================================================================================== shadow_ramp = min(iteration / 1000, 1.0) ### set occlusion ray influence if denoiser is not None: denoiser.set_influence(shadow_ramp) opt_mesh_dict = self.render(glctx, target, lgt, opt_material, denoiser=denoiser, shadow_scale=shadow_ramp) buffers = opt_mesh_dict['buffers'] # ============================================================================================== # Compute loss # ============================================================================================== with torch.no_grad(): # Image-space loss, split into a coverage component and a color component color_ref = target['img'] gt_mask = color_ref[..., 3:] img_loss = F.mse_loss(buffers['shaded'][..., 3:], color_ref[..., 3:]) img_loss = img_loss + loss_fn(buffers['shaded'][..., 0:3] * color_ref[..., 3:], color_ref[..., 0:3] * color_ref[..., 3:]) img_loss = img_loss + 5e-1 * F.l1_loss(buffers['msdf_image'].clamp(min=0) * (gt_mask == 0).float(), torch.zeros_like(gt_mask)) img_loss = img_loss + 5e-1 * F.l1_loss(buffers['msdf_image'].clamp(max=0) * (gt_mask == 1).float(), torch.ones_like(gt_mask)) if self.FLAGS.use_img_2nd_layer: color_ref_2nd = target['img_second'] img_loss = img_loss + F.mse_loss(buffers['shaded_second'][..., 3:], color_ref_2nd[..., 3:]) img_loss = img_loss + loss_fn(buffers['shaded_second'][..., 0:3] * color_ref_2nd[..., 3:], color_ref_2nd[..., 0:3] * color_ref_2nd[..., 3:]) if self.FLAGS.use_depth: depth_loss_scale = 100. depth_loss = depth_loss_scale * ((buffers['invdepth'][:, :, :, :1] - target['invdepth'][:, :, :, :1]).abs()).mean() if self.FLAGS.use_depth_2nd_layer: depth_loss += 0.1 * depth_loss_scale * ((buffers['invdepth_second'][:, :, :, :1] - target['invdepth_second'][:, :, :, :1]).abs()).mean() else: depth_loss = torch.tensor(0., device=img_loss.device) # Eikonal if self.FLAGS.use_sdf_mlp and self.FLAGS.use_eikonal and opt_mesh_dict['sampled_pts'] is not None: v = opt_mesh_dict['sampled_pts'].detach() v.requires_grad = True sdf_eik = self.sdf_net(v) if self.FLAGS.eikonal_scale is None: ### Default hardcoded Eikonal loss schedule if iteration < 500: eik_coeff = 3e-1 elif iteration < 1000: eik_coeff = 1e-1 elif iteration < 2000: eik_coeff = 1e-1 else: eik_coeff = 1e-2 else: eik_coeff = self.FLAGS.eikonal_scale eik_loss = eik_coeff * ( torch.autograd.grad(sdf_eik.sum(), v, create_graph=True)[0].pow(2).sum(dim=-1).sqrt() - 1 ).pow(2).mean() else: eik_loss = torch.tensor(0., device=img_loss.device) if self.FLAGS.use_mesh_msdf_reg: mesh_msdf_regscale = (64 / self.grid_res) ** 3 # scale inversely proportional to grid_res^3 eps = 1e-3 open_scale = self.FLAGS.msdf_reg_open_scale close_scale = self.FLAGS.msdf_reg_close_scale eps = torch.tensor([eps]).cuda() if open_scale > 0: mesh_msdf_reg_loss = open_scale * mesh_msdf_regscale * F.huber_loss( opt_mesh_dict['msdf'].clamp(min=-eps).squeeze(), -eps.expand(opt_mesh_dict['msdf'].size(0)), reduction='sum' ) else: mesh_msdf_reg_loss = torch.tensor(0., device=img_loss.device) if close_scale != 0: with torch.no_grad(): visible_verts = (opt_mesh_dict['imesh'].t_pos_idx[buffers['visible_triangles']]).unique() visible_boundary_verts = visible_verts[visible_verts >= opt_mesh_dict['n_verts_watertight']] - opt_mesh_dict['n_verts_watertight'] visible_boundary_mask = torch.zeros(opt_mesh_dict['msdf_boundary'].size(0)).cuda() visible_boundary_mask[visible_boundary_verts] = 1 visible_boundary_mask = visible_boundary_mask.bool() boundary_msdf = opt_mesh_dict['msdf_boundary'] boundary_msdf = boundary_msdf[visible_boundary_mask] mesh_msdf_reg_loss += close_scale * mesh_msdf_regscale * F.huber_loss( boundary_msdf.clamp(max=eps).squeeze(), eps.expand(boundary_msdf.size(0)), reduction='sum' ) else: mesh_msdf_reg_loss = torch.tensor(0., device=img_loss.device) # SDF regularizer sdf_weight = self.FLAGS.sdf_regularizer - (self.FLAGS.sdf_regularizer - 0.01) * min(1.0, 4.0 * t_iter) sdf_reg_loss = compute_sdf_reg_loss(opt_mesh_dict['sdf'], self.all_edges).mean() * sdf_weight # Monochrome shading regularizer if 'diffuse_light' not in buffers: monochrome_loss = torch.zeros_like(img_loss) else: monochrome_loss = regularizer.shading_loss(buffers['diffuse_light'], buffers['specular_light'], color_ref, self.FLAGS.lambda_diffuse, self.FLAGS.lambda_specular) # Material smoothness regularizer mtl_smooth_loss = regularizer.material_smoothness_grad( buffers['kd_grad'], buffers['ks_grad'], buffers['normal_grad'], lambda_kd=self.FLAGS.lambda_kd, lambda_ks=self.FLAGS.lambda_ks, lambda_nrm=self.FLAGS.lambda_nrm) # Chroma regularizer chroma_loss = regularizer.chroma_loss(buffers['kd'], color_ref, self.FLAGS.lambda_chroma) assert 'perturbed_nrm' not in buffers # disable normal map in first pass geo_reg_loss = sdf_reg_loss + eik_loss + mesh_msdf_reg_loss shading_reg_loss = monochrome_loss + mtl_smooth_loss + chroma_loss reg_loss = geo_reg_loss + shading_reg_loss return img_loss, depth_loss, reg_loss ================================================ FILE: geometry/mlp.py ================================================ import torch import torch.nn as nn import numpy as np from .embedding import Embedding class MLP(nn.Module): def __init__(self, n_freq=6, d_hidden=128, d_out=1, n_hidden=3, skip_in=[], use_float16=False): super().__init__() self.emb = Embedding(3, n_freq) layers = [ nn.Linear(self.emb.out_channels, d_hidden), nn.Softplus(beta=100) ] count = 2 self.skip_count = [] self.skip_in = skip_in for i in range(n_hidden): if i in skip_in: layers.append(nn.Linear(d_hidden + self.emb.out_channels, d_hidden)) self.skip_count.append(count) else: layers.append(nn.Linear(d_hidden, d_hidden)) count += 1 layers.append(nn.Softplus(beta=100)) count += 1 layers.append(nn.Linear(d_hidden, d_out)) count += 1 self.net = nn.ModuleList(layers) self.use_float16 = use_float16 def forward(self, x): emb = self.emb(x) x = emb with torch.autocast('cuda', dtype=torch.float16, enabled=self.use_float16): for i, module in enumerate(self.net): if i in self.skip_count: x = module(torch.cat([x, emb], dim=-1)) else: x = module(x) return x ================================================ FILE: render/light.py ================================================ # Copyright (c) 2020-2022, NVIDIA CORPORATION. All rights reserved. # # NVIDIA CORPORATION and its licensors retain all intellectual property # and proprietary rights in and to this software, related documentation # and any modifications thereto. Any use, reproduction, disclosure or # distribution of this software and related documentation without an express # license agreement from NVIDIA CORPORATION is strictly prohibited. import os import numpy as np import torch import nvdiffrast.torch as dr from . import util from . import renderutils as ru ###################################################################################### # Monte-carlo sampled environment light with PDF / CDF computation ###################################################################################### class EnvironmentLight: LIGHT_MIN_RES = 16 MIN_ROUGHNESS = 0.08 MAX_ROUGHNESS = 0.5 def __init__(self, base): self.mtx = None self.base = base self.pdf_scale = (self.base.shape[0] * self.base.shape[1]) / (2 * np.pi * np.pi) self.update_pdf() def xfm(self, mtx): self.mtx = mtx def parameters(self): return [self.base] def clone(self): return EnvironmentLight(self.base.clone().detach()) def clamp_(self, min=None, max=None): self.base.clamp_(min, max) def update_pdf(self): with torch.no_grad(): # Compute PDF Y = util.pixel_grid(self.base.shape[1], self.base.shape[0])[..., 1] self._pdf = torch.max(self.base, dim=-1)[0] * torch.sin(Y * np.pi) # Scale by sin(theta) for lat-long, https://cs184.eecs.berkeley.edu/sp18/article/25 self._pdf = self._pdf / torch.sum(self._pdf) # Compute cumulative sums over the columns and rows self.cols = torch.cumsum(self._pdf, dim=1) self.rows = torch.cumsum(self.cols[:, -1:].repeat([1, self.cols.shape[1]]), dim=0) # Normalize self.cols = self.cols / torch.where(self.cols[:, -1:] > 0, self.cols[:, -1:], torch.ones_like(self.cols)) self.rows = self.rows / torch.where(self.rows[-1:, :] > 0, self.rows[-1:, :], torch.ones_like(self.rows)) @torch.no_grad() def generate_image(self, res): texcoord = util.pixel_grid(res[1], res[0]) return dr.texture(self.base[None, ...].contiguous(), texcoord[None, ...].contiguous(), filter_mode='linear')[0] ###################################################################################### # Load and store ###################################################################################### @torch.no_grad() def _load_env_hdr(fn, scale=1.0, res=None, trainable=False): latlong_img = torch.tensor(util.load_image(fn), dtype=torch.float32, device='cuda')*scale if res is not None: texcoord = util.pixel_grid(res[1], res[0]) latlong_img = torch.clamp(dr.texture(latlong_img[None, ...], texcoord[None, ...], filter_mode='linear')[0], min=0.0001) print("EnvProbe,", latlong_img.shape, ", min/max", torch.min(latlong_img).item(), torch.max(latlong_img).item()) if trainable: print("trainable light loaded") return EnvironmentLight(base=latlong_img.clone().detach().requires_grad_(True)) else: return EnvironmentLight(base=latlong_img) @torch.no_grad() def load_env(fn, scale=1.0, res=None, trainable=False): if os.path.splitext(fn)[1].lower() == ".hdr": return _load_env_hdr(fn, scale, res, trainable=trainable) else: assert False, "Unknown envlight extension %s" % os.path.splitext(fn)[1] @torch.no_grad() def save_env_map(fn, light): assert isinstance(light, EnvironmentLight) color = light.generate_image([512, 1024]) util.save_image_raw(fn, color.detach().cpu().numpy()) ###################################################################################### # Create trainable with random initialization ###################################################################################### def create_trainable_env_rnd(base_res, scale=0.5, bias=0.25): base = torch.rand(base_res, base_res, 3, dtype=torch.float32, device='cuda') * scale + bias l = EnvironmentLight(base.clone().detach().requires_grad_(True)) return l ================================================ FILE: render/material.py ================================================ # Copyright (c) 2020-2022, NVIDIA CORPORATION. All rights reserved. # # NVIDIA CORPORATION and its licensors retain all intellectual property # and proprietary rights in and to this software, related documentation # and any modifications thereto. Any use, reproduction, disclosure or # distribution of this software and related documentation without an express # license agreement from NVIDIA CORPORATION is strictly prohibited. import os import numpy as np import torch from . import util from . import texture from . import mlptexture ###################################################################################### # .mtl material format loading / storing ###################################################################################### def load_mtl(fn, clear_ks=True): import re mtl_path = os.path.dirname(fn) # Read file with open(fn, 'r') as f: lines = f.readlines() # Parse materials materials = [] for line in lines: split_line = re.split(' +|\t+|\n+', line.strip()) prefix = split_line[0].lower() data = split_line[1:] if 'newmtl' in prefix: material = {'name' : data[0]} materials += [material] elif materials: if 'bsdf' in prefix or 'map_kd' in prefix or 'map_ks' in prefix or 'bump' in prefix: material[prefix] = data[0] else: material[prefix] = torch.tensor(tuple(float(d) for d in data), dtype=torch.float32, device='cuda') # Convert everything to textures. Our code expects 'kd' and 'ks' to be texture maps. So replace constants with 1x1 maps for mat in materials: if not 'bsdf' in mat: mat['bsdf'] = 'pbr' if 'map_kd' in mat: mat['kd'] = texture.load_texture2D(os.path.join(mtl_path, mat['map_kd'])) else: mat['kd'] = texture.Texture2D(mat['kd']) if 'map_ks' in mat: mat['ks'] = texture.load_texture2D(os.path.join(mtl_path, mat['map_ks']), channels=3) else: mat['ks'] = texture.Texture2D(mat['ks']) if 'bump' in mat: mat['normal'] = texture.load_texture2D(os.path.join(mtl_path, mat['bump']), lambda_fn=lambda x: x * 2 - 1, channels=3) # Convert Kd from sRGB to linear RGB mat['kd'] = texture.srgb_to_rgb(mat['kd']) if clear_ks: # Override ORM occlusion (red) channel by zeros. We hijack this channel for mip in mat['ks'].getMips(): mip[..., 0] = 0.0 return materials def save_mtl(fn, material): folder = os.path.dirname(fn) with open(fn, "w") as f: f.write('newmtl defaultMat\n') if material is not None: f.write('bsdf %s\n' % material['bsdf']) if 'kd' in material.keys(): f.write('map_Kd texture_kd.png\n') texture.save_texture2D(os.path.join(folder, 'texture_kd.png'), texture.rgb_to_srgb(material['kd'])) if 'ks' in material.keys(): f.write('map_Ks texture_ks.png\n') texture.save_texture2D(os.path.join(folder, 'texture_ks.png'), material['ks']) if 'normal' in material.keys(): f.write('bump texture_n.png\n') texture.save_texture2D(os.path.join(folder, 'texture_n.png'), material['normal'], lambda_fn=lambda x:(util.safe_normalize(x)+1)*0.5) else: f.write('Kd 1 1 1\n') f.write('Ks 0 0 0\n') f.write('Ka 0 0 0\n') f.write('Tf 1 1 1\n') f.write('Ni 1\n') f.write('Ns 0\n') ###################################################################################### # Utility function to convert an existing material and make all textures trainable ###################################################################################### def create_trainable(material): result = material.copy() for key, val in result.items(): if isinstance(val, texture.Texture2D): result[key] = texture.create_trainable(val) return result def get_parameters(material): trainable = [] for key, val in material.items(): if isinstance(val, texture.Texture2D) or isinstance(val, mlptexture.MLPTexture3D): trainable += val.parameters() return trainable ###################################################################################### # Merge multiple materials into a single uber-material ###################################################################################### def _upscale_replicate(x, full_res): x = x.permute(0, 3, 1, 2) x = torch.nn.functional.pad(x, (0, full_res[1] - x.shape[3], 0, full_res[0] - x.shape[2]), 'replicate') return x.permute(0, 2, 3, 1).contiguous() def merge_materials(materials, texcoords, tfaces, mfaces): assert len(materials) > 0 for mat in materials: assert mat['bsdf'] == materials[0]['bsdf'], "All materials must have the same BSDF (uber shader)" assert ('normal' in mat) is ('normal' in materials[0]), "All materials must have either normal map enabled or disabled" uber_material = { 'name' : 'uber_material', 'bsdf' : materials[0]['bsdf'], } textures = ['kd', 'ks', 'normal'] # Find maximum texture resolution across all materials and textures max_res = None for mat in materials: for tex in textures: tex_res = np.array(mat[tex].getRes()) if tex in mat else np.array([1, 1]) max_res = np.maximum(max_res, tex_res) if max_res is not None else tex_res # Compute size of compund texture and round up to nearest PoT full_res = 2**np.ceil(np.log2(max_res * np.array([1, len(materials)]))).astype(np.int) # Normalize texture resolution across all materials & combine into a single large texture for tex in textures: if tex in materials[0]: tex_data = torch.cat(tuple(util.scale_img_nhwc(mat[tex].data, tuple(max_res)) for mat in materials), dim=2) # Lay out all textures horizontally, NHWC so dim2 is x tex_data = _upscale_replicate(tex_data, full_res) uber_material[tex] = texture.Texture2D(tex_data) # Compute scaling values for used / unused texture area s_coeff = [full_res[0] / max_res[0], full_res[1] / max_res[1]] # Recompute texture coordinates to cooincide with new composite texture new_tverts = {} new_tverts_data = [] for fi in range(len(tfaces)): matIdx = mfaces[fi] for vi in range(3): ti = tfaces[fi][vi] if not (ti in new_tverts): new_tverts[ti] = {} if not (matIdx in new_tverts[ti]): # create new vertex new_tverts_data.append([(matIdx + texcoords[ti][0]) / s_coeff[1], texcoords[ti][1] / s_coeff[0]]) # Offset texture coodrinate (x direction) by material id & scale to local space. Note, texcoords are (u,v) but texture is stored (w,h) so the indexes swap here new_tverts[ti][matIdx] = len(new_tverts_data) - 1 tfaces[fi][vi] = new_tverts[ti][matIdx] # reindex vertex return uber_material, new_tverts_data, tfaces ================================================ FILE: render/mesh.py ================================================ # Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual # property and proprietary rights in and to this material, related # documentation and any modifications thereto. Any use, reproduction, # disclosure or distribution of this material and related documentation # without an express license agreement from NVIDIA CORPORATION or # its affiliates is strictly prohibited. import os import numpy as np import torch from . import obj from . import util ###################################################################################### # Base mesh class ###################################################################################### class Mesh: def __init__(self, v_pos=None, t_pos_idx=None, v_nrm=None, t_nrm_idx=None, v_tex=None, t_tex_idx=None, v_tng=None, t_tng_idx=None, material=None, base=None): self.v_pos = v_pos self.v_nrm = v_nrm self.v_tex = v_tex self.v_tng = v_tng self.t_pos_idx = t_pos_idx self.t_nrm_idx = t_nrm_idx self.t_tex_idx = t_tex_idx self.t_tng_idx = t_tng_idx self.material = material if base is not None: self.copy_none(base) def copy_none(self, other): if self.v_pos is None: self.v_pos = other.v_pos if self.t_pos_idx is None: self.t_pos_idx = other.t_pos_idx if self.v_nrm is None: self.v_nrm = other.v_nrm if self.t_nrm_idx is None: self.t_nrm_idx = other.t_nrm_idx if self.v_tex is None: self.v_tex = other.v_tex if self.t_tex_idx is None: self.t_tex_idx = other.t_tex_idx if self.v_tng is None: self.v_tng = other.v_tng if self.t_tng_idx is None: self.t_tng_idx = other.t_tng_idx if self.material is None: self.material = other.material def clone(self): out = Mesh(base=self) if out.v_pos is not None: out.v_pos = out.v_pos.clone().detach() if out.t_pos_idx is not None: out.t_pos_idx = out.t_pos_idx.clone().detach() if out.v_nrm is not None: out.v_nrm = out.v_nrm.clone().detach() if out.t_nrm_idx is not None: out.t_nrm_idx = out.t_nrm_idx.clone().detach() if out.v_tex is not None: out.v_tex = out.v_tex.clone().detach() if out.t_tex_idx is not None: out.t_tex_idx = out.t_tex_idx.clone().detach() if out.v_tng is not None: out.v_tng = out.v_tng.clone().detach() if out.t_tng_idx is not None: out.t_tng_idx = out.t_tng_idx.clone().detach() return out ###################################################################################### # Mesh loeading helper ###################################################################################### def load_mesh(filename, mtl_override=None, mtl_default=None, mtl_type_override=None): name, ext = os.path.splitext(filename) if ext == ".obj": return obj.load_obj(filename, clear_ks=True, mtl_override=mtl_override, mtl_default=mtl_default, mtl_type_override=mtl_type_override) assert False, "Invalid mesh file extension" ###################################################################################### # Compute AABB ###################################################################################### def aabb(mesh): return torch.min(mesh.v_pos, dim=0).values, torch.max(mesh.v_pos, dim=0).values ###################################################################################### # Compute AABB with only used vertices ###################################################################################### def aabb_clean(mesh): v_pos_clean = mesh.v_pos[mesh.t_pos_idx.unique()] return torch.min(v_pos_clean, dim=0).values, torch.max(v_pos_clean, dim=0).values ###################################################################################### # Compute unique edge list from attribute/vertex index list ###################################################################################### def compute_edges(attr_idx, return_inverse=False): with torch.no_grad(): # Create all edges, packed by triangle all_edges = torch.cat(( torch.stack((attr_idx[:, 0], attr_idx[:, 1]), dim=-1), torch.stack((attr_idx[:, 1], attr_idx[:, 2]), dim=-1), torch.stack((attr_idx[:, 2], attr_idx[:, 0]), dim=-1), ), dim=-1).view(-1, 2) # Swap edge order so min index is always first order = (all_edges[:, 0] > all_edges[:, 1]).long().unsqueeze(dim=1) sorted_edges = torch.cat(( torch.gather(all_edges, 1, order), torch.gather(all_edges, 1, 1 - order) ), dim=-1) # Eliminate duplicates and return inverse mapping return torch.unique(sorted_edges, dim=0, return_inverse=return_inverse) ###################################################################################### # Compute unique edge to face mapping from attribute/vertex index list ###################################################################################### def compute_edge_to_face_mapping(attr_idx, return_inverse=False): with torch.no_grad(): # Get unique edges # Create all edges, packed by triangle all_edges = torch.cat(( torch.stack((attr_idx[:, 0], attr_idx[:, 1]), dim=-1), torch.stack((attr_idx[:, 1], attr_idx[:, 2]), dim=-1), torch.stack((attr_idx[:, 2], attr_idx[:, 0]), dim=-1), ), dim=-1).view(-1, 2) # Swap edge order so min index is always first order = (all_edges[:, 0] > all_edges[:, 1]).long().unsqueeze(dim=1) sorted_edges = torch.cat(( torch.gather(all_edges, 1, order), torch.gather(all_edges, 1, 1 - order) ), dim=-1) # Elliminate duplicates and return inverse mapping unique_edges, idx_map = torch.unique(sorted_edges, dim=0, return_inverse=True) tris = torch.arange(attr_idx.shape[0]).repeat_interleave(3).cuda() tris_per_edge = torch.zeros((unique_edges.shape[0], 2), dtype=torch.int64).cuda() # Compute edge to face table mask0 = order[:,0] == 0 mask1 = order[:,0] == 1 tris_per_edge[idx_map[mask0], 0] = tris[mask0] tris_per_edge[idx_map[mask1], 1] = tris[mask1] return tris_per_edge ###################################################################################### # Align base mesh to reference mesh:move & rescale to match bounding boxes. ###################################################################################### def unit_size(mesh): with torch.no_grad(): vmin, vmax = aabb(mesh) scale = 2 / torch.max(vmax - vmin).item() v_pos = mesh.v_pos - (vmax + vmin) / 2 # Center mesh on origin v_pos = v_pos * scale # Rescale to unit size return Mesh(v_pos, base=mesh) ###################################################################################### # Center & scale mesh for rendering ###################################################################################### def center_by_reference(base_mesh, ref_aabb, scale): center = (ref_aabb[0] + ref_aabb[1]).cuda() * 0.5 scale = scale / torch.max(ref_aabb[1] - ref_aabb[0]).item() v_pos = (base_mesh.v_pos - center[None, ...]) * scale return Mesh(v_pos, base=base_mesh) def center_by_reference_noscale(base_mesh, ref_aabb, scale=None): center = (ref_aabb[0] + ref_aabb[1]) * 0.5 v_pos = (base_mesh.v_pos - center[None, ...]) return Mesh(v_pos, base=base_mesh) def center_with_global_aabb(base_mesh, ref_aabb, scale): # center = (base_mesh.v_pos.min(dim=0).values + base_mesh.v_pos.max(dim=0).values) * 0.5 ### used for the experiments... wrong center = ref_aabb.mean(dim=0).cuda() scale = scale / torch.max(ref_aabb[1] - ref_aabb[0]).item() * 2.0 v_pos = (base_mesh.v_pos - center[None, ...]) * scale return Mesh(v_pos, base=base_mesh) def center_with_global_aabb_perdim(base_mesh, ref_aabb, scale): center = (base_mesh.v_pos.min(dim=0).values + base_mesh.v_pos.max(dim=0).values) * 0.5 scale = scale / (ref_aabb[1] - ref_aabb[0]) v_pos = (base_mesh.v_pos - center[None, ...]) * scale.view(1, 3) return Mesh(v_pos, base=base_mesh) def scale_with_global_aabb(base_mesh, ref_aabb, scale): scale = scale / torch.max(ref_aabb[1] - ref_aabb[0]).item() * 0.5 v_pos = base_mesh.v_pos * scale return Mesh(v_pos, base=base_mesh) def scale_with_global_aabb_perdim(base_mesh, ref_aabb, scale): scale = scale / (ref_aabb[1] - ref_aabb[0]) v_pos = base_mesh.v_pos * scale.view(1, 3) return Mesh(v_pos, base=base_mesh) ###################################################################################### # Simple smooth vertex normal computation ###################################################################################### def auto_normals(imesh): i0 = imesh.t_pos_idx[:, 0] i1 = imesh.t_pos_idx[:, 1] i2 = imesh.t_pos_idx[:, 2] v0 = imesh.v_pos[i0, :] v1 = imesh.v_pos[i1, :] v2 = imesh.v_pos[i2, :] face_normals = torch.cross(v1 - v0, v2 - v0) # Splat face normals to vertices v_nrm = torch.zeros_like(imesh.v_pos) v_nrm.scatter_add_(0, i0[:, None].repeat(1,3), face_normals) v_nrm.scatter_add_(0, i1[:, None].repeat(1,3), face_normals) v_nrm.scatter_add_(0, i2[:, None].repeat(1,3), face_normals) # Normalize, replace zero (degenerated) normals with some default value v_nrm = torch.where(util.dot(v_nrm, v_nrm) > 1e-20, v_nrm, torch.tensor([0.0, 0.0, 1.0], dtype=torch.float32, device='cuda')) v_nrm = util.safe_normalize(v_nrm) if torch.is_anomaly_enabled(): assert torch.all(torch.isfinite(v_nrm)) return Mesh(v_nrm=v_nrm, t_nrm_idx=imesh.t_pos_idx, base=imesh) ###################################################################################### # Compute tangent space from texture map coordinates # Follows http://www.mikktspace.com/ conventions ###################################################################################### def compute_tangents(imesh, v_tng=None): if v_tng is not None: v_tng = util.safe_normalize(v_tng) v_tng = util.safe_normalize(v_tng - util.dot(v_tng, imesh.v_nrm) * imesh.v_nrm) return Mesh(v_tng=v_tng, t_tng_idx=imesh.t_nrm_idx, base=imesh) vn_idx = [None] * 3 pos = [None] * 3 tex = [None] * 3 for i in range(0,3): pos[i] = imesh.v_pos[imesh.t_pos_idx[:, i]] tex[i] = imesh.v_tex[imesh.t_tex_idx[:, i]] vn_idx[i] = imesh.t_nrm_idx[:, i] tangents = torch.zeros_like(imesh.v_nrm) tansum = torch.zeros_like(imesh.v_nrm) # Compute tangent space for each triangle uve1 = tex[1] - tex[0] uve2 = tex[2] - tex[0] pe1 = pos[1] - pos[0] pe2 = pos[2] - pos[0] nom = (pe1 * uve2[..., 1:2] - pe2 * uve1[..., 1:2]) denom = (uve1[..., 0:1] * uve2[..., 1:2] - uve1[..., 1:2] * uve2[..., 0:1]) # Avoid division by zero for degenerated texture coordinates tang = nom / torch.where(denom > 0.0, torch.clamp(denom, min=1e-6), torch.clamp(denom, max=-1e-6)) # Update all 3 vertices for i in range(0,3): idx = vn_idx[i][:, None].repeat(1,3) tangents.scatter_add_(0, idx, tang) # tangents[n_i] = tangents[n_i] + tang tansum.scatter_add_(0, idx, torch.ones_like(tang)) # tansum[n_i] = tansum[n_i] + 1 tangents = tangents / tansum # Normalize and make sure tangent is perpendicular to normal tangents = util.safe_normalize(tangents) tangents = util.safe_normalize(tangents - util.dot(tangents, imesh.v_nrm) * imesh.v_nrm) if torch.is_anomaly_enabled(): assert torch.all(torch.isfinite(tangents)) return Mesh(v_tng=tangents, t_tng_idx=imesh.t_nrm_idx, base=imesh) ================================================ FILE: render/mlptexture.py ================================================ # Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual # property and proprietary rights in and to this material, related # documentation and any modifications thereto. Any use, reproduction, # disclosure or distribution of this material and related documentation # without an express license agreement from NVIDIA CORPORATION or # its affiliates is strictly prohibited. import torch import tinycudann as tcnn import numpy as np ####################################################################################################################################################### # Small MLP using PyTorch primitives, internal helper class ####################################################################################################################################################### class _MLP(torch.nn.Module): def __init__(self, cfg, loss_scale=1.0): super(_MLP, self).__init__() self.loss_scale = loss_scale net = (torch.nn.Linear(cfg['n_input_dims'], cfg['n_neurons'], bias=False), torch.nn.ReLU()) for i in range(cfg['n_hidden_layers']-1): net = net + (torch.nn.Linear(cfg['n_neurons'], cfg['n_neurons'], bias=False), torch.nn.ReLU()) net = net + (torch.nn.Linear(cfg['n_neurons'], cfg['n_output_dims'], bias=False),) self.net = torch.nn.Sequential(*net).cuda() self.net.apply(self._init_weights) if self.loss_scale != 1.0: self.net.register_full_backward_hook(lambda module, grad_i, grad_o: (grad_i[0] * self.loss_scale, )) def forward(self, x): return self.net(x.to(torch.float32)) @staticmethod def _init_weights(m): if type(m) == torch.nn.Linear: torch.nn.init.kaiming_uniform_(m.weight, nonlinearity='relu') if hasattr(m.bias, 'data'): m.bias.data.fill_(0.0) ####################################################################################################################################################### # Outward visible MLP class ####################################################################################################################################################### class MLPTexture3D(torch.nn.Module): def __init__(self, AABB, channels = 3, internal_dims = 32, hidden = 2, min_max = None, use_float16=False): super(MLPTexture3D, self).__init__() self.channels = channels self.internal_dims = internal_dims self.AABB = AABB self.min_max = min_max self.use_float16 = use_float16 # Setup positional encoding, see https://github.com/NVlabs/tiny-cuda-nn for details desired_resolution = 4096 base_grid_resolution = 16 num_levels = 16 per_level_scale = np.exp(np.log(desired_resolution / base_grid_resolution) / (num_levels-1)) enc_cfg = { "otype": "HashGrid", "n_levels": num_levels, "n_features_per_level": 2, "log2_hashmap_size": 19, "base_resolution": base_grid_resolution, "per_level_scale" : per_level_scale } gradient_scaling = 128.0 self.encoder = tcnn.Encoding(3, enc_cfg) self.encoder.register_full_backward_hook(lambda module, grad_i, grad_o: (grad_i[0] / gradient_scaling, )) # Setup MLP mlp_cfg = { "n_input_dims" : self.encoder.n_output_dims, "n_output_dims" : self.channels, "n_hidden_layers" : hidden, "n_neurons" : self.internal_dims } self.net = _MLP(mlp_cfg, gradient_scaling) print("Encoder output: %d dims" % (self.encoder.n_output_dims)) # Sample texture at a given location def sample(self, texc): _texc = (texc.view(-1, 3) - self.AABB[0][None, ...]) / (self.AABB[1][None, ...] - self.AABB[0][None, ...]) _texc = torch.clamp(_texc, min=0, max=1) p_enc = self.encoder(_texc.contiguous()) with torch.autocast('cuda', dtype=torch.float16, enabled=self.use_float16): out = self.net.forward(p_enc) # Sigmoid limit and scale to the allowed range out = torch.sigmoid(out) * (self.min_max[1][None, :] - self.min_max[0][None, :]) + self.min_max[0][None, :] return out.view(*texc.shape[:-1], self.channels) # Remap to [n, h, w, c] # In-place clamp with no derivative to make sure values are in valid range after training def clamp_(self): pass def cleanup(self): tcnn.free_temporary_memory() ================================================ FILE: render/obj.py ================================================ # Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual # property and proprietary rights in and to this material, related # documentation and any modifications thereto. Any use, reproduction, # disclosure or distribution of this material and related documentation # without an express license agreement from NVIDIA CORPORATION or # its affiliates is strictly prohibited. import os import torch from . import texture from . import mesh from . import material ###################################################################################### # Utility functions ###################################################################################### def _find_mat(materials, name): for mat in materials: if mat['name'] == name: return mat return materials[0] # Materials 0 is the default ###################################################################################### # Create mesh object from objfile ###################################################################################### def load_obj(filename, clear_ks=True, mtl_override=None, mtl_default=None, mtl_type_override=None): obj_path = os.path.dirname(filename) # Read entire file with open(filename, 'r') as f: lines = f.readlines() # Load materials if mtl_default is None: all_materials = [ { 'name' : '_default_mat', 'bsdf' : 'pbr', 'kd' : texture.Texture2D(torch.tensor([0.5, 0.5, 0.5], dtype=torch.float32, device='cuda')), 'ks' : texture.Texture2D(torch.tensor([0.0, 0.0, 0.0], dtype=torch.float32, device='cuda')) } ] if mtl_override is None: for line in lines: if len(line.split()) == 0: continue if line.split()[0] == 'mtllib': all_materials += material.load_mtl(os.path.join(obj_path, line.split()[1]), clear_ks) # Read in entire material library else: all_materials += material.load_mtl(mtl_override) else: print("Load use-defined default mtl") all_materials = [mtl_default] if mtl_type_override is not None: for k in range(len(all_materials)): all_materials[k]['bsdf'] = mtl_type_override # load vertices vertices, texcoords, normals = [], [], [] for line in lines: if len(line.split()) == 0: continue prefix = line.split()[0].lower() if prefix == 'v': vertices.append([float(v) for v in line.split()[1:]]) elif prefix == 'vt': val = [float(v) for v in line.split()[1:]] texcoords.append([val[0], 1.0 - val[1]]) elif prefix == 'vn': normals.append([float(v) for v in line.split()[1:]]) # load faces activeMatIdx = None used_materials = [] faces, tfaces, nfaces, mfaces = [], [], [], [] for line in lines: if len(line.split()) == 0: continue prefix = line.split()[0].lower() if prefix == 'usemtl': # Track used materials mat = _find_mat(all_materials, line.split()[1]) if not mat in used_materials: used_materials.append(mat) activeMatIdx = used_materials.index(mat) elif prefix == 'f': # Parse face vs = line.split()[1:] nv = len(vs) vv = vs[0].split('/') v0 = int(vv[0]) - 1 t0 = int(vv[1]) - 1 if vv[1] != "" else -1 n0 = int(vv[2]) - 1 if vv[2] != "" else -1 for i in range(nv - 2): # Triangulate polygons vv = vs[i + 1].split('/') v1 = int(vv[0]) - 1 t1 = int(vv[1]) - 1 if vv[1] != "" else -1 n1 = int(vv[2]) - 1 if vv[2] != "" else -1 vv = vs[i + 2].split('/') v2 = int(vv[0]) - 1 t2 = int(vv[1]) - 1 if vv[1] != "" else -1 n2 = int(vv[2]) - 1 if vv[2] != "" else -1 mfaces.append(activeMatIdx) faces.append([v0, v1, v2]) tfaces.append([t0, t1, t2]) nfaces.append([n0, n1, n2]) assert len(tfaces) == len(faces) and len(nfaces) == len (faces) # Create an "uber" material by combining all textures into a larger texture if len(used_materials) > 1: uber_material, texcoords, tfaces = material.merge_materials(used_materials, texcoords, tfaces, mfaces) else: uber_material = used_materials[0] # elif len(used_materials) == 1: # uber_material = used_materials[0] # else: # uber_material = [all_materials[0]] vertices = torch.tensor(vertices, dtype=torch.float32, device='cuda') texcoords = torch.tensor(texcoords, dtype=torch.float32, device='cuda') if len(texcoords) > 0 else None normals = torch.tensor(normals, dtype=torch.float32, device='cuda') if len(normals) > 0 else None faces = torch.tensor(faces, dtype=torch.int64, device='cuda') tfaces = torch.tensor(tfaces, dtype=torch.int64, device='cuda') if texcoords is not None else None nfaces = torch.tensor(nfaces, dtype=torch.int64, device='cuda') if normals is not None else None vertices = vertices[:, :3] return mesh.Mesh(vertices, faces, normals, nfaces, texcoords, tfaces, material=uber_material) ###################################################################################### # Save mesh object to objfile ###################################################################################### def write_obj(folder, mesh, save_material=True): obj_file = os.path.join(folder, 'mesh.obj') print("Writing mesh: ", obj_file) with open(obj_file, "w") as f: f.write("mtllib mesh.mtl\n") f.write("g default\n") v_pos = mesh.v_pos.detach().cpu().numpy() if mesh.v_pos is not None else None v_nrm = mesh.v_nrm.detach().cpu().numpy() if mesh.v_nrm is not None else None v_tex = mesh.v_tex.detach().cpu().numpy() if mesh.v_tex is not None else None t_pos_idx = mesh.t_pos_idx.detach().cpu().numpy() if mesh.t_pos_idx is not None else None t_nrm_idx = mesh.t_nrm_idx.detach().cpu().numpy() if mesh.t_nrm_idx is not None else None t_tex_idx = mesh.t_tex_idx.detach().cpu().numpy() if mesh.t_tex_idx is not None else None print(" writing %d vertices" % len(v_pos)) for v in v_pos: f.write('v {} {} {} \n'.format(v[0], v[1], v[2])) if v_tex is not None: print(" writing %d texcoords" % len(v_tex)) assert(len(t_pos_idx) == len(t_tex_idx)) for v in v_tex: f.write('vt {} {} \n'.format(v[0], 1.0 - v[1])) if v_nrm is not None: print(" writing %d normals" % len(v_nrm)) assert(len(t_pos_idx) == len(t_nrm_idx)) for v in v_nrm: f.write('vn {} {} {}\n'.format(v[0], v[1], v[2])) # faces f.write("s 1 \n") f.write("g pMesh1\n") f.write("usemtl defaultMat\n") # Write faces print(" writing %d faces" % len(t_pos_idx)) for i in range(len(t_pos_idx)): f.write("f ") for j in range(3): f.write(' %s/%s/%s' % (str(t_pos_idx[i][j]+1), '' if v_tex is None else str(t_tex_idx[i][j]+1), '' if v_nrm is None else str(t_nrm_idx[i][j]+1))) f.write("\n") if save_material: mtl_file = os.path.join(folder, 'mesh.mtl') print("Writing material: ", mtl_file) material.save_mtl(mtl_file, mesh.material) print("Done exporting mesh") ================================================ FILE: render/optixutils/__init__.py ================================================ # Copyright (c) 2020-2022, NVIDIA CORPORATION. All rights reserved. # # NVIDIA CORPORATION and its licensors retain all intellectual property # and proprietary rights in and to this software, related documentation # and any modifications thereto. Any use, reproduction, disclosure or # distribution of this software and related documentation without an express # license agreement from NVIDIA CORPORATION is strictly prohibited. from .ops import OptiXContext, optix_build_bvh, optix_env_shade, bilateral_denoiser __all__ = ["OptiXContext", "optix_build_bvh", "optix_env_shade", 'bilateral_denoiser'] ================================================ FILE: render/optixutils/c_src/accessor.h ================================================ // Copyright (c) 2020-2022, NVIDIA CORPORATION. All rights reserved. // // NVIDIA CORPORATION and its licensors retain all intellectual property // and proprietary rights in and to this software, related documentation // and any modifications thereto. Any use, reproduction, disclosure or // distribution of this software and related documentation without an express // license agreement from NVIDIA CORPORATION is strictly prohibited. // Stripped down version from pytorch. Made to work with optix kernels where it's // hard to include dependencies // https://github.com/pytorch/pytorch/blob/dc169d53aa266560750ea25ee0cf31c7e614550d/aten/src/ATen/core/TensorAccessor.h ///////////////////////////////////////////////////////////////////////////// // From PyTorch: // Copyright (c) 2016- Facebook, Inc (Adam Paszke) // Copyright (c) 2014- Facebook, Inc (Soumith Chintala) // Copyright (c) 2011-2014 Idiap Research Institute (Ronan Collobert) // Copyright (c) 2012-2014 Deepmind Technologies (Koray Kavukcuoglu) // Copyright (c) 2011-2012 NEC Laboratories America (Koray Kavukcuoglu) // Copyright (c) 2011-2013 NYU (Clement Farabet) // Copyright (c) 2006-2010 NEC Laboratories America (Ronan Collobert, Leon Bottou, Iain Melvin, Jason Weston) // Copyright (c) 2006 Idiap Research Institute (Samy Bengio) // Copyright (c) 2001-2004 Idiap Research Institute (Ronan Collobert, Samy Bengio, Johnny Mariethoz) // From Caffe2: // Copyright (c) 2016-present, Facebook Inc. All rights reserved. // All contributions by Facebook: // Copyright (c) 2016 Facebook Inc. // All contributions by Google: // Copyright (c) 2015 Google Inc. // All rights reserved. // All contributions by Yangqing Jia: // Copyright (c) 2015 Yangqing Jia // All rights reserved. // All contributions by Kakao Brain: // Copyright 2019-2020 Kakao Brain // All contributions by Cruise LLC: // Copyright (c) 2022 Cruise LLC. // All rights reserved. // All contributions from Caffe: // Copyright(c) 2013, 2014, 2015, the respective contributors // All rights reserved. // All other contributions: // Copyright(c) 2015, 2016 the respective contributors // All rights reserved. // Caffe2 uses a copyright model similar to Caffe: each contributor holds // copyright over their contributions to Caffe2. The project versioning records // all such contribution and copyright details. If a contributor wants to further // mark their specific copyright on a particular contribution, they should // indicate their copyright solely in the commit message of the change when it is // committed. // All rights reserved. // Redistribution and use in source and binary forms, with or without // modification, are permitted provided that the following conditions are met: // 1. Redistributions of source code must retain the above copyright // notice, this list of conditions and the following disclaimer. // 2. Redistributions in binary form must reproduce the above copyright // notice, this list of conditions and the following disclaimer in the // documentation and/or other materials provided with the distribution. // 3. Neither the names of Facebook, Deepmind Technologies, NYU, NEC Laboratories America // and IDIAP Research Institute nor the names of its contributors may be // used to endorse or promote products derived from this software without // specific prior written permission. // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" // AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE // IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE // ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE // LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR // CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF // SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS // INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN // CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) // ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE // POSSIBILITY OF SUCH DAMAGE. ///////////////////////////////////////////////////////////////////////////// #pragma once #if defined(__OPTIX__) typedef int int32_t; typedef long long int64_t; #else #include #endif #ifdef __CUDACC__ #ifdef __CUDA_ARCH__ #define C10_DEVICE __device__ #define C10_HOST_DEVICE __device__ #else #define C10_DEVICE __device__ #define C10_HOST __host__ #define C10_HOST_DEVICE __host__ __device__ #endif #else #include #define C10_HOST_DEVICE #define C10_HOST #endif // The PtrTraits argument to the TensorAccessor/GenericPackedTensorAccessor // is used to enable the __restrict__ keyword/modifier for the data // passed to cuda. template struct DefaultPtrTraits { typedef T* PtrType; }; #if defined(__CUDACC__) || defined(__HIPCC__) template struct RestrictPtrTraits { typedef T* __restrict__ PtrType; }; #endif // TensorAccessorBase and TensorAccessor are used for both CPU and CUDA tensors. // For CUDA tensors it is used in device code (only). This means that we restrict ourselves // to functions and types available there (e.g. IntArrayRef isn't). // The PtrTraits argument is only relevant to cuda to support `__restrict__` pointers. template class PtrTraits = DefaultPtrTraits, typename index_t = int64_t> class TensorAccessorBase { public: typedef typename PtrTraits::PtrType PtrType; C10_HOST_DEVICE TensorAccessorBase( PtrType data_, const index_t* sizes_, const index_t* strides_) : data_(data_), sizes_(sizes_), strides_(strides_) {} C10_HOST_DEVICE index_t stride(index_t i) const { return strides_[i]; } C10_HOST_DEVICE index_t size(index_t i) const { return sizes_[i]; } C10_HOST_DEVICE PtrType data() { return data_; } C10_HOST_DEVICE const PtrType data() const { return data_; } protected: PtrType data_; const index_t* sizes_; const index_t* strides_; }; // The `TensorAccessor` is typically instantiated for CPU `Tensor`s using // `Tensor.accessor()`. // For CUDA `Tensor`s, `GenericPackedTensorAccessor` is used on the host and only // indexing on the device uses `TensorAccessor`s. template class PtrTraits = DefaultPtrTraits, typename index_t = int64_t> class TensorAccessor : public TensorAccessorBase { public: typedef typename PtrTraits::PtrType PtrType; C10_HOST_DEVICE TensorAccessor( PtrType data_, const index_t* sizes_, const index_t* strides_) : TensorAccessorBase(data_,sizes_,strides_) {} C10_HOST_DEVICE TensorAccessor operator[](index_t i) { return TensorAccessor(this->data_ + this->strides_[0]*i,this->sizes_+1,this->strides_+1); } C10_HOST_DEVICE const TensorAccessor operator[](index_t i) const { return TensorAccessor(this->data_ + this->strides_[0]*i,this->sizes_+1,this->strides_+1); } }; template class PtrTraits, typename index_t> class TensorAccessor : public TensorAccessorBase { public: typedef typename PtrTraits::PtrType PtrType; C10_HOST_DEVICE TensorAccessor( PtrType data_, const index_t* sizes_, const index_t* strides_) : TensorAccessorBase(data_,sizes_,strides_) {} C10_HOST_DEVICE T & operator[](index_t i) { // NOLINTNEXTLINE(clang-analyzer-core.NullDereference) return this->data_[this->strides_[0]*i]; } C10_HOST_DEVICE const T & operator[](index_t i) const { return this->data_[this->strides_[0]*i]; } }; // GenericPackedTensorAccessorBase and GenericPackedTensorAccessor are used on for CUDA `Tensor`s on the host // and as // In contrast to `TensorAccessor`s, they copy the strides and sizes on instantiation (on the host) // in order to transfer them on the device when calling kernels. // On the device, indexing of multidimensional tensors gives to `TensorAccessor`s. // Use RestrictPtrTraits as PtrTraits if you want the tensor's data pointer to be marked as __restrict__. // Instantiation from data, sizes, strides is only needed on the host and std::copy isn't available // on the device, so those functions are host only. template class PtrTraits = DefaultPtrTraits, typename index_t = int64_t> class GenericPackedTensorAccessorBase { public: typedef typename PtrTraits::PtrType PtrType; #if !defined(__CUDACC__) C10_HOST GenericPackedTensorAccessorBase() {} C10_HOST GenericPackedTensorAccessorBase( PtrType data_, const index_t* sizes_, const index_t* strides_) : data_(data_) { std::copy(sizes_, sizes_ + N, std::begin(this->sizes_)); std::copy(strides_, strides_ + N, std::begin(this->strides_)); } // if index_t is not int64_t, we want to have an int64_t constructor template ::value>::type> C10_HOST GenericPackedTensorAccessorBase( PtrType data_, const source_index_t* sizes_, const source_index_t* strides_) : data_(data_) { for (const auto i : c10::irange(N)) { this->sizes_[i] = sizes_[i]; this->strides_[i] = strides_[i]; } } #endif C10_HOST_DEVICE index_t stride(index_t i) const { return strides_[i]; } C10_HOST_DEVICE index_t size(index_t i) const { return sizes_[i]; } C10_HOST_DEVICE PtrType data() { return data_; } C10_HOST_DEVICE const PtrType data() const { return data_; } protected: PtrType data_; index_t sizes_[N]; index_t strides_[N]; }; template class PtrTraits = DefaultPtrTraits, typename index_t = int64_t> class GenericPackedTensorAccessor : public GenericPackedTensorAccessorBase { public: typedef typename PtrTraits::PtrType PtrType; #if !defined(__CUDACC__) C10_HOST GenericPackedTensorAccessor() : GenericPackedTensorAccessorBase() {} C10_HOST GenericPackedTensorAccessor( PtrType data_, const index_t* sizes_, const index_t* strides_) : GenericPackedTensorAccessorBase(data_, sizes_, strides_) {} // if index_t is not int64_t, we want to have an int64_t constructor template ::value>::type> C10_HOST GenericPackedTensorAccessor( PtrType data_, const source_index_t* sizes_, const source_index_t* strides_) : GenericPackedTensorAccessorBase(data_, sizes_, strides_) {} #else C10_DEVICE TensorAccessor operator[](index_t i) { index_t* new_sizes = this->sizes_ + 1; index_t* new_strides = this->strides_ + 1; return TensorAccessor(this->data_ + this->strides_[0]*i, new_sizes, new_strides); } C10_DEVICE const TensorAccessor operator[](index_t i) const { const index_t* new_sizes = this->sizes_ + 1; const index_t* new_strides = this->strides_ + 1; return TensorAccessor(this->data_ + this->strides_[0]*i, new_sizes, new_strides); } #endif }; template class PtrTraits, typename index_t> class GenericPackedTensorAccessor : public GenericPackedTensorAccessorBase { public: typedef typename PtrTraits::PtrType PtrType; #if !defined(__CUDACC__) C10_HOST GenericPackedTensorAccessor() : GenericPackedTensorAccessorBase() {} C10_HOST GenericPackedTensorAccessor( PtrType data_, const index_t* sizes_, const index_t* strides_) : GenericPackedTensorAccessorBase(data_, sizes_, strides_) {} // if index_t is not int64_t, we want to have an int64_t constructor template ::value>::type> C10_HOST GenericPackedTensorAccessor( PtrType data_, const source_index_t* sizes_, const source_index_t* strides_) : GenericPackedTensorAccessorBase(data_, sizes_, strides_) {} #else C10_DEVICE T & operator[](index_t i) { return this->data_[this->strides_[0] * i]; } C10_DEVICE const T& operator[](index_t i) const { return this->data_[this->strides_[0]*i]; } #endif }; template class PtrTraits = DefaultPtrTraits> using PackedTensorAccessor32 = GenericPackedTensorAccessor; template class PtrTraits = DefaultPtrTraits> using PackedTensorAccessor64 = GenericPackedTensorAccessor; ================================================ FILE: render/optixutils/c_src/bsdf.h ================================================ // Copyright (c) 2020-2022, NVIDIA CORPORATION. All rights reserved. // // NVIDIA CORPORATION and its licensors retain all intellectual property // and proprietary rights in and to this software, related documentation // and any modifications thereto. Any use, reproduction, disclosure or // distribution of this software and related documentation without an express // license agreement from NVIDIA CORPORATION is strictly prohibited. #pragma once #ifdef __CUDACC__ #define SPECULAR_EPSILON 1e-4f #ifndef M_PI #define M_PI 3.14159265358979323846f #endif //------------------------------------------------------------------------ // Lambert functions __device__ inline float fwdLambert(const float3 nrm, const float3 wi) { return max(dot(nrm, wi) / M_PI, 0.0f); } __device__ inline void bwdLambert(const float3 nrm, const float3 wi, float3& d_nrm, float3& d_wi, const float d_out) { if (dot(nrm, wi) > 0.0f) bwd_dot(nrm, wi, d_nrm, d_wi, d_out / M_PI); } //------------------------------------------------------------------------ // Fresnel Schlick __device__ inline float fwdFresnelSchlick(const float f0, const float f90, const float cosTheta) { float _cosTheta = clamp(cosTheta, SPECULAR_EPSILON, 1.0f - SPECULAR_EPSILON); float scale = powf(1.0f - _cosTheta, 5.0f); return f0 * (1.0f - scale) + f90 * scale; } __device__ inline void bwdFresnelSchlick(const float f0, const float f90, const float cosTheta, float& d_f0, float& d_f90, float& d_cosTheta, const float d_out) { float _cosTheta = clamp(cosTheta, SPECULAR_EPSILON, 1.0f - SPECULAR_EPSILON); float scale = pow(max(1.0f - _cosTheta, 0.0f), 5.0f); d_f0 += d_out * (1.0 - scale); d_f90 += d_out * scale; if (cosTheta >= SPECULAR_EPSILON && cosTheta < 1.0f - SPECULAR_EPSILON) { d_cosTheta += d_out * (f90 - f0) * -5.0f * powf(1.0f - cosTheta, 4.0f); } } __device__ inline float3 fwdFresnelSchlick(const float3 f0, const float3 f90, const float cosTheta) { float _cosTheta = clamp(cosTheta, SPECULAR_EPSILON, 1.0f - SPECULAR_EPSILON); float scale = powf(1.0f - _cosTheta, 5.0f); return f0 * (1.0f - scale) + f90 * scale; } __device__ inline void bwdFresnelSchlick(const float3 f0, const float3 f90, const float cosTheta, float3& d_f0, float3& d_f90, float& d_cosTheta, const float3 d_out) { float _cosTheta = clamp(cosTheta, SPECULAR_EPSILON, 1.0f - SPECULAR_EPSILON); float scale = pow(max(1.0f - _cosTheta, 0.0f), 5.0f); d_f0 += d_out * (1.0 - scale); d_f90 += d_out * scale; if (cosTheta >= SPECULAR_EPSILON && cosTheta < 1.0f - SPECULAR_EPSILON) { d_cosTheta += sum(d_out * (f90 - f0) * -5.0f * powf(1.0f - cosTheta, 4.0f)); } } //------------------------------------------------------------------------ // Ndf GGX __device__ inline float fwdNdfGGX(const float alphaSqr, const float cosTheta) { float _cosTheta = clamp(cosTheta, SPECULAR_EPSILON, 1.0f - SPECULAR_EPSILON); float d = (_cosTheta * alphaSqr - _cosTheta) * _cosTheta + 1.0f; return alphaSqr / (d * d * M_PI); } __device__ inline void bwdNdfGGX(const float alphaSqr, const float cosTheta, float& d_alphaSqr, float& d_cosTheta, const float d_out) { // Torch only back propagates if clamp doesn't trigger float _cosTheta = clamp(cosTheta, SPECULAR_EPSILON, 1.0f - SPECULAR_EPSILON); float cosThetaSqr = _cosTheta * _cosTheta; d_alphaSqr += d_out * (1.0f - (alphaSqr + 1.0f) * cosThetaSqr) / (M_PI * powf((alphaSqr - 1.0) * cosThetaSqr + 1.0f, 3.0f)); if (cosTheta > SPECULAR_EPSILON && cosTheta < 1.0f - SPECULAR_EPSILON) { d_cosTheta += d_out * -(4.0f * (alphaSqr - 1.0f) * alphaSqr * cosTheta) / (M_PI * powf((alphaSqr - 1.0) * cosThetaSqr + 1.0f, 3.0f)); } } //------------------------------------------------------------------------ // Lambda GGX __device__ inline float fwdLambdaGGX(const float alphaSqr, const float cosTheta) { float _cosTheta = clamp(cosTheta, SPECULAR_EPSILON, 1.0f - SPECULAR_EPSILON); float cosThetaSqr = _cosTheta * _cosTheta; float tanThetaSqr = (1.0 - cosThetaSqr) / cosThetaSqr; float res = 0.5f * (sqrtf(1.0f + alphaSqr * tanThetaSqr) - 1.0f); return res; } __device__ inline void bwdLambdaGGX(const float alphaSqr, const float cosTheta, float& d_alphaSqr, float& d_cosTheta, const float d_out) { float _cosTheta = clamp(cosTheta, SPECULAR_EPSILON, 1.0f - SPECULAR_EPSILON); float cosThetaSqr = _cosTheta * _cosTheta; float tanThetaSqr = (1.0 - cosThetaSqr) / cosThetaSqr; float res = 0.5f * (sqrtf(1.0f + alphaSqr * tanThetaSqr) - 1.0f); d_alphaSqr += d_out * (0.25 * tanThetaSqr) / sqrtf(alphaSqr * tanThetaSqr + 1.0f); if (cosTheta > SPECULAR_EPSILON && cosTheta < 1.0f - SPECULAR_EPSILON) d_cosTheta += d_out * -(0.5 * alphaSqr) / (powf(_cosTheta, 3.0f) * sqrtf(alphaSqr / cosThetaSqr - alphaSqr + 1.0f)); } //------------------------------------------------------------------------ // Masking GGX __device__ inline float fwdMaskingSmithGGXCorrelated(const float alphaSqr, const float cosThetaI, const float cosThetaO) { float lambdaI = fwdLambdaGGX(alphaSqr, cosThetaI); float lambdaO = fwdLambdaGGX(alphaSqr, cosThetaO); return 1.0f / (1.0f + lambdaI + lambdaO); } __device__ inline void bwdMaskingSmithGGXCorrelated(const float alphaSqr, const float cosThetaI, const float cosThetaO, float& d_alphaSqr, float& d_cosThetaI, float& d_cosThetaO, const float d_out) { // FWD eval float lambdaI = fwdLambdaGGX(alphaSqr, cosThetaI); float lambdaO = fwdLambdaGGX(alphaSqr, cosThetaO); // BWD eval float d_lambdaIO = -d_out / powf(1.0f + lambdaI + lambdaO, 2.0f); bwdLambdaGGX(alphaSqr, cosThetaI, d_alphaSqr, d_cosThetaI, d_lambdaIO); bwdLambdaGGX(alphaSqr, cosThetaO, d_alphaSqr, d_cosThetaO, d_lambdaIO); } //------------------------------------------------------------------------ // GGX specular __device__ float3 fwdPbrSpecular(const float3 col, const float3 nrm, const float3 wo, const float3 wi, const float alpha, const float min_roughness) { float _alpha = clamp(alpha, min_roughness * min_roughness, 1.0f); float alphaSqr = _alpha * _alpha; float3 h = safe_normalize(wo + wi); float woDotN = dot(wo, nrm); float wiDotN = dot(wi, nrm); float woDotH = dot(wo, h); float nDotH = dot(nrm, h); float D = fwdNdfGGX(alphaSqr, nDotH); float G = fwdMaskingSmithGGXCorrelated(alphaSqr, woDotN, wiDotN); float3 F = fwdFresnelSchlick(col, make_float3(1.0f), woDotH); float3 w = F * D * G * 0.25 / woDotN; bool frontfacing = (woDotN > SPECULAR_EPSILON) & (wiDotN > SPECULAR_EPSILON); return frontfacing ? w : make_float3(0.0f); } __device__ void bwdPbrSpecular( const float3 col, const float3 nrm, const float3 wo, const float3 wi, const float alpha, const float min_roughness, float3& d_col, float3& d_nrm, float3& d_wo, float3& d_wi, float& d_alpha, const float3 d_out) { /////////////////////////////////////////////////////////////////////// // FWD eval float _alpha = clamp(alpha, min_roughness * min_roughness, 1.0f); float alphaSqr = _alpha * _alpha; float3 h = safe_normalize(wo + wi); float woDotN = dot(wo, nrm); float wiDotN = dot(wi, nrm); float woDotH = dot(wo, h); float nDotH = dot(nrm, h); float D = fwdNdfGGX(alphaSqr, nDotH); float G = fwdMaskingSmithGGXCorrelated(alphaSqr, woDotN, wiDotN); float3 F = fwdFresnelSchlick(col, make_float3(1.0f), woDotH); float3 w = F * D * G * 0.25 / woDotN; bool frontfacing = (woDotN > SPECULAR_EPSILON) & (wiDotN > SPECULAR_EPSILON); if (frontfacing) { /////////////////////////////////////////////////////////////////////// // BWD eval float3 d_F = d_out * D * G * 0.25f / woDotN; float d_D = sum(d_out * F * G * 0.25f / woDotN); float d_G = sum(d_out * F * D * 0.25f / woDotN); float d_woDotN = -sum(d_out * F * D * G * 0.25f / (woDotN * woDotN)); float3 d_f90 = make_float3(0); float d_woDotH = 0, d_wiDotN = 0, d_nDotH = 0, d_alphaSqr = 0; bwdFresnelSchlick(col, make_float3(1.0f), woDotH, d_col, d_f90, d_woDotH, d_F); bwdMaskingSmithGGXCorrelated(alphaSqr, woDotN, wiDotN, d_alphaSqr, d_woDotN, d_wiDotN, d_G); bwdNdfGGX(alphaSqr, nDotH, d_alphaSqr, d_nDotH, d_D); float3 d_h = make_float3(0); bwd_dot(nrm, h, d_nrm, d_h, d_nDotH); bwd_dot(wo, h, d_wo, d_h, d_woDotH); bwd_dot(wi, nrm, d_wi, d_nrm, d_wiDotN); bwd_dot(wo, nrm, d_wo, d_nrm, d_woDotN); float3 d_h_unnorm = make_float3(0); bwd_safe_normalize(wo + wi, d_h_unnorm, d_h); d_wo += d_h_unnorm; d_wi += d_h_unnorm; if (alpha > min_roughness * min_roughness) d_alpha += d_alphaSqr * 2 * alpha; } } //------------------------------------------------------------------------ // Full PBR BSDF __device__ void fwdPbrBSDF(const float3 kd, const float3 arm, const float3 pos, const float3 nrm, const float3 view_pos, const float3 wi, const float min_roughness, float3 &diffuse, float3 &specular) { float3 wo = safe_normalize(view_pos - pos); float alpha = arm.y * arm.y; float3 spec_col = (make_float3(0.04f) * (1.0f - arm.z) + kd * arm.z) * (1.0 - arm.x); // Removed because of demodulated albedo. // float3 diff_col = kd * (1.0f - arm.z); float diff = 0.0f; diff = fwdLambert(nrm, wi); diffuse = make_float3(diff);//diff_col * diff; specular = fwdPbrSpecular(spec_col, nrm, wo, wi, alpha, min_roughness); } __device__ void bwdPbrBSDF( const float3 kd, const float3 arm, const float3 pos, const float3 nrm, const float3 view_pos, const float3 wi, const float min_roughness, float3& d_kd, float3& d_arm, float3& d_pos, float3& d_nrm, float3& d_view_pos, float3& d_wi, const float3 d_diffuse, float3 d_specular) { //////////////////////////////////////////////////////////////////////// // FWD float3 _wo = view_pos - pos; float3 wo = safe_normalize(_wo); float alpha = arm.y * arm.y; float3 spec_col = (make_float3(0.04f) * (1.0f - arm.z) + kd * arm.z) * (1.0 - arm.x); //////////////////////////////////////////////////////////////////////// // BWD float d_alpha = 0; d_wi = make_float3(0); float3 d_spec_col = make_float3(0), d_wo = make_float3(0); bwdPbrSpecular(spec_col, nrm, wo, wi, alpha, min_roughness, d_spec_col, d_nrm, d_wo, d_wi, d_alpha, d_specular); // float d_diff = sum(diff_col * d_diffuse); float d_diff = sum(d_diffuse); bwdLambert(nrm, wi, d_nrm, d_wi, d_diff); // Backprop: spec_col = (0.04f * (1.0f - arm.z) + kd * arm.z) * (1.0 - arm.x) d_kd -= d_spec_col * (arm.x - 1.0f) * arm.z; d_arm.x += sum(d_spec_col * (arm.z * (make_float3(0.04f) - kd) - 0.04f)); d_arm.z -= sum(d_spec_col * (kd - make_float3(0.04f)) * (arm.x - 1.0f)); // Backprop: alpha = arm.y * arm.y d_arm.y += d_alpha * 2 * arm.y; // Backprop: float3 wo = safe_normalize(view_pos - pos); float3 d__wo = make_float3(0); bwd_safe_normalize(_wo, d__wo, d_wo); d_view_pos += d__wo; d_pos -= d__wo; } #endif ================================================ FILE: render/optixutils/c_src/common.h ================================================ // Copyright (c) 2020-2022, NVIDIA CORPORATION. All rights reserved. // // NVIDIA CORPORATION and its licensors retain all intellectual property // and proprietary rights in and to this software, related documentation // and any modifications thereto. Any use, reproduction, disclosure or // distribution of this software and related documentation without an express // license agreement from NVIDIA CORPORATION is strictly prohibited. #pragma once // Helper functions to do broadcast guarded fetches #if defined(__CUDACC__) template static __device__ inline float3 fetch3(const T &tensor, U idx, Args... args) { return tensor.size(0) == 1 ? fetch3(tensor[0], args...) : fetch3(tensor[idx], args...); } template static __device__ inline float3 fetch3(const T &tensor) { return tensor.size(0) == 1 ? make_float3(tensor[0], tensor[0], tensor[0]) : make_float3(tensor[0], tensor[1], tensor[2]); } template static __device__ inline float2 fetch2(const T &tensor, U idx, Args... args) { return tensor.size(0) == 1 ? fetch2(tensor[0], args...) : fetch2(tensor[idx], args...); } template static __device__ inline float2 fetch2(const T &tensor) { return tensor.size(0) == 1 ? make_float2(tensor[0], tensor[0]) : make_float2(tensor[0], tensor[1]); } #include "math_utils.h" #include "bsdf.h" #endif //------------------------------------------------------------------------------ // CUDA error-checking macros //------------------------------------------------------------------------------ #define CUDA_CHECK( call ) \ do \ { \ cudaError_t error = call; \ if( error != cudaSuccess ) \ { \ std::stringstream ss; \ ss << "CUDA call (" << #call << " ) failed with error: '" \ << cudaGetErrorString( error ) \ << "' (" __FILE__ << ":" << __LINE__ << ")\n"; \ } \ } while( 0 ) #define OPTIX_CHECK( call ) \ do \ { \ OptixResult res = call; \ if( res != OPTIX_SUCCESS ) \ { \ std::stringstream ss; \ ss << "Optix call '" << #call << "' failed: " __FILE__ ":" \ << __LINE__ << ")\n"; \ } \ } while( 0 ) #define OPTIX_CHECK_LOG( call ) \ do \ { \ OptixResult res = call; \ const size_t sizeof_log_returned = sizeof_log; \ sizeof_log = sizeof( log ); /* reset sizeof_log for future calls */ \ if( res != OPTIX_SUCCESS ) \ { \ std::stringstream ss; \ ss << "Optix call '" << #call << "' failed: " __FILE__ ":" \ << __LINE__ << ")\nLog:\n" << log \ << ( sizeof_log_returned > sizeof( log ) ? "" : "" ) \ << "\n"; \ } \ } while( 0 ) #define NVRTC_CHECK_ERROR( func ) \ do \ { \ nvrtcResult code = func; \ if( code != NVRTC_SUCCESS ) \ throw std::runtime_error( "ERROR: " __FILE__ "(): " + std::string( nvrtcGetErrorString( code ) ) ); \ } while( 0 ) ================================================ FILE: render/optixutils/c_src/denoising.cu ================================================ // Copyright (c) 2020-2022, NVIDIA CORPORATION. All rights reserved. // // NVIDIA CORPORATION and its licensors retain all intellectual property // and proprietary rights in and to this software, related documentation // and any modifications thereto. Any use, reproduction, disclosure or // distribution of this software and related documentation without an express // license agreement from NVIDIA CORPORATION is strictly prohibited. #include "common.h" #include "denoising.h" #define FLT_EPS 0.0001f __global__ void bilateral_denoiser_fwd_kernel(BilateralDenoiserParams params) { uint3 idx = make_uint3(blockIdx.x * blockDim.x + threadIdx.x, blockIdx.y * blockDim.y + threadIdx.y, blockIdx.z * blockDim.z + threadIdx.z); if (idx.z >= params.col.size(0) || idx.y >= params.col.size(1) || idx.x >= params.col.size(2)) return; // Fetch central tap float3 c_nrm = fetch3(params.nrm, idx.z, idx.y, idx.x); float2 c_zdz = fetch2(params.zdz, idx.z, idx.y, idx.x); float variance = params.sigma * params.sigma; int filter_rad = 2 * ceil(params.sigma * 2.5) + 1; float accum_w = 0.0f; float3 accum_col = make_float3(0.0f); for (int32_t fy = -filter_rad; fy <= filter_rad; ++fy) { for (int32_t fx = -filter_rad; fx <= filter_rad; ++fx) { // Compute tap coordinates, used for input activations and bilateral guides int32_t y = idx.y + fy; int32_t x = idx.x + fx; if (y < 0 || x < 0 || y >= params.col.size(1) || x >= params.col.size(2)) continue; // Fetch current tap float3 t_col = fetch3(params.col, idx.z, y, x); float3 t_nrm = fetch3(params.nrm, idx.z, y, x); float2 t_zdz = fetch2(params.zdz, idx.z, y, x); ///////////////////////////////////////////////////////// // Compute bilateral weight ///////////////////////////////////////////////////////// // Distance float dist_sqr = fx * fx + fy * fy; float dist = sqrtf(dist_sqr); float w_xy = expf(-dist_sqr / (2.0f * variance)); // Normal float w_normal = powf(min(max(dot(t_nrm, c_nrm), FLT_EPS), 1.0f), 128.0f); // Depth float w_depth = expf(-(abs(t_zdz.x - c_zdz.x) / max(c_zdz.y * dist, FLT_EPS))); float w = w_xy * w_normal * w_depth; accum_col = accum_col + t_col * w; accum_w += w; } } params.out[idx.z][idx.y][idx.x][0] = accum_col.x; params.out[idx.z][idx.y][idx.x][1] = accum_col.y; params.out[idx.z][idx.y][idx.x][2] = accum_col.z; params.out[idx.z][idx.y][idx.x][3] = max(accum_w, 0.0001f); } __global__ void bilateral_denoiser_bwd_kernel(BilateralDenoiserParams params) { uint3 idx = make_uint3(blockIdx.x * blockDim.x + threadIdx.x, blockIdx.y * blockDim.y + threadIdx.y, blockIdx.z * blockDim.z + threadIdx.z); if (idx.z >= params.col.size(0) || idx.y >= params.col.size(1) || idx.x >= params.col.size(2)) return; // Fetch central tap float3 c_nrm = fetch3(params.nrm, idx.z, idx.y, idx.x); float2 c_zdz = fetch2(params.zdz, idx.z, idx.y, idx.x); float variance = params.sigma * params.sigma; int filter_rad = 2 * ceil(params.sigma * 2.5) + 1; float3 accum_grad = make_float3(0.0f); for (int32_t fy = -filter_rad; fy <= filter_rad; ++fy) { for (int32_t fx = -filter_rad; fx <= filter_rad; ++fx) { // Compute tap coordinates, used for input activations and bilateral guides int32_t y = idx.y + fy; int32_t x = idx.x + fx; if (y < 0 || x < 0 || y >= params.col.size(1) || x >= params.col.size(2)) continue; // Fetch current tap float3 t_col = fetch3(params.col, idx.z, y, x); float3 t_nrm = fetch3(params.nrm, idx.z, y, x); float2 t_zdz = fetch2(params.zdz, idx.z, y, x); ///////////////////////////////////////////////////////// // Compute bilateral weight ///////////////////////////////////////////////////////// // Distance, transposing fx & fy doesn't affect distance float dist_sqr = fx * fx + fy * fy; float dist = sqrtf(dist_sqr); float w_xy = expf(-dist_sqr / (2.0f * variance)); // Normal, transpose c_ and t_ (it's symmetric so doesn't matter) float w_normal = powf(min(max(dot(t_nrm, c_nrm), FLT_EPS), 1.0f), 128.0f); // Depth, transpose c_ and t_ (matters for the denominator) float w_depth = expf(-(abs(t_zdz.x - c_zdz.x) / max(t_zdz.y * dist, FLT_EPS))); float w = w_xy * w_normal * w_depth; float3 t_col_grad = w * fetch3(params.out_grad, idx.z, y, x); accum_grad += t_col_grad; } } params.col_grad[idx.z][idx.y][idx.x][0] = accum_grad.x; params.col_grad[idx.z][idx.y][idx.x][1] = accum_grad.y; params.col_grad[idx.z][idx.y][idx.x][2] = accum_grad.z; } ================================================ FILE: render/optixutils/c_src/denoising.h ================================================ // Copyright (c) 2020-2022, NVIDIA CORPORATION. All rights reserved. // // NVIDIA CORPORATION and its licensors retain all intellectual property // and proprietary rights in and to this software, related documentation // and any modifications thereto. Any use, reproduction, disclosure or // distribution of this software and related documentation without an express // license agreement from NVIDIA CORPORATION is strictly prohibited. #pragma once #include "accessor.h" struct BilateralDenoiserParams { PackedTensorAccessor32 col; PackedTensorAccessor32 col_grad; PackedTensorAccessor32 nrm; PackedTensorAccessor32 zdz; PackedTensorAccessor32 out; PackedTensorAccessor32 out_grad; float sigma; }; ================================================ FILE: render/optixutils/c_src/envsampling/kernel.cu ================================================ // Copyright (c) 2020-2022, NVIDIA CORPORATION. All rights reserved. // // NVIDIA CORPORATION and its licensors retain all intellectual property // and proprietary rights in and to this software, related documentation // and any modifications thereto. Any use, reproduction, disclosure or // distribution of this software and related documentation without an express // license agreement from NVIDIA CORPORATION is strictly prohibited. #define OPTIXU_MATH_DEFINE_IN_NAMESPACE #include #include #include "params.h" #include "../common.h" #define MIN_ROUGHNESS 0.08f extern "C" { __constant__ EnvSamplingParams params; } //============================================================================== // Math / utility functions //============================================================================== #include "../bsdf.h" // from https://www.reedbeta.com/blog/hash-functions-for-gpu-rendering/ __device__ unsigned int rand_pcg(unsigned int &rng_state) { unsigned int word = ((rng_state >> ((rng_state >> 28u) + 4u)) ^ rng_state) * 277803737u; rng_state = rng_state * 747796405u + 2891336453u; return (word >> 22u) ^ word; } __device__ unsigned int hash_pcg(unsigned int global_seed, unsigned int sample_seed) { return rand_pcg(global_seed) ^ rand_pcg(sample_seed); } __device__ float uniform_pcg(unsigned int &rng_state) { return (float)(rand_pcg(rng_state) & 0xFFFFFF) / (float)0x1000000; } __device__ float3 tolocal(const float3& a, const float3& u, const float3& v, const float3& w) { return make_float3(dot(a, u), dot(a, v), dot(a, w)); } __device__ float3 toworld(const float3& a, const float3& u, const float3& v, const float3& w) { return u * a.x + v * a.y + w * a.z; } __device__ float3 cosine_sample(float3 N, float u, float v, float& pdf) { // construct local frame N = safe_normalize(N); float3 dx, dy; branchlessONB(N, dx, dy); // cosine sampling in local frame float phi = 2.0 * CUDART_PI * u; float costheta = sqrt(v); float sintheta = sqrt(1.0 - v); // Cartesian vector in local space float x = cos(phi)*sintheta; float y = sin(phi)*sintheta; float z = costheta; pdf = max(0.000001f, costheta / CUDART_PI); // Local to world float3 vec = dx*x + dy*y + N*z; return safe_normalize(vec); } __device__ float albedo(const float3& baseColor, const float eta, const float3& wo, const float3& N) { // Construct tangent frame float3 W = safe_normalize(N); float3 U,V; branchlessONB(W, U, V); float3 wo_l = safe_normalize(tolocal(wo, U, V, W)); const float cosNO = wo_l.z; if (!(cosNO > 0)) return 0.0f; return luminance(fwdFresnelSchlick(baseColor, make_float3(1.f, 1.f, 1.f), cosNO)); } //============================================================================== // Shadow ray test. Note: This code ignores the shadow gradient boundary term. // We saw no benefit to boundary term gradients in our experiments. //============================================================================== __device__ float shadow_test(uint3 idx, float3 ray_origin, float3 ray_dir, float vis_grad) { unsigned int isVisible = 0; optixTrace( params.handle, ray_origin, ray_dir, 0.0f, // Min intersection distance 1e16, // Max intersection distance 0.0f, // rayTime -- used for motion blur OptixVisibilityMask(0xFF), OPTIX_RAY_FLAG_DISABLE_ANYHIT | OPTIX_RAY_FLAG_DISABLE_CLOSESTHIT | OPTIX_RAY_FLAG_TERMINATE_ON_FIRST_HIT, 0, // SBT offset 0, // SBT stride 0, // missSBTIndex isVisible); return isVisible ? 1.0f : 0.0f; } //============================================================================== // Light probe functions //============================================================================== __device__ float2 _dir_to_tc(float3 dir) { float u = atan2f(dir.x, -dir.z) / (2.0f * CUDART_PI) + 0.5f; float v = acosf(clamp(dir.y, -1.0f, 1.0f)) / CUDART_PI; return make_float2(u, v); } __device__ float3 _tc_to_dir(float2 uv) { float sinphi, cosphi; sincos((uv.x * 2.0f - 1.0f) * CUDART_PI, &sinphi, &cosphi); float sintheta, costheta; sincos(uv.y * CUDART_PI, &sintheta, &costheta); return make_float3(sintheta*sinphi, costheta, -sintheta*cosphi); } template __device__ float sample_cdf(const T &cdf, float x, unsigned int &idx, float &pdf) { x = min(x, 0.99999994f); // Binary search to find next index above unsigned int _min = 0; unsigned int _max = cdf.size(0) - 1; unsigned int m = int(ceil(log2((float)_max))) + 1; for (int i=0; i= cdf[mid] ? mid :_min; _max = x < cdf[mid] ? mid : _max; } idx = _max; float sample; if (idx == 0) { pdf = cdf[0]; sample = x; } else { float data0 = cdf[idx]; float data1 = cdf[idx-1]; pdf = data0 - data1; sample = (x - data1); } // keep result in [0,1) return min(sample / pdf, 0.99999994f); } __device__ float lightPDF(const float3& dir) { // Sample light float2 coord = _dir_to_tc(dir); // retrieve nearest neighbor int x = clamp((int)(coord.x * params.pdf.size(1)), 0, params.pdf.size(1) - 1); int y = clamp((int)(coord.y * params.pdf.size(0)), 0, params.pdf.size(0) - 1); float pdf_weight = params.cols.size(0) * params.cols.size(1) / (2.0f * CUDART_PI * CUDART_PI * max(sinf(coord.y * CUDART_PI), 0.0001f)); return params.pdf[y][x] * pdf_weight; } __device__ float3 lightSample(float u, float v, float& pdf) { float row_pdf, col_pdf; unsigned int x, y; float ry = sample_cdf(params.rows, v, y, row_pdf); float rx = sample_cdf(params.cols[y], u, x, col_pdf); float3 rnd_dir = _tc_to_dir(make_float2((x+rx)/params.cols.size(1), (y+ry)/params.cols.size(0))); pdf = lightPDF(rnd_dir); return rnd_dir; } __device__ float3 eval_light_fwd(float2 coord) { coord = coord * make_float2(params.light.size(1), params.light.size(0)); int x = clamp((int)coord.x, 0, params.light.size(1) - 1); int y = clamp((int)coord.y, 0, params.light.size(0) - 1); return fetch3(params.light, y, x); } __device__ void eval_light_bwd(float2 coord, float3 light_grad) { coord = coord * make_float2(params.light.size(1), params.light.size(0)); int x = clamp((int)coord.x, 0, params.light.size(1) - 1); int y = clamp((int)coord.y, 0, params.light.size(0) - 1); atomicAdd(¶ms.light_grad[y][x][0], light_grad.x); atomicAdd(¶ms.light_grad[y][x][1], light_grad.y); atomicAdd(¶ms.light_grad[y][x][2], light_grad.z); } //============================================================================== // BSDF evaluation & importance sampling //============================================================================== __device__ float evalNdfGGX(float alpha, float cosTheta) { float a2 = alpha * alpha; float d = ((cosTheta * a2 - cosTheta) * cosTheta + 1); return a2 / (d * d * CUDART_PI); } __device__ float evalG1GGX(float alphaSqr, float cosTheta) { if (cosTheta <= 0) return 0; float cosThetaSqr = cosTheta * cosTheta; float tanThetaSqr = max(1.0f - cosThetaSqr, 0.0f) / cosThetaSqr; return 2 / (1 + sqrt(1 + alphaSqr * tanThetaSqr)); } __device__ float evalPdfGGX_VNDF(float alpha, float3 wo, float3 h) { float G1 = evalG1GGX(alpha * alpha, wo.z); float D = evalNdfGGX(alpha, h.z); return G1 * D * max(0.f, dot(wo, h)) / wo.z; } // Samples the GGX (Trowbridge-Reitz) using the distribution of visible normals (VNDF). // See http://jcgt.org/published/0007/04/01/paper.pdf __device__ float3 sampleGGX_VNDF(float alpha, float3 wo, float ux, float uy, float& pdf) { // Transform the view vector to the hemisphere configuration. float3 Vh = safe_normalize(make_float3(alpha * wo.x, alpha * wo.y, wo.z)); // Construct orthonormal basis (Vh,T1,T2). float3 T1 = (Vh.z < 0.9999f) ? safe_normalize(cross(make_float3(0.f, 0.f, 1.f), Vh)) : make_float3(1.f, 0.f, 0.f); float3 T2 = cross(Vh, T1); // Parameterization of the projected area of the hemisphere. float r = sqrtf(ux); float phi = (2.f * M_PI) * uy; float t1 = r * cos(phi); float t2 = r * sin(phi); float s = 0.5f * (1.f + Vh.z); t2 = (1.f - s) * sqrtf(1.f - t1 * t1) + s * t2; // Reproject onto hemisphere. float3 Nh = T1 * t1 + T2* t2 + Vh * sqrtf(max(0.f, 1.f - t1 * t1 - t2 * t2)); // Transform the normal back to the ellipsoid configuration. This is our half vector. float3 h = safe_normalize(make_float3(alpha * Nh.x, alpha * Nh.y, max(0.f, Nh.z))); pdf = evalPdfGGX_VNDF(alpha, wo, h); return h; } __device__ float3 ggx_sample(float3 N, float3 wo, float u, float v, float alpha, float& pdf) { // Construct tangent frame float3 W = safe_normalize(N); float3 U,V; branchlessONB(W, U, V); float3 wo_l = safe_normalize(tolocal(wo, U, V, W)); const float cosNO = wo_l.z; if (!(cosNO > 0)) { pdf = 0.f; return make_float3(0.f, 0.f, 0.f); } float3 h = sampleGGX_VNDF(alpha, wo_l, u, v, pdf); // pdf = G1(wo) * D(h) * max(0,dot(wo,h)) / wo.z // Reflect the outgoing direction to find the incident direction. float woDotH = dot(wo_l, h); float3 wi_l = h * woDotH * 2.0f - wo_l; pdf /= (4.0f * woDotH); // Jacobian of the reflection operator. float3 wi_o = toworld(wi_l, U, V, W); return safe_normalize(wi_o); } __device__ float evalLambdaGGX(float alphaSqr, float cosTheta) { if (cosTheta <= 0) return 0; float cosThetaSqr = cosTheta * cosTheta; float tanThetaSqr = max(1 - cosThetaSqr, 0.0f) / cosThetaSqr; return 0.5 * (-1 + sqrt(1 + alphaSqr * tanThetaSqr)); } __device__ float ggx_pdf(float3 N, const float3 wo, const float3 wi, float alpha) { // Construct tangent frame float3 W = safe_normalize(N); float3 U,V; branchlessONB(W, U, V); // wo_l : V // wi_l : L float3 wo_l = tolocal(wo, U, V, W); float3 wi_l = tolocal(wi, U, V, W); float pdf = 0.0f; if (wo_l.z > 0 && wi_l.z > 0) { float3 m = safe_normalize(wi_l + wo_l); const float woDotH = dot(m, wo_l); const float D = evalNdfGGX(alpha, m.z); float G1 = evalG1GGX(alpha * alpha, wo_l.z); pdf = G1 * D * max(0.f, dot(wo_l, m)) / wo_l.z; pdf /= (4 * woDotH); } return pdf; } __device__ void update_pdf(float* pdf, float opdf, float b) { if (b > 0.000001f) { opdf *= b; *pdf += opdf; } } __device__ float3 bsdf_sample(float pDiffuse, float pSpecular, float3 N, float3 wo, float3 s, float alpha, float& pdf) { float3 rnd = s; pdf = 0.0f; float3 wi_o; if (rnd.z < pDiffuse) // Sample diffuse lobe { if (pDiffuse < 0.0001f) { pdf = 1.0f; return N; } wi_o = cosine_sample(N, rnd.x, rnd.y, pdf); pdf *= pDiffuse; // we sampled the diffuse lobe, now figure out how much the other bsdf contribute to the chosen direction if (pSpecular > 0) { float bsdf_pdf = ggx_pdf(N, wo, wi_o, alpha); update_pdf(&pdf, bsdf_pdf, 1.0f - pDiffuse); } } else // Sample specular lobe { wi_o = ggx_sample(N, wo, rnd.x, rnd.y, alpha, pdf); pdf *= 1.f - pDiffuse; // we sampled PDF 1, now figure out how much the other bsdf contribute to the chosen direction if (pDiffuse > 0) { float bsdf_pdf = max(dot(N, wi_o), 0.0) / CUDART_PI; // cosine sampling pdf update_pdf(&pdf, bsdf_pdf, pDiffuse); } } return wi_o; } __device__ float bsdf_pdf(float pDiffuse, float pSpecular, float3 N, const float3 wo, const float3 wi, float alpha) { // Check that L and V are in the positive hemisphere. // The G term on the correlated form is not robust for NdotL = NdotV = 0.0. float NdotL = dot(N, wi); float NdotV = dot(N, wo); static const float kMinCosTheta = 1e-6f; float pdf = 0.0f; if (min(NdotV, NdotL) < kMinCosTheta) return 1.0f; if (pDiffuse > 0) { float bsdf_pdf = max(dot(N, wi), 0.0) / CUDART_PI; // cosine sampling pdf update_pdf(&pdf, bsdf_pdf, pDiffuse); } if (pSpecular > 0) { float bsdf_pdf = ggx_pdf(N, wo, wi, alpha); // ggx sampling pdf update_pdf(&pdf, bsdf_pdf, 1.0f - pDiffuse); } return pdf; } //============================================================================== // Optix kernels //============================================================================== __device__ void process_sample(uint3 idx, float3 ray_origin, float3 ray_dir, float3 gb_pos, float3 gb_normal, float3 gb_view_pos, float3 gb_kd, float3 gb_ks, float pdfSum, float weight, float3 &diff, float3 &spec, float3 diff_grad, float3 spec_grad) { float2 coord = _dir_to_tc(ray_dir); float3 light_col = eval_light_fwd(coord); float mis_weight = 1.0 / max(pdfSum, 0.0001f); // MIS balance heuristic // float alpha = gb_ks.y * gb_ks.y; float3 _diff = make_float3(0), _spec = make_float3(0); if (params.BSDF == 1 || params.BSDF == 2) _diff = make_float3(fwdLambert(gb_normal, ray_dir)); else fwdPbrBSDF(gb_kd, gb_ks, gb_pos, gb_normal, gb_view_pos, ray_dir, 0.08f, _diff, _spec); // Trace shadow ray for current sample float V_grad = sum((diff_grad * _diff + spec_grad * _spec) * light_col * mis_weight * weight) * params.shadow_scale; float V = shadow_test(idx, ray_origin, ray_dir, V_grad) * params.shadow_scale + (1 - params.shadow_scale); if (params.backward) { float3 light_grad = (diff_grad * _diff + spec_grad * _spec) * V * mis_weight * weight; eval_light_bwd(coord, light_grad); float3 _diff_grad = diff_grad * light_col * V * mis_weight * weight; float3 _spec_grad = spec_grad * light_col * V * mis_weight * weight; float3 gb_kd_grad = make_float3(0), gb_ks_grad = make_float3(0), gb_pos_grad = make_float3(0), gb_normal_grad = make_float3(0), gb_view_pos_grad = make_float3(0), ray_dir_grad = make_float3(0); if (params.BSDF == 1 || params.BSDF == 2) // params.BSDF : 0 : 'pbr', 1 : 'diffuse', 2 : 'white' { float3 wi_grad = make_float3(0); float lambert = fwdLambert(gb_normal, ray_dir); float lambert_grad = sum(_diff_grad); bwdLambert(gb_normal, ray_dir, gb_normal_grad, wi_grad, lambert_grad); } else { bwdPbrBSDF( gb_kd, gb_ks, gb_pos, gb_normal, gb_view_pos, ray_dir, 0.08f, gb_kd_grad, gb_ks_grad, gb_pos_grad, gb_normal_grad, gb_view_pos_grad, ray_dir_grad, _diff_grad, _spec_grad); } params.gb_pos_grad[idx.z][idx.y][idx.x][0] += gb_pos_grad.x; params.gb_pos_grad[idx.z][idx.y][idx.x][1] += gb_pos_grad.y; params.gb_pos_grad[idx.z][idx.y][idx.x][2] += gb_pos_grad.z; params.gb_normal_grad[idx.z][idx.y][idx.x][0] += gb_normal_grad.x; params.gb_normal_grad[idx.z][idx.y][idx.x][1] += gb_normal_grad.y; params.gb_normal_grad[idx.z][idx.y][idx.x][2] += gb_normal_grad.z; params.gb_kd_grad[idx.z][idx.y][idx.x][0] += gb_kd_grad.x; params.gb_kd_grad[idx.z][idx.y][idx.x][1] += gb_kd_grad.y; params.gb_kd_grad[idx.z][idx.y][idx.x][2] += gb_kd_grad.z; params.gb_ks_grad[idx.z][idx.y][idx.x][0] += gb_ks_grad.x; params.gb_ks_grad[idx.z][idx.y][idx.x][1] += gb_ks_grad.y; params.gb_ks_grad[idx.z][idx.y][idx.x][2] += gb_ks_grad.z; } diff = _diff * light_col * V * mis_weight * weight; spec = _spec * light_col * V * mis_weight * weight; } extern "C" __global__ void __raygen__rg() { // Lookup our location within the launch grid const uint3 idx = optixGetLaunchIndex(); const uint3 dim = optixGetLaunchDimensions(); // Read per-pixel constant input tensors, ray_origin, g-buffer entries etc. float mask = params.mask[idx.z][idx.y][idx.x]; float3 ray_origin = fetch3(params.ro, idx.z, idx.y, idx.x); float3 gb_pos = fetch3(params.gb_pos, idx.z, idx.y, idx.x); float3 gb_normal = fetch3(params.gb_normal, idx.z, idx.y, idx.x); float3 gb_view_pos = fetch3(params.gb_view_pos, idx.z, idx.y, idx.x); float3 gb_kd = fetch3(params.gb_kd, idx.z, idx.y, idx.x); float3 gb_ks = fetch3(params.gb_ks, idx.z, idx.y, idx.x); if (mask <= 0) return; // Early exit masked pixels float3 diff_grad, spec_grad; if (params.backward) { diff_grad = fetch3(params.diff_grad, idx.z, idx.y, idx.x); spec_grad = fetch3(params.spec_grad, idx.z, idx.y, idx.x); } float3 diffAccum = make_float3(0.0f, 0.0f, 0.0f); float3 specAccum = make_float3(0.0f, 0.0f, 0.0f); float strata_frac = 1.0f / params.n_samples_x; float sample_frac = 1.0f / (params.n_samples_x * params.n_samples_x); float alpha = gb_ks.y * gb_ks.y; // roughness squared float3 wo = safe_normalize(gb_view_pos - gb_pos); // view direction float metallic = gb_ks.z; float3 baseColor = gb_kd; float3 specColor = make_float3(0.04f, 0.04f, 0.04f) * (1.0f - metallic) + baseColor * metallic; float diffuseWeight = (1.f - metallic) * luminance(baseColor); float eta = 1.0f; float specularWeight = albedo(specColor, eta, wo, gb_normal); float pDiffuse = (diffuseWeight + specularWeight) > 0.f ? diffuseWeight / (diffuseWeight + specularWeight) : 1.f; float pSpecular = 1.0f - pDiffuse; unsigned int rng_state = hash_pcg(params.rnd_seed, (idx.z * dim.y + idx.y) * dim.x + idx.x); unsigned int lightIdx = rand_pcg(rng_state) % params.perms.size(0), bsdfIdx = rand_pcg(rng_state) % params.perms.size(0); for (int i = 0; i < params.n_samples_x * params.n_samples_x; ++i) { float3 ray_dir, diff, spec; float sx, sy, sz = 0.f, pdf_light, pdf_bsdf; // Light importance sampling sx = ((float)(params.perms[lightIdx][i] % params.n_samples_x) + uniform_pcg(rng_state)) * strata_frac; sy = ((float)(params.perms[lightIdx][i] / params.n_samples_x) + uniform_pcg(rng_state)) * strata_frac; ray_dir = lightSample(sx, sy, pdf_light); pdf_bsdf = bsdf_pdf(pDiffuse, pSpecular, gb_normal, wo, ray_dir, alpha); process_sample(idx, ray_origin, ray_dir, gb_pos, gb_normal, gb_view_pos, gb_kd, gb_ks, pdf_light + pdf_bsdf, sample_frac, diff, spec, diff_grad, spec_grad); diffAccum = diffAccum + diff; specAccum = specAccum + spec; // BSDF sampling (sample either the diffuse or specular lobe) sx = ((float)(params.perms[bsdfIdx][i] % params.n_samples_x) + uniform_pcg(rng_state)) * strata_frac; sy = ((float)(params.perms[bsdfIdx][i] / params.n_samples_x) + uniform_pcg(rng_state)) * strata_frac; sz = uniform_pcg(rng_state); ray_dir = bsdf_sample(pDiffuse, pSpecular, gb_normal, wo, make_float3(sx, sy, sz), alpha, pdf_bsdf); pdf_light = lightPDF(ray_dir); process_sample(idx, ray_origin, ray_dir, gb_pos, gb_normal, gb_view_pos, gb_kd, gb_ks, pdf_light + pdf_bsdf, sample_frac, diff, spec, diff_grad, spec_grad); diffAccum = diffAccum + diff; specAccum = specAccum + spec; } // Record results in our output raster if (!params.backward) { params.diff[idx.z][idx.y][idx.x][0] = diffAccum.x; params.diff[idx.z][idx.y][idx.x][1] = diffAccum.y; params.diff[idx.z][idx.y][idx.x][2] = diffAccum.z; params.spec[idx.z][idx.y][idx.x][0] = specAccum.x; params.spec[idx.z][idx.y][idx.x][1] = specAccum.y; params.spec[idx.z][idx.y][idx.x][2] = specAccum.z; } } extern "C" __global__ void __miss__ms() { optixSetPayload_0(1); } ================================================ FILE: render/optixutils/c_src/envsampling/params.h ================================================ // Copyright (c) 2020-2022, NVIDIA CORPORATION. All rights reserved. // // NVIDIA CORPORATION and its licensors retain all intellectual property // and proprietary rights in and to this software, related documentation // and any modifications thereto. Any use, reproduction, disclosure or // distribution of this software and related documentation without an express // license agreement from NVIDIA CORPORATION is strictly prohibited. #include "../accessor.h" struct EnvSamplingParams { // Ray data PackedTensorAccessor32 ro; // ray origin // GBuffer PackedTensorAccessor32 mask; PackedTensorAccessor32 gb_pos; PackedTensorAccessor32 gb_pos_grad; PackedTensorAccessor32 gb_normal; PackedTensorAccessor32 gb_normal_grad; PackedTensorAccessor32 gb_view_pos; PackedTensorAccessor32 gb_kd; PackedTensorAccessor32 gb_kd_grad; PackedTensorAccessor32 gb_ks; PackedTensorAccessor32 gb_ks_grad; // Light PackedTensorAccessor32 light; PackedTensorAccessor32 light_grad; PackedTensorAccessor32 pdf; // light pdf PackedTensorAccessor32 rows; // light sampling cdf PackedTensorAccessor32 cols; // light sampling cdf // Output PackedTensorAccessor32 diff; PackedTensorAccessor32 diff_grad; PackedTensorAccessor32 spec; PackedTensorAccessor32 spec_grad; // Table with random permutations for stratified sampling PackedTensorAccessor32 perms; OptixTraversableHandle handle; unsigned int BSDF; unsigned int n_samples_x; unsigned int rnd_seed; unsigned int backward; float shadow_scale; }; ================================================ FILE: render/optixutils/c_src/math_utils.h ================================================ // Copyright (c) 2020-2022, NVIDIA CORPORATION. All rights reserved. // // NVIDIA CORPORATION and its licensors retain all intellectual property // and proprietary rights in and to this software, related documentation // and any modifications thereto. Any use, reproduction, disclosure or // distribution of this software and related documentation without an express // license agreement from NVIDIA CORPORATION is strictly prohibited. #pragma once #ifdef __CUDACC__ template static __device__ __inline__ T clamp(T x, T _min, T _max) { return min(_max, max(_min, x)); } static __device__ inline float3 make_float3(float a) { return make_float3(a, a, a); } static __device__ inline float2& operator/= (float2& a, const float2& b) { a.x /= b.x; a.y /= b.y; return a; } static __device__ inline float2& operator*= (float2& a, const float2& b) { a.x *= b.x; a.y *= b.y; return a; } static __device__ inline float2& operator+= (float2& a, const float2& b) { a.x += b.x; a.y += b.y; return a; } static __device__ inline float2& operator-= (float2& a, const float2& b) { a.x -= b.x; a.y -= b.y; return a; } static __device__ inline float2& operator/= (float2& a, float b) { a.x /= b; a.y /= b; return a; } static __device__ inline float2& operator*= (float2& a, float b) { a.x *= b; a.y *= b; return a; } static __device__ inline float2& operator+= (float2& a, float b) { a.x += b; a.y += b; return a; } static __device__ inline float2& operator-= (float2& a, float b) { a.x -= b; a.y -= b; return a; } static __device__ inline float2 operator/ (const float2& a, const float2& b) { return make_float2(a.x / b.x, a.y / b.y); } static __device__ inline float2 operator* (const float2& a, const float2& b) { return make_float2(a.x * b.x, a.y * b.y); } static __device__ inline float2 operator+ (const float2& a, const float2& b) { return make_float2(a.x + b.x, a.y + b.y); } static __device__ inline float2 operator- (const float2& a, const float2& b) { return make_float2(a.x - b.x, a.y - b.y); } static __device__ inline float2 operator/ (const float2& a, float b) { return make_float2(a.x / b, a.y / b); } static __device__ inline float2 operator* (const float2& a, float b) { return make_float2(a.x * b, a.y * b); } static __device__ inline float2 operator+ (const float2& a, float b) { return make_float2(a.x + b, a.y + b); } static __device__ inline float2 operator- (const float2& a, float b) { return make_float2(a.x - b, a.y - b); } static __device__ inline float2 operator/ (float a, const float2& b) { return make_float2(a / b.x, a / b.y); } static __device__ inline float2 operator* (float a, const float2& b) { return make_float2(a * b.x, a * b.y); } static __device__ inline float2 operator+ (float a, const float2& b) { return make_float2(a + b.x, a + b.y); } static __device__ inline float2 operator- (float a, const float2& b) { return make_float2(a - b.x, a - b.y); } static __device__ inline float2 operator- (const float2& a) { return make_float2(-a.x, -a.y); } static __device__ inline float3& operator/= (float3& a, const float3& b) { a.x /= b.x; a.y /= b.y; a.z /= b.z; return a; } static __device__ inline float3& operator*= (float3& a, const float3& b) { a.x *= b.x; a.y *= b.y; a.z *= b.z; return a; } static __device__ inline float3& operator+= (float3& a, const float3& b) { a.x += b.x; a.y += b.y; a.z += b.z; return a; } static __device__ inline float3& operator-= (float3& a, const float3& b) { a.x -= b.x; a.y -= b.y; a.z -= b.z; return a; } static __device__ inline float3& operator/= (float3& a, float b) { a.x /= b; a.y /= b; a.z /= b; return a; } static __device__ inline float3& operator*= (float3& a, float b) { a.x *= b; a.y *= b; a.z *= b; return a; } static __device__ inline float3& operator+= (float3& a, float b) { a.x += b; a.y += b; a.z += b; return a; } static __device__ inline float3& operator-= (float3& a, float b) { a.x -= b; a.y -= b; a.z -= b; return a; } static __device__ inline float3 operator/ (const float3& a, const float3& b) { return make_float3(a.x / b.x, a.y / b.y, a.z / b.z); } static __device__ inline float3 operator* (const float3& a, const float3& b) { return make_float3(a.x * b.x, a.y * b.y, a.z * b.z); } static __device__ inline float3 operator+ (const float3& a, const float3& b) { return make_float3(a.x + b.x, a.y + b.y, a.z + b.z); } static __device__ inline float3 operator- (const float3& a, const float3& b) { return make_float3(a.x - b.x, a.y - b.y, a.z - b.z); } static __device__ inline float3 operator/ (const float3& a, float b) { return make_float3(a.x / b, a.y / b, a.z / b); } static __device__ inline float3 operator* (const float3& a, float b) { return make_float3(a.x * b, a.y * b, a.z * b); } static __device__ inline float3 operator+ (const float3& a, float b) { return make_float3(a.x + b, a.y + b, a.z + b); } static __device__ inline float3 operator- (const float3& a, float b) { return make_float3(a.x - b, a.y - b, a.z - b); } static __device__ inline float3 operator/ (float a, const float3& b) { return make_float3(a / b.x, a / b.y, a / b.z); } static __device__ inline float3 operator* (float a, const float3& b) { return make_float3(a * b.x, a * b.y, a * b.z); } static __device__ inline float3 operator+ (float a, const float3& b) { return make_float3(a + b.x, a + b.y, a + b.z); } static __device__ inline float3 operator- (float a, const float3& b) { return make_float3(a - b.x, a - b.y, a - b.z); } static __device__ inline float3 operator- (const float3& a) { return make_float3(-a.x, -a.y, -a.z); } static __device__ inline float4& operator/= (float4& a, const float4& b) { a.x /= b.x; a.y /= b.y; a.z /= b.z; a.w /= b.w; return a; } static __device__ inline float4& operator*= (float4& a, const float4& b) { a.x *= b.x; a.y *= b.y; a.z *= b.z; a.w *= b.w; return a; } static __device__ inline float4& operator+= (float4& a, const float4& b) { a.x += b.x; a.y += b.y; a.z += b.z; a.w += b.w; return a; } static __device__ inline float4& operator-= (float4& a, const float4& b) { a.x -= b.x; a.y -= b.y; a.z -= b.z; a.w -= b.w; return a; } static __device__ inline float4& operator/= (float4& a, float b) { a.x /= b; a.y /= b; a.z /= b; a.w /= b; return a; } static __device__ inline float4& operator*= (float4& a, float b) { a.x *= b; a.y *= b; a.z *= b; a.w *= b; return a; } static __device__ inline float4& operator+= (float4& a, float b) { a.x += b; a.y += b; a.z += b; a.w += b; return a; } static __device__ inline float4& operator-= (float4& a, float b) { a.x -= b; a.y -= b; a.z -= b; a.w -= b; return a; } static __device__ inline float4 operator/ (const float4& a, const float4& b) { return make_float4(a.x / b.x, a.y / b.y, a.z / b.z, a.w / b.w); } static __device__ inline float4 operator* (const float4& a, const float4& b) { return make_float4(a.x * b.x, a.y * b.y, a.z * b.z, a.w * b.w); } static __device__ inline float4 operator+ (const float4& a, const float4& b) { return make_float4(a.x + b.x, a.y + b.y, a.z + b.z, a.w + b.w); } static __device__ inline float4 operator- (const float4& a, const float4& b) { return make_float4(a.x - b.x, a.y - b.y, a.z - b.z, a.w - b.w); } static __device__ inline float4 operator/ (const float4& a, float b) { return make_float4(a.x / b, a.y / b, a.z / b, a.w / b); } static __device__ inline float4 operator* (const float4& a, float b) { return make_float4(a.x * b, a.y * b, a.z * b, a.w * b); } static __device__ inline float4 operator+ (const float4& a, float b) { return make_float4(a.x + b, a.y + b, a.z + b, a.w + b); } static __device__ inline float4 operator- (const float4& a, float b) { return make_float4(a.x - b, a.y - b, a.z - b, a.w - b); } static __device__ inline float4 operator/ (float a, const float4& b) { return make_float4(a / b.x, a / b.y, a / b.z, a / b.w); } static __device__ inline float4 operator* (float a, const float4& b) { return make_float4(a * b.x, a * b.y, a * b.z, a * b.w); } static __device__ inline float4 operator+ (float a, const float4& b) { return make_float4(a + b.x, a + b.y, a + b.z, a + b.w); } static __device__ inline float4 operator- (float a, const float4& b) { return make_float4(a - b.x, a - b.y, a - b.z, a - b.w); } static __device__ inline float4 operator- (const float4& a) { return make_float4(-a.x, -a.y, -a.z, -a.w); } static __device__ inline float sum(float3 a) { return a.x + a.y + a.z; } static __device__ inline float dot(float3 a, float3 b) { return a.x * b.x + a.y * b.y + a.z * b.z; } static __device__ inline void bwd_dot(float3 a, float3 b, float3& d_a, float3& d_b, float d_out) { d_a.x += d_out * b.x; d_a.y += d_out * b.y; d_a.z += d_out * b.z; d_b.x += d_out * a.x; d_b.y += d_out * a.y; d_b.z += d_out * a.z; } static __device__ inline float luminance(const float3 rgb) { return dot(rgb, make_float3(0.2126f, 0.7152f, 0.0722f)); } static __device__ inline float3 cross(float3 a, float3 b) { float3 out; out.x = a.y * b.z - a.z * b.y; out.y = a.z * b.x - a.x * b.z; out.z = a.x * b.y - a.y * b.x; return out; } static __device__ inline void bwd_cross(float3 a, float3 b, float3 &d_a, float3 &d_b, float3 d_out) { d_a.x += d_out.z * b.y - d_out.y * b.z; d_a.y += d_out.x * b.z - d_out.z * b.x; d_a.z += d_out.y * b.x - d_out.x * b.y; d_b.x += d_out.y * a.z - d_out.z * a.y; d_b.y += d_out.z * a.x - d_out.x * a.z; d_b.z += d_out.x * a.y - d_out.y * a.x; } static __device__ inline float3 reflect(float3 x, float3 n) { return n * 2.0f * dot(n, x) - x; } static __device__ inline void bwd_reflect(float3 x, float3 n, float3& d_x, float3& d_n, float3 d_out) { d_x.x += d_out.x * (2 * n.x * n.x - 1) + d_out.y * (2 * n.x * n.y) + d_out.z * (2 * n.x * n.z); d_x.y += d_out.x * (2 * n.x * n.y) + d_out.y * (2 * n.y * n.y - 1) + d_out.z * (2 * n.y * n.z); d_x.z += d_out.x * (2 * n.x * n.z) + d_out.y * (2 * n.y * n.z) + d_out.z * (2 * n.z * n.z - 1); d_n.x += d_out.x * (2 * (2 * n.x * x.x + n.y * x.y + n.z * x.z)) + d_out.y * (2 * n.y * x.x) + d_out.z * (2 * n.z * x.x); d_n.y += d_out.x * (2 * n.x * x.y) + d_out.y * (2 * (n.x * x.x + 2 * n.y * x.y + n.z * x.z)) + d_out.z * (2 * n.z * x.y); d_n.z += d_out.x * (2 * n.x * x.z) + d_out.y * (2 * n.y * x.z) + d_out.z * (2 * (n.x * x.x + n.y * x.y + 2 * n.z * x.z)); } static __device__ inline float3 safe_normalize(float3 v) { float l = sqrtf(v.x * v.x + v.y * v.y + v.z * v.z); return l > 0.0f ? (v / l) : make_float3(0.0f); } static __device__ inline void bwd_safe_normalize(const float3 v, float3& d_v, float3 d_out) { float l = sqrtf(v.x * v.x + v.y * v.y + v.z * v.z); if (l > 0.0f) { float fac = 1.0 / powf(v.x * v.x + v.y * v.y + v.z * v.z, 1.5f); d_v.x += (d_out.x * (v.y * v.y + v.z * v.z) - d_out.y * (v.x * v.y) - d_out.z * (v.x * v.z)) * fac; d_v.y += (d_out.y * (v.x * v.x + v.z * v.z) - d_out.x * (v.y * v.x) - d_out.z * (v.y * v.z)) * fac; d_v.z += (d_out.z * (v.x * v.x + v.y * v.y) - d_out.x * (v.z * v.x) - d_out.y * (v.z * v.y)) * fac; } } // Code from // https://graphics.pixar.com/library/OrthonormalB/paper.pdf static __device__ inline void branchlessONB(const float3 &n, float3 &b1, float3 &b2) { float sign = copysignf(1.0f, n.z); const float a = -1.0f / (sign + n.z); const float b = n.x * n.y * a; b1 = make_float3(1.0f + sign * n.x * n.x * a, sign * b, -sign * n.x); b2 = make_float3(b, sign + n.y * n.y * a, -n.y); } #endif ================================================ FILE: render/optixutils/c_src/optix_wrapper.cpp ================================================ // Copyright (c) 2020-2022, NVIDIA CORPORATION. All rights reserved. // // NVIDIA CORPORATION and its licensors retain all intellectual property // and proprietary rights in and to this software, related documentation // and any modifications thereto. Any use, reproduction, disclosure or // distribution of this software and related documentation without an express // license agreement from NVIDIA CORPORATION is strictly prohibited. #ifdef _MSC_VER #pragma warning(push, 0) #include #pragma warning(pop) #else #include #endif #include #include #include #include #include #include #include #include #include #include "common.h" #include "optix_wrapper.h" // NVRTC compiler options #define CUDA_NVRTC_OPTIONS \ "-std=c++11", \ "-arch", \ "compute_70", \ "-use_fast_math", \ "-lineinfo", \ "-default-device", \ "-rdc", \ "true", \ "-D__x86_64", \ "-D__OPTIX__" static void context_log_cb( unsigned int level, const char* tag, const char* message, void* /*cbdata */) { std::cerr << "[" << std::setw( 2 ) << level << "][" << std::setw( 12 ) << tag << "]: " << message << "\n"; } static bool readSourceFile( std::string& str, const std::string& filename ) { // Try to open file std::ifstream file( filename.c_str(), std::ios::binary ); if( file.good() ) { // Found usable source file std::vector buffer = std::vector( std::istreambuf_iterator( file ), {} ); str.assign(buffer.begin(), buffer.end()); return true; } return false; } static void getCuStringFromFile( std::string& cu, const char* filename ) { // Try to get source code from file if( readSourceFile( cu, filename ) ) return; // Wasn't able to find or open the requested file throw std::runtime_error( "Couldn't open source file " + std::string( filename ) ); } static void getPtxFromCuString( std::string& ptx, const char* include_dir, const char* optix_include_dir, const char* cuda_include_dir, const char* cu_source, const char* name, const char** log_string ) { // Create program nvrtcProgram prog = 0; NVRTC_CHECK_ERROR( nvrtcCreateProgram( &prog, cu_source, name, 0, NULL, NULL ) ); // Gather NVRTC options std::vector options; std::string sample_dir; sample_dir = std::string( "-I" ) + include_dir; options.push_back( sample_dir.c_str() ); // Collect include dirs std::vector include_dirs; include_dirs.push_back( std::string( "-I" ) + optix_include_dir ); include_dirs.push_back( std::string( "-I" ) + cuda_include_dir ); for( const std::string& dir : include_dirs) options.push_back( dir.c_str() ); // Collect NVRTC options const char* compiler_options[] = {CUDA_NVRTC_OPTIONS}; std::copy( std::begin( compiler_options ), std::end( compiler_options ), std::back_inserter( options ) ); // JIT compile CU to PTX const nvrtcResult compileRes = nvrtcCompileProgram( prog, (int)options.size(), options.data() ); // Retrieve log output std::string g_nvrtcLog; size_t log_size = 0; NVRTC_CHECK_ERROR( nvrtcGetProgramLogSize( prog, &log_size ) ); g_nvrtcLog.resize( log_size ); if( log_size > 1 ) { NVRTC_CHECK_ERROR( nvrtcGetProgramLog( prog, &g_nvrtcLog[0] ) ); if( log_string ) *log_string = g_nvrtcLog.c_str(); } if( compileRes != NVRTC_SUCCESS ) throw std::runtime_error( "NVRTC Compilation failed.\n" + g_nvrtcLog ); // Retrieve PTX code size_t ptx_size = 0; NVRTC_CHECK_ERROR( nvrtcGetPTXSize( prog, &ptx_size ) ); ptx.resize( ptx_size ); NVRTC_CHECK_ERROR( nvrtcGetPTX( prog, &ptx[0] ) ); // Cleanup NVRTC_CHECK_ERROR( nvrtcDestroyProgram( &prog ) ); } const char* getInputData( const char* filename, const char* include_dir, const char* optix_include_dir, const char* cuda_include_dir, const char* name, size_t& dataSize, const char** log) { if( log ) *log = NULL; std::string * ptx, cu; ptx = new std::string(); getCuStringFromFile( cu, filename ); getPtxFromCuString( *ptx, include_dir, optix_include_dir, cuda_include_dir, cu.c_str(), name, log ); dataSize = ptx->size(); return ptx->c_str(); } struct SbtRecord { __align__( OPTIX_SBT_RECORD_ALIGNMENT ) char header[OPTIX_SBT_RECORD_HEADER_SIZE]; }; void createPipeline(const OptixDeviceContext context, const std::string& path, const std::string& cuda_path, const std::string& kernel_name, OptixModule* module, OptixPipeline* pipeline, OptixShaderBindingTable& sbt) { char log[2048]; OptixPipelineCompileOptions pipeline_compile_options = {}; { OptixModuleCompileOptions module_compile_options = {}; module_compile_options.maxRegisterCount = OPTIX_COMPILE_DEFAULT_MAX_REGISTER_COUNT; module_compile_options.optLevel = OPTIX_COMPILE_OPTIMIZATION_DEFAULT; module_compile_options.debugLevel = OPTIX_COMPILE_DEBUG_LEVEL_DEFAULT; pipeline_compile_options.usesMotionBlur = false; pipeline_compile_options.traversableGraphFlags = OPTIX_TRAVERSABLE_GRAPH_FLAG_ALLOW_SINGLE_GAS; pipeline_compile_options.numPayloadValues = 1; pipeline_compile_options.numAttributeValues = 2; pipeline_compile_options.exceptionFlags = OPTIX_EXCEPTION_FLAG_NONE; pipeline_compile_options.pipelineLaunchParamsVariableName = "params"; pipeline_compile_options.usesPrimitiveTypeFlags = OPTIX_PRIMITIVE_TYPE_FLAGS_TRIANGLE; size_t inputSize = 0; std::string shaderFile = path + "/c_src/" + kernel_name + "/kernel.cu"; std::string includeDir = path + "/c_src/" + kernel_name; std::string optix_include_dir = path + "/include"; std::string cuda_include_dir = cuda_path + "/include"; const char* input = getInputData(shaderFile.c_str(), includeDir.c_str(), optix_include_dir.c_str(), cuda_include_dir.c_str(), "kernel", inputSize, (const char**)&log); size_t sizeof_log = sizeof( log ); OPTIX_CHECK_LOG( optixModuleCreateFromPTX( context, &module_compile_options, &pipeline_compile_options, input, inputSize, log, &sizeof_log, module) ); } // // Create program groups // OptixProgramGroup raygen_prog_group = nullptr; OptixProgramGroup miss_prog_group = nullptr; { OptixProgramGroupOptions program_group_options = {}; // Initialize to zeros OptixProgramGroupDesc raygen_prog_group_desc = {}; // raygen_prog_group_desc.kind = OPTIX_PROGRAM_GROUP_KIND_RAYGEN; raygen_prog_group_desc.raygen.module = *module; raygen_prog_group_desc.raygen.entryFunctionName = "__raygen__rg"; size_t sizeof_log = sizeof( log ); OPTIX_CHECK_LOG( optixProgramGroupCreate( context, &raygen_prog_group_desc, 1, // num program groups &program_group_options, log, &sizeof_log, &raygen_prog_group ) ); OptixProgramGroupDesc miss_prog_group_desc = {}; miss_prog_group_desc.kind = OPTIX_PROGRAM_GROUP_KIND_MISS; miss_prog_group_desc.miss.module = *module; miss_prog_group_desc.miss.entryFunctionName = "__miss__ms"; sizeof_log = sizeof( log ); OPTIX_CHECK_LOG( optixProgramGroupCreate( context, &miss_prog_group_desc, 1, // num program groups &program_group_options, log, &sizeof_log, &miss_prog_group ) ); } // // Link pipeline // { const uint32_t max_trace_depth = 1; OptixProgramGroup program_groups[] = { raygen_prog_group, miss_prog_group }; OptixPipelineLinkOptions pipeline_link_options = {}; pipeline_link_options.maxTraceDepth = max_trace_depth; pipeline_link_options.debugLevel = OPTIX_COMPILE_DEBUG_LEVEL_DEFAULT; size_t sizeof_log = sizeof( log ); OPTIX_CHECK_LOG( optixPipelineCreate( context, &pipeline_compile_options, &pipeline_link_options, program_groups, sizeof( program_groups ) / sizeof( program_groups[0] ), log, &sizeof_log, pipeline ) ); OptixStackSizes stack_sizes = {}; for( auto& prog_group : program_groups ) { OPTIX_CHECK( optixUtilAccumulateStackSizes( prog_group, &stack_sizes ) ); } uint32_t direct_callable_stack_size_from_traversal; uint32_t direct_callable_stack_size_from_state; uint32_t continuation_stack_size; OPTIX_CHECK( optixUtilComputeStackSizes( &stack_sizes, max_trace_depth, 0, // maxCCDepth 0, // maxDCDEpth &direct_callable_stack_size_from_traversal, &direct_callable_stack_size_from_state, &continuation_stack_size ) ); OPTIX_CHECK( optixPipelineSetStackSize( *pipeline, direct_callable_stack_size_from_traversal, direct_callable_stack_size_from_state, continuation_stack_size, 1 // maxTraversableDepth ) ); } // // Set up shader binding table // { CUdeviceptr raygen_record; const size_t raygen_record_size = sizeof( SbtRecord ); CUDA_CHECK( cudaMalloc( reinterpret_cast( &raygen_record ), raygen_record_size ) ); SbtRecord rg_sbt; OPTIX_CHECK( optixSbtRecordPackHeader( raygen_prog_group, &rg_sbt ) ); CUDA_CHECK( cudaMemcpy( reinterpret_cast( raygen_record ), &rg_sbt, raygen_record_size, cudaMemcpyHostToDevice ) ); CUdeviceptr miss_record; size_t miss_record_size = sizeof( SbtRecord ); CUDA_CHECK( cudaMalloc( reinterpret_cast( &miss_record ), miss_record_size ) ); SbtRecord ms_sbt; OPTIX_CHECK( optixSbtRecordPackHeader( miss_prog_group, &ms_sbt ) ); CUDA_CHECK( cudaMemcpy( reinterpret_cast( miss_record ), &ms_sbt, miss_record_size, cudaMemcpyHostToDevice ) ); sbt.raygenRecord = raygen_record; sbt.missRecordBase = miss_record; sbt.missRecordStrideInBytes = sizeof( SbtRecord ); sbt.missRecordCount = 1; } } OptiXStateWrapper::OptiXStateWrapper(const std::string& path, const std::string& cuda_path) { pState = new OptiXState(); memset(pState, 0, sizeof(OptiXState)); // create OptiX context pState->context = nullptr; { // Initialize the OptiX API, loading all API entry points OPTIX_CHECK( optixInit() ); // Specify context options OptixDeviceContextOptions options = {}; options.logCallbackFunction = &context_log_cb; options.logCallbackLevel = 0; // Associate a CUDA context (and therefore a specific GPU) with this // device context CUcontext cuCtx = 0; // zero means take the current context OPTIX_CHECK( optixDeviceContextCreate( cuCtx, &options, &pState->context ) ); } // Create pipelines pState->moduleEnvSampling = nullptr; pState->pipelineEnvSampling = nullptr; pState->sbtEnvSampling = {}; createPipeline(pState->context, path, cuda_path, "envsampling", &pState->moduleEnvSampling, &pState->pipelineEnvSampling, pState->sbtEnvSampling); printf("End of OptiXStateWrapper \n"); } OptiXStateWrapper::~OptiXStateWrapper(void) { OPTIX_CHECK( optixPipelineDestroy( pState->pipelineEnvSampling ) ); CUDA_CHECK( cudaFree( reinterpret_cast( pState->sbtEnvSampling.raygenRecord ) ) ); CUDA_CHECK( cudaFree( reinterpret_cast( pState->sbtEnvSampling.missRecordBase ) ) ); OPTIX_CHECK( optixModuleDestroy( pState->moduleEnvSampling ) ); CUDA_CHECK( cudaFree( reinterpret_cast( pState->d_gas_output_buffer ) ) ); OPTIX_CHECK( optixDeviceContextDestroy( pState->context ) ); delete pState; printf("OptiXStateWrapper destructor \n"); } ================================================ FILE: render/optixutils/c_src/optix_wrapper.h ================================================ // Copyright (c) 2020-2022, NVIDIA CORPORATION. All rights reserved. // // NVIDIA CORPORATION and its licensors retain all intellectual property // and proprietary rights in and to this software, related documentation // and any modifications thereto. Any use, reproduction, disclosure or // distribution of this software and related documentation without an express // license agreement from NVIDIA CORPORATION is strictly prohibited. #pragma once #include #include //------------------------------------------------------------------------ // Python OptiX state wrapper. struct OptiXState { OptixDeviceContext context; OptixTraversableHandle gas_handle; CUdeviceptr d_gas_output_buffer; // Differentiable env sampling OptixPipeline pipelineEnvSampling; OptixShaderBindingTable sbtEnvSampling; OptixModule moduleEnvSampling; }; class OptiXStateWrapper { public: OptiXStateWrapper (const std::string &path, const std::string &cuda_path); ~OptiXStateWrapper (void); OptiXState* pState; }; ================================================ FILE: render/optixutils/c_src/torch_bindings.cpp ================================================ // Copyright (c) 2020-2022, NVIDIA CORPORATION. All rights reserved. // // NVIDIA CORPORATION and its licensors retain all intellectual property // and proprietary rights in and to this software, related documentation // and any modifications thereto. Any use, reproduction, disclosure or // distribution of this software and related documentation without an express // license agreement from NVIDIA CORPORATION is strictly prohibited. #ifdef _MSC_VER #pragma warning(push, 0) #include #pragma warning(pop) #else #include #endif #include #include #include #include #include #include "common.h" #include "optix_wrapper.h" #include "denoising.h" #include "envsampling/params.h" //------------------------------------------------------------------------ // CUDA kernels void bilateral_denoiser_fwd_kernel(BilateralDenoiserParams params); void bilateral_denoiser_bwd_kernel(BilateralDenoiserParams params); //------------------------------------------------------------------------ // OptiX tracer void optix_build_bvh(OptiXStateWrapper& stateWrapper,torch::Tensor grid_verts, torch::Tensor grid_tris, unsigned int rebuild) { // // accel handling // // Clear BVH GPU memory { // Use default options for simplicity. In a real use case we would want to // enable compaction, etc OptixAccelBuildOptions accel_options = {}; accel_options.buildFlags = OPTIX_BUILD_FLAG_ALLOW_COMPACTION | OPTIX_BUILD_FLAG_ALLOW_UPDATE | OPTIX_BUILD_FLAG_ALLOW_RANDOM_VERTEX_ACCESS; if (rebuild > 0) { CUDA_CHECK( cudaFree( reinterpret_cast( stateWrapper.pState->d_gas_output_buffer ) ) ); accel_options.operation = OPTIX_BUILD_OPERATION_BUILD; } else { accel_options.operation = OPTIX_BUILD_OPERATION_UPDATE; } CUdeviceptr d_vertices = (CUdeviceptr)grid_verts.data_ptr(); CUdeviceptr d_indices = (CUdeviceptr)grid_tris.data_ptr(); // Our build input is a simple list of non-indexed triangle vertices const uint32_t triangle_input_flags[1] = { OPTIX_GEOMETRY_FLAG_NONE }; OptixBuildInput triangle_input = {}; triangle_input.type = OPTIX_BUILD_INPUT_TYPE_TRIANGLES; triangle_input.triangleArray.vertexFormat = OPTIX_VERTEX_FORMAT_FLOAT3; triangle_input.triangleArray.numVertices = (uint32_t)grid_verts.size(0); triangle_input.triangleArray.vertexBuffers = &d_vertices; triangle_input.triangleArray.indexFormat = OPTIX_INDICES_FORMAT_UNSIGNED_INT3; triangle_input.triangleArray.numIndexTriplets = (uint32_t)grid_tris.size(0); triangle_input.triangleArray.indexBuffer = d_indices; triangle_input.triangleArray.flags = triangle_input_flags; triangle_input.triangleArray.numSbtRecords = 1; OptixAccelBufferSizes gas_buffer_sizes; OPTIX_CHECK( optixAccelComputeMemoryUsage( stateWrapper.pState->context, &accel_options, &triangle_input, 1, // Number of build inputs &gas_buffer_sizes ) ); CUdeviceptr d_temp_buffer_gas; CUDA_CHECK( cudaMalloc( reinterpret_cast( &d_temp_buffer_gas ), gas_buffer_sizes.tempSizeInBytes ) ); if (rebuild > 0) { CUDA_CHECK( cudaMalloc( reinterpret_cast( &stateWrapper.pState->d_gas_output_buffer ), gas_buffer_sizes.outputSizeInBytes ) ); } OPTIX_CHECK( optixAccelBuild( stateWrapper.pState->context, 0, // CUDA stream &accel_options, &triangle_input, 1, // num build inputs d_temp_buffer_gas, gas_buffer_sizes.tempSizeInBytes, stateWrapper.pState->d_gas_output_buffer, gas_buffer_sizes.outputSizeInBytes, &stateWrapper.pState->gas_handle, nullptr, // emitted property list 0 // num emitted properties ) ); // We can now free the scratch space buffer used during build and the vertex // inputs, since they are not needed by our trivial shading method CUDA_CHECK( cudaFree( reinterpret_cast( d_temp_buffer_gas ) ) ); } } template class PtrTraits = DefaultPtrTraits> PackedTensorAccessor32 packed_accessor32(torch::Tensor tensor) { return PackedTensorAccessor32(static_cast::PtrType>(tensor.data_ptr()), tensor.sizes().data(), tensor.strides().data()); } std::tuple env_shade_fwd( OptiXStateWrapper& stateWrapper, torch::Tensor mask, torch::Tensor ro, torch::Tensor gb_pos, torch::Tensor gb_normal, torch::Tensor gb_view_pos, torch::Tensor gb_kd, torch::Tensor gb_ks, torch::Tensor light, torch::Tensor pdf, torch::Tensor rows, torch::Tensor cols, torch::Tensor perms, unsigned int BSDF, unsigned int n_samples_x, unsigned int rnd_seed, float shadow_scale) { // // launch OptiX kernel // cudaStream_t stream = at::cuda::getCurrentCUDAStream(); // Allocate output tensors. torch::TensorOptions opts = torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA); torch::Tensor diff = torch::zeros({ ro.size(0), ro.size(1), ro.size(2), 3 }, opts) ; torch::Tensor spec = torch::zeros({ ro.size(0), ro.size(1), ro.size(2), 3 }, opts) ; EnvSamplingParams params; params.handle = stateWrapper.pState->gas_handle; params.mask = packed_accessor32(mask); params.ro = packed_accessor32(ro); params.gb_pos = packed_accessor32(gb_pos); params.gb_normal = packed_accessor32(gb_normal); params.gb_view_pos = packed_accessor32(gb_view_pos); params.gb_kd = packed_accessor32(gb_kd); params.gb_ks = packed_accessor32(gb_ks); params.light = packed_accessor32(light); params.pdf = packed_accessor32(pdf); params.rows = packed_accessor32(rows); params.cols = packed_accessor32(cols); params.diff = packed_accessor32(diff); params.spec = packed_accessor32(spec); params.perms = packed_accessor32(perms); params.BSDF = BSDF; params.n_samples_x = n_samples_x; params.rnd_seed = rnd_seed; params.backward = 0; params.shadow_scale = shadow_scale; CUdeviceptr d_param; CUDA_CHECK( cudaMalloc( reinterpret_cast( &d_param ), sizeof( EnvSamplingParams ) ) ); CUDA_CHECK( cudaMemcpy( reinterpret_cast( d_param ), ¶ms, sizeof( params ), cudaMemcpyHostToDevice ) ); OPTIX_CHECK( optixLaunch( stateWrapper.pState->pipelineEnvSampling, stream, d_param, sizeof( EnvSamplingParams ), &stateWrapper.pState->sbtEnvSampling, ro.size(2), ro.size(1), ro.size(0) ) ); CUDA_CHECK( cudaStreamSynchronize( stream ) ); return std::tuple(diff, spec); } std::tuple env_shade_bwd( OptiXStateWrapper& stateWrapper, torch::Tensor mask, torch::Tensor ro, torch::Tensor gb_pos, torch::Tensor gb_normal, torch::Tensor gb_view_pos, torch::Tensor gb_kd, torch::Tensor gb_ks, torch::Tensor light, torch::Tensor pdf, torch::Tensor rows, torch::Tensor cols, torch::Tensor perms, unsigned int BSDF, unsigned int n_samples_x, unsigned int rnd_seed, float shadow_scale, torch::Tensor diff_grad, torch::Tensor spec_grad) { // // launch OptiX kernel // cudaStream_t stream = at::cuda::getCurrentCUDAStream(); // Allocate output tensors. torch::TensorOptions opts = torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA); EnvSamplingParams params; params.handle = stateWrapper.pState->gas_handle; params.mask = packed_accessor32(mask); params.ro = packed_accessor32(ro); params.gb_pos = packed_accessor32(gb_pos); params.gb_normal = packed_accessor32(gb_normal); params.gb_view_pos = packed_accessor32(gb_view_pos); params.gb_kd = packed_accessor32(gb_kd); params.gb_ks = packed_accessor32(gb_ks); params.light = packed_accessor32(light); params.pdf = packed_accessor32(pdf); params.rows = packed_accessor32(rows); params.cols = packed_accessor32(cols); params.diff_grad = packed_accessor32(diff_grad); params.spec_grad = packed_accessor32(spec_grad); params.perms = packed_accessor32(perms); params.BSDF = BSDF; params.n_samples_x = n_samples_x; params.rnd_seed = rnd_seed; params.backward = 1; params.shadow_scale = shadow_scale; // Create gradient tensor for pos torch::Tensor gb_pos_grad = torch::zeros({ ro.size(0), ro.size(1), ro.size(2), gb_pos.size(3) }, torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA)); params.gb_pos_grad = packed_accessor32(gb_pos_grad); torch::Tensor gb_normal_grad = torch::zeros({ ro.size(0), ro.size(1), ro.size(2), gb_normal.size(3) }, torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA)); params.gb_normal_grad = packed_accessor32(gb_normal_grad); torch::Tensor gb_kd_grad = torch::zeros({ ro.size(0), ro.size(1), ro.size(2), gb_kd.size(3) }, torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA)); params.gb_kd_grad = packed_accessor32(gb_kd_grad); torch::Tensor gb_ks_grad = torch::zeros({ ro.size(0), ro.size(1), ro.size(2), gb_ks.size(3) }, torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA)); params.gb_ks_grad = packed_accessor32(gb_ks_grad); // Create gradient tensor for light torch::Tensor light_grad = torch::zeros({ light.size(0), light.size(1), light.size(2) }, torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA)); params.light_grad = packed_accessor32(light_grad); CUdeviceptr d_param; CUDA_CHECK( cudaMalloc( reinterpret_cast( &d_param ), sizeof( EnvSamplingParams ) ) ); CUDA_CHECK( cudaMemcpy( reinterpret_cast( d_param ), ¶ms, sizeof( params ), cudaMemcpyHostToDevice ) ); OPTIX_CHECK( optixLaunch( stateWrapper.pState->pipelineEnvSampling, stream, d_param, sizeof( EnvSamplingParams ), &stateWrapper.pState->sbtEnvSampling, ro.size(2), ro.size(1), ro.size(0) ) ); CUDA_CHECK( cudaStreamSynchronize( stream ) ); return std::tuple(gb_pos_grad, gb_normal_grad, gb_kd_grad, gb_ks_grad, light_grad); } torch::Tensor bilateral_denoiser_fwd(torch::Tensor col, torch::Tensor nrm, torch::Tensor zdz, float sigma) { cudaStream_t stream = at::cuda::getCurrentCUDAStream(); torch::TensorOptions opts = torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA); torch::Tensor out = torch::zeros({ col.size(0), col.size(1), col.size(2), 4 }, opts); dim3 blockSize(8, 8, 1); dim3 gridSize((col.size(2) - 1) / blockSize.x + 1, (col.size(1) - 1) / blockSize.y + 1, (col.size(0) - 1) / blockSize.z + 1); BilateralDenoiserParams params; params.col = packed_accessor32(col); params.nrm = packed_accessor32(nrm); params.zdz = packed_accessor32(zdz); params.out = packed_accessor32(out); params.sigma = sigma; void *args[] = {¶ms}; CUDA_CHECK(cudaLaunchKernel((const void *)bilateral_denoiser_fwd_kernel, gridSize, blockSize, args, 0, stream)); return out; } torch::Tensor bilateral_denoiser_bwd(torch::Tensor col, torch::Tensor nrm, torch::Tensor zdz, float sigma, torch::Tensor out_grad) { cudaStream_t stream = at::cuda::getCurrentCUDAStream(); torch::TensorOptions opts = torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA); torch::Tensor col_grad = torch::zeros({ col.size(0), col.size(1), col.size(2), col.size(3) }, opts); dim3 blockSize(8, 8, 1); dim3 gridSize((col.size(2) - 1) / blockSize.x + 1, (col.size(1) - 1) / blockSize.y + 1, (col.size(0) - 1) / blockSize.z + 1); BilateralDenoiserParams params; params.col = packed_accessor32(col); params.nrm = packed_accessor32(nrm); params.zdz = packed_accessor32(zdz); params.out_grad = packed_accessor32(out_grad); params.col_grad = packed_accessor32(col_grad); params.sigma = sigma; void *args[] = {¶ms}; CUDA_CHECK(cudaLaunchKernel((const void *)bilateral_denoiser_bwd_kernel, gridSize, blockSize, args, 0, stream)); return col_grad; } PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { pybind11::class_(m, "OptiXStateWrapper").def(pybind11::init()); m.def("env_shade_fwd", &env_shade_fwd, "env_shade_fwd"); m.def("env_shade_bwd", &env_shade_bwd, "env_shade_bwd"); // m.def("env_shade_single_sided_fwd", &env_shade_fwd, "env_shade_single_sided_fwd"); // m.def("env_shade_single_sided_bwd", &env_shade_bwd, "env_shade_single_sided_bwd"); m.def("optix_build_bvh", &optix_build_bvh, "optix_build_bvh"); m.def("bilateral_denoiser_fwd", &bilateral_denoiser_fwd, "bilateral_denoiser_fwd"); m.def("bilateral_denoiser_bwd", &bilateral_denoiser_bwd, "bilateral_denoiser_bwd"); } ================================================ FILE: render/optixutils/include/internal/optix_7_device_impl.h ================================================ /* * Copyright (c) 2021 NVIDIA Corporation. All rights reserved. * * NVIDIA Corporation and its licensors retain all intellectual property and proprietary * rights in and to this software, related documentation and any modifications thereto. * Any use, reproduction, disclosure or distribution of this software and related * documentation without an express license agreement from NVIDIA Corporation is strictly * prohibited. * * TO THE MAXIMUM EXTENT PERMITTED BY APPLICABLE LAW, THIS SOFTWARE IS PROVIDED *AS IS* * AND NVIDIA AND ITS SUPPLIERS DISCLAIM ALL WARRANTIES, EITHER EXPRESS OR IMPLIED, * INCLUDING, BUT NOT LIMITED TO, IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A * PARTICULAR PURPOSE. IN NO EVENT SHALL NVIDIA OR ITS SUPPLIERS BE LIABLE FOR ANY * SPECIAL, INCIDENTAL, INDIRECT, OR CONSEQUENTIAL DAMAGES WHATSOEVER (INCLUDING, WITHOUT * LIMITATION, DAMAGES FOR LOSS OF BUSINESS PROFITS, BUSINESS INTERRUPTION, LOSS OF * BUSINESS INFORMATION, OR ANY OTHER PECUNIARY LOSS) ARISING OUT OF THE USE OF OR * INABILITY TO USE THIS SOFTWARE, EVEN IF NVIDIA HAS BEEN ADVISED OF THE POSSIBILITY OF * SUCH DAMAGES */ /** * @file optix_7_device_impl.h * @author NVIDIA Corporation * @brief OptiX public API * * OptiX public API Reference - Device side implementation */ #if !defined( __OPTIX_INCLUDE_INTERNAL_HEADERS__ ) #error("optix_7_device_impl.h is an internal header file and must not be used directly. Please use optix_device.h or optix.h instead.") #endif #ifndef __optix_optix_7_device_impl_h__ #define __optix_optix_7_device_impl_h__ #include "internal/optix_7_device_impl_exception.h" #include "internal/optix_7_device_impl_transformations.h" static __forceinline__ __device__ void optixTrace( OptixTraversableHandle handle, float3 rayOrigin, float3 rayDirection, float tmin, float tmax, float rayTime, OptixVisibilityMask visibilityMask, unsigned int rayFlags, unsigned int SBToffset, unsigned int SBTstride, unsigned int missSBTIndex ) { float ox = rayOrigin.x, oy = rayOrigin.y, oz = rayOrigin.z; float dx = rayDirection.x, dy = rayDirection.y, dz = rayDirection.z; unsigned int p0, p1, p2, p3, p4, p5, p6, p7, p8, p9, p10, p11, p12, p13, p14, p15, p16, p17, p18, p19, p20, p21, p22, p23, p24, p25, p26, p27, p28, p29, p30, p31; asm volatile( "call" "(%0,%1,%2,%3,%4,%5,%6,%7,%8,%9,%10,%11,%12,%13,%14,%15,%16,%17,%18,%19,%20,%21,%22,%23,%24,%25,%26,%27,%28,%" "29,%30,%31)," "_optix_trace_typed_32," "(%32,%33,%34,%35,%36,%37,%38,%39,%40,%41,%42,%43,%44,%45,%46,%47,%48,%49,%50,%51,%52,%53,%54,%55,%56,%57,%58,%" "59,%60,%61,%62,%63,%64,%65,%66,%67,%68,%69,%70,%71,%72,%73,%74,%75,%76,%77,%78,%79,%80);" : "=r"( p0 ), "=r"( p1 ), "=r"( p2 ), "=r"( p3 ), "=r"( p4 ), "=r"( p5 ), "=r"( p6 ), "=r"( p7 ), "=r"( p8 ), "=r"( p9 ), "=r"( p10 ), "=r"( p11 ), "=r"( p12 ), "=r"( p13 ), "=r"( p14 ), "=r"( p15 ), "=r"( p16 ), "=r"( p17 ), "=r"( p18 ), "=r"( p19 ), "=r"( p20 ), "=r"( p21 ), "=r"( p22 ), "=r"( p23 ), "=r"( p24 ), "=r"( p25 ), "=r"( p26 ), "=r"( p27 ), "=r"( p28 ), "=r"( p29 ), "=r"( p30 ), "=r"( p31 ) : "r"( 0 ), "l"( handle ), "f"( ox ), "f"( oy ), "f"( oz ), "f"( dx ), "f"( dy ), "f"( dz ), "f"( tmin ), "f"( tmax ), "f"( rayTime ), "r"( visibilityMask ), "r"( rayFlags ), "r"( SBToffset ), "r"( SBTstride ), "r"( missSBTIndex ), "r"( 0 ), "r"( p0 ), "r"( p1 ), "r"( p2 ), "r"( p3 ), "r"( p4 ), "r"( p5 ), "r"( p6 ), "r"( p7 ), "r"( p8 ), "r"( p9 ), "r"( p10 ), "r"( p11 ), "r"( p12 ), "r"( p13 ), "r"( p14 ), "r"( p15 ), "r"( p16 ), "r"( p17 ), "r"( p18 ), "r"( p19 ), "r"( p20 ), "r"( p21 ), "r"( p22 ), "r"( p23 ), "r"( p24 ), "r"( p25 ), "r"( p26 ), "r"( p27 ), "r"( p28 ), "r"( p29 ), "r"( p30 ), "r"( p31 ) : ); (void)p0, (void)p1, (void)p2, (void)p3, (void)p4, (void)p5, (void)p6, (void)p7, (void)p8, (void)p9, (void)p10, (void)p11, (void)p12, (void)p13, (void)p14, (void)p15, (void)p16, (void)p17, (void)p18, (void)p19, (void)p20, (void)p21, (void)p22, (void)p23, (void)p24, (void)p25, (void)p26, (void)p27, (void)p28, (void)p29, (void)p30, (void)p31; } static __forceinline__ __device__ void optixTrace( OptixTraversableHandle handle, float3 rayOrigin, float3 rayDirection, float tmin, float tmax, float rayTime, OptixVisibilityMask visibilityMask, unsigned int rayFlags, unsigned int SBToffset, unsigned int SBTstride, unsigned int missSBTIndex, unsigned int& p0 ) { float ox = rayOrigin.x, oy = rayOrigin.y, oz = rayOrigin.z; float dx = rayDirection.x, dy = rayDirection.y, dz = rayDirection.z; unsigned int p1, p2, p3, p4, p5, p6, p7, p8, p9, p10, p11, p12, p13, p14, p15, p16, p17, p18, p19, p20, p21, p22, p23, p24, p25, p26, p27, p28, p29, p30, p31; asm volatile( "call" "(%0,%1,%2,%3,%4,%5,%6,%7,%8,%9,%10,%11,%12,%13,%14,%15,%16,%17,%18,%19,%20,%21,%22,%23,%24,%25,%26,%27,%28,%" "29,%30,%31)," "_optix_trace_typed_32," "(%32,%33,%34,%35,%36,%37,%38,%39,%40,%41,%42,%43,%44,%45,%46,%47,%48,%49,%50,%51,%52,%53,%54,%55,%56,%57,%58,%" "59,%60,%61,%62,%63,%64,%65,%66,%67,%68,%69,%70,%71,%72,%73,%74,%75,%76,%77,%78,%79,%80);" : "=r"( p0 ), "=r"( p1 ), "=r"( p2 ), "=r"( p3 ), "=r"( p4 ), "=r"( p5 ), "=r"( p6 ), "=r"( p7 ), "=r"( p8 ), "=r"( p9 ), "=r"( p10 ), "=r"( p11 ), "=r"( p12 ), "=r"( p13 ), "=r"( p14 ), "=r"( p15 ), "=r"( p16 ), "=r"( p17 ), "=r"( p18 ), "=r"( p19 ), "=r"( p20 ), "=r"( p21 ), "=r"( p22 ), "=r"( p23 ), "=r"( p24 ), "=r"( p25 ), "=r"( p26 ), "=r"( p27 ), "=r"( p28 ), "=r"( p29 ), "=r"( p30 ), "=r"( p31 ) : "r"( 0 ), "l"( handle ), "f"( ox ), "f"( oy ), "f"( oz ), "f"( dx ), "f"( dy ), "f"( dz ), "f"( tmin ), "f"( tmax ), "f"( rayTime ), "r"( visibilityMask ), "r"( rayFlags ), "r"( SBToffset ), "r"( SBTstride ), "r"( missSBTIndex ), "r"( 1 ), "r"( p0 ), "r"( p1 ), "r"( p2 ), "r"( p3 ), "r"( p4 ), "r"( p5 ), "r"( p6 ), "r"( p7 ), "r"( p8 ), "r"( p9 ), "r"( p10 ), "r"( p11 ), "r"( p12 ), "r"( p13 ), "r"( p14 ), "r"( p15 ), "r"( p16 ), "r"( p17 ), "r"( p18 ), "r"( p19 ), "r"( p20 ), "r"( p21 ), "r"( p22 ), "r"( p23 ), "r"( p24 ), "r"( p25 ), "r"( p26 ), "r"( p27 ), "r"( p28 ), "r"( p29 ), "r"( p30 ), "r"( p31 ) : ); (void)p1, (void)p2, (void)p3, (void)p4, (void)p5, (void)p6, (void)p7, (void)p8, (void)p9, (void)p10, (void)p11, (void)p12, (void)p13, (void)p14, (void)p15, (void)p16, (void)p17, (void)p18, (void)p19, (void)p20, (void)p21, (void)p22, (void)p23, (void)p24, (void)p25, (void)p26, (void)p27, (void)p28, (void)p29, (void)p30, (void)p31; } static __forceinline__ __device__ void optixTrace( OptixTraversableHandle handle, float3 rayOrigin, float3 rayDirection, float tmin, float tmax, float rayTime, OptixVisibilityMask visibilityMask, unsigned int rayFlags, unsigned int SBToffset, unsigned int SBTstride, unsigned int missSBTIndex, unsigned int& p0, unsigned int& p1 ) { float ox = rayOrigin.x, oy = rayOrigin.y, oz = rayOrigin.z; float dx = rayDirection.x, dy = rayDirection.y, dz = rayDirection.z; unsigned int p2, p3, p4, p5, p6, p7, p8, p9, p10, p11, p12, p13, p14, p15, p16, p17, p18, p19, p20, p21, p22, p23, p24, p25, p26, p27, p28, p29, p30, p31; asm volatile( "call" "(%0,%1,%2,%3,%4,%5,%6,%7,%8,%9,%10,%11,%12,%13,%14,%15,%16,%17,%18,%19,%20,%21,%22,%23,%24,%25,%26,%27,%28,%" "29,%30,%31)," "_optix_trace_typed_32," "(%32,%33,%34,%35,%36,%37,%38,%39,%40,%41,%42,%43,%44,%45,%46,%47,%48,%49,%50,%51,%52,%53,%54,%55,%56,%57,%58,%" "59,%60,%61,%62,%63,%64,%65,%66,%67,%68,%69,%70,%71,%72,%73,%74,%75,%76,%77,%78,%79,%80);" : "=r"( p0 ), "=r"( p1 ), "=r"( p2 ), "=r"( p3 ), "=r"( p4 ), "=r"( p5 ), "=r"( p6 ), "=r"( p7 ), "=r"( p8 ), "=r"( p9 ), "=r"( p10 ), "=r"( p11 ), "=r"( p12 ), "=r"( p13 ), "=r"( p14 ), "=r"( p15 ), "=r"( p16 ), "=r"( p17 ), "=r"( p18 ), "=r"( p19 ), "=r"( p20 ), "=r"( p21 ), "=r"( p22 ), "=r"( p23 ), "=r"( p24 ), "=r"( p25 ), "=r"( p26 ), "=r"( p27 ), "=r"( p28 ), "=r"( p29 ), "=r"( p30 ), "=r"( p31 ) : "r"( 0 ), "l"( handle ), "f"( ox ), "f"( oy ), "f"( oz ), "f"( dx ), "f"( dy ), "f"( dz ), "f"( tmin ), "f"( tmax ), "f"( rayTime ), "r"( visibilityMask ), "r"( rayFlags ), "r"( SBToffset ), "r"( SBTstride ), "r"( missSBTIndex ), "r"( 2 ), "r"( p0 ), "r"( p1 ), "r"( p2 ), "r"( p3 ), "r"( p4 ), "r"( p5 ), "r"( p6 ), "r"( p7 ), "r"( p8 ), "r"( p9 ), "r"( p10 ), "r"( p11 ), "r"( p12 ), "r"( p13 ), "r"( p14 ), "r"( p15 ), "r"( p16 ), "r"( p17 ), "r"( p18 ), "r"( p19 ), "r"( p20 ), "r"( p21 ), "r"( p22 ), "r"( p23 ), "r"( p24 ), "r"( p25 ), "r"( p26 ), "r"( p27 ), "r"( p28 ), "r"( p29 ), "r"( p30 ), "r"( p31 ) : ); (void)p2, (void)p3, (void)p4, (void)p5, (void)p6, (void)p7, (void)p8, (void)p9, (void)p10, (void)p11, (void)p12, (void)p13, (void)p14, (void)p15, (void)p16, (void)p17, (void)p18, (void)p19, (void)p20, (void)p21, (void)p22, (void)p23, (void)p24, (void)p25, (void)p26, (void)p27, (void)p28, (void)p29, (void)p30, (void)p31; } static __forceinline__ __device__ void optixTrace( OptixTraversableHandle handle, float3 rayOrigin, float3 rayDirection, float tmin, float tmax, float rayTime, OptixVisibilityMask visibilityMask, unsigned int rayFlags, unsigned int SBToffset, unsigned int SBTstride, unsigned int missSBTIndex, unsigned int& p0, unsigned int& p1, unsigned int& p2 ) { float ox = rayOrigin.x, oy = rayOrigin.y, oz = rayOrigin.z; float dx = rayDirection.x, dy = rayDirection.y, dz = rayDirection.z; unsigned int p3, p4, p5, p6, p7, p8, p9, p10, p11, p12, p13, p14, p15, p16, p17, p18, p19, p20, p21, p22, p23, p24, p25, p26, p27, p28, p29, p30, p31; asm volatile( "call" "(%0,%1,%2,%3,%4,%5,%6,%7,%8,%9,%10,%11,%12,%13,%14,%15,%16,%17,%18,%19,%20,%21,%22,%23,%24,%25,%26,%27,%28,%" "29,%30,%31)," "_optix_trace_typed_32," "(%32,%33,%34,%35,%36,%37,%38,%39,%40,%41,%42,%43,%44,%45,%46,%47,%48,%49,%50,%51,%52,%53,%54,%55,%56,%57,%58,%" "59,%60,%61,%62,%63,%64,%65,%66,%67,%68,%69,%70,%71,%72,%73,%74,%75,%76,%77,%78,%79,%80);" : "=r"( p0 ), "=r"( p1 ), "=r"( p2 ), "=r"( p3 ), "=r"( p4 ), "=r"( p5 ), "=r"( p6 ), "=r"( p7 ), "=r"( p8 ), "=r"( p9 ), "=r"( p10 ), "=r"( p11 ), "=r"( p12 ), "=r"( p13 ), "=r"( p14 ), "=r"( p15 ), "=r"( p16 ), "=r"( p17 ), "=r"( p18 ), "=r"( p19 ), "=r"( p20 ), "=r"( p21 ), "=r"( p22 ), "=r"( p23 ), "=r"( p24 ), "=r"( p25 ), "=r"( p26 ), "=r"( p27 ), "=r"( p28 ), "=r"( p29 ), "=r"( p30 ), "=r"( p31 ) : "r"( 0 ), "l"( handle ), "f"( ox ), "f"( oy ), "f"( oz ), "f"( dx ), "f"( dy ), "f"( dz ), "f"( tmin ), "f"( tmax ), "f"( rayTime ), "r"( visibilityMask ), "r"( rayFlags ), "r"( SBToffset ), "r"( SBTstride ), "r"( missSBTIndex ), "r"( 3 ), "r"( p0 ), "r"( p1 ), "r"( p2 ), "r"( p3 ), "r"( p4 ), "r"( p5 ), "r"( p6 ), "r"( p7 ), "r"( p8 ), "r"( p9 ), "r"( p10 ), "r"( p11 ), "r"( p12 ), "r"( p13 ), "r"( p14 ), "r"( p15 ), "r"( p16 ), "r"( p17 ), "r"( p18 ), "r"( p19 ), "r"( p20 ), "r"( p21 ), "r"( p22 ), "r"( p23 ), "r"( p24 ), "r"( p25 ), "r"( p26 ), "r"( p27 ), "r"( p28 ), "r"( p29 ), "r"( p30 ), "r"( p31 ) : ); (void)p3, (void)p4, (void)p5, (void)p6, (void)p7, (void)p8, (void)p9, (void)p10, (void)p11, (void)p12, (void)p13, (void)p14, (void)p15, (void)p16, (void)p17, (void)p18, (void)p19, (void)p20, (void)p21, (void)p22, (void)p23, (void)p24, (void)p25, (void)p26, (void)p27, (void)p28, (void)p29, (void)p30, (void)p31; } static __forceinline__ __device__ void optixTrace( OptixTraversableHandle handle, float3 rayOrigin, float3 rayDirection, float tmin, float tmax, float rayTime, OptixVisibilityMask visibilityMask, unsigned int rayFlags, unsigned int SBToffset, unsigned int SBTstride, unsigned int missSBTIndex, unsigned int& p0, unsigned int& p1, unsigned int& p2, unsigned int& p3 ) { float ox = rayOrigin.x, oy = rayOrigin.y, oz = rayOrigin.z; float dx = rayDirection.x, dy = rayDirection.y, dz = rayDirection.z; unsigned int p4, p5, p6, p7, p8, p9, p10, p11, p12, p13, p14, p15, p16, p17, p18, p19, p20, p21, p22, p23, p24, p25, p26, p27, p28, p29, p30, p31; asm volatile( "call" "(%0,%1,%2,%3,%4,%5,%6,%7,%8,%9,%10,%11,%12,%13,%14,%15,%16,%17,%18,%19,%20,%21,%22,%23,%24,%25,%26,%27,%28,%" "29,%30,%31)," "_optix_trace_typed_32," "(%32,%33,%34,%35,%36,%37,%38,%39,%40,%41,%42,%43,%44,%45,%46,%47,%48,%49,%50,%51,%52,%53,%54,%55,%56,%57,%58,%" "59,%60,%61,%62,%63,%64,%65,%66,%67,%68,%69,%70,%71,%72,%73,%74,%75,%76,%77,%78,%79,%80);" : "=r"( p0 ), "=r"( p1 ), "=r"( p2 ), "=r"( p3 ), "=r"( p4 ), "=r"( p5 ), "=r"( p6 ), "=r"( p7 ), "=r"( p8 ), "=r"( p9 ), "=r"( p10 ), "=r"( p11 ), "=r"( p12 ), "=r"( p13 ), "=r"( p14 ), "=r"( p15 ), "=r"( p16 ), "=r"( p17 ), "=r"( p18 ), "=r"( p19 ), "=r"( p20 ), "=r"( p21 ), "=r"( p22 ), "=r"( p23 ), "=r"( p24 ), "=r"( p25 ), "=r"( p26 ), "=r"( p27 ), "=r"( p28 ), "=r"( p29 ), "=r"( p30 ), "=r"( p31 ) : "r"( 0 ), "l"( handle ), "f"( ox ), "f"( oy ), "f"( oz ), "f"( dx ), "f"( dy ), "f"( dz ), "f"( tmin ), "f"( tmax ), "f"( rayTime ), "r"( visibilityMask ), "r"( rayFlags ), "r"( SBToffset ), "r"( SBTstride ), "r"( missSBTIndex ), "r"( 4 ), "r"( p0 ), "r"( p1 ), "r"( p2 ), "r"( p3 ), "r"( p4 ), "r"( p5 ), "r"( p6 ), "r"( p7 ), "r"( p8 ), "r"( p9 ), "r"( p10 ), "r"( p11 ), "r"( p12 ), "r"( p13 ), "r"( p14 ), "r"( p15 ), "r"( p16 ), "r"( p17 ), "r"( p18 ), "r"( p19 ), "r"( p20 ), "r"( p21 ), "r"( p22 ), "r"( p23 ), "r"( p24 ), "r"( p25 ), "r"( p26 ), "r"( p27 ), "r"( p28 ), "r"( p29 ), "r"( p30 ), "r"( p31 ) : ); (void)p4, (void)p5, (void)p6, (void)p7, (void)p8, (void)p9, (void)p10, (void)p11, (void)p12, (void)p13, (void)p14, (void)p15, (void)p16, (void)p17, (void)p18, (void)p19, (void)p20, (void)p21, (void)p22, (void)p23, (void)p24, (void)p25, (void)p26, (void)p27, (void)p28, (void)p29, (void)p30, (void)p31; } static __forceinline__ __device__ void optixTrace( OptixTraversableHandle handle, float3 rayOrigin, float3 rayDirection, float tmin, float tmax, float rayTime, OptixVisibilityMask visibilityMask, unsigned int rayFlags, unsigned int SBToffset, unsigned int SBTstride, unsigned int missSBTIndex, unsigned int& p0, unsigned int& p1, unsigned int& p2, unsigned int& p3, unsigned int& p4 ) { float ox = rayOrigin.x, oy = rayOrigin.y, oz = rayOrigin.z; float dx = rayDirection.x, dy = rayDirection.y, dz = rayDirection.z; unsigned int p5, p6, p7, p8, p9, p10, p11, p12, p13, p14, p15, p16, p17, p18, p19, p20, p21, p22, p23, p24, p25, p26, p27, p28, p29, p30, p31; asm volatile( "call" "(%0,%1,%2,%3,%4,%5,%6,%7,%8,%9,%10,%11,%12,%13,%14,%15,%16,%17,%18,%19,%20,%21,%22,%23,%24,%25,%26,%27,%28,%" "29,%30,%31)," "_optix_trace_typed_32," "(%32,%33,%34,%35,%36,%37,%38,%39,%40,%41,%42,%43,%44,%45,%46,%47,%48,%49,%50,%51,%52,%53,%54,%55,%56,%57,%58,%" "59,%60,%61,%62,%63,%64,%65,%66,%67,%68,%69,%70,%71,%72,%73,%74,%75,%76,%77,%78,%79,%80);" : "=r"( p0 ), "=r"( p1 ), "=r"( p2 ), "=r"( p3 ), "=r"( p4 ), "=r"( p5 ), "=r"( p6 ), "=r"( p7 ), "=r"( p8 ), "=r"( p9 ), "=r"( p10 ), "=r"( p11 ), "=r"( p12 ), "=r"( p13 ), "=r"( p14 ), "=r"( p15 ), "=r"( p16 ), "=r"( p17 ), "=r"( p18 ), "=r"( p19 ), "=r"( p20 ), "=r"( p21 ), "=r"( p22 ), "=r"( p23 ), "=r"( p24 ), "=r"( p25 ), "=r"( p26 ), "=r"( p27 ), "=r"( p28 ), "=r"( p29 ), "=r"( p30 ), "=r"( p31 ) : "r"( 0 ), "l"( handle ), "f"( ox ), "f"( oy ), "f"( oz ), "f"( dx ), "f"( dy ), "f"( dz ), "f"( tmin ), "f"( tmax ), "f"( rayTime ), "r"( visibilityMask ), "r"( rayFlags ), "r"( SBToffset ), "r"( SBTstride ), "r"( missSBTIndex ), "r"( 5 ), "r"( p0 ), "r"( p1 ), "r"( p2 ), "r"( p3 ), "r"( p4 ), "r"( p5 ), "r"( p6 ), "r"( p7 ), "r"( p8 ), "r"( p9 ), "r"( p10 ), "r"( p11 ), "r"( p12 ), "r"( p13 ), "r"( p14 ), "r"( p15 ), "r"( p16 ), "r"( p17 ), "r"( p18 ), "r"( p19 ), "r"( p20 ), "r"( p21 ), "r"( p22 ), "r"( p23 ), "r"( p24 ), "r"( p25 ), "r"( p26 ), "r"( p27 ), "r"( p28 ), "r"( p29 ), "r"( p30 ), "r"( p31 ) : ); (void)p5, (void)p6, (void)p7, (void)p8, (void)p9, (void)p10, (void)p11, (void)p12, (void)p13, (void)p14, (void)p15, (void)p16, (void)p17, (void)p18, (void)p19, (void)p20, (void)p21, (void)p22, (void)p23, (void)p24, (void)p25, (void)p26, (void)p27, (void)p28, (void)p29, (void)p30, (void)p31; } static __forceinline__ __device__ void optixTrace( OptixTraversableHandle handle, float3 rayOrigin, float3 rayDirection, float tmin, float tmax, float rayTime, OptixVisibilityMask visibilityMask, unsigned int rayFlags, unsigned int SBToffset, unsigned int SBTstride, unsigned int missSBTIndex, unsigned int& p0, unsigned int& p1, unsigned int& p2, unsigned int& p3, unsigned int& p4, unsigned int& p5 ) { float ox = rayOrigin.x, oy = rayOrigin.y, oz = rayOrigin.z; float dx = rayDirection.x, dy = rayDirection.y, dz = rayDirection.z; unsigned int p6, p7, p8, p9, p10, p11, p12, p13, p14, p15, p16, p17, p18, p19, p20, p21, p22, p23, p24, p25, p26, p27, p28, p29, p30, p31; asm volatile( "call" "(%0,%1,%2,%3,%4,%5,%6,%7,%8,%9,%10,%11,%12,%13,%14,%15,%16,%17,%18,%19,%20,%21,%22,%23,%24,%25,%26,%27,%28,%" "29,%30,%31)," "_optix_trace_typed_32," "(%32,%33,%34,%35,%36,%37,%38,%39,%40,%41,%42,%43,%44,%45,%46,%47,%48,%49,%50,%51,%52,%53,%54,%55,%56,%57,%58,%" "59,%60,%61,%62,%63,%64,%65,%66,%67,%68,%69,%70,%71,%72,%73,%74,%75,%76,%77,%78,%79,%80);" : "=r"( p0 ), "=r"( p1 ), "=r"( p2 ), "=r"( p3 ), "=r"( p4 ), "=r"( p5 ), "=r"( p6 ), "=r"( p7 ), "=r"( p8 ), "=r"( p9 ), "=r"( p10 ), "=r"( p11 ), "=r"( p12 ), "=r"( p13 ), "=r"( p14 ), "=r"( p15 ), "=r"( p16 ), "=r"( p17 ), "=r"( p18 ), "=r"( p19 ), "=r"( p20 ), "=r"( p21 ), "=r"( p22 ), "=r"( p23 ), "=r"( p24 ), "=r"( p25 ), "=r"( p26 ), "=r"( p27 ), "=r"( p28 ), "=r"( p29 ), "=r"( p30 ), "=r"( p31 ) : "r"( 0 ), "l"( handle ), "f"( ox ), "f"( oy ), "f"( oz ), "f"( dx ), "f"( dy ), "f"( dz ), "f"( tmin ), "f"( tmax ), "f"( rayTime ), "r"( visibilityMask ), "r"( rayFlags ), "r"( SBToffset ), "r"( SBTstride ), "r"( missSBTIndex ), "r"( 6 ), "r"( p0 ), "r"( p1 ), "r"( p2 ), "r"( p3 ), "r"( p4 ), "r"( p5 ), "r"( p6 ), "r"( p7 ), "r"( p8 ), "r"( p9 ), "r"( p10 ), "r"( p11 ), "r"( p12 ), "r"( p13 ), "r"( p14 ), "r"( p15 ), "r"( p16 ), "r"( p17 ), "r"( p18 ), "r"( p19 ), "r"( p20 ), "r"( p21 ), "r"( p22 ), "r"( p23 ), "r"( p24 ), "r"( p25 ), "r"( p26 ), "r"( p27 ), "r"( p28 ), "r"( p29 ), "r"( p30 ), "r"( p31 ) : ); (void)p6, (void)p7, (void)p8, (void)p9, (void)p10, (void)p11, (void)p12, (void)p13, (void)p14, (void)p15, (void)p16, (void)p17, (void)p18, (void)p19, (void)p20, (void)p21, (void)p22, (void)p23, (void)p24, (void)p25, (void)p26, (void)p27, (void)p28, (void)p29, (void)p30, (void)p31; } static __forceinline__ __device__ void optixTrace( OptixTraversableHandle handle, float3 rayOrigin, float3 rayDirection, float tmin, float tmax, float rayTime, OptixVisibilityMask visibilityMask, unsigned int rayFlags, unsigned int SBToffset, unsigned int SBTstride, unsigned int missSBTIndex, unsigned int& p0, unsigned int& p1, unsigned int& p2, unsigned int& p3, unsigned int& p4, unsigned int& p5, unsigned int& p6 ) { float ox = rayOrigin.x, oy = rayOrigin.y, oz = rayOrigin.z; float dx = rayDirection.x, dy = rayDirection.y, dz = rayDirection.z; unsigned int p7, p8, p9, p10, p11, p12, p13, p14, p15, p16, p17, p18, p19, p20, p21, p22, p23, p24, p25, p26, p27, p28, p29, p30, p31; asm volatile( "call" "(%0,%1,%2,%3,%4,%5,%6,%7,%8,%9,%10,%11,%12,%13,%14,%15,%16,%17,%18,%19,%20,%21,%22,%23,%24,%25,%26,%27,%28,%" "29,%30,%31)," "_optix_trace_typed_32," "(%32,%33,%34,%35,%36,%37,%38,%39,%40,%41,%42,%43,%44,%45,%46,%47,%48,%49,%50,%51,%52,%53,%54,%55,%56,%57,%58,%" "59,%60,%61,%62,%63,%64,%65,%66,%67,%68,%69,%70,%71,%72,%73,%74,%75,%76,%77,%78,%79,%80);" : "=r"( p0 ), "=r"( p1 ), "=r"( p2 ), "=r"( p3 ), "=r"( p4 ), "=r"( p5 ), "=r"( p6 ), "=r"( p7 ), "=r"( p8 ), "=r"( p9 ), "=r"( p10 ), "=r"( p11 ), "=r"( p12 ), "=r"( p13 ), "=r"( p14 ), "=r"( p15 ), "=r"( p16 ), "=r"( p17 ), "=r"( p18 ), "=r"( p19 ), "=r"( p20 ), "=r"( p21 ), "=r"( p22 ), "=r"( p23 ), "=r"( p24 ), "=r"( p25 ), "=r"( p26 ), "=r"( p27 ), "=r"( p28 ), "=r"( p29 ), "=r"( p30 ), "=r"( p31 ) : "r"( 0 ), "l"( handle ), "f"( ox ), "f"( oy ), "f"( oz ), "f"( dx ), "f"( dy ), "f"( dz ), "f"( tmin ), "f"( tmax ), "f"( rayTime ), "r"( visibilityMask ), "r"( rayFlags ), "r"( SBToffset ), "r"( SBTstride ), "r"( missSBTIndex ), "r"( 7 ), "r"( p0 ), "r"( p1 ), "r"( p2 ), "r"( p3 ), "r"( p4 ), "r"( p5 ), "r"( p6 ), "r"( p7 ), "r"( p8 ), "r"( p9 ), "r"( p10 ), "r"( p11 ), "r"( p12 ), "r"( p13 ), "r"( p14 ), "r"( p15 ), "r"( p16 ), "r"( p17 ), "r"( p18 ), "r"( p19 ), "r"( p20 ), "r"( p21 ), "r"( p22 ), "r"( p23 ), "r"( p24 ), "r"( p25 ), "r"( p26 ), "r"( p27 ), "r"( p28 ), "r"( p29 ), "r"( p30 ), "r"( p31 ) : ); (void)p7, (void)p8, (void)p9, (void)p10, (void)p11, (void)p12, (void)p13, (void)p14, (void)p15, (void)p16, (void)p17, (void)p18, (void)p19, (void)p20, (void)p21, (void)p22, (void)p23, (void)p24, (void)p25, (void)p26, (void)p27, (void)p28, (void)p29, (void)p30, (void)p31; } static __forceinline__ __device__ void optixTrace( OptixTraversableHandle handle, float3 rayOrigin, float3 rayDirection, float tmin, float tmax, float rayTime, OptixVisibilityMask visibilityMask, unsigned int rayFlags, unsigned int SBToffset, unsigned int SBTstride, unsigned int missSBTIndex, unsigned int& p0, unsigned int& p1, unsigned int& p2, unsigned int& p3, unsigned int& p4, unsigned int& p5, unsigned int& p6, unsigned int& p7 ) { float ox = rayOrigin.x, oy = rayOrigin.y, oz = rayOrigin.z; float dx = rayDirection.x, dy = rayDirection.y, dz = rayDirection.z; unsigned int p8, p9, p10, p11, p12, p13, p14, p15, p16, p17, p18, p19, p20, p21, p22, p23, p24, p25, p26, p27, p28, p29, p30, p31; asm volatile( "call" "(%0,%1,%2,%3,%4,%5,%6,%7,%8,%9,%10,%11,%12,%13,%14,%15,%16,%17,%18,%19,%20,%21,%22,%23,%24,%25,%26,%27,%28,%" "29,%30,%31)," "_optix_trace_typed_32," "(%32,%33,%34,%35,%36,%37,%38,%39,%40,%41,%42,%43,%44,%45,%46,%47,%48,%49,%50,%51,%52,%53,%54,%55,%56,%57,%58,%" "59,%60,%61,%62,%63,%64,%65,%66,%67,%68,%69,%70,%71,%72,%73,%74,%75,%76,%77,%78,%79,%80);" : "=r"( p0 ), "=r"( p1 ), "=r"( p2 ), "=r"( p3 ), "=r"( p4 ), "=r"( p5 ), "=r"( p6 ), "=r"( p7 ), "=r"( p8 ), "=r"( p9 ), "=r"( p10 ), "=r"( p11 ), "=r"( p12 ), "=r"( p13 ), "=r"( p14 ), "=r"( p15 ), "=r"( p16 ), "=r"( p17 ), "=r"( p18 ), "=r"( p19 ), "=r"( p20 ), "=r"( p21 ), "=r"( p22 ), "=r"( p23 ), "=r"( p24 ), "=r"( p25 ), "=r"( p26 ), "=r"( p27 ), "=r"( p28 ), "=r"( p29 ), "=r"( p30 ), "=r"( p31 ) : "r"( 0 ), "l"( handle ), "f"( ox ), "f"( oy ), "f"( oz ), "f"( dx ), "f"( dy ), "f"( dz ), "f"( tmin ), "f"( tmax ), "f"( rayTime ), "r"( visibilityMask ), "r"( rayFlags ), "r"( SBToffset ), "r"( SBTstride ), "r"( missSBTIndex ), "r"( 8 ), "r"( p0 ), "r"( p1 ), "r"( p2 ), "r"( p3 ), "r"( p4 ), "r"( p5 ), "r"( p6 ), "r"( p7 ), "r"( p8 ), "r"( p9 ), "r"( p10 ), "r"( p11 ), "r"( p12 ), "r"( p13 ), "r"( p14 ), "r"( p15 ), "r"( p16 ), "r"( p17 ), "r"( p18 ), "r"( p19 ), "r"( p20 ), "r"( p21 ), "r"( p22 ), "r"( p23 ), "r"( p24 ), "r"( p25 ), "r"( p26 ), "r"( p27 ), "r"( p28 ), "r"( p29 ), "r"( p30 ), "r"( p31 ) : ); (void)p8, (void)p9, (void)p10, (void)p11, (void)p12, (void)p13, (void)p14, (void)p15, (void)p16, (void)p17, (void)p18, (void)p19, (void)p20, (void)p21, (void)p22, (void)p23, (void)p24, (void)p25, (void)p26, (void)p27, (void)p28, (void)p29, (void)p30, (void)p31; } static __forceinline__ __device__ void optixSetPayload_0( unsigned int p ) { asm volatile( "call _optix_set_payload, (%0, %1);" : : "r"( 0 ), "r"( p ) : ); } static __forceinline__ __device__ void optixSetPayload_1( unsigned int p ) { asm volatile( "call _optix_set_payload, (%0, %1);" : : "r"( 1 ), "r"( p ) : ); } static __forceinline__ __device__ void optixSetPayload_2( unsigned int p ) { asm volatile( "call _optix_set_payload, (%0, %1);" : : "r"( 2 ), "r"( p ) : ); } static __forceinline__ __device__ void optixSetPayload_3( unsigned int p ) { asm volatile( "call _optix_set_payload, (%0, %1);" : : "r"( 3 ), "r"( p ) : ); } static __forceinline__ __device__ void optixSetPayload_4( unsigned int p ) { asm volatile( "call _optix_set_payload, (%0, %1);" : : "r"( 4 ), "r"( p ) : ); } static __forceinline__ __device__ void optixSetPayload_5( unsigned int p ) { asm volatile( "call _optix_set_payload, (%0, %1);" : : "r"( 5 ), "r"( p ) : ); } static __forceinline__ __device__ void optixSetPayload_6( unsigned int p ) { asm volatile( "call _optix_set_payload, (%0, %1);" : : "r"( 6 ), "r"( p ) : ); } static __forceinline__ __device__ void optixSetPayload_7( unsigned int p ) { asm volatile( "call _optix_set_payload, (%0, %1);" : : "r"( 7 ), "r"( p ) : ); } static __forceinline__ __device__ unsigned int optixGetPayload_0() { unsigned int result; asm volatile( "call (%0), _optix_get_payload, (%1);" : "=r"( result ) : "r"( 0 ) : ); return result; } static __forceinline__ __device__ unsigned int optixGetPayload_1() { unsigned int result; asm volatile( "call (%0), _optix_get_payload, (%1);" : "=r"( result ) : "r"( 1 ) : ); return result; } static __forceinline__ __device__ unsigned int optixGetPayload_2() { unsigned int result; asm volatile( "call (%0), _optix_get_payload, (%1);" : "=r"( result ) : "r"( 2 ) : ); return result; } static __forceinline__ __device__ unsigned int optixGetPayload_3() { unsigned int result; asm volatile( "call (%0), _optix_get_payload, (%1);" : "=r"( result ) : "r"( 3 ) : ); return result; } static __forceinline__ __device__ unsigned int optixGetPayload_4() { unsigned int result; asm volatile( "call (%0), _optix_get_payload, (%1);" : "=r"( result ) : "r"( 4 ) : ); return result; } static __forceinline__ __device__ unsigned int optixGetPayload_5() { unsigned int result; asm volatile( "call (%0), _optix_get_payload, (%1);" : "=r"( result ) : "r"( 5 ) : ); return result; } static __forceinline__ __device__ unsigned int optixGetPayload_6() { unsigned int result; asm volatile( "call (%0), _optix_get_payload, (%1);" : "=r"( result ) : "r"( 6 ) : ); return result; } static __forceinline__ __device__ unsigned int optixGetPayload_7() { unsigned int result; asm volatile( "call (%0), _optix_get_payload, (%1);" : "=r"( result ) : "r"( 7 ) : ); return result; } static __forceinline__ __device__ unsigned int optixUndefinedValue() { unsigned int u0; asm( "call (%0), _optix_undef_value, ();" : "=r"( u0 ) : ); return u0; } static __forceinline__ __device__ float3 optixGetWorldRayOrigin() { float f0, f1, f2; asm( "call (%0), _optix_get_world_ray_origin_x, ();" : "=f"( f0 ) : ); asm( "call (%0), _optix_get_world_ray_origin_y, ();" : "=f"( f1 ) : ); asm( "call (%0), _optix_get_world_ray_origin_z, ();" : "=f"( f2 ) : ); return make_float3( f0, f1, f2 ); } static __forceinline__ __device__ float3 optixGetWorldRayDirection() { float f0, f1, f2; asm( "call (%0), _optix_get_world_ray_direction_x, ();" : "=f"( f0 ) : ); asm( "call (%0), _optix_get_world_ray_direction_y, ();" : "=f"( f1 ) : ); asm( "call (%0), _optix_get_world_ray_direction_z, ();" : "=f"( f2 ) : ); return make_float3( f0, f1, f2 ); } static __forceinline__ __device__ float3 optixGetObjectRayOrigin() { float f0, f1, f2; asm( "call (%0), _optix_get_object_ray_origin_x, ();" : "=f"( f0 ) : ); asm( "call (%0), _optix_get_object_ray_origin_y, ();" : "=f"( f1 ) : ); asm( "call (%0), _optix_get_object_ray_origin_z, ();" : "=f"( f2 ) : ); return make_float3( f0, f1, f2 ); } static __forceinline__ __device__ float3 optixGetObjectRayDirection() { float f0, f1, f2; asm( "call (%0), _optix_get_object_ray_direction_x, ();" : "=f"( f0 ) : ); asm( "call (%0), _optix_get_object_ray_direction_y, ();" : "=f"( f1 ) : ); asm( "call (%0), _optix_get_object_ray_direction_z, ();" : "=f"( f2 ) : ); return make_float3( f0, f1, f2 ); } static __forceinline__ __device__ float optixGetRayTmin() { float f0; asm( "call (%0), _optix_get_ray_tmin, ();" : "=f"( f0 ) : ); return f0; } static __forceinline__ __device__ float optixGetRayTmax() { float f0; asm( "call (%0), _optix_get_ray_tmax, ();" : "=f"( f0 ) : ); return f0; } static __forceinline__ __device__ float optixGetRayTime() { float f0; asm( "call (%0), _optix_get_ray_time, ();" : "=f"( f0 ) : ); return f0; } static __forceinline__ __device__ unsigned int optixGetRayFlags() { unsigned int u0; asm( "call (%0), _optix_get_ray_flags, ();" : "=r"( u0 ) : ); return u0; } static __forceinline__ __device__ unsigned int optixGetRayVisibilityMask() { unsigned int u0; asm( "call (%0), _optix_get_ray_visibility_mask, ();" : "=r"( u0 ) : ); return u0; } static __forceinline__ __device__ OptixTraversableHandle optixGetInstanceTraversableFromIAS( OptixTraversableHandle ias, unsigned int instIdx ) { unsigned long long handle; asm( "call (%0), _optix_get_instance_traversable_from_ias, (%1, %2);" : "=l"( handle ) : "l"( ias ), "r"( instIdx ) ); return (OptixTraversableHandle)handle; } static __forceinline__ __device__ void optixGetTriangleVertexData( OptixTraversableHandle gas, unsigned int primIdx, unsigned int sbtGASIndex, float time, float3 data[3] ) { asm( "call (%0, %1, %2, %3, %4, %5, %6, %7, %8), _optix_get_triangle_vertex_data, " "(%9, %10, %11, %12);" : "=f"( data[0].x ), "=f"( data[0].y ), "=f"( data[0].z ), "=f"( data[1].x ), "=f"( data[1].y ), "=f"( data[1].z ), "=f"( data[2].x ), "=f"( data[2].y ), "=f"( data[2].z ) : "l"( gas ), "r"( primIdx ), "r"( sbtGASIndex ), "f"( time ) : ); } static __forceinline__ __device__ void optixGetLinearCurveVertexData( OptixTraversableHandle gas, unsigned int primIdx, unsigned int sbtGASIndex, float time, float4 data[2] ) { asm( "call (%0, %1, %2, %3, %4, %5, %6, %7), _optix_get_linear_curve_vertex_data, " "(%8, %9, %10, %11);" : "=f"( data[0].x ), "=f"( data[0].y ), "=f"( data[0].z ), "=f"( data[0].w ), "=f"( data[1].x ), "=f"( data[1].y ), "=f"( data[1].z ), "=f"( data[1].w ) : "l"( gas ), "r"( primIdx ), "r"( sbtGASIndex ), "f"( time ) : ); } static __forceinline__ __device__ void optixGetQuadraticBSplineVertexData( OptixTraversableHandle gas, unsigned int primIdx, unsigned int sbtGASIndex, float time, float4 data[3] ) { asm( "call (%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11), _optix_get_quadratic_bspline_vertex_data, " "(%12, %13, %14, %15);" : "=f"( data[0].x ), "=f"( data[0].y ), "=f"( data[0].z ), "=f"( data[0].w ), "=f"( data[1].x ), "=f"( data[1].y ), "=f"( data[1].z ), "=f"( data[1].w ), "=f"( data[2].x ), "=f"( data[2].y ), "=f"( data[2].z ), "=f"( data[2].w ) : "l"( gas ), "r"( primIdx ), "r"( sbtGASIndex ), "f"( time ) : ); } static __forceinline__ __device__ void optixGetCubicBSplineVertexData( OptixTraversableHandle gas, unsigned int primIdx, unsigned int sbtGASIndex, float time, float4 data[4] ) { asm( "call (%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15), " "_optix_get_cubic_bspline_vertex_data, " "(%16, %17, %18, %19);" : "=f"( data[0].x ), "=f"( data[0].y ), "=f"( data[0].z ), "=f"( data[0].w ), "=f"( data[1].x ), "=f"( data[1].y ), "=f"( data[1].z ), "=f"( data[1].w ), "=f"( data[2].x ), "=f"( data[2].y ), "=f"( data[2].z ), "=f"( data[2].w ), "=f"( data[3].x ), "=f"( data[3].y ), "=f"( data[3].z ), "=f"( data[3].w ) : "l"( gas ), "r"( primIdx ), "r"( sbtGASIndex ), "f"( time ) : ); } static __forceinline__ __device__ OptixTraversableHandle optixGetGASTraversableHandle() { unsigned long long handle; asm( "call (%0), _optix_get_gas_traversable_handle, ();" : "=l"( handle ) : ); return (OptixTraversableHandle)handle; } static __forceinline__ __device__ float optixGetGASMotionTimeBegin( OptixTraversableHandle handle ) { float f0; asm( "call (%0), _optix_get_gas_motion_time_begin, (%1);" : "=f"( f0 ) : "l"( handle ) : ); return f0; } static __forceinline__ __device__ float optixGetGASMotionTimeEnd( OptixTraversableHandle handle ) { float f0; asm( "call (%0), _optix_get_gas_motion_time_end, (%1);" : "=f"( f0 ) : "l"( handle ) : ); return f0; } static __forceinline__ __device__ unsigned int optixGetGASMotionStepCount( OptixTraversableHandle handle ) { unsigned int u0; asm( "call (%0), _optix_get_gas_motion_step_count, (%1);" : "=r"( u0 ) : "l"( handle ) : ); return u0; } static __forceinline__ __device__ void optixGetWorldToObjectTransformMatrix( float m[12] ) { if( optixGetTransformListSize() == 0 ) { m[0] = 1.0f; m[1] = 0.0f; m[2] = 0.0f; m[3] = 0.0f; m[4] = 0.0f; m[5] = 1.0f; m[6] = 0.0f; m[7] = 0.0f; m[8] = 0.0f; m[9] = 0.0f; m[10] = 1.0f; m[11] = 0.0f; return; } float4 m0, m1, m2; optix_impl::optixGetWorldToObjectTransformMatrix( m0, m1, m2 ); m[0] = m0.x; m[1] = m0.y; m[2] = m0.z; m[3] = m0.w; m[4] = m1.x; m[5] = m1.y; m[6] = m1.z; m[7] = m1.w; m[8] = m2.x; m[9] = m2.y; m[10] = m2.z; m[11] = m2.w; } static __forceinline__ __device__ void optixGetObjectToWorldTransformMatrix( float m[12] ) { if( optixGetTransformListSize() == 0 ) { m[0] = 1.0f; m[1] = 0.0f; m[2] = 0.0f; m[3] = 0.0f; m[4] = 0.0f; m[5] = 1.0f; m[6] = 0.0f; m[7] = 0.0f; m[8] = 0.0f; m[9] = 0.0f; m[10] = 1.0f; m[11] = 0.0f; return; } float4 m0, m1, m2; optix_impl::optixGetObjectToWorldTransformMatrix( m0, m1, m2 ); m[0] = m0.x; m[1] = m0.y; m[2] = m0.z; m[3] = m0.w; m[4] = m1.x; m[5] = m1.y; m[6] = m1.z; m[7] = m1.w; m[8] = m2.x; m[9] = m2.y; m[10] = m2.z; m[11] = m2.w; } static __forceinline__ __device__ float3 optixTransformPointFromWorldToObjectSpace( float3 point ) { if( optixGetTransformListSize() == 0 ) return point; float4 m0, m1, m2; optix_impl::optixGetWorldToObjectTransformMatrix( m0, m1, m2 ); return optix_impl::optixTransformPoint( m0, m1, m2, point ); } static __forceinline__ __device__ float3 optixTransformVectorFromWorldToObjectSpace( float3 vec ) { if( optixGetTransformListSize() == 0 ) return vec; float4 m0, m1, m2; optix_impl::optixGetWorldToObjectTransformMatrix( m0, m1, m2 ); return optix_impl::optixTransformVector( m0, m1, m2, vec ); } static __forceinline__ __device__ float3 optixTransformNormalFromWorldToObjectSpace( float3 normal ) { if( optixGetTransformListSize() == 0 ) return normal; float4 m0, m1, m2; optix_impl::optixGetObjectToWorldTransformMatrix( m0, m1, m2 ); // inverse of optixGetWorldToObjectTransformMatrix() return optix_impl::optixTransformNormal( m0, m1, m2, normal ); } static __forceinline__ __device__ float3 optixTransformPointFromObjectToWorldSpace( float3 point ) { if( optixGetTransformListSize() == 0 ) return point; float4 m0, m1, m2; optix_impl::optixGetObjectToWorldTransformMatrix( m0, m1, m2 ); return optix_impl::optixTransformPoint( m0, m1, m2, point ); } static __forceinline__ __device__ float3 optixTransformVectorFromObjectToWorldSpace( float3 vec ) { if( optixGetTransformListSize() == 0 ) return vec; float4 m0, m1, m2; optix_impl::optixGetObjectToWorldTransformMatrix( m0, m1, m2 ); return optix_impl::optixTransformVector( m0, m1, m2, vec ); } static __forceinline__ __device__ float3 optixTransformNormalFromObjectToWorldSpace( float3 normal ) { if( optixGetTransformListSize() == 0 ) return normal; float4 m0, m1, m2; optix_impl::optixGetWorldToObjectTransformMatrix( m0, m1, m2 ); // inverse of optixGetObjectToWorldTransformMatrix() return optix_impl::optixTransformNormal( m0, m1, m2, normal ); } static __forceinline__ __device__ unsigned int optixGetTransformListSize() { unsigned int u0; asm( "call (%0), _optix_get_transform_list_size, ();" : "=r"( u0 ) : ); return u0; } static __forceinline__ __device__ OptixTraversableHandle optixGetTransformListHandle( unsigned int index ) { unsigned long long u0; asm( "call (%0), _optix_get_transform_list_handle, (%1);" : "=l"( u0 ) : "r"( index ) : ); return u0; } static __forceinline__ __device__ OptixTransformType optixGetTransformTypeFromHandle( OptixTraversableHandle handle ) { int i0; asm( "call (%0), _optix_get_transform_type_from_handle, (%1);" : "=r"( i0 ) : "l"( handle ) : ); return (OptixTransformType)i0; } static __forceinline__ __device__ const OptixStaticTransform* optixGetStaticTransformFromHandle( OptixTraversableHandle handle ) { unsigned long long ptr; asm( "call (%0), _optix_get_static_transform_from_handle, (%1);" : "=l"( ptr ) : "l"( handle ) : ); return (const OptixStaticTransform*)ptr; } static __forceinline__ __device__ const OptixSRTMotionTransform* optixGetSRTMotionTransformFromHandle( OptixTraversableHandle handle ) { unsigned long long ptr; asm( "call (%0), _optix_get_srt_motion_transform_from_handle, (%1);" : "=l"( ptr ) : "l"( handle ) : ); return (const OptixSRTMotionTransform*)ptr; } static __forceinline__ __device__ const OptixMatrixMotionTransform* optixGetMatrixMotionTransformFromHandle( OptixTraversableHandle handle ) { unsigned long long ptr; asm( "call (%0), _optix_get_matrix_motion_transform_from_handle, (%1);" : "=l"( ptr ) : "l"( handle ) : ); return (const OptixMatrixMotionTransform*)ptr; } static __forceinline__ __device__ unsigned int optixGetInstanceIdFromHandle( OptixTraversableHandle handle ) { int i0; asm( "call (%0), _optix_get_instance_id_from_handle, (%1);" : "=r"( i0 ) : "l"( handle ) : ); return i0; } static __forceinline__ __device__ OptixTraversableHandle optixGetInstanceChildFromHandle( OptixTraversableHandle handle ) { unsigned long long i0; asm( "call (%0), _optix_get_instance_child_from_handle, (%1);" : "=l"( i0 ) : "l"( handle ) : ); return (OptixTraversableHandle)i0; } static __forceinline__ __device__ const float4* optixGetInstanceTransformFromHandle( OptixTraversableHandle handle ) { unsigned long long ptr; asm( "call (%0), _optix_get_instance_transform_from_handle, (%1);" : "=l"( ptr ) : "l"( handle ) : ); return (const float4*)ptr; } static __forceinline__ __device__ const float4* optixGetInstanceInverseTransformFromHandle( OptixTraversableHandle handle ) { unsigned long long ptr; asm( "call (%0), _optix_get_instance_inverse_transform_from_handle, (%1);" : "=l"( ptr ) : "l"( handle ) : ); return (const float4*)ptr; } static __forceinline__ __device__ bool optixReportIntersection( float hitT, unsigned int hitKind ) { int ret; asm volatile( "call (%0), _optix_report_intersection_0" ", (%1, %2);" : "=r"( ret ) : "f"( hitT ), "r"( hitKind ) : ); return ret; } static __forceinline__ __device__ bool optixReportIntersection( float hitT, unsigned int hitKind, unsigned int a0 ) { int ret; asm volatile( "call (%0), _optix_report_intersection_1" ", (%1, %2, %3);" : "=r"( ret ) : "f"( hitT ), "r"( hitKind ), "r"( a0 ) : ); return ret; } static __forceinline__ __device__ bool optixReportIntersection( float hitT, unsigned int hitKind, unsigned int a0, unsigned int a1 ) { int ret; asm volatile( "call (%0), _optix_report_intersection_2" ", (%1, %2, %3, %4);" : "=r"( ret ) : "f"( hitT ), "r"( hitKind ), "r"( a0 ), "r"( a1 ) : ); return ret; } static __forceinline__ __device__ bool optixReportIntersection( float hitT, unsigned int hitKind, unsigned int a0, unsigned int a1, unsigned int a2 ) { int ret; asm volatile( "call (%0), _optix_report_intersection_3" ", (%1, %2, %3, %4, %5);" : "=r"( ret ) : "f"( hitT ), "r"( hitKind ), "r"( a0 ), "r"( a1 ), "r"( a2 ) : ); return ret; } static __forceinline__ __device__ bool optixReportIntersection( float hitT, unsigned int hitKind, unsigned int a0, unsigned int a1, unsigned int a2, unsigned int a3 ) { int ret; asm volatile( "call (%0), _optix_report_intersection_4" ", (%1, %2, %3, %4, %5, %6);" : "=r"( ret ) : "f"( hitT ), "r"( hitKind ), "r"( a0 ), "r"( a1 ), "r"( a2 ), "r"( a3 ) : ); return ret; } static __forceinline__ __device__ bool optixReportIntersection( float hitT, unsigned int hitKind, unsigned int a0, unsigned int a1, unsigned int a2, unsigned int a3, unsigned int a4 ) { int ret; asm volatile( "call (%0), _optix_report_intersection_5" ", (%1, %2, %3, %4, %5, %6, %7);" : "=r"( ret ) : "f"( hitT ), "r"( hitKind ), "r"( a0 ), "r"( a1 ), "r"( a2 ), "r"( a3 ), "r"( a4 ) : ); return ret; } static __forceinline__ __device__ bool optixReportIntersection( float hitT, unsigned int hitKind, unsigned int a0, unsigned int a1, unsigned int a2, unsigned int a3, unsigned int a4, unsigned int a5 ) { int ret; asm volatile( "call (%0), _optix_report_intersection_6" ", (%1, %2, %3, %4, %5, %6, %7, %8);" : "=r"( ret ) : "f"( hitT ), "r"( hitKind ), "r"( a0 ), "r"( a1 ), "r"( a2 ), "r"( a3 ), "r"( a4 ), "r"( a5 ) : ); return ret; } static __forceinline__ __device__ bool optixReportIntersection( float hitT, unsigned int hitKind, unsigned int a0, unsigned int a1, unsigned int a2, unsigned int a3, unsigned int a4, unsigned int a5, unsigned int a6 ) { int ret; asm volatile( "call (%0), _optix_report_intersection_7" ", (%1, %2, %3, %4, %5, %6, %7, %8, %9);" : "=r"( ret ) : "f"( hitT ), "r"( hitKind ), "r"( a0 ), "r"( a1 ), "r"( a2 ), "r"( a3 ), "r"( a4 ), "r"( a5 ), "r"( a6 ) : ); return ret; } static __forceinline__ __device__ bool optixReportIntersection( float hitT, unsigned int hitKind, unsigned int a0, unsigned int a1, unsigned int a2, unsigned int a3, unsigned int a4, unsigned int a5, unsigned int a6, unsigned int a7 ) { int ret; asm volatile( "call (%0), _optix_report_intersection_8" ", (%1, %2, %3, %4, %5, %6, %7, %8, %9, %10);" : "=r"( ret ) : "f"( hitT ), "r"( hitKind ), "r"( a0 ), "r"( a1 ), "r"( a2 ), "r"( a3 ), "r"( a4 ), "r"( a5 ), "r"( a6 ), "r"( a7 ) : ); return ret; } #define OPTIX_DEFINE_optixGetAttribute_BODY( which ) \ unsigned int ret; \ asm( "call (%0), _optix_get_attribute_" #which ", ();" : "=r"( ret ) : ); \ return ret; static __forceinline__ __device__ unsigned int optixGetAttribute_0() { OPTIX_DEFINE_optixGetAttribute_BODY( 0 ); } static __forceinline__ __device__ unsigned int optixGetAttribute_1() { OPTIX_DEFINE_optixGetAttribute_BODY( 1 ); } static __forceinline__ __device__ unsigned int optixGetAttribute_2() { OPTIX_DEFINE_optixGetAttribute_BODY( 2 ); } static __forceinline__ __device__ unsigned int optixGetAttribute_3() { OPTIX_DEFINE_optixGetAttribute_BODY( 3 ); } static __forceinline__ __device__ unsigned int optixGetAttribute_4() { OPTIX_DEFINE_optixGetAttribute_BODY( 4 ); } static __forceinline__ __device__ unsigned int optixGetAttribute_5() { OPTIX_DEFINE_optixGetAttribute_BODY( 5 ); } static __forceinline__ __device__ unsigned int optixGetAttribute_6() { OPTIX_DEFINE_optixGetAttribute_BODY( 6 ); } static __forceinline__ __device__ unsigned int optixGetAttribute_7() { OPTIX_DEFINE_optixGetAttribute_BODY( 7 ); } #undef OPTIX_DEFINE_optixGetAttribute_BODY static __forceinline__ __device__ void optixTerminateRay() { asm volatile( "call _optix_terminate_ray, ();" ); } static __forceinline__ __device__ void optixIgnoreIntersection() { asm volatile( "call _optix_ignore_intersection, ();" ); } static __forceinline__ __device__ unsigned int optixGetPrimitiveIndex() { unsigned int u0; asm( "call (%0), _optix_read_primitive_idx, ();" : "=r"( u0 ) : ); return u0; } static __forceinline__ __device__ unsigned int optixGetSbtGASIndex() { unsigned int u0; asm( "call (%0), _optix_read_sbt_gas_idx, ();" : "=r"( u0 ) : ); return u0; } static __forceinline__ __device__ unsigned int optixGetInstanceId() { unsigned int u0; asm( "call (%0), _optix_read_instance_id, ();" : "=r"( u0 ) : ); return u0; } static __forceinline__ __device__ unsigned int optixGetInstanceIndex() { unsigned int u0; asm( "call (%0), _optix_read_instance_idx, ();" : "=r"( u0 ) : ); return u0; } static __forceinline__ __device__ unsigned int optixGetHitKind() { unsigned int u0; asm( "call (%0), _optix_get_hit_kind, ();" : "=r"( u0 ) : ); return u0; } static __forceinline__ __device__ OptixPrimitiveType optixGetPrimitiveType(unsigned int hitKind) { unsigned int u0; asm( "call (%0), _optix_get_primitive_type_from_hit_kind, (%1);" : "=r"( u0 ) : "r"( hitKind ) ); return (OptixPrimitiveType)u0; } static __forceinline__ __device__ bool optixIsBackFaceHit( unsigned int hitKind ) { unsigned int u0; asm( "call (%0), _optix_get_backface_from_hit_kind, (%1);" : "=r"( u0 ) : "r"( hitKind ) ); return (u0 == 0x1); } static __forceinline__ __device__ bool optixIsFrontFaceHit( unsigned int hitKind ) { return !optixIsBackFaceHit( hitKind ); } static __forceinline__ __device__ OptixPrimitiveType optixGetPrimitiveType() { return optixGetPrimitiveType( optixGetHitKind() ); } static __forceinline__ __device__ bool optixIsBackFaceHit() { return optixIsBackFaceHit( optixGetHitKind() ); } static __forceinline__ __device__ bool optixIsFrontFaceHit() { return optixIsFrontFaceHit( optixGetHitKind() ); } static __forceinline__ __device__ bool optixIsTriangleHit() { return optixIsTriangleFrontFaceHit() || optixIsTriangleBackFaceHit(); } static __forceinline__ __device__ bool optixIsTriangleFrontFaceHit() { return optixGetHitKind() == OPTIX_HIT_KIND_TRIANGLE_FRONT_FACE; } static __forceinline__ __device__ bool optixIsTriangleBackFaceHit() { return optixGetHitKind() == OPTIX_HIT_KIND_TRIANGLE_BACK_FACE; } static __forceinline__ __device__ float optixGetCurveParameter() { return __int_as_float( optixGetAttribute_0() ); } static __forceinline__ __device__ float2 optixGetTriangleBarycentrics() { float f0, f1; asm( "call (%0, %1), _optix_get_triangle_barycentrics, ();" : "=f"( f0 ), "=f"( f1 ) : ); return make_float2( f0, f1 ); } static __forceinline__ __device__ uint3 optixGetLaunchIndex() { unsigned int u0, u1, u2; asm( "call (%0), _optix_get_launch_index_x, ();" : "=r"( u0 ) : ); asm( "call (%0), _optix_get_launch_index_y, ();" : "=r"( u1 ) : ); asm( "call (%0), _optix_get_launch_index_z, ();" : "=r"( u2 ) : ); return make_uint3( u0, u1, u2 ); } static __forceinline__ __device__ uint3 optixGetLaunchDimensions() { unsigned int u0, u1, u2; asm( "call (%0), _optix_get_launch_dimension_x, ();" : "=r"( u0 ) : ); asm( "call (%0), _optix_get_launch_dimension_y, ();" : "=r"( u1 ) : ); asm( "call (%0), _optix_get_launch_dimension_z, ();" : "=r"( u2 ) : ); return make_uint3( u0, u1, u2 ); } static __forceinline__ __device__ CUdeviceptr optixGetSbtDataPointer() { unsigned long long ptr; asm( "call (%0), _optix_get_sbt_data_ptr_64, ();" : "=l"( ptr ) : ); return (CUdeviceptr)ptr; } static __forceinline__ __device__ void optixThrowException( int exceptionCode ) { asm volatile( "call _optix_throw_exception_0, (%0);" : /* no return value */ : "r"( exceptionCode ) : ); } static __forceinline__ __device__ void optixThrowException( int exceptionCode, unsigned int exceptionDetail0 ) { asm volatile( "call _optix_throw_exception_1, (%0, %1);" : /* no return value */ : "r"( exceptionCode ), "r"( exceptionDetail0 ) : ); } static __forceinline__ __device__ void optixThrowException( int exceptionCode, unsigned int exceptionDetail0, unsigned int exceptionDetail1 ) { asm volatile( "call _optix_throw_exception_2, (%0, %1, %2);" : /* no return value */ : "r"( exceptionCode ), "r"( exceptionDetail0 ), "r"( exceptionDetail1 ) : ); } static __forceinline__ __device__ void optixThrowException( int exceptionCode, unsigned int exceptionDetail0, unsigned int exceptionDetail1, unsigned int exceptionDetail2 ) { asm volatile( "call _optix_throw_exception_3, (%0, %1, %2, %3);" : /* no return value */ : "r"( exceptionCode ), "r"( exceptionDetail0 ), "r"( exceptionDetail1 ), "r"( exceptionDetail2 ) : ); } static __forceinline__ __device__ void optixThrowException( int exceptionCode, unsigned int exceptionDetail0, unsigned int exceptionDetail1, unsigned int exceptionDetail2, unsigned int exceptionDetail3 ) { asm volatile( "call _optix_throw_exception_4, (%0, %1, %2, %3, %4);" : /* no return value */ : "r"( exceptionCode ), "r"( exceptionDetail0 ), "r"( exceptionDetail1 ), "r"( exceptionDetail2 ), "r"( exceptionDetail3 ) : ); } static __forceinline__ __device__ void optixThrowException( int exceptionCode, unsigned int exceptionDetail0, unsigned int exceptionDetail1, unsigned int exceptionDetail2, unsigned int exceptionDetail3, unsigned int exceptionDetail4 ) { asm volatile( "call _optix_throw_exception_5, (%0, %1, %2, %3, %4, %5);" : /* no return value */ : "r"( exceptionCode ), "r"( exceptionDetail0 ), "r"( exceptionDetail1 ), "r"( exceptionDetail2 ), "r"( exceptionDetail3 ), "r"( exceptionDetail4 ) : ); } static __forceinline__ __device__ void optixThrowException( int exceptionCode, unsigned int exceptionDetail0, unsigned int exceptionDetail1, unsigned int exceptionDetail2, unsigned int exceptionDetail3, unsigned int exceptionDetail4, unsigned int exceptionDetail5 ) { asm volatile( "call _optix_throw_exception_6, (%0, %1, %2, %3, %4, %5, %6);" : /* no return value */ : "r"( exceptionCode ), "r"( exceptionDetail0 ), "r"( exceptionDetail1 ), "r"( exceptionDetail2 ), "r"( exceptionDetail3 ), "r"( exceptionDetail4 ), "r"( exceptionDetail5 ) : ); } static __forceinline__ __device__ void optixThrowException( int exceptionCode, unsigned int exceptionDetail0, unsigned int exceptionDetail1, unsigned int exceptionDetail2, unsigned int exceptionDetail3, unsigned int exceptionDetail4, unsigned int exceptionDetail5, unsigned int exceptionDetail6 ) { asm volatile( "call _optix_throw_exception_7, (%0, %1, %2, %3, %4, %5, %6, %7);" : /* no return value */ : "r"( exceptionCode ), "r"( exceptionDetail0 ), "r"( exceptionDetail1 ), "r"( exceptionDetail2 ), "r"( exceptionDetail3 ), "r"( exceptionDetail4 ), "r"( exceptionDetail5 ), "r"( exceptionDetail6 ) : ); } static __forceinline__ __device__ void optixThrowException( int exceptionCode, unsigned int exceptionDetail0, unsigned int exceptionDetail1, unsigned int exceptionDetail2, unsigned int exceptionDetail3, unsigned int exceptionDetail4, unsigned int exceptionDetail5, unsigned int exceptionDetail6, unsigned int exceptionDetail7 ) { asm volatile( "call _optix_throw_exception_8, (%0, %1, %2, %3, %4, %5, %6, %7, %8);" : /* no return value */ : "r"( exceptionCode ), "r"( exceptionDetail0 ), "r"( exceptionDetail1 ), "r"( exceptionDetail2 ), "r"( exceptionDetail3 ), "r"( exceptionDetail4 ), "r"( exceptionDetail5 ), "r"( exceptionDetail6 ), "r"( exceptionDetail7 ) : ); } static __forceinline__ __device__ int optixGetExceptionCode() { int s0; asm( "call (%0), _optix_get_exception_code, ();" : "=r"( s0 ) : ); return s0; } #define OPTIX_DEFINE_optixGetExceptionDetail_BODY( which ) \ unsigned int ret; \ asm( "call (%0), _optix_get_exception_detail_" #which ", ();" : "=r"( ret ) : ); \ return ret; static __forceinline__ __device__ unsigned int optixGetExceptionDetail_0() { OPTIX_DEFINE_optixGetExceptionDetail_BODY( 0 ); } static __forceinline__ __device__ unsigned int optixGetExceptionDetail_1() { OPTIX_DEFINE_optixGetExceptionDetail_BODY( 1 ); } static __forceinline__ __device__ unsigned int optixGetExceptionDetail_2() { OPTIX_DEFINE_optixGetExceptionDetail_BODY( 2 ); } static __forceinline__ __device__ unsigned int optixGetExceptionDetail_3() { OPTIX_DEFINE_optixGetExceptionDetail_BODY( 3 ); } static __forceinline__ __device__ unsigned int optixGetExceptionDetail_4() { OPTIX_DEFINE_optixGetExceptionDetail_BODY( 4 ); } static __forceinline__ __device__ unsigned int optixGetExceptionDetail_5() { OPTIX_DEFINE_optixGetExceptionDetail_BODY( 5 ); } static __forceinline__ __device__ unsigned int optixGetExceptionDetail_6() { OPTIX_DEFINE_optixGetExceptionDetail_BODY( 6 ); } static __forceinline__ __device__ unsigned int optixGetExceptionDetail_7() { OPTIX_DEFINE_optixGetExceptionDetail_BODY( 7 ); } #undef OPTIX_DEFINE_optixGetExceptionDetail_BODY static __forceinline__ __device__ OptixTraversableHandle optixGetExceptionInvalidTraversable() { unsigned long long handle; asm( "call (%0), _optix_get_exception_invalid_traversable, ();" : "=l"( handle ) : ); return (OptixTraversableHandle)handle; } static __forceinline__ __device__ int optixGetExceptionInvalidSbtOffset() { int s0; asm( "call (%0), _optix_get_exception_invalid_sbt_offset, ();" : "=r"( s0 ) : ); return s0; } static __forceinline__ __device__ OptixInvalidRayExceptionDetails optixGetExceptionInvalidRay() { float rayOriginX, rayOriginY, rayOriginZ, rayDirectionX, rayDirectionY, rayDirectionZ, tmin, tmax, rayTime; asm( "call (%0, %1, %2, %3, %4, %5, %6, %7, %8), _optix_get_exception_invalid_ray, ();" : "=f"( rayOriginX ), "=f"( rayOriginY ), "=f"( rayOriginZ ), "=f"( rayDirectionX ), "=f"( rayDirectionY ), "=f"( rayDirectionZ ), "=f"( tmin ), "=f"( tmax ), "=f"( rayTime ) : ); OptixInvalidRayExceptionDetails ray; ray.origin = make_float3( rayOriginX, rayOriginY, rayOriginZ ); ray.direction = make_float3( rayDirectionX, rayDirectionY, rayDirectionZ ); ray.tmin = tmin; ray.tmax = tmax; ray.time = rayTime; return ray; } static __forceinline__ __device__ OptixParameterMismatchExceptionDetails optixGetExceptionParameterMismatch() { unsigned int expected, actual, sbtIdx; unsigned long long calleeName; asm( "call (%0, %1, %2, %3), _optix_get_exception_parameter_mismatch, ();" : "=r"(expected), "=r"(actual), "=r"(sbtIdx), "=l"(calleeName) : ); OptixParameterMismatchExceptionDetails details; details.expectedParameterCount = expected; details.passedArgumentCount = actual; details.sbtIndex = sbtIdx; details.callableName = (char*)calleeName; return details; } static __forceinline__ __device__ char* optixGetExceptionLineInfo() { unsigned long long ptr; asm( "call (%0), _optix_get_exception_line_info, ();" : "=l"(ptr) : ); return (char*)ptr; } template static __forceinline__ __device__ ReturnT optixDirectCall( unsigned int sbtIndex, ArgTypes... args ) { unsigned long long func; asm( "call (%0), _optix_call_direct_callable,(%1);" : "=l"( func ) : "r"( sbtIndex ) : ); using funcT = ReturnT ( * )( ArgTypes... ); funcT call = ( funcT )( func ); return call( args... ); } template static __forceinline__ __device__ ReturnT optixContinuationCall( unsigned int sbtIndex, ArgTypes... args ) { unsigned long long func; asm( "call (%0), _optix_call_continuation_callable,(%1);" : "=l"( func ) : "r"( sbtIndex ) : ); using funcT = ReturnT ( * )( ArgTypes... ); funcT call = ( funcT )( func ); return call( args... ); } #endif static __forceinline__ __device__ uint4 optixTexFootprint2D( unsigned long long tex, unsigned int texInfo, float x, float y, unsigned int* singleMipLevel ) { uint4 result; unsigned long long resultPtr = reinterpret_cast( &result ); unsigned long long singleMipLevelPtr = reinterpret_cast( singleMipLevel ); // Cast float args to integers, because the intrinics take .b32 arguments when compiled to PTX. asm volatile( "call _optix_tex_footprint_2d_v2" ", (%0, %1, %2, %3, %4, %5);" : : "l"( tex ), "r"( texInfo ), "r"( __float_as_uint( x ) ), "r"( __float_as_uint( y ) ), "l"( singleMipLevelPtr ), "l"( resultPtr ) : ); return result; } static __forceinline__ __device__ uint4 optixTexFootprint2DGrad( unsigned long long tex, unsigned int texInfo, float x, float y, float dPdx_x, float dPdx_y, float dPdy_x, float dPdy_y, bool coarse, unsigned int* singleMipLevel ) { uint4 result; unsigned long long resultPtr = reinterpret_cast( &result ); unsigned long long singleMipLevelPtr = reinterpret_cast( singleMipLevel ); // Cast float args to integers, because the intrinics take .b32 arguments when compiled to PTX. asm volatile( "call _optix_tex_footprint_2d_grad_v2" ", (%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10);" : : "l"( tex ), "r"( texInfo ), "r"( __float_as_uint( x ) ), "r"( __float_as_uint( y ) ), "r"( __float_as_uint( dPdx_x ) ), "r"( __float_as_uint( dPdx_y ) ), "r"( __float_as_uint( dPdy_x ) ), "r"( __float_as_uint( dPdy_y ) ), "r"( static_cast( coarse ) ), "l"( singleMipLevelPtr ), "l"( resultPtr ) : ); return result; } static __forceinline__ __device__ uint4 optixTexFootprint2DLod( unsigned long long tex, unsigned int texInfo, float x, float y, float level, bool coarse, unsigned int* singleMipLevel ) { uint4 result; unsigned long long resultPtr = reinterpret_cast( &result ); unsigned long long singleMipLevelPtr = reinterpret_cast( singleMipLevel ); // Cast float args to integers, because the intrinics take .b32 arguments when compiled to PTX. asm volatile( "call _optix_tex_footprint_2d_lod_v2" ", (%0, %1, %2, %3, %4, %5, %6, %7);" : : "l"( tex ), "r"( texInfo ), "r"( __float_as_uint( x ) ), "r"( __float_as_uint( y ) ), "r"( __float_as_uint( level ) ), "r"( static_cast( coarse ) ), "l"( singleMipLevelPtr ), "l"( resultPtr ) : ); return result; } ================================================ FILE: render/optixutils/include/internal/optix_7_device_impl_exception.h ================================================ /* * Copyright (c) 2021 NVIDIA Corporation. All rights reserved. * * NVIDIA Corporation and its licensors retain all intellectual property and proprietary * rights in and to this software, related documentation and any modifications thereto. * Any use, reproduction, disclosure or distribution of this software and related * documentation without an express license agreement from NVIDIA Corporation is strictly * prohibited. * * TO THE MAXIMUM EXTENT PERMITTED BY APPLICABLE LAW, THIS SOFTWARE IS PROVIDED *AS IS* * AND NVIDIA AND ITS SUPPLIERS DISCLAIM ALL WARRANTIES, EITHER EXPRESS OR IMPLIED, * INCLUDING, BUT NOT LIMITED TO, IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A * PARTICULAR PURPOSE. IN NO EVENT SHALL NVIDIA OR ITS SUPPLIERS BE LIABLE FOR ANY * SPECIAL, INCIDENTAL, INDIRECT, OR CONSEQUENTIAL DAMAGES WHATSOEVER (INCLUDING, WITHOUT * LIMITATION, DAMAGES FOR LOSS OF BUSINESS PROFITS, BUSINESS INTERRUPTION, LOSS OF * BUSINESS INFORMATION, OR ANY OTHER PECUNIARY LOSS) ARISING OUT OF THE USE OF OR * INABILITY TO USE THIS SOFTWARE, EVEN IF NVIDIA HAS BEEN ADVISED OF THE POSSIBILITY OF * SUCH DAMAGES */ /** * @file optix_7_device_impl_exception.h * @author NVIDIA Corporation * @brief OptiX public API * * OptiX public API Reference - Device side implementation for exception helper function. */ #if !defined( __OPTIX_INCLUDE_INTERNAL_HEADERS__ ) #error("optix_7_device_impl_exception.h is an internal header file and must not be used directly. Please use optix_device.h or optix.h instead.") #endif #ifndef __optix_optix_7_device_impl_exception_h__ #define __optix_optix_7_device_impl_exception_h__ #if !defined(__CUDACC_RTC__) #include /* for printf */ #endif namespace optix_impl { static __forceinline__ __device__ void optixDumpStaticTransformFromHandle( OptixTraversableHandle handle ) { const OptixStaticTransform* traversable = optixGetStaticTransformFromHandle( handle ); if( traversable ) { const uint3 index = optixGetLaunchIndex(); printf( "(%4i,%4i,%4i) OptixStaticTransform@%p = {\n" " child = %p,\n" " transform = { %f,%f,%f,%f,\n" " %f,%f,%f,%f,\n" " %f,%f,%f,%f } }\n", index.x,index.y,index.z, traversable, (void*)traversable->child, traversable->transform[0], traversable->transform[1], traversable->transform[2], traversable->transform[3], traversable->transform[4], traversable->transform[5], traversable->transform[6], traversable->transform[7], traversable->transform[8], traversable->transform[9], traversable->transform[10], traversable->transform[11] ); } } static __forceinline__ __device__ void optixDumpMotionMatrixTransformFromHandle( OptixTraversableHandle handle ) { const OptixMatrixMotionTransform* traversable = optixGetMatrixMotionTransformFromHandle( handle ); if( traversable ) { const uint3 index = optixGetLaunchIndex(); printf( "(%4i,%4i,%4i) OptixMatrixMotionTransform@%p = {\n" " child = %p,\n" " motionOptions = { numKeys = %i, flags = %i, timeBegin = %f, timeEnd = %f },\n" " transform = { { %f,%f,%f,%f,\n" " %f,%f,%f,%f,\n" " %f,%f,%f,%f }, ... }\n", index.x,index.y,index.z, traversable, (void*)traversable->child, (int)traversable->motionOptions.numKeys, (int)traversable->motionOptions.flags, traversable->motionOptions.timeBegin, traversable->motionOptions.timeEnd, traversable->transform[0][0], traversable->transform[0][1], traversable->transform[0][2], traversable->transform[0][3], traversable->transform[0][4], traversable->transform[0][5], traversable->transform[0][6], traversable->transform[0][7], traversable->transform[0][8], traversable->transform[0][9], traversable->transform[0][10], traversable->transform[0][11] ); } } static __forceinline__ __device__ void optixDumpSrtMatrixTransformFromHandle( OptixTraversableHandle handle ) { const OptixSRTMotionTransform* traversable = optixGetSRTMotionTransformFromHandle( handle ); if( traversable ) { const uint3 index = optixGetLaunchIndex(); printf( "(%4i,%4i,%4i) OptixSRTMotionTransform@%p = {\n" " child = %p,\n" " motionOptions = { numKeys = %i, flags = %i, timeBegin = %f, timeEnd = %f },\n" " srtData = { { sx = %f, a = %f, b = %f, pvx = %f,\n" " sy = %f, c = %f, pvy = %f, sz = %f,\n" " pvz = %f, qx = %f, qy = %f, qz = %f,\n" " qw = %f, tx = %f, ty = %f, tz = %f }, ... }\n", index.x,index.y,index.z, traversable, (void*)traversable->child, (int)traversable->motionOptions.numKeys, (int)traversable->motionOptions.flags, traversable->motionOptions.timeBegin, traversable->motionOptions.timeEnd, traversable->srtData[0].sx, traversable->srtData[0].a, traversable->srtData[0].b, traversable->srtData[0].pvx, traversable->srtData[0].sy, traversable->srtData[0].c, traversable->srtData[0].pvy,traversable->srtData[0].sz, traversable->srtData[0].pvz,traversable->srtData[0].qx,traversable->srtData[0].qy, traversable->srtData[0].qz, traversable->srtData[0].qw, traversable->srtData[0].tx,traversable->srtData[0].ty, traversable->srtData[0].tz ); } } static __forceinline__ __device__ void optixDumpInstanceFromHandle( OptixTraversableHandle handle ) { if( optixGetTransformTypeFromHandle( handle ) == OPTIX_TRANSFORM_TYPE_INSTANCE ) { unsigned int instanceId = optixGetInstanceIdFromHandle( handle ); const float4* transform = optixGetInstanceTransformFromHandle( handle ); const uint3 index = optixGetLaunchIndex(); printf( "(%4i,%4i,%4i) OptixInstance = {\n" " instanceId = %i,\n" " transform = { %f,%f,%f,%f,\n" " %f,%f,%f,%f,\n" " %f,%f,%f,%f } }\n", index.x,index.y,index.z, instanceId, transform[0].x, transform[0].y, transform[0].z, transform[0].w, transform[1].x, transform[1].y, transform[1].z, transform[1].w, transform[2].x, transform[2].y, transform[2].z, transform[2].w ); } } static __forceinline__ __device__ void optixDumpTransform( OptixTraversableHandle handle ) { const OptixTransformType type = optixGetTransformTypeFromHandle( handle ); const uint3 index = optixGetLaunchIndex(); switch( type ) { case OPTIX_TRANSFORM_TYPE_NONE: break; case OPTIX_TRANSFORM_TYPE_STATIC_TRANSFORM: optixDumpStaticTransformFromHandle( handle ); break; case OPTIX_TRANSFORM_TYPE_MATRIX_MOTION_TRANSFORM: optixDumpMotionMatrixTransformFromHandle( handle ); break; case OPTIX_TRANSFORM_TYPE_SRT_MOTION_TRANSFORM: optixDumpSrtMatrixTransformFromHandle( handle ); break; case OPTIX_TRANSFORM_TYPE_INSTANCE: optixDumpInstanceFromHandle( handle ); break; default: break; } } static __forceinline__ __device__ void optixDumpTransformList() { const int tlistSize = optixGetTransformListSize(); const uint3 index = optixGetLaunchIndex(); printf("(%4i,%4i,%4i) transform list of size %i:\n", index.x,index.y,index.z, tlistSize); for( unsigned int i = 0 ; i < tlistSize ; ++i ) { OptixTraversableHandle handle = optixGetTransformListHandle( i ); printf("(%4i,%4i,%4i) transform[%i] = %p\n", index.x, index.y, index.z, i, (void*)handle); optixDumpTransform(handle); } } static __forceinline__ __device__ void optixDumpExceptionDetails() { bool dumpTlist = false; const int exceptionCode = optixGetExceptionCode(); const uint3 index = optixGetLaunchIndex(); if( exceptionCode == OPTIX_EXCEPTION_CODE_STACK_OVERFLOW ) { printf("(%4i,%4i,%4i) error: stack overflow\n", index.x,index.y,index.z); } else if( exceptionCode == OPTIX_EXCEPTION_CODE_TRACE_DEPTH_EXCEEDED ) { printf("(%4i,%4i,%4i) error: trace depth exceeded\n", index.x,index.y,index.z); } else if( exceptionCode == OPTIX_EXCEPTION_CODE_TRAVERSAL_DEPTH_EXCEEDED ) { printf("(%4i,%4i,%4i) error: traversal depth exceeded\n", index.x,index.y,index.z); dumpTlist = true; } else if( exceptionCode == OPTIX_EXCEPTION_CODE_TRAVERSAL_INVALID_TRAVERSABLE ) { OptixTraversableHandle handle = optixGetExceptionInvalidTraversable(); printf("(%4i,%4i,%4i) error: invalid traversable %p\n", index.x,index.y,index.z, (void*)handle); dumpTlist = true; } else if( exceptionCode == OPTIX_EXCEPTION_CODE_TRAVERSAL_INVALID_MISS_SBT ) { int sbtOffset = optixGetExceptionInvalidSbtOffset(); printf("(%4i,%4i,%4i) error: invalid miss sbt of %i\n", index.x,index.y,index.z, sbtOffset); } else if( exceptionCode == OPTIX_EXCEPTION_CODE_TRAVERSAL_INVALID_HIT_SBT ) { int sbtOffset = optixGetExceptionInvalidSbtOffset(); printf("(%4i,%4i,%4i) error: invalid hit sbt of %i at primitive with gas sbt index %i\n", index.x,index.y,index.z, sbtOffset, optixGetSbtGASIndex() ); dumpTlist = true; } else if( exceptionCode == OPTIX_EXCEPTION_CODE_UNSUPPORTED_PRIMITIVE_TYPE ) { dumpTlist = true; printf( "(%4i,%4i,%4i) error: shader encountered unsupported builtin type\n" " call location: %s\n", index.x, index.y, index.z, optixGetExceptionLineInfo() ); } else if( exceptionCode == OPTIX_EXCEPTION_CODE_INVALID_RAY ) { OptixInvalidRayExceptionDetails ray = optixGetExceptionInvalidRay(); printf( "(%4i,%4i,%4i) error: encountered ray with nan or inf values:\n", index.x, index.y, index.z ); printf( " origin: [%f, %f, %f]\n" " direction: [%f, %f, %f]\n" " tmin: %f\n" " tmax: %f\n" " rayTime: %f\n" " call location: %s\n", ray.origin.x, ray.origin.y, ray.origin.z, ray.direction.x, ray.direction.y, ray.direction.z, ray.tmin, ray.tmax, ray.time, optixGetExceptionLineInfo() ); } else if( exceptionCode == OPTIX_EXCEPTION_CODE_CALLABLE_PARAMETER_MISMATCH ) { OptixParameterMismatchExceptionDetails details = optixGetExceptionParameterMismatch(); printf( "(%4i,%4i,%4i) error: parameter mismatch in callable call.\n", index.x, index.y, index.z ); printf( " passed packed arguments: %u 32 Bit values\n" " expected packed parameters: %u 32 Bit values\n" " SBT index: %u\n" " called function: %s\n" " call location: %s\n", details.passedArgumentCount, details.expectedParameterCount, details.sbtIndex, details.callableName, optixGetExceptionLineInfo() ); } else if( exceptionCode == OPTIX_EXCEPTION_CODE_BUILTIN_IS_MISMATCH ) { dumpTlist = true; printf("(%4i,%4i,%4i) error: mismatch between builtin IS shader and build input\n" " call location: %s\n", index.x,index.y,index.z, optixGetExceptionLineInfo() ); } else if( exceptionCode == OPTIX_EXCEPTION_CODE_CALLABLE_INVALID_SBT ) { int sbtOffset = optixGetExceptionInvalidSbtOffset(); printf( "(%4i,%4i,%4i) error: invalid sbt offset of %i for callable program\n", index.x, index.y, index.z, sbtOffset ); } else if( exceptionCode == OPTIX_EXCEPTION_CODE_CALLABLE_NO_DC_SBT_RECORD ) { int sbtOffset = optixGetExceptionInvalidSbtOffset(); printf( "(%4i,%4i,%4i) error: invalid sbt offset of %i for direct callable program\n", index.x, index.y, index.z, sbtOffset ); } else if( exceptionCode == OPTIX_EXCEPTION_CODE_CALLABLE_NO_CC_SBT_RECORD ) { int sbtOffset = optixGetExceptionInvalidSbtOffset(); printf( "(%4i,%4i,%4i) error: invalid sbt offset of %i for continuation callable program\n", index.x, index.y, index.z, sbtOffset ); } else if( exceptionCode == OPTIX_EXCEPTION_CODE_UNSUPPORTED_SINGLE_LEVEL_GAS ) { OptixTraversableHandle handle = optixGetExceptionInvalidTraversable(); printf("(%4i,%4i,%4i) error: unsupported single GAS traversable graph %p\n", index.x,index.y,index.z, (void*)handle); dumpTlist = true; } else if( ( exceptionCode <= OPTIX_EXCEPTION_CODE_INVALID_VALUE_ARGUMENT_0 ) && ( exceptionCode >= OPTIX_EXCEPTION_CODE_INVALID_VALUE_ARGUMENT_2 ) ) { printf("(%4i,%4i,%4i) error: invalid value for argument %i\n", index.x,index.y,index.z, -(exceptionCode - OPTIX_EXCEPTION_CODE_INVALID_VALUE_ARGUMENT_0) ); } else if( exceptionCode == OPTIX_EXCEPTION_CODE_UNSUPPORTED_DATA_ACCESS ) { printf("(%4i,%4i,%4i) error: unsupported random data access\n", index.x,index.y,index.z); } else if( exceptionCode >= 0 ) { dumpTlist = true; printf( "(%4i,%4i,%4i) error: user exception with error code %i\n" " call location: %s\n", index.x, index.y, index.z, exceptionCode, optixGetExceptionLineInfo() ); } else { printf("(%4i,%4i,%4i) error: unknown exception with error code %i\n", index.x,index.y,index.z, exceptionCode); } if( dumpTlist ) optixDumpTransformList(); } } // namespace optix_impl #endif ================================================ FILE: render/optixutils/include/internal/optix_7_device_impl_transformations.h ================================================ /* * Copyright (c) 2021 NVIDIA Corporation. All rights reserved. * * NVIDIA Corporation and its licensors retain all intellectual property and proprietary * rights in and to this software, related documentation and any modifications thereto. * Any use, reproduction, disclosure or distribution of this software and related * documentation without an express license agreement from NVIDIA Corporation is strictly * prohibited. * * TO THE MAXIMUM EXTENT PERMITTED BY APPLICABLE LAW, THIS SOFTWARE IS PROVIDED *AS IS* * AND NVIDIA AND ITS SUPPLIERS DISCLAIM ALL WARRANTIES, EITHER EXPRESS OR IMPLIED, * INCLUDING, BUT NOT LIMITED TO, IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A * PARTICULAR PURPOSE. IN NO EVENT SHALL NVIDIA OR ITS SUPPLIERS BE LIABLE FOR ANY * SPECIAL, INCIDENTAL, INDIRECT, OR CONSEQUENTIAL DAMAGES WHATSOEVER (INCLUDING, WITHOUT * LIMITATION, DAMAGES FOR LOSS OF BUSINESS PROFITS, BUSINESS INTERRUPTION, LOSS OF * BUSINESS INFORMATION, OR ANY OTHER PECUNIARY LOSS) ARISING OUT OF THE USE OF OR * INABILITY TO USE THIS SOFTWARE, EVEN IF NVIDIA HAS BEEN ADVISED OF THE POSSIBILITY OF * SUCH DAMAGES */ /** * @file optix_7_device_impl_transformations.h * @author NVIDIA Corporation * @brief OptiX public API * * OptiX public API Reference - Device side implementation for transformation helper functions. */ #if !defined( __OPTIX_INCLUDE_INTERNAL_HEADERS__ ) #error("optix_7_device_impl_transformations.h is an internal header file and must not be used directly. Please use optix_device.h or optix.h instead.") #endif #ifndef __optix_optix_7_device_impl_transformations_h__ #define __optix_optix_7_device_impl_transformations_h__ namespace optix_impl { static __forceinline__ __device__ float4 optixAddFloat4( const float4& a, const float4& b ) { return make_float4( a.x + b.x, a.y + b.y, a.z + b.z, a.w + b.w ); } static __forceinline__ __device__ float4 optixMulFloat4( const float4& a, float b ) { return make_float4( a.x * b, a.y * b, a.z * b, a.w * b ); } static __forceinline__ __device__ uint4 optixLdg( unsigned long long addr ) { const uint4* ptr; asm volatile( "cvta.to.global.u64 %0, %1;" : "=l"( ptr ) : "l"( addr ) ); uint4 ret; asm volatile( "ld.global.v4.u32 {%0,%1,%2,%3}, [%4];" : "=r"( ret.x ), "=r"( ret.y ), "=r"( ret.z ), "=r"( ret.w ) : "l"( ptr ) ); return ret; } template static __forceinline__ __device__ T optixLoadReadOnlyAlign16( const T* ptr ) { T v; for( int ofs = 0; ofs < sizeof( T ); ofs += 16 ) *(uint4*)( (char*)&v + ofs ) = optixLdg( (unsigned long long)( (char*)ptr + ofs ) ); return v; } // Multiplies the row vector vec with the 3x4 matrix with rows m0, m1, and m2 static __forceinline__ __device__ float4 optixMultiplyRowMatrix( const float4 vec, const float4 m0, const float4 m1, const float4 m2 ) { float4 result; result.x = vec.x * m0.x + vec.y * m1.x + vec.z * m2.x; result.y = vec.x * m0.y + vec.y * m1.y + vec.z * m2.y; result.z = vec.x * m0.z + vec.y * m1.z + vec.z * m2.z; result.w = vec.x * m0.w + vec.y * m1.w + vec.z * m2.w + vec.w; return result; } // Converts the SRT transformation srt into a 3x4 matrix with rows m0, m1, and m2 static __forceinline__ __device__ void optixGetMatrixFromSrt( float4& m0, float4& m1, float4& m2, const OptixSRTData& srt ) { const float4 q = {srt.qx, srt.qy, srt.qz, srt.qw}; // normalize const float inv_sql = 1.f / ( srt.qx * srt.qx + srt.qy * srt.qy + srt.qz * srt.qz + srt.qw * srt.qw ); const float4 nq = optixMulFloat4( q, inv_sql ); const float sqw = q.w * nq.w; const float sqx = q.x * nq.x; const float sqy = q.y * nq.y; const float sqz = q.z * nq.z; const float xy = q.x * nq.y; const float zw = q.z * nq.w; const float xz = q.x * nq.z; const float yw = q.y * nq.w; const float yz = q.y * nq.z; const float xw = q.x * nq.w; m0.x = ( sqx - sqy - sqz + sqw ); m0.y = 2.0f * ( xy - zw ); m0.z = 2.0f * ( xz + yw ); m1.x = 2.0f * ( xy + zw ); m1.y = ( -sqx + sqy - sqz + sqw ); m1.z = 2.0f * ( yz - xw ); m2.x = 2.0f * ( xz - yw ); m2.y = 2.0f * ( yz + xw ); m2.z = ( -sqx - sqy + sqz + sqw ); m0.w = m0.x * srt.pvx + m0.y * srt.pvy + m0.z * srt.pvz + srt.tx; m1.w = m1.x * srt.pvx + m1.y * srt.pvy + m1.z * srt.pvz + srt.ty; m2.w = m2.x * srt.pvx + m2.y * srt.pvy + m2.z * srt.pvz + srt.tz; m0.z = m0.x * srt.b + m0.y * srt.c + m0.z * srt.sz; m1.z = m1.x * srt.b + m1.y * srt.c + m1.z * srt.sz; m2.z = m2.x * srt.b + m2.y * srt.c + m2.z * srt.sz; m0.y = m0.x * srt.a + m0.y * srt.sy; m1.y = m1.x * srt.a + m1.y * srt.sy; m2.y = m2.x * srt.a + m2.y * srt.sy; m0.x = m0.x * srt.sx; m1.x = m1.x * srt.sx; m2.x = m2.x * srt.sx; } // Inverts a 3x4 matrix in place static __forceinline__ __device__ void optixInvertMatrix( float4& m0, float4& m1, float4& m2 ) { const float det3 = m0.x * ( m1.y * m2.z - m1.z * m2.y ) - m0.y * ( m1.x * m2.z - m1.z * m2.x ) + m0.z * ( m1.x * m2.y - m1.y * m2.x ); const float inv_det3 = 1.0f / det3; float inv3[3][3]; inv3[0][0] = inv_det3 * ( m1.y * m2.z - m2.y * m1.z ); inv3[0][1] = inv_det3 * ( m0.z * m2.y - m2.z * m0.y ); inv3[0][2] = inv_det3 * ( m0.y * m1.z - m1.y * m0.z ); inv3[1][0] = inv_det3 * ( m1.z * m2.x - m2.z * m1.x ); inv3[1][1] = inv_det3 * ( m0.x * m2.z - m2.x * m0.z ); inv3[1][2] = inv_det3 * ( m0.z * m1.x - m1.z * m0.x ); inv3[2][0] = inv_det3 * ( m1.x * m2.y - m2.x * m1.y ); inv3[2][1] = inv_det3 * ( m0.y * m2.x - m2.y * m0.x ); inv3[2][2] = inv_det3 * ( m0.x * m1.y - m1.x * m0.y ); const float b[3] = {m0.w, m1.w, m2.w}; m0.x = inv3[0][0]; m0.y = inv3[0][1]; m0.z = inv3[0][2]; m0.w = -inv3[0][0] * b[0] - inv3[0][1] * b[1] - inv3[0][2] * b[2]; m1.x = inv3[1][0]; m1.y = inv3[1][1]; m1.z = inv3[1][2]; m1.w = -inv3[1][0] * b[0] - inv3[1][1] * b[1] - inv3[1][2] * b[2]; m2.x = inv3[2][0]; m2.y = inv3[2][1]; m2.z = inv3[2][2]; m2.w = -inv3[2][0] * b[0] - inv3[2][1] * b[1] - inv3[2][2] * b[2]; } static __forceinline__ __device__ void optixLoadInterpolatedMatrixKey( float4& m0, float4& m1, float4& m2, const float4* matrix, const float t1 ) { m0 = optixLoadReadOnlyAlign16( &matrix[0] ); m1 = optixLoadReadOnlyAlign16( &matrix[1] ); m2 = optixLoadReadOnlyAlign16( &matrix[2] ); // The conditional prevents concurrent loads leading to spills if( t1 > 0.0f ) { const float t0 = 1.0f - t1; m0 = optixAddFloat4( optixMulFloat4( m0, t0 ), optixMulFloat4( optixLoadReadOnlyAlign16( &matrix[3] ), t1 ) ); m1 = optixAddFloat4( optixMulFloat4( m1, t0 ), optixMulFloat4( optixLoadReadOnlyAlign16( &matrix[4] ), t1 ) ); m2 = optixAddFloat4( optixMulFloat4( m2, t0 ), optixMulFloat4( optixLoadReadOnlyAlign16( &matrix[5] ), t1 ) ); } } static __forceinline__ __device__ void optixLoadInterpolatedSrtKey( float4& srt0, float4& srt1, float4& srt2, float4& srt3, const float4* srt, const float t1 ) { srt0 = optixLoadReadOnlyAlign16( &srt[0] ); srt1 = optixLoadReadOnlyAlign16( &srt[1] ); srt2 = optixLoadReadOnlyAlign16( &srt[2] ); srt3 = optixLoadReadOnlyAlign16( &srt[3] ); // The conditional prevents concurrent loads leading to spills if( t1 > 0.0f ) { const float t0 = 1.0f - t1; srt0 = optixAddFloat4( optixMulFloat4( srt0, t0 ), optixMulFloat4( optixLoadReadOnlyAlign16( &srt[4] ), t1 ) ); srt1 = optixAddFloat4( optixMulFloat4( srt1, t0 ), optixMulFloat4( optixLoadReadOnlyAlign16( &srt[5] ), t1 ) ); srt2 = optixAddFloat4( optixMulFloat4( srt2, t0 ), optixMulFloat4( optixLoadReadOnlyAlign16( &srt[6] ), t1 ) ); srt3 = optixAddFloat4( optixMulFloat4( srt3, t0 ), optixMulFloat4( optixLoadReadOnlyAlign16( &srt[7] ), t1 ) ); float inv_length = 1.f / sqrt( srt2.y * srt2.y + srt2.z * srt2.z + srt2.w * srt2.w + srt3.x * srt3.x ); srt2.y *= inv_length; srt2.z *= inv_length; srt2.w *= inv_length; srt3.x *= inv_length; } } static __forceinline__ __device__ void optixResolveMotionKey( float& localt, int& key, const OptixMotionOptions& options, const float globalt ) { const float timeBegin = options.timeBegin; const float timeEnd = options.timeEnd; const float numIntervals = (float)( options.numKeys - 1 ); // No need to check the motion flags. If data originates from a valid transform list handle, then globalt is in // range, or vanish flags are not set. const float time = max( 0.f, min( numIntervals, ( globalt - timeBegin ) * numIntervals / ( timeEnd - timeBegin ) ) ); const float fltKey = floorf( time ); localt = time - fltKey; key = (int)fltKey; } // Returns the interpolated transformation matrix for a particular matrix motion transformation and point in time. static __forceinline__ __device__ void optixGetInterpolatedTransformation( float4& trf0, float4& trf1, float4& trf2, const OptixMatrixMotionTransform* transformData, const float time ) { // Compute key and intra key time float keyTime; int key; optixResolveMotionKey( keyTime, key, optixLoadReadOnlyAlign16( transformData ).motionOptions, time ); // Get pointer to left key const float4* transform = (const float4*)( &transformData->transform[key][0] ); // Load and interpolate matrix keys optixLoadInterpolatedMatrixKey( trf0, trf1, trf2, transform, keyTime ); } // Returns the interpolated transformation matrix for a particular SRT motion transformation and point in time. static __forceinline__ __device__ void optixGetInterpolatedTransformation( float4& trf0, float4& trf1, float4& trf2, const OptixSRTMotionTransform* transformData, const float time ) { // Compute key and intra key time float keyTime; int key; optixResolveMotionKey( keyTime, key, optixLoadReadOnlyAlign16( transformData ).motionOptions, time ); // Get pointer to left key const float4* dataPtr = reinterpret_cast( &transformData->srtData[key] ); // Load and interpolated SRT keys float4 data[4]; optixLoadInterpolatedSrtKey( data[0], data[1], data[2], data[3], dataPtr, keyTime ); OptixSRTData srt = {data[0].x, data[0].y, data[0].z, data[0].w, data[1].x, data[1].y, data[1].z, data[1].w, data[2].x, data[2].y, data[2].z, data[2].w, data[3].x, data[3].y, data[3].z, data[3].w}; // Convert SRT into a matrix optixGetMatrixFromSrt( trf0, trf1, trf2, srt ); } // Returns the interpolated transformation matrix for a particular traversable handle and point in time. static __forceinline__ __device__ void optixGetInterpolatedTransformationFromHandle( float4& trf0, float4& trf1, float4& trf2, const OptixTraversableHandle handle, const float time, const bool objectToWorld ) { const OptixTransformType type = optixGetTransformTypeFromHandle( handle ); if( type == OPTIX_TRANSFORM_TYPE_MATRIX_MOTION_TRANSFORM || type == OPTIX_TRANSFORM_TYPE_SRT_MOTION_TRANSFORM ) { if( type == OPTIX_TRANSFORM_TYPE_MATRIX_MOTION_TRANSFORM ) { const OptixMatrixMotionTransform* transformData = optixGetMatrixMotionTransformFromHandle( handle ); optixGetInterpolatedTransformation( trf0, trf1, trf2, transformData, time ); } else { const OptixSRTMotionTransform* transformData = optixGetSRTMotionTransformFromHandle( handle ); optixGetInterpolatedTransformation( trf0, trf1, trf2, transformData, time ); } if( !objectToWorld ) optixInvertMatrix( trf0, trf1, trf2 ); } else if( type == OPTIX_TRANSFORM_TYPE_INSTANCE || type == OPTIX_TRANSFORM_TYPE_STATIC_TRANSFORM ) { const float4* transform; if( type == OPTIX_TRANSFORM_TYPE_INSTANCE ) { transform = ( objectToWorld ) ? optixGetInstanceTransformFromHandle( handle ) : optixGetInstanceInverseTransformFromHandle( handle ); } else { const OptixStaticTransform* traversable = optixGetStaticTransformFromHandle( handle ); transform = (const float4*)( ( objectToWorld ) ? traversable->transform : traversable->invTransform ); } trf0 = optixLoadReadOnlyAlign16( &transform[0] ); trf1 = optixLoadReadOnlyAlign16( &transform[1] ); trf2 = optixLoadReadOnlyAlign16( &transform[2] ); } else { trf0 = {1.0f, 0.0f, 0.0f, 0.0f}; trf1 = {0.0f, 1.0f, 0.0f, 0.0f}; trf2 = {0.0f, 0.0f, 1.0f, 0.0f}; } } // Returns the world-to-object transformation matrix resulting from the current transform stack and current ray time. static __forceinline__ __device__ void optixGetWorldToObjectTransformMatrix( float4& m0, float4& m1, float4& m2 ) { const unsigned int size = optixGetTransformListSize(); const float time = optixGetRayTime(); #pragma unroll 1 for( unsigned int i = 0; i < size; ++i ) { OptixTraversableHandle handle = optixGetTransformListHandle( i ); float4 trf0, trf1, trf2; optixGetInterpolatedTransformationFromHandle( trf0, trf1, trf2, handle, time, /*objectToWorld*/ false ); if( i == 0 ) { m0 = trf0; m1 = trf1; m2 = trf2; } else { // m := trf * m float4 tmp0 = m0, tmp1 = m1, tmp2 = m2; m0 = optixMultiplyRowMatrix( trf0, tmp0, tmp1, tmp2 ); m1 = optixMultiplyRowMatrix( trf1, tmp0, tmp1, tmp2 ); m2 = optixMultiplyRowMatrix( trf2, tmp0, tmp1, tmp2 ); } } } // Returns the object-to-world transformation matrix resulting from the current transform stack and current ray time. static __forceinline__ __device__ void optixGetObjectToWorldTransformMatrix( float4& m0, float4& m1, float4& m2 ) { const int size = optixGetTransformListSize(); const float time = optixGetRayTime(); #pragma unroll 1 for( int i = size - 1; i >= 0; --i ) { OptixTraversableHandle handle = optixGetTransformListHandle( i ); float4 trf0, trf1, trf2; optixGetInterpolatedTransformationFromHandle( trf0, trf1, trf2, handle, time, /*objectToWorld*/ true ); if( i == size - 1 ) { m0 = trf0; m1 = trf1; m2 = trf2; } else { // m := trf * m float4 tmp0 = m0, tmp1 = m1, tmp2 = m2; m0 = optixMultiplyRowMatrix( trf0, tmp0, tmp1, tmp2 ); m1 = optixMultiplyRowMatrix( trf1, tmp0, tmp1, tmp2 ); m2 = optixMultiplyRowMatrix( trf2, tmp0, tmp1, tmp2 ); } } } // Multiplies the 3x4 matrix with rows m0, m1, m2 with the point p. static __forceinline__ __device__ float3 optixTransformPoint( const float4& m0, const float4& m1, const float4& m2, const float3& p ) { float3 result; result.x = m0.x * p.x + m0.y * p.y + m0.z * p.z + m0.w; result.y = m1.x * p.x + m1.y * p.y + m1.z * p.z + m1.w; result.z = m2.x * p.x + m2.y * p.y + m2.z * p.z + m2.w; return result; } // Multiplies the 3x3 linear submatrix of the 3x4 matrix with rows m0, m1, m2 with the vector v. static __forceinline__ __device__ float3 optixTransformVector( const float4& m0, const float4& m1, const float4& m2, const float3& v ) { float3 result; result.x = m0.x * v.x + m0.y * v.y + m0.z * v.z; result.y = m1.x * v.x + m1.y * v.y + m1.z * v.z; result.z = m2.x * v.x + m2.y * v.y + m2.z * v.z; return result; } // Multiplies the transpose of the 3x3 linear submatrix of the 3x4 matrix with rows m0, m1, m2 with the normal n. // Note that the given matrix is supposed to be the inverse of the actual transformation matrix. static __forceinline__ __device__ float3 optixTransformNormal( const float4& m0, const float4& m1, const float4& m2, const float3& n ) { float3 result; result.x = m0.x * n.x + m1.x * n.y + m2.x * n.z; result.y = m0.y * n.x + m1.y * n.y + m2.y * n.z; result.z = m0.z * n.x + m1.z * n.y + m2.z * n.z; return result; } } // namespace optix_impl #endif ================================================ FILE: render/optixutils/include/optix.h ================================================ /* * Copyright (c) 2021 NVIDIA Corporation. All rights reserved. * * NVIDIA Corporation and its licensors retain all intellectual property and proprietary * rights in and to this software, related documentation and any modifications thereto. * Any use, reproduction, disclosure or distribution of this software and related * documentation without an express license agreement from NVIDIA Corporation is strictly * prohibited. * * TO THE MAXIMUM EXTENT PERMITTED BY APPLICABLE LAW, THIS SOFTWARE IS PROVIDED *AS IS* * AND NVIDIA AND ITS SUPPLIERS DISCLAIM ALL WARRANTIES, EITHER EXPRESS OR IMPLIED, * INCLUDING, BUT NOT LIMITED TO, IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A * PARTICULAR PURPOSE. IN NO EVENT SHALL NVIDIA OR ITS SUPPLIERS BE LIABLE FOR ANY * SPECIAL, INCIDENTAL, INDIRECT, OR CONSEQUENTIAL DAMAGES WHATSOEVER (INCLUDING, WITHOUT * LIMITATION, DAMAGES FOR LOSS OF BUSINESS PROFITS, BUSINESS INTERRUPTION, LOSS OF * BUSINESS INFORMATION, OR ANY OTHER PECUNIARY LOSS) ARISING OUT OF THE USE OF OR * INABILITY TO USE THIS SOFTWARE, EVEN IF NVIDIA HAS BEEN ADVISED OF THE POSSIBILITY OF * SUCH DAMAGES */ /// @file /// @author NVIDIA Corporation /// @brief OptiX public API header /// /// Includes the host api if compiling host code, includes the cuda api if compiling device code. /// For the math library routines include optix_math.h #ifndef __optix_optix_h__ #define __optix_optix_h__ /// The OptiX version. /// /// - major = OPTIX_VERSION/10000 /// - minor = (OPTIX_VERSION%10000)/100 /// - micro = OPTIX_VERSION%100 #define OPTIX_VERSION 70300 #ifdef __CUDACC__ #include "optix_device.h" #else #include "optix_host.h" #endif #endif // __optix_optix_h__ ================================================ FILE: render/optixutils/include/optix_7_device.h ================================================ /* * Copyright (c) 2021 NVIDIA Corporation. All rights reserved. * * NVIDIA Corporation and its licensors retain all intellectual property and proprietary * rights in and to this software, related documentation and any modifications thereto. * Any use, reproduction, disclosure or distribution of this software and related * documentation without an express license agreement from NVIDIA Corporation is strictly * prohibited. * * TO THE MAXIMUM EXTENT PERMITTED BY APPLICABLE LAW, THIS SOFTWARE IS PROVIDED *AS IS* * AND NVIDIA AND ITS SUPPLIERS DISCLAIM ALL WARRANTIES, EITHER EXPRESS OR IMPLIED, * INCLUDING, BUT NOT LIMITED TO, IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A * PARTICULAR PURPOSE. IN NO EVENT SHALL NVIDIA OR ITS SUPPLIERS BE LIABLE FOR ANY * SPECIAL, INCIDENTAL, INDIRECT, OR CONSEQUENTIAL DAMAGES WHATSOEVER (INCLUDING, WITHOUT * LIMITATION, DAMAGES FOR LOSS OF BUSINESS PROFITS, BUSINESS INTERRUPTION, LOSS OF * BUSINESS INFORMATION, OR ANY OTHER PECUNIARY LOSS) ARISING OUT OF THE USE OF OR * INABILITY TO USE THIS SOFTWARE, EVEN IF NVIDIA HAS BEEN ADVISED OF THE POSSIBILITY OF * SUCH DAMAGES */ /// @file /// @author NVIDIA Corporation /// @brief OptiX public API header /// /// OptiX public API Reference - Device API declarations #if !defined( __OPTIX_INCLUDE_INTERNAL_HEADERS__ ) #error("optix_7_device.h is an internal header file and must not be used directly. Please use optix_device.h or optix.h instead.") #endif #ifndef __optix_optix_7_device_h__ #define __optix_optix_7_device_h__ #if defined( __cplusplus ) && ( __cplusplus < 201103L ) && !defined( _WIN32 ) #error Device code for OptiX requires at least C++11. Consider adding "--std c++11" to the nvcc command-line. #endif #include "optix_7_types.h" /// \defgroup optix_device_api Device API /// \brief OptiX Device API /** \addtogroup optix_device_api @{ */ /// Initiates a ray tracing query starting with the given traversable (overload without payload). /// /// \param[in] handle /// \param[in] rayOrigin /// \param[in] rayDirection /// \param[in] tmin /// \param[in] tmax /// \param[in] rayTime /// \param[in] visibilityMask really only 8 bits /// \param[in] rayFlags really only 8 bits, combination of OptixRayFlags /// \param[in] SBToffset really only 8 bits /// \param[in] SBTstride really only 8 bits /// \param[in] missSBTIndex specifies the miss program invoked on a miss static __forceinline__ __device__ void optixTrace( OptixTraversableHandle handle, float3 rayOrigin, float3 rayDirection, float tmin, float tmax, float rayTime, OptixVisibilityMask visibilityMask, unsigned int rayFlags, unsigned int SBToffset, unsigned int SBTstride, unsigned int missSBTIndex ); /// Initiates a ray tracing query starting with the given traversable (overload with 1 payload registers). /// /// \see #optixTrace(OptixTraversableHandle,float3,float3,float,float,float,OptixVisibilityMask,unsigned int,unsigned int,unsigned int,unsigned int) static __forceinline__ __device__ void optixTrace( OptixTraversableHandle handle, float3 rayOrigin, float3 rayDirection, float tmin, float tmax, float rayTime, OptixVisibilityMask visibilityMask, unsigned int rayFlags, unsigned int SBToffset, unsigned int SBTstride, unsigned int missSBTIndex, unsigned int& p0 ); /// Initiates a ray tracing query starting with the given traversable (overload with 2 payload registers). /// /// \see #optixTrace(OptixTraversableHandle,float3,float3,float,float,float,OptixVisibilityMask,unsigned int,unsigned int,unsigned int,unsigned int) static __forceinline__ __device__ void optixTrace( OptixTraversableHandle handle, float3 rayOrigin, float3 rayDirection, float tmin, float tmax, float rayTime, OptixVisibilityMask visibilityMask, unsigned int rayFlags, unsigned int SBToffset, unsigned int SBTstride, unsigned int missSBTIndex, unsigned int& p0, unsigned int& p1 ); /// Initiates a ray tracing query starting with the given traversable (overload with 3 payload registers). /// /// \see #optixTrace(OptixTraversableHandle,float3,float3,float,float,float,OptixVisibilityMask,unsigned int,unsigned int,unsigned int,unsigned int) static __forceinline__ __device__ void optixTrace( OptixTraversableHandle handle, float3 rayOrigin, float3 rayDirection, float tmin, float tmax, float rayTime, OptixVisibilityMask visibilityMask, unsigned int rayFlags, unsigned int SBToffset, unsigned int SBTstride, unsigned int missSBTIndex, unsigned int& p0, unsigned int& p1, unsigned int& p2 ); /// Initiates a ray tracing query starting with the given traversable (overload with 4 payload registers). /// /// \see #optixTrace(OptixTraversableHandle,float3,float3,float,float,float,OptixVisibilityMask,unsigned int,unsigned int,unsigned int,unsigned int) static __forceinline__ __device__ void optixTrace( OptixTraversableHandle handle, float3 rayOrigin, float3 rayDirection, float tmin, float tmax, float rayTime, OptixVisibilityMask visibilityMask, unsigned int rayFlags, unsigned int SBToffset, unsigned int SBTstride, unsigned int missSBTIndex, unsigned int& p0, unsigned int& p1, unsigned int& p2, unsigned int& p3 ); /// Initiates a ray tracing query starting with the given traversable (overload with 5 payload registers). /// /// \see #optixTrace(OptixTraversableHandle,float3,float3,float,float,float,OptixVisibilityMask,unsigned int,unsigned int,unsigned int,unsigned int) static __forceinline__ __device__ void optixTrace( OptixTraversableHandle handle, float3 rayOrigin, float3 rayDirection, float tmin, float tmax, float rayTime, OptixVisibilityMask visibilityMask, unsigned int rayFlags, unsigned int SBToffset, unsigned int SBTstride, unsigned int missSBTIndex, unsigned int& p0, unsigned int& p1, unsigned int& p2, unsigned int& p3, unsigned int& p4 ); /// Initiates a ray tracing query starting with the given traversable (overload with 6 payload registers). /// /// \see #optixTrace(OptixTraversableHandle,float3,float3,float,float,float,OptixVisibilityMask,unsigned int,unsigned int,unsigned int,unsigned int) static __forceinline__ __device__ void optixTrace( OptixTraversableHandle handle, float3 rayOrigin, float3 rayDirection, float tmin, float tmax, float rayTime, OptixVisibilityMask visibilityMask, unsigned int rayFlags, unsigned int SBToffset, unsigned int SBTstride, unsigned int missSBTIndex, unsigned int& p0, unsigned int& p1, unsigned int& p2, unsigned int& p3, unsigned int& p4, unsigned int& p5 ); /// Initiates a ray tracing query starting with the given traversable (overload with 7 payload registers). /// /// \see #optixTrace(OptixTraversableHandle,float3,float3,float,float,float,OptixVisibilityMask,unsigned int,unsigned int,unsigned int,unsigned int) static __forceinline__ __device__ void optixTrace( OptixTraversableHandle handle, float3 rayOrigin, float3 rayDirection, float tmin, float tmax, float rayTime, OptixVisibilityMask visibilityMask, unsigned int rayFlags, unsigned int SBToffset, unsigned int SBTstride, unsigned int missSBTIndex, unsigned int& p0, unsigned int& p1, unsigned int& p2, unsigned int& p3, unsigned int& p4, unsigned int& p5, unsigned int& p6 ); /// Initiates a ray tracing query starting with the given traversable (overload with 8 payload registers). /// /// \see #optixTrace(OptixTraversableHandle,float3,float3,float,float,float,OptixVisibilityMask,unsigned int,unsigned int,unsigned int,unsigned int) static __forceinline__ __device__ void optixTrace( OptixTraversableHandle handle, float3 rayOrigin, float3 rayDirection, float tmin, float tmax, float rayTime, OptixVisibilityMask visibilityMask, unsigned int rayFlags, unsigned int SBToffset, unsigned int SBTstride, unsigned int missSBTIndex, unsigned int& p0, unsigned int& p1, unsigned int& p2, unsigned int& p3, unsigned int& p4, unsigned int& p5, unsigned int& p6, unsigned int& p7 ); /// Writes the 32-bit payload value at slot 0. static __forceinline__ __device__ void optixSetPayload_0( unsigned int p ); /// Writes the 32-bit payload value at slot 1. static __forceinline__ __device__ void optixSetPayload_1( unsigned int p ); /// Writes the 32-bit payload value at slot 2. static __forceinline__ __device__ void optixSetPayload_2( unsigned int p ); /// Writes the 32-bit payload value at slot 3. static __forceinline__ __device__ void optixSetPayload_3( unsigned int p ); /// Writes the 32-bit payload value at slot 4. static __forceinline__ __device__ void optixSetPayload_4( unsigned int p ); /// Writes the 32-bit payload value at slot 5. static __forceinline__ __device__ void optixSetPayload_5( unsigned int p ); /// Writes the 32-bit payload value at slot 6. static __forceinline__ __device__ void optixSetPayload_6( unsigned int p ); /// Writes the 32-bit payload value at slot 7. static __forceinline__ __device__ void optixSetPayload_7( unsigned int p ); /// Reads the 32-bit payload value at slot 0. static __forceinline__ __device__ unsigned int optixGetPayload_0(); /// Reads the 32-bit payload value at slot 1. static __forceinline__ __device__ unsigned int optixGetPayload_1(); /// Reads the 32-bit payload value at slot 2. static __forceinline__ __device__ unsigned int optixGetPayload_2(); /// Reads the 32-bit payload value at slot 3. static __forceinline__ __device__ unsigned int optixGetPayload_3(); /// Reads the 32-bit payload value at slot 4. static __forceinline__ __device__ unsigned int optixGetPayload_4(); /// Reads the 32-bit payload value at slot 5. static __forceinline__ __device__ unsigned int optixGetPayload_5(); /// Reads the 32-bit payload value at slot 6. static __forceinline__ __device__ unsigned int optixGetPayload_6(); /// Reads the 32-bit payload value at slot 7. static __forceinline__ __device__ unsigned int optixGetPayload_7(); /// Returns an undefined value. static __forceinline__ __device__ unsigned int optixUndefinedValue(); /// Returns the rayOrigin passed into optixTrace. /// /// May be more expensive to call in IS and AH than their object space counterparts, /// so effort should be made to use the object space ray in those programs. /// Only available in IS, AH, CH, MS static __forceinline__ __device__ float3 optixGetWorldRayOrigin(); /// Returns the rayDirection passed into optixTrace. /// /// May be more expensive to call in IS and AH than their object space counterparts, /// so effort should be made to use the object space ray in those programs. /// Only available in IS, AH, CH, MS static __forceinline__ __device__ float3 optixGetWorldRayDirection(); /// Returns the current object space ray origin based on the current transform stack. /// /// Only available in IS and AH. static __forceinline__ __device__ float3 optixGetObjectRayOrigin(); /// Returns the current object space ray direction based on the current transform stack. /// /// Only available in IS and AH. static __forceinline__ __device__ float3 optixGetObjectRayDirection(); /// Returns the tmin passed into optixTrace. /// /// Only available in IS, AH, CH, MS static __forceinline__ __device__ float optixGetRayTmin(); /// In IS and CH returns the current smallest reported hitT or the tmax passed into optixTrace if no hit has been reported /// In AH returns the hitT value as passed in to optixReportIntersection /// In MS returns the tmax passed into optixTrace /// Only available in IS, AH, CH, MS static __forceinline__ __device__ float optixGetRayTmax(); /// Returns the rayTime passed into optixTrace. /// /// Will return 0 if motion is disabled. /// Only available in IS, AH, CH, MS static __forceinline__ __device__ float optixGetRayTime(); /// Returns the rayFlags passed into optixTrace /// /// Only available in IS, AH, CH, MS static __forceinline__ __device__ unsigned int optixGetRayFlags(); /// Returns the visibilityMask passed into optixTrace /// /// Only available in IS, AH, CH, MS static __forceinline__ __device__ unsigned int optixGetRayVisibilityMask(); /// Return the traversable handle of a given instance in an Instance /// Acceleration Structure (IAS) static __forceinline__ __device__ OptixTraversableHandle optixGetInstanceTraversableFromIAS( OptixTraversableHandle ias, unsigned int instIdx ); /// Return the object space triangle vertex positions of a given triangle in a Geometry /// Acceleration Structure (GAS) at a given motion time. /// To access vertex data, the GAS must be built using the flag OPTIX_BUILD_FLAG_ALLOW_RANDOM_VERTEX_ACCESS. /// /// If motion is disabled via OptixPipelineCompileOptions::usesMotionBlur, or the GAS does not contain motion, the /// time parameter is ignored. static __forceinline__ __device__ void optixGetTriangleVertexData( OptixTraversableHandle gas, unsigned int primIdx, unsigned int sbtGASIndex, float time, float3 data[3]); /// Return the object space curve control vertex data of a linear curve in a Geometry /// Acceleration Structure (GAS) at a given motion time. /// To access vertex data, the GAS must be built using the flag OPTIX_BUILD_FLAG_ALLOW_RANDOM_VERTEX_ACCESS. /// /// data[i] = {x,y,z,w} with {x,y,z} the position and w the radius of control vertex i. /// If motion is disabled via OptixPipelineCompileOptions::usesMotionBlur, or the GAS does not contain motion, the /// time parameter is ignored. static __forceinline__ __device__ void optixGetLinearCurveVertexData( OptixTraversableHandle gas, unsigned int primIdx, unsigned int sbtGASIndex, float time, float4 data[2] ); /// Return the object space curve control vertex data of a quadratic BSpline curve in a Geometry /// Acceleration Structure (GAS) at a given motion time. /// To access vertex data, the GAS must be built using the flag OPTIX_BUILD_FLAG_ALLOW_RANDOM_VERTEX_ACCESS. /// /// data[i] = {x,y,z,w} with {x,y,z} the position and w the radius of control vertex i. /// If motion is disabled via OptixPipelineCompileOptions::usesMotionBlur, or the GAS does not contain motion, the /// time parameter is ignored. static __forceinline__ __device__ void optixGetQuadraticBSplineVertexData( OptixTraversableHandle gas, unsigned int primIdx, unsigned int sbtGASIndex, float time, float4 data[3] ); /// Return the object space curve control vertex data of a cubic BSpline curve in a Geometry /// Acceleration Structure (GAS) at a given motion time. /// To access vertex data, the GAS must be built using the flag OPTIX_BUILD_FLAG_ALLOW_RANDOM_VERTEX_ACCESS. /// /// data[i] = {x,y,z,w} with {x,y,z} the position and w the radius of control vertex i. /// If motion is disabled via OptixPipelineCompileOptions::usesMotionBlur, or the GAS does not contain motion, the /// time parameter is ignored. static __forceinline__ __device__ void optixGetCubicBSplineVertexData( OptixTraversableHandle gas, unsigned int primIdx, unsigned int sbtGASIndex, float time, float4 data[4] ); /// Returns the traversable handle for the Geometry Acceleration Structure (GAS) containing /// the current hit. May be called from IS, AH and CH. static __forceinline__ __device__ OptixTraversableHandle optixGetGASTraversableHandle(); /// Returns the motion begin time of a GAS (see OptixMotionOptions) static __forceinline__ __device__ float optixGetGASMotionTimeBegin( OptixTraversableHandle gas ); /// Returns the motion end time of a GAS (see OptixMotionOptions) static __forceinline__ __device__ float optixGetGASMotionTimeEnd( OptixTraversableHandle gas ); /// Returns the number of motion steps of a GAS (see OptixMotionOptions) static __forceinline__ __device__ unsigned int optixGetGASMotionStepCount( OptixTraversableHandle gas ); /// Returns the world-to-object transformation matrix resulting from the current active transformation list. /// /// The cost of this function may be proportional to the size of the transformation list. static __forceinline__ __device__ void optixGetWorldToObjectTransformMatrix( float m[12] ); /// Returns the object-to-world transformation matrix resulting from the current active transformation list. /// /// The cost of this function may be proportional to the size of the transformation list. static __forceinline__ __device__ void optixGetObjectToWorldTransformMatrix( float m[12] ); /// Transforms the point using world-to-object transformation matrix resulting from the current active transformation /// list. /// /// The cost of this function may be proportional to the size of the transformation list. static __forceinline__ __device__ float3 optixTransformPointFromWorldToObjectSpace( float3 point ); /// Transforms the vector using world-to-object transformation matrix resulting from the current active transformation /// list. /// /// The cost of this function may be proportional to the size of the transformation list. static __forceinline__ __device__ float3 optixTransformVectorFromWorldToObjectSpace( float3 vec ); /// Transforms the normal using world-to-object transformation matrix resulting from the current active transformation /// list. /// /// The cost of this function may be proportional to the size of the transformation list. static __forceinline__ __device__ float3 optixTransformNormalFromWorldToObjectSpace( float3 normal ); /// Transforms the point using object-to-world transformation matrix resulting from the current active transformation /// list. /// /// The cost of this function may be proportional to the size of the transformation list. static __forceinline__ __device__ float3 optixTransformPointFromObjectToWorldSpace( float3 point ); /// Transforms the vector using object-to-world transformation matrix resulting from the current active transformation /// list. /// /// The cost of this function may be proportional to the size of the transformation list. static __forceinline__ __device__ float3 optixTransformVectorFromObjectToWorldSpace( float3 vec ); /// Transforms the normal using object-to-world transformation matrix resulting from the current active transformation /// list. /// /// The cost of this function may be proportional to the size of the transformation list. static __forceinline__ __device__ float3 optixTransformNormalFromObjectToWorldSpace( float3 normal ); /// Returns the number of transforms on the current transform list. /// /// Only available in IS, AH, CH, EX static __forceinline__ __device__ unsigned int optixGetTransformListSize(); /// Returns the traversable handle for a transform on the current transform list. /// /// Only available in IS, AH, CH, EX static __forceinline__ __device__ OptixTraversableHandle optixGetTransformListHandle( unsigned int index ); /// Returns the transform type of a traversable handle from a transform list. static __forceinline__ __device__ OptixTransformType optixGetTransformTypeFromHandle( OptixTraversableHandle handle ); /// Returns a pointer to a OptixStaticTransform from its traversable handle. /// /// Returns 0 if the traversable is not of type OPTIX_TRANSFORM_TYPE_STATIC_TRANSFORM. static __forceinline__ __device__ const OptixStaticTransform* optixGetStaticTransformFromHandle( OptixTraversableHandle handle ); /// Returns a pointer to a OptixSRTMotionTransform from its traversable handle. /// /// Returns 0 if the traversable is not of type OPTIX_TRANSFORM_TYPE_SRT_MOTION_TRANSFORM. static __forceinline__ __device__ const OptixSRTMotionTransform* optixGetSRTMotionTransformFromHandle( OptixTraversableHandle handle ); /// Returns a pointer to a OptixMatrixMotionTransform from its traversable handle. /// /// Returns 0 if the traversable is not of type OPTIX_TRANSFORM_TYPE_MATRIX_MOTION_TRANSFORM. static __forceinline__ __device__ const OptixMatrixMotionTransform* optixGetMatrixMotionTransformFromHandle( OptixTraversableHandle handle ); /// Returns instanceId from an OptixInstance traversable. /// /// Returns 0 if the traversable handle does not reference an OptixInstance. static __forceinline__ __device__ unsigned int optixGetInstanceIdFromHandle( OptixTraversableHandle handle ); /// Returns child traversable handle from an OptixInstance traversable. /// /// Returns 0 if the traversable handle does not reference an OptixInstance. static __forceinline__ __device__ OptixTraversableHandle optixGetInstanceChildFromHandle( OptixTraversableHandle handle ); /// Returns object-to-world transform from an OptixInstance traversable. /// /// Returns 0 if the traversable handle does not reference an OptixInstance. static __forceinline__ __device__ const float4* optixGetInstanceTransformFromHandle( OptixTraversableHandle handle ); /// Returns world-to-object transform from an OptixInstance traversable. /// /// Returns 0 if the traversable handle does not reference an OptixInstance. static __forceinline__ __device__ const float4* optixGetInstanceInverseTransformFromHandle( OptixTraversableHandle handle ); /// Reports an intersections (overload without attributes). /// /// If optixGetRayTmin() <= hitT <= optixGetRayTmax(), the any hit program associated with this intersection program (via the SBT entry) is called. /// The AH program can do one of three things: /// 1. call optixIgnoreIntersection - no hit is recorded, optixReportIntersection returns false /// 2. call optixTerminateRay - hit is recorded, optixReportIntersection does not return, no further traversal occurs, /// and the associated closest hit program is called /// 3. neither - hit is recorded, optixReportIntersection returns true /// hitKind - Only the 7 least significant bits should be written [0..127]. Any values above 127 are reserved for built in intersection. The value can be queried with optixGetHitKind() in AH and CH. /// /// The attributes specified with a0..a7 are available in the AH and CH programs. /// Note that the attributes available in the CH program correspond to the closest recorded intersection. /// The number of attributes in registers and memory can be configured in the pipeline. /// /// \param[in] hitT /// \param[in] hitKind static __forceinline__ __device__ bool optixReportIntersection( float hitT, unsigned int hitKind ); /// Reports an intersection (overload with 1 attribute register). /// /// \see #optixReportIntersection(float,unsigned int) static __forceinline__ __device__ bool optixReportIntersection( float hitT, unsigned int hitKind, unsigned int a0 ); /// Reports an intersection (overload with 2 attribute registers). /// /// \see #optixReportIntersection(float,unsigned int) static __forceinline__ __device__ bool optixReportIntersection( float hitT, unsigned int hitKind, unsigned int a0, unsigned int a1 ); /// Reports an intersection (overload with 3 attribute registers). /// /// \see #optixReportIntersection(float,unsigned int) static __forceinline__ __device__ bool optixReportIntersection( float hitT, unsigned int hitKind, unsigned int a0, unsigned int a1, unsigned int a2 ); /// Reports an intersection (overload with 4 attribute registers). /// /// \see #optixReportIntersection(float,unsigned int) static __forceinline__ __device__ bool optixReportIntersection( float hitT, unsigned int hitKind, unsigned int a0, unsigned int a1, unsigned int a2, unsigned int a3 ); /// Reports an intersection (overload with 5 attribute registers). /// /// \see #optixReportIntersection(float,unsigned int) static __forceinline__ __device__ bool optixReportIntersection( float hitT, unsigned int hitKind, unsigned int a0, unsigned int a1, unsigned int a2, unsigned int a3, unsigned int a4 ); /// Reports an intersection (overload with 6 attribute registers). /// /// \see #optixReportIntersection(float,unsigned int) static __forceinline__ __device__ bool optixReportIntersection( float hitT, unsigned int hitKind, unsigned int a0, unsigned int a1, unsigned int a2, unsigned int a3, unsigned int a4, unsigned int a5 ); /// Reports an intersection (overload with 7 attribute registers). /// /// \see #optixReportIntersection(float,unsigned int) static __forceinline__ __device__ bool optixReportIntersection( float hitT, unsigned int hitKind, unsigned int a0, unsigned int a1, unsigned int a2, unsigned int a3, unsigned int a4, unsigned int a5, unsigned int a6 ); /// Reports an intersection (overload with 8 attribute registers). /// /// \see #optixReportIntersection(float,unsigned int) static __forceinline__ __device__ bool optixReportIntersection( float hitT, unsigned int hitKind, unsigned int a0, unsigned int a1, unsigned int a2, unsigned int a3, unsigned int a4, unsigned int a5, unsigned int a6, unsigned int a7 ); /// Returns the attribute at slot 0. static __forceinline__ __device__ unsigned int optixGetAttribute_0(); /// Returns the attribute at slot 1. static __forceinline__ __device__ unsigned int optixGetAttribute_1(); /// Returns the attribute at slot 2. static __forceinline__ __device__ unsigned int optixGetAttribute_2(); /// Returns the attribute at slot 3. static __forceinline__ __device__ unsigned int optixGetAttribute_3(); /// Returns the attribute at slot 4. static __forceinline__ __device__ unsigned int optixGetAttribute_4(); /// Returns the attribute at slot 5. static __forceinline__ __device__ unsigned int optixGetAttribute_5(); /// Returns the attribute at slot 6. static __forceinline__ __device__ unsigned int optixGetAttribute_6(); /// Returns the attribute at slot 7. static __forceinline__ __device__ unsigned int optixGetAttribute_7(); /// Record the hit, stops traversal, and proceeds to CH. /// /// Available only in AH. static __forceinline__ __device__ void optixTerminateRay(); /// Discards the hit, and returns control to the calling optixReportIntersection or built-in intersection routine. /// /// Available only in AH. static __forceinline__ __device__ void optixIgnoreIntersection(); /// For a given OptixBuildInputTriangleArray the number of primitives is defined as /// "(OptixBuildInputTriangleArray::indexBuffer == 0) ? OptixBuildInputTriangleArray::numVertices/3 : /// OptixBuildInputTriangleArray::numIndexTriplets;". /// For a given OptixBuildInputCustomPrimitiveArray the number of primitives is defined as /// numAabbs. /// /// The primitive index returns the index into the array of primitives /// plus the primitiveIndexOffset. /// /// In IS and AH this corresponds to the currently intersected primitive. /// In CH this corresponds to the primitive index of the closest intersected primitive. static __forceinline__ __device__ unsigned int optixGetPrimitiveIndex(); /// Returns the Sbt GAS index of the primitive associated with the current intersection. /// /// In IS and AH this corresponds to the currently intersected primitive. /// In CH this corresponds to the Sbt GAS index of the closest intersected primitive. /// In EX with exception code OPTIX_EXCEPTION_CODE_TRAVERSAL_INVALID_HIT_SBT corresponds to the sbt index within the hit GAS. Returns zero for all other exceptions. static __forceinline__ __device__ unsigned int optixGetSbtGASIndex(); /// Returns the OptixInstance::instanceId of the instance within the top level acceleration structure associated with the current intersection. /// /// When building an acceleration structure using OptixBuildInputInstanceArray each OptixInstance has a user supplied instanceId. /// OptixInstance objects reference another acceleration structure. During traversal the acceleration structures are visited top down. /// In the IS and AH programs the OptixInstance::instanceId corresponding to the most recently visited OptixInstance is returned when calling optixGetInstanceId(). /// In CH optixGetInstanceId() returns the OptixInstance::instanceId when the hit was recorded with optixReportIntersection. /// In the case where there is no OptixInstance visited, optixGetInstanceId returns ~0u static __forceinline__ __device__ unsigned int optixGetInstanceId(); /// Returns the zero-based index of the instance within its instance acceleration structure associated with the current intersection. /// /// In the IS and AH programs the index corresponding to the most recently visited OptixInstance is returned when calling optixGetInstanceIndex(). /// In CH optixGetInstanceIndex() returns the index when the hit was recorded with optixReportIntersection. /// In the case where there is no OptixInstance visited, optixGetInstanceIndex returns 0 static __forceinline__ __device__ unsigned int optixGetInstanceIndex(); /// Returns the 8 bit hit kind associated with the current hit. /// /// Use optixGetPrimitiveType() to interpret the hit kind. /// For custom intersections (primitive type OPTIX_PRIMITIVE_TYPE_CUSTOM), /// this is the 7-bit hitKind passed to optixReportIntersection(). /// Hit kinds greater than 127 are reserved for built-in primitives. /// /// Available only in AH and CH. static __forceinline__ __device__ unsigned int optixGetHitKind(); /// Function interpreting the result of #optixGetHitKind(). static __forceinline__ __device__ OptixPrimitiveType optixGetPrimitiveType( unsigned int hitKind ); /// Function interpreting the result of #optixGetHitKind(). static __forceinline__ __device__ bool optixIsFrontFaceHit( unsigned int hitKind ); /// Function interpreting the result of #optixGetHitKind(). static __forceinline__ __device__ bool optixIsBackFaceHit( unsigned int hitKind ); /// Function interpreting the hit kind associated with the current optixReportIntersection. static __forceinline__ __device__ OptixPrimitiveType optixGetPrimitiveType(); /// Function interpreting the hit kind associated with the current optixReportIntersection. static __forceinline__ __device__ bool optixIsFrontFaceHit(); /// Function interpreting the hit kind associated with the current optixReportIntersection. static __forceinline__ __device__ bool optixIsBackFaceHit(); /// Convenience function interpreting the result of #optixGetHitKind(). static __forceinline__ __device__ bool optixIsTriangleHit(); /// Convenience function interpreting the result of #optixGetHitKind(). static __forceinline__ __device__ bool optixIsTriangleFrontFaceHit(); /// Convenience function interpreting the result of #optixGetHitKind(). static __forceinline__ __device__ bool optixIsTriangleBackFaceHit(); /// Convenience function that returns the first two attributes as floats. /// /// When using OptixBuildInputTriangleArray objects, during intersection the barycentric /// coordinates are stored into the first two attribute registers. static __forceinline__ __device__ float2 optixGetTriangleBarycentrics(); /// Convenience function that returns the curve parameter. /// /// When using OptixBuildInputCurveArray objects, during intersection the curve parameter /// is stored into the first attribute register. static __forceinline__ __device__ float optixGetCurveParameter(); /// Available in any program, it returns the current launch index within the launch dimensions specified by optixLaunch on the host. /// /// The raygen program is typically only launched once per launch index. static __forceinline__ __device__ uint3 optixGetLaunchIndex(); /// Available in any program, it returns the dimensions of the current launch specified by optixLaunch on the host. static __forceinline__ __device__ uint3 optixGetLaunchDimensions(); /// Returns the generic memory space pointer to the data region (past the header) of the currently active SBT record corresponding to the current program. static __forceinline__ __device__ CUdeviceptr optixGetSbtDataPointer(); /// Throws a user exception with the given exception code (overload without exception details). /// /// The exception code must be in the range from 0 to 2^30 - 1. Up to 8 optional exception details can be passed. They /// can be queried in the EX program using optixGetExceptionDetail_0() to ..._8(). /// /// The exception details must not be used to encode pointers to the stack since the current stack is not preserved in /// the EX program. /// /// Not available in EX. /// /// \param[in] exceptionCode The exception code to be thrown. static __forceinline__ __device__ void optixThrowException( int exceptionCode ); /// Throws a user exception with the given exception code (overload with 1 exception detail). /// /// \see #optixThrowException(int) static __forceinline__ __device__ void optixThrowException( int exceptionCode, unsigned int exceptionDetail0 ); /// Throws a user exception with the given exception code (overload with 2 exception details). /// /// \see #optixThrowException(int) static __forceinline__ __device__ void optixThrowException( int exceptionCode, unsigned int exceptionDetail0, unsigned int exceptionDetail1 ); /// Throws a user exception with the given exception code (overload with 3 exception details). /// /// \see #optixThrowException(int) static __forceinline__ __device__ void optixThrowException( int exceptionCode, unsigned int exceptionDetail0, unsigned int exceptionDetail1, unsigned int exceptionDetail2 ); /// Throws a user exception with the given exception code (overload with 4 exception details). /// /// \see #optixThrowException(int) static __forceinline__ __device__ void optixThrowException( int exceptionCode, unsigned int exceptionDetail0, unsigned int exceptionDetail1, unsigned int exceptionDetail2, unsigned int exceptionDetail3 ); /// Throws a user exception with the given exception code (overload with 5 exception details). /// /// \see #optixThrowException(int) static __forceinline__ __device__ void optixThrowException( int exceptionCode, unsigned int exceptionDetail0, unsigned int exceptionDetail1, unsigned int exceptionDetail2, unsigned int exceptionDetail3, unsigned int exceptionDetail4 ); /// Throws a user exception with the given exception code (overload with 6 exception details). /// /// \see #optixThrowException(int) static __forceinline__ __device__ void optixThrowException( int exceptionCode, unsigned int exceptionDetail0, unsigned int exceptionDetail1, unsigned int exceptionDetail2, unsigned int exceptionDetail3, unsigned int exceptionDetail4, unsigned int exceptionDetail5 ); /// Throws a user exception with the given exception code (overload with 7 exception details). /// /// \see #optixThrowException(int) static __forceinline__ __device__ void optixThrowException( int exceptionCode, unsigned int exceptionDetail0, unsigned int exceptionDetail1, unsigned int exceptionDetail2, unsigned int exceptionDetail3, unsigned int exceptionDetail4, unsigned int exceptionDetail5, unsigned int exceptionDetail6 ); /// Throws a user exception with the given exception code (overload with 8 exception details). /// /// \see #optixThrowException(int) static __forceinline__ __device__ void optixThrowException( int exceptionCode, unsigned int exceptionDetail0, unsigned int exceptionDetail1, unsigned int exceptionDetail2, unsigned int exceptionDetail3, unsigned int exceptionDetail4, unsigned int exceptionDetail5, unsigned int exceptionDetail6, unsigned int exceptionDetail7 ); /// Returns the exception code. /// /// Only available in EX. static __forceinline__ __device__ int optixGetExceptionCode(); /// Returns the 32-bit exception detail at slot 0. /// /// The behavior is undefined if the exception is not a user exception, or the used overload #optixThrowException() did /// not provide the queried exception detail. /// /// Only available in EX. static __forceinline__ __device__ unsigned int optixGetExceptionDetail_0(); /// Returns the 32-bit exception detail at slot 1. /// /// \see #optixGetExceptionDetail_0() static __forceinline__ __device__ unsigned int optixGetExceptionDetail_1(); /// Returns the 32-bit exception detail at slot 2. /// /// \see #optixGetExceptionDetail_0() static __forceinline__ __device__ unsigned int optixGetExceptionDetail_2(); /// Returns the 32-bit exception detail at slot 3. /// /// \see #optixGetExceptionDetail_0() static __forceinline__ __device__ unsigned int optixGetExceptionDetail_3(); /// Returns the 32-bit exception detail at slot 4. /// /// \see #optixGetExceptionDetail_0() static __forceinline__ __device__ unsigned int optixGetExceptionDetail_4(); /// Returns the 32-bit exception detail at slot 5. /// /// \see #optixGetExceptionDetail_0() static __forceinline__ __device__ unsigned int optixGetExceptionDetail_5(); /// Returns the 32-bit exception detail at slot 6. /// /// \see #optixGetExceptionDetail_0() static __forceinline__ __device__ unsigned int optixGetExceptionDetail_6(); /// Returns the 32-bit exception detail at slot 7. /// /// \see #optixGetExceptionDetail_0() static __forceinline__ __device__ unsigned int optixGetExceptionDetail_7(); /// Returns the invalid traversable handle for exceptions with exception code OPTIX_EXCEPTION_CODE_TRAVERSAL_INVALID_TRAVERSABLE. /// /// Returns zero for all other exception codes. /// /// Only available in EX. static __forceinline__ __device__ OptixTraversableHandle optixGetExceptionInvalidTraversable(); /// Returns the invalid sbt offset for exceptions with exception code OPTIX_EXCEPTION_CODE_TRAVERSAL_INVALID_MISS_SBT and OPTIX_EXCEPTION_CODE_TRAVERSAL_INVALID_HIT_SBT. /// /// Returns zero for all other exception codes. /// /// Only available in EX. static __forceinline__ __device__ int optixGetExceptionInvalidSbtOffset(); /// Returns the invalid ray for exceptions with exception code OPTIX_EXCEPTION_CODE_INVALID_RAY. /// Exceptions of type OPTIX_EXCEPTION_CODE_INVALID_RAY are thrown when one or more values that were /// passed into optixTrace are either inf or nan. /// /// OptixInvalidRayExceptionDetails::rayTime will always be 0 if OptixPipelineCompileOptions::usesMotionBlur is 0. /// Values in the returned struct are all zero for all other exception codes. /// /// Only available in EX. static __forceinline__ __device__ OptixInvalidRayExceptionDetails optixGetExceptionInvalidRay(); /// Returns information about an exception with code OPTIX_EXCEPTION_CODE_CALLABLE_PARAMETER_MISMATCH. /// /// Exceptions of type OPTIX_EXCEPTION_CODE_CALLABLE_PARAMETER_MISMATCH are called when the number of /// arguments that were passed into a call to optixDirectCall or optixContinuationCall does not match /// the number of parameters of the callable that is called. /// Note that the parameters are packed by OptiX into individual 32 bit values, so the number of /// expected and passed values may not correspond to the number of arguments passed into optixDirectCall /// or optixContinuationCall. /// /// Values in the returned struct are all zero for all other exception codes. /// /// Only available in EX. static __forceinline__ __device__ OptixParameterMismatchExceptionDetails optixGetExceptionParameterMismatch(); /// Returns a string that includes information about the source location that caused the current exception. /// /// The source location is only available for exceptions of type OPTIX_EXCEPTION_CODE_CALLABLE_PARAMETER_MISMATCH, /// OPTIX_EXCEPTION_CODE_UNSUPPORTED_PRIMITIVE_TYPE, OPTIX_EXCEPTION_CODE_INVALID_RAY, and for user exceptions. /// Line information needs to be present in the input PTX and OptixModuleCompileOptions::debugLevel /// may not be set to OPTIX_COMPILE_DEBUG_LEVEL_NONE. /// /// Returns a NULL pointer if no line information is available. /// /// Only available in EX. static __forceinline__ __device__ char* optixGetExceptionLineInfo(); /// Creates a call to the direct callable program at the specified SBT entry. /// /// This will call the program that was specified in the OptixProgramGroupCallables::entryFunctionNameDC in the /// module specified by OptixProgramGroupCallables::moduleDC. /// The address of the SBT entry is calculated by OptixShaderBindingTable::callablesRecordBase + ( OptixShaderBindingTable::callablesRecordStrideInBytes * sbtIndex ). /// /// Behavior is undefined if there is no direct callable program at the specified SBT entry. /// /// Behavior is undefined if the number of arguments that are being passed in does not match the number of /// parameters expected by the program that is called. In that case an exception of type OPTIX_EXCEPTION_CODE_CALLABLE_PARAMETER_MISMATCH /// will be thrown if OPTIX_EXCEPTION_FLAG_DEBUG was specified for the OptixPipelineCompileOptions::exceptionFlags. /// /// \param[in] sbtIndex The offset of the SBT entry of the direct callable program to call relative to OptixShaderBindingTable::callablesRecordBase. /// \param[in] args The arguments to pass to the direct callable program. template static __forceinline__ __device__ ReturnT optixDirectCall( unsigned int sbtIndex, ArgTypes... args ); /// Creates a call to the continuation callable program at the specified SBT entry. /// /// This will call the program that was specified in the OptixProgramGroupCallables::entryFunctionNameCC in the /// module specified by OptixProgramGroupCallables::moduleCC. /// The address of the SBT entry is calculated by OptixShaderBindingTable::callablesRecordBase + ( OptixShaderBindingTable::callablesRecordStrideInBytes * sbtIndex ). /// As opposed to direct callable programs, continuation callable programs are allowed to call optixTrace recursively. /// /// Behavior is undefined if there is no continuation callable program at the specified SBT entry. /// /// Behavior is undefined if the number of arguments that are being passed in does not match the number of /// parameters expected by the program that is called. In that case an exception of type OPTIX_EXCEPTION_CODE_CALLABLE_PARAMETER_MISMATCH /// will be thrown if OPTIX_EXCEPTION_FLAG_DEBUG was specified for the OptixPipelineCompileOptions::exceptionFlags. /// /// \param[in] sbtIndex The offset of the SBT entry of the continuation callable program to call relative to OptixShaderBindingTable::callablesRecordBase. /// \param[in] args The arguments to pass to the continuation callable program. template static __forceinline__ __device__ ReturnT optixContinuationCall( unsigned int sbtIndex, ArgTypes... args ); /// optixTexFootprint2D calculates the footprint of a corresponding 2D texture fetch (non-mipmapped). /// /// On Turing and subsequent architectures, a texture footprint instruction allows user programs to /// determine the set of texels that would be accessed by an equivalent filtered texture lookup. /// /// \param[in] tex CUDA texture object (cast to 64-bit integer) /// \param[in] texInfo Texture info packed into 32-bit integer, described below. /// \param[in] x Texture coordinate /// \param[in] y Texture coordinate /// \param[out] singleMipLevel Result indicating whether the footprint spans only a single miplevel. /// /// The texture info argument is a packed 32-bit integer with the following layout: /// /// texInfo[31:29] = reserved (3 bits) /// texInfo[28:24] = miplevel count (5 bits) /// texInfo[23:20] = log2 of tile width (4 bits) /// texInfo[19:16] = log2 of tile height (4 bits) /// texInfo[15:10] = reserved (6 bits) /// texInfo[9:8] = horizontal wrap mode (2 bits) (CUaddress_mode) /// texInfo[7:6] = vertical wrap mode (2 bits) (CUaddress_mode) /// texInfo[5] = mipmap filter mode (1 bit) (CUfilter_mode) /// texInfo[4:0] = maximum anisotropy (5 bits) /// /// Returns a 16-byte structure (as a uint4) that stores the footprint of a texture request at a /// particular "granularity", which has the following layout: /// /// struct Texture2DFootprint /// { /// unsigned long long mask; /// unsigned int tileY : 12; /// unsigned int reserved1 : 4; /// unsigned int dx : 3; /// unsigned int dy : 3; /// unsigned int reserved2 : 2; /// unsigned int granularity : 4; /// unsigned int reserved3 : 4; /// unsigned int tileX : 12; /// unsigned int level : 4; /// unsigned int reserved4 : 16; /// }; /// /// The granularity indicates the size of texel groups that are represented by an 8x8 bitmask. For /// example, a granularity of 12 indicates texel groups that are 128x64 texels in size. In a /// footprint call, The returned granularity will either be the actual granularity of the result, or /// 0 if the footprint call was able to honor the requested granularity (the usual case). /// /// level is the mip level of the returned footprint. Two footprint calls are needed to get the /// complete footprint when a texture call spans multiple mip levels. /// /// mask is an 8x8 bitmask of texel groups that are covered, or partially covered, by the footprint. /// tileX and tileY give the starting position of the mask in 8x8 texel-group blocks. For example, /// suppose a granularity of 12 (128x64 texels), and tileX=3 and tileY=4. In this case, bit 0 of the /// mask (the low order bit) corresponds to texel group coordinates (3*8, 4*8), and texel /// coordinates (3*8*128, 4*8*64), within the specified mip level. /// /// If nonzero, dx and dy specify a "toroidal rotation" of the bitmask. Toroidal rotation of a /// coordinate in the mask simply means that its value is reduced by 8. Continuing the example from /// above, if dx=0 and dy=0 the mask covers texel groups (3*8, 4*8) to (3*8+7, 4*8+7) inclusive. /// If, on the other hand, dx=2, the rightmost 2 columns in the mask have their x coordinates /// reduced by 8, and similarly for dy. /// /// See the OptiX SDK for sample code that illustrates how to unpack the result. static __forceinline__ __device__ uint4 optixTexFootprint2D( unsigned long long tex, unsigned int texInfo, float x, float y, unsigned int* singleMipLevel ); /// optixTexFootprint2DLod calculates the footprint of a corresponding 2D texture fetch (tex2DLod) /// \param[in] tex CUDA texture object (cast to 64-bit integer) /// \param[in] texInfo Texture info packed into 32-bit integer, described below. /// \param[in] x Texture coordinate /// \param[in] y Texture coordinate /// \param[in] level Level of detail (lod) /// \param[in] coarse Requests footprint from coarse miplevel, when the footprint spans two levels. /// \param[out] singleMipLevel Result indicating whether the footprint spans only a single miplevel. /// \see #optixTexFootprint2D(unsigned long long,unsigned int,float,float,unsigned int*) static __forceinline__ __device__ uint4 optixTexFootprint2DLod( unsigned long long tex, unsigned int texInfo, float x, float y, float level, bool coarse, unsigned int* singleMipLevel ); /// optixTexFootprint2DGrad calculates the footprint of a corresponding 2D texture fetch (tex2DGrad) /// \param[in] tex CUDA texture object (cast to 64-bit integer) /// \param[in] texInfo Texture info packed into 32-bit integer, described below. /// \param[in] x Texture coordinate /// \param[in] y Texture coordinate /// \param[in] dPdx_x Derivative of x coordinte, which determines level of detail. /// \param[in] dPdx_y Derivative of x coordinte, which determines level of detail. /// \param[in] dPdy_x Derivative of y coordinte, which determines level of detail. /// \param[in] dPdy_y Derivative of y coordinte, which determines level of detail. /// \param[in] coarse Requests footprint from coarse miplevel, when the footprint spans two levels. /// \param[out] singleMipLevel Result indicating whether the footprint spans only a single miplevel. /// \see #optixTexFootprint2D(unsigned long long,unsigned int,float,float,unsigned int*) static __forceinline__ __device__ uint4 optixTexFootprint2DGrad( unsigned long long tex, unsigned int texInfo, float x, float y, float dPdx_x, float dPdx_y, float dPdy_x, float dPdy_y, bool coarse, unsigned int* singleMipLevel ); /*@}*/ // end group optix_device_api #include "internal/optix_7_device_impl.h" #endif // __optix_optix_7_device_h__ ================================================ FILE: render/optixutils/include/optix_7_host.h ================================================ /* * Copyright (c) 2021 NVIDIA Corporation. All rights reserved. * * NVIDIA Corporation and its licensors retain all intellectual property and proprietary * rights in and to this software, related documentation and any modifications thereto. * Any use, reproduction, disclosure or distribution of this software and related * documentation without an express license agreement from NVIDIA Corporation is strictly * prohibited. * * TO THE MAXIMUM EXTENT PERMITTED BY APPLICABLE LAW, THIS SOFTWARE IS PROVIDED *AS IS* * AND NVIDIA AND ITS SUPPLIERS DISCLAIM ALL WARRANTIES, EITHER EXPRESS OR IMPLIED, * INCLUDING, BUT NOT LIMITED TO, IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A * PARTICULAR PURPOSE. IN NO EVENT SHALL NVIDIA OR ITS SUPPLIERS BE LIABLE FOR ANY * SPECIAL, INCIDENTAL, INDIRECT, OR CONSEQUENTIAL DAMAGES WHATSOEVER (INCLUDING, WITHOUT * LIMITATION, DAMAGES FOR LOSS OF BUSINESS PROFITS, BUSINESS INTERRUPTION, LOSS OF * BUSINESS INFORMATION, OR ANY OTHER PECUNIARY LOSS) ARISING OUT OF THE USE OF OR * INABILITY TO USE THIS SOFTWARE, EVEN IF NVIDIA HAS BEEN ADVISED OF THE POSSIBILITY OF * SUCH DAMAGES */ /// @file /// @author NVIDIA Corporation /// @brief OptiX public API header /// /// OptiX host include file -- includes the host api if compiling host code. /// For the math library routines include optix_math.h #if !defined( __OPTIX_INCLUDE_INTERNAL_HEADERS__ ) #error("optix_7_host.h is an internal header file and must not be used directly. Please use optix_host.h or optix.h instead.") #endif #ifndef __optix_optix_7_host_h__ #define __optix_optix_7_host_h__ #include "optix_7_types.h" #if !defined( OPTIX_DONT_INCLUDE_CUDA ) // If OPTIX_DONT_INCLUDE_CUDA is defined, cuda driver types must be defined through other // means before including optix headers. #include #endif #ifdef __cplusplus extern "C" { #endif /// \defgroup optix_host_api Host API /// \brief OptiX Host API /// \defgroup optix_host_api_error_handling Error handling /// \ingroup optix_host_api //@{ /// Returns a string containing the name of an error code in the enum. /// /// Output is a string representation of the enum. For example "OPTIX_SUCCESS" for /// OPTIX_SUCCESS and "OPTIX_ERROR_INVALID_VALUE" for OPTIX_ERROR_INVALID_VALUE. /// /// If the error code is not recognized, "Unrecognized OptixResult code" is returned. /// /// \param[in] result OptixResult enum to generate string name for /// /// \see #optixGetErrorString const char* optixGetErrorName( OptixResult result ); /// Returns the description string for an error code. /// /// Output is a string description of the enum. For example "Success" for /// OPTIX_SUCCESS and "Invalid value" for OPTIX_ERROR_INVALID_VALUE. /// /// If the error code is not recognized, "Unrecognized OptixResult code" is returned. /// /// \param[in] result OptixResult enum to generate string description for /// /// \see #optixGetErrorName const char* optixGetErrorString( OptixResult result ); //@} /// \defgroup optix_host_api_device_context Device context /// \ingroup optix_host_api //@{ /// Create a device context associated with the CUDA context specified with 'fromContext'. /// /// If zero is specified for 'fromContext', OptiX will use the current CUDA context. The /// CUDA context should be initialized before calling optixDeviceContextCreate. /// /// \param[in] fromContext /// \param[in] options /// \param[out] context /// \return /// - OPTIX_ERROR_CUDA_NOT_INITIALIZED /// If using zero for 'fromContext' and CUDA has not been initialized yet on the calling /// thread. /// - OPTIX_ERROR_CUDA_ERROR /// CUDA operation failed. /// - OPTIX_ERROR_HOST_OUT_OF_MEMORY /// Heap allocation failed. /// - OPTIX_ERROR_INTERNAL_ERROR /// Internal error OptixResult optixDeviceContextCreate( CUcontext fromContext, const OptixDeviceContextOptions* options, OptixDeviceContext* context ); /// Destroys all CPU and GPU state associated with the device. /// /// It will attempt to block on CUDA streams that have launch work outstanding. /// /// Any API objects, such as OptixModule and OptixPipeline, not already destroyed will be /// destroyed. /// /// Thread safety: A device context must not be destroyed while it is still in use by concurrent API calls in other threads. OptixResult optixDeviceContextDestroy( OptixDeviceContext context ); /// Query properties of a device context. /// /// \param[in] context the device context to query the property for /// \param[in] property the property to query /// \param[out] value pointer to the returned /// \param[in] sizeInBytes size of output OptixResult optixDeviceContextGetProperty( OptixDeviceContext context, OptixDeviceProperty property, void* value, size_t sizeInBytes ); /// Sets the current log callback method. /// /// See #OptixLogCallback for more details. /// /// Thread safety: It is guaranteed that the callback itself (callbackFunction and callbackData) are updated atomically. /// It is not guaranteed that the callback itself (callbackFunction and callbackData) and the callbackLevel are updated /// atomically. It is unspecified when concurrent API calls using the same context start to make use of the new /// callback method. /// /// \param[in] context the device context /// \param[in] callbackFunction the callback function to call /// \param[in] callbackData pointer to data passed to callback function while invoking it /// \param[in] callbackLevel callback level OptixResult optixDeviceContextSetLogCallback( OptixDeviceContext context, OptixLogCallback callbackFunction, void* callbackData, unsigned int callbackLevel ); /// Enables or disables the disk cache. /// /// If caching was previously disabled, enabling it will attempt to initialize /// the disk cache database using the currently configured cache location. An /// error will be returned if initialization fails. /// /// Note that no in-memory cache is used, so no caching behavior will be observed if the disk cache /// is disabled. /// /// The cache can be disabled by setting the environment variable OPTIX_CACHE_MAXSIZE=0. /// The environment variable takes precedence over this setting. /// See #optixDeviceContextSetCacheDatabaseSizes for additional information. /// /// Note that the disk cache can be disabled by the environment variable, but it cannot be enabled /// via the environment if it is disabled via the API. /// /// \param[in] context the device context /// \param[in] enabled 1 to enabled, 0 to disable OptixResult optixDeviceContextSetCacheEnabled( OptixDeviceContext context, int enabled ); /// Sets the location of the disk cache. /// /// The location is specified by a directory. This directory should not be used for other purposes /// and will be created if it does not exist. An error will be returned if is not possible to /// create the disk cache at the specified location for any reason (e.g., the path is invalid or /// the directory is not writable). Caching will be disabled if the disk cache cannot be /// initialized in the new location. If caching is disabled, no error will be returned until caching /// is enabled. If the disk cache is located on a network file share, behavior is undefined. /// /// The location of the disk cache can be overridden with the environment variable OPTIX_CACHE_PATH. /// The environment variable takes precedence over this setting. /// /// The default location depends on the operating system: /// - Windows: %LOCALAPPDATA%\\NVIDIA\\OptixCache /// - Linux: /var/tmp/OptixCache_\ (or /tmp/OptixCache_\ if the first choice is not usable), /// the underscore and username suffix are omitted if the username cannot be obtained /// - MacOS X: /Library/Application Support/NVIDIA/OptixCache /// /// \param[in] context the device context /// \param[in] location directory of disk cache OptixResult optixDeviceContextSetCacheLocation( OptixDeviceContext context, const char* location ); /// Sets the low and high water marks for disk cache garbage collection. /// /// Garbage collection is triggered when a new entry is written to the cache and /// the current cache data size plus the size of the cache entry that is about /// to be inserted exceeds the high water mark. Garbage collection proceeds until /// the size reaches the low water mark. Garbage collection will always free enough /// space to insert the new entry without exceeding the low water mark. Setting /// either limit to zero will disable garbage collection. An error will be returned /// if both limits are non-zero and the high water mark is smaller than the low water mark. /// /// Note that garbage collection is performed only on writes to the disk cache. No garbage /// collection is triggered on disk cache initialization or immediately when calling this function, /// but on subsequent inserting of data into the database. /// /// If the size of a compiled module exceeds the value configured for the high water /// mark and garbage collection is enabled, the module will not be added to the cache /// and a warning will be added to the log. /// /// The high water mark can be overridden with the environment variable OPTIX_CACHE_MAXSIZE. /// The environment variable takes precedence over the function parameters. The low water mark /// will be set to half the value of OPTIX_CACHE_MAXSIZE. Setting OPTIX_CACHE_MAXSIZE to 0 will /// disable the disk cache, but will not alter the contents of the cache. Negative and non-integer /// values will be ignored. /// /// \param[in] context the device context /// \param[in] lowWaterMark the low water mark /// \param[in] highWaterMark the high water mark OptixResult optixDeviceContextSetCacheDatabaseSizes( OptixDeviceContext context, size_t lowWaterMark, size_t highWaterMark ); /// Indicates whether the disk cache is enabled or disabled. /// /// \param[in] context the device context /// \param[out] enabled 1 if enabled, 0 if disabled OptixResult optixDeviceContextGetCacheEnabled( OptixDeviceContext context, int* enabled ); /// Returns the location of the disk cache. If the cache has been disabled by setting the environment /// variable OPTIX_CACHE_MAXSIZE=0, this function will return an empy string. /// /// \param[in] context the device context /// \param[out] location directory of disk cache, null terminated if locationSize > 0 /// \param[in] locationSize locationSize OptixResult optixDeviceContextGetCacheLocation( OptixDeviceContext context, char* location, size_t locationSize ); /// Returns the low and high water marks for disk cache garbage collection. If the cache has been disabled by /// setting the environment variable OPTIX_CACHE_MAXSIZE=0, this function will return 0 for the low and high /// water marks. /// /// \param[in] context the device context /// \param[out] lowWaterMark the low water mark /// \param[out] highWaterMark the high water mark OptixResult optixDeviceContextGetCacheDatabaseSizes( OptixDeviceContext context, size_t* lowWaterMark, size_t* highWaterMark ); //@} /// \defgroup optix_host_api_pipelines Pipelines /// \ingroup optix_host_api //@{ /// logString is an optional buffer that contains compiler feedback and errors. This /// information is also passed to the context logger (if enabled), however it may be /// difficult to correlate output to the logger to specific API invocations when using /// multiple threads. The output to logString will only contain feedback for this specific /// invocation of this API call. /// /// logStringSize as input should be a pointer to the number of bytes backing logString. /// Upon return it contains the length of the log message (including the null terminator) /// which may be greater than the input value. In this case, the log message will be /// truncated to fit into logString. /// /// If logString or logStringSize are NULL, no output is written to logString. If /// logStringSize points to a value that is zero, no output is written. This does not /// affect output to the context logger if enabled. /// /// \param[in] context /// \param[in] pipelineCompileOptions /// \param[in] pipelineLinkOptions /// \param[in] programGroups array of ProgramGroup objects /// \param[in] numProgramGroups number of ProgramGroup objects /// \param[out] logString Information will be written to this string. If logStringSize > 0 logString will be null terminated. /// \param[in,out] logStringSize /// \param[out] pipeline OptixResult optixPipelineCreate( OptixDeviceContext context, const OptixPipelineCompileOptions* pipelineCompileOptions, const OptixPipelineLinkOptions* pipelineLinkOptions, const OptixProgramGroup* programGroups, unsigned int numProgramGroups, char* logString, size_t* logStringSize, OptixPipeline* pipeline ); /// Thread safety: A pipeline must not be destroyed while it is still in use by concurrent API calls in other threads. OptixResult optixPipelineDestroy( OptixPipeline pipeline ); /// Sets the stack sizes for a pipeline. /// /// Users are encouraged to see the programming guide and the implementations of the helper functions /// to understand how to construct the stack sizes based on their particular needs. /// /// If this method is not used, an internal default implementation is used. The default implementation is correct (but /// not necessarily optimal) as long as the maximum depth of call trees of CC and DC programs is at most 2 and no motion transforms are used. /// /// The maxTraversableGraphDepth responds to the maximal number of traversables visited when calling trace. /// Every acceleration structure and motion transform count as one level of traversal. /// E.g., for a simple IAS (instance acceleration structure) -> GAS (geometry acceleration structure) /// traversal graph, the maxTraversableGraphDepth is two. /// For IAS -> MT (motion transform) -> GAS, the maxTraversableGraphDepth is three. /// Note that it does not matter whether a IAS or GAS has motion or not, it always counts as one. /// Launching optix with exceptions turned on (see #OPTIX_EXCEPTION_FLAG_TRACE_DEPTH) will throw an exception /// if the specified maxTraversableGraphDepth is too small. /// /// \param[in] pipeline The pipeline to configure the stack size for. /// \param[in] directCallableStackSizeFromTraversal The direct stack size requirement for direct callables invoked from IS or AH. /// \param[in] directCallableStackSizeFromState The direct stack size requirement for direct callables invoked from RG, MS, or CH. /// \param[in] continuationStackSize The continuation stack requirement. /// \param[in] maxTraversableGraphDepth The maximum depth of a traversable graph passed to trace. OptixResult optixPipelineSetStackSize( OptixPipeline pipeline, unsigned int directCallableStackSizeFromTraversal, unsigned int directCallableStackSizeFromState, unsigned int continuationStackSize, unsigned int maxTraversableGraphDepth ); //@} /// \defgroup optix_host_api_modules Modules /// \ingroup optix_host_api //@{ /// logString is an optional buffer that contains compiler feedback and errors. This /// information is also passed to the context logger (if enabled), however it may be /// difficult to correlate output to the logger to specific API invocations when using /// multiple threads. The output to logString will only contain feedback for this specific /// invocation of this API call. /// /// logStringSize as input should be a pointer to the number of bytes backing logString. /// Upon return it contains the length of the log message (including the null terminator) /// which may be greater than the input value. In this case, the log message will be /// truncated to fit into logString. /// /// If logString or logStringSize are NULL, no output is written to logString. If /// logStringSize points to a value that is zero, no output is written. This does not /// affect output to the context logger if enabled. /// /// \param[in] context /// \param[in] moduleCompileOptions /// \param[in] pipelineCompileOptions All modules in a pipeline need to use the same values for the pipeline compile options. /// \param[in] PTX Pointer to the PTX input string. /// \param[in] PTXsize Parsing proceeds up to PTXsize characters, or the first NUL byte, whichever occurs first. /// \param[out] logString Information will be written to this string. If logStringSize > 0 logString will be null terminated. /// \param[in,out] logStringSize /// \param[out] module /// /// \return OPTIX_ERROR_INVALID_VALUE - context is 0, moduleCompileOptions is 0, pipelineCompileOptions is 0, PTX is 0, module is 0. OptixResult optixModuleCreateFromPTX( OptixDeviceContext context, const OptixModuleCompileOptions* moduleCompileOptions, const OptixPipelineCompileOptions* pipelineCompileOptions, const char* PTX, size_t PTXsize, char* logString, size_t* logStringSize, OptixModule* module ); /// Call for OptixModule objects created with optixModuleCreateFromPTX and optixModuleDeserialize. /// /// Modules must not be destroyed while they are still used by any program group. /// /// Thread safety: A module must not be destroyed while it is still in use by concurrent API calls in other threads. OptixResult optixModuleDestroy( OptixModule module ); /// Returns a module containing the intersection program for the built-in primitive type specified /// by the builtinISOptions. This module must be used as the moduleIS for the OptixProgramGroupHitgroup /// in any SBT record for that primitive type. (The entryFunctionNameIS should be null.) OptixResult optixBuiltinISModuleGet( OptixDeviceContext context, const OptixModuleCompileOptions* moduleCompileOptions, const OptixPipelineCompileOptions* pipelineCompileOptions, const OptixBuiltinISOptions* builtinISOptions, OptixModule* builtinModule ); //@} /// \defgroup optix_host_api_program_groups Program groups /// \ingroup optix_host_api //@{ /// Returns the stack sizes for the given program group. /// /// \param[in] programGroup the program group /// \param[out] stackSizes the corresponding stack sizes OptixResult optixProgramGroupGetStackSize( OptixProgramGroup programGroup, OptixStackSizes* stackSizes ); /// logString is an optional buffer that contains compiler feedback and errors. This /// information is also passed to the context logger (if enabled), however it may be /// difficult to correlate output to the logger to specific API invocations when using /// multiple threads. The output to logString will only contain feedback for this specific /// invocation of this API call. /// /// logStringSize as input should be a pointer to the number of bytes backing logString. /// Upon return it contains the length of the log message (including the null terminator) /// which may be greater than the input value. In this case, the log message will be /// truncated to fit into logString. /// /// If logString or logStringSize are NULL, no output is written to logString. If /// logStringSize points to a value that is zero, no output is written. This does not /// affect output to the context logger if enabled. /// /// Creates numProgramGroups OptiXProgramGroup objects from the specified /// OptixProgramGroupDesc array. The size of the arrays must match. /// /// \param[in] context /// \param[in] programDescriptions N * OptixProgramGroupDesc /// \param[in] numProgramGroups N /// \param[in] options /// \param[out] logString Information will be written to this string. If logStringSize > 0 logString will be null terminated. /// \param[in,out] logStringSize /// \param[out] programGroups OptixResult optixProgramGroupCreate( OptixDeviceContext context, const OptixProgramGroupDesc* programDescriptions, unsigned int numProgramGroups, const OptixProgramGroupOptions* options, char* logString, size_t* logStringSize, OptixProgramGroup* programGroups ); /// Thread safety: A program group must not be destroyed while it is still in use by concurrent API calls in other threads. OptixResult optixProgramGroupDestroy( OptixProgramGroup programGroup ); //@} /// \defgroup optix_host_api_launches Launches /// \ingroup optix_host_api //@{ /// Where the magic happens. /// /// The stream and pipeline must belong to the same device context. Multiple launches /// may be issues in parallel from multiple threads to different streams. /// /// pipelineParamsSize number of bytes are copied from the device memory pointed to by /// pipelineParams before launch. It is an error if pipelineParamsSize is greater than the /// size of the variable declared in modules and identified by /// OptixPipelineCompileOptions::pipelineLaunchParamsVariableName. If the launch params /// variable was optimized out or not found in the modules linked to the pipeline then /// the pipelineParams and pipelineParamsSize parameters are ignored. /// /// sbt points to the shader binding table, which defines shader /// groupings and their resources. See the SBT spec. /// /// \param[in] pipeline /// \param[in] stream /// \param[in] pipelineParams /// \param[in] pipelineParamsSize /// \param[in] sbt /// \param[in] width number of elements to compute /// \param[in] height number of elements to compute /// \param[in] depth number of elements to compute /// /// Thread safety: In the current implementation concurrent launches to the same pipeline are not /// supported. Concurrent launches require separate OptixPipeline objects. OptixResult optixLaunch( OptixPipeline pipeline, CUstream stream, CUdeviceptr pipelineParams, size_t pipelineParamsSize, const OptixShaderBindingTable* sbt, unsigned int width, unsigned int height, unsigned int depth ); /// \param[in] programGroup the program group containing the program(s) /// \param[out] sbtRecordHeaderHostPointer the result sbt record header OptixResult optixSbtRecordPackHeader( OptixProgramGroup programGroup, void* sbtRecordHeaderHostPointer ); //@} /// \defgroup optix_host_api_acceleration_structures Acceleration structures /// \ingroup optix_host_api //@{ /// \param[in] context device context of the pipeline /// \param[in] accelOptions accel options /// \param[in] buildInputs an array of OptixBuildInput objects /// \param[in] numBuildInputs number of elements in buildInputs (must be at least 1) /// \param[out] bufferSizes fills in buffer sizes OptixResult optixAccelComputeMemoryUsage( OptixDeviceContext context, const OptixAccelBuildOptions* accelOptions, const OptixBuildInput* buildInputs, unsigned int numBuildInputs, OptixAccelBufferSizes* bufferSizes ); /// \param[in] context /// \param[in] stream /// \param[in] accelOptions accel options /// \param[in] buildInputs an array of OptixBuildInput objects /// \param[in] numBuildInputs must be >= 1 for GAS, and == 1 for IAS /// \param[in] tempBuffer must be a multiple of OPTIX_ACCEL_BUFFER_BYTE_ALIGNMENT /// \param[in] tempBufferSizeInBytes /// \param[in] outputBuffer must be a multiple of OPTIX_ACCEL_BUFFER_BYTE_ALIGNMENT /// \param[in] outputBufferSizeInBytes /// \param[out] outputHandle /// \param[out] emittedProperties types of requested properties and output buffers /// \param[in] numEmittedProperties number of post-build properties to populate (may be zero) OptixResult optixAccelBuild( OptixDeviceContext context, CUstream stream, const OptixAccelBuildOptions* accelOptions, const OptixBuildInput* buildInputs, unsigned int numBuildInputs, CUdeviceptr tempBuffer, size_t tempBufferSizeInBytes, CUdeviceptr outputBuffer, size_t outputBufferSizeInBytes, OptixTraversableHandle* outputHandle, const OptixAccelEmitDesc* emittedProperties, unsigned int numEmittedProperties ); /// Obtain relocation information, stored in OptixAccelRelocationInfo, for a given context /// and acceleration structure's traversable handle. /// /// The relocation information can be passed to optixAccelCheckRelocationCompatibility to /// determine if an acceleration structure, referenced by 'handle', can be relocated to a /// different device's memory space (see #optixAccelCheckRelocationCompatibility). /// /// When used with optixAccelRelocate, it provides data necessary for doing the relocation. /// /// If the acceleration structure data associated with 'handle' is copied multiple times, /// the same OptixAccelRelocationInfo can also be used on all copies. /// /// \param[in] context /// \param[in] handle /// \param[out] info /// \return OPTIX_ERROR_INVALID_VALUE will be returned for traversable handles that are not from /// acceleration structure builds. OptixResult optixAccelGetRelocationInfo( OptixDeviceContext context, OptixTraversableHandle handle, OptixAccelRelocationInfo* info ); /// Checks if an acceleration structure built using another OptixDeviceContext (that was /// used to fill in 'info') is compatible with the OptixDeviceContext specified in the /// 'context' parameter. /// /// Any device is always compatible with itself. /// /// \param[in] context /// \param[in] info /// \param[out] compatible If OPTIX_SUCCESS is returned 'compatible' will have the value of either: /// - 0: This context is not compatible with acceleration structure data associated with 'info'. /// - 1: This context is compatible. OptixResult optixAccelCheckRelocationCompatibility( OptixDeviceContext context, const OptixAccelRelocationInfo* info, int* compatible ); /// optixAccelRelocate is called to update the acceleration structure after it has been /// relocated. Relocation is necessary when the acceleration structure's location in device /// memory has changed. optixAccelRelocate does not copy the memory. This function only /// operates on the relocated memory who's new location is specified by 'targetAccel'. /// optixAccelRelocate also returns the new OptixTraversableHandle associated with /// 'targetAccel'. The original memory (source) is not required to be valid, only the /// OptixAccelRelocationInfo. /// /// Before copying the data and calling optixAccelRelocate, /// optixAccelCheckRelocationCompatibility should be called to ensure the copy will be /// compatible with the destination device context. /// /// The memory pointed to by 'targetAccel' should be allocated with the same size as the /// source acceleration. Similar to the 'outputBuffer' used in optixAccelBuild, this /// pointer must be a multiple of OPTIX_ACCEL_BUFFER_BYTE_ALIGNMENT. /// /// The memory in 'targetAccel' must be allocated as long as the accel is in use. /// /// When relocating an accel that contains instances, 'instanceTraversableHandles' and /// 'numInstanceTraversableHandles' should be supplied. These are the traversable handles /// of the instances. These can be used when also relocating the instances. No updates to /// the bounds are performed. Use optixAccelBuild to update the bounds. /// 'instanceTraversableHandles' and 'numInstanceTraversableHandles' may be zero when /// relocating bottom level accel (i.e. an accel with no instances). /// /// \param[in] context /// \param[in] stream /// \param[in] info /// \param[in] instanceTraversableHandles /// \param[in] numInstanceTraversableHandles /// \param[in] targetAccel /// \param[in] targetAccelSizeInBytes /// \param[out] targetHandle OptixResult optixAccelRelocate( OptixDeviceContext context, CUstream stream, const OptixAccelRelocationInfo* info, CUdeviceptr instanceTraversableHandles, size_t numInstanceTraversableHandles, CUdeviceptr targetAccel, size_t targetAccelSizeInBytes, OptixTraversableHandle* targetHandle ); /// After building an acceleration structure, it can be copied in a compacted form to reduce /// memory. In order to be compacted, OPTIX_BUILD_FLAG_ALLOW_COMPACTION must be supplied in /// OptixAccelBuildOptions::buildFlags passed to optixAccelBuild. /// /// 'outputBuffer' is the pointer to where the compacted acceleration structure will be /// written. This pointer must be a multiple of OPTIX_ACCEL_BUFFER_BYTE_ALIGNMENT. /// /// The size of the memory specified in 'outputBufferSizeInBytes' should be at least the /// value computed using the OPTIX_PROPERTY_TYPE_COMPACTED_SIZE that was reported during /// optixAccelBuild. /// /// \param[in] context /// \param[in] stream /// \param[in] inputHandle /// \param[in] outputBuffer /// \param[in] outputBufferSizeInBytes /// \param[out] outputHandle OptixResult optixAccelCompact( OptixDeviceContext context, CUstream stream, OptixTraversableHandle inputHandle, CUdeviceptr outputBuffer, size_t outputBufferSizeInBytes, OptixTraversableHandle* outputHandle ); /// \param[in] onDevice /// \param[in] pointer pointer to traversable allocated in OptixDeviceContext. This pointer must be a multiple of OPTIX_TRANSFORM_BYTE_ALIGNMENT /// \param[in] traversableType Type of OptixTraversableHandle to create /// \param[out] traversableHandle traversable handle. traversableHandle must be in host memory OptixResult optixConvertPointerToTraversableHandle( OptixDeviceContext onDevice, CUdeviceptr pointer, OptixTraversableType traversableType, OptixTraversableHandle* traversableHandle ); //@} /// \defgroup optix_host_api_denoiser Denoiser /// \ingroup optix_host_api //@{ /// Creates a denoiser object with the given options, using built-in inference models /// /// 'modelKind' selects the model used for inference. /// Inference for the built-in models can be guided (giving hints to improve image quality) with /// albedo and normal vector images in the guide layer (see 'optixDenoiserInvoke'). /// Use of these images must be enabled in 'OptixDenoiserOptions'. /// /// \param[in] context /// \param[in] modelKind /// \param[in] options /// \param[out] denoiser OptixResult optixDenoiserCreate( OptixDeviceContext context, OptixDenoiserModelKind modelKind, const OptixDenoiserOptions* options, OptixDenoiser* denoiser ); /// Creates a denoiser object with the given options, using a provided inference model /// /// 'userData' and 'userDataSizeInBytes' provide a user model for inference. /// The memory passed in userData will be accessed only during the invocation of this function and /// can be freed after it returns. /// The user model must export only one weight set which determines both the model kind and the /// required set of guide images. /// /// \param[in] context /// \param[in] userData /// \param[in] userDataSizeInBytes /// \param[out] denoiser OptixResult optixDenoiserCreateWithUserModel( OptixDeviceContext context, const void* userData, size_t userDataSizeInBytes, OptixDenoiser* denoiser ); /// Destroys the denoiser object and any associated host resources. OptixResult optixDenoiserDestroy( OptixDenoiser denoiser ); /// Computes the GPU memory resources required to execute the denoiser. /// /// Memory for state and scratch buffers must be allocated with the sizes in 'returnSizes' and scratch memory /// passed to optixDenoiserSetup, optixDenoiserInvoke, /// optixDenoiserComputeIntensity and optixDenoiserComputeAverageColor. /// For tiled denoising an overlap area must be added to each tile on all sides which increases the amount of /// memory needed to denoise a tile. In case of tiling use withOverlapScratchSizeInBytes. /// If only full resolution images are denoised, withoutOverlapScratchSizeInBytes can be used which is always /// smaller than withOverlapScratchSizeInBytes. /// /// 'outputWidth' and 'outputHeight' is the dimension of the image to be denoised (without overlap in case tiling /// is being used). /// 'outputWidth' and 'outputHeight' must be greater than or equal to the dimensions passed to optixDenoiserSetup. /// /// \param[in] denoiser /// \param[in] outputWidth /// \param[in] outputHeight /// \param[out] returnSizes OptixResult optixDenoiserComputeMemoryResources( const OptixDenoiser denoiser, unsigned int outputWidth, unsigned int outputHeight, OptixDenoiserSizes* returnSizes ); /// Initializes the state required by the denoiser. /// /// 'inputWidth' and 'inputHeight' must include overlap on both sides of the image if tiling is being used. The overlap is /// returned by #optixDenoiserComputeMemoryResources. /// For subsequent calls to #optixDenoiserInvoke 'inputWidth' and 'inputHeight' are the maximum dimensions /// of the input layers. Dimensions of the input layers passed to #optixDenoiserInvoke may be different in each /// invocation however they always must be smaller than 'inputWidth' and 'inputHeight' passed to #optixDenoiserSetup. /// /// \param[in] denoiser /// \param[in] stream /// \param[in] inputWidth /// \param[in] inputHeight /// \param[in] denoiserState /// \param[in] denoiserStateSizeInBytes /// \param[in] scratch /// \param[in] scratchSizeInBytes OptixResult optixDenoiserSetup( OptixDenoiser denoiser, CUstream stream, unsigned int inputWidth, unsigned int inputHeight, CUdeviceptr denoiserState, size_t denoiserStateSizeInBytes, CUdeviceptr scratch, size_t scratchSizeInBytes ); /// Invokes denoiser on a set of input data and produces at least one output image. /// State memory must be available during the execution of the /// denoiser (or until optixDenoiserSetup is called with a new state memory pointer). /// Scratch memory passed is used only for the duration of this function. /// Scratch and state memory sizes must have a size greater than or equal to the sizes as returned by /// optixDenoiserComputeMemoryResources. /// /// 'inputOffsetX' and 'inputOffsetY' are pixel offsets in the 'inputLayers' image /// specifying the beginning of the image without overlap. When denoising an entire image without tiling /// there is no overlap and 'inputOffsetX' and 'inputOffsetY' must be zero. When denoising a tile which is /// adjacent to one of the four sides of the entire image the corresponding offsets must also be zero since /// there is no overlap at the side adjacent to the image border. /// /// 'guideLayer' provides additional information to the denoiser. When providing albedo and normal vector /// guide images, the corresponding fields in the 'OptixDenoiserOptions' must be /// enabled, see #optixDenoiserCreate. /// 'guideLayer' must not be null. If a guide image in 'OptixDenoiserOptions' is not enabled, the /// corresponding image in 'OptixDenoiserGuideLayer' is ignored. /// /// If OPTIX_DENOISER_MODEL_KIND_TEMPORAL is selected, a 2d flow image must be given in 'OptixDenoiserGuideLayer'. /// It describes for each pixel the flow from the previous to the current frame (a 2d vector in pixel space). /// The denoised beauty/AOV of the previous frame must be given in 'previousOutput'. /// If this image is not available in the first frame of a sequence, the noisy beauty/AOV from the first frame /// and zero flow vectors could be given as a substitute. /// For non-temporal model kinds the flow image in 'OptixDenoiserGuideLayer' is ignored. /// 'previousOutput' and /// 'output' may refer to the same buffer, i.e. 'previousOutput' is first read by this function and later /// overwritten with the denoised result. 'output' can be passed as 'previousOutput' to the next frame. /// In other model kinds (not temporal) 'previousOutput' is ignored. /// /// The beauty layer must be given as the first entry in 'layers'. /// In AOV type model kinds (OPTIX_DENOISER_MODEL_KIND_AOV or in user defined models implementing /// kernel-prediction) additional layers for the AOV images can be given. /// In each layer the noisy input image is given in 'input', the denoised output is written into the /// 'output' image. input and output images may refer to the same buffer, with the restriction that /// the pixel formats must be identical for input and output when the blend mode is selected (see /// #OptixDenoiserParams). /// /// If OPTIX_DENOISER_MODEL_KIND_TEMPORAL is selected, the /// normal vector guide image must be given as 3d vectors in camera space. In the other models only /// the x and y channels are used and other channels are ignored. /// /// \param[in] denoiser /// \param[in] stream /// \param[in] params /// \param[in] denoiserState /// \param[in] denoiserStateSizeInBytes /// \param[in] guideLayer /// \param[in] layers /// \param[in] numLayers /// \param[in] inputOffsetX /// \param[in] inputOffsetY /// \param[in] outputLayer /// \param[in] scratch /// \param[in] scratchSizeInBytes OptixResult optixDenoiserInvoke( OptixDenoiser denoiser, CUstream stream, const OptixDenoiserParams* params, CUdeviceptr denoiserState, size_t denoiserStateSizeInBytes, const OptixDenoiserGuideLayer* guideLayer, const OptixDenoiserLayer* layers, unsigned int numLayers, unsigned int inputOffsetX, unsigned int inputOffsetY, CUdeviceptr scratch, size_t scratchSizeInBytes ); /// Computes the logarithmic average intensity of the given image. The returned value 'outputIntensity' /// is multiplied with the RGB values of the input image/tile in optixDenoiserInvoke if given in the parameter /// OptixDenoiserParams::hdrIntensity (otherwise 'hdrIntensity' must be a null pointer). This is useful for /// denoising HDR images which are very dark or bright. /// When denoising tiles the intensity of the entire image should be computed, i.e. not per tile to get /// consistent results. /// /// For each RGB pixel in the inputImage the intensity is calculated and summed if it is greater than 1e-8f: /// intensity = log(r * 0.212586f + g * 0.715170f + b * 0.072200f). /// The function returns 0.18 / exp(sum of intensities / number of summed pixels). /// More details could be found in the Reinhard tonemapping paper: /// http://www.cmap.polytechnique.fr/~peyre/cours/x2005signal/hdr_photographic.pdf /// /// This function needs scratch memory with a size of at least /// sizeof( int ) * ( 2 + inputImage::width * inputImage::height ). When denoising entire images (no tiling) /// the same scratch memory as passed to optixDenoiserInvoke could be used. // /// data type unsigned char is not supported for 'inputImage', it must be 3 or 4 component half/float. /// /// \param[in] denoiser /// \param[in] stream /// \param[in] inputImage /// \param[out] outputIntensity single float /// \param[in] scratch /// \param[in] scratchSizeInBytes OptixResult optixDenoiserComputeIntensity( OptixDenoiser denoiser, CUstream stream, const OptixImage2D* inputImage, CUdeviceptr outputIntensity, CUdeviceptr scratch, size_t scratchSizeInBytes ); /// Compute average logarithmic for each of the first three channels for the given image. /// When denoising tiles the intensity of the entire image should be computed, i.e. not per tile to get /// consistent results. /// This function needs scratch memory with a size of at least /// sizeof( int ) * ( 3 + 3 * inputImage::width * inputImage::height ). When denoising entire images (no tiling) /// the same scratch memory as passed to optixDenoiserInvoke could be used. /// /// data type unsigned char is not supported for 'inputImage', it must be 3 or 4 component half/float. /// /// \param[in] denoiser /// \param[in] stream /// \param[in] inputImage /// \param[out] outputAverageColor three floats /// \param[in] scratch /// \param[in] scratchSizeInBytes OptixResult optixDenoiserComputeAverageColor( OptixDenoiser denoiser, CUstream stream, const OptixImage2D* inputImage, CUdeviceptr outputAverageColor, CUdeviceptr scratch, size_t scratchSizeInBytes ); //@} #ifdef __cplusplus } #endif #include "optix_function_table.h" #endif // __optix_optix_7_host_h__ ================================================ FILE: render/optixutils/include/optix_7_types.h ================================================ /* * Copyright (c) 2021 NVIDIA Corporation. All rights reserved. * * NVIDIA Corporation and its licensors retain all intellectual property and proprietary * rights in and to this software, related documentation and any modifications thereto. * Any use, reproduction, disclosure or distribution of this software and related * documentation without an express license agreement from NVIDIA Corporation is strictly * prohibited. * * TO THE MAXIMUM EXTENT PERMITTED BY APPLICABLE LAW, THIS SOFTWARE IS PROVIDED *AS IS* * AND NVIDIA AND ITS SUPPLIERS DISCLAIM ALL WARRANTIES, EITHER EXPRESS OR IMPLIED, * INCLUDING, BUT NOT LIMITED TO, IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A * PARTICULAR PURPOSE. IN NO EVENT SHALL NVIDIA OR ITS SUPPLIERS BE LIABLE FOR ANY * SPECIAL, INCIDENTAL, INDIRECT, OR CONSEQUENTIAL DAMAGES WHATSOEVER (INCLUDING, WITHOUT * LIMITATION, DAMAGES FOR LOSS OF BUSINESS PROFITS, BUSINESS INTERRUPTION, LOSS OF * BUSINESS INFORMATION, OR ANY OTHER PECUNIARY LOSS) ARISING OUT OF THE USE OF OR * INABILITY TO USE THIS SOFTWARE, EVEN IF NVIDIA HAS BEEN ADVISED OF THE POSSIBILITY OF * SUCH DAMAGES */ /// @file /// @author NVIDIA Corporation /// @brief OptiX public API header /// /// OptiX types include file -- defines types and enums used by the API. /// For the math library routines include optix_math.h #if !defined( __OPTIX_INCLUDE_INTERNAL_HEADERS__ ) #error("optix_7_types.h is an internal header file and must not be used directly. Please use optix_types.h, optix_host.h, optix_device.h or optix.h instead.") #endif #ifndef __optix_optix_7_types_h__ #define __optix_optix_7_types_h__ #if !defined(__CUDACC_RTC__) #include /* for size_t */ #endif /// \defgroup optix_types Types /// \brief OptiX Types /** \addtogroup optix_types @{ */ // This typedef should match the one in cuda.h in order to avoid compilation errors. #if defined(__x86_64) || defined(AMD64) || defined(_M_AMD64) || defined(__powerpc64__) || defined(__EDG_IA64_ABI)/*=NVRTC*/ || defined(__aarch64__) /// CUDA device pointer typedef unsigned long long CUdeviceptr; #else /// CUDA device pointer typedef unsigned int CUdeviceptr; #endif /// Opaque type representing a device context typedef struct OptixDeviceContext_t* OptixDeviceContext; /// Opaque type representing a module typedef struct OptixModule_t* OptixModule; /// Opaque type representing a program group typedef struct OptixProgramGroup_t* OptixProgramGroup; /// Opaque type representing a pipeline typedef struct OptixPipeline_t* OptixPipeline; /// Opaque type representing a denoiser instance typedef struct OptixDenoiser_t* OptixDenoiser; /// Traversable handle typedef unsigned long long OptixTraversableHandle; /// Visibility mask typedef unsigned int OptixVisibilityMask; /// Size of the SBT record headers. #define OPTIX_SBT_RECORD_HEADER_SIZE ( (size_t)32 ) /// Alignment requirement for device pointers in OptixShaderBindingTable. #define OPTIX_SBT_RECORD_ALIGNMENT 16ull /// Alignment requirement for output and temporay buffers for acceleration structures. #define OPTIX_ACCEL_BUFFER_BYTE_ALIGNMENT 128ull /// Alignment requirement for OptixBuildInputInstanceArray::instances. #define OPTIX_INSTANCE_BYTE_ALIGNMENT 16ull /// Alignment requirement for OptixBuildInputCustomPrimitiveArray::aabbBuffers #define OPTIX_AABB_BUFFER_BYTE_ALIGNMENT 8ull /// Alignment requirement for OptixBuildInputTriangleArray::preTransform #define OPTIX_GEOMETRY_TRANSFORM_BYTE_ALIGNMENT 16ull /// Alignment requirement for OptixStaticTransform, OptixMatrixMotionTransform, OptixSRTMotionTransform. #define OPTIX_TRANSFORM_BYTE_ALIGNMENT 64ull /// Maximum number of registers allowed. Defaults to no explicit limit. #define OPTIX_COMPILE_DEFAULT_MAX_REGISTER_COUNT 0 /// Maximum number of payload values allowed. #define OPTIX_COMPILE_DEFAULT_MAX_PAYLOAD_VALUE_COUNT 8 /// Result codes returned from API functions /// /// All host side API functions return OptixResult with the exception of optixGetErrorName /// and optixGetErrorString. When successful OPTIX_SUCCESS is returned. All return codes /// except for OPTIX_SUCCESS should be assumed to be errors as opposed to a warning. /// /// \see #optixGetErrorName(), #optixGetErrorString() typedef enum OptixResult { OPTIX_SUCCESS = 0, OPTIX_ERROR_INVALID_VALUE = 7001, OPTIX_ERROR_HOST_OUT_OF_MEMORY = 7002, OPTIX_ERROR_INVALID_OPERATION = 7003, OPTIX_ERROR_FILE_IO_ERROR = 7004, OPTIX_ERROR_INVALID_FILE_FORMAT = 7005, OPTIX_ERROR_DISK_CACHE_INVALID_PATH = 7010, OPTIX_ERROR_DISK_CACHE_PERMISSION_ERROR = 7011, OPTIX_ERROR_DISK_CACHE_DATABASE_ERROR = 7012, OPTIX_ERROR_DISK_CACHE_INVALID_DATA = 7013, OPTIX_ERROR_LAUNCH_FAILURE = 7050, OPTIX_ERROR_INVALID_DEVICE_CONTEXT = 7051, OPTIX_ERROR_CUDA_NOT_INITIALIZED = 7052, OPTIX_ERROR_VALIDATION_FAILURE = 7053, OPTIX_ERROR_INVALID_PTX = 7200, OPTIX_ERROR_INVALID_LAUNCH_PARAMETER = 7201, OPTIX_ERROR_INVALID_PAYLOAD_ACCESS = 7202, OPTIX_ERROR_INVALID_ATTRIBUTE_ACCESS = 7203, OPTIX_ERROR_INVALID_FUNCTION_USE = 7204, OPTIX_ERROR_INVALID_FUNCTION_ARGUMENTS = 7205, OPTIX_ERROR_PIPELINE_OUT_OF_CONSTANT_MEMORY = 7250, OPTIX_ERROR_PIPELINE_LINK_ERROR = 7251, OPTIX_ERROR_ILLEGAL_DURING_TASK_EXECUTE = 7270, OPTIX_ERROR_INTERNAL_COMPILER_ERROR = 7299, OPTIX_ERROR_DENOISER_MODEL_NOT_SET = 7300, OPTIX_ERROR_DENOISER_NOT_INITIALIZED = 7301, OPTIX_ERROR_ACCEL_NOT_COMPATIBLE = 7400, OPTIX_ERROR_NOT_SUPPORTED = 7800, OPTIX_ERROR_UNSUPPORTED_ABI_VERSION = 7801, OPTIX_ERROR_FUNCTION_TABLE_SIZE_MISMATCH = 7802, OPTIX_ERROR_INVALID_ENTRY_FUNCTION_OPTIONS = 7803, OPTIX_ERROR_LIBRARY_NOT_FOUND = 7804, OPTIX_ERROR_ENTRY_SYMBOL_NOT_FOUND = 7805, OPTIX_ERROR_LIBRARY_UNLOAD_FAILURE = 7806, OPTIX_ERROR_CUDA_ERROR = 7900, OPTIX_ERROR_INTERNAL_ERROR = 7990, OPTIX_ERROR_UNKNOWN = 7999, } OptixResult; /// Parameters used for #optixDeviceContextGetProperty() /// /// \see #optixDeviceContextGetProperty() typedef enum OptixDeviceProperty { /// Maximum value for OptixPipelineLinkOptions::maxTraceDepth. sizeof( unsigned int ) OPTIX_DEVICE_PROPERTY_LIMIT_MAX_TRACE_DEPTH = 0x2001, /// Maximum value to pass into optixPipelineSetStackSize for parameter /// maxTraversableGraphDepth.v sizeof( unsigned int ) OPTIX_DEVICE_PROPERTY_LIMIT_MAX_TRAVERSABLE_GRAPH_DEPTH = 0x2002, /// The maximum number of primitives (over all build inputs) as input to a single /// Geometry Acceleration Structure (GAS). sizeof( unsigned int ) OPTIX_DEVICE_PROPERTY_LIMIT_MAX_PRIMITIVES_PER_GAS = 0x2003, /// The maximum number of instances (over all build inputs) as input to a single /// Instance Acceleration Structure (IAS). sizeof( unsigned int ) OPTIX_DEVICE_PROPERTY_LIMIT_MAX_INSTANCES_PER_IAS = 0x2004, /// The RT core version supported by the device (0 for no support, 10 for version /// 1.0). sizeof( unsigned int ) OPTIX_DEVICE_PROPERTY_RTCORE_VERSION = 0x2005, /// The maximum value for #OptixInstance::instanceId. sizeof( unsigned int ) OPTIX_DEVICE_PROPERTY_LIMIT_MAX_INSTANCE_ID = 0x2006, /// The number of bits available for the #OptixInstance::visibilityMask. /// Higher bits must be set to zero. sizeof( unsigned int ) OPTIX_DEVICE_PROPERTY_LIMIT_NUM_BITS_INSTANCE_VISIBILITY_MASK = 0x2007, /// The maximum number of instances that can be added to a single Instance /// Acceleration Structure (IAS). sizeof( unsigned int ) OPTIX_DEVICE_PROPERTY_LIMIT_MAX_SBT_RECORDS_PER_GAS = 0x2008, /// The maximum value for #OptixInstance::sbtOffset. sizeof( unsigned int ) OPTIX_DEVICE_PROPERTY_LIMIT_MAX_SBT_OFFSET = 0x2009, } OptixDeviceProperty; /// Type of the callback function used for log messages. /// /// \param[in] level The log level indicates the severity of the message. See below for /// possible values. /// \param[in] tag A terse message category description (e.g., 'SCENE STAT'). /// \param[in] message Null terminated log message (without newline at the end). /// \param[in] cbdata Callback data that was provided with the callback pointer. /// /// It is the users responsibility to ensure thread safety within this function. /// /// The following log levels are defined. /// /// 0 disable Setting the callback level will disable all messages. The callback /// function will not be called in this case. /// 1 fatal A non-recoverable error. The context and/or OptiX itself might no longer /// be in a usable state. /// 2 error A recoverable error, e.g., when passing invalid call parameters. /// 3 warning Hints that OptiX might not behave exactly as requested by the user or /// may perform slower than expected. /// 4 print Status or progress messages. /// /// Higher levels might occur. /// /// \see #optixDeviceContextSetLogCallback(), #OptixDeviceContextOptions typedef void ( *OptixLogCallback )( unsigned int level, const char* tag, const char* message, void* cbdata ); /// Validation mode settings. /// /// When enabled, certain device code utilities will be enabled to provide as good debug and /// error checking facilities as possible. /// /// /// \see #optixDeviceContextCreate() typedef enum OptixDeviceContextValidationMode { OPTIX_DEVICE_CONTEXT_VALIDATION_MODE_OFF = 0, OPTIX_DEVICE_CONTEXT_VALIDATION_MODE_ALL = 0xFFFFFFFF } OptixDeviceContextValidationMode; /// Parameters used for #optixDeviceContextCreate() /// /// \see #optixDeviceContextCreate() typedef struct OptixDeviceContextOptions { /// Function pointer used when OptiX wishes to generate messages OptixLogCallback logCallbackFunction; /// Pointer stored and passed to logCallbackFunction when a message is generated void* logCallbackData; /// Maximum callback level to generate message for (see #OptixLogCallback) int logCallbackLevel; /// Validation mode of context. OptixDeviceContextValidationMode validationMode; } OptixDeviceContextOptions; /// Flags used by #OptixBuildInputTriangleArray::flags /// and #OptixBuildInputCurveArray::flag /// and #OptixBuildInputCustomPrimitiveArray::flags typedef enum OptixGeometryFlags { /// No flags set OPTIX_GEOMETRY_FLAG_NONE = 0, /// Disables the invocation of the anyhit program. /// Can be overridden by OPTIX_INSTANCE_FLAG_ENFORCE_ANYHIT and OPTIX_RAY_FLAG_ENFORCE_ANYHIT. OPTIX_GEOMETRY_FLAG_DISABLE_ANYHIT = 1u << 0, /// If set, an intersection with the primitive will trigger one and only one /// invocation of the anyhit program. Otherwise, the anyhit program may be invoked /// more than once. OPTIX_GEOMETRY_FLAG_REQUIRE_SINGLE_ANYHIT_CALL = 1u << 1 } OptixGeometryFlags; /// Legacy type: A subset of the hit kinds for built-in primitive intersections. /// It is preferred to use optixGetPrimitiveType(), together with /// optixIsFrontFaceHit() or optixIsBackFaceHit(). /// /// \see #optixGetHitKind() typedef enum OptixHitKind { /// Ray hit the triangle on the front face OPTIX_HIT_KIND_TRIANGLE_FRONT_FACE = 0xFE, /// Ray hit the triangle on the back face OPTIX_HIT_KIND_TRIANGLE_BACK_FACE = 0xFF } OptixHitKind; /// Format of indices used int #OptixBuildInputTriangleArray::indexFormat. typedef enum OptixIndicesFormat { /// No indices, this format must only be used in combination with triangle soups, i.e., numIndexTriplets must be zero OPTIX_INDICES_FORMAT_NONE = 0, /// Three shorts OPTIX_INDICES_FORMAT_UNSIGNED_SHORT3 = 0x2102, /// Three ints OPTIX_INDICES_FORMAT_UNSIGNED_INT3 = 0x2103 } OptixIndicesFormat; /// Format of vertices used in #OptixBuildInputTriangleArray::vertexFormat. typedef enum OptixVertexFormat { OPTIX_VERTEX_FORMAT_NONE = 0, ///< No vertices OPTIX_VERTEX_FORMAT_FLOAT3 = 0x2121, ///< Vertices are represented by three floats OPTIX_VERTEX_FORMAT_FLOAT2 = 0x2122, ///< Vertices are represented by two floats OPTIX_VERTEX_FORMAT_HALF3 = 0x2123, ///< Vertices are represented by three halfs OPTIX_VERTEX_FORMAT_HALF2 = 0x2124, ///< Vertices are represented by two halfs OPTIX_VERTEX_FORMAT_SNORM16_3 = 0x2125, OPTIX_VERTEX_FORMAT_SNORM16_2 = 0x2126 } OptixVertexFormat; /// Format of transform used in #OptixBuildInputTriangleArray::transformFormat. typedef enum OptixTransformFormat { OPTIX_TRANSFORM_FORMAT_NONE = 0, ///< no transform, default for zero initialization OPTIX_TRANSFORM_FORMAT_MATRIX_FLOAT12 = 0x21E1, ///< 3x4 row major affine matrix } OptixTransformFormat; /// Triangle inputs /// /// \see #OptixBuildInput::triangleArray typedef struct OptixBuildInputTriangleArray { /// Points to host array of device pointers, one per motion step. Host array size must match the number of /// motion keys as set in #OptixMotionOptions (or an array of size 1 if OptixMotionOptions::numKeys is set /// to 0 or 1). Each per motion key device pointer must point to an array of vertices of the /// triangles in the format as described by vertexFormat. The minimum alignment must match the natural /// alignment of the type as specified in the vertexFormat, i.e., for OPTIX_VERTEX_FORMAT_FLOATX 4-byte, /// for all others a 2-byte alignment. However, an 16-byte stride (and buffer alignment) is recommended for /// vertices of format OPTIX_VERTEX_FORMAT_FLOAT3 for GAS build performance. const CUdeviceptr* vertexBuffers; /// Number of vertices in each of buffer in OptixBuildInputTriangleArray::vertexBuffers. unsigned int numVertices; /// \see #OptixVertexFormat OptixVertexFormat vertexFormat; /// Stride between vertices. If set to zero, vertices are assumed to be tightly /// packed and stride is inferred from vertexFormat. unsigned int vertexStrideInBytes; /// Optional pointer to array of 16 or 32-bit int triplets, one triplet per triangle. /// The minimum alignment must match the natural alignment of the type as specified in the indexFormat, i.e., /// for OPTIX_INDICES_FORMAT_UNSIGNED_INT3 4-byte and for OPTIX_INDICES_FORMAT_UNSIGNED_SHORT3 a 2-byte alignment. CUdeviceptr indexBuffer; /// Size of array in OptixBuildInputTriangleArray::indexBuffer. For build, needs to be zero if indexBuffer is \c nullptr. unsigned int numIndexTriplets; /// \see #OptixIndicesFormat OptixIndicesFormat indexFormat; /// Stride between triplets of indices. If set to zero, indices are assumed to be tightly /// packed and stride is inferred from indexFormat. unsigned int indexStrideInBytes; /// Optional pointer to array of floats /// representing a 3x4 row major affine /// transformation matrix. This pointer must be a multiple of OPTIX_GEOMETRY_TRANSFORM_BYTE_ALIGNMENT CUdeviceptr preTransform; /// Array of flags, to specify flags per sbt record, /// combinations of OptixGeometryFlags describing the /// primitive behavior, size must match numSbtRecords const unsigned int* flags; /// Number of sbt records available to the sbt index offset override. unsigned int numSbtRecords; /// Device pointer to per-primitive local sbt index offset buffer. May be NULL. /// Every entry must be in range [0,numSbtRecords-1]. /// Size needs to be the number of primitives. CUdeviceptr sbtIndexOffsetBuffer; /// Size of type of the sbt index offset. Needs to be 0, 1, 2 or 4 (8, 16 or 32 bit). unsigned int sbtIndexOffsetSizeInBytes; /// Stride between the index offsets. If set to zero, the offsets are assumed to be tightly /// packed and the stride matches the size of the type (sbtIndexOffsetSizeInBytes). unsigned int sbtIndexOffsetStrideInBytes; /// Primitive index bias, applied in optixGetPrimitiveIndex(). /// Sum of primitiveIndexOffset and number of triangles must not overflow 32bits. unsigned int primitiveIndexOffset; /// \see #OptixTransformFormat OptixTransformFormat transformFormat; } OptixBuildInputTriangleArray; /// Builtin primitive types /// typedef enum OptixPrimitiveType { /// Custom primitive. OPTIX_PRIMITIVE_TYPE_CUSTOM = 0x2500, /// B-spline curve of degree 2 with circular cross-section. OPTIX_PRIMITIVE_TYPE_ROUND_QUADRATIC_BSPLINE = 0x2501, /// B-spline curve of degree 3 with circular cross-section. OPTIX_PRIMITIVE_TYPE_ROUND_CUBIC_BSPLINE = 0x2502, /// Piecewise linear curve with circular cross-section. OPTIX_PRIMITIVE_TYPE_ROUND_LINEAR = 0x2503, /// Triangle. OPTIX_PRIMITIVE_TYPE_TRIANGLE = 0x2531, } OptixPrimitiveType; /// Builtin flags may be bitwise combined. /// /// \see #OptixPipelineCompileOptions::usesPrimitiveTypeFlags typedef enum OptixPrimitiveTypeFlags { /// Custom primitive. OPTIX_PRIMITIVE_TYPE_FLAGS_CUSTOM = 1 << 0, /// B-spline curve of degree 2 with circular cross-section. OPTIX_PRIMITIVE_TYPE_FLAGS_ROUND_QUADRATIC_BSPLINE = 1 << 1, /// B-spline curve of degree 3 with circular cross-section. OPTIX_PRIMITIVE_TYPE_FLAGS_ROUND_CUBIC_BSPLINE = 1 << 2, /// Piecewise linear curve with circular cross-section. OPTIX_PRIMITIVE_TYPE_FLAGS_ROUND_LINEAR = 1 << 3, /// Triangle. OPTIX_PRIMITIVE_TYPE_FLAGS_TRIANGLE = 1 << 31, } OptixPrimitiveTypeFlags; /// Curve inputs /// /// A curve is a swept surface defined by a 3D spline curve and a varying width (radius). A curve (or "strand") of /// degree d (3=cubic, 2=quadratic, 1=linear) is represented by N > d vertices and N width values, and comprises N - d segments. /// Each segment is defined by d+1 consecutive vertices. Each curve may have a different number of vertices. /// /// OptiX describes the curve array as a list of curve segments. The primitive id is the segment number. /// It is the user's responsibility to maintain a mapping between curves and curve segments. /// Each index buffer entry i = indexBuffer[primid] specifies the start of a curve segment, /// represented by d+1 consecutive vertices in the vertex buffer, /// and d+1 consecutive widths in the width buffer. Width is interpolated the same /// way vertices are interpolated, that is, using the curve basis. /// /// Each curves build input has only one SBT record. /// To create curves with different materials in the same BVH, use multiple build inputs. /// /// \see #OptixBuildInput::curveArray typedef struct OptixBuildInputCurveArray { /// Curve degree and basis /// \see #OptixPrimitiveType OptixPrimitiveType curveType; /// Number of primitives. Each primitive is a polynomial curve segment. unsigned int numPrimitives; /// Pointer to host array of device pointers, one per motion step. Host array size must match number of /// motion keys as set in #OptixMotionOptions (or an array of size 1 if OptixMotionOptions::numKeys is set /// to 1). Each per-motion-key device pointer must point to an array of floats (the vertices of the /// curves). const CUdeviceptr* vertexBuffers; /// Number of vertices in each buffer in vertexBuffers. unsigned int numVertices; /// Stride between vertices. If set to zero, vertices are assumed to be tightly /// packed and stride is sizeof( float3 ). unsigned int vertexStrideInBytes; /// Parallel to vertexBuffers: a device pointer per motion step, each with numVertices float values, /// specifying the curve width (radius) corresponding to each vertex. const CUdeviceptr* widthBuffers; /// Stride between widths. If set to zero, widths are assumed to be tightly /// packed and stride is sizeof( float ). unsigned int widthStrideInBytes; /// Reserved for future use. const CUdeviceptr* normalBuffers; /// Reserved for future use. unsigned int normalStrideInBytes; /// Device pointer to array of unsigned ints, one per curve segment. /// This buffer is required (unlike for OptixBuildInputTriangleArray). /// Each index is the start of degree+1 consecutive vertices in vertexBuffers, /// and corresponding widths in widthBuffers and normals in normalBuffers. /// These define a single segment. Size of array is numPrimitives. CUdeviceptr indexBuffer; /// Stride between indices. If set to zero, indices are assumed to be tightly /// packed and stride is sizeof( unsigned int ). unsigned int indexStrideInBytes; /// Combination of OptixGeometryFlags describing the /// primitive behavior. unsigned int flag; /// Primitive index bias, applied in optixGetPrimitiveIndex(). /// Sum of primitiveIndexOffset and number of primitives must not overflow 32bits. unsigned int primitiveIndexOffset; } OptixBuildInputCurveArray; /// AABB inputs typedef struct OptixAabb { float minX; ///< Lower extent in X direction. float minY; ///< Lower extent in Y direction. float minZ; ///< Lower extent in Z direction. float maxX; ///< Upper extent in X direction. float maxY; ///< Upper extent in Y direction. float maxZ; ///< Upper extent in Z direction. } OptixAabb; /// Custom primitive inputs /// /// \see #OptixBuildInput::customPrimitiveArray typedef struct OptixBuildInputCustomPrimitiveArray { /// Points to host array of device pointers to AABBs (type OptixAabb), one per motion step. /// Host array size must match number of motion keys as set in OptixMotionOptions (or an array of size 1 /// if OptixMotionOptions::numKeys is set to 1). /// Each device pointer must be a multiple of OPTIX_AABB_BUFFER_BYTE_ALIGNMENT. const CUdeviceptr* aabbBuffers; /// Number of primitives in each buffer (i.e., per motion step) in /// #OptixBuildInputCustomPrimitiveArray::aabbBuffers. unsigned int numPrimitives; /// Stride between AABBs (per motion key). If set to zero, the aabbs are assumed to be tightly /// packed and the stride is assumed to be sizeof( OptixAabb ). /// If non-zero, the value must be a multiple of OPTIX_AABB_BUFFER_BYTE_ALIGNMENT. unsigned int strideInBytes; /// Array of flags, to specify flags per sbt record, /// combinations of OptixGeometryFlags describing the /// primitive behavior, size must match numSbtRecords const unsigned int* flags; /// Number of sbt records available to the sbt index offset override. unsigned int numSbtRecords; /// Device pointer to per-primitive local sbt index offset buffer. May be NULL. /// Every entry must be in range [0,numSbtRecords-1]. /// Size needs to be the number of primitives. CUdeviceptr sbtIndexOffsetBuffer; /// Size of type of the sbt index offset. Needs to be 0, 1, 2 or 4 (8, 16 or 32 bit). unsigned int sbtIndexOffsetSizeInBytes; /// Stride between the index offsets. If set to zero, the offsets are assumed to be tightly /// packed and the stride matches the size of the type (sbtIndexOffsetSizeInBytes). unsigned int sbtIndexOffsetStrideInBytes; /// Primitive index bias, applied in optixGetPrimitiveIndex(). /// Sum of primitiveIndexOffset and number of primitive must not overflow 32bits. unsigned int primitiveIndexOffset; } OptixBuildInputCustomPrimitiveArray; /// Instance and instance pointer inputs /// /// \see #OptixBuildInput::instanceArray typedef struct OptixBuildInputInstanceArray { /// If OptixBuildInput::type is OPTIX_BUILD_INPUT_TYPE_INSTANCE_POINTERS instances and /// aabbs should be interpreted as arrays of pointers instead of arrays of structs. /// /// This pointer must be a multiple of OPTIX_INSTANCE_BYTE_ALIGNMENT if /// OptixBuildInput::type is OPTIX_BUILD_INPUT_TYPE_INSTANCES. The array elements must /// be a multiple of OPTIX_INSTANCE_BYTE_ALIGNMENT if OptixBuildInput::type is /// OPTIX_BUILD_INPUT_TYPE_INSTANCE_POINTERS. CUdeviceptr instances; /// Number of elements in #OptixBuildInputInstanceArray::instances. unsigned int numInstances; } OptixBuildInputInstanceArray; /// Enum to distinguish the different build input types. /// /// \see #OptixBuildInput::type typedef enum OptixBuildInputType { /// Triangle inputs. \see #OptixBuildInputTriangleArray OPTIX_BUILD_INPUT_TYPE_TRIANGLES = 0x2141, /// Custom primitive inputs. \see #OptixBuildInputCustomPrimitiveArray OPTIX_BUILD_INPUT_TYPE_CUSTOM_PRIMITIVES = 0x2142, /// Instance inputs. \see #OptixBuildInputInstanceArray OPTIX_BUILD_INPUT_TYPE_INSTANCES = 0x2143, /// Instance pointer inputs. \see #OptixBuildInputInstanceArray OPTIX_BUILD_INPUT_TYPE_INSTANCE_POINTERS = 0x2144, /// Curve inputs. \see #OptixBuildInputCurveArray OPTIX_BUILD_INPUT_TYPE_CURVES = 0x2145 } OptixBuildInputType; /// Build inputs. /// /// All of them support motion and the size of the data arrays needs to match the number of motion steps /// /// \see #optixAccelComputeMemoryUsage(), #optixAccelBuild() typedef struct OptixBuildInput { /// The type of the build input. OptixBuildInputType type; union { /// Triangle inputs. OptixBuildInputTriangleArray triangleArray; /// Curve inputs. OptixBuildInputCurveArray curveArray; /// Custom primitive inputs. OptixBuildInputCustomPrimitiveArray customPrimitiveArray; /// Instance and instance pointer inputs. OptixBuildInputInstanceArray instanceArray; char pad[1024]; }; } OptixBuildInput; // Some 32-bit tools use this header. This static_assert fails for them because // the default enum size is 4 bytes, rather than 8, under 32-bit compilers. // This #ifndef allows them to disable the static assert. // TODO Define a static assert for C/pre-C++-11 #if defined( __cplusplus ) && __cplusplus >= 201103L static_assert( sizeof( OptixBuildInput ) == 8 + 1024, "OptixBuildInput has wrong size" ); #endif /// Flags set on the #OptixInstance::flags. /// /// These can be or'ed together to combine multiple flags. typedef enum OptixInstanceFlags { /// No special flag set OPTIX_INSTANCE_FLAG_NONE = 0, /// Prevent triangles from getting culled due to their orientation. /// Effectively ignores ray flags /// OPTIX_RAY_FLAG_CULL_BACK_FACING_TRIANGLES and OPTIX_RAY_FLAG_CULL_FRONT_FACING_TRIANGLES. OPTIX_INSTANCE_FLAG_DISABLE_TRIANGLE_FACE_CULLING = 1u << 0, /// Flip triangle orientation. /// This affects front/backface culling as well as the reported face in case of a hit. OPTIX_INSTANCE_FLAG_FLIP_TRIANGLE_FACING = 1u << 1, /// Disable anyhit programs for all geometries of the instance. /// Can be overridden by OPTIX_RAY_FLAG_ENFORCE_ANYHIT. /// This flag is mutually exclusive with OPTIX_INSTANCE_FLAG_ENFORCE_ANYHIT. OPTIX_INSTANCE_FLAG_DISABLE_ANYHIT = 1u << 2, /// Enables anyhit programs for all geometries of the instance. /// Overrides OPTIX_GEOMETRY_FLAG_DISABLE_ANYHIT /// Can be overridden by OPTIX_RAY_FLAG_DISABLE_ANYHIT. /// This flag is mutually exclusive with OPTIX_INSTANCE_FLAG_DISABLE_ANYHIT. OPTIX_INSTANCE_FLAG_ENFORCE_ANYHIT = 1u << 3, /// Disable the instance transformation OPTIX_INSTANCE_FLAG_DISABLE_TRANSFORM = 1u << 6, } OptixInstanceFlags; /// Instances /// /// \see #OptixBuildInputInstanceArray::instances typedef struct OptixInstance { /// affine object-to-world transformation as 3x4 matrix in row-major layout float transform[12]; /// Application supplied ID. The maximal ID can be queried using OPTIX_DEVICE_PROPERTY_LIMIT_MAX_INSTANCE_ID. unsigned int instanceId; /// SBT record offset. Will only be used for instances of geometry acceleration structure (GAS) objects. /// Needs to be set to 0 for instances of instance acceleration structure (IAS) objects. The maximal SBT offset /// can be queried using OPTIX_DEVICE_PROPERTY_LIMIT_MAX_INSTANCE_SBT_OFFSET. unsigned int sbtOffset; /// Visibility mask. If rayMask & instanceMask == 0 the instance is culled. The number of available bits can be /// queried using OPTIX_DEVICE_PROPERTY_LIMIT_NUM_BITS_INSTANCE_VISIBILITY_MASK. unsigned int visibilityMask; /// Any combination of OptixInstanceFlags is allowed. unsigned int flags; /// Set with an OptixTraversableHandle. OptixTraversableHandle traversableHandle; /// round up to 80-byte, to ensure 16-byte alignment unsigned int pad[2]; } OptixInstance; /// Builder Options /// /// Used for #OptixAccelBuildOptions::buildFlags. Can be or'ed together. typedef enum OptixBuildFlags { /// No special flags set. OPTIX_BUILD_FLAG_NONE = 0, /// Allow updating the build with new vertex positions with subsequent calls to /// optixAccelBuild. OPTIX_BUILD_FLAG_ALLOW_UPDATE = 1u << 0, OPTIX_BUILD_FLAG_ALLOW_COMPACTION = 1u << 1, OPTIX_BUILD_FLAG_PREFER_FAST_TRACE = 1u << 2, OPTIX_BUILD_FLAG_PREFER_FAST_BUILD = 1u << 3, /// Allow random access to build input vertices /// See optixGetTriangleVertexData /// optixGetLinearCurveVertexData /// optixGetQuadraticBSplineVertexData /// optixGetCubicBSplineVertexData OPTIX_BUILD_FLAG_ALLOW_RANDOM_VERTEX_ACCESS = 1u << 4, /// Allow random access to instances /// See optixGetInstanceTraversableFromIAS OPTIX_BUILD_FLAG_ALLOW_RANDOM_INSTANCE_ACCESS = 1u << 5, } OptixBuildFlags; /// Enum to specify the acceleration build operation. /// /// Used in OptixAccelBuildOptions, which is then passed to optixAccelBuild and /// optixAccelComputeMemoryUsage, this enum indicates whether to do a build or an update /// of the acceleration structure. /// /// Acceleration structure updates utilize the same acceleration structure, but with /// updated bounds. Updates are typically much faster than builds, however, large /// perturbations can degrade the quality of the acceleration structure. /// /// \see #optixAccelComputeMemoryUsage(), #optixAccelBuild(), #OptixAccelBuildOptions typedef enum OptixBuildOperation { /// Perform a full build operation OPTIX_BUILD_OPERATION_BUILD = 0x2161, /// Perform an update using new bounds OPTIX_BUILD_OPERATION_UPDATE = 0x2162, } OptixBuildOperation; /// Enum to specify motion flags. /// /// \see #OptixMotionOptions::flags. typedef enum OptixMotionFlags { OPTIX_MOTION_FLAG_NONE = 0, OPTIX_MOTION_FLAG_START_VANISH = 1u << 0, OPTIX_MOTION_FLAG_END_VANISH = 1u << 1 } OptixMotionFlags; /// Motion options /// /// \see #OptixAccelBuildOptions::motionOptions, #OptixMatrixMotionTransform::motionOptions, /// #OptixSRTMotionTransform::motionOptions typedef struct OptixMotionOptions { /// If numKeys > 1, motion is enabled. timeBegin, /// timeEnd and flags are all ignored when motion is disabled. unsigned short numKeys; /// Combinations of #OptixMotionFlags unsigned short flags; /// Point in time where motion starts. float timeBegin; /// Point in time where motion ends. float timeEnd; } OptixMotionOptions; /// Build options for acceleration structures. /// /// \see #optixAccelComputeMemoryUsage(), #optixAccelBuild() typedef struct OptixAccelBuildOptions { /// Combinations of OptixBuildFlags unsigned int buildFlags; /// If OPTIX_BUILD_OPERATION_UPDATE the output buffer is assumed to contain the result /// of a full build with OPTIX_BUILD_FLAG_ALLOW_UPDATE set and using the same number of /// primitives. It is updated incrementally to reflect the current position of the /// primitives. OptixBuildOperation operation; /// Options for motion. OptixMotionOptions motionOptions; } OptixAccelBuildOptions; /// Struct for querying builder allocation requirements. /// /// Once queried the sizes should be used to allocate device memory of at least these sizes. /// /// \see #optixAccelComputeMemoryUsage() typedef struct OptixAccelBufferSizes { /// The size in bytes required for the outputBuffer parameter to optixAccelBuild when /// doing a build (OPTIX_BUILD_OPERATION_BUILD). size_t outputSizeInBytes; /// The size in bytes required for the tempBuffer paramter to optixAccelBuild when /// doing a build (OPTIX_BUILD_OPERATION_BUILD). size_t tempSizeInBytes; /// The size in bytes required for the tempBuffer parameter to optixAccelBuild /// when doing an update (OPTIX_BUILD_OPERATION_UPDATE). This value can be different /// than tempSizeInBytes used for a full build. Only non-zero if /// OPTIX_BUILD_FLAG_ALLOW_UPDATE flag is set in OptixAccelBuildOptions. size_t tempUpdateSizeInBytes; } OptixAccelBufferSizes; /// Properties which can be emitted during acceleration structure build. /// /// \see #OptixAccelEmitDesc::type. typedef enum OptixAccelPropertyType { /// Size of a compacted acceleration structure. The device pointer points to a uint64. OPTIX_PROPERTY_TYPE_COMPACTED_SIZE = 0x2181, /// OptixAabb * numMotionSteps OPTIX_PROPERTY_TYPE_AABBS = 0x2182, } OptixAccelPropertyType; /// Specifies a type and output destination for emitted post-build properties. /// /// \see #optixAccelBuild() typedef struct OptixAccelEmitDesc { /// Output buffer for the properties CUdeviceptr result; /// Requested property OptixAccelPropertyType type; } OptixAccelEmitDesc; /// Used to store information related to relocation of acceleration structures. /// /// \see #optixAccelGetRelocationInfo(), #optixAccelCheckRelocationCompatibility(), #optixAccelRelocate() typedef struct OptixAccelRelocationInfo { /// Opaque data, used internally, should not be modified unsigned long long info[4]; } OptixAccelRelocationInfo; /// Static transform /// /// The device address of instances of this type must be a multiple of OPTIX_TRANSFORM_BYTE_ALIGNMENT. /// /// \see #optixConvertPointerToTraversableHandle() typedef struct OptixStaticTransform { /// The traversable transformed by this transformation OptixTraversableHandle child; /// Padding to make the transformations 16 byte aligned unsigned int pad[2]; /// Affine object-to-world transformation as 3x4 matrix in row-major layout float transform[12]; /// Affine world-to-object transformation as 3x4 matrix in row-major layout /// Must be the inverse of the transform matrix float invTransform[12]; } OptixStaticTransform; /// Represents a matrix motion transformation. /// /// The device address of instances of this type must be a multiple of OPTIX_TRANSFORM_BYTE_ALIGNMENT. /// /// This struct, as defined here, handles only N=2 motion keys due to the fixed array length of its transform member. /// The following example shows how to create instances for an arbitrary number N of motion keys: /// /// \code /// float matrixData[N][12]; /// ... // setup matrixData /// /// size_t transformSizeInBytes = sizeof( OptixMatrixMotionTransform ) + ( N-2 ) * 12 * sizeof( float ); /// OptixMatrixMotionTransform* matrixMoptionTransform = (OptixMatrixMotionTransform*) malloc( transformSizeInBytes ); /// memset( matrixMoptionTransform, 0, transformSizeInBytes ); /// /// ... // setup other members of matrixMoptionTransform /// matrixMoptionTransform->motionOptions.numKeys/// = N; /// memcpy( matrixMoptionTransform->transform, matrixData, N * 12 * sizeof( float ) ); /// /// ... // copy matrixMoptionTransform to device memory /// free( matrixMoptionTransform ) /// \endcode /// /// \see #optixConvertPointerToTraversableHandle() typedef struct OptixMatrixMotionTransform { /// The traversable that is transformed by this transformation OptixTraversableHandle child; /// The motion options for this transformation OptixMotionOptions motionOptions; /// Padding to make the transformation 16 byte aligned unsigned int pad[3]; /// Affine object-to-world transformation as 3x4 matrix in row-major layout float transform[2][12]; } OptixMatrixMotionTransform; /// Represents an SRT transformation. /// /// An SRT transformation can represent a smooth rotation with fewer motion keys than a matrix transformation. Each /// motion key is constructed from elements taken from a matrix S, a quaternion R, and a translation T. /// /// The scaling matrix /// \f$S = \begin{bmatrix} sx & a & b & pvx \\ 0 & sy & c & pvy \\ 0 & 0 & sz & pvz \end{bmatrix}\f$ // [ sx a b pvx ] // S = [ 0 sy c pvy ] // [ 0 0 sz pvz ] /// defines an affine transformation that can include scale, shear, and a translation. /// The translation allows to define the pivot point for the subsequent rotation. /// /// The quaternion R = [ qx, qy, qz, qw ] describes a rotation with angular component qw = cos(theta/2) and other /// components [ qx, qy, qz ] = sin(theta/2) * [ ax, ay, az ] where the axis [ ax, ay, az ] is normalized. /// /// The translation matrix /// \f$T = \begin{bmatrix} 1 & 0 & 0 & tx \\ 0 & 1 & 0 & ty \\ 0 & 0 & 1 & tz \end{bmatrix}\f$ // [ 1 0 0 tx ] // T = [ 0 1 0 ty ] // [ 0 0 1 tz ] /// defines another translation that is applied after the rotation. Typically, this translation includes /// the inverse translation from the matrix S to reverse the translation for the pivot point for R. /// /// To obtain the effective transformation at time t, the elements of the components of S, R, and T will be interpolated /// linearly. The components are then multiplied to obtain the combined transformation C = T * R * S. The transformation /// C is the effective object-to-world transformations at time t, and C^(-1) is the effective world-to-object /// transformation at time t. /// /// \see #OptixSRTMotionTransform::srtData, #optixConvertPointerToTraversableHandle() typedef struct OptixSRTData { /// \name Parameters describing the SRT transformation /// @{ float sx, a, b, pvx, sy, c, pvy, sz, pvz, qx, qy, qz, qw, tx, ty, tz; /// @} } OptixSRTData; // TODO Define a static assert for C/pre-C++-11 #if defined( __cplusplus ) && __cplusplus >= 201103L static_assert( sizeof( OptixSRTData ) == 16 * 4, "OptixSRTData has wrong size" ); #endif /// Represents an SRT motion transformation. /// /// The device address of instances of this type must be a multiple of OPTIX_TRANSFORM_BYTE_ALIGNMENT. /// /// This struct, as defined here, handles only N=2 motion keys due to the fixed array length of its srtData member. /// The following example shows how to create instances for an arbitrary number N of motion keys: /// /// \code /// OptixSRTData srtData[N]; /// ... // setup srtData /// /// size_t transformSizeInBytes = sizeof( OptixSRTMotionTransform ) + ( N-2 ) * sizeof( OptixSRTData ); /// OptixSRTMotionTransform* srtMotionTransform = (OptixSRTMotionTransform*) malloc( transformSizeInBytes ); /// memset( srtMotionTransform, 0, transformSizeInBytes ); /// /// ... // setup other members of srtMotionTransform /// srtMotionTransform->motionOptions.numKeys = N; /// memcpy( srtMotionTransform->srtData, srtData, N * sizeof( OptixSRTData ) ); /// /// ... // copy srtMotionTransform to device memory /// free( srtMotionTransform ) /// \endcode /// /// \see #optixConvertPointerToTraversableHandle() typedef struct OptixSRTMotionTransform { /// The traversable transformed by this transformation OptixTraversableHandle child; /// The motion options for this transformation OptixMotionOptions motionOptions; /// Padding to make the SRT data 16 byte aligned unsigned int pad[3]; /// The actual SRT data describing the transformation OptixSRTData srtData[2]; } OptixSRTMotionTransform; // TODO Define a static assert for C/pre-C++-11 #if defined( __cplusplus ) && __cplusplus >= 201103L static_assert( sizeof( OptixSRTMotionTransform ) == 8 + 12 + 12 + 2 * 16 * 4, "OptixSRTMotionTransform has wrong size" ); #endif /// Traversable Handles /// /// \see #optixConvertPointerToTraversableHandle() typedef enum OptixTraversableType { /// Static transforms. \see #OptixStaticTransform OPTIX_TRAVERSABLE_TYPE_STATIC_TRANSFORM = 0x21C1, /// Matrix motion transform. \see #OptixMatrixMotionTransform OPTIX_TRAVERSABLE_TYPE_MATRIX_MOTION_TRANSFORM = 0x21C2, /// SRT motion transform. \see #OptixSRTMotionTransform OPTIX_TRAVERSABLE_TYPE_SRT_MOTION_TRANSFORM = 0x21C3, } OptixTraversableType; /// Pixel formats used by the denoiser. /// /// \see #OptixImage2D::format typedef enum OptixPixelFormat { OPTIX_PIXEL_FORMAT_HALF2 = 0x2207, ///< two halfs, XY OPTIX_PIXEL_FORMAT_HALF3 = 0x2201, ///< three halfs, RGB OPTIX_PIXEL_FORMAT_HALF4 = 0x2202, ///< four halfs, RGBA OPTIX_PIXEL_FORMAT_FLOAT2 = 0x2208, ///< two floats, XY OPTIX_PIXEL_FORMAT_FLOAT3 = 0x2203, ///< three floats, RGB OPTIX_PIXEL_FORMAT_FLOAT4 = 0x2204, ///< four floats, RGBA OPTIX_PIXEL_FORMAT_UCHAR3 = 0x2205, ///< three unsigned chars, RGB OPTIX_PIXEL_FORMAT_UCHAR4 = 0x2206 ///< four unsigned chars, RGBA } OptixPixelFormat; /// Image descriptor used by the denoiser. /// /// \see #optixDenoiserInvoke(), #optixDenoiserComputeIntensity() typedef struct OptixImage2D { /// Pointer to the actual pixel data. CUdeviceptr data; /// Width of the image (in pixels) unsigned int width; /// Height of the image (in pixels) unsigned int height; /// Stride between subsequent rows of the image (in bytes). unsigned int rowStrideInBytes; /// Stride between subsequent pixels of the image (in bytes). /// For now, only 0 or the value that corresponds to a dense packing of pixels (no gaps) is supported. unsigned int pixelStrideInBytes; /// Pixel format. OptixPixelFormat format; } OptixImage2D; /// Model kind used by the denoiser. /// /// \see #optixDenoiserCreate typedef enum OptixDenoiserModelKind { /// Use the built-in model appropriate for low dynamic range input. OPTIX_DENOISER_MODEL_KIND_LDR = 0x2322, /// Use the built-in model appropriate for high dynamic range input. OPTIX_DENOISER_MODEL_KIND_HDR = 0x2323, /// Use the built-in model appropriate for high dynamic range input and support for AOVs OPTIX_DENOISER_MODEL_KIND_AOV = 0x2324, /// Use the built-in model appropriate for high dynamic range input, temporally stable OPTIX_DENOISER_MODEL_KIND_TEMPORAL = 0x2325, } OptixDenoiserModelKind; /// Options used by the denoiser /// /// \see #optixDenoiserCreate() typedef struct OptixDenoiserOptions { // if nonzero, albedo image must be given in OptixDenoiserGuideLayer unsigned int guideAlbedo; // if nonzero, normal image must be given in OptixDenoiserGuideLayer unsigned int guideNormal; } OptixDenoiserOptions; /// Guide layer for the denoiser /// /// \see #optixDenoiserInvoke() typedef struct OptixDenoiserGuideLayer { // albedo/bsdf image OptixImage2D albedo; // normal vector image (2d or 3d pixel format) OptixImage2D normal; // 2d flow image, pixel flow from previous to current frame for each pixel OptixImage2D flow; } OptixDenoiserGuideLayer; /// Input/Output layers for the denoiser /// /// \see #optixDenoiserInvoke() typedef struct OptixDenoiserLayer { // input image (beauty or AOV) OptixImage2D input; // denoised output image from previous frame if temporal model kind selected OptixImage2D previousOutput; // denoised output for given input OptixImage2D output; } OptixDenoiserLayer; /// Various parameters used by the denoiser /// /// \see #optixDenoiserInvoke() /// \see #optixDenoiserComputeIntensity() /// \see #optixDenoiserComputeAverageColor() typedef struct OptixDenoiserParams { /// if set to nonzero value, denoise alpha channel (if present) in first inputLayer image unsigned int denoiseAlpha; /// average log intensity of input image (default null pointer). points to a single float. /// with the default (null pointer) denoised results will not be optimal for very dark or /// bright input images. CUdeviceptr hdrIntensity; /// blend factor. /// If set to 0 the output is 100% of the denoised input. If set to 1, the output is 100% of /// the unmodified input. Values between 0 and 1 will linearly interpolate between the denoised /// and unmodified input. float blendFactor; /// this parameter is used when the OPTIX_DENOISER_MODEL_KIND_AOV model kind is set. /// average log color of input image, separate for RGB channels (default null pointer). /// points to three floats. with the default (null pointer) denoised results will not be /// optimal. CUdeviceptr hdrAverageColor; } OptixDenoiserParams; /// Various sizes related to the denoiser. /// /// \see #optixDenoiserComputeMemoryResources() typedef struct OptixDenoiserSizes { size_t stateSizeInBytes; size_t withOverlapScratchSizeInBytes; size_t withoutOverlapScratchSizeInBytes; unsigned int overlapWindowSizeInPixels; } OptixDenoiserSizes; /// Ray flags passed to the device function #optixTrace(). These affect the behavior of /// traversal per invocation. /// /// \see #optixTrace() typedef enum OptixRayFlags { /// No change from the behavior configured for the individual AS. OPTIX_RAY_FLAG_NONE = 0u, /// Disables anyhit programs for the ray. /// Overrides OPTIX_INSTANCE_FLAG_ENFORCE_ANYHIT. /// This flag is mutually exclusive with OPTIX_RAY_FLAG_ENFORCE_ANYHIT, /// OPTIX_RAY_FLAG_CULL_DISABLED_ANYHIT, OPTIX_RAY_FLAG_CULL_ENFORCED_ANYHIT. OPTIX_RAY_FLAG_DISABLE_ANYHIT = 1u << 0, /// Forces anyhit program execution for the ray. /// Overrides OPTIX_GEOMETRY_FLAG_DISABLE_ANYHIT as well as OPTIX_INSTANCE_FLAG_DISABLE_ANYHIT. /// This flag is mutually exclusive with OPTIX_RAY_FLAG_DISABLE_ANYHIT, /// OPTIX_RAY_FLAG_CULL_DISABLED_ANYHIT, OPTIX_RAY_FLAG_CULL_ENFORCED_ANYHIT. OPTIX_RAY_FLAG_ENFORCE_ANYHIT = 1u << 1, /// Terminates the ray after the first hit and executes /// the closesthit program of that hit. OPTIX_RAY_FLAG_TERMINATE_ON_FIRST_HIT = 1u << 2, /// Disables closesthit programs for the ray, but still executes miss program in case of a miss. OPTIX_RAY_FLAG_DISABLE_CLOSESTHIT = 1u << 3, /// Do not intersect triangle back faces /// (respects a possible face change due to instance flag /// OPTIX_INSTANCE_FLAG_FLIP_TRIANGLE_FACING). /// This flag is mutually exclusive with OPTIX_RAY_FLAG_CULL_FRONT_FACING_TRIANGLES. OPTIX_RAY_FLAG_CULL_BACK_FACING_TRIANGLES = 1u << 4, /// Do not intersect triangle front faces /// (respects a possible face change due to instance flag /// OPTIX_INSTANCE_FLAG_FLIP_TRIANGLE_FACING). /// This flag is mutually exclusive with OPTIX_RAY_FLAG_CULL_BACK_FACING_TRIANGLES. OPTIX_RAY_FLAG_CULL_FRONT_FACING_TRIANGLES = 1u << 5, /// Do not intersect geometry which disables anyhit programs /// (due to setting geometry flag OPTIX_GEOMETRY_FLAG_DISABLE_ANYHIT or /// instance flag OPTIX_INSTANCE_FLAG_DISABLE_ANYHIT). /// This flag is mutually exclusive with OPTIX_RAY_FLAG_CULL_ENFORCED_ANYHIT, /// OPTIX_RAY_FLAG_ENFORCE_ANYHIT, OPTIX_RAY_FLAG_DISABLE_ANYHIT. OPTIX_RAY_FLAG_CULL_DISABLED_ANYHIT = 1u << 6, /// Do not intersect geometry which have an enabled anyhit program /// (due to not setting geometry flag OPTIX_GEOMETRY_FLAG_DISABLE_ANYHIT or /// setting instance flag OPTIX_INSTANCE_FLAG_ENFORCE_ANYHIT). /// This flag is mutually exclusive with OPTIX_RAY_FLAG_CULL_DISABLED_ANYHIT, /// OPTIX_RAY_FLAG_ENFORCE_ANYHIT, OPTIX_RAY_FLAG_DISABLE_ANYHIT. OPTIX_RAY_FLAG_CULL_ENFORCED_ANYHIT = 1u << 7 } OptixRayFlags; /// Transform /// /// OptixTransformType is used by the device function #optixGetTransformTypeFromHandle() to /// determine the type of the OptixTraversableHandle returned from /// optixGetTransformListHandle(). typedef enum OptixTransformType { OPTIX_TRANSFORM_TYPE_NONE = 0, ///< Not a transformation OPTIX_TRANSFORM_TYPE_STATIC_TRANSFORM = 1, ///< \see #OptixStaticTransform OPTIX_TRANSFORM_TYPE_MATRIX_MOTION_TRANSFORM = 2, ///< \see #OptixMatrixMotionTransform OPTIX_TRANSFORM_TYPE_SRT_MOTION_TRANSFORM = 3, ///< \see #OptixSRTMotionTransform OPTIX_TRANSFORM_TYPE_INSTANCE = 4, ///< \see #OptixInstance } OptixTransformType; /// Specifies the set of valid traversable graphs that may be /// passed to invocation of #optixTrace(). Flags may be bitwise combined. typedef enum OptixTraversableGraphFlags { /// Used to signal that any traversable graphs is valid. /// This flag is mutually exclusive with all other flags. OPTIX_TRAVERSABLE_GRAPH_FLAG_ALLOW_ANY = 0, /// Used to signal that a traversable graph of a single Geometry Acceleration /// Structure (GAS) without any transforms is valid. This flag may be combined with /// other flags except for OPTIX_TRAVERSABLE_GRAPH_FLAG_ALLOW_ANY. OPTIX_TRAVERSABLE_GRAPH_FLAG_ALLOW_SINGLE_GAS = 1u << 0, /// Used to signal that a traversable graph of a single Instance Acceleration /// Structure (IAS) directly connected to Geometry Acceleration Structure (GAS) /// traversables without transform traversables in between is valid. This flag may /// be combined with other flags except for OPTIX_TRAVERSABLE_GRAPH_FLAG_ALLOW_ANY. OPTIX_TRAVERSABLE_GRAPH_FLAG_ALLOW_SINGLE_LEVEL_INSTANCING = 1u << 1, } OptixTraversableGraphFlags; /// Optimization levels /// /// \see #OptixModuleCompileOptions::optLevel typedef enum OptixCompileOptimizationLevel { /// Default is to run all optimizations OPTIX_COMPILE_OPTIMIZATION_DEFAULT = 0, /// No optimizations OPTIX_COMPILE_OPTIMIZATION_LEVEL_0 = 0x2340, /// Some optimizations OPTIX_COMPILE_OPTIMIZATION_LEVEL_1 = 0x2341, /// Most optimizations OPTIX_COMPILE_OPTIMIZATION_LEVEL_2 = 0x2342, /// All optimizations OPTIX_COMPILE_OPTIMIZATION_LEVEL_3 = 0x2343, } OptixCompileOptimizationLevel; /// Debug levels /// /// \see #OptixModuleCompileOptions::debugLevel typedef enum OptixCompileDebugLevel { /// Default currently is to add line info OPTIX_COMPILE_DEBUG_LEVEL_DEFAULT = 0, /// No debug information OPTIX_COMPILE_DEBUG_LEVEL_NONE = 0x2350, /// Generate lineinfo only OPTIX_COMPILE_DEBUG_LEVEL_LINEINFO = 0x2351, /// Generate dwarf debug information. OPTIX_COMPILE_DEBUG_LEVEL_FULL = 0x2352, } OptixCompileDebugLevel; /// Struct for specifying specializations for pipelineParams as specified in /// OptixPipelineCompileOptions::pipelineLaunchParamsVariableName. /// /// The bound values are supposed to represent a constant value in the /// pipelineParams. OptiX will attempt to locate all loads from the pipelineParams and /// correlate them to the appropriate bound value, but there are cases where OptiX cannot /// safely or reliably do this. For example if the pointer to the pipelineParams is passed /// as an argument to a non-inline function or the offset of the load to the /// pipelineParams cannot be statically determined (e.g. accessed in a loop). No module /// should rely on the value being specialized in order to work correctly. The values in /// the pipelineParams specified on optixLaunch should match the bound value. If /// validation mode is enabled on the context, OptiX will verify that the bound values /// specified matches the values in pipelineParams specified to optixLaunch. /// /// These values are compiled in to the module as constants. Once the constants are /// inserted into the code, an optimization pass will be run that will attempt to /// propagate the consants and remove unreachable code. /// /// If caching is enabled, changes in these values will result in newly compiled modules. /// /// The pipelineParamOffset and sizeInBytes must be within the bounds of the /// pipelineParams variable. OPTIX_ERROR_INVALID_VALUE will be returned from /// optixModuleCreateFromPTX otherwise. /// /// If more than one bound value overlaps or the size of a bound value is equal to 0, /// an OPTIX_ERROR_INVALID_VALUE will be returned from optixModuleCreateFromPTX. /// /// The same set of bound values do not need to be used for all modules in a pipeline, but /// overlapping values between modules must have the same value. /// OPTIX_ERROR_INVALID_VALUE will be returned from optixPipelineCreate otherwise. /// /// \see #OptixModuleCompileOptions typedef struct OptixModuleCompileBoundValueEntry { size_t pipelineParamOffsetInBytes; size_t sizeInBytes; const void* boundValuePtr; const char* annotation; // optional string to display, set to 0 if unused. If unused, // OptiX will report the annotation as "No annotation" } OptixModuleCompileBoundValueEntry; /// Compilation options for module /// /// \see #optixModuleCreateFromPTX() typedef struct OptixModuleCompileOptions { /// Maximum number of registers allowed when compiling to SASS. /// Set to 0 for no explicit limit. May vary within a pipeline. int maxRegisterCount; /// Optimization level. May vary within a pipeline. OptixCompileOptimizationLevel optLevel; /// Generate debug information. OptixCompileDebugLevel debugLevel; /// Ingored if numBoundValues is set to 0 const OptixModuleCompileBoundValueEntry* boundValues; /// set to 0 if unused unsigned int numBoundValues; } OptixModuleCompileOptions; /// Distinguishes different kinds of program groups. typedef enum OptixProgramGroupKind { /// Program group containing a raygen (RG) program /// \see #OptixProgramGroupSingleModule, #OptixProgramGroupDesc::raygen OPTIX_PROGRAM_GROUP_KIND_RAYGEN = 0x2421, /// Program group containing a miss (MS) program /// \see #OptixProgramGroupSingleModule, #OptixProgramGroupDesc::miss OPTIX_PROGRAM_GROUP_KIND_MISS = 0x2422, /// Program group containing an exception (EX) program /// \see OptixProgramGroupHitgroup, #OptixProgramGroupDesc::exception OPTIX_PROGRAM_GROUP_KIND_EXCEPTION = 0x2423, /// Program group containing an intersection (IS), any hit (AH), and/or closest hit (CH) program /// \see #OptixProgramGroupSingleModule, #OptixProgramGroupDesc::hitgroup OPTIX_PROGRAM_GROUP_KIND_HITGROUP = 0x2424, /// Program group containing a direct (DC) or continuation (CC) callable program /// \see OptixProgramGroupCallables, #OptixProgramGroupDesc::callables OPTIX_PROGRAM_GROUP_KIND_CALLABLES = 0x2425 } OptixProgramGroupKind; /// Flags for program groups typedef enum OptixProgramGroupFlags { /// Currently there are no flags OPTIX_PROGRAM_GROUP_FLAGS_NONE = 0 } OptixProgramGroupFlags; /// Program group representing a single module. /// /// Used for raygen, miss, and exception programs. In case of raygen and exception programs, module and entry /// function name need to be valid. For miss programs, module and entry function name might both be \c nullptr. /// /// \see #OptixProgramGroupDesc::raygen, #OptixProgramGroupDesc::miss, #OptixProgramGroupDesc::exception typedef struct OptixProgramGroupSingleModule { /// Module holding single program. OptixModule module; /// Entry function name of the single program. const char* entryFunctionName; } OptixProgramGroupSingleModule; /// Program group representing the hitgroup. /// /// For each of the three program types, module and entry function name might both be \c nullptr. /// /// \see #OptixProgramGroupDesc::hitgroup typedef struct OptixProgramGroupHitgroup { /// Module holding the closest hit (CH) program. OptixModule moduleCH; /// Entry function name of the closest hit (CH) program. const char* entryFunctionNameCH; /// Module holding the any hit (AH) program. OptixModule moduleAH; /// Entry function name of the any hit (AH) program. const char* entryFunctionNameAH; /// Module holding the intersection (Is) program. OptixModule moduleIS; /// Entry function name of the intersection (IS) program. const char* entryFunctionNameIS; } OptixProgramGroupHitgroup; /// Program group representing callables. /// /// Module and entry function name need to be valid for at least one of the two callables. /// /// \see ##OptixProgramGroupDesc::callables typedef struct OptixProgramGroupCallables { /// Module holding the direct callable (DC) program. OptixModule moduleDC; /// Entry function name of the direct callable (DC) program. const char* entryFunctionNameDC; /// Module holding the continuation callable (CC) program. OptixModule moduleCC; /// Entry function name of the continuation callable (CC) program. const char* entryFunctionNameCC; } OptixProgramGroupCallables; /// Descriptor for program groups. typedef struct OptixProgramGroupDesc { /// The kind of program group. OptixProgramGroupKind kind; /// See #OptixProgramGroupFlags unsigned int flags; union { /// \see #OPTIX_PROGRAM_GROUP_KIND_RAYGEN OptixProgramGroupSingleModule raygen; /// \see #OPTIX_PROGRAM_GROUP_KIND_MISS OptixProgramGroupSingleModule miss; /// \see #OPTIX_PROGRAM_GROUP_KIND_EXCEPTION OptixProgramGroupSingleModule exception; /// \see #OPTIX_PROGRAM_GROUP_KIND_CALLABLES OptixProgramGroupCallables callables; /// \see #OPTIX_PROGRAM_GROUP_KIND_HITGROUP OptixProgramGroupHitgroup hitgroup; }; } OptixProgramGroupDesc; /// Program group options /// /// \see #optixProgramGroupCreate() typedef struct OptixProgramGroupOptions { /// reserved value for future use. must be 0. int reserved; } OptixProgramGroupOptions; /// The following values are used to indicate which exception was thrown. typedef enum OptixExceptionCodes { /// Stack overflow of the continuation stack. /// no exception details. OPTIX_EXCEPTION_CODE_STACK_OVERFLOW = -1, /// The trace depth is exceeded. /// no exception details. OPTIX_EXCEPTION_CODE_TRACE_DEPTH_EXCEEDED = -2, /// The traversal depth is exceeded. /// Exception details: /// optixGetTransformListSize() /// optixGetTransformListHandle() OPTIX_EXCEPTION_CODE_TRAVERSAL_DEPTH_EXCEEDED = -3, /// Traversal encountered an invalid traversable type. /// Exception details: /// optixGetTransformListSize() /// optixGetTransformListHandle() /// optixGetExceptionInvalidTraversable() OPTIX_EXCEPTION_CODE_TRAVERSAL_INVALID_TRAVERSABLE = -5, /// The miss SBT record index is out of bounds /// A miss SBT record index is valid within the range [0, OptixShaderBindingTable::missRecordCount) (See optixLaunch) /// Exception details: /// optixGetExceptionInvalidSbtOffset() OPTIX_EXCEPTION_CODE_TRAVERSAL_INVALID_MISS_SBT = -6, /// The traversal hit SBT record index out of bounds. /// /// A traversal hit SBT record index is valid within the range [0, OptixShaderBindingTable::hitgroupRecordCount) (See optixLaunch) /// The following formula relates the // sbt-index (See optixGetExceptionInvalidSbtOffset), // sbt-instance-offset (See OptixInstance::sbtOffset), /// sbt-geometry-acceleration-structure-index (See optixGetSbtGASIndex), /// sbt-stride-from-trace-call and sbt-offset-from-trace-call (See optixTrace) /// /// sbt-index = sbt-instance-offset + (sbt-geometry-acceleration-structure-index * sbt-stride-from-trace-call) + sbt-offset-from-trace-call /// /// Exception details: /// optixGetTransformListSize() /// optixGetTransformListHandle() /// optixGetExceptionInvalidSbtOffset() /// optixGetSbtGASIndex() OPTIX_EXCEPTION_CODE_TRAVERSAL_INVALID_HIT_SBT = -7, /// The shader encountered an unsupported primitive type (See OptixPipelineCompileOptions::usesPrimitiveTypeFlags). /// no exception details. OPTIX_EXCEPTION_CODE_UNSUPPORTED_PRIMITIVE_TYPE = -8, /// The shader encountered a call to optixTrace with at least /// one of the float arguments being inf or nan. /// Exception details: /// optixGetExceptionInvalidRay() OPTIX_EXCEPTION_CODE_INVALID_RAY = -9, /// The shader encountered a call to either optixDirectCall or optixCallableCall /// where the argument count does not match the parameter count of the callable /// program which is called. /// Exception details: /// optixGetExceptionParameterMismatch OPTIX_EXCEPTION_CODE_CALLABLE_PARAMETER_MISMATCH = -10, /// The invoked builtin IS does not match the current GAS OPTIX_EXCEPTION_CODE_BUILTIN_IS_MISMATCH = -11, /// Tried to call a callable program using an SBT offset that is larger /// than the number of passed in callable SBT records. /// Exception details: /// optixGetExceptionInvalidSbtOffset() OPTIX_EXCEPTION_CODE_CALLABLE_INVALID_SBT = -12, /// Tried to call a direct callable using an SBT offset of a record that /// was built from a program group that did not include a direct callable. OPTIX_EXCEPTION_CODE_CALLABLE_NO_DC_SBT_RECORD = -13, /// Tried to call a continuation callable using an SBT offset of a record /// that was built from a program group that did not include a continuation callable. OPTIX_EXCEPTION_CODE_CALLABLE_NO_CC_SBT_RECORD = -14, /// Tried to directly traverse a single gas while single gas traversable graphs are not enabled /// (see OptixTraversableGraphFlags::OPTIX_TRAVERSABLE_GRAPH_FLAG_ALLOW_SINGLE_GAS). /// Exception details: /// optixGetTransformListSize() /// optixGetTransformListHandle() /// optixGetExceptionInvalidTraversable() OPTIX_EXCEPTION_CODE_UNSUPPORTED_SINGLE_LEVEL_GAS = -15, /// argument passed to an optix call is /// not within an acceptable range of values. OPTIX_EXCEPTION_CODE_INVALID_VALUE_ARGUMENT_0 = -16, OPTIX_EXCEPTION_CODE_INVALID_VALUE_ARGUMENT_1 = -17, OPTIX_EXCEPTION_CODE_INVALID_VALUE_ARGUMENT_2 = -18, /// Tried to access data on an AS without random data access support (See OptixBuildFlags). OPTIX_EXCEPTION_CODE_UNSUPPORTED_DATA_ACCESS = -32, } OptixExceptionCodes; /// Exception flags. /// /// \see #OptixPipelineCompileOptions::exceptionFlags, #OptixExceptionCodes typedef enum OptixExceptionFlags { /// No exception are enabled. OPTIX_EXCEPTION_FLAG_NONE = 0, /// Enables exceptions check related to the continuation stack. OPTIX_EXCEPTION_FLAG_STACK_OVERFLOW = 1u << 0, /// Enables exceptions check related to trace depth. OPTIX_EXCEPTION_FLAG_TRACE_DEPTH = 1u << 1, /// Enables user exceptions via optixThrowException(). This flag must be specified for all modules in a pipeline /// if any module calls optixThrowException(). OPTIX_EXCEPTION_FLAG_USER = 1u << 2, /// Enables various exceptions check related to traversal. OPTIX_EXCEPTION_FLAG_DEBUG = 1u << 3 } OptixExceptionFlags; /// Compilation options for all modules of a pipeline. /// /// Similar to #OptixModuleCompileOptions, but these options here need to be equal for all modules of a pipeline. /// /// \see #optixModuleCreateFromPTX(), #optixPipelineCreate() typedef struct OptixPipelineCompileOptions { /// Boolean value indicating whether motion blur could be used int usesMotionBlur; /// Traversable graph bitfield. See OptixTraversableGraphFlags unsigned int traversableGraphFlags; /// How much storage, in 32b words, to make available for the payload, [0..32] int numPayloadValues; /// How much storage, in 32b words, to make available for the attributes. The /// minimum number is 2. Values below that will automatically be changed to 2. [2..8] int numAttributeValues; /// A bitmask of OptixExceptionFlags indicating which exceptions are enabled. unsigned int exceptionFlags; /// The name of the pipeline parameter variable. If 0, no pipeline parameter /// will be available. This will be ignored if the launch param variable was /// optimized out or was not found in the modules linked to the pipeline. const char* pipelineLaunchParamsVariableName; /// Bit field enabling primitive types. See OptixPrimitiveTypeFlags. /// Setting to zero corresponds to enabling OPTIX_PRIMITIVE_TYPE_FLAGS_CUSTOM and OPTIX_PRIMITIVE_TYPE_FLAGS_TRIANGLE. unsigned int usesPrimitiveTypeFlags; // Reserved for future use.These values must be set to zero. unsigned int reserved; size_t reserved2; } OptixPipelineCompileOptions; /// Link options for a pipeline /// /// \see #optixPipelineCreate() typedef struct OptixPipelineLinkOptions { /// Maximum trace recursion depth. 0 means a ray generation program can be /// launched, but can't trace any rays. The maximum allowed value is 31. unsigned int maxTraceDepth; /// Generate debug information. OptixCompileDebugLevel debugLevel; } OptixPipelineLinkOptions; /// Describes the shader binding table (SBT) /// /// \see #optixLaunch() typedef struct OptixShaderBindingTable { /// Device address of the SBT record of the ray gen program to start launch at. The address must be a multiple of /// OPTIX_SBT_RECORD_ALIGNMENT. CUdeviceptr raygenRecord; /// Device address of the SBT record of the exception program. The address must be a multiple of /// OPTIX_SBT_RECORD_ALIGNMENT. CUdeviceptr exceptionRecord; /// Arrays of SBT records for miss programs. The base address and the stride must be a multiple of /// OPTIX_SBT_RECORD_ALIGNMENT. /// @{ CUdeviceptr missRecordBase; unsigned int missRecordStrideInBytes; unsigned int missRecordCount; /// @} /// Arrays of SBT records for hit groups. The base address and the stride must be a multiple of /// OPTIX_SBT_RECORD_ALIGNMENT. /// @{ CUdeviceptr hitgroupRecordBase; unsigned int hitgroupRecordStrideInBytes; unsigned int hitgroupRecordCount; /// @} /// Arrays of SBT records for callable programs. If the base address is not null, the stride and count must not be /// zero. If the base address is null, then the count needs to zero. The base address and the stride must be a /// multiple of OPTIX_SBT_RECORD_ALIGNMENT. /// @{ CUdeviceptr callablesRecordBase; unsigned int callablesRecordStrideInBytes; unsigned int callablesRecordCount; /// @} } OptixShaderBindingTable; /// Describes the stack size requirements of a program group. /// /// \see optixProgramGroupGetStackSize() typedef struct OptixStackSizes { /// Continuation stack size of RG programs in bytes unsigned int cssRG; /// Continuation stack size of MS programs in bytes unsigned int cssMS; /// Continuation stack size of CH programs in bytes unsigned int cssCH; /// Continuation stack size of AH programs in bytes unsigned int cssAH; /// Continuation stack size of IS programs in bytes unsigned int cssIS; /// Continuation stack size of CC programs in bytes unsigned int cssCC; /// Direct stack size of DC programs in bytes unsigned int dssDC; } OptixStackSizes; /// Options that can be passed to \c optixQueryFunctionTable() typedef enum OptixQueryFunctionTableOptions { /// Placeholder (there are no options yet) OPTIX_QUERY_FUNCTION_TABLE_OPTION_DUMMY = 0 } OptixQueryFunctionTableOptions; /// Type of the function \c optixQueryFunctionTable() typedef OptixResult( OptixQueryFunctionTable_t )( int abiId, unsigned int numOptions, OptixQueryFunctionTableOptions* /*optionKeys*/, const void** /*optionValues*/, void* functionTable, size_t sizeOfTable ); /// Specifies the options for retrieving an intersection program for a built-in primitive type. /// The primitive type must not be OPTIX_PRIMITIVE_TYPE_CUSTOM. /// /// \see #optixBuiltinISModuleGet() typedef struct OptixBuiltinISOptions { OptixPrimitiveType builtinISModuleType; /// Boolean value indicating whether vertex motion blur is used (but not motion transform blur). int usesMotionBlur; } OptixBuiltinISOptions; #if defined( __CUDACC__ ) /// Describes the ray that was passed into \c optixTrace() which caused an exception with /// exception code OPTIX_EXCEPTION_CODE_INVALID_RAY. /// /// \see #optixGetExceptionInvalidRay() typedef struct OptixInvalidRayExceptionDetails { float3 origin; float3 direction; float tmin; float tmax; float time; } OptixInvalidRayExceptionDetails; /// Describes the details of a call to a callable program which caused an exception with /// exception code OPTIX_EXCEPTION_CODE_CALLABLE_PARAMETER_MISMATCH, /// Note that OptiX packs the parameters into individual 32 bit values, so the number of /// expected and passed values may not correspond to the number of arguments passed into /// optixDirectCall or optixContinuationCall, or the number parameters in the definition /// of the function that is called. typedef struct OptixParameterMismatchExceptionDetails { /// Number of 32 bit values expected by the callable program unsigned int expectedParameterCount; /// Number of 32 bit values that were passed to the callable program unsigned int passedArgumentCount; /// The offset of the SBT entry of the callable program relative to OptixShaderBindingTable::callablesRecordBase unsigned int sbtIndex; /// Pointer to a string that holds the name of the callable program that was called char* callableName; } OptixParameterMismatchExceptionDetails; #endif /*@}*/ // end group optix_types #endif // __optix_optix_7_types_h__ ================================================ FILE: render/optixutils/include/optix_denoiser_tiling.h ================================================ /* * Copyright (c) 2021 NVIDIA Corporation. All rights reserved. * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions * are met: * * Redistributions of source code must retain the above copyright * notice, this list of conditions and the following disclaimer. * * Redistributions in binary form must reproduce the above copyright * notice, this list of conditions and the following disclaimer in the * documentation and/or other materials provided with the distribution. * * Neither the name of NVIDIA CORPORATION nor the names of its * contributors may be used to endorse or promote products derived * from this software without specific prior written permission. * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY * EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR * PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR * CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, * EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, * PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR * PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY * OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. */ /// @file /// @author NVIDIA Corporation /// @brief OptiX public API header #ifndef optix_denoiser_tiling_h #define optix_denoiser_tiling_h #include #include #include #ifdef __cplusplus extern "C" { #endif /** \addtogroup optix_utilities @{ */ /// Tile definition /// /// see #optixUtilDenoiserSplitImage /// struct OptixUtilDenoiserImageTile { // input tile image OptixImage2D input; // output tile image OptixImage2D output; // overlap offsets, parameters for #optixUtilDenoiserInvoke unsigned int inputOffsetX; unsigned int inputOffsetY; }; /// Return pixel stride in bytes for the given pixel format /// if the pixelStrideInBytes member of the image is zero. /// Otherwise return pixelStrideInBytes from the image. /// /// \param[in] image Image containing the pixel stride /// inline unsigned int optixUtilGetPixelStride( const OptixImage2D& image ) { unsigned int pixelStrideInBytes = image.pixelStrideInBytes; if( pixelStrideInBytes == 0 ) { switch( image.format ) { case OPTIX_PIXEL_FORMAT_HALF2: pixelStrideInBytes = 2 * sizeof( short ); break; case OPTIX_PIXEL_FORMAT_HALF3: pixelStrideInBytes = 3 * sizeof( short ); break; case OPTIX_PIXEL_FORMAT_HALF4: pixelStrideInBytes = 4 * sizeof( short ); break; case OPTIX_PIXEL_FORMAT_FLOAT2: pixelStrideInBytes = 2 * sizeof( float ); break; case OPTIX_PIXEL_FORMAT_FLOAT3: pixelStrideInBytes = 3 * sizeof( float ); break; case OPTIX_PIXEL_FORMAT_FLOAT4: pixelStrideInBytes = 4 * sizeof( float ); break; case OPTIX_PIXEL_FORMAT_UCHAR3: pixelStrideInBytes = 3 * sizeof( char ); break; case OPTIX_PIXEL_FORMAT_UCHAR4: pixelStrideInBytes = 4 * sizeof( char ); break; } } return pixelStrideInBytes; } /// Split image into 2D tiles given horizontal and vertical tile size /// /// \param[in] input full resolution input image to be split /// \param[in] output full resolution output image /// \param[in] overlapWindowSizeInPixels see #OptixDenoiserSizes, #optixDenoiserComputeMemoryResources /// \param[in] tileWidth maximum width of tiles /// \param[in] tileHeight maximum height of tiles /// \param[out] tiles list of tiles covering the input image /// inline OptixResult optixUtilDenoiserSplitImage( const OptixImage2D& input, const OptixImage2D& output, unsigned int overlapWindowSizeInPixels, unsigned int tileWidth, unsigned int tileHeight, std::vector& tiles ) { if( tileWidth == 0 || tileHeight == 0 ) return OPTIX_ERROR_INVALID_VALUE; unsigned int inPixelStride = optixUtilGetPixelStride( input ); unsigned int outPixelStride = optixUtilGetPixelStride( output ); int inp_w = std::min( tileWidth + 2 * overlapWindowSizeInPixels, input.width ); int inp_h = std::min( tileHeight + 2 * overlapWindowSizeInPixels, input.height ); int inp_y = 0, copied_y = 0; do { int inputOffsetY = inp_y == 0 ? 0 : std::max( (int)overlapWindowSizeInPixels, inp_h - ( (int)input.height - inp_y ) ); int copy_y = inp_y == 0 ? std::min( input.height, tileHeight + overlapWindowSizeInPixels ) : std::min( tileHeight, input.height - copied_y ); int inp_x = 0, copied_x = 0; do { int inputOffsetX = inp_x == 0 ? 0 : std::max( (int)overlapWindowSizeInPixels, inp_w - ( (int)input.width - inp_x ) ); int copy_x = inp_x == 0 ? std::min( input.width, tileWidth + overlapWindowSizeInPixels ) : std::min( tileWidth, input.width - copied_x ); OptixUtilDenoiserImageTile tile; tile.input.data = input.data + ( inp_y - inputOffsetY ) * input.rowStrideInBytes + ( inp_x - inputOffsetX ) * inPixelStride; tile.input.width = inp_w; tile.input.height = inp_h; tile.input.rowStrideInBytes = input.rowStrideInBytes; tile.input.pixelStrideInBytes = input.pixelStrideInBytes; tile.input.format = input.format; tile.output.data = output.data + inp_y * output.rowStrideInBytes + inp_x * outPixelStride; tile.output.width = copy_x; tile.output.height = copy_y; tile.output.rowStrideInBytes = output.rowStrideInBytes; tile.output.pixelStrideInBytes = output.pixelStrideInBytes; tile.output.format = output.format; tile.inputOffsetX = inputOffsetX; tile.inputOffsetY = inputOffsetY; tiles.push_back( tile ); inp_x += inp_x == 0 ? tileWidth + overlapWindowSizeInPixels : tileWidth; copied_x += copy_x; } while( inp_x < static_cast( input.width ) ); inp_y += inp_y == 0 ? tileHeight + overlapWindowSizeInPixels : tileHeight; copied_y += copy_y; } while( inp_y < static_cast( input.height ) ); return OPTIX_SUCCESS; } /// Run denoiser on input layers /// see #optixDenoiserInvoke /// additional parameters: /// Runs the denoiser on the input layers on a single GPU and stream using #optixDenoiserInvoke. /// If the input layers' dimensions are larger than the specified tile size, the image is divided into /// tiles using #optixUtilDenoiserSplitImage, and multiple back-to-back invocations are performed in /// order to reuse the scratch space. Multiple tiles can be invoked concurrently if /// #optixUtilDenoiserSplitImage is used directly and multiple scratch allocations for each concurrent /// invocation are used. /// The input parameters are the same as #optixDenoiserInvoke except for the addition of the maximum tile size. /// /// \param[in] denoiser /// \param[in] stream /// \param[in] params /// \param[in] denoiserState /// \param[in] denoiserStateSizeInBytes /// \param[in] guideLayer /// \param[in] layers /// \param[in] numLayers /// \param[in] scratch /// \param[in] scratchSizeInBytes /// \param[in] overlapWindowSizeInPixels /// \param[in] tileWidth /// \param[in] tileHeight inline OptixResult optixUtilDenoiserInvokeTiled( OptixDenoiser denoiser, CUstream stream, const OptixDenoiserParams* params, CUdeviceptr denoiserState, size_t denoiserStateSizeInBytes, const OptixDenoiserGuideLayer* guideLayer, const OptixDenoiserLayer* layers, unsigned int numLayers, CUdeviceptr scratch, size_t scratchSizeInBytes, unsigned int overlapWindowSizeInPixels, unsigned int tileWidth, unsigned int tileHeight ) { if( !guideLayer || !layers ) return OPTIX_ERROR_INVALID_VALUE; std::vector> tiles( numLayers ); std::vector> prevTiles( numLayers ); for( unsigned int l = 0; l < numLayers; l++ ) { if( const OptixResult res = optixUtilDenoiserSplitImage( layers[l].input, layers[l].output, overlapWindowSizeInPixels, tileWidth, tileHeight, tiles[l] ) ) return res; if( layers[l].previousOutput.data ) { OptixImage2D dummyOutput = layers[l].previousOutput; if( const OptixResult res = optixUtilDenoiserSplitImage( layers[l].previousOutput, dummyOutput, overlapWindowSizeInPixels, tileWidth, tileHeight, prevTiles[l] ) ) return res; } } std::vector albedoTiles; if( guideLayer->albedo.data ) { OptixImage2D dummyOutput = guideLayer->albedo; if( const OptixResult res = optixUtilDenoiserSplitImage( guideLayer->albedo, dummyOutput, overlapWindowSizeInPixels, tileWidth, tileHeight, albedoTiles ) ) return res; } std::vector normalTiles; if( guideLayer->normal.data ) { OptixImage2D dummyOutput = guideLayer->normal; if( const OptixResult res = optixUtilDenoiserSplitImage( guideLayer->normal, dummyOutput, overlapWindowSizeInPixels, tileWidth, tileHeight, normalTiles ) ) return res; } std::vector flowTiles; if( guideLayer->flow.data ) { OptixImage2D dummyOutput = guideLayer->flow; if( const OptixResult res = optixUtilDenoiserSplitImage( guideLayer->flow, dummyOutput, overlapWindowSizeInPixels, tileWidth, tileHeight, flowTiles ) ) return res; } for( size_t t = 0; t < tiles[0].size(); t++ ) { std::vector tlayers; for( unsigned int l = 0; l < numLayers; l++ ) { OptixDenoiserLayer layer = {}; layer.input = ( tiles[l] )[t].input; layer.output = ( tiles[l] )[t].output; if( layers[l].previousOutput.data ) layer.previousOutput = ( prevTiles[l] )[t].input; tlayers.push_back( layer ); } OptixDenoiserGuideLayer gl = {}; if( guideLayer->albedo.data ) gl.albedo = albedoTiles[t].input; if( guideLayer->normal.data ) gl.normal = normalTiles[t].input; if( guideLayer->flow.data ) gl.flow = flowTiles[t].input; if( const OptixResult res = optixDenoiserInvoke( denoiser, stream, params, denoiserState, denoiserStateSizeInBytes, &gl, &tlayers[0], numLayers, ( tiles[0] )[t].inputOffsetX, ( tiles[0] )[t].inputOffsetY, scratch, scratchSizeInBytes ) ) return res; } return OPTIX_SUCCESS; } /*@}*/ // end group optix_utilities #ifdef __cplusplus } #endif #endif // __optix_optix_stack_size_h__ ================================================ FILE: render/optixutils/include/optix_device.h ================================================ /* * Copyright (c) 2021 NVIDIA Corporation. All rights reserved. * * NVIDIA Corporation and its licensors retain all intellectual property and proprietary * rights in and to this software, related documentation and any modifications thereto. * Any use, reproduction, disclosure or distribution of this software and related * documentation without an express license agreement from NVIDIA Corporation is strictly * prohibited. * * TO THE MAXIMUM EXTENT PERMITTED BY APPLICABLE LAW, THIS SOFTWARE IS PROVIDED *AS IS* * AND NVIDIA AND ITS SUPPLIERS DISCLAIM ALL WARRANTIES, EITHER EXPRESS OR IMPLIED, * INCLUDING, BUT NOT LIMITED TO, IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A * PARTICULAR PURPOSE. IN NO EVENT SHALL NVIDIA OR ITS SUPPLIERS BE LIABLE FOR ANY * SPECIAL, INCIDENTAL, INDIRECT, OR CONSEQUENTIAL DAMAGES WHATSOEVER (INCLUDING, WITHOUT * LIMITATION, DAMAGES FOR LOSS OF BUSINESS PROFITS, BUSINESS INTERRUPTION, LOSS OF * BUSINESS INFORMATION, OR ANY OTHER PECUNIARY LOSS) ARISING OUT OF THE USE OF OR * INABILITY TO USE THIS SOFTWARE, EVEN IF NVIDIA HAS BEEN ADVISED OF THE POSSIBILITY OF * SUCH DAMAGES */ /** * @file optix_device.h * @author NVIDIA Corporation * @brief OptiX public API * * OptiX public API Reference - Host/Device side */ /******************************************************************************\ * optix_cuda.h * * This file provides the nvcc interface for generating PTX that the OptiX is * capable of parsing and weaving into the final kernel. This is included by * optix.h automatically if compiling device code. It can be included explicitly * in host code if desired. * \******************************************************************************/ #if !defined(__OPTIX_INCLUDE_INTERNAL_HEADERS__) # define __OPTIX_INCLUDE_INTERNAL_HEADERS__ # define __UNDEF_OPTIX_INCLUDE_INTERNAL_HEADERS_OPTIX_DEVICE_H__ #endif #include "optix_7_device.h" #if defined( __UNDEF_OPTIX_INCLUDE_INTERNAL_HEADERS_OPTIX_DEVICE_H__ ) # undef __OPTIX_INCLUDE_INTERNAL_HEADERS__ # undef __UNDEF_OPTIX_INCLUDE_INTERNAL_HEADERS_OPTIX_DEVICE_H__ #endif ================================================ FILE: render/optixutils/include/optix_function_table.h ================================================ /* * Copyright (c) 2021 NVIDIA Corporation. All rights reserved. * * NVIDIA Corporation and its licensors retain all intellectual property and proprietary * rights in and to this software, related documentation and any modifications thereto. * Any use, reproduction, disclosure or distribution of this software and related * documentation without an express license agreement from NVIDIA Corporation is strictly * prohibited. * * TO THE MAXIMUM EXTENT PERMITTED BY APPLICABLE LAW, THIS SOFTWARE IS PROVIDED *AS IS* * AND NVIDIA AND ITS SUPPLIERS DISCLAIM ALL WARRANTIES, EITHER EXPRESS OR IMPLIED, * INCLUDING, BUT NOT LIMITED TO, IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A * PARTICULAR PURPOSE. IN NO EVENT SHALL NVIDIA OR ITS SUPPLIERS BE LIABLE FOR ANY * SPECIAL, INCIDENTAL, INDIRECT, OR CONSEQUENTIAL DAMAGES WHATSOEVER (INCLUDING, WITHOUT * LIMITATION, DAMAGES FOR LOSS OF BUSINESS PROFITS, BUSINESS INTERRUPTION, LOSS OF * BUSINESS INFORMATION, OR ANY OTHER PECUNIARY LOSS) ARISING OUT OF THE USE OF OR * INABILITY TO USE THIS SOFTWARE, EVEN IF NVIDIA HAS BEEN ADVISED OF THE POSSIBILITY OF * SUCH DAMAGES */ /// @file /// @author NVIDIA Corporation /// @brief OptiX public API header #ifndef __optix_optix_function_table_h__ #define __optix_optix_function_table_h__ /// The OptiX ABI version. #define OPTIX_ABI_VERSION 47 #ifndef OPTIX_DEFINE_ABI_VERSION_ONLY #include "optix_types.h" #if !defined( OPTIX_DONT_INCLUDE_CUDA ) // If OPTIX_DONT_INCLUDE_CUDA is defined, cuda driver types must be defined through other // means before including optix headers. #include #endif #ifdef __cplusplus extern "C" { #endif /// \defgroup optix_function_table Function Table /// \brief OptiX Function Table /** \addtogroup optix_function_table @{ */ /// The function table containing all API functions. /// /// See #optixInit() and #optixInitWithHandle(). typedef struct OptixFunctionTable { /// \name Error handling //@ { /// See ::optixGetErrorName(). const char* ( *optixGetErrorName )( OptixResult result ); /// See ::optixGetErrorString(). const char* ( *optixGetErrorString )( OptixResult result ); //@ } /// \name Device context //@ { /// See ::optixDeviceContextCreate(). OptixResult ( *optixDeviceContextCreate )( CUcontext fromContext, const OptixDeviceContextOptions* options, OptixDeviceContext* context ); /// See ::optixDeviceContextDestroy(). OptixResult ( *optixDeviceContextDestroy )( OptixDeviceContext context ); /// See ::optixDeviceContextGetProperty(). OptixResult ( *optixDeviceContextGetProperty )( OptixDeviceContext context, OptixDeviceProperty property, void* value, size_t sizeInBytes ); /// See ::optixDeviceContextSetLogCallback(). OptixResult ( *optixDeviceContextSetLogCallback )( OptixDeviceContext context, OptixLogCallback callbackFunction, void* callbackData, unsigned int callbackLevel ); /// See ::optixDeviceContextSetCacheEnabled(). OptixResult ( *optixDeviceContextSetCacheEnabled )( OptixDeviceContext context, int enabled ); /// See ::optixDeviceContextSetCacheLocation(). OptixResult ( *optixDeviceContextSetCacheLocation )( OptixDeviceContext context, const char* location ); /// See ::optixDeviceContextSetCacheDatabaseSizes(). OptixResult ( *optixDeviceContextSetCacheDatabaseSizes )( OptixDeviceContext context, size_t lowWaterMark, size_t highWaterMark ); /// See ::optixDeviceContextGetCacheEnabled(). OptixResult ( *optixDeviceContextGetCacheEnabled )( OptixDeviceContext context, int* enabled ); /// See ::optixDeviceContextGetCacheLocation(). OptixResult ( *optixDeviceContextGetCacheLocation )( OptixDeviceContext context, char* location, size_t locationSize ); /// See ::optixDeviceContextGetCacheDatabaseSizes(). OptixResult ( *optixDeviceContextGetCacheDatabaseSizes )( OptixDeviceContext context, size_t* lowWaterMark, size_t* highWaterMark ); //@ } /// \name Modules //@ { /// See ::optixModuleCreateFromPTX(). OptixResult ( *optixModuleCreateFromPTX )( OptixDeviceContext context, const OptixModuleCompileOptions* moduleCompileOptions, const OptixPipelineCompileOptions* pipelineCompileOptions, const char* PTX, size_t PTXsize, char* logString, size_t* logStringSize, OptixModule* module ); /// See ::optixModuleDestroy(). OptixResult ( *optixModuleDestroy )( OptixModule module ); /// See ::optixBuiltinISModuleGet(). OptixResult( *optixBuiltinISModuleGet )( OptixDeviceContext context, const OptixModuleCompileOptions* moduleCompileOptions, const OptixPipelineCompileOptions* pipelineCompileOptions, const OptixBuiltinISOptions* builtinISOptions, OptixModule* builtinModule); //@ } /// \name Program groups //@ { /// See ::optixProgramGroupCreate(). OptixResult ( *optixProgramGroupCreate )( OptixDeviceContext context, const OptixProgramGroupDesc* programDescriptions, unsigned int numProgramGroups, const OptixProgramGroupOptions* options, char* logString, size_t* logStringSize, OptixProgramGroup* programGroups ); /// See ::optixProgramGroupDestroy(). OptixResult ( *optixProgramGroupDestroy )( OptixProgramGroup programGroup ); /// See ::optixProgramGroupGetStackSize(). OptixResult ( *optixProgramGroupGetStackSize )( OptixProgramGroup programGroup, OptixStackSizes* stackSizes ); //@ } /// \name Pipeline //@ { /// See ::optixPipelineCreate(). OptixResult ( *optixPipelineCreate )( OptixDeviceContext context, const OptixPipelineCompileOptions* pipelineCompileOptions, const OptixPipelineLinkOptions* pipelineLinkOptions, const OptixProgramGroup* programGroups, unsigned int numProgramGroups, char* logString, size_t* logStringSize, OptixPipeline* pipeline ); /// See ::optixPipelineDestroy(). OptixResult ( *optixPipelineDestroy )( OptixPipeline pipeline ); /// See ::optixPipelineSetStackSize(). OptixResult ( *optixPipelineSetStackSize )( OptixPipeline pipeline, unsigned int directCallableStackSizeFromTraversal, unsigned int directCallableStackSizeFromState, unsigned int continuationStackSize, unsigned int maxTraversableGraphDepth ); //@ } /// \name Acceleration structures //@ { /// See ::optixAccelComputeMemoryUsage(). OptixResult ( *optixAccelComputeMemoryUsage )( OptixDeviceContext context, const OptixAccelBuildOptions* accelOptions, const OptixBuildInput* buildInputs, unsigned int numBuildInputs, OptixAccelBufferSizes* bufferSizes ); /// See ::optixAccelBuild(). OptixResult ( *optixAccelBuild )( OptixDeviceContext context, CUstream stream, const OptixAccelBuildOptions* accelOptions, const OptixBuildInput* buildInputs, unsigned int numBuildInputs, CUdeviceptr tempBuffer, size_t tempBufferSizeInBytes, CUdeviceptr outputBuffer, size_t outputBufferSizeInBytes, OptixTraversableHandle* outputHandle, const OptixAccelEmitDesc* emittedProperties, unsigned int numEmittedProperties ); /// See ::optixAccelGetRelocationInfo(). OptixResult ( *optixAccelGetRelocationInfo )( OptixDeviceContext context, OptixTraversableHandle handle, OptixAccelRelocationInfo* info ); /// See ::optixAccelCheckRelocationCompatibility(). OptixResult ( *optixAccelCheckRelocationCompatibility )( OptixDeviceContext context, const OptixAccelRelocationInfo* info, int* compatible ); /// See ::optixAccelRelocate(). OptixResult ( *optixAccelRelocate )( OptixDeviceContext context, CUstream stream, const OptixAccelRelocationInfo* info, CUdeviceptr instanceTraversableHandles, size_t numInstanceTraversableHandles, CUdeviceptr targetAccel, size_t targetAccelSizeInBytes, OptixTraversableHandle* targetHandle ); /// See ::optixAccelCompact(). OptixResult ( *optixAccelCompact )( OptixDeviceContext context, CUstream stream, OptixTraversableHandle inputHandle, CUdeviceptr outputBuffer, size_t outputBufferSizeInBytes, OptixTraversableHandle* outputHandle ); /// See ::optixConvertPointerToTraversableHandle(). OptixResult ( *optixConvertPointerToTraversableHandle )( OptixDeviceContext onDevice, CUdeviceptr pointer, OptixTraversableType traversableType, OptixTraversableHandle* traversableHandle ); //@ } /// \name Launch //@ { /// See ::optixConvertPointerToTraversableHandle(). OptixResult ( *optixSbtRecordPackHeader )( OptixProgramGroup programGroup, void* sbtRecordHeaderHostPointer ); /// See ::optixConvertPointerToTraversableHandle(). OptixResult ( *optixLaunch )( OptixPipeline pipeline, CUstream stream, CUdeviceptr pipelineParams, size_t pipelineParamsSize, const OptixShaderBindingTable* sbt, unsigned int width, unsigned int height, unsigned int depth ); //@ } /// \name Denoiser //@ { /// See ::optixDenoiserCreate(). OptixResult ( *optixDenoiserCreate )( OptixDeviceContext context, OptixDenoiserModelKind modelKind, const OptixDenoiserOptions* options, OptixDenoiser* returnHandle ); /// See ::optixDenoiserDestroy(). OptixResult ( *optixDenoiserDestroy )( OptixDenoiser handle ); /// See ::optixDenoiserComputeMemoryResources(). OptixResult ( *optixDenoiserComputeMemoryResources )( const OptixDenoiser handle, unsigned int maximumInputWidth, unsigned int maximumInputHeight, OptixDenoiserSizes* returnSizes ); /// See ::optixDenoiserSetup(). OptixResult ( *optixDenoiserSetup )( OptixDenoiser denoiser, CUstream stream, unsigned int inputWidth, unsigned int inputHeight, CUdeviceptr state, size_t stateSizeInBytes, CUdeviceptr scratch, size_t scratchSizeInBytes ); /// See ::optixDenoiserInvoke(). OptixResult ( *optixDenoiserInvoke )( OptixDenoiser denoiser, CUstream stream, const OptixDenoiserParams* params, CUdeviceptr denoiserState, size_t denoiserStateSizeInBytes, const OptixDenoiserGuideLayer * guideLayer, const OptixDenoiserLayer * layers, unsigned int numLayers, unsigned int inputOffsetX, unsigned int inputOffsetY, CUdeviceptr scratch, size_t scratchSizeInBytes ); /// See ::optixDenoiserComputeIntensity(). OptixResult ( *optixDenoiserComputeIntensity )( OptixDenoiser handle, CUstream stream, const OptixImage2D* inputImage, CUdeviceptr outputIntensity, CUdeviceptr scratch, size_t scratchSizeInBytes ); /// See ::optixDenoiserComputeAverageColor(). OptixResult ( *optixDenoiserComputeAverageColor )( OptixDenoiser handle, CUstream stream, const OptixImage2D* inputImage, CUdeviceptr outputAverageColor, CUdeviceptr scratch, size_t scratchSizeInBytes ); /// See ::optixDenoiserCreateWithUserModel(). OptixResult ( *optixDenoiserCreateWithUserModel )( OptixDeviceContext context, const void * data, size_t dataSizeInBytes, OptixDenoiser* returnHandle ); //@ } } OptixFunctionTable; /*@}*/ // end group optix_function_table #ifdef __cplusplus } #endif #endif /* OPTIX_DEFINE_ABI_VERSION_ONLY */ #endif /* __optix_optix_function_table_h__ */ ================================================ FILE: render/optixutils/include/optix_function_table_definition.h ================================================ /* * Copyright (c) 2021 NVIDIA Corporation. All rights reserved. * * NVIDIA Corporation and its licensors retain all intellectual property and proprietary * rights in and to this software, related documentation and any modifications thereto. * Any use, reproduction, disclosure or distribution of this software and related * documentation without an express license agreement from NVIDIA Corporation is strictly * prohibited. * * TO THE MAXIMUM EXTENT PERMITTED BY APPLICABLE LAW, THIS SOFTWARE IS PROVIDED *AS IS* * AND NVIDIA AND ITS SUPPLIERS DISCLAIM ALL WARRANTIES, EITHER EXPRESS OR IMPLIED, * INCLUDING, BUT NOT LIMITED TO, IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A * PARTICULAR PURPOSE. IN NO EVENT SHALL NVIDIA OR ITS SUPPLIERS BE LIABLE FOR ANY * SPECIAL, INCIDENTAL, INDIRECT, OR CONSEQUENTIAL DAMAGES WHATSOEVER (INCLUDING, WITHOUT * LIMITATION, DAMAGES FOR LOSS OF BUSINESS PROFITS, BUSINESS INTERRUPTION, LOSS OF * BUSINESS INFORMATION, OR ANY OTHER PECUNIARY LOSS) ARISING OUT OF THE USE OF OR * INABILITY TO USE THIS SOFTWARE, EVEN IF NVIDIA HAS BEEN ADVISED OF THE POSSIBILITY OF * SUCH DAMAGES */ /// @file /// @author NVIDIA Corporation /// @brief OptiX public API header #ifndef __optix_optix_function_table_definition_h__ #define __optix_optix_function_table_definition_h__ #include "optix_function_table.h" #ifdef __cplusplus extern "C" { #endif /** \addtogroup optix_function_table @{ */ /// If the stubs in optix_stubs.h are used, then the function table needs to be defined in exactly /// one translation unit. This can be achieved by including this header file in that translation /// unit. OptixFunctionTable g_optixFunctionTable; /*@}*/ // end group optix_function_table #ifdef __cplusplus } #endif #endif // __optix_optix_function_table_definition_h__ ================================================ FILE: render/optixutils/include/optix_host.h ================================================ /* * Copyright (c) 2021 NVIDIA Corporation. All rights reserved. * * NVIDIA Corporation and its licensors retain all intellectual property and proprietary * rights in and to this software, related documentation and any modifications thereto. * Any use, reproduction, disclosure or distribution of this software and related * documentation without an express license agreement from NVIDIA Corporation is strictly * prohibited. * * TO THE MAXIMUM EXTENT PERMITTED BY APPLICABLE LAW, THIS SOFTWARE IS PROVIDED *AS IS* * AND NVIDIA AND ITS SUPPLIERS DISCLAIM ALL WARRANTIES, EITHER EXPRESS OR IMPLIED, * INCLUDING, BUT NOT LIMITED TO, IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A * PARTICULAR PURPOSE. IN NO EVENT SHALL NVIDIA OR ITS SUPPLIERS BE LIABLE FOR ANY * SPECIAL, INCIDENTAL, INDIRECT, OR CONSEQUENTIAL DAMAGES WHATSOEVER (INCLUDING, WITHOUT * LIMITATION, DAMAGES FOR LOSS OF BUSINESS PROFITS, BUSINESS INTERRUPTION, LOSS OF * BUSINESS INFORMATION, OR ANY OTHER PECUNIARY LOSS) ARISING OUT OF THE USE OF OR * INABILITY TO USE THIS SOFTWARE, EVEN IF NVIDIA HAS BEEN ADVISED OF THE POSSIBILITY OF * SUCH DAMAGES */ /** * @file optix_host.h * @author NVIDIA Corporation * @brief OptiX public API * * OptiX public API Reference - Host side */ #if !defined(__OPTIX_INCLUDE_INTERNAL_HEADERS__) # define __OPTIX_INCLUDE_INTERNAL_HEADERS__ # define __UNDEF_OPTIX_INCLUDE_INTERNAL_HEADERS_OPTIX_HOST_H__ #endif #include "optix_7_host.h" #if defined( __UNDEF_OPTIX_INCLUDE_INTERNAL_HEADERS_OPTIX_HOST_H__ ) # undef __OPTIX_INCLUDE_INTERNAL_HEADERS__ # undef __UNDEF_OPTIX_INCLUDE_INTERNAL_HEADERS_OPTIX_HOST_H__ #endif ================================================ FILE: render/optixutils/include/optix_stack_size.h ================================================ /* * Copyright (c) 2021 NVIDIA Corporation. All rights reserved. * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions * are met: * * Redistributions of source code must retain the above copyright * notice, this list of conditions and the following disclaimer. * * Redistributions in binary form must reproduce the above copyright * notice, this list of conditions and the following disclaimer in the * documentation and/or other materials provided with the distribution. * * Neither the name of NVIDIA CORPORATION nor the names of its * contributors may be used to endorse or promote products derived * from this software without specific prior written permission. * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY * EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR * PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR * CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, * EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, * PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR * PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY * OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. */ /// @file /// @author NVIDIA Corporation /// @brief OptiX public API header #ifndef __optix_optix_stack_size_h__ #define __optix_optix_stack_size_h__ #include "optix.h" #include #include #ifdef __cplusplus extern "C" { #endif /** \addtogroup optix_utilities @{ */ /// Retrieves direct and continuation stack sizes for each program in the program group and accumulates the upper bounds /// in the correponding output variables based on the semantic type of the program. Before the first invocation of this /// function with a given instance of #OptixStackSizes, the members of that instance should be set to 0. inline OptixResult optixUtilAccumulateStackSizes( OptixProgramGroup programGroup, OptixStackSizes* stackSizes ) { if( !stackSizes ) return OPTIX_ERROR_INVALID_VALUE; OptixStackSizes localStackSizes; OptixResult result = optixProgramGroupGetStackSize( programGroup, &localStackSizes ); if( result != OPTIX_SUCCESS ) return result; stackSizes->cssRG = std::max( stackSizes->cssRG, localStackSizes.cssRG ); stackSizes->cssMS = std::max( stackSizes->cssMS, localStackSizes.cssMS ); stackSizes->cssCH = std::max( stackSizes->cssCH, localStackSizes.cssCH ); stackSizes->cssAH = std::max( stackSizes->cssAH, localStackSizes.cssAH ); stackSizes->cssIS = std::max( stackSizes->cssIS, localStackSizes.cssIS ); stackSizes->cssCC = std::max( stackSizes->cssCC, localStackSizes.cssCC ); stackSizes->dssDC = std::max( stackSizes->dssDC, localStackSizes.dssDC ); return OPTIX_SUCCESS; } /// Computes the stack size values needed to configure a pipeline. /// /// See the programming guide for an explanation of the formula. /// /// \param[in] stackSizes Accumulated stack sizes of all programs in the call graph. /// \param[in] maxTraceDepth Maximum depth of #optixTrace() calls. /// \param[in] maxCCDepth Maximum depth of calls trees of continuation callables. /// \param[in] maxDCDepth Maximum depth of calls trees of direct callables. /// \param[out] directCallableStackSizeFromTraversal Direct stack size requirement for direct callables invoked from /// IS or AH. /// \param[out] directCallableStackSizeFromState Direct stack size requirement for direct callables invoked from /// RG, MS, or CH. /// \param[out] continuationStackSize Continuation stack requirement. inline OptixResult optixUtilComputeStackSizes( const OptixStackSizes* stackSizes, unsigned int maxTraceDepth, unsigned int maxCCDepth, unsigned int maxDCDepth, unsigned int* directCallableStackSizeFromTraversal, unsigned int* directCallableStackSizeFromState, unsigned int* continuationStackSize ) { if( !stackSizes ) return OPTIX_ERROR_INVALID_VALUE; const unsigned int cssRG = stackSizes->cssRG; const unsigned int cssMS = stackSizes->cssMS; const unsigned int cssCH = stackSizes->cssCH; const unsigned int cssAH = stackSizes->cssAH; const unsigned int cssIS = stackSizes->cssIS; const unsigned int cssCC = stackSizes->cssCC; const unsigned int dssDC = stackSizes->dssDC; if( directCallableStackSizeFromTraversal ) *directCallableStackSizeFromTraversal = maxDCDepth * dssDC; if( directCallableStackSizeFromState ) *directCallableStackSizeFromState = maxDCDepth * dssDC; // upper bound on continuation stack used by call trees of continuation callables unsigned int cssCCTree = maxCCDepth * cssCC; // upper bound on continuation stack used by CH or MS programs including the call tree of // continuation callables unsigned int cssCHOrMSPlusCCTree = std::max( cssCH, cssMS ) + cssCCTree; // clang-format off if( continuationStackSize ) *continuationStackSize = cssRG + cssCCTree + ( std::max( maxTraceDepth, 1u ) - 1 ) * cssCHOrMSPlusCCTree + std::min( maxTraceDepth, 1u ) * std::max( cssCHOrMSPlusCCTree, cssIS + cssAH ); // clang-format on return OPTIX_SUCCESS; } /// Computes the stack size values needed to configure a pipeline. /// /// This variant is similar to #optixUtilComputeStackSizes(), except that it expects the values dssDC and /// maxDCDepth split by call site semantic. /// /// See programming guide for an explanation of the formula. /// /// \param[in] stackSizes Accumulated stack sizes of all programs in the call graph. /// \param[in] dssDCFromTraversal Accumulated direct stack size of all DC programs invoked from IS /// or AH. /// \param[in] dssDCFromState Accumulated direct stack size of all DC programs invoked from RG, /// MS, or CH. /// \param[in] maxTraceDepth Maximum depth of #optixTrace() calls. /// \param[in] maxCCDepth Maximum depth of calls trees of continuation callables. /// \param[in] maxDCDepthFromTraversal Maximum depth of calls trees of direct callables invoked from IS /// or AH. /// \param[in] maxDCDepthFromState Maximum depth of calls trees of direct callables invoked from RG, /// MS, or CH. /// \param[out] directCallableStackSizeFromTraversal Direct stack size requirement for direct callables invoked from /// IS or AH. /// \param[out] directCallableStackSizeFromState Direct stack size requirement for direct callables invoked from /// RG, MS, or CH. /// \param[out] continuationStackSize Continuation stack requirement. inline OptixResult optixUtilComputeStackSizesDCSplit( const OptixStackSizes* stackSizes, unsigned int dssDCFromTraversal, unsigned int dssDCFromState, unsigned int maxTraceDepth, unsigned int maxCCDepth, unsigned int maxDCDepthFromTraversal, unsigned int maxDCDepthFromState, unsigned int* directCallableStackSizeFromTraversal, unsigned int* directCallableStackSizeFromState, unsigned int* continuationStackSize ) { if( !stackSizes ) return OPTIX_ERROR_INVALID_VALUE; const unsigned int cssRG = stackSizes->cssRG; const unsigned int cssMS = stackSizes->cssMS; const unsigned int cssCH = stackSizes->cssCH; const unsigned int cssAH = stackSizes->cssAH; const unsigned int cssIS = stackSizes->cssIS; const unsigned int cssCC = stackSizes->cssCC; // use dssDCFromTraversal and dssDCFromState instead of stackSizes->dssDC if( directCallableStackSizeFromTraversal ) *directCallableStackSizeFromTraversal = maxDCDepthFromTraversal * dssDCFromTraversal; if( directCallableStackSizeFromState ) *directCallableStackSizeFromState = maxDCDepthFromState * dssDCFromState; // upper bound on continuation stack used by call trees of continuation callables unsigned int cssCCTree = maxCCDepth * cssCC; // upper bound on continuation stack used by CH or MS programs including the call tree of // continuation callables unsigned int cssCHOrMSPlusCCTree = std::max( cssCH, cssMS ) + cssCCTree; // clang-format off if( continuationStackSize ) *continuationStackSize = cssRG + cssCCTree + ( std::max( maxTraceDepth, 1u ) - 1 ) * cssCHOrMSPlusCCTree + std::min( maxTraceDepth, 1u ) * std::max( cssCHOrMSPlusCCTree, cssIS + cssAH ); // clang-format on return OPTIX_SUCCESS; } /// Computes the stack size values needed to configure a pipeline. /// /// This variant is similar to #optixUtilComputeStackSizes(), except that it expects the value cssCCTree /// instead of cssCC and maxCCDepth. /// /// See programming guide for an explanation of the formula. /// /// \param[in] stackSizes Accumulated stack sizes of all programs in the call graph. /// \param[in] cssCCTree Maximum stack size used by calls trees of continuation callables. /// \param[in] maxTraceDepth Maximum depth of #optixTrace() calls. /// \param[in] maxDCDepth Maximum depth of calls trees of direct callables. /// \param[out] directCallableStackSizeFromTraversal Direct stack size requirement for direct callables invoked from /// IS or AH. /// \param[out] directCallableStackSizeFromState Direct stack size requirement for direct callables invoked from /// RG, MS, or CH. /// \param[out] continuationStackSize Continuation stack requirement. inline OptixResult optixUtilComputeStackSizesCssCCTree( const OptixStackSizes* stackSizes, unsigned int cssCCTree, unsigned int maxTraceDepth, unsigned int maxDCDepth, unsigned int* directCallableStackSizeFromTraversal, unsigned int* directCallableStackSizeFromState, unsigned int* continuationStackSize ) { if( !stackSizes ) return OPTIX_ERROR_INVALID_VALUE; const unsigned int cssRG = stackSizes->cssRG; const unsigned int cssMS = stackSizes->cssMS; const unsigned int cssCH = stackSizes->cssCH; const unsigned int cssAH = stackSizes->cssAH; const unsigned int cssIS = stackSizes->cssIS; // use cssCCTree instead of stackSizes->cssCC and maxCCDepth const unsigned int dssDC = stackSizes->dssDC; if( directCallableStackSizeFromTraversal ) *directCallableStackSizeFromTraversal = maxDCDepth * dssDC; if( directCallableStackSizeFromState ) *directCallableStackSizeFromState = maxDCDepth * dssDC; // upper bound on continuation stack used by CH or MS programs including the call tree of // continuation callables unsigned int cssCHOrMSPlusCCTree = std::max( cssCH, cssMS ) + cssCCTree; // clang-format off if( continuationStackSize ) *continuationStackSize = cssRG + cssCCTree + ( std::max( maxTraceDepth, 1u ) - 1 ) * cssCHOrMSPlusCCTree + std::min( maxTraceDepth, 1u ) * std::max( cssCHOrMSPlusCCTree, cssIS + cssAH ); // clang-format on return OPTIX_SUCCESS; } /// Computes the stack size values needed to configure a pipeline. /// /// This variant is a specialization of #optixUtilComputeStackSizes() for a simple path tracer with the following /// assumptions: There are only two ray types, camera rays and shadow rays. There are only RG, MS, and CH programs, and /// no AH, IS, CC, or DC programs. The camera rays invoke only the miss and closest hit programs MS1 and CH1, /// respectively. The CH1 program might trace shadow rays, which invoke only the miss and closest hit programs MS2 and /// CH2, respectively. /// /// For flexibility, we allow for each of CH1 and CH2 not just one single program group, but an array of programs /// groups, and compute the maximas of the stack size requirements per array. /// /// See programming guide for an explanation of the formula. inline OptixResult optixUtilComputeStackSizesSimplePathTracer( OptixProgramGroup programGroupRG, OptixProgramGroup programGroupMS1, const OptixProgramGroup* programGroupCH1, unsigned int programGroupCH1Count, OptixProgramGroup programGroupMS2, const OptixProgramGroup* programGroupCH2, unsigned int programGroupCH2Count, unsigned int* directCallableStackSizeFromTraversal, unsigned int* directCallableStackSizeFromState, unsigned int* continuationStackSize ) { if( !programGroupCH1 && ( programGroupCH1Count > 0 ) ) return OPTIX_ERROR_INVALID_VALUE; if( !programGroupCH2 && ( programGroupCH2Count > 0 ) ) return OPTIX_ERROR_INVALID_VALUE; OptixResult result; OptixStackSizes stackSizesRG = {}; result = optixProgramGroupGetStackSize( programGroupRG, &stackSizesRG ); if( result != OPTIX_SUCCESS ) return result; OptixStackSizes stackSizesMS1 = {}; result = optixProgramGroupGetStackSize( programGroupMS1, &stackSizesMS1 ); if( result != OPTIX_SUCCESS ) return result; OptixStackSizes stackSizesCH1 = {}; for( unsigned int i = 0; i < programGroupCH1Count; ++i ) { result = optixUtilAccumulateStackSizes( programGroupCH1[i], &stackSizesCH1 ); if( result != OPTIX_SUCCESS ) return result; } OptixStackSizes stackSizesMS2 = {}; result = optixProgramGroupGetStackSize( programGroupMS2, &stackSizesMS2 ); if( result != OPTIX_SUCCESS ) return result; OptixStackSizes stackSizesCH2 = {}; memset( &stackSizesCH2, 0, sizeof( OptixStackSizes ) ); for( unsigned int i = 0; i < programGroupCH2Count; ++i ) { result = optixUtilAccumulateStackSizes( programGroupCH2[i], &stackSizesCH2 ); if( result != OPTIX_SUCCESS ) return result; } const unsigned int cssRG = stackSizesRG.cssRG; const unsigned int cssMS1 = stackSizesMS1.cssMS; const unsigned int cssCH1 = stackSizesCH1.cssCH; const unsigned int cssMS2 = stackSizesMS2.cssMS; const unsigned int cssCH2 = stackSizesCH2.cssCH; // no AH, IS, CC, or DC programs if( directCallableStackSizeFromTraversal ) *directCallableStackSizeFromTraversal = 0; if( directCallableStackSizeFromState ) *directCallableStackSizeFromState = 0; if( continuationStackSize ) *continuationStackSize = cssRG + std::max( cssMS1, cssCH1 + std::max( cssMS2, cssCH2 ) ); return OPTIX_SUCCESS; } /*@}*/ // end group optix_utilities #ifdef __cplusplus } #endif #endif // __optix_optix_stack_size_h__ ================================================ FILE: render/optixutils/include/optix_stubs.h ================================================ /* * Copyright (c) 2021 NVIDIA Corporation. All rights reserved. * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions * are met: * * Redistributions of source code must retain the above copyright * notice, this list of conditions and the following disclaimer. * * Redistributions in binary form must reproduce the above copyright * notice, this list of conditions and the following disclaimer in the * documentation and/or other materials provided with the distribution. * * Neither the name of NVIDIA CORPORATION nor the names of its * contributors may be used to endorse or promote products derived * from this software without specific prior written permission. * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY * EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR * PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR * CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, * EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, * PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR * PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY * OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. */ /// @file /// @author NVIDIA Corporation /// @brief OptiX public API header #ifndef __optix_optix_stubs_h__ #define __optix_optix_stubs_h__ #include "optix_function_table.h" #ifdef _WIN32 #ifndef WIN32_LEAN_AND_MEAN #define WIN32_LEAN_AND_MEAN 1 #endif #include // The cfgmgr32 header is necessary for interrogating driver information in the registry. // For convenience the library is also linked in automatically using the #pragma command. #include #pragma comment( lib, "Cfgmgr32.lib" ) #include #else #include #endif #ifdef __cplusplus extern "C" { #endif // The function table needs to be defined in exactly one translation unit. This can be // achieved by including optix_function_table_definition.h in that translation unit. extern OptixFunctionTable g_optixFunctionTable; #ifdef _WIN32 static void* optixLoadWindowsDllFromName( const char* optixDllName ) { void* handle = NULL; // Get the size of the path first, then allocate unsigned int size = GetSystemDirectoryA( NULL, 0 ); if( size == 0 ) { // Couldn't get the system path size, so bail return NULL; } size_t pathSize = size + 1 + strlen( optixDllName ); char* systemPath = (char*)malloc( pathSize ); if( systemPath == NULL ) return NULL; if( GetSystemDirectoryA( systemPath, size ) != size - 1 ) { // Something went wrong free( systemPath ); return NULL; } strcat( systemPath, "\\" ); strcat( systemPath, optixDllName ); handle = LoadLibraryA( systemPath ); free( systemPath ); if( handle ) return handle; // If we didn't find it, go looking in the register store. Since nvoptix.dll doesn't // have its own registry entry, we are going to look for the opengl driver which lives // next to nvoptix.dll. 0 (null) will be returned if any errors occured. static const char* deviceInstanceIdentifiersGUID = "{4d36e968-e325-11ce-bfc1-08002be10318}"; const ULONG flags = CM_GETIDLIST_FILTER_CLASS | CM_GETIDLIST_FILTER_PRESENT; ULONG deviceListSize = 0; if( CM_Get_Device_ID_List_SizeA( &deviceListSize, deviceInstanceIdentifiersGUID, flags ) != CR_SUCCESS ) { return NULL; } char* deviceNames = (char*)malloc( deviceListSize ); if( deviceNames == NULL ) return NULL; if( CM_Get_Device_ID_ListA( deviceInstanceIdentifiersGUID, deviceNames, deviceListSize, flags ) ) { free( deviceNames ); return NULL; } DEVINST devID = 0; char* dllPath = NULL; // Continue to the next device if errors are encountered. for( char* deviceName = deviceNames; *deviceName; deviceName += strlen( deviceName ) + 1 ) { if( CM_Locate_DevNodeA( &devID, deviceName, CM_LOCATE_DEVNODE_NORMAL ) != CR_SUCCESS ) { continue; } HKEY regKey = 0; if( CM_Open_DevNode_Key( devID, KEY_QUERY_VALUE, 0, RegDisposition_OpenExisting, ®Key, CM_REGISTRY_SOFTWARE ) != CR_SUCCESS ) { continue; } const char* valueName = "OpenGLDriverName"; DWORD valueSize = 0; LSTATUS ret = RegQueryValueExA( regKey, valueName, NULL, NULL, NULL, &valueSize ); if( ret != ERROR_SUCCESS ) { RegCloseKey( regKey ); continue; } char* regValue = (char*)malloc( valueSize ); if( regValue == NULL ) { RegCloseKey( regKey ); continue; } ret = RegQueryValueExA( regKey, valueName, NULL, NULL, (LPBYTE)regValue, &valueSize ); if( ret != ERROR_SUCCESS ) { free( regValue ); RegCloseKey( regKey ); continue; } // Strip the opengl driver dll name from the string then create a new string with // the path and the nvoptix.dll name for( int i = (int) valueSize - 1; i >= 0 && regValue[i] != '\\'; --i ) regValue[i] = '\0'; size_t newPathSize = strlen( regValue ) + strlen( optixDllName ) + 1; dllPath = (char*)malloc( newPathSize ); if( dllPath == NULL ) { free( regValue ); RegCloseKey( regKey ); continue; } strcpy( dllPath, regValue ); strcat( dllPath, optixDllName ); free( regValue ); RegCloseKey( regKey ); handle = LoadLibraryA( (LPCSTR)dllPath ); free( dllPath ); if( handle ) break; } free( deviceNames ); return handle; } static void* optixLoadWindowsDll( ) { return optixLoadWindowsDllFromName( "nvoptix.dll" ); } #endif /// \defgroup optix_utilities Utilities /// \brief OptiX Utilities /** \addtogroup optix_utilities @{ */ /// Loads the OptiX library and initializes the function table used by the stubs below. /// /// If handlePtr is not nullptr, an OS-specific handle to the library will be returned in *handlePtr. /// /// \see #optixUninitWithHandle inline OptixResult optixInitWithHandle( void** handlePtr ) { // Make sure these functions get initialized to zero in case the DLL and function // table can't be loaded g_optixFunctionTable.optixGetErrorName = 0; g_optixFunctionTable.optixGetErrorString = 0; if( !handlePtr ) return OPTIX_ERROR_INVALID_VALUE; #ifdef _WIN32 *handlePtr = optixLoadWindowsDll(); if( !*handlePtr ) return OPTIX_ERROR_LIBRARY_NOT_FOUND; void* symbol = GetProcAddress( (HMODULE)*handlePtr, "optixQueryFunctionTable" ); if( !symbol ) return OPTIX_ERROR_ENTRY_SYMBOL_NOT_FOUND; #else *handlePtr = dlopen( "libnvoptix.so.1", RTLD_NOW ); if( !*handlePtr ) return OPTIX_ERROR_LIBRARY_NOT_FOUND; void* symbol = dlsym( *handlePtr, "optixQueryFunctionTable" ); if( !symbol ) return OPTIX_ERROR_ENTRY_SYMBOL_NOT_FOUND; #endif OptixQueryFunctionTable_t* optixQueryFunctionTable = (OptixQueryFunctionTable_t*)symbol; return optixQueryFunctionTable( OPTIX_ABI_VERSION, 0, 0, 0, &g_optixFunctionTable, sizeof( g_optixFunctionTable ) ); } /// Loads the OptiX library and initializes the function table used by the stubs below. /// /// A variant of #optixInitWithHandle() that does not make the handle to the loaded library available. inline OptixResult optixInit( void ) { void* handle; return optixInitWithHandle( &handle ); } /// Unloads the OptiX library and zeros the function table used by the stubs below. Takes the /// handle returned by optixInitWithHandle. All OptixDeviceContext objects must be destroyed /// before calling this function, or the behavior is undefined. /// /// \see #optixInitWithHandle inline OptixResult optixUninitWithHandle( void* handle ) { if( !handle ) return OPTIX_ERROR_INVALID_VALUE; #ifdef _WIN32 if( !FreeLibrary( (HMODULE)handle ) ) return OPTIX_ERROR_LIBRARY_UNLOAD_FAILURE; #else if( dlclose( handle ) ) return OPTIX_ERROR_LIBRARY_UNLOAD_FAILURE; #endif OptixFunctionTable empty = { 0 }; g_optixFunctionTable = empty; return OPTIX_SUCCESS; } /*@}*/ // end group optix_utilities #ifndef OPTIX_DOXYGEN_SHOULD_SKIP_THIS // Stub functions that forward calls to the corresponding function pointer in the function table. inline const char* optixGetErrorName( OptixResult result ) { if( g_optixFunctionTable.optixGetErrorName ) return g_optixFunctionTable.optixGetErrorName( result ); // If the DLL and symbol table couldn't be loaded, provide a set of error strings // suitable for processing errors related to the DLL loading. switch( result ) { case OPTIX_SUCCESS: return "OPTIX_SUCCESS"; case OPTIX_ERROR_INVALID_VALUE: return "OPTIX_ERROR_INVALID_VALUE"; case OPTIX_ERROR_UNSUPPORTED_ABI_VERSION: return "OPTIX_ERROR_UNSUPPORTED_ABI_VERSION"; case OPTIX_ERROR_FUNCTION_TABLE_SIZE_MISMATCH: return "OPTIX_ERROR_FUNCTION_TABLE_SIZE_MISMATCH"; case OPTIX_ERROR_INVALID_ENTRY_FUNCTION_OPTIONS: return "OPTIX_ERROR_INVALID_ENTRY_FUNCTION_OPTIONS"; case OPTIX_ERROR_LIBRARY_NOT_FOUND: return "OPTIX_ERROR_LIBRARY_NOT_FOUND"; case OPTIX_ERROR_ENTRY_SYMBOL_NOT_FOUND: return "OPTIX_ERROR_ENTRY_SYMBOL_NOT_FOUND"; case OPTIX_ERROR_LIBRARY_UNLOAD_FAILURE: return "OPTIX_ERROR_LIBRARY_UNLOAD_FAILURE"; default: return "Unknown OptixResult code"; } } inline const char* optixGetErrorString( OptixResult result ) { if( g_optixFunctionTable.optixGetErrorString ) return g_optixFunctionTable.optixGetErrorString( result ); // If the DLL and symbol table couldn't be loaded, provide a set of error strings // suitable for processing errors related to the DLL loading. switch( result ) { case OPTIX_SUCCESS: return "Success"; case OPTIX_ERROR_INVALID_VALUE: return "Invalid value"; case OPTIX_ERROR_UNSUPPORTED_ABI_VERSION: return "Unsupported ABI version"; case OPTIX_ERROR_FUNCTION_TABLE_SIZE_MISMATCH: return "Function table size mismatch"; case OPTIX_ERROR_INVALID_ENTRY_FUNCTION_OPTIONS: return "Invalid options to entry function"; case OPTIX_ERROR_LIBRARY_NOT_FOUND: return "Library not found"; case OPTIX_ERROR_ENTRY_SYMBOL_NOT_FOUND: return "Entry symbol not found"; case OPTIX_ERROR_LIBRARY_UNLOAD_FAILURE: return "Library could not be unloaded"; default: return "Unknown OptixResult code"; } } inline OptixResult optixDeviceContextCreate( CUcontext fromContext, const OptixDeviceContextOptions* options, OptixDeviceContext* context ) { return g_optixFunctionTable.optixDeviceContextCreate( fromContext, options, context ); } inline OptixResult optixDeviceContextDestroy( OptixDeviceContext context ) { return g_optixFunctionTable.optixDeviceContextDestroy( context ); } inline OptixResult optixDeviceContextGetProperty( OptixDeviceContext context, OptixDeviceProperty property, void* value, size_t sizeInBytes ) { return g_optixFunctionTable.optixDeviceContextGetProperty( context, property, value, sizeInBytes ); } inline OptixResult optixDeviceContextSetLogCallback( OptixDeviceContext context, OptixLogCallback callbackFunction, void* callbackData, unsigned int callbackLevel ) { return g_optixFunctionTable.optixDeviceContextSetLogCallback( context, callbackFunction, callbackData, callbackLevel ); } inline OptixResult optixDeviceContextSetCacheEnabled( OptixDeviceContext context, int enabled ) { return g_optixFunctionTable.optixDeviceContextSetCacheEnabled( context, enabled ); } inline OptixResult optixDeviceContextSetCacheLocation( OptixDeviceContext context, const char* location ) { return g_optixFunctionTable.optixDeviceContextSetCacheLocation( context, location ); } inline OptixResult optixDeviceContextSetCacheDatabaseSizes( OptixDeviceContext context, size_t lowWaterMark, size_t highWaterMark ) { return g_optixFunctionTable.optixDeviceContextSetCacheDatabaseSizes( context, lowWaterMark, highWaterMark ); } inline OptixResult optixDeviceContextGetCacheEnabled( OptixDeviceContext context, int* enabled ) { return g_optixFunctionTable.optixDeviceContextGetCacheEnabled( context, enabled ); } inline OptixResult optixDeviceContextGetCacheLocation( OptixDeviceContext context, char* location, size_t locationSize ) { return g_optixFunctionTable.optixDeviceContextGetCacheLocation( context, location, locationSize ); } inline OptixResult optixDeviceContextGetCacheDatabaseSizes( OptixDeviceContext context, size_t* lowWaterMark, size_t* highWaterMark ) { return g_optixFunctionTable.optixDeviceContextGetCacheDatabaseSizes( context, lowWaterMark, highWaterMark ); } inline OptixResult optixModuleCreateFromPTX( OptixDeviceContext context, const OptixModuleCompileOptions* moduleCompileOptions, const OptixPipelineCompileOptions* pipelineCompileOptions, const char* PTX, size_t PTXsize, char* logString, size_t* logStringSize, OptixModule* module ) { return g_optixFunctionTable.optixModuleCreateFromPTX( context, moduleCompileOptions, pipelineCompileOptions, PTX, PTXsize, logString, logStringSize, module ); } inline OptixResult optixModuleDestroy( OptixModule module ) { return g_optixFunctionTable.optixModuleDestroy( module ); } inline OptixResult optixBuiltinISModuleGet( OptixDeviceContext context, const OptixModuleCompileOptions* moduleCompileOptions, const OptixPipelineCompileOptions* pipelineCompileOptions, const OptixBuiltinISOptions* builtinISOptions, OptixModule* builtinModule ) { return g_optixFunctionTable.optixBuiltinISModuleGet( context, moduleCompileOptions, pipelineCompileOptions, builtinISOptions, builtinModule ); } inline OptixResult optixProgramGroupCreate( OptixDeviceContext context, const OptixProgramGroupDesc* programDescriptions, unsigned int numProgramGroups, const OptixProgramGroupOptions* options, char* logString, size_t* logStringSize, OptixProgramGroup* programGroups ) { return g_optixFunctionTable.optixProgramGroupCreate( context, programDescriptions, numProgramGroups, options, logString, logStringSize, programGroups ); } inline OptixResult optixProgramGroupDestroy( OptixProgramGroup programGroup ) { return g_optixFunctionTable.optixProgramGroupDestroy( programGroup ); } inline OptixResult optixProgramGroupGetStackSize( OptixProgramGroup programGroup, OptixStackSizes* stackSizes ) { return g_optixFunctionTable.optixProgramGroupGetStackSize( programGroup, stackSizes ); } inline OptixResult optixPipelineCreate( OptixDeviceContext context, const OptixPipelineCompileOptions* pipelineCompileOptions, const OptixPipelineLinkOptions* pipelineLinkOptions, const OptixProgramGroup* programGroups, unsigned int numProgramGroups, char* logString, size_t* logStringSize, OptixPipeline* pipeline ) { return g_optixFunctionTable.optixPipelineCreate( context, pipelineCompileOptions, pipelineLinkOptions, programGroups, numProgramGroups, logString, logStringSize, pipeline ); } inline OptixResult optixPipelineDestroy( OptixPipeline pipeline ) { return g_optixFunctionTable.optixPipelineDestroy( pipeline ); } inline OptixResult optixPipelineSetStackSize( OptixPipeline pipeline, unsigned int directCallableStackSizeFromTraversal, unsigned int directCallableStackSizeFromState, unsigned int continuationStackSize, unsigned int maxTraversableGraphDepth ) { return g_optixFunctionTable.optixPipelineSetStackSize( pipeline, directCallableStackSizeFromTraversal, directCallableStackSizeFromState, continuationStackSize, maxTraversableGraphDepth ); } inline OptixResult optixAccelComputeMemoryUsage( OptixDeviceContext context, const OptixAccelBuildOptions* accelOptions, const OptixBuildInput* buildInputs, unsigned int numBuildInputs, OptixAccelBufferSizes* bufferSizes ) { return g_optixFunctionTable.optixAccelComputeMemoryUsage( context, accelOptions, buildInputs, numBuildInputs, bufferSizes ); } inline OptixResult optixAccelBuild( OptixDeviceContext context, CUstream stream, const OptixAccelBuildOptions* accelOptions, const OptixBuildInput* buildInputs, unsigned int numBuildInputs, CUdeviceptr tempBuffer, size_t tempBufferSizeInBytes, CUdeviceptr outputBuffer, size_t outputBufferSizeInBytes, OptixTraversableHandle* outputHandle, const OptixAccelEmitDesc* emittedProperties, unsigned int numEmittedProperties ) { return g_optixFunctionTable.optixAccelBuild( context, stream, accelOptions, buildInputs, numBuildInputs, tempBuffer, tempBufferSizeInBytes, outputBuffer, outputBufferSizeInBytes, outputHandle, emittedProperties, numEmittedProperties ); } inline OptixResult optixAccelGetRelocationInfo( OptixDeviceContext context, OptixTraversableHandle handle, OptixAccelRelocationInfo* info ) { return g_optixFunctionTable.optixAccelGetRelocationInfo( context, handle, info ); } inline OptixResult optixAccelCheckRelocationCompatibility( OptixDeviceContext context, const OptixAccelRelocationInfo* info, int* compatible ) { return g_optixFunctionTable.optixAccelCheckRelocationCompatibility( context, info, compatible ); } inline OptixResult optixAccelRelocate( OptixDeviceContext context, CUstream stream, const OptixAccelRelocationInfo* info, CUdeviceptr instanceTraversableHandles, size_t numInstanceTraversableHandles, CUdeviceptr targetAccel, size_t targetAccelSizeInBytes, OptixTraversableHandle* targetHandle ) { return g_optixFunctionTable.optixAccelRelocate( context, stream, info, instanceTraversableHandles, numInstanceTraversableHandles, targetAccel, targetAccelSizeInBytes, targetHandle ); } inline OptixResult optixAccelCompact( OptixDeviceContext context, CUstream stream, OptixTraversableHandle inputHandle, CUdeviceptr outputBuffer, size_t outputBufferSizeInBytes, OptixTraversableHandle* outputHandle ) { return g_optixFunctionTable.optixAccelCompact( context, stream, inputHandle, outputBuffer, outputBufferSizeInBytes, outputHandle ); } inline OptixResult optixConvertPointerToTraversableHandle( OptixDeviceContext onDevice, CUdeviceptr pointer, OptixTraversableType traversableType, OptixTraversableHandle* traversableHandle ) { return g_optixFunctionTable.optixConvertPointerToTraversableHandle( onDevice, pointer, traversableType, traversableHandle ); } inline OptixResult optixSbtRecordPackHeader( OptixProgramGroup programGroup, void* sbtRecordHeaderHostPointer ) { return g_optixFunctionTable.optixSbtRecordPackHeader( programGroup, sbtRecordHeaderHostPointer ); } inline OptixResult optixLaunch( OptixPipeline pipeline, CUstream stream, CUdeviceptr pipelineParams, size_t pipelineParamsSize, const OptixShaderBindingTable* sbt, unsigned int width, unsigned int height, unsigned int depth ) { return g_optixFunctionTable.optixLaunch( pipeline, stream, pipelineParams, pipelineParamsSize, sbt, width, height, depth ); } inline OptixResult optixDenoiserCreate( OptixDeviceContext context, OptixDenoiserModelKind modelKind, const OptixDenoiserOptions* options, OptixDenoiser* returnHandle ) { return g_optixFunctionTable.optixDenoiserCreate( context, modelKind, options, returnHandle ); } inline OptixResult optixDenoiserCreateWithUserModel( OptixDeviceContext context, const void* data, size_t dataSizeInBytes, OptixDenoiser* returnHandle ) { return g_optixFunctionTable.optixDenoiserCreateWithUserModel( context, data, dataSizeInBytes, returnHandle ); } inline OptixResult optixDenoiserDestroy( OptixDenoiser handle ) { return g_optixFunctionTable.optixDenoiserDestroy( handle ); } inline OptixResult optixDenoiserComputeMemoryResources( const OptixDenoiser handle, unsigned int maximumInputWidth, unsigned int maximumInputHeight, OptixDenoiserSizes* returnSizes ) { return g_optixFunctionTable.optixDenoiserComputeMemoryResources( handle, maximumInputWidth, maximumInputHeight, returnSizes ); } inline OptixResult optixDenoiserSetup( OptixDenoiser denoiser, CUstream stream, unsigned int inputWidth, unsigned int inputHeight, CUdeviceptr denoiserState, size_t denoiserStateSizeInBytes, CUdeviceptr scratch, size_t scratchSizeInBytes ) { return g_optixFunctionTable.optixDenoiserSetup( denoiser, stream, inputWidth, inputHeight, denoiserState, denoiserStateSizeInBytes, scratch, scratchSizeInBytes ); } inline OptixResult optixDenoiserInvoke( OptixDenoiser handle, CUstream stream, const OptixDenoiserParams* params, CUdeviceptr denoiserData, size_t denoiserDataSize, const OptixDenoiserGuideLayer* guideLayer, const OptixDenoiserLayer* layers, unsigned int numLayers, unsigned int inputOffsetX, unsigned int inputOffsetY, CUdeviceptr scratch, size_t scratchSizeInBytes ) { return g_optixFunctionTable.optixDenoiserInvoke( handle, stream, params, denoiserData, denoiserDataSize, guideLayer, layers, numLayers, inputOffsetX, inputOffsetY, scratch, scratchSizeInBytes ); } inline OptixResult optixDenoiserComputeIntensity( OptixDenoiser handle, CUstream stream, const OptixImage2D* inputImage, CUdeviceptr outputIntensity, CUdeviceptr scratch, size_t scratchSizeInBytes ) { return g_optixFunctionTable.optixDenoiserComputeIntensity( handle, stream, inputImage, outputIntensity, scratch, scratchSizeInBytes ); } inline OptixResult optixDenoiserComputeAverageColor( OptixDenoiser handle, CUstream stream, const OptixImage2D* inputImage, CUdeviceptr outputAverageColor, CUdeviceptr scratch, size_t scratchSizeInBytes ) { return g_optixFunctionTable.optixDenoiserComputeAverageColor( handle, stream, inputImage, outputAverageColor, scratch, scratchSizeInBytes ); } #endif // OPTIX_DOXYGEN_SHOULD_SKIP_THIS #ifdef __cplusplus } #endif #endif // __optix_optix_stubs_h__ ================================================ FILE: render/optixutils/include/optix_types.h ================================================ /* * Copyright (c) 2021 NVIDIA Corporation. All rights reserved. * * NVIDIA Corporation and its licensors retain all intellectual property and proprietary * rights in and to this software, related documentation and any modifications thereto. * Any use, reproduction, disclosure or distribution of this software and related * documentation without an express license agreement from NVIDIA Corporation is strictly * prohibited. * * TO THE MAXIMUM EXTENT PERMITTED BY APPLICABLE LAW, THIS SOFTWARE IS PROVIDED *AS IS* * AND NVIDIA AND ITS SUPPLIERS DISCLAIM ALL WARRANTIES, EITHER EXPRESS OR IMPLIED, * INCLUDING, BUT NOT LIMITED TO, IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A * PARTICULAR PURPOSE. IN NO EVENT SHALL NVIDIA OR ITS SUPPLIERS BE LIABLE FOR ANY * SPECIAL, INCIDENTAL, INDIRECT, OR CONSEQUENTIAL DAMAGES WHATSOEVER (INCLUDING, WITHOUT * LIMITATION, DAMAGES FOR LOSS OF BUSINESS PROFITS, BUSINESS INTERRUPTION, LOSS OF * BUSINESS INFORMATION, OR ANY OTHER PECUNIARY LOSS) ARISING OUT OF THE USE OF OR * INABILITY TO USE THIS SOFTWARE, EVEN IF NVIDIA HAS BEEN ADVISED OF THE POSSIBILITY OF * SUCH DAMAGES */ /** * @file optix_types.h * @author NVIDIA Corporation * @brief OptiX public API header * */ #ifndef __optix_optix_types_h__ #define __optix_optix_types_h__ // clang-format off #if !defined(__OPTIX_INCLUDE_INTERNAL_HEADERS__) # define __OPTIX_INCLUDE_INTERNAL_HEADERS__ # define __UNDEF_OPTIX_INCLUDE_INTERNAL_HEADERS_OPTIX_TYPES_H__ #endif #include "optix_7_types.h" #if defined( __UNDEF_OPTIX_INCLUDE_INTERNAL_HEADERS_OPTIX_TYPES_H__ ) # undef __OPTIX_INCLUDE_INTERNAL_HEADERS__ # undef __UNDEF_OPTIX_INCLUDE_INTERNAL_HEADERS_OPTIX_TYPES_H__ #endif // clang-format on #endif // #ifndef __optix_optix_types_h__ ================================================ FILE: render/optixutils/ops.py ================================================ # Copyright (c) 2020-2022, NVIDIA CORPORATION. All rights reserved. # # NVIDIA CORPORATION and its licensors retain all intellectual property # and proprietary rights in and to this software, related documentation # and any modifications thereto. Any use, reproduction, disclosure or # distribution of this software and related documentation without an express # license agreement from NVIDIA CORPORATION is strictly prohibited. import numpy as np import os import sys import torch import torch.utils.cpp_extension #---------------------------------------------------------------------------- # C++/Cuda plugin compiler/loader. _plugin = None if _plugin is None: # Make sure we can find the necessary compiler and libary binaries. if os.name == 'nt': optix_include_dir = os.path.dirname(__file__) + r"\include" def find_cl_path(): import glob for edition in ['Enterprise', 'Professional', 'BuildTools', 'Community']: vs_editions = glob.glob(r"C:\Program Files (x86)\Microsoft Visual Studio\*\%s\VC\Tools\MSVC\*\bin\Hostx64\x64" % edition) \ + glob.glob(r"C:\Program Files\Microsoft Visual Studio\*\%s\VC\Tools\MSVC\*\bin\Hostx64\x64" % edition) paths = sorted(vs_editions, reverse=True) if paths: return paths[0] # If cl.exe is not on path, try to find it. if os.system("where cl.exe >nul 2>nul") != 0: cl_path = find_cl_path() if cl_path is None: raise RuntimeError("Could not locate a supported Microsoft Visual C++ installation") os.environ['PATH'] += ';' + cl_path elif os.name == 'posix': optix_include_dir = os.path.dirname(__file__) + r"/include" include_paths = [optix_include_dir] # Compiler options. opts = ['-DNVDR_TORCH'] # Linker options. if os.name == 'posix': ldflags = ['-lcuda', '-lnvrtc'] elif os.name == 'nt': ldflags = ['cuda.lib', 'advapi32.lib', 'nvrtc.lib'] # List of sources. source_files = [ 'c_src/denoising.cu', 'c_src/optix_wrapper.cpp', 'c_src/torch_bindings.cpp' ] # Some containers set this to contain old architectures that won't compile. We only need the one installed in the machine. os.environ['TORCH_CUDA_ARCH_LIST'] = '' # Compile and load. build_dir = os.path.join(os. path. dirname(__file__), 'build') os.makedirs(build_dir, exist_ok=True) source_paths = [os.path.join(os.path.dirname(__file__), fn) for fn in source_files] torch.utils.cpp_extension.load(name='optixutils_plugin', sources=source_paths, extra_cflags=opts, build_directory=build_dir, extra_cuda_cflags=opts, extra_ldflags=ldflags, extra_include_paths=include_paths, with_cuda=True, verbose=True) # Import, cache, and return the compiled module. import optixutils_plugin _plugin = optixutils_plugin #---------------------------------------------------------------------------- # OptiX autograd func #---------------------------------------------------------------------------- class _optix_env_shade_func(torch.autograd.Function): _random_perm = {} @staticmethod def forward(ctx, optix_ctx, mask, ro, gb_pos, gb_normal, gb_view_pos, gb_kd, gb_ks, light, pdf, rows, cols, BSDF, n_samples_x, rnd_seed, shadow_scale): _rnd_seed = np.random.randint(2**31) if rnd_seed is None else rnd_seed if n_samples_x not in _optix_env_shade_func._random_perm: # Generate (32k) tables with random permutations to decorrelate the BSDF and light stratified samples _optix_env_shade_func._random_perm[n_samples_x] = torch.argsort(torch.rand(32768, n_samples_x * n_samples_x, device="cuda"), dim=-1).int() diff, spec = _plugin.env_shade_fwd(optix_ctx.cpp_wrapper, mask, ro, gb_pos, gb_normal, gb_view_pos, gb_kd, gb_ks, light, pdf, rows, cols, _optix_env_shade_func._random_perm[n_samples_x], BSDF, n_samples_x, _rnd_seed, shadow_scale) ctx.save_for_backward(mask, ro, gb_pos, gb_normal, gb_view_pos, gb_kd, gb_ks, light, pdf, rows, cols) ctx.optix_ctx = optix_ctx ctx.BSDF = BSDF ctx.n_samples_x = n_samples_x ctx.rnd_seed = rnd_seed ctx.shadow_scale = shadow_scale return diff, spec @staticmethod def backward(ctx, diff_grad, spec_grad): optix_ctx = ctx.optix_ctx _rnd_seed = np.random.randint(2**31) if ctx.rnd_seed is None else ctx.rnd_seed mask, ro, gb_pos, gb_normal, gb_view_pos, gb_kd, gb_ks, light, pdf, rows, cols = ctx.saved_variables gb_pos_grad, gb_normal_grad, gb_kd_grad, gb_ks_grad, light_grad = _plugin.env_shade_bwd( optix_ctx.cpp_wrapper, mask, ro, gb_pos, gb_normal, gb_view_pos, gb_kd, gb_ks, light, pdf, rows, cols, _optix_env_shade_func._random_perm[ctx.n_samples_x], ctx.BSDF, ctx.n_samples_x, _rnd_seed, ctx.shadow_scale, diff_grad, spec_grad) return None, None, None, gb_pos_grad, gb_normal_grad, None, gb_kd_grad, gb_ks_grad, light_grad, None, None, None, None, None, None, None class _bilateral_denoiser_func(torch.autograd.Function): @staticmethod def forward(ctx, col, nrm, zdz, sigma): ctx.save_for_backward(col, nrm, zdz) ctx.sigma = sigma out = _plugin.bilateral_denoiser_fwd(col, nrm, zdz, sigma) return out @staticmethod def backward(ctx, out_grad): col, nrm, zdz = ctx.saved_variables col_grad = _plugin.bilateral_denoiser_bwd(col, nrm, zdz, ctx.sigma, out_grad) return col_grad, None, None, None #---------------------------------------------------------------------------- # OptiX ray tracing utils #---------------------------------------------------------------------------- class OptiXContext: def __init__(self): print("Cuda path", torch.utils.cpp_extension.CUDA_HOME) self.cpp_wrapper = _plugin.OptiXStateWrapper(os.path.dirname(__file__), torch.utils.cpp_extension.CUDA_HOME) def optix_build_bvh(optix_ctx, verts, tris, rebuild): ''' choose not to raise error since we may have msdf supervision.. should clean the code later ''' # assert tris.shape[0] > 0, "Got empty training triangle mesh (unrecoverable discontinuity)" # assert verts.shape[0] > 0, "Got empty training triangle mesh (unrecoverable discontinuity)" _plugin.optix_build_bvh(optix_ctx.cpp_wrapper, verts.view(-1, 3), tris.view(-1, 3), rebuild) def optix_env_shade(optix_ctx, mask, ro, gb_pos, gb_normal, gb_view_pos, gb_kd, gb_ks, light, pdf, rows, cols, BSDF='pbr', n_samples_x=8, rnd_seed=None, shadow_scale=1.0): iBSDF = ['pbr', 'diffuse', 'white'].index(BSDF) # Ordering important, must match the order of the fwd/bwdPbrBSDF kernel. return _optix_env_shade_func.apply(optix_ctx, mask, ro, gb_pos, gb_normal, gb_view_pos, gb_kd, gb_ks, light, pdf, rows, cols, iBSDF, n_samples_x, rnd_seed, shadow_scale) def bilateral_denoiser(col, nrm, zdz, sigma): col_w = _bilateral_denoiser_func.apply(col, nrm, zdz, sigma) return col_w[..., 0:3] / col_w[..., 3:4] ================================================ FILE: render/optixutils/tests/filter_test.py ================================================ # Copyright (c) 2020-2022, NVIDIA CORPORATION. All rights reserved. # # NVIDIA CORPORATION and its licensors retain all intellectual property # and proprietary rights in and to this software, related documentation # and any modifications thereto. Any use, reproduction, disclosure or # distribution of this software and related documentation without an express # license agreement from NVIDIA CORPORATION is strictly prohibited. from pickletools import read_float8 import torch import os import sys import math sys.path.insert(0, os.path.join(sys.path[0], '../..')) import optixutils as ou import numpy as np RES = 1024 DTYPE = torch.float32 def length(x: torch.Tensor, eps: float =1e-20) -> torch.Tensor: return torch.sqrt(torch.clamp(dot(x,x), min=eps)) # Clamp to avoid nan gradients because grad(sqrt(0)) = NaN def safe_normalize(x: torch.Tensor, eps: float =1e-20) -> torch.Tensor: return x / length(x, eps) def dot(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: return torch.sum(x*y, -1, keepdim=True) class BilateralDenoiser(torch.nn.Module): def __init__(self, sigma=1.0): super(BilateralDenoiser, self).__init__() self.set_sigma(sigma) def set_sigma(self, sigma): self.sigma = max(sigma, 0.0001) self.variance = self.sigma**2. self.N = 2 * math.ceil(self.sigma * 2.5) + 1 def forward(self, input): eps = 0.0001 col = input[..., 0:3] nrm = input[..., 3:6] kd = input[..., 6:9] zdz = input[..., 9:11] accum_col = torch.zeros_like(col) accum_w = torch.zeros_like(col[..., 0:1]) for y in range(-self.N, self.N+1): for x in range(-self.N, self.N+1): ty, tx = torch.meshgrid(torch.arange(0, input.shape[1], dtype=torch.float32, device="cuda"), torch.arange(0, input.shape[2], dtype=torch.float32, device="cuda")) tx = tx[None, ..., None] + x ty = ty[None, ..., None] + y dist_sqr = (x**2 + y**2) dist = np.sqrt(dist_sqr) w_xy = np.exp(-dist_sqr / (2 * self.variance)) with torch.no_grad(): nrm_tap = torch.roll(nrm, (-y, -x), (1, 2)) w_normal = torch.pow(torch.clamp(dot(nrm_tap, nrm), min=eps, max=1.0), 128.0) # From SVGF zdz_tap = torch.roll(zdz, (-y, -x), (1, 2)) w_depth = torch.exp(-(torch.abs(zdz_tap[..., 0:1] - zdz[..., 0:1]) / torch.clamp(zdz[..., 1:2] * dist, min=eps)) ) # From SVGF w = w_xy * w_normal * w_depth w = torch.where((tx >= 0) & (tx < input.shape[2]) & (ty >= 0) & (ty < input.shape[1]), w, torch.zeros_like(w)) col_tap = torch.roll(col, (-y, -x), (1, 2)) accum_col += col_tap * w accum_w += w return accum_col / torch.clamp(accum_w, min=eps) def relative_loss(name, ref, cuda): ref = ref.float() cuda = cuda.float() denom = torch.where(ref > 1e-7, ref, torch.ones_like(ref)) relative = torch.abs(ref - cuda) / denom print(name, torch.max(relative).item()) def test_filter(): img_cuda = torch.rand(1, RES, RES, 11, dtype=DTYPE, device='cuda') img_cuda[..., 3:6] = safe_normalize(img_cuda[..., 3:6]) img_ref = img_cuda.clone().detach().requires_grad_(True) img_cuda = img_cuda.clone().detach().requires_grad_(True) target_cuda = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True) target_ref = target_cuda.clone().detach().requires_grad_(True) SIGMA = 2.0 start = torch.cuda.Event(enable_timing=True) end = torch.cuda.Event(enable_timing=True) start.record() denoiser = BilateralDenoiser(sigma=SIGMA) denoised_ref = denoiser.forward(img_ref) ref_loss = torch.nn.MSELoss()(denoised_ref, target_ref) ref_loss.backward() end.record() torch.cuda.synchronize() print("Python:", start.elapsed_time(end)) start.record() denoised_cuda = ou.svgf(img_cuda[..., 0:3], img_cuda[..., 3:6], img_cuda[..., 9:11], img_cuda[..., 6:9], SIGMA) cuda_loss = torch.nn.MSELoss()(denoised_cuda, target_cuda) cuda_loss.backward() end.record() torch.cuda.synchronize() print("CUDA:", start.elapsed_time(end)) print("-------------------------------------------------------------") print(" Filter loss:") print("-------------------------------------------------------------") relative_loss("denoised:", denoised_ref[..., 0:3], denoised_cuda[..., 0:3]) relative_loss("grad:", img_ref.grad[..., 0:3], img_cuda.grad[..., 0:3]) test_filter() ================================================ FILE: render/regularizer.py ================================================ # Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual # property and proprietary rights in and to this material, related # documentation and any modifications thereto. Any use, reproduction, # disclosure or distribution of this material and related documentation # without an express license agreement from NVIDIA CORPORATION or # its affiliates is strictly prohibited. import torch import nvdiffrast.torch as dr from render import util from . import mesh def luma(x): return ((x[..., 0:1] + x[..., 1:2] + x[..., 2:3]) / 3).repeat(1, 1, 1, 3) def value(x): return torch.max(x[..., 0:3], dim=-1, keepdim=True)[0].repeat(1, 1, 1, 3) def chroma_loss(kd, color_ref, lambda_chroma): eps = 0.001 ref_chroma = color_ref[..., 0:3] / torch.clip(value(color_ref), min=eps) opt_chroma = kd[..., 0:3] / torch.clip(value(kd), min=eps) return torch.mean(torch.abs((opt_chroma - ref_chroma) * color_ref[..., 3:])) * lambda_chroma # Diffuse luma regularizer + specular def shading_loss(diffuse_light, specular_light, color_ref, lambda_diffuse, lambda_specular): diffuse_luma = luma(diffuse_light) specular_luma = luma(specular_light) ref_luma = value(color_ref) eps = 0.001 img = util.rgb_to_srgb(torch.log(torch.clamp((diffuse_luma + specular_luma) * color_ref[..., 3:], min=0, max=65535) + 1)) target = util.rgb_to_srgb(torch.log(torch.clamp(ref_luma * color_ref[..., 3:], min=0, max=65535) + 1)) # error = torch.abs(img - target) * diffuse_luma / torch.clamp(diffuse_luma + specular_luma, min=eps) ### encourage specular component to take control error = torch.abs(img - target) ### the original version in the paper loss = torch.mean(error) * lambda_diffuse loss += torch.mean(specular_luma) / torch.clamp(torch.mean(diffuse_luma), min=eps) * lambda_specular return loss ###################################################################################### # Material smoothness loss ###################################################################################### def material_smoothness_grad(kd_grad, ks_grad, nrm_grad, lambda_kd=0.25, lambda_ks=0.1, lambda_nrm=0.0): kd_luma_grad = (kd_grad[..., 0] + kd_grad[..., 1] + kd_grad[..., 2]) / 3 loss = torch.mean(kd_luma_grad * kd_grad[..., -1]) * lambda_kd loss += torch.mean(ks_grad[..., :-1] * ks_grad[..., -1:]) * lambda_ks loss += torch.mean(nrm_grad[..., :-1] * nrm_grad[..., -1:]) * lambda_nrm return loss ###################################################################################### # Computes the image gradient, useful for kd/ks smoothness losses ###################################################################################### def image_grad(buf, std=0.01): t, s = torch.meshgrid(torch.linspace(-1.0 + 1.0 / buf.shape[1], 1.0 - 1.0 / buf.shape[1], buf.shape[1], device="cuda"), torch.linspace(-1.0 + 1.0 / buf.shape[2], 1.0 - 1.0 / buf.shape[2], buf.shape[2], device="cuda"), indexing='ij') tc = torch.normal(mean=0, std=std, size=(buf.shape[0], buf.shape[1], buf.shape[2], 2), device="cuda") + torch.stack((s, t), dim=-1)[None, ...] tap = dr.texture(buf, tc, filter_mode='linear', boundary_mode='clamp') return torch.abs(tap[..., :-1] - buf[..., :-1]) * tap[..., -1:] * buf[..., -1:] ###################################################################################### # Computes the avergage edge length of a mesh. # Rough estimate of the tessellation of a mesh. Can be used e.g. to clamp gradients ###################################################################################### def avg_edge_length(v_pos, t_pos_idx): e_pos_idx = mesh.compute_edges(t_pos_idx) edge_len = util.length(v_pos[e_pos_idx[:, 0]] - v_pos[e_pos_idx[:, 1]]) return torch.mean(edge_len) ###################################################################################### # Laplacian regularization using umbrella operator (Fujiwara / Desbrun). # https://mgarland.org/class/geom04/material/smoothing.pdf ###################################################################################### def laplace_regularizer_const(v_pos, t_pos_idx): term = torch.zeros_like(v_pos) norm = torch.zeros_like(v_pos[..., 0:1]) v0 = v_pos[t_pos_idx[:, 0], :] v1 = v_pos[t_pos_idx[:, 1], :] v2 = v_pos[t_pos_idx[:, 2], :] term.scatter_add_(0, t_pos_idx[:, 0:1].repeat(1,3), (v1 - v0) + (v2 - v0)) term.scatter_add_(0, t_pos_idx[:, 1:2].repeat(1,3), (v0 - v1) + (v2 - v1)) term.scatter_add_(0, t_pos_idx[:, 2:3].repeat(1,3), (v0 - v2) + (v1 - v2)) two = torch.ones_like(v0) * 2.0 norm.scatter_add_(0, t_pos_idx[:, 0:1], two) norm.scatter_add_(0, t_pos_idx[:, 1:2], two) norm.scatter_add_(0, t_pos_idx[:, 2:3], two) term = term / torch.clamp(norm, min=1.0) return torch.mean(term**2) ###################################################################################### # Smooth vertex normals ###################################################################################### def normal_consistency(v_pos, t_pos_idx): # Compute face normals v0 = v_pos[t_pos_idx[:, 0], :] v1 = v_pos[t_pos_idx[:, 1], :] v2 = v_pos[t_pos_idx[:, 2], :] face_normals = util.safe_normalize(torch.cross(v1 - v0, v2 - v0)) tris_per_edge = mesh.compute_edge_to_face_mapping(t_pos_idx) # Fetch normals for both faces sharind an edge n0 = face_normals[tris_per_edge[:, 0], :] n1 = face_normals[tris_per_edge[:, 1], :] # Compute error metric based on normal difference term = torch.clamp(util.dot(n0, n1), min=-1.0, max=1.0) term = (1.0 - term) * 0.5 return torch.mean(torch.abs(term)) ================================================ FILE: render/render.py ================================================ # Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual # property and proprietary rights in and to this material, related # documentation and any modifications thereto. Any use, reproduction, # disclosure or distribution of this material and related documentation # without an express license agreement from NVIDIA CORPORATION or # its affiliates is strictly prohibited. from threading import local import numpy as np import torch import nvdiffrast.torch as dr from . import util from . import renderutils as ru from . import optixutils as ou from . import light rnd_seed = 0 # ============================================================================================== # Helper functions # ============================================================================================== def interpolate(attr, rast, attr_idx, rast_db=None): return dr.interpolate(attr.contiguous(), rast, attr_idx, rast_db=rast_db, diff_attrs=None if rast_db is None else 'all') # ============================================================================================== # pixel shader # ============================================================================================== def shade( FLAGS, rast, gb_depth, gb_pos, gb_geometric_normal, gb_normal, gb_tangent, gb_texc, gb_texc_deriv, view_pos, lgt, material, optix_ctx, mesh, bsdf, denoiser, shadow_scale, use_uv=True, finetune_normal=True, xfm_lgt=None, shade_data=False ): offset = torch.normal(mean=0, std=0.005, size=(gb_depth.shape[0], gb_depth.shape[1], gb_depth.shape[2], 2), device="cuda") jitter = (util.pixel_grid(gb_depth.shape[2], gb_depth.shape[1])[None, ...] + offset).contiguous() mask = (rast[..., -1:] > 0).float() mask_tap = dr.texture(mask.contiguous(), jitter, filter_mode='linear', boundary_mode='clamp') grad_weight = mask * mask_tap ################################################################################ # Texture lookups ################################################################################ perturbed_nrm = None if 'kd_ks' in material: # Combined texture, used for MLPs because lookups are expensive all_tex_jitter = material['kd_ks'].sample(gb_pos + torch.normal(mean=0, std=0.01, size=gb_pos.shape, device="cuda")) # all_tex_jitter = material['kd_ks'].sample(gb_pos + torch.normal(mean=0, std=0.002, size=gb_pos.shape, device="cuda")) all_tex = material['kd_ks'].sample(gb_pos) assert all_tex.shape[-1] == 6, "Combined kd_ks must be 6 channels" kd, ks = all_tex[..., 0:3], all_tex[..., 3:6] kd_grad = torch.abs(all_tex_jitter[..., 0:3] - kd) ks_grad = torch.abs(all_tex_jitter[..., 3:6] - ks) * torch.tensor([0, 1, 1], dtype=torch.float32, device='cuda')[None, None, None, :] # Omit o-component elif 'kd_ks_normal' in material: raise NotImplementedError else: if shade_data: kd = material['kd'].sample(gb_texc, gb_texc_deriv) ks = material['ks'].sample(gb_texc, gb_texc_deriv)[..., 0:3] # skip alpha if 'normal' in material: perturbed_nrm = material['normal'].sample(gb_texc, gb_texc_deriv) kd_jitter = dr.texture(kd.contiguous(), jitter, filter_mode='linear', boundary_mode='clamp') ks_jitter = dr.texture(ks.contiguous(), jitter, filter_mode='linear', boundary_mode='clamp') kd_grad = torch.abs(kd_jitter - kd) * grad_weight ks_grad = torch.abs(ks_jitter - ks) * torch.tensor([0, 1, 1], dtype=torch.float32, device='cuda')[None, None, None, :] * grad_weight # Omit o-component else: kd = material['kd'].sample(gb_texc, gb_texc_deriv) ks = material['ks'].sample(gb_texc, gb_texc_deriv)[..., 0:3] # skip alpha if 'normal' in material: perturbed_nrm = material['normal'].sample(gb_texc, gb_texc_deriv) kd_jitter = dr.texture(kd.contiguous(), jitter, filter_mode='linear', boundary_mode='clamp') ks_jitter = dr.texture(ks.contiguous(), jitter, filter_mode='linear', boundary_mode='clamp') kd_grad = torch.abs(kd_jitter - kd) * grad_weight ks_grad = torch.abs(ks_jitter - ks) * torch.tensor([0, 1, 1], dtype=torch.float32, device='cuda')[None, None, None, :] * grad_weight # Omit o-component # Separate kd into alpha and color, default alpha = 1 alpha = kd[..., 3:4] if kd.shape[-1] == 4 else torch.ones_like(kd[..., 0:1]) kd = kd[..., 0:3] ################################################################################ # Normal perturbation & normal bend ################################################################################ if (not finetune_normal) or ('no_perturbed_nrm' in material and material['no_perturbed_nrm']): perturbed_nrm = None # Geometric smoothed normal regularizer nrm_jitter = dr.texture(gb_normal.contiguous(), jitter, filter_mode='linear', boundary_mode='clamp') nrm_grad = torch.abs(nrm_jitter - gb_normal) * grad_weight if perturbed_nrm is not None: perturbed_nrm_jitter = dr.texture(perturbed_nrm.contiguous(), jitter, filter_mode='linear', boundary_mode='clamp') perturbed_nrm_grad = 1.0 - util.safe_normalize(util.safe_normalize(perturbed_nrm_jitter) + util.safe_normalize(perturbed_nrm))[..., 2:3] perturbed_nrm_grad = perturbed_nrm_grad.repeat(1,1,1,3) * grad_weight gb_normal = ru.prepare_shading_normal(gb_pos, view_pos, perturbed_nrm, gb_normal, gb_tangent, gb_geometric_normal, two_sided_shading=True, opengl=True) ################################################################################ # Evaluate BSDF ################################################################################ assert 'bsdf' in material or bsdf is not None, "Material must specify a BSDF type" bsdf = material['bsdf'] if bsdf is None else bsdf if bsdf == 'pbr' or bsdf == 'diffuse' or bsdf == 'white': kd = torch.ones_like(kd) if bsdf == 'white' else kd assert isinstance(lgt, light.EnvironmentLight) and optix_ctx is not None ro = gb_pos + gb_normal*0.001 global rnd_seed diffuse_accum, specular_accum = ou.optix_env_shade(optix_ctx, rast[..., -1], ro, gb_pos, gb_normal, view_pos, kd, ks, lgt.base, lgt._pdf, lgt.rows[:,0], lgt.cols, BSDF=bsdf, n_samples_x=FLAGS.n_samples, rnd_seed=None if FLAGS.decorrelated else rnd_seed, shadow_scale=shadow_scale) rnd_seed += 1 # denoise demodulated shaded values if possible if denoiser is not None and FLAGS.denoiser_demodulate: diffuse_accum = denoiser.forward(torch.cat((diffuse_accum, gb_normal, gb_depth), dim=-1)) specular_accum = denoiser.forward(torch.cat((specular_accum, gb_normal, gb_depth), dim=-1)) if bsdf == 'white' or bsdf == 'diffuse': shaded_col = diffuse_accum * kd else: kd = kd * (1.0 - ks[..., 2:3]) # kd * (1.0 - metalness) shaded_col = diffuse_accum * kd + specular_accum # denoise combined shaded values if possible if denoiser is not None and not FLAGS.denoiser_demodulate: shaded_col = denoiser.forward(torch.cat((shaded_col, gb_normal, gb_depth), dim=-1)) elif bsdf == 'normal': shaded_col = (gb_normal + 1.0)*0.5 elif bsdf == 'tangent': shaded_col = (gb_tangent + 1.0)*0.5 elif bsdf == 'kd': shaded_col = kd elif bsdf == 'ks': shaded_col = ks else: assert False, "Invalid BSDF '%s'" % bsdf eps = 1e-8 allone_map = torch.ones_like(alpha) # Return multiple buffers # Setting the `alphas` of depth and invdepth to 1 to avoid double blending # (one with background, the other in antialiasing) buffers = { 'shaded' : torch.cat((shaded_col, alpha), dim=-1), 'z_grad' : torch.cat((gb_depth, torch.zeros_like(alpha), alpha), dim=-1), 'normal' : torch.cat((gb_normal, alpha), dim=-1), 'geometric_normal' : torch.cat((gb_geometric_normal, alpha), dim=-1), 'kd' : torch.cat((kd, alpha), dim=-1), 'ks' : torch.cat((ks, alpha), dim=-1), 'kd_grad' : torch.cat((kd_grad, alpha), dim=-1), 'ks_grad' : torch.cat((ks_grad, alpha), dim=-1), 'normal_grad' : torch.cat((nrm_grad, alpha), dim=-1), # 'depth' : torch.cat(((gb_pos - view_pos).pow(2).sum(dim=-1, keepdim=True).sqrt(), allone_map), dim=-1), # 'invdepth' : torch.cat((1.0 / ((gb_pos - view_pos).pow(2) + eps).sum(dim=-1, keepdim=True).sqrt(), allone_map), dim=-1), } if 'diffuse_accum' in locals(): buffers['diffuse_light'] = torch.cat((diffuse_accum, alpha), dim=-1) if 'specular_accum' in locals(): buffers['specular_light'] = torch.cat((specular_accum, alpha), dim=-1) if perturbed_nrm is not None: buffers['perturbed_nrm'] = torch.cat((perturbed_nrm, alpha), dim=-1) buffers['perturbed_nrm_grad'] = torch.cat((perturbed_nrm_grad, alpha), dim=-1) return buffers # ============================================================================================== # Render a depth slice of the mesh (scene), some limitations: # - Single mesh # - Single light # - Single material # ============================================================================================== def render_layer( FLAGS, v_pos_clip, rast, rast_deriv, mesh, view_pos, lgt, resolution, spp, msaa, optix_ctx, bsdf, denoiser, shadow_scale, use_uv=True, finetune_normal=True, extra_dict=None, xfm_lgt = None, shade_data = False ): full_res = [resolution[0]*spp, resolution[1]*spp] ################################################################################ # Rasterize ################################################################################ # Scale down to shading resolution when MSAA is enabled, otherwise shade at full resolution if spp > 1 and msaa: rast_out_s = util.scale_img_nhwc(rast, resolution, mag='nearest', min='nearest') rast_out_deriv_s = util.scale_img_nhwc(rast_deriv, resolution, mag='nearest', min='nearest') * spp else: rast_out_s = rast rast_out_deriv_s = rast_deriv ################################################################################ # Interpolate attributes ################################################################################ # Interpolate world space position gb_pos, _ = interpolate(mesh.v_pos[None, ...], rast_out_s, mesh.t_pos_idx.int()) # Compute geometric normals. We need those because of bent normals trick (for bump mapping) v0 = mesh.v_pos[mesh.t_pos_idx[:, 0], :] v1 = mesh.v_pos[mesh.t_pos_idx[:, 1], :] v2 = mesh.v_pos[mesh.t_pos_idx[:, 2], :] face_normals = util.safe_normalize(torch.cross(v1 - v0, v2 - v0)) face_normal_indices = (torch.arange(0, face_normals.shape[0], dtype=torch.int64, device='cuda')[:, None]).repeat(1, 3) gb_geometric_normal, _ = interpolate(face_normals[None, ...], rast_out_s, face_normal_indices.int()) if use_uv: # Compute tangent space assert mesh.v_nrm is not None and mesh.v_tng is not None gb_normal, _ = interpolate(mesh.v_nrm[None, ...], rast_out_s, mesh.t_nrm_idx.int()) gb_tangent, _ = interpolate(mesh.v_tng[None, ...], rast_out_s, mesh.t_tng_idx.int()) # Interpolate tangents # Texture coordinate assert mesh.v_tex is not None gb_texc, gb_texc_deriv = interpolate(mesh.v_tex[None, ...], rast_out_s, mesh.t_tex_idx.int(), rast_db=rast_out_deriv_s) else: # Compute tangent space assert mesh.v_nrm is not None gb_normal, _ = interpolate(mesh.v_nrm[None, ...], rast_out_s, mesh.t_nrm_idx.int()) with torch.no_grad(): noise = torch.randn_like(gb_normal) noise = noise / noise.norm(dim=-1, keepdim=True) gb_tangent = torch.cross(noise, gb_normal) ### since we only use tangent for adding isotropic noises but not for uv maps # # Texture coordinate gb_texc, gb_texc_deriv = None, None # Interpolate z and z-gradient with torch.no_grad(): eps = 0.00001 clip_pos, clip_pos_deriv = interpolate(v_pos_clip, rast_out_s, mesh.t_pos_idx.int(), rast_db=rast_out_deriv_s) z0 = torch.clamp(clip_pos[..., 2:3], min=eps) / torch.clamp(clip_pos[..., 3:4], min=eps) z1 = torch.clamp(clip_pos[..., 2:3] + torch.abs(clip_pos_deriv[..., 2:3]), min=eps) / torch.clamp(clip_pos[..., 3:4] + torch.abs(clip_pos_deriv[..., 3:4]), min=eps) z_grad = torch.abs(z1 - z0) gb_depth = torch.cat((z0, z_grad), dim=-1) ################################################################################ # Shade ################################################################################ buffers = shade( FLAGS, rast_out_s, gb_depth, gb_pos, gb_geometric_normal, gb_normal, gb_tangent, gb_texc, gb_texc_deriv, view_pos, lgt, mesh.material, optix_ctx, mesh, bsdf, denoiser, shadow_scale, use_uv=use_uv, finetune_normal=finetune_normal, xfm_lgt=xfm_lgt, shade_data=shade_data ) ################################################################################ # Prepare output ################################################################################ if extra_dict is not None: for key in extra_dict: if key == 'msdf' and extra_dict[key] is not None: assert extra_dict[key].dim() == 1 or (extra_dict[key].dim() == 2 and extra_dict[key].size(1) == 1) buffers['msdf_image'], _ = interpolate(extra_dict[key].squeeze()[None, :, None], rast_out_s, mesh.t_pos_idx.int()) elif key == 'msdf_watertight' and extra_dict[key] is not None: assert extra_dict[key].dim() == 1 or (extra_dict[key].dim() == 2 and extra_dict[key].size(1) == 1) buffers['msdf_watertight_image'], _ = interpolate(extra_dict['msdf_watertight'].squeeze()[None, :, None], rast_out_s.detach(), mesh.t_pos_idx.int()) ## maybe better to stop all gradients to vpos # Scale back up to visibility resolution if using MSAA if spp > 1 and msaa: for key in buffers.keys(): buffers[key] = util.scale_img_nhwc(buffers[key], full_res, mag='nearest', min='nearest') # Return buffers return buffers # ============================================================================================== # Render a depth peeled mesh (scene), some limitations: # - Single mesh # - Single light # - Single material # ============================================================================================== def render_mesh( FLAGS, ctx, mesh, mtx_in, view_pos, lgt, resolution, spp = 1, num_layers = 1, msaa = False, background = None, optix_ctx = None, bsdf = None, denoiser = None, shadow_scale = 1.0, use_uv = True, finetune_normal = True, extra_dict = None, xfm_lgt = None, shade_data = False, ): def prepare_input_vector(x): x = torch.tensor(x, dtype=torch.float32, device='cuda') if not torch.is_tensor(x) else x return x[:, None, None, :] if len(x.shape) == 2 else x def composite_buffer(key, layers, background, antialias): accum = background for buffers, rast in reversed(layers): alpha = (rast[..., -1:] > 0).float() * buffers[key][..., -1:] accum = torch.lerp(accum, torch.cat((buffers[key][..., :-1], torch.ones_like(buffers[key][..., -1:])), dim=-1), alpha) if antialias: accum = dr.antialias(accum.contiguous(), rast, v_pos_clip, mesh.t_pos_idx.int()) return accum ''' choose not to raise error since it is possible that we have msdf supervision. should clean the code later ''' # assert mesh.t_pos_idx.shape[0] > 0, "Got empty training triangle mesh (unrecoverable discontinuity)" # assert background is None or (background.shape[1] == resolution[0] and background.shape[2] == resolution[1]) full_res = [resolution[0]*spp, resolution[1]*spp] # Convert numpy arrays to torch tensors mtx_in = torch.tensor(mtx_in, dtype=torch.float32, device='cuda') if not torch.is_tensor(mtx_in) else mtx_in view_pos = prepare_input_vector(view_pos) # clip space transform v_pos_clip = ru.xfm_points(mesh.v_pos[None, ...], mtx_in) # Render all layers front-to-back with dr.DepthPeeler(ctx, v_pos_clip, mesh.t_pos_idx.int(), full_res) as peeler: assert num_layers == 1 rast, db = peeler.rasterize_next_layer() visible_triangles = rast[:,:,:,-1].long().unique() if visible_triangles[0] == 0: visible_triangles = visible_triangles[1:] visible_triangles = visible_triangles - 1 layers = [ (render_layer( FLAGS, v_pos_clip, rast, db, mesh, view_pos, lgt, resolution, spp, msaa, optix_ctx, bsdf, denoiser, shadow_scale, use_uv=use_uv, finetune_normal=finetune_normal, extra_dict=extra_dict, xfm_lgt=xfm_lgt, shade_data=shade_data), rast)] # rast, db = peeler.rasterize_next_layer() # layer_second = [ # (render_layer( # FLAGS, v_pos_clip, # rast, db, mesh, view_pos, lgt, resolution, spp, msaa, # optix_ctx, bsdf, denoiser, shadow_scale, # use_uv=use_uv, finetune_normal=finetune_normal, # extra_dict=extra_dict, # xfm_lgt=xfm_lgt, # shade_data=shade_data), # rast)] # Setup background if background is not None: if spp > 1: background = util.scale_img_nhwc(background, full_res, mag='nearest', min='nearest') background = torch.cat((background, torch.zeros_like(background[..., 0:1])), dim=-1) else: background = torch.zeros(1, full_res[0], full_res[1], 4, dtype=torch.float32, device='cuda') # Composite layers front-to-back out_buffers = {} out_buffers['visible_triangles'] = visible_triangles for key in layers[0][0].keys(): if layers[0][0][key] is None: out_buffers[key] = None continue if key == 'shaded': accum = composite_buffer(key, layers, background, True) elif key == 'depth': continue default_depth = 20.0 accum = composite_buffer(key, layers, torch.ones_like(layers[0][0][key]) * default_depth, True) elif key == 'invdepth': accum = composite_buffer(key, layers, torch.zeros_like(layers[0][0][key]), True) else: accum = composite_buffer(key, layers, torch.zeros_like(layers[0][0][key]), True) # Downscale to framebuffer resolution. Use avg pooling out_buffers[key] = util.avg_pool_nhwc(accum, spp) if spp > 1 else accum # accum = composite_buffer('shaded', layer_second, background.clone(), True) # out_buffers['shaded_second'] = util.avg_pool_nhwc(accum, spp) if spp > 1 else accum # accum = composite_buffer('invdepth', layer_second, torch.zeros_like(layers[0][0]['invdepth']), True) # out_buffers['invdepth_second'] = util.avg_pool_nhwc(accum, spp) if spp > 1 else accum # accum = composite_buffer('depth', layer_second, torch.ones_like(layers[0][0]['depth']) * default_depth, True) # out_buffers['depth_second'] = util.avg_pool_nhwc(accum, spp) if spp > 1 else accum return out_buffers # ============================================================================================== # Render UVs # ============================================================================================== def render_uv(ctx, mesh, resolution, mlp_texture): # clip space transform uv_clip = mesh.v_tex[None, ...]*2.0 - 1.0 # pad to four component coordinate uv_clip4 = torch.cat((uv_clip, torch.zeros_like(uv_clip[...,0:1]), torch.ones_like(uv_clip[...,0:1])), dim = -1) # rasterize rast, _ = dr.rasterize(ctx, uv_clip4, mesh.t_tex_idx.int(), resolution) # Interpolate world space position gb_pos, _ = interpolate(mesh.v_pos[None, ...], rast, mesh.t_pos_idx.int()) # Sample out textures from MLP all_tex = mlp_texture.sample(gb_pos) assert all_tex.shape[-1] == 6, "Combined kd_ks must be 6 channels" return (rast[..., -1:] > 0).float(), all_tex[..., 0:3], all_tex[..., 3:6] ================================================ FILE: render/renderutils/__init__.py ================================================ # Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual # property and proprietary rights in and to this material, related # documentation and any modifications thereto. Any use, reproduction, # disclosure or distribution of this material and related documentation # without an express license agreement from NVIDIA CORPORATION or # its affiliates is strictly prohibited. from .ops import xfm_points, xfm_vectors, image_loss, diffuse_cubemap, specular_cubemap, prepare_shading_normal, lambert, frostbite_diffuse, pbr_specular, pbr_bsdf, _fresnel_shlick, _ndf_ggx, _lambda_ggx, _masking_smith __all__ = ["xfm_vectors", "xfm_points", "image_loss", "diffuse_cubemap","specular_cubemap", "prepare_shading_normal", "lambert", "frostbite_diffuse", "pbr_specular", "pbr_bsdf", "_fresnel_shlick", "_ndf_ggx", "_lambda_ggx", "_masking_smith", ] ================================================ FILE: render/renderutils/bsdf.py ================================================ # Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual # property and proprietary rights in and to this material, related # documentation and any modifications thereto. Any use, reproduction, # disclosure or distribution of this material and related documentation # without an express license agreement from NVIDIA CORPORATION or # its affiliates is strictly prohibited. import math import torch NORMAL_THRESHOLD = 0.1 ################################################################################ # Vector utility functions ################################################################################ def _dot(x, y): return torch.sum(x*y, -1, keepdim=True) def _reflect(x, n): return 2*_dot(x, n)*n - x def _safe_normalize(x): return torch.nn.functional.normalize(x, dim = -1) def _bend_normal(view_vec, smooth_nrm, geom_nrm, two_sided_shading): # Swap normal direction for backfacing surfaces if two_sided_shading: smooth_nrm = torch.where(_dot(geom_nrm, view_vec) > 0, smooth_nrm, -smooth_nrm) geom_nrm = torch.where(_dot(geom_nrm, view_vec) > 0, geom_nrm, -geom_nrm) t = torch.clamp(_dot(view_vec, smooth_nrm) / NORMAL_THRESHOLD, min=0, max=1) return torch.lerp(geom_nrm, smooth_nrm, t) def _perturb_normal(perturbed_nrm, smooth_nrm, smooth_tng, opengl): smooth_bitang = _safe_normalize(torch.cross(smooth_tng, smooth_nrm)) if opengl: shading_nrm = smooth_tng * perturbed_nrm[..., 0:1] - smooth_bitang * perturbed_nrm[..., 1:2] + smooth_nrm * torch.clamp(perturbed_nrm[..., 2:3], min=0.0) else: shading_nrm = smooth_tng * perturbed_nrm[..., 0:1] + smooth_bitang * perturbed_nrm[..., 1:2] + smooth_nrm * torch.clamp(perturbed_nrm[..., 2:3], min=0.0) return _safe_normalize(shading_nrm) def bsdf_prepare_shading_normal(pos, view_pos, perturbed_nrm, smooth_nrm, smooth_tng, geom_nrm, two_sided_shading, opengl): smooth_nrm = _safe_normalize(smooth_nrm) smooth_tng = _safe_normalize(smooth_tng) view_vec = _safe_normalize(view_pos - pos) shading_nrm = _perturb_normal(perturbed_nrm, smooth_nrm, smooth_tng, opengl) return _bend_normal(view_vec, shading_nrm, geom_nrm, two_sided_shading) ################################################################################ # Simple lambertian diffuse BSDF ################################################################################ def bsdf_lambert(nrm, wi): return torch.clamp(_dot(nrm, wi), min=0.0) / math.pi ################################################################################ # Frostbite diffuse ################################################################################ def bsdf_frostbite(nrm, wi, wo, linearRoughness): wiDotN = _dot(wi, nrm) woDotN = _dot(wo, nrm) h = _safe_normalize(wo + wi) wiDotH = _dot(wi, h) energyBias = 0.5 * linearRoughness energyFactor = 1.0 - (0.51 / 1.51) * linearRoughness f90 = energyBias + 2.0 * wiDotH * wiDotH * linearRoughness f0 = 1.0 wiScatter = bsdf_fresnel_shlick(f0, f90, wiDotN) woScatter = bsdf_fresnel_shlick(f0, f90, woDotN) res = wiScatter * woScatter * energyFactor return torch.where((wiDotN > 0.0) & (woDotN > 0.0), res, torch.zeros_like(res)) ################################################################################ # Phong specular, loosely based on mitsuba implementation ################################################################################ def bsdf_phong(nrm, wo, wi, N): dp_r = torch.clamp(_dot(_reflect(wo, nrm), wi), min=0.0, max=1.0) dp_l = torch.clamp(_dot(nrm, wi), min=0.0, max=1.0) return (dp_r ** N) * dp_l * (N + 2) / (2 * math.pi) ################################################################################ # PBR's implementation of GGX specular ################################################################################ specular_epsilon = 1e-4 def bsdf_fresnel_shlick(f0, f90, cosTheta): _cosTheta = torch.clamp(cosTheta, min=specular_epsilon, max=1.0 - specular_epsilon) return f0 + (f90 - f0) * (1.0 - _cosTheta) ** 5.0 def bsdf_ndf_ggx(alphaSqr, cosTheta): _cosTheta = torch.clamp(cosTheta, min=specular_epsilon, max=1.0 - specular_epsilon) d = (_cosTheta * alphaSqr - _cosTheta) * _cosTheta + 1 return alphaSqr / (d * d * math.pi) def bsdf_lambda_ggx(alphaSqr, cosTheta): _cosTheta = torch.clamp(cosTheta, min=specular_epsilon, max=1.0 - specular_epsilon) cosThetaSqr = _cosTheta * _cosTheta tanThetaSqr = (1.0 - cosThetaSqr) / cosThetaSqr res = 0.5 * (torch.sqrt(1 + alphaSqr * tanThetaSqr) - 1.0) return res def bsdf_masking_smith_ggx_correlated(alphaSqr, cosThetaI, cosThetaO): lambdaI = bsdf_lambda_ggx(alphaSqr, cosThetaI) lambdaO = bsdf_lambda_ggx(alphaSqr, cosThetaO) return 1 / (1 + lambdaI + lambdaO) def bsdf_pbr_specular(col, nrm, wo, wi, alpha, min_roughness=0.08): _alpha = torch.clamp(alpha, min=min_roughness*min_roughness, max=1.0) alphaSqr = _alpha * _alpha h = _safe_normalize(wo + wi) woDotN = _dot(wo, nrm) wiDotN = _dot(wi, nrm) woDotH = _dot(wo, h) nDotH = _dot(nrm, h) D = bsdf_ndf_ggx(alphaSqr, nDotH) G = bsdf_masking_smith_ggx_correlated(alphaSqr, woDotN, wiDotN) F = bsdf_fresnel_shlick(col, 1, woDotH) w = F * D * G * 0.25 / torch.clamp(woDotN, min=specular_epsilon) frontfacing = (woDotN > specular_epsilon) & (wiDotN > specular_epsilon) return torch.where(frontfacing, w, torch.zeros_like(w)) def bsdf_pbr(kd, arm, pos, nrm, view_pos, light_pos, min_roughness, BSDF): wo = _safe_normalize(view_pos - pos) wi = _safe_normalize(light_pos - pos) spec_str = arm[..., 0:1] # x component roughness = arm[..., 1:2] # y component metallic = arm[..., 2:3] # z component ks = (0.04 * (1.0 - metallic) + kd * metallic) * (1 - spec_str) kd = kd * (1.0 - metallic) if BSDF == 0: diffuse = kd * bsdf_lambert(nrm, wi) else: diffuse = kd * bsdf_frostbite(nrm, wi, wo, roughness) specular = bsdf_pbr_specular(ks, nrm, wo, wi, roughness*roughness, min_roughness=min_roughness) return diffuse + specular ================================================ FILE: render/renderutils/c_src/bsdf.cu ================================================ /* * Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual * property and proprietary rights in and to this material, related * documentation and any modifications thereto. Any use, reproduction, * disclosure or distribution of this material and related documentation * without an express license agreement from NVIDIA CORPORATION or * its affiliates is strictly prohibited. */ #include "common.h" #include "bsdf.h" #define SPECULAR_EPSILON 1e-4f //------------------------------------------------------------------------ // Lambert functions __device__ inline float fwdLambert(const vec3f nrm, const vec3f wi) { return max(dot(nrm, wi) / M_PI, 0.0f); } __device__ inline void bwdLambert(const vec3f nrm, const vec3f wi, vec3f& d_nrm, vec3f& d_wi, const float d_out) { if (dot(nrm, wi) > 0.0f) bwdDot(nrm, wi, d_nrm, d_wi, d_out / M_PI); } //------------------------------------------------------------------------ // Fresnel Schlick __device__ inline float fwdFresnelSchlick(const float f0, const float f90, const float cosTheta) { float _cosTheta = clamp(cosTheta, SPECULAR_EPSILON, 1.0f - SPECULAR_EPSILON); float scale = powf(1.0f - _cosTheta, 5.0f); return f0 * (1.0f - scale) + f90 * scale; } __device__ inline void bwdFresnelSchlick(const float f0, const float f90, const float cosTheta, float& d_f0, float& d_f90, float& d_cosTheta, const float d_out) { float _cosTheta = clamp(cosTheta, SPECULAR_EPSILON, 1.0f - SPECULAR_EPSILON); float scale = pow(max(1.0f - _cosTheta, 0.0f), 5.0f); d_f0 += d_out * (1.0 - scale); d_f90 += d_out * scale; if (cosTheta >= SPECULAR_EPSILON && cosTheta < 1.0f - SPECULAR_EPSILON) { d_cosTheta += d_out * (f90 - f0) * -5.0f * powf(1.0f - cosTheta, 4.0f); } } __device__ inline vec3f fwdFresnelSchlick(const vec3f f0, const vec3f f90, const float cosTheta) { float _cosTheta = clamp(cosTheta, SPECULAR_EPSILON, 1.0f - SPECULAR_EPSILON); float scale = powf(1.0f - _cosTheta, 5.0f); return f0 * (1.0f - scale) + f90 * scale; } __device__ inline void bwdFresnelSchlick(const vec3f f0, const vec3f f90, const float cosTheta, vec3f& d_f0, vec3f& d_f90, float& d_cosTheta, const vec3f d_out) { float _cosTheta = clamp(cosTheta, SPECULAR_EPSILON, 1.0f - SPECULAR_EPSILON); float scale = pow(max(1.0f - _cosTheta, 0.0f), 5.0f); d_f0 += d_out * (1.0 - scale); d_f90 += d_out * scale; if (cosTheta >= SPECULAR_EPSILON && cosTheta < 1.0f - SPECULAR_EPSILON) { d_cosTheta += sum(d_out * (f90 - f0) * -5.0f * powf(1.0f - cosTheta, 4.0f)); } } //------------------------------------------------------------------------ // Frostbite diffuse __device__ inline float fwdFrostbiteDiffuse(const vec3f nrm, const vec3f wi, const vec3f wo, float linearRoughness) { float wiDotN = dot(wi, nrm); float woDotN = dot(wo, nrm); if (wiDotN > 0.0f && woDotN > 0.0f) { vec3f h = safeNormalize(wo + wi); float wiDotH = dot(wi, h); float energyBias = 0.5f * linearRoughness; float energyFactor = 1.0f - (0.51f / 1.51f) * linearRoughness; float f90 = energyBias + 2.f * wiDotH * wiDotH * linearRoughness; float f0 = 1.f; float wiScatter = fwdFresnelSchlick(f0, f90, wiDotN); float woScatter = fwdFresnelSchlick(f0, f90, woDotN); return wiScatter * woScatter * energyFactor; } else return 0.0f; } __device__ inline void bwdFrostbiteDiffuse(const vec3f nrm, const vec3f wi, const vec3f wo, float linearRoughness, vec3f& d_nrm, vec3f& d_wi, vec3f& d_wo, float &d_linearRoughness, const float d_out) { float wiDotN = dot(wi, nrm); float woDotN = dot(wo, nrm); if (wiDotN > 0.0f && woDotN > 0.0f) { vec3f h = safeNormalize(wo + wi); float wiDotH = dot(wi, h); float energyBias = 0.5f * linearRoughness; float energyFactor = 1.0f - (0.51f / 1.51f) * linearRoughness; float f90 = energyBias + 2.f * wiDotH * wiDotH * linearRoughness; float f0 = 1.f; float wiScatter = fwdFresnelSchlick(f0, f90, wiDotN); float woScatter = fwdFresnelSchlick(f0, f90, woDotN); // -------------- BWD -------------- // Backprop: return wiScatter * woScatter * energyFactor; float d_wiScatter = d_out * woScatter * energyFactor; float d_woScatter = d_out * wiScatter * energyFactor; float d_energyFactor = d_out * wiScatter * woScatter; // Backprop: float woScatter = fwdFresnelSchlick(f0, f90, woDotN); float d_woDotN = 0.0f, d_f0 = 0.0, d_f90 = 0.0f; bwdFresnelSchlick(f0, f90, woDotN, d_f0, d_f90, d_woDotN, d_woScatter); // Backprop: float wiScatter = fwdFresnelSchlick(fd0, fd90, wiDotN); float d_wiDotN = 0.0f; bwdFresnelSchlick(f0, f90, wiDotN, d_f0, d_f90, d_wiDotN, d_wiScatter); // Backprop: float f90 = energyBias + 2.f * wiDotH * wiDotH * linearRoughness; float d_energyBias = d_f90; float d_wiDotH = d_f90 * 4 * wiDotH * linearRoughness; d_linearRoughness += d_f90 * 2 * wiDotH * wiDotH; // Backprop: float energyFactor = 1.0f - (0.51f / 1.51f) * linearRoughness; d_linearRoughness -= (0.51f / 1.51f) * d_energyFactor; // Backprop: float energyBias = 0.5f * linearRoughness; d_linearRoughness += 0.5 * d_energyBias; // Backprop: float wiDotH = dot(wi, h); vec3f d_h(0); bwdDot(wi, h, d_wi, d_h, d_wiDotH); // Backprop: vec3f h = safeNormalize(wo + wi); vec3f d_wo_wi(0); bwdSafeNormalize(wo + wi, d_wo_wi, d_h); d_wi += d_wo_wi; d_wo += d_wo_wi; bwdDot(wo, nrm, d_wo, d_nrm, d_woDotN); bwdDot(wi, nrm, d_wi, d_nrm, d_wiDotN); } } //------------------------------------------------------------------------ // Ndf GGX __device__ inline float fwdNdfGGX(const float alphaSqr, const float cosTheta) { float _cosTheta = clamp(cosTheta, SPECULAR_EPSILON, 1.0f - SPECULAR_EPSILON); float d = (_cosTheta * alphaSqr - _cosTheta) * _cosTheta + 1.0f; return alphaSqr / (d * d * M_PI); } __device__ inline void bwdNdfGGX(const float alphaSqr, const float cosTheta, float& d_alphaSqr, float& d_cosTheta, const float d_out) { // Torch only back propagates if clamp doesn't trigger float _cosTheta = clamp(cosTheta, SPECULAR_EPSILON, 1.0f - SPECULAR_EPSILON); float cosThetaSqr = _cosTheta * _cosTheta; d_alphaSqr += d_out * (1.0f - (alphaSqr + 1.0f) * cosThetaSqr) / (M_PI * powf((alphaSqr - 1.0) * cosThetaSqr + 1.0f, 3.0f)); if (cosTheta > SPECULAR_EPSILON && cosTheta < 1.0f - SPECULAR_EPSILON) { d_cosTheta += d_out * -(4.0f * (alphaSqr - 1.0f) * alphaSqr * cosTheta) / (M_PI * powf((alphaSqr - 1.0) * cosThetaSqr + 1.0f, 3.0f)); } } //------------------------------------------------------------------------ // Lambda GGX __device__ inline float fwdLambdaGGX(const float alphaSqr, const float cosTheta) { float _cosTheta = clamp(cosTheta, SPECULAR_EPSILON, 1.0f - SPECULAR_EPSILON); float cosThetaSqr = _cosTheta * _cosTheta; float tanThetaSqr = (1.0 - cosThetaSqr) / cosThetaSqr; float res = 0.5f * (sqrtf(1.0f + alphaSqr * tanThetaSqr) - 1.0f); return res; } __device__ inline void bwdLambdaGGX(const float alphaSqr, const float cosTheta, float& d_alphaSqr, float& d_cosTheta, const float d_out) { float _cosTheta = clamp(cosTheta, SPECULAR_EPSILON, 1.0f - SPECULAR_EPSILON); float cosThetaSqr = _cosTheta * _cosTheta; float tanThetaSqr = (1.0 - cosThetaSqr) / cosThetaSqr; float res = 0.5f * (sqrtf(1.0f + alphaSqr * tanThetaSqr) - 1.0f); d_alphaSqr += d_out * (0.25 * tanThetaSqr) / sqrtf(alphaSqr * tanThetaSqr + 1.0f); if (cosTheta > SPECULAR_EPSILON && cosTheta < 1.0f - SPECULAR_EPSILON) d_cosTheta += d_out * -(0.5 * alphaSqr) / (powf(_cosTheta, 3.0f) * sqrtf(alphaSqr / cosThetaSqr - alphaSqr + 1.0f)); } //------------------------------------------------------------------------ // Masking GGX __device__ inline float fwdMaskingSmithGGXCorrelated(const float alphaSqr, const float cosThetaI, const float cosThetaO) { float lambdaI = fwdLambdaGGX(alphaSqr, cosThetaI); float lambdaO = fwdLambdaGGX(alphaSqr, cosThetaO); return 1.0f / (1.0f + lambdaI + lambdaO); } __device__ inline void bwdMaskingSmithGGXCorrelated(const float alphaSqr, const float cosThetaI, const float cosThetaO, float& d_alphaSqr, float& d_cosThetaI, float& d_cosThetaO, const float d_out) { // FWD eval float lambdaI = fwdLambdaGGX(alphaSqr, cosThetaI); float lambdaO = fwdLambdaGGX(alphaSqr, cosThetaO); // BWD eval float d_lambdaIO = -d_out / powf(1.0f + lambdaI + lambdaO, 2.0f); bwdLambdaGGX(alphaSqr, cosThetaI, d_alphaSqr, d_cosThetaI, d_lambdaIO); bwdLambdaGGX(alphaSqr, cosThetaO, d_alphaSqr, d_cosThetaO, d_lambdaIO); } //------------------------------------------------------------------------ // GGX specular __device__ vec3f fwdPbrSpecular(const vec3f col, const vec3f nrm, const vec3f wo, const vec3f wi, const float alpha, const float min_roughness) { float _alpha = clamp(alpha, min_roughness * min_roughness, 1.0f); float alphaSqr = _alpha * _alpha; vec3f h = safeNormalize(wo + wi); float woDotN = dot(wo, nrm); float wiDotN = dot(wi, nrm); float woDotH = dot(wo, h); float nDotH = dot(nrm, h); float D = fwdNdfGGX(alphaSqr, nDotH); float G = fwdMaskingSmithGGXCorrelated(alphaSqr, woDotN, wiDotN); vec3f F = fwdFresnelSchlick(col, 1.0f, woDotH); vec3f w = F * D * G * 0.25 / woDotN; bool frontfacing = (woDotN > SPECULAR_EPSILON) & (wiDotN > SPECULAR_EPSILON); return frontfacing ? w : 0.0f; } __device__ void bwdPbrSpecular( const vec3f col, const vec3f nrm, const vec3f wo, const vec3f wi, const float alpha, const float min_roughness, vec3f& d_col, vec3f& d_nrm, vec3f& d_wo, vec3f& d_wi, float& d_alpha, const vec3f d_out) { /////////////////////////////////////////////////////////////////////// // FWD eval float _alpha = clamp(alpha, min_roughness * min_roughness, 1.0f); float alphaSqr = _alpha * _alpha; vec3f h = safeNormalize(wo + wi); float woDotN = dot(wo, nrm); float wiDotN = dot(wi, nrm); float woDotH = dot(wo, h); float nDotH = dot(nrm, h); float D = fwdNdfGGX(alphaSqr, nDotH); float G = fwdMaskingSmithGGXCorrelated(alphaSqr, woDotN, wiDotN); vec3f F = fwdFresnelSchlick(col, 1.0f, woDotH); vec3f w = F * D * G * 0.25 / woDotN; bool frontfacing = (woDotN > SPECULAR_EPSILON) & (wiDotN > SPECULAR_EPSILON); if (frontfacing) { /////////////////////////////////////////////////////////////////////// // BWD eval vec3f d_F = d_out * D * G * 0.25f / woDotN; float d_D = sum(d_out * F * G * 0.25f / woDotN); float d_G = sum(d_out * F * D * 0.25f / woDotN); float d_woDotN = -sum(d_out * F * D * G * 0.25f / (woDotN * woDotN)); vec3f d_f90(0); float d_woDotH(0), d_wiDotN(0), d_nDotH(0), d_alphaSqr(0); bwdFresnelSchlick(col, 1.0f, woDotH, d_col, d_f90, d_woDotH, d_F); bwdMaskingSmithGGXCorrelated(alphaSqr, woDotN, wiDotN, d_alphaSqr, d_woDotN, d_wiDotN, d_G); bwdNdfGGX(alphaSqr, nDotH, d_alphaSqr, d_nDotH, d_D); vec3f d_h(0); bwdDot(nrm, h, d_nrm, d_h, d_nDotH); bwdDot(wo, h, d_wo, d_h, d_woDotH); bwdDot(wi, nrm, d_wi, d_nrm, d_wiDotN); bwdDot(wo, nrm, d_wo, d_nrm, d_woDotN); vec3f d_h_unnorm(0); bwdSafeNormalize(wo + wi, d_h_unnorm, d_h); d_wo += d_h_unnorm; d_wi += d_h_unnorm; if (alpha > min_roughness * min_roughness) d_alpha += d_alphaSqr * 2 * alpha; } } //------------------------------------------------------------------------ // Full PBR BSDF __device__ vec3f fwdPbrBSDF(const vec3f kd, const vec3f arm, const vec3f pos, const vec3f nrm, const vec3f view_pos, const vec3f light_pos, const float min_roughness, int BSDF) { vec3f wo = safeNormalize(view_pos - pos); vec3f wi = safeNormalize(light_pos - pos); float alpha = arm.y * arm.y; vec3f spec_col = (0.04f * (1.0f - arm.z) + kd * arm.z) * (1.0 - arm.x); vec3f diff_col = kd * (1.0f - arm.z); float diff = 0.0f; if (BSDF == 0) diff = fwdLambert(nrm, wi); else diff = fwdFrostbiteDiffuse(nrm, wi, wo, arm.y); vec3f diffuse = diff_col * diff; vec3f specular = fwdPbrSpecular(spec_col, nrm, wo, wi, alpha, min_roughness); return diffuse + specular; } __device__ void bwdPbrBSDF( const vec3f kd, const vec3f arm, const vec3f pos, const vec3f nrm, const vec3f view_pos, const vec3f light_pos, const float min_roughness, int BSDF, vec3f& d_kd, vec3f& d_arm, vec3f& d_pos, vec3f& d_nrm, vec3f& d_view_pos, vec3f& d_light_pos, const vec3f d_out) { //////////////////////////////////////////////////////////////////////// // FWD vec3f _wi = light_pos - pos; vec3f _wo = view_pos - pos; vec3f wi = safeNormalize(_wi); vec3f wo = safeNormalize(_wo); float alpha = arm.y * arm.y; vec3f spec_col = (0.04f * (1.0f - arm.z) + kd * arm.z) * (1.0 - arm.x); vec3f diff_col = kd * (1.0f - arm.z); float diff = 0.0f; if (BSDF == 0) diff = fwdLambert(nrm, wi); else diff = fwdFrostbiteDiffuse(nrm, wi, wo, arm.y); //////////////////////////////////////////////////////////////////////// // BWD float d_alpha(0); vec3f d_spec_col(0), d_wi(0), d_wo(0); bwdPbrSpecular(spec_col, nrm, wo, wi, alpha, min_roughness, d_spec_col, d_nrm, d_wo, d_wi, d_alpha, d_out); float d_diff = sum(diff_col * d_out); if (BSDF == 0) bwdLambert(nrm, wi, d_nrm, d_wi, d_diff); else bwdFrostbiteDiffuse(nrm, wi, wo, arm.y, d_nrm, d_wi, d_wo, d_arm.y, d_diff); // Backprop: diff_col = kd * (1.0f - arm.z) vec3f d_diff_col = d_out * diff; d_kd += d_diff_col * (1.0f - arm.z); d_arm.z -= sum(d_diff_col * kd); // Backprop: spec_col = (0.04f * (1.0f - arm.z) + kd * arm.z) * (1.0 - arm.x) d_kd -= d_spec_col * (arm.x - 1.0f) * arm.z; d_arm.x += sum(d_spec_col * (arm.z * (0.04f - kd) - 0.04f)); d_arm.z -= sum(d_spec_col * (kd - 0.04f) * (arm.x - 1.0f)); // Backprop: alpha = arm.y * arm.y d_arm.y += d_alpha * 2 * arm.y; // Backprop: vec3f wi = safeNormalize(light_pos - pos); vec3f d__wi(0); bwdSafeNormalize(_wi, d__wi, d_wi); d_light_pos += d__wi; d_pos -= d__wi; // Backprop: vec3f wo = safeNormalize(view_pos - pos); vec3f d__wo(0); bwdSafeNormalize(_wo, d__wo, d_wo); d_view_pos += d__wo; d_pos -= d__wo; } //------------------------------------------------------------------------ // Kernels __global__ void LambertFwdKernel(LambertKernelParams p) { // Calculate pixel position. unsigned int px = blockIdx.x * blockDim.x + threadIdx.x; unsigned int py = blockIdx.y * blockDim.y + threadIdx.y; unsigned int pz = blockIdx.z; if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z) return; vec3f nrm = p.nrm.fetch3(px, py, pz); vec3f wi = p.wi.fetch3(px, py, pz); float res = fwdLambert(nrm, wi); p.out.store(px, py, pz, res); } __global__ void LambertBwdKernel(LambertKernelParams p) { // Calculate pixel position. unsigned int px = blockIdx.x * blockDim.x + threadIdx.x; unsigned int py = blockIdx.y * blockDim.y + threadIdx.y; unsigned int pz = blockIdx.z; if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z) return; vec3f nrm = p.nrm.fetch3(px, py, pz); vec3f wi = p.wi.fetch3(px, py, pz); float d_out = p.out.fetch1(px, py, pz); vec3f d_nrm(0), d_wi(0); bwdLambert(nrm, wi, d_nrm, d_wi, d_out); p.nrm.store_grad(px, py, pz, d_nrm); p.wi.store_grad(px, py, pz, d_wi); } __global__ void FrostbiteDiffuseFwdKernel(FrostbiteDiffuseKernelParams p) { // Calculate pixel position. unsigned int px = blockIdx.x * blockDim.x + threadIdx.x; unsigned int py = blockIdx.y * blockDim.y + threadIdx.y; unsigned int pz = blockIdx.z; if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z) return; vec3f nrm = p.nrm.fetch3(px, py, pz); vec3f wi = p.wi.fetch3(px, py, pz); vec3f wo = p.wo.fetch3(px, py, pz); float linearRoughness = p.linearRoughness.fetch1(px, py, pz); float res = fwdFrostbiteDiffuse(nrm, wi, wo, linearRoughness); p.out.store(px, py, pz, res); } __global__ void FrostbiteDiffuseBwdKernel(FrostbiteDiffuseKernelParams p) { // Calculate pixel position. unsigned int px = blockIdx.x * blockDim.x + threadIdx.x; unsigned int py = blockIdx.y * blockDim.y + threadIdx.y; unsigned int pz = blockIdx.z; if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z) return; vec3f nrm = p.nrm.fetch3(px, py, pz); vec3f wi = p.wi.fetch3(px, py, pz); vec3f wo = p.wo.fetch3(px, py, pz); float linearRoughness = p.linearRoughness.fetch1(px, py, pz); float d_out = p.out.fetch1(px, py, pz); float d_linearRoughness = 0.0f; vec3f d_nrm(0), d_wi(0), d_wo(0); bwdFrostbiteDiffuse(nrm, wi, wo, linearRoughness, d_nrm, d_wi, d_wo, d_linearRoughness, d_out); p.nrm.store_grad(px, py, pz, d_nrm); p.wi.store_grad(px, py, pz, d_wi); p.wo.store_grad(px, py, pz, d_wo); p.linearRoughness.store_grad(px, py, pz, d_linearRoughness); } __global__ void FresnelShlickFwdKernel(FresnelShlickKernelParams p) { // Calculate pixel position. unsigned int px = blockIdx.x * blockDim.x + threadIdx.x; unsigned int py = blockIdx.y * blockDim.y + threadIdx.y; unsigned int pz = blockIdx.z; if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z) return; vec3f f0 = p.f0.fetch3(px, py, pz); vec3f f90 = p.f90.fetch3(px, py, pz); float cosTheta = p.cosTheta.fetch1(px, py, pz); vec3f res = fwdFresnelSchlick(f0, f90, cosTheta); p.out.store(px, py, pz, res); } __global__ void FresnelShlickBwdKernel(FresnelShlickKernelParams p) { // Calculate pixel position. unsigned int px = blockIdx.x * blockDim.x + threadIdx.x; unsigned int py = blockIdx.y * blockDim.y + threadIdx.y; unsigned int pz = blockIdx.z; if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z) return; vec3f f0 = p.f0.fetch3(px, py, pz); vec3f f90 = p.f90.fetch3(px, py, pz); float cosTheta = p.cosTheta.fetch1(px, py, pz); vec3f d_out = p.out.fetch3(px, py, pz); vec3f d_f0(0), d_f90(0); float d_cosTheta(0); bwdFresnelSchlick(f0, f90, cosTheta, d_f0, d_f90, d_cosTheta, d_out); p.f0.store_grad(px, py, pz, d_f0); p.f90.store_grad(px, py, pz, d_f90); p.cosTheta.store_grad(px, py, pz, d_cosTheta); } __global__ void ndfGGXFwdKernel(NdfGGXParams p) { // Calculate pixel position. unsigned int px = blockIdx.x * blockDim.x + threadIdx.x; unsigned int py = blockIdx.y * blockDim.y + threadIdx.y; unsigned int pz = blockIdx.z; if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z) return; float alphaSqr = p.alphaSqr.fetch1(px, py, pz); float cosTheta = p.cosTheta.fetch1(px, py, pz); float res = fwdNdfGGX(alphaSqr, cosTheta); p.out.store(px, py, pz, res); } __global__ void ndfGGXBwdKernel(NdfGGXParams p) { // Calculate pixel position. unsigned int px = blockIdx.x * blockDim.x + threadIdx.x; unsigned int py = blockIdx.y * blockDim.y + threadIdx.y; unsigned int pz = blockIdx.z; if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z) return; float alphaSqr = p.alphaSqr.fetch1(px, py, pz); float cosTheta = p.cosTheta.fetch1(px, py, pz); float d_out = p.out.fetch1(px, py, pz); float d_alphaSqr(0), d_cosTheta(0); bwdNdfGGX(alphaSqr, cosTheta, d_alphaSqr, d_cosTheta, d_out); p.alphaSqr.store_grad(px, py, pz, d_alphaSqr); p.cosTheta.store_grad(px, py, pz, d_cosTheta); } __global__ void lambdaGGXFwdKernel(NdfGGXParams p) { // Calculate pixel position. unsigned int px = blockIdx.x * blockDim.x + threadIdx.x; unsigned int py = blockIdx.y * blockDim.y + threadIdx.y; unsigned int pz = blockIdx.z; if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z) return; float alphaSqr = p.alphaSqr.fetch1(px, py, pz); float cosTheta = p.cosTheta.fetch1(px, py, pz); float res = fwdLambdaGGX(alphaSqr, cosTheta); p.out.store(px, py, pz, res); } __global__ void lambdaGGXBwdKernel(NdfGGXParams p) { // Calculate pixel position. unsigned int px = blockIdx.x * blockDim.x + threadIdx.x; unsigned int py = blockIdx.y * blockDim.y + threadIdx.y; unsigned int pz = blockIdx.z; if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z) return; float alphaSqr = p.alphaSqr.fetch1(px, py, pz); float cosTheta = p.cosTheta.fetch1(px, py, pz); float d_out = p.out.fetch1(px, py, pz); float d_alphaSqr(0), d_cosTheta(0); bwdLambdaGGX(alphaSqr, cosTheta, d_alphaSqr, d_cosTheta, d_out); p.alphaSqr.store_grad(px, py, pz, d_alphaSqr); p.cosTheta.store_grad(px, py, pz, d_cosTheta); } __global__ void maskingSmithFwdKernel(MaskingSmithParams p) { // Calculate pixel position. unsigned int px = blockIdx.x * blockDim.x + threadIdx.x; unsigned int py = blockIdx.y * blockDim.y + threadIdx.y; unsigned int pz = blockIdx.z; if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z) return; float alphaSqr = p.alphaSqr.fetch1(px, py, pz); float cosThetaI = p.cosThetaI.fetch1(px, py, pz); float cosThetaO = p.cosThetaO.fetch1(px, py, pz); float res = fwdMaskingSmithGGXCorrelated(alphaSqr, cosThetaI, cosThetaO); p.out.store(px, py, pz, res); } __global__ void maskingSmithBwdKernel(MaskingSmithParams p) { // Calculate pixel position. unsigned int px = blockIdx.x * blockDim.x + threadIdx.x; unsigned int py = blockIdx.y * blockDim.y + threadIdx.y; unsigned int pz = blockIdx.z; if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z) return; float alphaSqr = p.alphaSqr.fetch1(px, py, pz); float cosThetaI = p.cosThetaI.fetch1(px, py, pz); float cosThetaO = p.cosThetaO.fetch1(px, py, pz); float d_out = p.out.fetch1(px, py, pz); float d_alphaSqr(0), d_cosThetaI(0), d_cosThetaO(0); bwdMaskingSmithGGXCorrelated(alphaSqr, cosThetaI, cosThetaO, d_alphaSqr, d_cosThetaI, d_cosThetaO, d_out); p.alphaSqr.store_grad(px, py, pz, d_alphaSqr); p.cosThetaI.store_grad(px, py, pz, d_cosThetaI); p.cosThetaO.store_grad(px, py, pz, d_cosThetaO); } __global__ void pbrSpecularFwdKernel(PbrSpecular p) { // Calculate pixel position. unsigned int px = blockIdx.x * blockDim.x + threadIdx.x; unsigned int py = blockIdx.y * blockDim.y + threadIdx.y; unsigned int pz = blockIdx.z; if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z) return; vec3f col = p.col.fetch3(px, py, pz); vec3f nrm = p.nrm.fetch3(px, py, pz); vec3f wo = p.wo.fetch3(px, py, pz); vec3f wi = p.wi.fetch3(px, py, pz); float alpha = p.alpha.fetch1(px, py, pz); vec3f res = fwdPbrSpecular(col, nrm, wo, wi, alpha, p.min_roughness); p.out.store(px, py, pz, res); } __global__ void pbrSpecularBwdKernel(PbrSpecular p) { // Calculate pixel position. unsigned int px = blockIdx.x * blockDim.x + threadIdx.x; unsigned int py = blockIdx.y * blockDim.y + threadIdx.y; unsigned int pz = blockIdx.z; if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z) return; vec3f col = p.col.fetch3(px, py, pz); vec3f nrm = p.nrm.fetch3(px, py, pz); vec3f wo = p.wo.fetch3(px, py, pz); vec3f wi = p.wi.fetch3(px, py, pz); float alpha = p.alpha.fetch1(px, py, pz); vec3f d_out = p.out.fetch3(px, py, pz); float d_alpha(0); vec3f d_col(0), d_nrm(0), d_wo(0), d_wi(0); bwdPbrSpecular(col, nrm, wo, wi, alpha, p.min_roughness, d_col, d_nrm, d_wo, d_wi, d_alpha, d_out); p.col.store_grad(px, py, pz, d_col); p.nrm.store_grad(px, py, pz, d_nrm); p.wo.store_grad(px, py, pz, d_wo); p.wi.store_grad(px, py, pz, d_wi); p.alpha.store_grad(px, py, pz, d_alpha); } __global__ void pbrBSDFFwdKernel(PbrBSDF p) { // Calculate pixel position. unsigned int px = blockIdx.x * blockDim.x + threadIdx.x; unsigned int py = blockIdx.y * blockDim.y + threadIdx.y; unsigned int pz = blockIdx.z; if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z) return; vec3f kd = p.kd.fetch3(px, py, pz); vec3f arm = p.arm.fetch3(px, py, pz); vec3f pos = p.pos.fetch3(px, py, pz); vec3f nrm = p.nrm.fetch3(px, py, pz); vec3f view_pos = p.view_pos.fetch3(px, py, pz); vec3f light_pos = p.light_pos.fetch3(px, py, pz); vec3f res = fwdPbrBSDF(kd, arm, pos, nrm, view_pos, light_pos, p.min_roughness, p.BSDF); p.out.store(px, py, pz, res); } __global__ void pbrBSDFBwdKernel(PbrBSDF p) { // Calculate pixel position. unsigned int px = blockIdx.x * blockDim.x + threadIdx.x; unsigned int py = blockIdx.y * blockDim.y + threadIdx.y; unsigned int pz = blockIdx.z; if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z) return; vec3f kd = p.kd.fetch3(px, py, pz); vec3f arm = p.arm.fetch3(px, py, pz); vec3f pos = p.pos.fetch3(px, py, pz); vec3f nrm = p.nrm.fetch3(px, py, pz); vec3f view_pos = p.view_pos.fetch3(px, py, pz); vec3f light_pos = p.light_pos.fetch3(px, py, pz); vec3f d_out = p.out.fetch3(px, py, pz); vec3f d_kd(0), d_arm(0), d_pos(0), d_nrm(0), d_view_pos(0), d_light_pos(0); bwdPbrBSDF(kd, arm, pos, nrm, view_pos, light_pos, p.min_roughness, p.BSDF, d_kd, d_arm, d_pos, d_nrm, d_view_pos, d_light_pos, d_out); p.kd.store_grad(px, py, pz, d_kd); p.arm.store_grad(px, py, pz, d_arm); p.pos.store_grad(px, py, pz, d_pos); p.nrm.store_grad(px, py, pz, d_nrm); p.view_pos.store_grad(px, py, pz, d_view_pos); p.light_pos.store_grad(px, py, pz, d_light_pos); } ================================================ FILE: render/renderutils/c_src/bsdf.h ================================================ /* * Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual * property and proprietary rights in and to this material, related * documentation and any modifications thereto. Any use, reproduction, * disclosure or distribution of this material and related documentation * without an express license agreement from NVIDIA CORPORATION or * its affiliates is strictly prohibited. */ #pragma once #include "common.h" struct LambertKernelParams { Tensor nrm; Tensor wi; Tensor out; dim3 gridSize; }; struct FrostbiteDiffuseKernelParams { Tensor nrm; Tensor wi; Tensor wo; Tensor linearRoughness; Tensor out; dim3 gridSize; }; struct FresnelShlickKernelParams { Tensor f0; Tensor f90; Tensor cosTheta; Tensor out; dim3 gridSize; }; struct NdfGGXParams { Tensor alphaSqr; Tensor cosTheta; Tensor out; dim3 gridSize; }; struct MaskingSmithParams { Tensor alphaSqr; Tensor cosThetaI; Tensor cosThetaO; Tensor out; dim3 gridSize; }; struct PbrSpecular { Tensor col; Tensor nrm; Tensor wo; Tensor wi; Tensor alpha; Tensor out; dim3 gridSize; float min_roughness; }; struct PbrBSDF { Tensor kd; Tensor arm; Tensor pos; Tensor nrm; Tensor view_pos; Tensor light_pos; Tensor out; dim3 gridSize; float min_roughness; int BSDF; }; ================================================ FILE: render/renderutils/c_src/common.cpp ================================================ /* * Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual * property and proprietary rights in and to this material, related * documentation and any modifications thereto. Any use, reproduction, * disclosure or distribution of this material and related documentation * without an express license agreement from NVIDIA CORPORATION or * its affiliates is strictly prohibited. */ #include #include //------------------------------------------------------------------------ // Block and grid size calculators for kernel launches. dim3 getLaunchBlockSize(int maxWidth, int maxHeight, dim3 dims) { int maxThreads = maxWidth * maxHeight; if (maxThreads <= 1 || (dims.x * dims.y) <= 1) return dim3(1, 1, 1); // Degenerate. // Start from max size. int bw = maxWidth; int bh = maxHeight; // Optimizations for weirdly sized buffers. if (dims.x < bw) { // Decrease block width to smallest power of two that covers the buffer width. while ((bw >> 1) >= dims.x) bw >>= 1; // Maximize height. bh = maxThreads / bw; if (bh > dims.y) bh = dims.y; } else if (dims.y < bh) { // Halve height and double width until fits completely inside buffer vertically. while (bh > dims.y) { bh >>= 1; if (bw < dims.x) bw <<= 1; } } // Done. return dim3(bw, bh, 1); } // returns the size of a block that can be reduced using horizontal SIMD operations (e.g. __shfl_xor_sync) dim3 getWarpSize(dim3 blockSize) { return dim3( std::min(blockSize.x, 32u), std::min(std::max(32u / blockSize.x, 1u), std::min(32u, blockSize.y)), std::min(std::max(32u / (blockSize.x * blockSize.y), 1u), std::min(32u, blockSize.z)) ); } dim3 getLaunchGridSize(dim3 blockSize, dim3 dims) { dim3 gridSize; gridSize.x = (dims.x - 1) / blockSize.x + 1; gridSize.y = (dims.y - 1) / blockSize.y + 1; gridSize.z = (dims.z - 1) / blockSize.z + 1; return gridSize; } //------------------------------------------------------------------------ ================================================ FILE: render/renderutils/c_src/common.h ================================================ /* * Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual * property and proprietary rights in and to this material, related * documentation and any modifications thereto. Any use, reproduction, * disclosure or distribution of this material and related documentation * without an express license agreement from NVIDIA CORPORATION or * its affiliates is strictly prohibited. */ #pragma once #include #include #include "vec3f.h" #include "vec4f.h" #include "tensor.h" dim3 getLaunchBlockSize(int maxWidth, int maxHeight, dim3 dims); dim3 getLaunchGridSize(dim3 blockSize, dim3 dims); #ifdef __CUDACC__ #ifdef _MSC_VER #define M_PI 3.14159265358979323846f #endif __host__ __device__ static inline dim3 getWarpSize(dim3 blockSize) { return dim3( min(blockSize.x, 32u), min(max(32u / blockSize.x, 1u), min(32u, blockSize.y)), min(max(32u / (blockSize.x * blockSize.y), 1u), min(32u, blockSize.z)) ); } __device__ static inline float clamp(float val, float mn, float mx) { return min(max(val, mn), mx); } #else dim3 getWarpSize(dim3 blockSize); #endif ================================================ FILE: render/renderutils/c_src/cubemap.cu ================================================ /* * Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual * property and proprietary rights in and to this material, related * documentation and any modifications thereto. Any use, reproduction, * disclosure or distribution of this material and related documentation * without an express license agreement from NVIDIA CORPORATION or * its affiliates is strictly prohibited. */ #include "common.h" #include "cubemap.h" #include // https://cgvr.cs.uni-bremen.de/teaching/cg_literatur/Spherical,%20Cubic,%20and%20Parabolic%20Environment%20Mappings.pdf __device__ float pixel_area(int x, int y, int N) { if (N > 1) { int H = N / 2; x = abs(x - H); y = abs(y - H); float dx = atan((float)(x + 1) / (float)H) - atan((float)x / (float)H); float dy = atan((float)(y + 1) / (float)H) - atan((float)y / (float)H); return dx * dy; } else return 1; } __device__ vec3f cube_to_dir(int x, int y, int side, int N) { float fx = 2.0f * (((float)x + 0.5f) / (float)N) - 1.0f; float fy = 2.0f * (((float)y + 0.5f) / (float)N) - 1.0f; switch (side) { case 0: return safeNormalize(vec3f(1, -fy, -fx)); case 1: return safeNormalize(vec3f(-1, -fy, fx)); case 2: return safeNormalize(vec3f(fx, 1, fy)); case 3: return safeNormalize(vec3f(fx, -1, -fy)); case 4: return safeNormalize(vec3f(fx, -fy, 1)); case 5: return safeNormalize(vec3f(-fx, -fy, -1)); } return vec3f(0,0,0); // Unreachable } __device__ vec3f dir_to_side(int side, vec3f v) { switch (side) { case 0: return vec3f(-v.z, -v.y, v.x); case 1: return vec3f( v.z, -v.y, -v.x); case 2: return vec3f( v.x, v.z, v.y); case 3: return vec3f( v.x, -v.z, -v.y); case 4: return vec3f( v.x, -v.y, v.z); case 5: return vec3f(-v.x, -v.y, -v.z); } return vec3f(0,0,0); // Unreachable } __device__ void extents_1d(float x, float z, float theta, float& _min, float& _max) { float l = sqrtf(x * x + z * z); float pxr = x + z * tan(theta) * l, pzr = z - x * tan(theta) * l; float pxl = x - z * tan(theta) * l, pzl = z + x * tan(theta) * l; if (pzl <= 0.00001f) _min = pxl > 0.0f ? FLT_MAX : -FLT_MAX; else _min = pxl / pzl; if (pzr <= 0.00001f) _max = pxr > 0.0f ? FLT_MAX : -FLT_MAX; else _max = pxr / pzr; } __device__ void dir_extents(int side, int N, vec3f v, float theta, int &_xmin, int& _xmax, int& _ymin, int& _ymax) { vec3f c = dir_to_side(side, v); // remap to (x,y,z) where side is at z = 1 if (theta < 0.785398f) // PI/4 { float xmin, xmax, ymin, ymax; extents_1d(c.x, c.z, theta, xmin, xmax); extents_1d(c.y, c.z, theta, ymin, ymax); if (xmin > 1.0f || xmax < -1.0f || ymin > 1.0f || ymax < -1.0f) { _xmin = -1; _xmax = -1; _ymin = -1; _ymax = -1; // Bad aabb } else { _xmin = (int)min(max((xmin + 1.0f) * (0.5f * (float)N), 0.0f), (float)(N - 1)); _xmax = (int)min(max((xmax + 1.0f) * (0.5f * (float)N), 0.0f), (float)(N - 1)); _ymin = (int)min(max((ymin + 1.0f) * (0.5f * (float)N), 0.0f), (float)(N - 1)); _ymax = (int)min(max((ymax + 1.0f) * (0.5f * (float)N), 0.0f), (float)(N - 1)); } } else { _xmin = 0.0f; _xmax = (float)(N-1); _ymin = 0.0f; _ymax = (float)(N-1); } } /////////////////////////////////////////////////////////////////////////////////////////////////////////// // Diffuse kernel __global__ void DiffuseCubemapFwdKernel(DiffuseCubemapKernelParams p) { // Calculate pixel position. int px = blockIdx.x * blockDim.x + threadIdx.x; int py = blockIdx.y * blockDim.y + threadIdx.y; int pz = blockIdx.z; if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z) return; int Npx = p.cubemap.dims[1]; vec3f N = cube_to_dir(px, py, pz, Npx); vec3f col(0); for (int s = 0; s < p.cubemap.dims[0]; ++s) { for (int y = 0; y < Npx; ++y) { for (int x = 0; x < Npx; ++x) { vec3f L = cube_to_dir(x, y, s, Npx); float costheta = min(max(dot(N, L), 0.0f), 0.999f); float w = costheta * pixel_area(x, y, Npx) / 3.141592f; // pi = area of positive hemisphere col += p.cubemap.fetch3(x, y, s) * w; } } } p.out.store(px, py, pz, col); } __global__ void DiffuseCubemapBwdKernel(DiffuseCubemapKernelParams p) { // Calculate pixel position. int px = blockIdx.x * blockDim.x + threadIdx.x; int py = blockIdx.y * blockDim.y + threadIdx.y; int pz = blockIdx.z; if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z) return; int Npx = p.cubemap.dims[1]; vec3f N = cube_to_dir(px, py, pz, Npx); vec3f grad = p.out.fetch3(px, py, pz); for (int s = 0; s < p.cubemap.dims[0]; ++s) { for (int y = 0; y < Npx; ++y) { for (int x = 0; x < Npx; ++x) { vec3f L = cube_to_dir(x, y, s, Npx); float costheta = min(max(dot(N, L), 0.0f), 0.999f); float w = costheta * pixel_area(x, y, Npx) / 3.141592f; // pi = area of positive hemisphere atomicAdd((float*)p.cubemap.d_val + p.cubemap.nhwcIndexContinuous(s, y, x, 0), grad.x * w); atomicAdd((float*)p.cubemap.d_val + p.cubemap.nhwcIndexContinuous(s, y, x, 1), grad.y * w); atomicAdd((float*)p.cubemap.d_val + p.cubemap.nhwcIndexContinuous(s, y, x, 2), grad.z * w); } } } } /////////////////////////////////////////////////////////////////////////////////////////////////////////// // GGX splitsum kernel __device__ inline float ndfGGX(const float alphaSqr, const float cosTheta) { float _cosTheta = clamp(cosTheta, 0.0, 1.0f); float d = (_cosTheta * alphaSqr - _cosTheta) * _cosTheta + 1.0f; return alphaSqr / (d * d * M_PI); } __global__ void SpecularBoundsKernel(SpecularBoundsKernelParams p) { int px = blockIdx.x * blockDim.x + threadIdx.x; int py = blockIdx.y * blockDim.y + threadIdx.y; int pz = blockIdx.z; if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z) return; int Npx = p.gridSize.x; vec3f VNR = cube_to_dir(px, py, pz, Npx); const int TILE_SIZE = 16; // Brute force entire cubemap and compute bounds for the cone for (int s = 0; s < p.gridSize.z; ++s) { // Assume empty BBox int _min_x = p.gridSize.x - 1, _max_x = 0; int _min_y = p.gridSize.y - 1, _max_y = 0; // For each (8x8) tile for (int tx = 0; tx < (p.gridSize.x + TILE_SIZE - 1) / TILE_SIZE; tx++) { for (int ty = 0; ty < (p.gridSize.y + TILE_SIZE - 1) / TILE_SIZE; ty++) { // Compute tile extents int tsx = tx * TILE_SIZE, tsy = ty * TILE_SIZE; int tex = min((tx + 1) * TILE_SIZE, p.gridSize.x), tey = min((ty + 1) * TILE_SIZE, p.gridSize.y); // Use some blunt interval arithmetics to cull tiles vec3f L0 = cube_to_dir(tsx, tsy, s, Npx), L1 = cube_to_dir(tex, tsy, s, Npx); vec3f L2 = cube_to_dir(tsx, tey, s, Npx), L3 = cube_to_dir(tex, tey, s, Npx); float minx = min(min(L0.x, L1.x), min(L2.x, L3.x)), maxx = max(max(L0.x, L1.x), max(L2.x, L3.x)); float miny = min(min(L0.y, L1.y), min(L2.y, L3.y)), maxy = max(max(L0.y, L1.y), max(L2.y, L3.y)); float minz = min(min(L0.z, L1.z), min(L2.z, L3.z)), maxz = max(max(L0.z, L1.z), max(L2.z, L3.z)); float maxdp = max(minx * VNR.x, maxx * VNR.x) + max(miny * VNR.y, maxy * VNR.y) + max(minz * VNR.z, maxz * VNR.z); if (maxdp >= p.costheta_cutoff) { // Test all pixels in tile. for (int y = tsy; y < tey; ++y) { for (int x = tsx; x < tex; ++x) { vec3f L = cube_to_dir(x, y, s, Npx); if (dot(L, VNR) >= p.costheta_cutoff) { _min_x = min(_min_x, x); _max_x = max(_max_x, x); _min_y = min(_min_y, y); _max_y = max(_max_y, y); } } } } } } p.out.store(p.out._nhwcIndex(pz, py, px, s * 4 + 0), _min_x); p.out.store(p.out._nhwcIndex(pz, py, px, s * 4 + 1), _max_x); p.out.store(p.out._nhwcIndex(pz, py, px, s * 4 + 2), _min_y); p.out.store(p.out._nhwcIndex(pz, py, px, s * 4 + 3), _max_y); } } __global__ void SpecularCubemapFwdKernel(SpecularCubemapKernelParams p) { // Calculate pixel position. int px = blockIdx.x * blockDim.x + threadIdx.x; int py = blockIdx.y * blockDim.y + threadIdx.y; int pz = blockIdx.z; if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z) return; int Npx = p.cubemap.dims[1]; vec3f VNR = cube_to_dir(px, py, pz, Npx); float alpha = p.roughness * p.roughness; float alphaSqr = alpha * alpha; float wsum = 0.0f; vec3f col(0); for (int s = 0; s < p.cubemap.dims[0]; ++s) { int xmin, xmax, ymin, ymax; xmin = (int)p.bounds.fetch(p.bounds._nhwcIndex(pz, py, px, s * 4 + 0)); xmax = (int)p.bounds.fetch(p.bounds._nhwcIndex(pz, py, px, s * 4 + 1)); ymin = (int)p.bounds.fetch(p.bounds._nhwcIndex(pz, py, px, s * 4 + 2)); ymax = (int)p.bounds.fetch(p.bounds._nhwcIndex(pz, py, px, s * 4 + 3)); if (xmin <= xmax) { for (int y = ymin; y <= ymax; ++y) { for (int x = xmin; x <= xmax; ++x) { vec3f L = cube_to_dir(x, y, s, Npx); if (dot(L, VNR) >= p.costheta_cutoff) { vec3f H = safeNormalize(L + VNR); float wiDotN = max(dot(L, VNR), 0.0f); float VNRDotH = max(dot(VNR, H), 0.0f); float w = wiDotN * ndfGGX(alphaSqr, VNRDotH) * pixel_area(x, y, Npx) / 4.0f; col += p.cubemap.fetch3(x, y, s) * w; wsum += w; } } } } } p.out.store(p.out._nhwcIndex(pz, py, px, 0), col.x); p.out.store(p.out._nhwcIndex(pz, py, px, 1), col.y); p.out.store(p.out._nhwcIndex(pz, py, px, 2), col.z); p.out.store(p.out._nhwcIndex(pz, py, px, 3), wsum); } __global__ void SpecularCubemapBwdKernel(SpecularCubemapKernelParams p) { // Calculate pixel position. int px = blockIdx.x * blockDim.x + threadIdx.x; int py = blockIdx.y * blockDim.y + threadIdx.y; int pz = blockIdx.z; if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z) return; int Npx = p.cubemap.dims[1]; vec3f VNR = cube_to_dir(px, py, pz, Npx); vec3f grad = p.out.fetch3(px, py, pz); float alpha = p.roughness * p.roughness; float alphaSqr = alpha * alpha; vec3f col(0); for (int s = 0; s < p.cubemap.dims[0]; ++s) { int xmin, xmax, ymin, ymax; xmin = (int)p.bounds.fetch(p.bounds._nhwcIndex(pz, py, px, s * 4 + 0)); xmax = (int)p.bounds.fetch(p.bounds._nhwcIndex(pz, py, px, s * 4 + 1)); ymin = (int)p.bounds.fetch(p.bounds._nhwcIndex(pz, py, px, s * 4 + 2)); ymax = (int)p.bounds.fetch(p.bounds._nhwcIndex(pz, py, px, s * 4 + 3)); if (xmin <= xmax) { for (int y = ymin; y <= ymax; ++y) { for (int x = xmin; x <= xmax; ++x) { vec3f L = cube_to_dir(x, y, s, Npx); if (dot(L, VNR) >= p.costheta_cutoff) { vec3f H = safeNormalize(L + VNR); float wiDotN = max(dot(L, VNR), 0.0f); float VNRDotH = max(dot(VNR, H), 0.0f); float w = wiDotN * ndfGGX(alphaSqr, VNRDotH) * pixel_area(x, y, Npx) / 4.0f; atomicAdd((float*)p.cubemap.d_val + p.cubemap.nhwcIndexContinuous(s, y, x, 0), grad.x * w); atomicAdd((float*)p.cubemap.d_val + p.cubemap.nhwcIndexContinuous(s, y, x, 1), grad.y * w); atomicAdd((float*)p.cubemap.d_val + p.cubemap.nhwcIndexContinuous(s, y, x, 2), grad.z * w); } } } } } } ================================================ FILE: render/renderutils/c_src/cubemap.h ================================================ /* * Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual * property and proprietary rights in and to this material, related * documentation and any modifications thereto. Any use, reproduction, * disclosure or distribution of this material and related documentation * without an express license agreement from NVIDIA CORPORATION or * its affiliates is strictly prohibited. */ #pragma once #include "common.h" struct DiffuseCubemapKernelParams { Tensor cubemap; Tensor out; dim3 gridSize; }; struct SpecularCubemapKernelParams { Tensor cubemap; Tensor bounds; Tensor out; dim3 gridSize; float costheta_cutoff; float roughness; }; struct SpecularBoundsKernelParams { float costheta_cutoff; Tensor out; dim3 gridSize; }; ================================================ FILE: render/renderutils/c_src/loss.cu ================================================ /* * Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual * property and proprietary rights in and to this material, related * documentation and any modifications thereto. Any use, reproduction, * disclosure or distribution of this material and related documentation * without an express license agreement from NVIDIA CORPORATION or * its affiliates is strictly prohibited. */ #include #include "common.h" #include "loss.h" //------------------------------------------------------------------------ // Utils __device__ inline float bwdAbs(float x) { return x == 0.0f ? 0.0f : x < 0.0f ? -1.0f : 1.0f; } __device__ float warpSum(float val) { for (int i = 1; i < 32; i *= 2) val += __shfl_xor_sync(0xFFFFFFFF, val, i); return val; } //------------------------------------------------------------------------ // Tonemapping __device__ inline float fwdSRGB(float x) { return x > 0.0031308f ? powf(max(x, 0.0031308f), 1.0f / 2.4f) * 1.055f - 0.055f : 12.92f * max(x, 0.0f); } __device__ inline void bwdSRGB(float x, float &d_x, float d_out) { if (x > 0.0031308f) d_x += d_out * 0.439583f / powf(x, 0.583333f); else if (x > 0.0f) d_x += d_out * 12.92f; } __device__ inline vec3f fwdTonemapLogSRGB(vec3f x) { return vec3f(fwdSRGB(logf(x.x + 1.0f)), fwdSRGB(logf(x.y + 1.0f)), fwdSRGB(logf(x.z + 1.0f))); } __device__ inline void bwdTonemapLogSRGB(vec3f x, vec3f& d_x, vec3f d_out) { if (x.x > 0.0f && x.x < 65535.0f) { bwdSRGB(logf(x.x + 1.0f), d_x.x, d_out.x); d_x.x *= 1 / (x.x + 1.0f); } if (x.y > 0.0f && x.y < 65535.0f) { bwdSRGB(logf(x.y + 1.0f), d_x.y, d_out.y); d_x.y *= 1 / (x.y + 1.0f); } if (x.z > 0.0f && x.z < 65535.0f) { bwdSRGB(logf(x.z + 1.0f), d_x.z, d_out.z); d_x.z *= 1 / (x.z + 1.0f); } } __device__ inline float fwdRELMSE(float img, float target, float eps = 0.1f) { return (img - target) * (img - target) / (img * img + target * target + eps); } __device__ inline void bwdRELMSE(float img, float target, float &d_img, float &d_target, float d_out, float eps = 0.1f) { float denom = (target * target + img * img + eps); d_img += d_out * 2 * (img - target) * (target * (target + img) + eps) / (denom * denom); d_target -= d_out * 2 * (img - target) * (img * (target + img) + eps) / (denom * denom); } __device__ inline float fwdSMAPE(float img, float target, float eps=0.01f) { return abs(img - target) / (img + target + eps); } __device__ inline void bwdSMAPE(float img, float target, float& d_img, float& d_target, float d_out, float eps = 0.01f) { float denom = (target + img + eps); d_img += d_out * bwdAbs(img - target) * (2 * target + eps) / (denom * denom); d_target -= d_out * bwdAbs(img - target) * (2 * img + eps) / (denom * denom); } //------------------------------------------------------------------------ // Kernels __global__ void imgLossFwdKernel(LossKernelParams p) { // Calculate pixel position. unsigned int px = blockIdx.x * blockDim.x + threadIdx.x; unsigned int py = blockIdx.y * blockDim.y + threadIdx.y; unsigned int pz = blockIdx.z; float floss = 0.0f; if (px < p.gridSize.x && py < p.gridSize.y && pz < p.gridSize.z) { vec3f img = p.img.fetch3(px, py, pz); vec3f target = p.target.fetch3(px, py, pz); img = vec3f(clamp(img.x, 0.0f, 65535.0f), clamp(img.y, 0.0f, 65535.0f), clamp(img.z, 0.0f, 65535.0f)); target = vec3f(clamp(target.x, 0.0f, 65535.0f), clamp(target.y, 0.0f, 65535.0f), clamp(target.z, 0.0f, 65535.0f)); if (p.tonemapper == TONEMAPPER_LOG_SRGB) { img = fwdTonemapLogSRGB(img); target = fwdTonemapLogSRGB(target); } vec3f vloss(0); if (p.loss == LOSS_MSE) vloss = (img - target) * (img - target); else if (p.loss == LOSS_RELMSE) vloss = vec3f(fwdRELMSE(img.x, target.x), fwdRELMSE(img.y, target.y), fwdRELMSE(img.z, target.z)); else if (p.loss == LOSS_SMAPE) vloss = vec3f(fwdSMAPE(img.x, target.x), fwdSMAPE(img.y, target.y), fwdSMAPE(img.z, target.z)); else vloss = vec3f(abs(img.x - target.x), abs(img.y - target.y), abs(img.z - target.z)); floss = sum(vloss) / 3.0f; } floss = warpSum(floss); dim3 warpSize = getWarpSize(blockDim); if (px < p.gridSize.x && py < p.gridSize.y && pz < p.gridSize.z && threadIdx.x % warpSize.x == 0 && threadIdx.y % warpSize.y == 0 && threadIdx.z % warpSize.z == 0) p.out.store(px / warpSize.x, py / warpSize.y, pz / warpSize.z, floss); } __global__ void imgLossBwdKernel(LossKernelParams p) { // Calculate pixel position. unsigned int px = blockIdx.x * blockDim.x + threadIdx.x; unsigned int py = blockIdx.y * blockDim.y + threadIdx.y; unsigned int pz = blockIdx.z; if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z) return; dim3 warpSize = getWarpSize(blockDim); vec3f _img = p.img.fetch3(px, py, pz); vec3f _target = p.target.fetch3(px, py, pz); float d_out = p.out.fetch1(px / warpSize.x, py / warpSize.y, pz / warpSize.z); ///////////////////////////////////////////////////////////////////// // FWD vec3f img = _img, target = _target; if (p.tonemapper == TONEMAPPER_LOG_SRGB) { img = fwdTonemapLogSRGB(img); target = fwdTonemapLogSRGB(target); } ///////////////////////////////////////////////////////////////////// // BWD vec3f d_vloss = vec3f(d_out, d_out, d_out) / 3.0f; vec3f d_img(0), d_target(0); if (p.loss == LOSS_MSE) { d_img = vec3f(d_vloss.x * 2 * (img.x - target.x), d_vloss.y * 2 * (img.y - target.y), d_vloss.x * 2 * (img.z - target.z)); d_target = -d_img; } else if (p.loss == LOSS_RELMSE) { bwdRELMSE(img.x, target.x, d_img.x, d_target.x, d_vloss.x); bwdRELMSE(img.y, target.y, d_img.y, d_target.y, d_vloss.y); bwdRELMSE(img.z, target.z, d_img.z, d_target.z, d_vloss.z); } else if (p.loss == LOSS_SMAPE) { bwdSMAPE(img.x, target.x, d_img.x, d_target.x, d_vloss.x); bwdSMAPE(img.y, target.y, d_img.y, d_target.y, d_vloss.y); bwdSMAPE(img.z, target.z, d_img.z, d_target.z, d_vloss.z); } else { d_img = d_vloss * vec3f(bwdAbs(img.x - target.x), bwdAbs(img.y - target.y), bwdAbs(img.z - target.z)); d_target = -d_img; } if (p.tonemapper == TONEMAPPER_LOG_SRGB) { vec3f d__img(0), d__target(0); bwdTonemapLogSRGB(_img, d__img, d_img); bwdTonemapLogSRGB(_target, d__target, d_target); d_img = d__img; d_target = d__target; } if (_img.x <= 0.0f || _img.x >= 65535.0f) d_img.x = 0; if (_img.y <= 0.0f || _img.y >= 65535.0f) d_img.y = 0; if (_img.z <= 0.0f || _img.z >= 65535.0f) d_img.z = 0; if (_target.x <= 0.0f || _target.x >= 65535.0f) d_target.x = 0; if (_target.y <= 0.0f || _target.y >= 65535.0f) d_target.y = 0; if (_target.z <= 0.0f || _target.z >= 65535.0f) d_target.z = 0; p.img.store_grad(px, py, pz, d_img); p.target.store_grad(px, py, pz, d_target); } ================================================ FILE: render/renderutils/c_src/loss.h ================================================ /* * Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual * property and proprietary rights in and to this material, related * documentation and any modifications thereto. Any use, reproduction, * disclosure or distribution of this material and related documentation * without an express license agreement from NVIDIA CORPORATION or * its affiliates is strictly prohibited. */ #pragma once #include "common.h" enum TonemapperType { TONEMAPPER_NONE = 0, TONEMAPPER_LOG_SRGB = 1 }; enum LossType { LOSS_L1 = 0, LOSS_MSE = 1, LOSS_RELMSE = 2, LOSS_SMAPE = 3 }; struct LossKernelParams { Tensor img; Tensor target; Tensor out; dim3 gridSize; TonemapperType tonemapper; LossType loss; }; ================================================ FILE: render/renderutils/c_src/mesh.cu ================================================ /* * Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual * property and proprietary rights in and to this material, related * documentation and any modifications thereto. Any use, reproduction, * disclosure or distribution of this material and related documentation * without an express license agreement from NVIDIA CORPORATION or * its affiliates is strictly prohibited. */ #include #include #include "common.h" #include "mesh.h" //------------------------------------------------------------------------ // Kernels __global__ void xfmPointsFwdKernel(XfmKernelParams p) { unsigned int px = blockIdx.x * blockDim.x + threadIdx.x; unsigned int pz = blockIdx.z * blockDim.z + threadIdx.z; __shared__ float mtx[4][4]; if (threadIdx.x < 16) mtx[threadIdx.x % 4][threadIdx.x / 4] = p.matrix.fetch(p.matrix.nhwcIndex(pz, threadIdx.x / 4, threadIdx.x % 4, 0)); __syncthreads(); if (px >= p.gridSize.x) return; vec3f pos( p.points.fetch(p.points.nhwcIndex(pz, px, 0, 0)), p.points.fetch(p.points.nhwcIndex(pz, px, 1, 0)), p.points.fetch(p.points.nhwcIndex(pz, px, 2, 0)) ); if (p.isPoints) { p.out.store(p.out.nhwcIndex(pz, px, 0, 0), pos.x * mtx[0][0] + pos.y * mtx[1][0] + pos.z * mtx[2][0] + mtx[3][0]); p.out.store(p.out.nhwcIndex(pz, px, 1, 0), pos.x * mtx[0][1] + pos.y * mtx[1][1] + pos.z * mtx[2][1] + mtx[3][1]); p.out.store(p.out.nhwcIndex(pz, px, 2, 0), pos.x * mtx[0][2] + pos.y * mtx[1][2] + pos.z * mtx[2][2] + mtx[3][2]); p.out.store(p.out.nhwcIndex(pz, px, 3, 0), pos.x * mtx[0][3] + pos.y * mtx[1][3] + pos.z * mtx[2][3] + mtx[3][3]); } else { p.out.store(p.out.nhwcIndex(pz, px, 0, 0), pos.x * mtx[0][0] + pos.y * mtx[1][0] + pos.z * mtx[2][0]); p.out.store(p.out.nhwcIndex(pz, px, 1, 0), pos.x * mtx[0][1] + pos.y * mtx[1][1] + pos.z * mtx[2][1]); p.out.store(p.out.nhwcIndex(pz, px, 2, 0), pos.x * mtx[0][2] + pos.y * mtx[1][2] + pos.z * mtx[2][2]); } } __global__ void xfmPointsBwdKernel(XfmKernelParams p) { unsigned int px = blockIdx.x * blockDim.x + threadIdx.x; unsigned int pz = blockIdx.z * blockDim.z + threadIdx.z; __shared__ float mtx[4][4]; if (threadIdx.x < 16) mtx[threadIdx.x % 4][threadIdx.x / 4] = p.matrix.fetch(p.matrix.nhwcIndex(pz, threadIdx.x / 4, threadIdx.x % 4, 0)); __syncthreads(); if (px >= p.gridSize.x) return; vec3f pos( p.points.fetch(p.points.nhwcIndex(pz, px, 0, 0)), p.points.fetch(p.points.nhwcIndex(pz, px, 1, 0)), p.points.fetch(p.points.nhwcIndex(pz, px, 2, 0)) ); vec4f d_out( p.out.fetch(p.out.nhwcIndex(pz, px, 0, 0)), p.out.fetch(p.out.nhwcIndex(pz, px, 1, 0)), p.out.fetch(p.out.nhwcIndex(pz, px, 2, 0)), p.out.fetch(p.out.nhwcIndex(pz, px, 3, 0)) ); if (p.isPoints) { p.points.store_grad(p.points.nhwcIndexContinuous(pz, px, 0, 0), d_out.x * mtx[0][0] + d_out.y * mtx[0][1] + d_out.z * mtx[0][2] + d_out.w * mtx[0][3]); p.points.store_grad(p.points.nhwcIndexContinuous(pz, px, 1, 0), d_out.x * mtx[1][0] + d_out.y * mtx[1][1] + d_out.z * mtx[1][2] + d_out.w * mtx[1][3]); p.points.store_grad(p.points.nhwcIndexContinuous(pz, px, 2, 0), d_out.x * mtx[2][0] + d_out.y * mtx[2][1] + d_out.z * mtx[2][2] + d_out.w * mtx[2][3]); } else { p.points.store_grad(p.points.nhwcIndexContinuous(pz, px, 0, 0), d_out.x * mtx[0][0] + d_out.y * mtx[0][1] + d_out.z * mtx[0][2]); p.points.store_grad(p.points.nhwcIndexContinuous(pz, px, 1, 0), d_out.x * mtx[1][0] + d_out.y * mtx[1][1] + d_out.z * mtx[1][2]); p.points.store_grad(p.points.nhwcIndexContinuous(pz, px, 2, 0), d_out.x * mtx[2][0] + d_out.y * mtx[2][1] + d_out.z * mtx[2][2]); } } ================================================ FILE: render/renderutils/c_src/mesh.h ================================================ /* * Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual * property and proprietary rights in and to this material, related * documentation and any modifications thereto. Any use, reproduction, * disclosure or distribution of this material and related documentation * without an express license agreement from NVIDIA CORPORATION or * its affiliates is strictly prohibited. */ #pragma once #include "common.h" struct XfmKernelParams { bool isPoints; Tensor points; Tensor matrix; Tensor out; dim3 gridSize; }; ================================================ FILE: render/renderutils/c_src/normal.cu ================================================ /* * Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual * property and proprietary rights in and to this material, related * documentation and any modifications thereto. Any use, reproduction, * disclosure or distribution of this material and related documentation * without an express license agreement from NVIDIA CORPORATION or * its affiliates is strictly prohibited. */ #include "common.h" #include "normal.h" #define NORMAL_THRESHOLD 0.1f //------------------------------------------------------------------------ // Perturb shading normal by tangent frame __device__ vec3f fwdPerturbNormal(const vec3f perturbed_nrm, const vec3f smooth_nrm, const vec3f smooth_tng, bool opengl) { vec3f _smooth_bitng = cross(smooth_tng, smooth_nrm); vec3f smooth_bitng = safeNormalize(_smooth_bitng); vec3f _shading_nrm = smooth_tng * perturbed_nrm.x + (opengl ? -1 : 1) * smooth_bitng * perturbed_nrm.y + smooth_nrm * max(perturbed_nrm.z, 0.0f); return safeNormalize(_shading_nrm); } __device__ void bwdPerturbNormal(const vec3f perturbed_nrm, const vec3f smooth_nrm, const vec3f smooth_tng, vec3f &d_perturbed_nrm, vec3f &d_smooth_nrm, vec3f &d_smooth_tng, const vec3f d_out, bool opengl) { //////////////////////////////////////////////////////////////////////// // FWD vec3f _smooth_bitng = cross(smooth_tng, smooth_nrm); vec3f smooth_bitng = safeNormalize(_smooth_bitng); vec3f _shading_nrm = smooth_tng * perturbed_nrm.x + (opengl ? -1 : 1) * smooth_bitng * perturbed_nrm.y + smooth_nrm * max(perturbed_nrm.z, 0.0f); //////////////////////////////////////////////////////////////////////// // BWD vec3f d_shading_nrm(0); bwdSafeNormalize(_shading_nrm, d_shading_nrm, d_out); vec3f d_smooth_bitng(0); if (perturbed_nrm.z > 0.0f) { d_smooth_nrm += d_shading_nrm * perturbed_nrm.z; d_perturbed_nrm.z += sum(d_shading_nrm * smooth_nrm); } d_smooth_bitng += (opengl ? -1 : 1) * d_shading_nrm * perturbed_nrm.y; d_perturbed_nrm.y += (opengl ? -1 : 1) * sum(d_shading_nrm * smooth_bitng); d_smooth_tng += d_shading_nrm * perturbed_nrm.x; d_perturbed_nrm.x += sum(d_shading_nrm * smooth_tng); vec3f d__smooth_bitng(0); bwdSafeNormalize(_smooth_bitng, d__smooth_bitng, d_smooth_bitng); bwdCross(smooth_tng, smooth_nrm, d_smooth_tng, d_smooth_nrm, d__smooth_bitng); } //------------------------------------------------------------------------ #define bent_nrm_eps 0.001f __device__ vec3f fwdBendNormal(const vec3f view_vec, const vec3f smooth_nrm, const vec3f geom_nrm) { float dp = dot(view_vec, smooth_nrm); float t = clamp(dp / NORMAL_THRESHOLD, 0.0f, 1.0f); return geom_nrm * (1.0f - t) + smooth_nrm * t; } __device__ void bwdBendNormal(const vec3f view_vec, const vec3f smooth_nrm, const vec3f geom_nrm, vec3f& d_view_vec, vec3f& d_smooth_nrm, vec3f& d_geom_nrm, const vec3f d_out) { //////////////////////////////////////////////////////////////////////// // FWD float dp = dot(view_vec, smooth_nrm); float t = clamp(dp / NORMAL_THRESHOLD, 0.0f, 1.0f); //////////////////////////////////////////////////////////////////////// // BWD if (dp > NORMAL_THRESHOLD) d_smooth_nrm += d_out; else { // geom_nrm * (1.0f - t) + smooth_nrm * t; d_geom_nrm += d_out * (1.0f - t); d_smooth_nrm += d_out * t; float d_t = sum(d_out * (smooth_nrm - geom_nrm)); float d_dp = dp < 0.0f || dp > NORMAL_THRESHOLD ? 0.0f : d_t / NORMAL_THRESHOLD; bwdDot(view_vec, smooth_nrm, d_view_vec, d_smooth_nrm, d_dp); } } //------------------------------------------------------------------------ // Kernels __global__ void PrepareShadingNormalFwdKernel(PrepareShadingNormalKernelParams p) { // Calculate pixel position. unsigned int px = blockIdx.x * blockDim.x + threadIdx.x; unsigned int py = blockIdx.y * blockDim.y + threadIdx.y; unsigned int pz = blockIdx.z; if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z) return; vec3f pos = p.pos.fetch3(px, py, pz); vec3f view_pos = p.view_pos.fetch3(px, py, pz); vec3f perturbed_nrm = p.perturbed_nrm.fetch3(px, py, pz); vec3f _smooth_nrm = p.smooth_nrm.fetch3(px, py, pz); vec3f _smooth_tng = p.smooth_tng.fetch3(px, py, pz); vec3f geom_nrm = p.geom_nrm.fetch3(px, py, pz); vec3f smooth_nrm = safeNormalize(_smooth_nrm); vec3f smooth_tng = safeNormalize(_smooth_tng); vec3f view_vec = safeNormalize(view_pos - pos); vec3f shading_nrm = fwdPerturbNormal(perturbed_nrm, smooth_nrm, smooth_tng, p.opengl); vec3f res; if (p.two_sided_shading && dot(view_vec, geom_nrm) < 0.0f) res = fwdBendNormal(view_vec, -shading_nrm, -geom_nrm); else res = fwdBendNormal(view_vec, shading_nrm, geom_nrm); p.out.store(px, py, pz, res); } __global__ void PrepareShadingNormalBwdKernel(PrepareShadingNormalKernelParams p) { // Calculate pixel position. unsigned int px = blockIdx.x * blockDim.x + threadIdx.x; unsigned int py = blockIdx.y * blockDim.y + threadIdx.y; unsigned int pz = blockIdx.z; if (px >= p.gridSize.x || py >= p.gridSize.y || pz >= p.gridSize.z) return; vec3f pos = p.pos.fetch3(px, py, pz); vec3f view_pos = p.view_pos.fetch3(px, py, pz); vec3f perturbed_nrm = p.perturbed_nrm.fetch3(px, py, pz); vec3f _smooth_nrm = p.smooth_nrm.fetch3(px, py, pz); vec3f _smooth_tng = p.smooth_tng.fetch3(px, py, pz); vec3f geom_nrm = p.geom_nrm.fetch3(px, py, pz); vec3f d_out = p.out.fetch3(px, py, pz); /////////////////////////////////////////////////////////////////////////////////////////////////// // FWD vec3f smooth_nrm = safeNormalize(_smooth_nrm); vec3f smooth_tng = safeNormalize(_smooth_tng); vec3f _view_vec = view_pos - pos; vec3f view_vec = safeNormalize(view_pos - pos); vec3f shading_nrm = fwdPerturbNormal(perturbed_nrm, smooth_nrm, smooth_tng, p.opengl); /////////////////////////////////////////////////////////////////////////////////////////////////// // BWD vec3f d_view_vec(0), d_shading_nrm(0), d_geom_nrm(0); if (p.two_sided_shading && dot(view_vec, geom_nrm) < 0.0f) { bwdBendNormal(view_vec, -shading_nrm, -geom_nrm, d_view_vec, d_shading_nrm, d_geom_nrm, d_out); d_shading_nrm = -d_shading_nrm; d_geom_nrm = -d_geom_nrm; } else bwdBendNormal(view_vec, shading_nrm, geom_nrm, d_view_vec, d_shading_nrm, d_geom_nrm, d_out); vec3f d_perturbed_nrm(0), d_smooth_nrm(0), d_smooth_tng(0); bwdPerturbNormal(perturbed_nrm, smooth_nrm, smooth_tng, d_perturbed_nrm, d_smooth_nrm, d_smooth_tng, d_shading_nrm, p.opengl); vec3f d__view_vec(0), d__smooth_nrm(0), d__smooth_tng(0); bwdSafeNormalize(_view_vec, d__view_vec, d_view_vec); bwdSafeNormalize(_smooth_nrm, d__smooth_nrm, d_smooth_nrm); bwdSafeNormalize(_smooth_tng, d__smooth_tng, d_smooth_tng); p.pos.store_grad(px, py, pz, -d__view_vec); p.view_pos.store_grad(px, py, pz, d__view_vec); p.perturbed_nrm.store_grad(px, py, pz, d_perturbed_nrm); p.smooth_nrm.store_grad(px, py, pz, d__smooth_nrm); p.smooth_tng.store_grad(px, py, pz, d__smooth_tng); p.geom_nrm.store_grad(px, py, pz, d_geom_nrm); } ================================================ FILE: render/renderutils/c_src/normal.h ================================================ /* * Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual * property and proprietary rights in and to this material, related * documentation and any modifications thereto. Any use, reproduction, * disclosure or distribution of this material and related documentation * without an express license agreement from NVIDIA CORPORATION or * its affiliates is strictly prohibited. */ #pragma once #include "common.h" struct PrepareShadingNormalKernelParams { Tensor pos; Tensor view_pos; Tensor perturbed_nrm; Tensor smooth_nrm; Tensor smooth_tng; Tensor geom_nrm; Tensor out; dim3 gridSize; bool two_sided_shading, opengl; }; ================================================ FILE: render/renderutils/c_src/tensor.h ================================================ /* * Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual * property and proprietary rights in and to this material, related * documentation and any modifications thereto. Any use, reproduction, * disclosure or distribution of this material and related documentation * without an express license agreement from NVIDIA CORPORATION or * its affiliates is strictly prohibited. */ #pragma once #if defined(__CUDACC__) && defined(BFLOAT16) #include // bfloat16 is float32 compatible with less mantissa bits #endif //--------------------------------------------------------------------------------- // CUDA-side Tensor class for in/out parameter parsing. Can be float32 or bfloat16 struct Tensor { void* val; void* d_val; int dims[4], _dims[4]; int strides[4]; bool fp16; #if defined(__CUDA__) && !defined(__CUDA_ARCH__) Tensor() : val(nullptr), d_val(nullptr), fp16(true), dims{ 0, 0, 0, 0 }, _dims{ 0, 0, 0, 0 }, strides{ 0, 0, 0, 0 } {} #endif #ifdef __CUDACC__ // Helpers to index and read/write a single element __device__ inline int _nhwcIndex(int n, int h, int w, int c) const { return n * strides[0] + h * strides[1] + w * strides[2] + c * strides[3]; } __device__ inline int nhwcIndex(int n, int h, int w, int c) const { return (dims[0] == 1 ? 0 : n * strides[0]) + (dims[1] == 1 ? 0 : h * strides[1]) + (dims[2] == 1 ? 0 : w * strides[2]) + (dims[3] == 1 ? 0 : c * strides[3]); } __device__ inline int nhwcIndexContinuous(int n, int h, int w, int c) const { return ((n * _dims[1] + h) * _dims[2] + w) * _dims[3] + c; } #ifdef BFLOAT16 __device__ inline float fetch(unsigned int idx) const { return fp16 ? __bfloat162float(((__nv_bfloat16*)val)[idx]) : ((float*)val)[idx]; } __device__ inline void store(unsigned int idx, float _val) { if (fp16) ((__nv_bfloat16*)val)[idx] = __float2bfloat16(_val); else ((float*)val)[idx] = _val; } __device__ inline void store_grad(unsigned int idx, float _val) { if (fp16) ((__nv_bfloat16*)d_val)[idx] = __float2bfloat16(_val); else ((float*)d_val)[idx] = _val; } #else __device__ inline float fetch(unsigned int idx) const { return ((float*)val)[idx]; } __device__ inline void store(unsigned int idx, float _val) { ((float*)val)[idx] = _val; } __device__ inline void store_grad(unsigned int idx, float _val) { ((float*)d_val)[idx] = _val; } #endif ////////////////////////////////////////////////////////////////////////////////////////// // Fetch, use broadcasting for tensor dimensions of size 1 __device__ inline float fetch1(unsigned int x, unsigned int y, unsigned int z) const { return fetch(nhwcIndex(z, y, x, 0)); } __device__ inline vec3f fetch3(unsigned int x, unsigned int y, unsigned int z) const { return vec3f( fetch(nhwcIndex(z, y, x, 0)), fetch(nhwcIndex(z, y, x, 1)), fetch(nhwcIndex(z, y, x, 2)) ); } ///////////////////////////////////////////////////////////////////////////////////////////////////////////// // Store, no broadcasting here. Assume we output full res gradient and then reduce using torch.sum outside __device__ inline void store(unsigned int x, unsigned int y, unsigned int z, float _val) { store(_nhwcIndex(z, y, x, 0), _val); } __device__ inline void store(unsigned int x, unsigned int y, unsigned int z, vec3f _val) { store(_nhwcIndex(z, y, x, 0), _val.x); store(_nhwcIndex(z, y, x, 1), _val.y); store(_nhwcIndex(z, y, x, 2), _val.z); } ///////////////////////////////////////////////////////////////////////////////////////////////////////////// // Store gradient , no broadcasting here. Assume we output full res gradient and then reduce using torch.sum outside __device__ inline void store_grad(unsigned int x, unsigned int y, unsigned int z, float _val) { store_grad(nhwcIndexContinuous(z, y, x, 0), _val); } __device__ inline void store_grad(unsigned int x, unsigned int y, unsigned int z, vec3f _val) { store_grad(nhwcIndexContinuous(z, y, x, 0), _val.x); store_grad(nhwcIndexContinuous(z, y, x, 1), _val.y); store_grad(nhwcIndexContinuous(z, y, x, 2), _val.z); } #endif }; ================================================ FILE: render/renderutils/c_src/torch_bindings.cpp ================================================ /* * Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual * property and proprietary rights in and to this material, related * documentation and any modifications thereto. Any use, reproduction, * disclosure or distribution of this material and related documentation * without an express license agreement from NVIDIA CORPORATION or * its affiliates is strictly prohibited. */ #ifdef _MSC_VER #pragma warning(push, 0) #include #pragma warning(pop) #else #include #endif #include #include #include #include #define NVDR_CHECK_CUDA_ERROR(CUDA_CALL) { cudaError_t err = CUDA_CALL; AT_CUDA_CHECK(cudaGetLastError()); } #define NVDR_CHECK_GL_ERROR(GL_CALL) { GL_CALL; GLenum err = glGetError(); TORCH_CHECK(err == GL_NO_ERROR, "OpenGL error: ", getGLErrorString(err), "[", #GL_CALL, ";]"); } #define CHECK_TENSOR(X, DIMS, CHANNELS) \ TORCH_CHECK(X.is_cuda(), #X " must be a cuda tensor") \ TORCH_CHECK(X.scalar_type() == torch::kFloat || X.scalar_type() == torch::kBFloat16, #X " must be fp32 or bf16") \ TORCH_CHECK(X.dim() == DIMS, #X " must have " #DIMS " dimensions") \ TORCH_CHECK(X.size(DIMS - 1) == CHANNELS, #X " must have " #CHANNELS " channels") #include "common.h" #include "loss.h" #include "normal.h" #include "cubemap.h" #include "bsdf.h" #include "mesh.h" #define BLOCK_X 8 #define BLOCK_Y 8 //------------------------------------------------------------------------ // mesh.cu void xfmPointsFwdKernel(XfmKernelParams p); void xfmPointsBwdKernel(XfmKernelParams p); //------------------------------------------------------------------------ // loss.cu void imgLossFwdKernel(LossKernelParams p); void imgLossBwdKernel(LossKernelParams p); //------------------------------------------------------------------------ // normal.cu void PrepareShadingNormalFwdKernel(PrepareShadingNormalKernelParams p); void PrepareShadingNormalBwdKernel(PrepareShadingNormalKernelParams p); //------------------------------------------------------------------------ // cubemap.cu void DiffuseCubemapFwdKernel(DiffuseCubemapKernelParams p); void DiffuseCubemapBwdKernel(DiffuseCubemapKernelParams p); void SpecularBoundsKernel(SpecularBoundsKernelParams p); void SpecularCubemapFwdKernel(SpecularCubemapKernelParams p); void SpecularCubemapBwdKernel(SpecularCubemapKernelParams p); //------------------------------------------------------------------------ // bsdf.cu void LambertFwdKernel(LambertKernelParams p); void LambertBwdKernel(LambertKernelParams p); void FrostbiteDiffuseFwdKernel(FrostbiteDiffuseKernelParams p); void FrostbiteDiffuseBwdKernel(FrostbiteDiffuseKernelParams p); void FresnelShlickFwdKernel(FresnelShlickKernelParams p); void FresnelShlickBwdKernel(FresnelShlickKernelParams p); void ndfGGXFwdKernel(NdfGGXParams p); void ndfGGXBwdKernel(NdfGGXParams p); void lambdaGGXFwdKernel(NdfGGXParams p); void lambdaGGXBwdKernel(NdfGGXParams p); void maskingSmithFwdKernel(MaskingSmithParams p); void maskingSmithBwdKernel(MaskingSmithParams p); void pbrSpecularFwdKernel(PbrSpecular p); void pbrSpecularBwdKernel(PbrSpecular p); void pbrBSDFFwdKernel(PbrBSDF p); void pbrBSDFBwdKernel(PbrBSDF p); //------------------------------------------------------------------------ // Tensor helpers void update_grid(dim3 &gridSize, torch::Tensor x) { gridSize.x = std::max(gridSize.x, (uint32_t)x.size(2)); gridSize.y = std::max(gridSize.y, (uint32_t)x.size(1)); gridSize.z = std::max(gridSize.z, (uint32_t)x.size(0)); } template void update_grid(dim3& gridSize, torch::Tensor x, Ts&&... vs) { gridSize.x = std::max(gridSize.x, (uint32_t)x.size(2)); gridSize.y = std::max(gridSize.y, (uint32_t)x.size(1)); gridSize.z = std::max(gridSize.z, (uint32_t)x.size(0)); update_grid(gridSize, std::forward(vs)...); } Tensor make_cuda_tensor(torch::Tensor val) { Tensor res; for (int i = 0; i < val.dim(); ++i) { res.dims[i] = val.size(i); res.strides[i] = val.stride(i); } res.fp16 = val.scalar_type() == torch::kBFloat16; res.val = res.fp16 ? (void*)val.data_ptr() : (void*)val.data_ptr(); res.d_val = nullptr; return res; } Tensor make_cuda_tensor(torch::Tensor val, dim3 outDims, torch::Tensor* grad = nullptr) { Tensor res; for (int i = 0; i < val.dim(); ++i) { res.dims[i] = val.size(i); res.strides[i] = val.stride(i); } if (val.dim() == 4) res._dims[0] = outDims.z, res._dims[1] = outDims.y, res._dims[2] = outDims.x, res._dims[3] = val.size(3); else res._dims[0] = outDims.z, res._dims[1] = outDims.x, res._dims[2] = val.size(2), res._dims[3] = 1; // Add a trailing one for indexing math to work out res.fp16 = val.scalar_type() == torch::kBFloat16; res.val = res.fp16 ? (void*)val.data_ptr() : (void*)val.data_ptr(); res.d_val = nullptr; if (grad != nullptr) { if (val.dim() == 4) *grad = torch::empty({ outDims.z, outDims.y, outDims.x, val.size(3) }, torch::TensorOptions().dtype(res.fp16 ? torch::kBFloat16 : torch::kFloat32).device(torch::kCUDA)); else // 3 *grad = torch::empty({ outDims.z, outDims.x, val.size(2) }, torch::TensorOptions().dtype(res.fp16 ? torch::kBFloat16 : torch::kFloat32).device(torch::kCUDA)); res.d_val = res.fp16 ? (void*)grad->data_ptr() : (void*)grad->data_ptr(); } return res; } //------------------------------------------------------------------------ // prepare_shading_normal torch::Tensor prepare_shading_normal_fwd(torch::Tensor pos, torch::Tensor view_pos, torch::Tensor perturbed_nrm, torch::Tensor smooth_nrm, torch::Tensor smooth_tng, torch::Tensor geom_nrm, bool two_sided_shading, bool opengl, bool fp16) { CHECK_TENSOR(pos, 4, 3); CHECK_TENSOR(view_pos, 4, 3); CHECK_TENSOR(perturbed_nrm, 4, 3); CHECK_TENSOR(smooth_nrm, 4, 3); CHECK_TENSOR(smooth_tng, 4, 3); CHECK_TENSOR(geom_nrm, 4, 3); cudaStream_t stream = at::cuda::getCurrentCUDAStream(); // Extract input parameters. PrepareShadingNormalKernelParams p; p.two_sided_shading = two_sided_shading; p.opengl = opengl; p.out.fp16 = fp16; update_grid(p.gridSize, pos, view_pos, perturbed_nrm, smooth_nrm, smooth_tng, geom_nrm); // Allocate output tensors. torch::TensorOptions opts = torch::TensorOptions().dtype(p.out.fp16 ? torch::kBFloat16 : torch::kFloat32).device(torch::kCUDA); torch::Tensor out = torch::empty({ p.gridSize.z, p.gridSize.y, p.gridSize.x, 3 }, opts); // Choose launch parameters. dim3 blockSize = getLaunchBlockSize(BLOCK_X, BLOCK_Y, p.gridSize); dim3 gridSize = getLaunchGridSize(blockSize, p.gridSize); // Setup tensors p.pos = make_cuda_tensor(pos, p.gridSize); p.view_pos = make_cuda_tensor(view_pos, p.gridSize); p.perturbed_nrm = make_cuda_tensor(perturbed_nrm, p.gridSize); p.smooth_nrm = make_cuda_tensor(smooth_nrm, p.gridSize); p.smooth_tng = make_cuda_tensor(smooth_tng, p.gridSize); p.geom_nrm = make_cuda_tensor(geom_nrm, p.gridSize); p.out = make_cuda_tensor(out, p.gridSize); // Launch CUDA kernel. void* args[] = { &p }; NVDR_CHECK_CUDA_ERROR(cudaLaunchKernel((const void*)PrepareShadingNormalFwdKernel, gridSize, blockSize, args, 0, stream)); return out; } std::tuple prepare_shading_normal_bwd(torch::Tensor pos, torch::Tensor view_pos, torch::Tensor perturbed_nrm, torch::Tensor smooth_nrm, torch::Tensor smooth_tng, torch::Tensor geom_nrm, torch::Tensor grad, bool two_sided_shading, bool opengl) { cudaStream_t stream = at::cuda::getCurrentCUDAStream(); // Extract input parameters. PrepareShadingNormalKernelParams p; p.two_sided_shading = two_sided_shading; p.opengl = opengl; update_grid(p.gridSize, pos, view_pos, perturbed_nrm, smooth_nrm, smooth_tng, geom_nrm); // Choose launch parameters. dim3 blockSize = getLaunchBlockSize(BLOCK_X, BLOCK_Y, p.gridSize); dim3 gridSize = getLaunchGridSize(blockSize, p.gridSize); // Setup tensors torch::Tensor pos_grad, view_pos_grad, perturbed_nrm_grad, smooth_nrm_grad, smooth_tng_grad, geom_nrm_grad; p.pos = make_cuda_tensor(pos, p.gridSize, &pos_grad); p.view_pos = make_cuda_tensor(view_pos, p.gridSize, &view_pos_grad); p.perturbed_nrm = make_cuda_tensor(perturbed_nrm, p.gridSize, &perturbed_nrm_grad); p.smooth_nrm = make_cuda_tensor(smooth_nrm, p.gridSize, &smooth_nrm_grad); p.smooth_tng = make_cuda_tensor(smooth_tng, p.gridSize, &smooth_tng_grad); p.geom_nrm = make_cuda_tensor(geom_nrm, p.gridSize, &geom_nrm_grad); p.out = make_cuda_tensor(grad, p.gridSize); // Launch CUDA kernel. void* args[] = { &p }; NVDR_CHECK_CUDA_ERROR(cudaLaunchKernel((const void*)PrepareShadingNormalBwdKernel, gridSize, blockSize, args, 0, stream)); return std::tuple(pos_grad, view_pos_grad, perturbed_nrm_grad, smooth_nrm_grad, smooth_tng_grad, geom_nrm_grad); } //------------------------------------------------------------------------ // lambert torch::Tensor lambert_fwd(torch::Tensor nrm, torch::Tensor wi, bool fp16) { CHECK_TENSOR(nrm, 4, 3); CHECK_TENSOR(wi, 4, 3); cudaStream_t stream = at::cuda::getCurrentCUDAStream(); // Extract input parameters. LambertKernelParams p; p.out.fp16 = fp16; update_grid(p.gridSize, nrm, wi); // Allocate output tensors. torch::TensorOptions opts = torch::TensorOptions().dtype(p.out.fp16 ? torch::kBFloat16 : torch::kFloat32).device(torch::kCUDA); torch::Tensor out = torch::empty({ p.gridSize.z, p.gridSize.y, p.gridSize.x, 1 }, opts); // Choose launch parameters. dim3 blockSize = getLaunchBlockSize(BLOCK_X, BLOCK_Y, p.gridSize); dim3 gridSize = getLaunchGridSize(blockSize, p.gridSize); p.nrm = make_cuda_tensor(nrm, p.gridSize); p.wi = make_cuda_tensor(wi, p.gridSize); p.out = make_cuda_tensor(out, p.gridSize); // Launch CUDA kernel. void* args[] = { &p }; NVDR_CHECK_CUDA_ERROR(cudaLaunchKernel((const void*)LambertFwdKernel, gridSize, blockSize, args, 0, stream)); return out; } std::tuple lambert_bwd(torch::Tensor nrm, torch::Tensor wi, torch::Tensor grad) { cudaStream_t stream = at::cuda::getCurrentCUDAStream(); // Extract input parameters. LambertKernelParams p; update_grid(p.gridSize, nrm, wi); // Choose launch parameters. dim3 blockSize = getLaunchBlockSize(BLOCK_X, BLOCK_Y, p.gridSize); dim3 gridSize = getLaunchGridSize(blockSize, p.gridSize); torch::Tensor nrm_grad, wi_grad; p.nrm = make_cuda_tensor(nrm, p.gridSize, &nrm_grad); p.wi = make_cuda_tensor(wi, p.gridSize, &wi_grad); p.out = make_cuda_tensor(grad, p.gridSize); // Launch CUDA kernel. void* args[] = { &p }; NVDR_CHECK_CUDA_ERROR(cudaLaunchKernel((const void*)LambertBwdKernel, gridSize, blockSize, args, 0, stream)); return std::tuple(nrm_grad, wi_grad); } //------------------------------------------------------------------------ // frostbite diffuse torch::Tensor frostbite_fwd(torch::Tensor nrm, torch::Tensor wi, torch::Tensor wo, torch::Tensor linearRoughness, bool fp16) { CHECK_TENSOR(nrm, 4, 3); CHECK_TENSOR(wi, 4, 3); CHECK_TENSOR(wo, 4, 3); CHECK_TENSOR(linearRoughness, 4, 1); cudaStream_t stream = at::cuda::getCurrentCUDAStream(); // Extract input parameters. FrostbiteDiffuseKernelParams p; p.out.fp16 = fp16; update_grid(p.gridSize, nrm, wi, wo, linearRoughness); // Allocate output tensors. torch::TensorOptions opts = torch::TensorOptions().dtype(p.out.fp16 ? torch::kBFloat16 : torch::kFloat32).device(torch::kCUDA); torch::Tensor out = torch::empty({ p.gridSize.z, p.gridSize.y, p.gridSize.x, 1 }, opts); // Choose launch parameters. dim3 blockSize = getLaunchBlockSize(BLOCK_X, BLOCK_Y, p.gridSize); dim3 gridSize = getLaunchGridSize(blockSize, p.gridSize); p.nrm = make_cuda_tensor(nrm, p.gridSize); p.wi = make_cuda_tensor(wi, p.gridSize); p.wo = make_cuda_tensor(wo, p.gridSize); p.linearRoughness = make_cuda_tensor(linearRoughness, p.gridSize); p.out = make_cuda_tensor(out, p.gridSize); // Launch CUDA kernel. void* args[] = { &p }; NVDR_CHECK_CUDA_ERROR(cudaLaunchKernel((const void*)FrostbiteDiffuseFwdKernel, gridSize, blockSize, args, 0, stream)); return out; } std::tuple frostbite_bwd(torch::Tensor nrm, torch::Tensor wi, torch::Tensor wo, torch::Tensor linearRoughness, torch::Tensor grad) { cudaStream_t stream = at::cuda::getCurrentCUDAStream(); // Extract input parameters. FrostbiteDiffuseKernelParams p; update_grid(p.gridSize, nrm, wi, wo, linearRoughness); // Choose launch parameters. dim3 blockSize = getLaunchBlockSize(BLOCK_X, BLOCK_Y, p.gridSize); dim3 gridSize = getLaunchGridSize(blockSize, p.gridSize); torch::Tensor nrm_grad, wi_grad, wo_grad, linearRoughness_grad; p.nrm = make_cuda_tensor(nrm, p.gridSize, &nrm_grad); p.wi = make_cuda_tensor(wi, p.gridSize, &wi_grad); p.wo = make_cuda_tensor(wo, p.gridSize, &wo_grad); p.linearRoughness = make_cuda_tensor(linearRoughness, p.gridSize, &linearRoughness_grad); p.out = make_cuda_tensor(grad, p.gridSize); // Launch CUDA kernel. void* args[] = { &p }; NVDR_CHECK_CUDA_ERROR(cudaLaunchKernel((const void*)FrostbiteDiffuseBwdKernel, gridSize, blockSize, args, 0, stream)); return std::tuple(nrm_grad, wi_grad, wo_grad, linearRoughness_grad); } //------------------------------------------------------------------------ // fresnel_shlick torch::Tensor fresnel_shlick_fwd(torch::Tensor f0, torch::Tensor f90, torch::Tensor cosTheta, bool fp16) { CHECK_TENSOR(f0, 4, 3); CHECK_TENSOR(f90, 4, 3); CHECK_TENSOR(cosTheta, 4, 1); cudaStream_t stream = at::cuda::getCurrentCUDAStream(); // Extract input parameters. FresnelShlickKernelParams p; p.out.fp16 = fp16; update_grid(p.gridSize, f0, f90, cosTheta); // Allocate output tensors. torch::TensorOptions opts = torch::TensorOptions().dtype(p.out.fp16 ? torch::kBFloat16 : torch::kFloat32).device(torch::kCUDA); torch::Tensor out = torch::empty({ p.gridSize.z, p.gridSize.y, p.gridSize.x, 3 }, opts); // Choose launch parameters. dim3 blockSize = getLaunchBlockSize(BLOCK_X, BLOCK_Y, p.gridSize); dim3 gridSize = getLaunchGridSize(blockSize, p.gridSize); p.f0 = make_cuda_tensor(f0, p.gridSize); p.f90 = make_cuda_tensor(f90, p.gridSize); p.cosTheta = make_cuda_tensor(cosTheta, p.gridSize); p.out = make_cuda_tensor(out, p.gridSize); // Launch CUDA kernel. void* args[] = { &p }; NVDR_CHECK_CUDA_ERROR(cudaLaunchKernel((const void*)FresnelShlickFwdKernel, gridSize, blockSize, args, 0, stream)); return out; } std::tuple fresnel_shlick_bwd(torch::Tensor f0, torch::Tensor f90, torch::Tensor cosTheta, torch::Tensor grad) { cudaStream_t stream = at::cuda::getCurrentCUDAStream(); // Extract input parameters. FresnelShlickKernelParams p; update_grid(p.gridSize, f0, f90, cosTheta); // Choose launch parameters. dim3 blockSize = getLaunchBlockSize(BLOCK_X, BLOCK_Y, p.gridSize); dim3 gridSize = getLaunchGridSize(blockSize, p.gridSize); torch::Tensor f0_grad, f90_grad, cosT_grad; p.f0 = make_cuda_tensor(f0, p.gridSize, &f0_grad); p.f90 = make_cuda_tensor(f90, p.gridSize, &f90_grad); p.cosTheta = make_cuda_tensor(cosTheta, p.gridSize, &cosT_grad); p.out = make_cuda_tensor(grad, p.gridSize); // Launch CUDA kernel. void* args[] = { &p }; NVDR_CHECK_CUDA_ERROR(cudaLaunchKernel((const void*)FresnelShlickBwdKernel, gridSize, blockSize, args, 0, stream)); return std::tuple(f0_grad, f90_grad, cosT_grad); } //------------------------------------------------------------------------ // ndf_ggd torch::Tensor ndf_ggx_fwd(torch::Tensor alphaSqr, torch::Tensor cosTheta, bool fp16) { CHECK_TENSOR(alphaSqr, 4, 1); CHECK_TENSOR(cosTheta, 4, 1); cudaStream_t stream = at::cuda::getCurrentCUDAStream(); // Extract input parameters. NdfGGXParams p; p.out.fp16 = fp16; update_grid(p.gridSize, alphaSqr, cosTheta); // Allocate output tensors. torch::TensorOptions opts = torch::TensorOptions().dtype(p.out.fp16 ? torch::kBFloat16 : torch::kFloat32).device(torch::kCUDA); torch::Tensor out = torch::empty({ p.gridSize.z, p.gridSize.y, p.gridSize.x, 1 }, opts); // Choose launch parameters. dim3 blockSize = getLaunchBlockSize(BLOCK_X, BLOCK_Y, p.gridSize); dim3 gridSize = getLaunchGridSize(blockSize, p.gridSize); p.alphaSqr = make_cuda_tensor(alphaSqr, p.gridSize); p.cosTheta = make_cuda_tensor(cosTheta, p.gridSize); p.out = make_cuda_tensor(out, p.gridSize); // Launch CUDA kernel. void* args[] = { &p }; NVDR_CHECK_CUDA_ERROR(cudaLaunchKernel((const void*)ndfGGXFwdKernel, gridSize, blockSize, args, 0, stream)); return out; } std::tuple ndf_ggx_bwd(torch::Tensor alphaSqr, torch::Tensor cosTheta, torch::Tensor grad) { cudaStream_t stream = at::cuda::getCurrentCUDAStream(); // Extract input parameters. NdfGGXParams p; update_grid(p.gridSize, alphaSqr, cosTheta); // Choose launch parameters. dim3 blockSize = getLaunchBlockSize(BLOCK_X, BLOCK_Y, p.gridSize); dim3 gridSize = getLaunchGridSize(blockSize, p.gridSize); torch::Tensor alphaSqr_grad, cosTheta_grad; p.alphaSqr = make_cuda_tensor(alphaSqr, p.gridSize, &alphaSqr_grad); p.cosTheta = make_cuda_tensor(cosTheta, p.gridSize, &cosTheta_grad); p.out = make_cuda_tensor(grad, p.gridSize); // Launch CUDA kernel. void* args[] = { &p }; NVDR_CHECK_CUDA_ERROR(cudaLaunchKernel((const void*)ndfGGXBwdKernel, gridSize, blockSize, args, 0, stream)); return std::tuple(alphaSqr_grad, cosTheta_grad); } //------------------------------------------------------------------------ // lambda_ggx torch::Tensor lambda_ggx_fwd(torch::Tensor alphaSqr, torch::Tensor cosTheta, bool fp16) { CHECK_TENSOR(alphaSqr, 4, 1); CHECK_TENSOR(cosTheta, 4, 1); cudaStream_t stream = at::cuda::getCurrentCUDAStream(); // Extract input parameters. NdfGGXParams p; p.out.fp16 = fp16; update_grid(p.gridSize, alphaSqr, cosTheta); // Allocate output tensors. torch::TensorOptions opts = torch::TensorOptions().dtype(p.out.fp16 ? torch::kBFloat16 : torch::kFloat32).device(torch::kCUDA); torch::Tensor out = torch::empty({ p.gridSize.z, p.gridSize.y, p.gridSize.x, 1 }, opts); // Choose launch parameters. dim3 blockSize = getLaunchBlockSize(BLOCK_X, BLOCK_Y, p.gridSize); dim3 gridSize = getLaunchGridSize(blockSize, p.gridSize); p.alphaSqr = make_cuda_tensor(alphaSqr, p.gridSize); p.cosTheta = make_cuda_tensor(cosTheta, p.gridSize); p.out = make_cuda_tensor(out, p.gridSize); // Launch CUDA kernel. void* args[] = { &p }; NVDR_CHECK_CUDA_ERROR(cudaLaunchKernel((const void*)lambdaGGXFwdKernel, gridSize, blockSize, args, 0, stream)); return out; } std::tuple lambda_ggx_bwd(torch::Tensor alphaSqr, torch::Tensor cosTheta, torch::Tensor grad) { cudaStream_t stream = at::cuda::getCurrentCUDAStream(); // Extract input parameters. NdfGGXParams p; update_grid(p.gridSize, alphaSqr, cosTheta); // Choose launch parameters. dim3 blockSize = getLaunchBlockSize(BLOCK_X, BLOCK_Y, p.gridSize); dim3 gridSize = getLaunchGridSize(blockSize, p.gridSize); torch::Tensor alphaSqr_grad, cosTheta_grad; p.alphaSqr = make_cuda_tensor(alphaSqr, p.gridSize, &alphaSqr_grad); p.cosTheta = make_cuda_tensor(cosTheta, p.gridSize, &cosTheta_grad); p.out = make_cuda_tensor(grad, p.gridSize); // Launch CUDA kernel. void* args[] = { &p }; NVDR_CHECK_CUDA_ERROR(cudaLaunchKernel((const void*)lambdaGGXBwdKernel, gridSize, blockSize, args, 0, stream)); return std::tuple(alphaSqr_grad, cosTheta_grad); } //------------------------------------------------------------------------ // masking_smith torch::Tensor masking_smith_fwd(torch::Tensor alphaSqr, torch::Tensor cosThetaI, torch::Tensor cosThetaO, bool fp16) { CHECK_TENSOR(alphaSqr, 4, 1); CHECK_TENSOR(cosThetaI, 4, 1); CHECK_TENSOR(cosThetaO, 4, 1); cudaStream_t stream = at::cuda::getCurrentCUDAStream(); // Extract input parameters. MaskingSmithParams p; p.out.fp16 = fp16; update_grid(p.gridSize, alphaSqr, cosThetaI, cosThetaO); // Allocate output tensors. torch::TensorOptions opts = torch::TensorOptions().dtype(p.out.fp16 ? torch::kBFloat16 : torch::kFloat32).device(torch::kCUDA); torch::Tensor out = torch::empty({ p.gridSize.z, p.gridSize.y, p.gridSize.x, 1 }, opts); // Choose launch parameters. dim3 blockSize = getLaunchBlockSize(BLOCK_X, BLOCK_Y, p.gridSize); dim3 gridSize = getLaunchGridSize(blockSize, p.gridSize); p.alphaSqr = make_cuda_tensor(alphaSqr, p.gridSize); p.cosThetaI = make_cuda_tensor(cosThetaI, p.gridSize); p.cosThetaO = make_cuda_tensor(cosThetaO, p.gridSize); p.out = make_cuda_tensor(out, p.gridSize); // Launch CUDA kernel. void* args[] = { &p }; NVDR_CHECK_CUDA_ERROR(cudaLaunchKernel((const void*)maskingSmithFwdKernel, gridSize, blockSize, args, 0, stream)); return out; } std::tuple masking_smith_bwd(torch::Tensor alphaSqr, torch::Tensor cosThetaI, torch::Tensor cosThetaO, torch::Tensor grad) { cudaStream_t stream = at::cuda::getCurrentCUDAStream(); // Extract input parameters. MaskingSmithParams p; update_grid(p.gridSize, alphaSqr, cosThetaI, cosThetaO); // Choose launch parameters. dim3 blockSize = getLaunchBlockSize(BLOCK_X, BLOCK_Y, p.gridSize); dim3 gridSize = getLaunchGridSize(blockSize, p.gridSize); torch::Tensor alphaSqr_grad, cosThetaI_grad, cosThetaO_grad; p.alphaSqr = make_cuda_tensor(alphaSqr, p.gridSize, &alphaSqr_grad); p.cosThetaI = make_cuda_tensor(cosThetaI, p.gridSize, &cosThetaI_grad); p.cosThetaO = make_cuda_tensor(cosThetaO, p.gridSize, &cosThetaO_grad); p.out = make_cuda_tensor(grad, p.gridSize); // Launch CUDA kernel. void* args[] = { &p }; NVDR_CHECK_CUDA_ERROR(cudaLaunchKernel((const void*)maskingSmithBwdKernel, gridSize, blockSize, args, 0, stream)); return std::tuple(alphaSqr_grad, cosThetaI_grad, cosThetaO_grad); } //------------------------------------------------------------------------ // pbr_specular torch::Tensor pbr_specular_fwd(torch::Tensor col, torch::Tensor nrm, torch::Tensor wo, torch::Tensor wi, torch::Tensor alpha, float min_roughness, bool fp16) { CHECK_TENSOR(col, 4, 3); CHECK_TENSOR(nrm, 4, 3); CHECK_TENSOR(wo, 4, 3); CHECK_TENSOR(wi, 4, 3); CHECK_TENSOR(alpha, 4, 1); cudaStream_t stream = at::cuda::getCurrentCUDAStream(); // Extract input parameters. PbrSpecular p; p.out.fp16 = fp16; p.min_roughness = min_roughness; update_grid(p.gridSize, col, nrm, wo, wi, alpha); // Allocate output tensors. torch::TensorOptions opts = torch::TensorOptions().dtype(p.out.fp16 ? torch::kBFloat16 : torch::kFloat32).device(torch::kCUDA); torch::Tensor out = torch::empty({ p.gridSize.z, p.gridSize.y, p.gridSize.x, 3 }, opts); // Choose launch parameters. dim3 blockSize = getLaunchBlockSize(BLOCK_X, BLOCK_Y, p.gridSize); dim3 gridSize = getLaunchGridSize(blockSize, p.gridSize); p.col = make_cuda_tensor(col, p.gridSize); p.nrm = make_cuda_tensor(nrm, p.gridSize); p.wo = make_cuda_tensor(wo, p.gridSize); p.wi = make_cuda_tensor(wi, p.gridSize); p.alpha = make_cuda_tensor(alpha, p.gridSize); p.out = make_cuda_tensor(out, p.gridSize); // Launch CUDA kernel. void* args[] = { &p }; NVDR_CHECK_CUDA_ERROR(cudaLaunchKernel((const void*)pbrSpecularFwdKernel, gridSize, blockSize, args, 0, stream)); return out; } std::tuple pbr_specular_bwd(torch::Tensor col, torch::Tensor nrm, torch::Tensor wo, torch::Tensor wi, torch::Tensor alpha, float min_roughness, torch::Tensor grad) { cudaStream_t stream = at::cuda::getCurrentCUDAStream(); // Extract input parameters. PbrSpecular p; update_grid(p.gridSize, col, nrm, wo, wi, alpha); p.min_roughness = min_roughness; // Choose launch parameters. dim3 blockSize = getLaunchBlockSize(BLOCK_X, BLOCK_Y, p.gridSize); dim3 gridSize = getLaunchGridSize(blockSize, p.gridSize); torch::Tensor col_grad, nrm_grad, wo_grad, wi_grad, alpha_grad; p.col = make_cuda_tensor(col, p.gridSize, &col_grad); p.nrm = make_cuda_tensor(nrm, p.gridSize, &nrm_grad); p.wo = make_cuda_tensor(wo, p.gridSize, &wo_grad); p.wi = make_cuda_tensor(wi, p.gridSize, &wi_grad); p.alpha = make_cuda_tensor(alpha, p.gridSize, &alpha_grad); p.out = make_cuda_tensor(grad, p.gridSize); // Launch CUDA kernel. void* args[] = { &p }; NVDR_CHECK_CUDA_ERROR(cudaLaunchKernel((const void*)pbrSpecularBwdKernel, gridSize, blockSize, args, 0, stream)); return std::tuple(col_grad, nrm_grad, wo_grad, wi_grad, alpha_grad); } //------------------------------------------------------------------------ // pbr_bsdf torch::Tensor pbr_bsdf_fwd(torch::Tensor kd, torch::Tensor arm, torch::Tensor pos, torch::Tensor nrm, torch::Tensor view_pos, torch::Tensor light_pos, float min_roughness, int BSDF, bool fp16) { CHECK_TENSOR(kd, 4, 3); CHECK_TENSOR(arm, 4, 3); CHECK_TENSOR(pos, 4, 3); CHECK_TENSOR(nrm, 4, 3); CHECK_TENSOR(view_pos, 4, 3); CHECK_TENSOR(light_pos, 4, 3); cudaStream_t stream = at::cuda::getCurrentCUDAStream(); // Extract input parameters. PbrBSDF p; p.out.fp16 = fp16; p.min_roughness = min_roughness; p.BSDF = BSDF; update_grid(p.gridSize, kd, arm, pos, nrm, view_pos, light_pos); // Allocate output tensors. torch::TensorOptions opts = torch::TensorOptions().dtype(p.out.fp16 ? torch::kBFloat16 : torch::kFloat32).device(torch::kCUDA); torch::Tensor out = torch::empty({ p.gridSize.z, p.gridSize.y, p.gridSize.x, 3 }, opts); // Choose launch parameters. dim3 blockSize = getLaunchBlockSize(BLOCK_X, BLOCK_Y, p.gridSize); dim3 gridSize = getLaunchGridSize(blockSize, p.gridSize); p.kd = make_cuda_tensor(kd, p.gridSize); p.arm = make_cuda_tensor(arm, p.gridSize); p.pos = make_cuda_tensor(pos, p.gridSize); p.nrm = make_cuda_tensor(nrm, p.gridSize); p.view_pos = make_cuda_tensor(view_pos, p.gridSize); p.light_pos = make_cuda_tensor(light_pos, p.gridSize); p.out = make_cuda_tensor(out, p.gridSize); // Launch CUDA kernel. void* args[] = { &p }; NVDR_CHECK_CUDA_ERROR(cudaLaunchKernel((const void*)pbrBSDFFwdKernel, gridSize, blockSize, args, 0, stream)); return out; } std::tuple pbr_bsdf_bwd(torch::Tensor kd, torch::Tensor arm, torch::Tensor pos, torch::Tensor nrm, torch::Tensor view_pos, torch::Tensor light_pos, float min_roughness, int BSDF, torch::Tensor grad) { cudaStream_t stream = at::cuda::getCurrentCUDAStream(); // Extract input parameters. PbrBSDF p; update_grid(p.gridSize, kd, arm, pos, nrm, view_pos, light_pos); p.min_roughness = min_roughness; p.BSDF = BSDF; // Choose launch parameters. dim3 blockSize = getLaunchBlockSize(BLOCK_X, BLOCK_Y, p.gridSize); dim3 gridSize = getLaunchGridSize(blockSize, p.gridSize); torch::Tensor kd_grad, arm_grad, pos_grad, nrm_grad, view_pos_grad, light_pos_grad; p.kd = make_cuda_tensor(kd, p.gridSize, &kd_grad); p.arm = make_cuda_tensor(arm, p.gridSize, &arm_grad); p.pos = make_cuda_tensor(pos, p.gridSize, &pos_grad); p.nrm = make_cuda_tensor(nrm, p.gridSize, &nrm_grad); p.view_pos = make_cuda_tensor(view_pos, p.gridSize, &view_pos_grad); p.light_pos = make_cuda_tensor(light_pos, p.gridSize, &light_pos_grad); p.out = make_cuda_tensor(grad, p.gridSize); // Launch CUDA kernel. void* args[] = { &p }; NVDR_CHECK_CUDA_ERROR(cudaLaunchKernel((const void*)pbrBSDFBwdKernel, gridSize, blockSize, args, 0, stream)); return std::tuple(kd_grad, arm_grad, pos_grad, nrm_grad, view_pos_grad, light_pos_grad); } //------------------------------------------------------------------------ // filter_cubemap torch::Tensor diffuse_cubemap_fwd(torch::Tensor cubemap) { CHECK_TENSOR(cubemap, 4, 3); cudaStream_t stream = at::cuda::getCurrentCUDAStream(); // Extract input parameters. DiffuseCubemapKernelParams p; update_grid(p.gridSize, cubemap); // Allocate output tensors. torch::TensorOptions opts = torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA); torch::Tensor out = torch::empty({ p.gridSize.z, p.gridSize.y, p.gridSize.x, 3 }, opts); // Choose launch parameters. dim3 blockSize = getLaunchBlockSize(BLOCK_X, BLOCK_Y, p.gridSize); dim3 gridSize = getLaunchGridSize(blockSize, p.gridSize); // Setup tensors p.cubemap = make_cuda_tensor(cubemap, p.gridSize); p.out = make_cuda_tensor(out, p.gridSize); // Launch CUDA kernel. void* args[] = { &p }; NVDR_CHECK_CUDA_ERROR(cudaLaunchKernel((const void*)DiffuseCubemapFwdKernel, gridSize, blockSize, args, 0, stream)); return out; } torch::Tensor diffuse_cubemap_bwd(torch::Tensor cubemap, torch::Tensor grad) { CHECK_TENSOR(cubemap, 4, 3); CHECK_TENSOR(grad, 4, 3); cudaStream_t stream = at::cuda::getCurrentCUDAStream(); // Extract input parameters. DiffuseCubemapKernelParams p; update_grid(p.gridSize, cubemap); // Choose launch parameters. dim3 blockSize = getLaunchBlockSize(BLOCK_X, BLOCK_Y, p.gridSize); dim3 gridSize = getLaunchGridSize(blockSize, p.gridSize); // Setup tensors torch::Tensor cubemap_grad; p.cubemap = make_cuda_tensor(cubemap, p.gridSize); p.out = make_cuda_tensor(grad, p.gridSize); cubemap_grad = torch::zeros({ p.gridSize.z, p.gridSize.y, p.gridSize.x, cubemap.size(3) }, torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA)); p.cubemap.d_val = (void*)cubemap_grad.data_ptr(); // Launch CUDA kernel. void* args[] = { &p }; NVDR_CHECK_CUDA_ERROR(cudaLaunchKernel((const void*)DiffuseCubemapBwdKernel, gridSize, blockSize, args, 0, stream)); return cubemap_grad; } torch::Tensor specular_bounds(int resolution, float costheta_cutoff) { cudaStream_t stream = at::cuda::getCurrentCUDAStream(); // Extract input parameters. SpecularBoundsKernelParams p; p.costheta_cutoff = costheta_cutoff; p.gridSize = dim3(resolution, resolution, 6); // Allocate output tensors. torch::TensorOptions opts = torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA); torch::Tensor out = torch::zeros({ p.gridSize.z, p.gridSize.y, p.gridSize.x, 6*4 }, opts); // Choose launch parameters. dim3 blockSize = getLaunchBlockSize(BLOCK_X, BLOCK_Y, p.gridSize); dim3 gridSize = getLaunchGridSize(blockSize, p.gridSize); // Setup tensors p.out = make_cuda_tensor(out, p.gridSize); // Launch CUDA kernel. void* args[] = { &p }; NVDR_CHECK_CUDA_ERROR(cudaLaunchKernel((const void*)SpecularBoundsKernel, gridSize, blockSize, args, 0, stream)); return out; } torch::Tensor specular_cubemap_fwd(torch::Tensor cubemap, torch::Tensor bounds, float roughness, float costheta_cutoff) { CHECK_TENSOR(cubemap, 4, 3); CHECK_TENSOR(bounds, 4, 6*4); cudaStream_t stream = at::cuda::getCurrentCUDAStream(); // Extract input parameters. SpecularCubemapKernelParams p; p.roughness = roughness; p.costheta_cutoff = costheta_cutoff; update_grid(p.gridSize, cubemap); // Allocate output tensors. torch::TensorOptions opts = torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA); torch::Tensor out = torch::empty({ p.gridSize.z, p.gridSize.y, p.gridSize.x, 4 }, opts); // Choose launch parameters. dim3 blockSize = getLaunchBlockSize(BLOCK_X, BLOCK_Y, p.gridSize); dim3 gridSize = getLaunchGridSize(blockSize, p.gridSize); // Setup tensors p.cubemap = make_cuda_tensor(cubemap, p.gridSize); p.bounds = make_cuda_tensor(bounds, p.gridSize); p.out = make_cuda_tensor(out, p.gridSize); // Launch CUDA kernel. void* args[] = { &p }; NVDR_CHECK_CUDA_ERROR(cudaLaunchKernel((const void*)SpecularCubemapFwdKernel, gridSize, blockSize, args, 0, stream)); return out; } torch::Tensor specular_cubemap_bwd(torch::Tensor cubemap, torch::Tensor bounds, torch::Tensor grad, float roughness, float costheta_cutoff) { CHECK_TENSOR(cubemap, 4, 3); CHECK_TENSOR(bounds, 4, 6*4); cudaStream_t stream = at::cuda::getCurrentCUDAStream(); // Extract input parameters. SpecularCubemapKernelParams p; p.roughness = roughness; p.costheta_cutoff = costheta_cutoff; update_grid(p.gridSize, cubemap); // Choose launch parameters. dim3 blockSize = getLaunchBlockSize(BLOCK_X, BLOCK_Y, p.gridSize); dim3 gridSize = getLaunchGridSize(blockSize, p.gridSize); // Setup tensors torch::Tensor cubemap_grad; p.cubemap = make_cuda_tensor(cubemap, p.gridSize); p.bounds = make_cuda_tensor(bounds, p.gridSize); p.out = make_cuda_tensor(grad, p.gridSize); cubemap_grad = torch::zeros({ p.gridSize.z, p.gridSize.y, p.gridSize.x, cubemap.size(3) }, torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA)); p.cubemap.d_val = (void*)cubemap_grad.data_ptr(); // Launch CUDA kernel. void* args[] = { &p }; NVDR_CHECK_CUDA_ERROR(cudaLaunchKernel((const void*)SpecularCubemapBwdKernel, gridSize, blockSize, args, 0, stream)); return cubemap_grad; } //------------------------------------------------------------------------ // loss function LossType strToLoss(std::string str) { if (str == "mse") return LOSS_MSE; else if (str == "relmse") return LOSS_RELMSE; else if (str == "smape") return LOSS_SMAPE; else return LOSS_L1; } torch::Tensor image_loss_fwd(torch::Tensor img, torch::Tensor target, std::string loss, std::string tonemapper, bool fp16) { CHECK_TENSOR(img, 4, 3); CHECK_TENSOR(target, 4, 3); cudaStream_t stream = at::cuda::getCurrentCUDAStream(); // Extract input parameters. LossKernelParams p; p.out.fp16 = fp16; p.loss = strToLoss(loss); p.tonemapper = tonemapper == "log_srgb" ? TONEMAPPER_LOG_SRGB : TONEMAPPER_NONE; update_grid(p.gridSize, img, target); // Choose launch parameters. dim3 blockSize = getLaunchBlockSize(BLOCK_X, BLOCK_Y, p.gridSize); dim3 warpSize = getWarpSize(blockSize); dim3 gridSize = getLaunchGridSize(blockSize, p.gridSize); // Allocate output tensors. torch::TensorOptions opts = torch::TensorOptions().dtype(p.out.fp16 ? torch::kBFloat16 : torch::kFloat32).device(torch::kCUDA); torch::Tensor out = torch::empty({ (p.gridSize.z - 1)/ warpSize.z + 1, (p.gridSize.y - 1) / warpSize.y + 1, (p.gridSize.x - 1) / warpSize.x + 1, 1 }, opts); p.img = make_cuda_tensor(img, p.gridSize); p.target = make_cuda_tensor(target, p.gridSize); p.out = make_cuda_tensor(out, p.gridSize); // Launch CUDA kernel. void* args[] = { &p }; NVDR_CHECK_CUDA_ERROR(cudaLaunchKernel((const void*)imgLossFwdKernel, gridSize, blockSize, args, 0, stream)); return out; } std::tuple image_loss_bwd(torch::Tensor img, torch::Tensor target, torch::Tensor grad, std::string loss, std::string tonemapper) { cudaStream_t stream = at::cuda::getCurrentCUDAStream(); // Extract input parameters. LossKernelParams p; p.loss = strToLoss(loss); p.tonemapper = tonemapper == "log_srgb" ? TONEMAPPER_LOG_SRGB : TONEMAPPER_NONE; update_grid(p.gridSize, img, target); // Choose launch parameters. dim3 blockSize = getLaunchBlockSize(BLOCK_X, BLOCK_Y, p.gridSize); dim3 warpSize = getWarpSize(blockSize); dim3 gridSize = getLaunchGridSize(blockSize, p.gridSize); torch::Tensor img_grad, target_grad; p.img = make_cuda_tensor(img, p.gridSize, &img_grad); p.target = make_cuda_tensor(target, p.gridSize, &target_grad); p.out = make_cuda_tensor(grad, p.gridSize); // Launch CUDA kernel. void* args[] = { &p }; NVDR_CHECK_CUDA_ERROR(cudaLaunchKernel((const void*)imgLossBwdKernel, gridSize, blockSize, args, 0, stream)); return std::tuple(img_grad, target_grad); } //------------------------------------------------------------------------ // transform function torch::Tensor xfm_fwd(torch::Tensor points, torch::Tensor matrix, bool isPoints, bool fp16) { CHECK_TENSOR(points, 3, 3); CHECK_TENSOR(matrix, 3, 4); cudaStream_t stream = at::cuda::getCurrentCUDAStream(); // Extract input parameters. XfmKernelParams p; p.out.fp16 = fp16; p.isPoints = isPoints; p.gridSize.x = points.size(1); p.gridSize.y = 1; p.gridSize.z = std::max(matrix.size(0), points.size(0)); // Choose launch parameters. dim3 blockSize(BLOCK_X * BLOCK_Y, 1, 1); dim3 warpSize = getWarpSize(blockSize); dim3 gridSize = getLaunchGridSize(blockSize, p.gridSize); // Allocate output tensors. torch::TensorOptions opts = torch::TensorOptions().dtype(p.out.fp16 ? torch::kBFloat16 : torch::kFloat32).device(torch::kCUDA); torch::Tensor out = isPoints ? torch::empty({ matrix.size(0), points.size(1), 4 }, opts) : torch::empty({ matrix.size(0), points.size(1), 3 }, opts); p.points = make_cuda_tensor(points, p.gridSize); p.matrix = make_cuda_tensor(matrix, p.gridSize); p.out = make_cuda_tensor(out, p.gridSize); // Launch CUDA kernel. void* args[] = { &p }; NVDR_CHECK_CUDA_ERROR(cudaLaunchKernel((const void*)xfmPointsFwdKernel, gridSize, blockSize, args, 0, stream)); return out; } torch::Tensor xfm_bwd(torch::Tensor points, torch::Tensor matrix, torch::Tensor grad, bool isPoints) { cudaStream_t stream = at::cuda::getCurrentCUDAStream(); // Extract input parameters. XfmKernelParams p; p.isPoints = isPoints; p.gridSize.x = points.size(1); p.gridSize.y = 1; p.gridSize.z = std::max(matrix.size(0), points.size(0)); // Choose launch parameters. dim3 blockSize(BLOCK_X * BLOCK_Y, 1, 1); dim3 warpSize = getWarpSize(blockSize); dim3 gridSize = getLaunchGridSize(blockSize, p.gridSize); torch::Tensor points_grad; p.points = make_cuda_tensor(points, p.gridSize, &points_grad); p.matrix = make_cuda_tensor(matrix, p.gridSize); p.out = make_cuda_tensor(grad, p.gridSize); // Launch CUDA kernel. void* args[] = { &p }; NVDR_CHECK_CUDA_ERROR(cudaLaunchKernel((const void*)xfmPointsBwdKernel, gridSize, blockSize, args, 0, stream)); return points_grad; } PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("prepare_shading_normal_fwd", &prepare_shading_normal_fwd, "prepare_shading_normal_fwd"); m.def("prepare_shading_normal_bwd", &prepare_shading_normal_bwd, "prepare_shading_normal_bwd"); m.def("lambert_fwd", &lambert_fwd, "lambert_fwd"); m.def("lambert_bwd", &lambert_bwd, "lambert_bwd"); m.def("frostbite_fwd", &frostbite_fwd, "frostbite_fwd"); m.def("frostbite_bwd", &frostbite_bwd, "frostbite_bwd"); m.def("fresnel_shlick_fwd", &fresnel_shlick_fwd, "fresnel_shlick_fwd"); m.def("fresnel_shlick_bwd", &fresnel_shlick_bwd, "fresnel_shlick_bwd"); m.def("ndf_ggx_fwd", &ndf_ggx_fwd, "ndf_ggx_fwd"); m.def("ndf_ggx_bwd", &ndf_ggx_bwd, "ndf_ggx_bwd"); m.def("lambda_ggx_fwd", &lambda_ggx_fwd, "lambda_ggx_fwd"); m.def("lambda_ggx_bwd", &lambda_ggx_bwd, "lambda_ggx_bwd"); m.def("masking_smith_fwd", &masking_smith_fwd, "masking_smith_fwd"); m.def("masking_smith_bwd", &masking_smith_bwd, "masking_smith_bwd"); m.def("pbr_specular_fwd", &pbr_specular_fwd, "pbr_specular_fwd"); m.def("pbr_specular_bwd", &pbr_specular_bwd, "pbr_specular_bwd"); m.def("pbr_bsdf_fwd", &pbr_bsdf_fwd, "pbr_bsdf_fwd"); m.def("pbr_bsdf_bwd", &pbr_bsdf_bwd, "pbr_bsdf_bwd"); m.def("diffuse_cubemap_fwd", &diffuse_cubemap_fwd, "diffuse_cubemap_fwd"); m.def("diffuse_cubemap_bwd", &diffuse_cubemap_bwd, "diffuse_cubemap_bwd"); m.def("specular_bounds", &specular_bounds, "specular_bounds"); m.def("specular_cubemap_fwd", &specular_cubemap_fwd, "specular_cubemap_fwd"); m.def("specular_cubemap_bwd", &specular_cubemap_bwd, "specular_cubemap_bwd"); m.def("image_loss_fwd", &image_loss_fwd, "image_loss_fwd"); m.def("image_loss_bwd", &image_loss_bwd, "image_loss_bwd"); m.def("xfm_fwd", &xfm_fwd, "xfm_fwd"); m.def("xfm_bwd", &xfm_bwd, "xfm_bwd"); } ================================================ FILE: render/renderutils/c_src/vec3f.h ================================================ /* * Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual * property and proprietary rights in and to this material, related * documentation and any modifications thereto. Any use, reproduction, * disclosure or distribution of this material and related documentation * without an express license agreement from NVIDIA CORPORATION or * its affiliates is strictly prohibited. */ #pragma once struct vec3f { float x, y, z; #ifdef __CUDACC__ __device__ vec3f() { } __device__ vec3f(float v) { x = v; y = v; z = v; } __device__ vec3f(float _x, float _y, float _z) { x = _x; y = _y; z = _z; } __device__ vec3f(float3 v) { x = v.x; y = v.y; z = v.z; } __device__ inline vec3f& operator+=(const vec3f& b) { x += b.x; y += b.y; z += b.z; return *this; } __device__ inline vec3f& operator-=(const vec3f& b) { x -= b.x; y -= b.y; z -= b.z; return *this; } __device__ inline vec3f& operator*=(const vec3f& b) { x *= b.x; y *= b.y; z *= b.z; return *this; } __device__ inline vec3f& operator/=(const vec3f& b) { x /= b.x; y /= b.y; z /= b.z; return *this; } #endif }; #ifdef __CUDACC__ __device__ static inline vec3f operator+(const vec3f& a, const vec3f& b) { return vec3f(a.x + b.x, a.y + b.y, a.z + b.z); } __device__ static inline vec3f operator-(const vec3f& a, const vec3f& b) { return vec3f(a.x - b.x, a.y - b.y, a.z - b.z); } __device__ static inline vec3f operator*(const vec3f& a, const vec3f& b) { return vec3f(a.x * b.x, a.y * b.y, a.z * b.z); } __device__ static inline vec3f operator/(const vec3f& a, const vec3f& b) { return vec3f(a.x / b.x, a.y / b.y, a.z / b.z); } __device__ static inline vec3f operator-(const vec3f& a) { return vec3f(-a.x, -a.y, -a.z); } __device__ static inline float sum(vec3f a) { return a.x + a.y + a.z; } __device__ static inline vec3f cross(vec3f a, vec3f b) { vec3f out; out.x = a.y * b.z - a.z * b.y; out.y = a.z * b.x - a.x * b.z; out.z = a.x * b.y - a.y * b.x; return out; } __device__ static inline void bwdCross(vec3f a, vec3f b, vec3f &d_a, vec3f &d_b, vec3f d_out) { d_a.x += d_out.z * b.y - d_out.y * b.z; d_a.y += d_out.x * b.z - d_out.z * b.x; d_a.z += d_out.y * b.x - d_out.x * b.y; d_b.x += d_out.y * a.z - d_out.z * a.y; d_b.y += d_out.z * a.x - d_out.x * a.z; d_b.z += d_out.x * a.y - d_out.y * a.x; } __device__ static inline float dot(vec3f a, vec3f b) { return a.x * b.x + a.y * b.y + a.z * b.z; } __device__ static inline void bwdDot(vec3f a, vec3f b, vec3f& d_a, vec3f& d_b, float d_out) { d_a.x += d_out * b.x; d_a.y += d_out * b.y; d_a.z += d_out * b.z; d_b.x += d_out * a.x; d_b.y += d_out * a.y; d_b.z += d_out * a.z; } __device__ static inline vec3f reflect(vec3f x, vec3f n) { return n * 2.0f * dot(n, x) - x; } __device__ static inline void bwdReflect(vec3f x, vec3f n, vec3f& d_x, vec3f& d_n, const vec3f d_out) { d_x.x += d_out.x * (2 * n.x * n.x - 1) + d_out.y * (2 * n.x * n.y) + d_out.z * (2 * n.x * n.z); d_x.y += d_out.x * (2 * n.x * n.y) + d_out.y * (2 * n.y * n.y - 1) + d_out.z * (2 * n.y * n.z); d_x.z += d_out.x * (2 * n.x * n.z) + d_out.y * (2 * n.y * n.z) + d_out.z * (2 * n.z * n.z - 1); d_n.x += d_out.x * (2 * (2 * n.x * x.x + n.y * x.y + n.z * x.z)) + d_out.y * (2 * n.y * x.x) + d_out.z * (2 * n.z * x.x); d_n.y += d_out.x * (2 * n.x * x.y) + d_out.y * (2 * (n.x * x.x + 2 * n.y * x.y + n.z * x.z)) + d_out.z * (2 * n.z * x.y); d_n.z += d_out.x * (2 * n.x * x.z) + d_out.y * (2 * n.y * x.z) + d_out.z * (2 * (n.x * x.x + n.y * x.y + 2 * n.z * x.z)); } __device__ static inline vec3f safeNormalize(vec3f v) { float l = sqrtf(v.x * v.x + v.y * v.y + v.z * v.z); return l > 0.0f ? (v / l) : vec3f(0.0f); } __device__ static inline void bwdSafeNormalize(const vec3f v, vec3f& d_v, const vec3f d_out) { float l = sqrtf(v.x * v.x + v.y * v.y + v.z * v.z); if (l > 0.0f) { float fac = 1.0 / powf(v.x * v.x + v.y * v.y + v.z * v.z, 1.5f); d_v.x += (d_out.x * (v.y * v.y + v.z * v.z) - d_out.y * (v.x * v.y) - d_out.z * (v.x * v.z)) * fac; d_v.y += (d_out.y * (v.x * v.x + v.z * v.z) - d_out.x * (v.y * v.x) - d_out.z * (v.y * v.z)) * fac; d_v.z += (d_out.z * (v.x * v.x + v.y * v.y) - d_out.x * (v.z * v.x) - d_out.y * (v.z * v.y)) * fac; } } #endif ================================================ FILE: render/renderutils/c_src/vec4f.h ================================================ /* * Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual * property and proprietary rights in and to this material, related * documentation and any modifications thereto. Any use, reproduction, * disclosure or distribution of this material and related documentation * without an express license agreement from NVIDIA CORPORATION or * its affiliates is strictly prohibited. */ #pragma once struct vec4f { float x, y, z, w; #ifdef __CUDACC__ __device__ vec4f() { } __device__ vec4f(float v) { x = v; y = v; z = v; w = v; } __device__ vec4f(float _x, float _y, float _z, float _w) { x = _x; y = _y; z = _z; w = _w; } __device__ vec4f(float4 v) { x = v.x; y = v.y; z = v.z; w = v.w; } #endif }; ================================================ FILE: render/renderutils/loss.py ================================================ # Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual # property and proprietary rights in and to this material, related # documentation and any modifications thereto. Any use, reproduction, # disclosure or distribution of this material and related documentation # without an express license agreement from NVIDIA CORPORATION or # its affiliates is strictly prohibited. import torch #---------------------------------------------------------------------------- # HDR image losses #---------------------------------------------------------------------------- def _tonemap_srgb(f, exposure=5): f = f * exposure return torch.where(f > 0.0031308, torch.pow(torch.clamp(f, min=0.0031308), 1.0/2.4)*1.055 - 0.055, 12.92*f) def _SMAPE(img, target, eps=0.01): nom = torch.abs(img - target) denom = torch.abs(img) + torch.abs(target) + 0.01 return torch.mean(nom / denom) def _RELMSE(img, target, eps=0.1): nom = (img - target) * (img - target) denom = img * img + target * target + 0.1 return torch.mean(nom / denom) def image_loss_fn(img, target, loss, tonemapper): if tonemapper == 'log_srgb': img = _tonemap_srgb(torch.log(torch.clamp(img, min=0, max=65535) + 1)) target = _tonemap_srgb(torch.log(torch.clamp(target, min=0, max=65535) + 1)) if loss == 'mse': return torch.nn.functional.mse_loss(img, target) elif loss == 'smape': return _SMAPE(img, target) elif loss == 'relmse': return _RELMSE(img, target) else: return torch.nn.functional.l1_loss(img, target) ================================================ FILE: render/renderutils/ops.py ================================================ # Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual # property and proprietary rights in and to this material, related # documentation and any modifications thereto. Any use, reproduction, # disclosure or distribution of this material and related documentation # without an express license agreement from NVIDIA CORPORATION or # its affiliates is strictly prohibited. import numpy as np import os import sys import torch import torch.utils.cpp_extension from .bsdf import * from .loss import * #---------------------------------------------------------------------------- # C++/Cuda plugin compiler/loader. _cached_plugin = None def _get_plugin(): # Return cached plugin if already loaded. global _cached_plugin if _cached_plugin is not None: return _cached_plugin # Make sure we can find the necessary compiler and libary binaries. if os.name == 'nt': def find_cl_path(): import glob for edition in ['Enterprise', 'Professional', 'BuildTools', 'Community']: paths = sorted(glob.glob(r"C:\Program Files (x86)\Microsoft Visual Studio\*\%s\VC\Tools\MSVC\*\bin\Hostx64\x64" % edition), reverse=True) if paths: return paths[0] # If cl.exe is not on path, try to find it. if os.system("where cl.exe >nul 2>nul") != 0: cl_path = find_cl_path() if cl_path is None: raise RuntimeError("Could not locate a supported Microsoft Visual C++ installation") os.environ['PATH'] += ';' + cl_path # Compiler options. opts = ['-DNVDR_TORCH'] # Linker options. if os.name == 'posix': ldflags = ['-lcuda', '-lnvrtc'] elif os.name == 'nt': ldflags = ['cuda.lib', 'advapi32.lib', 'nvrtc.lib'] # List of sources. source_files = [ 'c_src/mesh.cu', 'c_src/loss.cu', 'c_src/bsdf.cu', 'c_src/normal.cu', 'c_src/cubemap.cu', 'c_src/common.cpp', 'c_src/torch_bindings.cpp' ] # Some containers set this to contain old architectures that won't compile. We only need the one installed in the machine. os.environ['TORCH_CUDA_ARCH_LIST'] = '' # Try to detect if a stray lock file is left in cache directory and show a warning. This sometimes happens on Windows if the build is interrupted at just the right moment. try: lock_fn = os.path.join(torch.utils.cpp_extension._get_build_directory('renderutils_plugin', False), 'lock') if os.path.exists(lock_fn): print("Warning: Lock file exists in build directory: '%s'" % lock_fn) except: pass # Compile and load. build_dir = os.path.join(os. path. dirname(__file__), 'build') os.makedirs(build_dir, exist_ok=True) source_paths = [os.path.join(os.path.dirname(__file__), fn) for fn in source_files] torch.utils.cpp_extension.load(name='renderutils_plugin', sources=source_paths, extra_cflags=opts, build_directory=build_dir, extra_cuda_cflags=opts, extra_ldflags=ldflags, with_cuda=True, verbose=True) # Import, cache, and return the compiled module. import renderutils_plugin _cached_plugin = renderutils_plugin return _cached_plugin #---------------------------------------------------------------------------- # Internal kernels, just used for testing functionality class _fresnel_shlick_func(torch.autograd.Function): @staticmethod def forward(ctx, f0, f90, cosTheta): out = _get_plugin().fresnel_shlick_fwd(f0, f90, cosTheta, False) ctx.save_for_backward(f0, f90, cosTheta) return out @staticmethod def backward(ctx, dout): f0, f90, cosTheta = ctx.saved_variables return _get_plugin().fresnel_shlick_bwd(f0, f90, cosTheta, dout) + (None,) def _fresnel_shlick(f0, f90, cosTheta, use_python=False): if use_python: out = bsdf_fresnel_shlick(f0, f90, cosTheta) else: out = _fresnel_shlick_func.apply(f0, f90, cosTheta) if torch.is_anomaly_enabled(): assert torch.all(torch.isfinite(out)), "Output of _fresnel_shlick contains inf or NaN" return out class _ndf_ggx_func(torch.autograd.Function): @staticmethod def forward(ctx, alphaSqr, cosTheta): out = _get_plugin().ndf_ggx_fwd(alphaSqr, cosTheta, False) ctx.save_for_backward(alphaSqr, cosTheta) return out @staticmethod def backward(ctx, dout): alphaSqr, cosTheta = ctx.saved_variables return _get_plugin().ndf_ggx_bwd(alphaSqr, cosTheta, dout) + (None,) def _ndf_ggx(alphaSqr, cosTheta, use_python=False): if use_python: out = bsdf_ndf_ggx(alphaSqr, cosTheta) else: out = _ndf_ggx_func.apply(alphaSqr, cosTheta) if torch.is_anomaly_enabled(): assert torch.all(torch.isfinite(out)), "Output of _ndf_ggx contains inf or NaN" return out class _lambda_ggx_func(torch.autograd.Function): @staticmethod def forward(ctx, alphaSqr, cosTheta): out = _get_plugin().lambda_ggx_fwd(alphaSqr, cosTheta, False) ctx.save_for_backward(alphaSqr, cosTheta) return out @staticmethod def backward(ctx, dout): alphaSqr, cosTheta = ctx.saved_variables return _get_plugin().lambda_ggx_bwd(alphaSqr, cosTheta, dout) + (None,) def _lambda_ggx(alphaSqr, cosTheta, use_python=False): if use_python: out = bsdf_lambda_ggx(alphaSqr, cosTheta) else: out = _lambda_ggx_func.apply(alphaSqr, cosTheta) if torch.is_anomaly_enabled(): assert torch.all(torch.isfinite(out)), "Output of _lambda_ggx contains inf or NaN" return out class _masking_smith_func(torch.autograd.Function): @staticmethod def forward(ctx, alphaSqr, cosThetaI, cosThetaO): ctx.save_for_backward(alphaSqr, cosThetaI, cosThetaO) out = _get_plugin().masking_smith_fwd(alphaSqr, cosThetaI, cosThetaO, False) return out @staticmethod def backward(ctx, dout): alphaSqr, cosThetaI, cosThetaO = ctx.saved_variables return _get_plugin().masking_smith_bwd(alphaSqr, cosThetaI, cosThetaO, dout) + (None,) def _masking_smith(alphaSqr, cosThetaI, cosThetaO, use_python=False): if use_python: out = bsdf_masking_smith_ggx_correlated(alphaSqr, cosThetaI, cosThetaO) else: out = _masking_smith_func.apply(alphaSqr, cosThetaI, cosThetaO) if torch.is_anomaly_enabled(): assert torch.all(torch.isfinite(out)), "Output of _masking_smith contains inf or NaN" return out #---------------------------------------------------------------------------- # Shading normal setup (bump mapping + bent normals) class _prepare_shading_normal_func(torch.autograd.Function): @staticmethod def forward(ctx, pos, view_pos, perturbed_nrm, smooth_nrm, smooth_tng, geom_nrm, two_sided_shading, opengl): ctx.two_sided_shading, ctx.opengl = two_sided_shading, opengl out = _get_plugin().prepare_shading_normal_fwd(pos, view_pos, perturbed_nrm, smooth_nrm, smooth_tng, geom_nrm, two_sided_shading, opengl, False) ctx.save_for_backward(pos, view_pos, perturbed_nrm, smooth_nrm, smooth_tng, geom_nrm) return out @staticmethod def backward(ctx, dout): pos, view_pos, perturbed_nrm, smooth_nrm, smooth_tng, geom_nrm = ctx.saved_variables return _get_plugin().prepare_shading_normal_bwd(pos, view_pos, perturbed_nrm, smooth_nrm, smooth_tng, geom_nrm, dout, ctx.two_sided_shading, ctx.opengl) + (None, None, None) def prepare_shading_normal(pos, view_pos, perturbed_nrm, smooth_nrm, smooth_tng, geom_nrm, two_sided_shading=True, opengl=True, use_python=False): '''Takes care of all corner cases and produces a final normal used for shading: - Constructs tangent space - Flips normal direction based on geometric normal for two sided Shading - Perturbs shading normal by normal map - Bends backfacing normals towards the camera to avoid shading artifacts All tensors assume a shape of [minibatch_size, height, width, 3] or broadcastable equivalent. Args: pos: World space g-buffer position. view_pos: Camera position in world space (typically using broadcasting). perturbed_nrm: Trangent-space normal perturbation from normal map lookup. smooth_nrm: Interpolated vertex normals. smooth_tng: Interpolated vertex tangents. geom_nrm: Geometric (face) normals. two_sided_shading: Use one/two sided shading opengl: Use OpenGL/DirectX normal map conventions use_python: Use PyTorch implementation (for validation) Returns: Final shading normal ''' if perturbed_nrm is None: perturbed_nrm = torch.tensor([0, 0, 1], dtype=torch.float32, device='cuda', requires_grad=False)[None, None, None, ...] if use_python: out = bsdf_prepare_shading_normal(pos, view_pos, perturbed_nrm, smooth_nrm, smooth_tng, geom_nrm, two_sided_shading, opengl) else: out = _prepare_shading_normal_func.apply(pos, view_pos, perturbed_nrm, smooth_nrm, smooth_tng, geom_nrm, two_sided_shading, opengl) if torch.is_anomaly_enabled(): assert torch.all(torch.isfinite(out)), "Output of prepare_shading_normal contains inf or NaN" return out #---------------------------------------------------------------------------- # BSDF functions class _lambert_func(torch.autograd.Function): @staticmethod def forward(ctx, nrm, wi): out = _get_plugin().lambert_fwd(nrm, wi, False) ctx.save_for_backward(nrm, wi) return out @staticmethod def backward(ctx, dout): nrm, wi = ctx.saved_variables return _get_plugin().lambert_bwd(nrm, wi, dout) + (None,) def lambert(nrm, wi, use_python=False): '''Lambertian bsdf. All tensors assume a shape of [minibatch_size, height, width, 3] or broadcastable equivalent. Args: nrm: World space shading normal. wi: World space light vector. use_python: Use PyTorch implementation (for validation) Returns: Shaded diffuse value with shape [minibatch_size, height, width, 1] ''' if use_python: out = bsdf_lambert(nrm, wi) else: out = _lambert_func.apply(nrm, wi) if torch.is_anomaly_enabled(): assert torch.all(torch.isfinite(out)), "Output of lambert contains inf or NaN" return out class _frostbite_diffuse_func(torch.autograd.Function): @staticmethod def forward(ctx, nrm, wi, wo, linearRoughness): out = _get_plugin().frostbite_fwd(nrm, wi, wo, linearRoughness, False) ctx.save_for_backward(nrm, wi, wo, linearRoughness) return out @staticmethod def backward(ctx, dout): nrm, wi, wo, linearRoughness = ctx.saved_variables return _get_plugin().frostbite_bwd(nrm, wi, wo, linearRoughness, dout) + (None,) def frostbite_diffuse(nrm, wi, wo, linearRoughness, use_python=False): '''Frostbite, normalized Disney Diffuse bsdf. All tensors assume a shape of [minibatch_size, height, width, 3] or broadcastable equivalent. Args: nrm: World space shading normal. wi: World space light vector. wo: World space camera vector. linearRoughness: Material roughness use_python: Use PyTorch implementation (for validation) Returns: Shaded diffuse value with shape [minibatch_size, height, width, 1] ''' if use_python: out = bsdf_frostbite(nrm, wi, wo, linearRoughness) else: out = _frostbite_diffuse_func.apply(nrm, wi, wo, linearRoughness) if torch.is_anomaly_enabled(): assert torch.all(torch.isfinite(out)), "Output of lambert contains inf or NaN" return out class _pbr_specular_func(torch.autograd.Function): @staticmethod def forward(ctx, col, nrm, wo, wi, alpha, min_roughness): ctx.save_for_backward(col, nrm, wo, wi, alpha) ctx.min_roughness = min_roughness out = _get_plugin().pbr_specular_fwd(col, nrm, wo, wi, alpha, min_roughness, False) return out @staticmethod def backward(ctx, dout): col, nrm, wo, wi, alpha = ctx.saved_variables return _get_plugin().pbr_specular_bwd(col, nrm, wo, wi, alpha, ctx.min_roughness, dout) + (None, None) def pbr_specular(col, nrm, wo, wi, alpha, min_roughness=0.08, use_python=False): '''Physically-based specular bsdf. All tensors assume a shape of [minibatch_size, height, width, 3] or broadcastable equivalent unless otherwise noted. Args: col: Specular lobe color nrm: World space shading normal. wo: World space camera vector. wi: World space light vector alpha: Specular roughness parameter with shape [minibatch_size, height, width, 1] min_roughness: Scalar roughness clamping threshold use_python: Use PyTorch implementation (for validation) Returns: Shaded specular color ''' if use_python: out = bsdf_pbr_specular(col, nrm, wo, wi, alpha, min_roughness=min_roughness) else: out = _pbr_specular_func.apply(col, nrm, wo, wi, alpha, min_roughness) if torch.is_anomaly_enabled(): assert torch.all(torch.isfinite(out)), "Output of pbr_specular contains inf or NaN" return out class _pbr_bsdf_func(torch.autograd.Function): @staticmethod def forward(ctx, kd, arm, pos, nrm, view_pos, light_pos, min_roughness, BSDF): ctx.save_for_backward(kd, arm, pos, nrm, view_pos, light_pos) ctx.min_roughness = min_roughness ctx.BSDF = BSDF out = _get_plugin().pbr_bsdf_fwd(kd, arm, pos, nrm, view_pos, light_pos, min_roughness, BSDF, False) return out @staticmethod def backward(ctx, dout): kd, arm, pos, nrm, view_pos, light_pos = ctx.saved_variables return _get_plugin().pbr_bsdf_bwd(kd, arm, pos, nrm, view_pos, light_pos, ctx.min_roughness, ctx.BSDF, dout) + (None, None, None) def pbr_bsdf(kd, arm, pos, nrm, view_pos, light_pos, min_roughness=0.08, bsdf="lambert", use_python=False): '''Physically-based bsdf, both diffuse & specular lobes All tensors assume a shape of [minibatch_size, height, width, 3] or broadcastable equivalent unless otherwise noted. Args: kd: Diffuse albedo. arm: Specular parameters (attenuation, linear roughness, metalness). pos: World space position. nrm: World space shading normal. view_pos: Camera position in world space, typically using broadcasting. light_pos: Light position in world space, typically using broadcasting. min_roughness: Scalar roughness clamping threshold bsdf: Controls diffuse BSDF, can be either 'lambert' or 'frostbite' use_python: Use PyTorch implementation (for validation) Returns: Shaded color. ''' BSDF = 0 if bsdf == 'frostbite': BSDF = 1 if use_python: out = bsdf_pbr(kd, arm, pos, nrm, view_pos, light_pos, min_roughness, BSDF) else: out = _pbr_bsdf_func.apply(kd, arm, pos, nrm, view_pos, light_pos, min_roughness, BSDF) if torch.is_anomaly_enabled(): assert torch.all(torch.isfinite(out)), "Output of pbr_bsdf contains inf or NaN" return out #---------------------------------------------------------------------------- # cubemap filter with filtering across edges class _diffuse_cubemap_func(torch.autograd.Function): @staticmethod def forward(ctx, cubemap): out = _get_plugin().diffuse_cubemap_fwd(cubemap) ctx.save_for_backward(cubemap) return out @staticmethod def backward(ctx, dout): cubemap, = ctx.saved_variables cubemap_grad = _get_plugin().diffuse_cubemap_bwd(cubemap, dout) return cubemap_grad, None def diffuse_cubemap(cubemap, use_python=False): if use_python: assert False else: out = _diffuse_cubemap_func.apply(cubemap) if torch.is_anomaly_enabled(): assert torch.all(torch.isfinite(out)), "Output of diffuse_cubemap contains inf or NaN" return out class _specular_cubemap(torch.autograd.Function): @staticmethod def forward(ctx, cubemap, roughness, costheta_cutoff, bounds): out = _get_plugin().specular_cubemap_fwd(cubemap, bounds, roughness, costheta_cutoff) ctx.save_for_backward(cubemap, bounds) ctx.roughness, ctx.theta_cutoff = roughness, costheta_cutoff return out @staticmethod def backward(ctx, dout): cubemap, bounds = ctx.saved_variables cubemap_grad = _get_plugin().specular_cubemap_bwd(cubemap, bounds, dout, ctx.roughness, ctx.theta_cutoff) return cubemap_grad, None, None, None # Compute the bounds of the GGX NDF lobe to retain "cutoff" percent of the energy def __ndfBounds(res, roughness, cutoff): def ndfGGX(alphaSqr, costheta): costheta = np.clip(costheta, 0.0, 1.0) d = (costheta * alphaSqr - costheta) * costheta + 1.0 return alphaSqr / (d * d * np.pi) # Sample out cutoff angle nSamples = 1000000 costheta = np.cos(np.linspace(0, np.pi/2.0, nSamples)) D = np.cumsum(ndfGGX(roughness**4, costheta)) idx = np.argmax(D >= D[..., -1] * cutoff) # Brute force compute lookup table with bounds bounds = _get_plugin().specular_bounds(res, costheta[idx]) return costheta[idx], bounds __ndfBoundsDict = {} def specular_cubemap(cubemap, roughness, cutoff=0.99, use_python=False): assert cubemap.shape[0] == 6 and cubemap.shape[1] == cubemap.shape[2], "Bad shape for cubemap tensor: %s" % str(cubemap.shape) if use_python: assert False else: key = (cubemap.shape[1], roughness, cutoff) if key not in __ndfBoundsDict: __ndfBoundsDict[key] = __ndfBounds(*key) out = _specular_cubemap.apply(cubemap, roughness, *__ndfBoundsDict[key]) if torch.is_anomaly_enabled(): assert torch.all(torch.isfinite(out)), "Output of specular_cubemap contains inf or NaN" return out[..., 0:3] / out[..., 3:] #---------------------------------------------------------------------------- # Fast image loss function class _image_loss_func(torch.autograd.Function): @staticmethod def forward(ctx, img, target, loss, tonemapper): ctx.loss, ctx.tonemapper = loss, tonemapper ctx.save_for_backward(img, target) out = _get_plugin().image_loss_fwd(img, target, loss, tonemapper, False) return out @staticmethod def backward(ctx, dout): img, target = ctx.saved_variables return _get_plugin().image_loss_bwd(img, target, dout, ctx.loss, ctx.tonemapper) + (None, None, None) def image_loss(img, target, loss='l1', tonemapper='none', use_python=False): '''Compute HDR image loss. Combines tonemapping and loss into a single kernel for better perf. All tensors assume a shape of [minibatch_size, height, width, 3] or broadcastable equivalent unless otherwise noted. Args: img: Input image. target: Target (reference) image. loss: Type of loss. Valid options are ['l1', 'mse', 'smape', 'relmse'] tonemapper: Tonemapping operations. Valid options are ['none', 'log_srgb'] use_python: Use PyTorch implementation (for validation) Returns: Image space loss (scalar value). ''' if use_python: out = image_loss_fn(img, target, loss, tonemapper) else: out = _image_loss_func.apply(img, target, loss, tonemapper) out = torch.sum(out) / (img.shape[0]*img.shape[1]*img.shape[2]) if torch.is_anomaly_enabled(): assert torch.all(torch.isfinite(out)), "Output of image_loss contains inf or NaN" return out #---------------------------------------------------------------------------- # Transform points function class _xfm_func(torch.autograd.Function): @staticmethod def forward(ctx, points, matrix, isPoints): ctx.save_for_backward(points, matrix) ctx.isPoints = isPoints return _get_plugin().xfm_fwd(points, matrix, isPoints, False) @staticmethod def backward(ctx, dout): points, matrix = ctx.saved_variables return (_get_plugin().xfm_bwd(points, matrix, dout, ctx.isPoints),) + (None, None, None) def xfm_points(points, matrix, use_python=False): '''Transform points. Args: points: Tensor containing 3D points with shape [minibatch_size, num_vertices, 3] or [1, num_vertices, 3] matrix: A 4x4 transform matrix with shape [minibatch_size, 4, 4] use_python: Use PyTorch's torch.matmul (for validation) Returns: Transformed points in homogeneous 4D with shape [minibatch_size, num_vertices, 4]. ''' if use_python: out = torch.matmul(torch.nn.functional.pad(points, pad=(0,1), mode='constant', value=1.0), torch.transpose(matrix, 1, 2)) else: out = _xfm_func.apply(points, matrix, True) if torch.is_anomaly_enabled(): assert torch.all(torch.isfinite(out)), "Output of xfm_points contains inf or NaN" return out def xfm_vectors(vectors, matrix, use_python=False): '''Transform vectors. Args: vectors: Tensor containing 3D vectors with shape [minibatch_size, num_vertices, 3] or [1, num_vertices, 3] matrix: A 4x4 transform matrix with shape [minibatch_size, 4, 4] use_python: Use PyTorch's torch.matmul (for validation) Returns: Transformed vectors in homogeneous 4D with shape [minibatch_size, num_vertices, 4]. ''' if use_python: out = torch.matmul(torch.nn.functional.pad(vectors, pad=(0,1), mode='constant', value=0.0), torch.transpose(matrix, 1, 2))[..., 0:3].contiguous() else: out = _xfm_func.apply(vectors, matrix, False) if torch.is_anomaly_enabled(): assert torch.all(torch.isfinite(out)), "Output of xfm_vectors contains inf or NaN" return out ================================================ FILE: render/renderutils/tests/test_bsdf.py ================================================ # Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual # property and proprietary rights in and to this material, related # documentation and any modifications thereto. Any use, reproduction, # disclosure or distribution of this material and related documentation # without an express license agreement from NVIDIA CORPORATION or # its affiliates is strictly prohibited. import torch import os import sys sys.path.insert(0, os.path.join(sys.path[0], '../..')) import renderutils as ru RES = 4 DTYPE = torch.float32 def relative_loss(name, ref, cuda): ref = ref.float() cuda = cuda.float() print(name, torch.max(torch.abs(ref - cuda) / torch.abs(ref + 1e-7)).item()) def test_normal(): pos_cuda = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True) pos_ref = pos_cuda.clone().detach().requires_grad_(True) view_pos_cuda = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True) view_pos_ref = view_pos_cuda.clone().detach().requires_grad_(True) perturbed_nrm_cuda = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True) perturbed_nrm_ref = perturbed_nrm_cuda.clone().detach().requires_grad_(True) smooth_nrm_cuda = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True) smooth_nrm_ref = smooth_nrm_cuda.clone().detach().requires_grad_(True) smooth_tng_cuda = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True) smooth_tng_ref = smooth_tng_cuda.clone().detach().requires_grad_(True) geom_nrm_cuda = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True) geom_nrm_ref = geom_nrm_cuda.clone().detach().requires_grad_(True) target = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda') ref = ru.prepare_shading_normal(pos_ref, view_pos_ref, perturbed_nrm_ref, smooth_nrm_ref, smooth_tng_ref, geom_nrm_ref, True, use_python=True) ref_loss = torch.nn.MSELoss()(ref, target) ref_loss.backward() cuda = ru.prepare_shading_normal(pos_cuda, view_pos_cuda, perturbed_nrm_cuda, smooth_nrm_cuda, smooth_tng_cuda, geom_nrm_cuda, True) cuda_loss = torch.nn.MSELoss()(cuda, target) cuda_loss.backward() print("-------------------------------------------------------------") print(" bent normal") print("-------------------------------------------------------------") relative_loss("res:", ref, cuda) relative_loss("pos:", pos_ref.grad, pos_cuda.grad) relative_loss("view_pos:", view_pos_ref.grad, view_pos_cuda.grad) relative_loss("perturbed_nrm:", perturbed_nrm_ref.grad, perturbed_nrm_cuda.grad) relative_loss("smooth_nrm:", smooth_nrm_ref.grad, smooth_nrm_cuda.grad) relative_loss("smooth_tng:", smooth_tng_ref.grad, smooth_tng_cuda.grad) relative_loss("geom_nrm:", geom_nrm_ref.grad, geom_nrm_cuda.grad) def test_schlick(): f0_cuda = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True) f0_ref = f0_cuda.clone().detach().requires_grad_(True) f90_cuda = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True) f90_ref = f90_cuda.clone().detach().requires_grad_(True) cosT_cuda = torch.rand(1, RES, RES, 1, dtype=DTYPE, device='cuda', requires_grad=True) * 2.0 cosT_cuda = cosT_cuda.clone().detach().requires_grad_(True) cosT_ref = cosT_cuda.clone().detach().requires_grad_(True) target = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda') ref = ru._fresnel_shlick(f0_ref, f90_ref, cosT_ref, use_python=True) ref_loss = torch.nn.MSELoss()(ref, target) ref_loss.backward() cuda = ru._fresnel_shlick(f0_cuda, f90_cuda, cosT_cuda) cuda_loss = torch.nn.MSELoss()(cuda, target) cuda_loss.backward() print("-------------------------------------------------------------") print(" Fresnel shlick") print("-------------------------------------------------------------") relative_loss("res:", ref, cuda) relative_loss("f0:", f0_ref.grad, f0_cuda.grad) relative_loss("f90:", f90_ref.grad, f90_cuda.grad) relative_loss("cosT:", cosT_ref.grad, cosT_cuda.grad) def test_ndf_ggx(): alphaSqr_cuda = torch.rand(1, RES, RES, 1, dtype=DTYPE, device='cuda', requires_grad=True) alphaSqr_cuda = alphaSqr_cuda.clone().detach().requires_grad_(True) alphaSqr_ref = alphaSqr_cuda.clone().detach().requires_grad_(True) cosT_cuda = torch.rand(1, RES, RES, 1, dtype=DTYPE, device='cuda', requires_grad=True) * 3.0 - 1 cosT_cuda = cosT_cuda.clone().detach().requires_grad_(True) cosT_ref = cosT_cuda.clone().detach().requires_grad_(True) target = torch.rand(1, RES, RES, 1, dtype=DTYPE, device='cuda') ref = ru._ndf_ggx(alphaSqr_ref, cosT_ref, use_python=True) ref_loss = torch.nn.MSELoss()(ref, target) ref_loss.backward() cuda = ru._ndf_ggx(alphaSqr_cuda, cosT_cuda) cuda_loss = torch.nn.MSELoss()(cuda, target) cuda_loss.backward() print("-------------------------------------------------------------") print(" Ndf GGX") print("-------------------------------------------------------------") relative_loss("res:", ref, cuda) relative_loss("alpha:", alphaSqr_ref.grad, alphaSqr_cuda.grad) relative_loss("cosT:", cosT_ref.grad, cosT_cuda.grad) def test_lambda_ggx(): alphaSqr_cuda = torch.rand(1, RES, RES, 1, dtype=DTYPE, device='cuda', requires_grad=True) alphaSqr_ref = alphaSqr_cuda.clone().detach().requires_grad_(True) cosT_cuda = torch.rand(1, RES, RES, 1, dtype=DTYPE, device='cuda', requires_grad=True) * 3.0 - 1 cosT_cuda = cosT_cuda.clone().detach().requires_grad_(True) cosT_ref = cosT_cuda.clone().detach().requires_grad_(True) target = torch.rand(1, RES, RES, 1, dtype=DTYPE, device='cuda') ref = ru._lambda_ggx(alphaSqr_ref, cosT_ref, use_python=True) ref_loss = torch.nn.MSELoss()(ref, target) ref_loss.backward() cuda = ru._lambda_ggx(alphaSqr_cuda, cosT_cuda) cuda_loss = torch.nn.MSELoss()(cuda, target) cuda_loss.backward() print("-------------------------------------------------------------") print(" Lambda GGX") print("-------------------------------------------------------------") relative_loss("res:", ref, cuda) relative_loss("alpha:", alphaSqr_ref.grad, alphaSqr_cuda.grad) relative_loss("cosT:", cosT_ref.grad, cosT_cuda.grad) def test_masking_smith(): alphaSqr_cuda = torch.rand(1, RES, RES, 1, dtype=DTYPE, device='cuda', requires_grad=True) alphaSqr_ref = alphaSqr_cuda.clone().detach().requires_grad_(True) cosThetaI_cuda = torch.rand(1, RES, RES, 1, dtype=DTYPE, device='cuda', requires_grad=True) cosThetaI_ref = cosThetaI_cuda.clone().detach().requires_grad_(True) cosThetaO_cuda = torch.rand(1, RES, RES, 1, dtype=DTYPE, device='cuda', requires_grad=True) cosThetaO_ref = cosThetaO_cuda.clone().detach().requires_grad_(True) target = torch.rand(1, RES, RES, 1, dtype=DTYPE, device='cuda') ref = ru._masking_smith(alphaSqr_ref, cosThetaI_ref, cosThetaO_ref, use_python=True) ref_loss = torch.nn.MSELoss()(ref, target) ref_loss.backward() cuda = ru._masking_smith(alphaSqr_cuda, cosThetaI_cuda, cosThetaO_cuda) cuda_loss = torch.nn.MSELoss()(cuda, target) cuda_loss.backward() print("-------------------------------------------------------------") print(" Smith masking term") print("-------------------------------------------------------------") relative_loss("res:", ref, cuda) relative_loss("alpha:", alphaSqr_ref.grad, alphaSqr_cuda.grad) relative_loss("cosThetaI:", cosThetaI_ref.grad, cosThetaI_cuda.grad) relative_loss("cosThetaO:", cosThetaO_ref.grad, cosThetaO_cuda.grad) def test_lambert(): normals_cuda = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True) normals_ref = normals_cuda.clone().detach().requires_grad_(True) wi_cuda = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True) wi_ref = wi_cuda.clone().detach().requires_grad_(True) target = torch.rand(1, RES, RES, 1, dtype=DTYPE, device='cuda') ref = ru.lambert(normals_ref, wi_ref, use_python=True) ref_loss = torch.nn.MSELoss()(ref, target) ref_loss.backward() cuda = ru.lambert(normals_cuda, wi_cuda) cuda_loss = torch.nn.MSELoss()(cuda, target) cuda_loss.backward() print("-------------------------------------------------------------") print(" Lambert") print("-------------------------------------------------------------") relative_loss("res:", ref, cuda) relative_loss("nrm:", normals_ref.grad, normals_cuda.grad) relative_loss("wi:", wi_ref.grad, wi_cuda.grad) def test_frostbite(): normals_cuda = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True) normals_ref = normals_cuda.clone().detach().requires_grad_(True) wi_cuda = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True) wi_ref = wi_cuda.clone().detach().requires_grad_(True) wo_cuda = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True) wo_ref = wo_cuda.clone().detach().requires_grad_(True) rough_cuda = torch.rand(1, RES, RES, 1, dtype=DTYPE, device='cuda', requires_grad=True) rough_ref = rough_cuda.clone().detach().requires_grad_(True) target = torch.rand(1, RES, RES, 1, dtype=DTYPE, device='cuda') ref = ru.frostbite_diffuse(normals_ref, wi_ref, wo_ref, rough_ref, use_python=True) ref_loss = torch.nn.MSELoss()(ref, target) ref_loss.backward() cuda = ru.frostbite_diffuse(normals_cuda, wi_cuda, wo_cuda, rough_cuda) cuda_loss = torch.nn.MSELoss()(cuda, target) cuda_loss.backward() print("-------------------------------------------------------------") print(" Frostbite") print("-------------------------------------------------------------") relative_loss("res:", ref, cuda) relative_loss("nrm:", normals_ref.grad, normals_cuda.grad) relative_loss("wo:", wo_ref.grad, wo_cuda.grad) relative_loss("wi:", wi_ref.grad, wi_cuda.grad) relative_loss("rough:", rough_ref.grad, rough_cuda.grad) def test_pbr_specular(): col_cuda = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True) col_ref = col_cuda.clone().detach().requires_grad_(True) nrm_cuda = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True) nrm_ref = nrm_cuda.clone().detach().requires_grad_(True) wi_cuda = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True) wi_ref = wi_cuda.clone().detach().requires_grad_(True) wo_cuda = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True) wo_ref = wo_cuda.clone().detach().requires_grad_(True) alpha_cuda = torch.rand(1, RES, RES, 1, dtype=DTYPE, device='cuda', requires_grad=True) alpha_ref = alpha_cuda.clone().detach().requires_grad_(True) target = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda') ref = ru.pbr_specular(col_ref, nrm_ref, wo_ref, wi_ref, alpha_ref, use_python=True) ref_loss = torch.nn.MSELoss()(ref, target) ref_loss.backward() cuda = ru.pbr_specular(col_cuda, nrm_cuda, wo_cuda, wi_cuda, alpha_cuda) cuda_loss = torch.nn.MSELoss()(cuda, target) cuda_loss.backward() print("-------------------------------------------------------------") print(" Pbr specular") print("-------------------------------------------------------------") relative_loss("res:", ref, cuda) if col_ref.grad is not None: relative_loss("col:", col_ref.grad, col_cuda.grad) if nrm_ref.grad is not None: relative_loss("nrm:", nrm_ref.grad, nrm_cuda.grad) if wi_ref.grad is not None: relative_loss("wi:", wi_ref.grad, wi_cuda.grad) if wo_ref.grad is not None: relative_loss("wo:", wo_ref.grad, wo_cuda.grad) if alpha_ref.grad is not None: relative_loss("alpha:", alpha_ref.grad, alpha_cuda.grad) def test_pbr_bsdf(bsdf): kd_cuda = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True) kd_ref = kd_cuda.clone().detach().requires_grad_(True) arm_cuda = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True) arm_ref = arm_cuda.clone().detach().requires_grad_(True) pos_cuda = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True) pos_ref = pos_cuda.clone().detach().requires_grad_(True) nrm_cuda = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True) nrm_ref = nrm_cuda.clone().detach().requires_grad_(True) view_cuda = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True) view_ref = view_cuda.clone().detach().requires_grad_(True) light_cuda = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True) light_ref = light_cuda.clone().detach().requires_grad_(True) target = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda') ref = ru.pbr_bsdf(kd_ref, arm_ref, pos_ref, nrm_ref, view_ref, light_ref, use_python=True, bsdf=bsdf) ref_loss = torch.nn.MSELoss()(ref, target) ref_loss.backward() cuda = ru.pbr_bsdf(kd_cuda, arm_cuda, pos_cuda, nrm_cuda, view_cuda, light_cuda, bsdf=bsdf) cuda_loss = torch.nn.MSELoss()(cuda, target) cuda_loss.backward() print("-------------------------------------------------------------") print(" Pbr BSDF") print("-------------------------------------------------------------") relative_loss("res:", ref, cuda) if kd_ref.grad is not None: relative_loss("kd:", kd_ref.grad, kd_cuda.grad) if arm_ref.grad is not None: relative_loss("arm:", arm_ref.grad, arm_cuda.grad) if pos_ref.grad is not None: relative_loss("pos:", pos_ref.grad, pos_cuda.grad) if nrm_ref.grad is not None: relative_loss("nrm:", nrm_ref.grad, nrm_cuda.grad) if view_ref.grad is not None: relative_loss("view:", view_ref.grad, view_cuda.grad) if light_ref.grad is not None: relative_loss("light:", light_ref.grad, light_cuda.grad) test_normal() test_schlick() test_ndf_ggx() test_lambda_ggx() test_masking_smith() test_lambert() test_frostbite() test_pbr_specular() test_pbr_bsdf('lambert') test_pbr_bsdf('frostbite') ================================================ FILE: render/renderutils/tests/test_loss.py ================================================ # Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual # property and proprietary rights in and to this material, related # documentation and any modifications thereto. Any use, reproduction, # disclosure or distribution of this material and related documentation # without an express license agreement from NVIDIA CORPORATION or # its affiliates is strictly prohibited. import torch import os import sys sys.path.insert(0, os.path.join(sys.path[0], '../..')) import renderutils as ru RES = 8 DTYPE = torch.float32 def tonemap_srgb(f): return torch.where(f > 0.0031308, torch.pow(torch.clamp(f, min=0.0031308), 1.0/2.4)*1.055 - 0.055, 12.92*f) def l1(output, target): x = torch.clamp(output, min=0, max=65535) r = torch.clamp(target, min=0, max=65535) x = tonemap_srgb(torch.log(x + 1)) r = tonemap_srgb(torch.log(r + 1)) return torch.nn.functional.l1_loss(x,r) def relative_loss(name, ref, cuda): ref = ref.float() cuda = cuda.float() print(name, torch.max(torch.abs(ref - cuda) / torch.abs(ref + 1e-7)).item()) def test_loss(loss, tonemapper): img_cuda = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True) img_ref = img_cuda.clone().detach().requires_grad_(True) target_cuda = torch.rand(1, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True) target_ref = target_cuda.clone().detach().requires_grad_(True) ref_loss = ru.image_loss(img_ref, target_ref, loss=loss, tonemapper=tonemapper, use_python=True) ref_loss.backward() cuda_loss = ru.image_loss(img_cuda, target_cuda, loss=loss, tonemapper=tonemapper) cuda_loss.backward() print("-------------------------------------------------------------") print(" Loss: %s, %s" % (loss, tonemapper)) print("-------------------------------------------------------------") relative_loss("res:", ref_loss, cuda_loss) relative_loss("img:", img_ref.grad, img_cuda.grad) relative_loss("target:", target_ref.grad, target_cuda.grad) test_loss('l1', 'none') test_loss('l1', 'log_srgb') test_loss('mse', 'log_srgb') test_loss('smape', 'none') test_loss('relmse', 'none') test_loss('mse', 'none') ================================================ FILE: render/renderutils/tests/test_mesh.py ================================================ # Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual # property and proprietary rights in and to this material, related # documentation and any modifications thereto. Any use, reproduction, # disclosure or distribution of this material and related documentation # without an express license agreement from NVIDIA CORPORATION or # its affiliates is strictly prohibited. import torch import os import sys sys.path.insert(0, os.path.join(sys.path[0], '../..')) import renderutils as ru BATCH = 8 RES = 1024 DTYPE = torch.float32 torch.manual_seed(0) def tonemap_srgb(f): return torch.where(f > 0.0031308, torch.pow(torch.clamp(f, min=0.0031308), 1.0/2.4)*1.055 - 0.055, 12.92*f) def l1(output, target): x = torch.clamp(output, min=0, max=65535) r = torch.clamp(target, min=0, max=65535) x = tonemap_srgb(torch.log(x + 1)) r = tonemap_srgb(torch.log(r + 1)) return torch.nn.functional.l1_loss(x,r) def relative_loss(name, ref, cuda): ref = ref.float() cuda = cuda.float() print(name, torch.max(torch.abs(ref - cuda) / torch.abs(ref)).item()) def test_xfm_points(): points_cuda = torch.rand(1, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True) points_ref = points_cuda.clone().detach().requires_grad_(True) mtx_cuda = torch.rand(BATCH, 4, 4, dtype=DTYPE, device='cuda', requires_grad=False) mtx_ref = mtx_cuda.clone().detach().requires_grad_(True) target = torch.rand(BATCH, RES, 4, dtype=DTYPE, device='cuda', requires_grad=True) ref_out = ru.xfm_points(points_ref, mtx_ref, use_python=True) ref_loss = torch.nn.MSELoss()(ref_out, target) ref_loss.backward() cuda_out = ru.xfm_points(points_cuda, mtx_cuda) cuda_loss = torch.nn.MSELoss()(cuda_out, target) cuda_loss.backward() print("-------------------------------------------------------------") relative_loss("res:", ref_out, cuda_out) relative_loss("points:", points_ref.grad, points_cuda.grad) def test_xfm_vectors(): points_cuda = torch.rand(1, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True) points_ref = points_cuda.clone().detach().requires_grad_(True) points_cuda_p = points_cuda.clone().detach().requires_grad_(True) points_ref_p = points_cuda.clone().detach().requires_grad_(True) mtx_cuda = torch.rand(BATCH, 4, 4, dtype=DTYPE, device='cuda', requires_grad=False) mtx_ref = mtx_cuda.clone().detach().requires_grad_(True) target = torch.rand(BATCH, RES, 4, dtype=DTYPE, device='cuda', requires_grad=True) ref_out = ru.xfm_vectors(points_ref.contiguous(), mtx_ref, use_python=True) ref_loss = torch.nn.MSELoss()(ref_out, target[..., 0:3]) ref_loss.backward() cuda_out = ru.xfm_vectors(points_cuda.contiguous(), mtx_cuda) cuda_loss = torch.nn.MSELoss()(cuda_out, target[..., 0:3]) cuda_loss.backward() ref_out_p = ru.xfm_points(points_ref_p.contiguous(), mtx_ref, use_python=True) ref_loss_p = torch.nn.MSELoss()(ref_out_p, target) ref_loss_p.backward() cuda_out_p = ru.xfm_points(points_cuda_p.contiguous(), mtx_cuda) cuda_loss_p = torch.nn.MSELoss()(cuda_out_p, target) cuda_loss_p.backward() print("-------------------------------------------------------------") relative_loss("res:", ref_out, cuda_out) relative_loss("points:", points_ref.grad, points_cuda.grad) relative_loss("points_p:", points_ref_p.grad, points_cuda_p.grad) test_xfm_points() test_xfm_vectors() ================================================ FILE: render/renderutils/tests/test_perf.py ================================================ # Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual # property and proprietary rights in and to this material, related # documentation and any modifications thereto. Any use, reproduction, # disclosure or distribution of this material and related documentation # without an express license agreement from NVIDIA CORPORATION or # its affiliates is strictly prohibited. import torch import os import sys sys.path.insert(0, os.path.join(sys.path[0], '../..')) import renderutils as ru DTYPE=torch.float32 def test_bsdf(BATCH, RES, ITR): kd_cuda = torch.rand(BATCH, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True) kd_ref = kd_cuda.clone().detach().requires_grad_(True) arm_cuda = torch.rand(BATCH, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True) arm_ref = arm_cuda.clone().detach().requires_grad_(True) pos_cuda = torch.rand(BATCH, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True) pos_ref = pos_cuda.clone().detach().requires_grad_(True) nrm_cuda = torch.rand(BATCH, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True) nrm_ref = nrm_cuda.clone().detach().requires_grad_(True) view_cuda = torch.rand(BATCH, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True) view_ref = view_cuda.clone().detach().requires_grad_(True) light_cuda = torch.rand(BATCH, RES, RES, 3, dtype=DTYPE, device='cuda', requires_grad=True) light_ref = light_cuda.clone().detach().requires_grad_(True) target = torch.rand(BATCH, RES, RES, 3, device='cuda') start = torch.cuda.Event(enable_timing=True) end = torch.cuda.Event(enable_timing=True) ru.pbr_bsdf(kd_cuda, arm_cuda, pos_cuda, nrm_cuda, view_cuda, light_cuda) print("--- Testing: [%d, %d, %d] ---" % (BATCH, RES, RES)) start.record() for i in range(ITR): ref = ru.pbr_bsdf(kd_ref, arm_ref, pos_ref, nrm_ref, view_ref, light_ref, use_python=True) end.record() torch.cuda.synchronize() print("Pbr BSDF python:", start.elapsed_time(end)) start.record() for i in range(ITR): cuda = ru.pbr_bsdf(kd_cuda, arm_cuda, pos_cuda, nrm_cuda, view_cuda, light_cuda) end.record() torch.cuda.synchronize() print("Pbr BSDF cuda:", start.elapsed_time(end)) test_bsdf(1, 512, 1000) test_bsdf(16, 512, 1000) test_bsdf(1, 2048, 1000) ================================================ FILE: render/texture.py ================================================ # Copyright (c) 2020-2022, NVIDIA CORPORATION. All rights reserved. # # NVIDIA CORPORATION and its licensors retain all intellectual property # and proprietary rights in and to this software, related documentation # and any modifications thereto. Any use, reproduction, disclosure or # distribution of this software and related documentation without an express # license agreement from NVIDIA CORPORATION is strictly prohibited. import os import numpy as np import torch import nvdiffrast.torch as dr from . import util ###################################################################################### # Smooth pooling / mip computation with linear gradient upscaling ###################################################################################### class texture2d_mip(torch.autograd.Function): @staticmethod def forward(ctx, texture): return util.avg_pool_nhwc(texture, (2,2)) @staticmethod def backward(ctx, dout): gy, gx = torch.meshgrid(torch.linspace(0.0 + 0.25 / dout.shape[1], 1.0 - 0.25 / dout.shape[1], dout.shape[1]*2, device="cuda"), torch.linspace(0.0 + 0.25 / dout.shape[2], 1.0 - 0.25 / dout.shape[2], dout.shape[2]*2, device="cuda")) uv = torch.stack((gx, gy), dim=-1) return dr.texture(dout * 0.25, uv[None, ...].contiguous(), filter_mode='linear', boundary_mode='clamp') ######################################################################################################## # Simple texture class. A texture can be either # - A 3D tensor (using auto mipmaps) # - A list of 3D tensors (full custom mip hierarchy) ######################################################################################################## class Texture2D: # Initializes a texture from image data. # Input can be constant value (1D array) or texture (3D array) or mip hierarchy (list of 3d arrays) def __init__(self, init, min_max=None): if isinstance(init, np.ndarray): init = torch.tensor(init, dtype=torch.float32, device='cuda') elif isinstance(init, list) and len(init) == 1: init = init[0] if isinstance(init, list) or len(init.shape) == 4: self.data = init elif len(init.shape) == 3: self.data = init[None, ...] else: self.data = init[None, None, None, :] # Convert constant to 1x1 tensor self.min_max = min_max # Filtered (trilinear) sample texture at a given location def sample(self, texc, texc_deriv, filter_mode='linear-mipmap-linear'): if isinstance(self.data, list): out = dr.texture(self.data[0], texc, texc_deriv, mip=self.data[1:], filter_mode=filter_mode) else: if self.data.shape[1] > 1 and self.data.shape[2] > 1: mips = [self.data] while mips[-1].shape[1] > 1 and mips[-1].shape[2] > 1: mips += [texture2d_mip.apply(mips[-1])] out = dr.texture(mips[0], texc, texc_deriv, mip=mips[1:], filter_mode=filter_mode) else: out = dr.texture(self.data, texc, texc_deriv, filter_mode=filter_mode) return out def getRes(self): return self.getMips()[0].shape[1:3] def getChannels(self): return self.getMips()[0].shape[3] def getMips(self): if isinstance(self.data, list): return self.data else: return [self.data] def parameters(self): return self.getMips() # In-place clamp with no derivative to make sure values are in valid range after training def clamp_(self): if self.min_max is not None: for mip in self.getMips(): for i in range(mip.shape[-1]): mip[..., i].clamp_(min=self.min_max[0][i], max=self.min_max[1][i]) # In-place clamp with no derivative to make sure values are in valid range after training def normalize_(self): with torch.no_grad(): for mip in self.getMips(): mip.copy_(util.safe_normalize(mip)) ######################################################################################################## # Helper function to create a trainable texture from a regular texture. The trainable weights are # initialized with texture data as an initial guess ######################################################################################################## def create_trainable(init, res=None, auto_mipmaps=True, min_max=None): with torch.no_grad(): if isinstance(init, Texture2D): assert isinstance(init.data, torch.Tensor) min_max = init.min_max if min_max is None else min_max init = init.data elif isinstance(init, np.ndarray): init = torch.tensor(init, dtype=torch.float32, device='cuda') # Pad to NHWC if needed if len(init.shape) == 1: # Extend constant to NHWC tensor init = init[None, None, None, :] elif len(init.shape) == 3: init = init[None, ...] # Scale input to desired resolution. if res is not None: init = util.scale_img_nhwc(init, res) # Generate custom mipchain if not auto_mipmaps: mip_chain = [init.clone().detach().requires_grad_(True)] while mip_chain[-1].shape[1] > 1 or mip_chain[-1].shape[2] > 1: new_size = [max(mip_chain[-1].shape[1] // 2, 1), max(mip_chain[-1].shape[2] // 2, 1)] init = util.scale_img_nhwc(mip_chain[-1], new_size) mip_chain += [init.clone().detach().requires_grad_(True)] return Texture2D(mip_chain, min_max=min_max) else: return Texture2D(init.clone().detach().requires_grad_(True), min_max=min_max) ######################################################################################################## # Convert texture to and from SRGB ######################################################################################################## def srgb_to_rgb(texture): return Texture2D(list(util.srgb_to_rgb(mip) for mip in texture.getMips())) def rgb_to_srgb(texture): return Texture2D(list(util.rgb_to_srgb(mip) for mip in texture.getMips())) ######################################################################################################## # Utility functions for loading / storing a texture ######################################################################################################## def _load_mip2D(fn, lambda_fn=None, channels=None): imgdata = torch.tensor(util.load_image(fn), dtype=torch.float32, device='cuda') if channels is not None: imgdata = imgdata[..., 0:channels] if lambda_fn is not None: imgdata = lambda_fn(imgdata) return imgdata.detach().clone() def load_texture2D(fn, lambda_fn=None, channels=None): base, ext = os.path.splitext(fn) if os.path.exists(base + "_0" + ext): mips = [] while os.path.exists(base + ("_%d" % len(mips)) + ext): mips += [_load_mip2D(base + ("_%d" % len(mips)) + ext, lambda_fn, channels)] return Texture2D(mips) else: return Texture2D(_load_mip2D(fn, lambda_fn, channels)) def _save_mip2D(fn, mip, mipidx, lambda_fn): if lambda_fn is not None: data = lambda_fn(mip).detach().cpu().numpy() else: data = mip.detach().cpu().numpy() if mipidx is None: util.save_image(fn, data) else: base, ext = os.path.splitext(fn) util.save_image(base + ("_%d" % mipidx) + ext, data) def save_texture2D(fn, tex, lambda_fn=None): if isinstance(tex.data, list): for i, mip in enumerate(tex.data): _save_mip2D(fn, mip[0,...], i, lambda_fn) else: _save_mip2D(fn, tex.data[0,...], None, lambda_fn) ================================================ FILE: render/util.py ================================================ # Copyright (c) 2020-2022, NVIDIA CORPORATION. All rights reserved. # # NVIDIA CORPORATION and its licensors retain all intellectual property # and proprietary rights in and to this software, related documentation # and any modifications thereto. Any use, reproduction, disclosure or # distribution of this software and related documentation without an express # license agreement from NVIDIA CORPORATION is strictly prohibited. import os import numpy as np import torch import nvdiffrast.torch as dr import imageio #---------------------------------------------------------------------------- # Vector operations #---------------------------------------------------------------------------- def dot(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: return torch.sum(x*y, -1, keepdim=True) def reflect(x: torch.Tensor, n: torch.Tensor) -> torch.Tensor: return 2*dot(x, n)*n - x def length(x: torch.Tensor, eps: float =1e-20) -> torch.Tensor: return torch.sqrt(torch.clamp(dot(x,x), min=eps)) # Clamp to avoid nan gradients because grad(sqrt(0)) = NaN def safe_normalize(x: torch.Tensor, eps: float =1e-20) -> torch.Tensor: return x / length(x, eps) def to_hvec(x: torch.Tensor, w: float) -> torch.Tensor: return torch.nn.functional.pad(x, pad=(0,1), mode='constant', value=w) def ycocg2rgb(ycocg): return torch.stack(( ycocg[..., 0] + ycocg[..., 1] - ycocg[..., 2], ycocg[..., 0] + ycocg[..., 2], ycocg[..., 0] - ycocg[..., 1] - ycocg[..., 2] ), dim=-1) def hsv2rgb(image): # Based on https://kornia.readthedocs.io/en/latest/_modules/kornia/color/hsv.html h, s, v = image[..., 0], image[..., 1], image[..., 2] hi = torch.floor(h * 6) % 6 f = ((h * 6) % 6) - hi p = v * (1 - s) q = v * (1 - f * s) t = v * (1 - (1 - f) * s) hi = hi.long() indices = torch.stack([hi, hi + 6, hi + 12], dim=-1) out = torch.stack((v, q, p, p, t, v, t, v, v, q, p, p, p, p, t, v, v, q), dim=-1) out = torch.gather(out, -1, indices) return out #---------------------------------------------------------------------------- # Pixel grid from resolution #---------------------------------------------------------------------------- def pixel_grid(width, height, center_x = 0.5, center_y = 0.5): y, x = torch.meshgrid( (torch.arange(0, height, dtype=torch.float32, device="cuda") + center_y) / height, (torch.arange(0, width, dtype=torch.float32, device="cuda") + center_x) / width) return torch.stack((x, y), dim=-1) #---------------------------------------------------------------------------- # Dilation filter #---------------------------------------------------------------------------- def dilate(x, x_avg, mask, N): def _gaussian(): variance = (1.0 / 2.5)**2. grid_y, grid_x = torch.meshgrid(torch.linspace(-1, 1, N, dtype=torch.float32, device="cuda"), torch.linspace(-1, 1, N, dtype=torch.float32, device="cuda")) xy_grid = torch.stack([grid_x, grid_y], dim=-1) gaussian_kernel = (.5*np.pi*variance) * torch.exp(-torch.sum(xy_grid**2., dim=-1) / (2*variance)) return gaussian_kernel / torch.sum(gaussian_kernel) def _w(c, cN): return torch.stack(list(_gaussian() if i == c else torch.zeros(N, N, dtype=torch.float32, device="cuda") for i in range(cN)), dim=0) epsilon = 1e-6 weights = torch.stack(list(_w(i, x.shape[3]) for i in range(x.shape[3])), dim=0) mask_flt = torch.nn.functional.conv2d(mask.permute(0, 3, 1, 2), weights[0:1, 0:1, ...], padding=N//2).permute(0, 2, 3, 1) x_flt = torch.nn.functional.conv2d((x * mask).permute(0, 3, 1, 2), weights, padding=N//2).permute(0, 2, 3, 1) x_flt = torch.where(mask_flt > epsilon, x_flt / torch.clamp(mask_flt, min=epsilon), x_avg) return x_flt * (1 - mask) + x * mask #---------------------------------------------------------------------------- # sRGB color transforms #---------------------------------------------------------------------------- def _rgb_to_srgb(f: torch.Tensor) -> torch.Tensor: return torch.where(f <= 0.0031308, f * 12.92, torch.pow(torch.clamp(f, 0.0031308), 1.0/2.4)*1.055 - 0.055) def rgb_to_srgb(f: torch.Tensor) -> torch.Tensor: assert f.shape[-1] == 3 or f.shape[-1] == 4 out = torch.cat((_rgb_to_srgb(f[..., 0:3]), f[..., 3:4]), dim=-1) if f.shape[-1] == 4 else _rgb_to_srgb(f) assert out.shape[0] == f.shape[0] and out.shape[1] == f.shape[1] and out.shape[2] == f.shape[2] return out def _srgb_to_rgb(f: torch.Tensor) -> torch.Tensor: return torch.where(f <= 0.04045, f / 12.92, torch.pow((torch.clamp(f, 0.04045) + 0.055) / 1.055, 2.4)) def srgb_to_rgb(f: torch.Tensor) -> torch.Tensor: assert f.shape[-1] == 3 or f.shape[-1] == 4 out = torch.cat((_srgb_to_rgb(f[..., 0:3]), f[..., 3:4]), dim=-1) if f.shape[-1] == 4 else _srgb_to_rgb(f) assert out.shape[0] == f.shape[0] and out.shape[1] == f.shape[1] and out.shape[2] == f.shape[2] return out def reinhard(f: torch.Tensor) -> torch.Tensor: return f/(1+f) #----------------------------------------------------------------------------------- # Metrics (taken from jaxNerf source code, in order to replicate their measurements) # # https://github.com/google-research/google-research/blob/301451a62102b046bbeebff49a760ebeec9707b8/jaxnerf/nerf/utils.py#L266 # #----------------------------------------------------------------------------------- def mse_to_psnr(mse): """Compute PSNR given an MSE (we assume the maximum pixel value is 1).""" return -10. / np.log(10.) * np.log(mse) def psnr_to_mse(psnr): """Compute MSE given a PSNR (we assume the maximum pixel value is 1).""" return np.exp(-0.1 * np.log(10.) * psnr) #---------------------------------------------------------------------------- # Displacement texture lookup #---------------------------------------------------------------------------- def get_miplevels(texture: np.ndarray) -> float: minDim = min(texture.shape[0], texture.shape[1]) return np.floor(np.log2(minDim)) def tex_2d(tex_map : torch.Tensor, coords : torch.Tensor, filter='nearest') -> torch.Tensor: tex_map = tex_map[None, ...] # Add batch dimension tex_map = tex_map.permute(0, 3, 1, 2) # NHWC -> NCHW tex = torch.nn.functional.grid_sample(tex_map, coords[None, None, ...] * 2 - 1, mode=filter, align_corners=False) tex = tex.permute(0, 2, 3, 1) # NCHW -> NHWC return tex[0, 0, ...] #---------------------------------------------------------------------------- # Cubemap utility functions #---------------------------------------------------------------------------- def cube_to_dir(s, x, y): if s == 0: rx, ry, rz = torch.ones_like(x), -y, -x elif s == 1: rx, ry, rz = -torch.ones_like(x), -y, x elif s == 2: rx, ry, rz = x, torch.ones_like(x), y elif s == 3: rx, ry, rz = x, -torch.ones_like(x), -y elif s == 4: rx, ry, rz = x, -y, torch.ones_like(x) elif s == 5: rx, ry, rz = -x, -y, -torch.ones_like(x) return torch.stack((rx, ry, rz), dim=-1) def latlong_to_cubemap(latlong_map, res): cubemap = torch.zeros(6, res[0], res[1], latlong_map.shape[-1], dtype=torch.float32, device='cuda') for s in range(6): gy, gx = torch.meshgrid(torch.linspace(-1.0 + 1.0 / res[0], 1.0 - 1.0 / res[0], res[0], device='cuda'), torch.linspace(-1.0 + 1.0 / res[1], 1.0 - 1.0 / res[1], res[1], device='cuda'), indexing='ij') v = safe_normalize(cube_to_dir(s, gx, gy)) tu = torch.atan2(v[..., 0:1], -v[..., 2:3]) / (2 * np.pi) + 0.5 tv = torch.acos(torch.clamp(v[..., 1:2], min=-1, max=1)) / np.pi texcoord = torch.cat((tu, tv), dim=-1) cubemap[s, ...] = dr.texture(latlong_map[None, ...], texcoord[None, ...], filter_mode='linear')[0] return cubemap def cubemap_to_latlong(cubemap, res): gy, gx = torch.meshgrid(torch.linspace( 0.0 + 1.0 / res[0], 1.0 - 1.0 / res[0], res[0], device='cuda'), torch.linspace(-1.0 + 1.0 / res[1], 1.0 - 1.0 / res[1], res[1], device='cuda'), indexing='ij') sintheta, costheta = torch.sin(gy*np.pi), torch.cos(gy*np.pi) sinphi, cosphi = torch.sin(gx*np.pi), torch.cos(gx*np.pi) reflvec = torch.stack(( sintheta*sinphi, costheta, -sintheta*cosphi ), dim=-1) return dr.texture(cubemap[None, ...], reflvec[None, ...].contiguous(), filter_mode='linear', boundary_mode='cube')[0] #---------------------------------------------------------------------------- # Image scaling #---------------------------------------------------------------------------- def scale_img_hwc(x : torch.Tensor, size, mag='bilinear', min='area') -> torch.Tensor: return scale_img_nhwc(x[None, ...], size, mag, min)[0] def scale_img_nhwc(x : torch.Tensor, size, mag='bilinear', min='area') -> torch.Tensor: assert (x.shape[1] >= size[0] and x.shape[2] >= size[1]) or (x.shape[1] < size[0] and x.shape[2] < size[1]), "Trying to magnify image in one dimension and minify in the other" y = x.permute(0, 3, 1, 2) # NHWC -> NCHW if x.shape[1] > size[0] and x.shape[2] > size[1]: # Minification, previous size was bigger y = torch.nn.functional.interpolate(y, size, mode=min) else: # Magnification if mag == 'bilinear' or mag == 'bicubic': y = torch.nn.functional.interpolate(y, size, mode=mag, align_corners=True) else: y = torch.nn.functional.interpolate(y, size, mode=mag) return y.permute(0, 2, 3, 1).contiguous() # NCHW -> NHWC def avg_pool_nhwc(x : torch.Tensor, size) -> torch.Tensor: y = x.permute(0, 3, 1, 2) # NHWC -> NCHW y = torch.nn.functional.avg_pool2d(y, size) return y.permute(0, 2, 3, 1).contiguous() # NCHW -> NHWC #---------------------------------------------------------------------------- # Behaves similar to tf.segment_sum #---------------------------------------------------------------------------- def segment_sum(data: torch.Tensor, segment_ids: torch.Tensor) -> torch.Tensor: num_segments = torch.unique_consecutive(segment_ids).shape[0] # Repeats ids until same dimension as data if len(segment_ids.shape) == 1: s = torch.prod(torch.tensor(data.shape[1:], dtype=torch.int64, device='cuda')).long() segment_ids = segment_ids.repeat_interleave(s).view(segment_ids.shape[0], *data.shape[1:]) assert data.shape == segment_ids.shape, "data.shape and segment_ids.shape should be equal" shape = [num_segments] + list(data.shape[1:]) result = torch.zeros(*shape, dtype=torch.float32, device='cuda') result = result.scatter_add(0, segment_ids, data) return result #---------------------------------------------------------------------------- # Matrix helpers. #---------------------------------------------------------------------------- def fovx_to_fovy(fovx, aspect): return np.arctan(np.tan(fovx / 2) / aspect) * 2.0 def focal_length_to_fovy(focal_length, sensor_height): return 2 * np.arctan(0.5 * sensor_height / focal_length) # Reworked so this matches gluPerspective / glm::perspective, using fovy def perspective(fovy=0.7854, aspect=1.0, n=0.1, f=1000.0, device=None): y = np.tan(fovy / 2) return torch.tensor([[1/(y*aspect), 0, 0, 0], [ 0, 1/-y, 0, 0], [ 0, 0, -(f+n)/(f-n), -(2*f*n)/(f-n)], [ 0, 0, -1, 0]], dtype=torch.float32, device=device) # Reworked so this matches gluPerspective / glm::perspective, using fovy def perspective_offcenter(fovy, fraction, rx, ry, aspect=1.0, n=0.1, f=1000.0, device=None): y = np.tan(fovy / 2) # Full frustum R, L = aspect*y, -aspect*y T, B = y, -y # Create a randomized sub-frustum width = (R-L)*fraction height = (T-B)*fraction xstart = (R-L)*rx ystart = (T-B)*ry l = L + xstart r = l + width b = B + ystart t = b + height # https://www.scratchapixel.com/lessons/3d-basic-rendering/perspective-and-orthographic-projection-matrix/opengl-perspective-projection-matrix return torch.tensor([[2/(r-l), 0, (r+l)/(r-l), 0], [ 0, -2/(t-b), (t+b)/(t-b), 0], [ 0, 0, -(f+n)/(f-n), -(2*f*n)/(f-n)], [ 0, 0, -1, 0]], dtype=torch.float32, device=device) def translate(x, y, z, device=None): return torch.tensor([[1, 0, 0, x], [0, 1, 0, y], [0, 0, 1, z], [0, 0, 0, 1]], dtype=torch.float32, device=device) def rotate_x(a, device=None): s, c = np.sin(a), np.cos(a) return torch.tensor([[1, 0, 0, 0], [0, c, s, 0], [0, -s, c, 0], [0, 0, 0, 1]], dtype=torch.float32, device=device) def rotate_y(a, device=None): s, c = np.sin(a), np.cos(a) return torch.tensor([[ c, 0, s, 0], [ 0, 1, 0, 0], [-s, 0, c, 0], [ 0, 0, 0, 1]], dtype=torch.float32, device=device) def rotate_z(a, device=None): s, c = np.sin(a), np.cos(a) return torch.tensor([[ c, s, 0, 0], [-s, c, 0, 0], [ 0, 0, 1, 0], [ 0, 0, 0, 1]], dtype=torch.float32, device=device) def scale(s, device=None): return torch.tensor([[ s, 0, 0, 0], [ 0, s, 0, 0], [ 0, 0, s, 0], [ 0, 0, 0, 1]], dtype=torch.float32, device=device) def lookAt(eye, at, up): a = eye - at w = a / torch.linalg.norm(a) u = torch.cross(up, w) u = u / torch.linalg.norm(u) v = torch.cross(w, u) translate = torch.tensor([[1, 0, 0, -eye[0]], [0, 1, 0, -eye[1]], [0, 0, 1, -eye[2]], [0, 0, 0, 1]], dtype=eye.dtype, device=eye.device) rotate = torch.tensor([[u[0], u[1], u[2], 0], [v[0], v[1], v[2], 0], [w[0], w[1], w[2], 0], [0, 0, 0, 1]], dtype=eye.dtype, device=eye.device) return rotate @ translate @torch.no_grad() def random_rotation_translation(t, device=None): m = np.random.normal(size=[3, 3]) m[1] = np.cross(m[0], m[2]) m[2] = np.cross(m[0], m[1]) m = m / np.linalg.norm(m, axis=1, keepdims=True) m = np.pad(m, [[0, 1], [0, 1]], mode='constant') m[3, 3] = 1.0 m[:3, 3] = np.random.uniform(-t, t, size=[3]) return torch.tensor(m, dtype=torch.float32, device=device) @torch.no_grad() def random_rotation(device=None): m = np.random.normal(size=[3, 3]) m[1] = np.cross(m[0], m[2]) m[2] = np.cross(m[0], m[1]) m = m / np.linalg.norm(m, axis=1, keepdims=True) m = np.pad(m, [[0, 1], [0, 1]], mode='constant') m[3, 3] = 1.0 m[:3, 3] = np.array([0,0,0]).astype(np.float32) return torch.tensor(m, dtype=torch.float32, device=device) #---------------------------------------------------------------------------- # Compute focal points of a set of lines using least squares. # handy for poorly centered datasets #---------------------------------------------------------------------------- def lines_focal(o, d): d = safe_normalize(d) I = torch.eye(3, dtype=o.dtype, device=o.device) S = torch.sum(d[..., None] @ torch.transpose(d[..., None], 1, 2) - I[None, ...], dim=0) C = torch.sum((d[..., None] @ torch.transpose(d[..., None], 1, 2) - I[None, ...]) @ o[..., None], dim=0).squeeze(1) return torch.linalg.pinv(S) @ C #---------------------------------------------------------------------------- # Cosine sample around a vector N #---------------------------------------------------------------------------- @torch.no_grad() def cosine_sample(N, size=None): # construct local frame N = N/torch.linalg.norm(N) dx0 = torch.tensor([0, N[2], -N[1]], dtype=N.dtype, device=N.device) dx1 = torch.tensor([-N[2], 0, N[0]], dtype=N.dtype, device=N.device) dx = torch.where(dot(dx0, dx0) > dot(dx1, dx1), dx0, dx1) #dx = dx0 if np.dot(dx0,dx0) > np.dot(dx1,dx1) else dx1 dx = dx / torch.linalg.norm(dx) dy = torch.cross(N,dx) dy = dy / torch.linalg.norm(dy) # cosine sampling in local frame if size is None: phi = 2.0 * np.pi * np.random.uniform() s = np.random.uniform() else: phi = 2.0 * np.pi * torch.rand(*size, 1, dtype=N.dtype, device=N.device) s = torch.rand(*size, 1, dtype=N.dtype, device=N.device) costheta = np.sqrt(s) sintheta = np.sqrt(1.0 - s) # cartesian vector in local space x = np.cos(phi)*sintheta y = np.sin(phi)*sintheta z = costheta # local to world return dx*x + dy*y + N*z #---------------------------------------------------------------------------- # Bilinear downsample by 2x. #---------------------------------------------------------------------------- def bilinear_downsample(x : torch.tensor) -> torch.Tensor: w = torch.tensor([[1, 3, 3, 1], [3, 9, 9, 3], [3, 9, 9, 3], [1, 3, 3, 1]], dtype=torch.float32, device=x.device) / 64.0 w = w.expand(x.shape[-1], 1, 4, 4) x = torch.nn.functional.conv2d(x.permute(0, 3, 1, 2), w, padding=1, stride=2, groups=x.shape[-1]) return x.permute(0, 2, 3, 1) #---------------------------------------------------------------------------- # Bilinear downsample log(spp) steps #---------------------------------------------------------------------------- def bilinear_downsample(x : torch.tensor, spp) -> torch.Tensor: w = torch.tensor([[1, 3, 3, 1], [3, 9, 9, 3], [3, 9, 9, 3], [1, 3, 3, 1]], dtype=torch.float32, device=x.device) / 64.0 g = x.shape[-1] w = w.expand(g, 1, 4, 4) x = x.permute(0, 3, 1, 2) # NHWC -> NCHW steps = int(np.log2(spp)) for _ in range(steps): xp = torch.nn.functional.pad(x, (1,1,1,1), mode='replicate') x = torch.nn.functional.conv2d(xp, w, padding=0, stride=2, groups=g) return x.permute(0, 2, 3, 1).contiguous() # NCHW -> NHWC #---------------------------------------------------------------------------- # Singleton initialize GLFW #---------------------------------------------------------------------------- _glfw_initialized = False def init_glfw(): global _glfw_initialized try: import glfw glfw.ERROR_REPORTING = 'raise' glfw.default_window_hints() glfw.window_hint(glfw.VISIBLE, glfw.FALSE) test = glfw.create_window(8, 8, "Test", None, None) # Create a window and see if not initialized yet except glfw.GLFWError as e: if e.error_code == glfw.NOT_INITIALIZED: glfw.init() _glfw_initialized = True #---------------------------------------------------------------------------- # Image display function using OpenGL. #---------------------------------------------------------------------------- _glfw_window = None def display_image(image, title=None): # Import OpenGL import OpenGL.GL as gl import glfw # Zoom image if requested. image = np.asarray(image[..., 0:3]) if image.shape[-1] == 4 else np.asarray(image) height, width, channels = image.shape # Initialize window. init_glfw() if title is None: title = 'Debug window' global _glfw_window if _glfw_window is None: glfw.default_window_hints() _glfw_window = glfw.create_window(width, height, title, None, None) glfw.make_context_current(_glfw_window) glfw.show_window(_glfw_window) glfw.swap_interval(0) else: glfw.make_context_current(_glfw_window) glfw.set_window_title(_glfw_window, title) glfw.set_window_size(_glfw_window, width, height) # Update window. glfw.poll_events() gl.glClearColor(0, 0, 0, 1) gl.glClear(gl.GL_COLOR_BUFFER_BIT) gl.glWindowPos2f(0, 0) gl.glPixelStorei(gl.GL_UNPACK_ALIGNMENT, 1) gl_format = {3: gl.GL_RGB, 2: gl.GL_RG, 1: gl.GL_LUMINANCE}[channels] gl_dtype = {'uint8': gl.GL_UNSIGNED_BYTE, 'float32': gl.GL_FLOAT}[image.dtype.name] gl.glDrawPixels(width, height, gl_format, gl_dtype, image[::-1]) glfw.swap_buffers(_glfw_window) if glfw.window_should_close(_glfw_window): return False return True #---------------------------------------------------------------------------- # Image save/load helper. #---------------------------------------------------------------------------- def save_image(fn, x : np.ndarray) -> np.ndarray: try: if os.path.splitext(fn)[1] == ".png": imageio.imwrite(fn, np.clip(np.rint(x * 255.0), 0, 255).astype(np.uint8), compress_level=3) # Low compression for faster saving else: imageio.imwrite(fn, np.clip(np.rint(x * 255.0), 0, 255).astype(np.uint8)) except: print("WARNING: FAILED to save image %s" % fn) def save_image_raw(fn, x : np.ndarray): try: imageio.imwrite(fn, x) except: print("WARNING: FAILED to save image %s" % fn) def load_image_raw(fn) -> np.ndarray: return imageio.imread(fn) def load_image(fn) -> np.ndarray: img = load_image_raw(fn) if img.dtype == np.float32: # HDR image return img else: # LDR image return img.astype(np.float32) / 255 #---------------------------------------------------------------------------- def time_to_text(x): if x > 3600: return "%.2f h" % (x / 3600) elif x > 60: return "%.2f m" % (x / 60) else: return "%.2f s" % x #---------------------------------------------------------------------------- def checkerboard(res, checker_size) -> np.ndarray: tiles_y = (res[0] + (checker_size*2) - 1) // (checker_size*2) tiles_x = (res[1] + (checker_size*2) - 1) // (checker_size*2) check = np.kron([[1, 0] * tiles_x, [0, 1] * tiles_x] * tiles_y, np.ones((checker_size, checker_size)))*0.33 + 0.33 check = check[:res[0], :res[1]] return np.stack((check, check, check), axis=-1) ================================================ FILE: train_gflexicubes_deepfashion.py ================================================ # Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual # property and proprietary rights in and to this material, related # documentation and any modifications thereto. Any use, reproduction, # disclosure or distribution of this material and related documentation # without an express license agreement from NVIDIA CORPORATION or # its affiliates is strictly prohibited. import os import sys import time import argparse import json import numpy as np import torch import nvdiffrast.torch as dr import xatlas # Import data readers / generators from dataset.dataset_deepfashion import DatasetDeepFashion from dataset.dataset_deepfashion_testset import DatasetDeepFashionTestset # Import topology / geometry trainers from geometry.gshell_flexicubes_geometry import GShellFlexiCubesGeometry import render.renderutils as ru from render import obj from render import material from render import util from render import mesh from render import texture from render import mlptexture from render import light from render import render from denoiser.denoiser import BilateralDenoiser RADIUS = 3.0 # Enable to debug back-prop anomalies # torch.autograd.set_detect_anomaly(True) ############################################################################### # Loss setup ############################################################################### @torch.no_grad() def createLoss(FLAGS): if FLAGS.loss == "smape": return lambda img, ref: ru.image_loss(img, ref, loss='smape', tonemapper='none') elif FLAGS.loss == "mse": return lambda img, ref: ru.image_loss(img, ref, loss='mse', tonemapper='none') elif FLAGS.loss == "logl1": return lambda img, ref: ru.image_loss(img, ref, loss='l1', tonemapper='log_srgb') elif FLAGS.loss == "logl2": return lambda img, ref: ru.image_loss(img, ref, loss='mse', tonemapper='log_srgb') elif FLAGS.loss == "relmse": return lambda img, ref: ru.image_loss(img, ref, loss='relmse', tonemapper='none') else: assert False ############################################################################### # Mix background into a dataset image ############################################################################### @torch.no_grad() def prepare_batch(target, bg_type='black'): assert len(target['img'].shape) == 4, "Image shape should be [n, h, w, c]" if bg_type == 'checker': background = torch.tensor(util.checkerboard(target['img'].shape[1:3], 8), dtype=torch.float32, device='cuda')[None, ...] elif bg_type == 'black': background = torch.zeros(target['img'].shape[0:3] + (3,), dtype=torch.float32, device='cuda') elif bg_type == 'white': background = torch.ones(target['img'].shape[0:3] + (3,), dtype=torch.float32, device='cuda') elif bg_type == 'reference': background = target['img'][..., 0:3] elif bg_type == 'random': background = torch.rand(target['img'].shape[0:3] + (3,), dtype=torch.float32, device='cuda') else: assert False, "Unknown background type %s" % bg_type target['mv'] = target['mv'].cuda() target['mvp'] = target['mvp'].cuda() target['campos'] = target['campos'].cuda() target['img'] = target['img'].cuda() target['background'] = background target['img'] = torch.cat((torch.lerp(background, target['img'][..., 0:3], target['img'][..., 3:4]), target['img'][..., 3:4]), dim=-1) return target ############################################################################### # UV - map geometry & convert to a mesh ############################################################################### @torch.no_grad() def xatlas_uvmap(glctx, geometry, mat, FLAGS): eval_mesh = geometry.getMesh(mat) try: eval_mesh = eval_mesh['imesh'] except: pass # Create uvs with xatlas v_pos = eval_mesh.v_pos.detach().cpu().numpy() t_pos_idx = eval_mesh.t_pos_idx.detach().cpu().numpy() vmapping, indices, uvs = xatlas.parametrize(v_pos, t_pos_idx) # Convert to tensors indices_int64 = indices.astype(np.uint64, casting='same_kind').view(np.int64) uvs = torch.tensor(uvs, dtype=torch.float32, device='cuda') faces = torch.tensor(indices_int64, dtype=torch.int64, device='cuda') new_mesh = mesh.Mesh(v_tex=uvs, t_tex_idx=faces, base=eval_mesh) mask, kd, ks = render.render_uv(glctx, new_mesh, FLAGS.texture_res, eval_mesh.material['kd_ks']) # Dilate all textures & use average color for background kd_avg = torch.sum(torch.sum(torch.sum(kd * mask, dim=0), dim=0), dim=0) / torch.sum(torch.sum(torch.sum(mask, dim=0), dim=0), dim=0) kd = util.dilate(kd, kd_avg[None, None, None, :], mask, 7) ks_avg = torch.sum(torch.sum(torch.sum(ks * mask, dim=0), dim=0), dim=0) / torch.sum(torch.sum(torch.sum(mask, dim=0), dim=0), dim=0) ks = util.dilate(ks, ks_avg[None, None, None, :], mask, 7) nrm_avg = torch.tensor([0, 0, 1], dtype=torch.float32, device="cuda") normal = nrm_avg[None, None, None, :].repeat(kd.shape[0], kd.shape[1], kd.shape[2], 1) new_mesh.material = mat.copy() del new_mesh.material['kd_ks'] if FLAGS.transparency: kd = torch.cat((kd, torch.rand_like(kd[...,0:1])), dim=-1) print("kd shape", kd.shape) kd_min, kd_max = torch.tensor(FLAGS.kd_min, dtype=torch.float32, device='cuda'), torch.tensor(FLAGS.kd_max, dtype=torch.float32, device='cuda') ks_min, ks_max = torch.tensor(FLAGS.ks_min, dtype=torch.float32, device='cuda'), torch.tensor(FLAGS.ks_max, dtype=torch.float32, device='cuda') nrm_min, nrm_max = torch.tensor(FLAGS.nrm_min, dtype=torch.float32, device='cuda'), torch.tensor(FLAGS.nrm_max, dtype=torch.float32, device='cuda') new_mesh.material.update({ 'kd' : texture.Texture2D(kd.clone().detach().requires_grad_(True), min_max=[kd_min, kd_max]), 'ks' : texture.Texture2D(ks.clone().detach().requires_grad_(True), min_max=[ks_min, ks_max]), 'normal' : texture.Texture2D(normal.clone().detach().requires_grad_(True), min_max=[nrm_min, nrm_max]), }) return new_mesh ############################################################################### # Utility functions for material ############################################################################### def initial_guess_material(geometry, mlp, FLAGS, init_mat=None): kd_min, kd_max = torch.tensor(FLAGS.kd_min, dtype=torch.float32, device='cuda'), torch.tensor(FLAGS.kd_max, dtype=torch.float32, device='cuda') ks_min, ks_max = torch.tensor(FLAGS.ks_min, dtype=torch.float32, device='cuda'), torch.tensor(FLAGS.ks_max, dtype=torch.float32, device='cuda') if mlp: mlp_min = torch.cat((kd_min[0:3], ks_min), dim=0) mlp_max = torch.cat((kd_max[0:3], ks_max), dim=0) mlp_map_opt = mlptexture.MLPTexture3D(geometry.getAABB(), channels=6, min_max=[mlp_min, mlp_max], use_float16=FLAGS.use_float16) mat = {'kd_ks' : mlp_map_opt} else: raise NotImplementedError mat['bsdf'] = FLAGS.bsdf mat['no_perturbed_nrm'] = FLAGS.no_perturbed_nrm return mat def initial_guess_material_knownkskd(geometry, mlp, FLAGS, init_mat=None): mat = { 'kd' : init_mat['kd'], 'ks' : init_mat['ks'] } if init_mat is not None: mat['bsdf'] = init_mat['bsdf'] else: mat['bsdf'] = 'pbr' return mat ############################################################################### # Validation & testing ############################################################################### @torch.no_grad() def validate_itr(glctx, target, geometry, opt_material, lgt, FLAGS, denoiser=None): result_dict = {} with torch.no_grad(): buffers = geometry.render(glctx, target, lgt, opt_material, use_uv=False, denoiser=denoiser)['buffers'] result_dict['ref'] = util.rgb_to_srgb(target['img'][...,0:3])[0] result_dict['opt'] = util.rgb_to_srgb(buffers['shaded'][...,0:3])[0] result_dict['mask_opt'] = buffers['shaded'][...,3:][0].expand(-1, -1, 3) result_dict['mask_ref'] = target['img'][...,3:][0].expand(-1, -1, 3) result_dict['msdf_image'] = buffers['msdf_image'][...,:][0].expand(-1, -1, 3).clamp(min=0, max=1) # result_dict['invdepth_2nd'] = buffers['invdepth_second'][...,0:1].expand(-1, -1, -1, 3).clamp(max=1.0)[0] # result_dict['invdepth_2nd_ref'] = target['invdepth_second'][...,0:1].expand(-1, -1, -1, 3).clamp(max=1.0)[0] result_image = torch.cat([result_dict['opt'], result_dict['ref'], result_dict['mask_opt'], result_dict['mask_ref'], result_dict['msdf_image']], axis=1) if FLAGS.display is not None: white_bg = torch.ones_like(target['background']) for layer in FLAGS.display: if 'latlong' in layer and layer['latlong']: result_dict['light_image'] = lgt.generate_image(FLAGS.display_res) result_dict['light_image'] = util.rgb_to_srgb(result_dict['light_image'] / (1 + result_dict['light_image'])) result_image = torch.cat([result_image, result_dict['light_image']], axis=1) elif 'bsdf' in layer: img = render.render_mesh(FLAGS, glctx, opt_mesh, target['mvp'], target['campos'], target['light'] if lgt is None else lgt, target['resolution'], spp=target['spp'], num_layers=FLAGS.layers, background=white_bg, bsdf=layer['bsdf'], optix_ctx=geometry.optix_ctx)['shaded'] if layer['bsdf'] == 'kd': result_dict[layer['bsdf']] = util.rgb_to_srgb(img[..., 0:3])[0] else: result_dict[layer['bsdf']] = img[0, ..., 0:3] result_image = torch.cat([result_image, result_dict[layer['bsdf']]], axis=1) if ref_mesh is not None: img = render.render_mesh(FLAGS, glctx, ref_mesh, target['mvp'], target['campos'], target['light'], target['resolution'], spp=target['spp'], num_layers=FLAGS.layers, background=white_bg, bsdf=layer['bsdf'], optix_ctx=geometry.optix_ctx)['shaded'] if layer['bsdf'] == 'kd': result_dict[layer['bsdf'] + "_ref"] = util.rgb_to_srgb(img[..., 0:3])[0] else: result_dict[layer['bsdf'] + "_ref"] = img[0, ..., 0:3] result_image = torch.cat([result_image, result_dict[layer['bsdf'] + "_ref"]], axis=1) elif 'normals' in layer and not FLAGS.no_perturbed_nrm: result_image = torch.cat([result_image, (buffers['perturbed_nrm'][0, ...,0:3] + 1.0) * 0.5], axis=1) elif 'diffuse_light' in layer: result_image = torch.cat([result_image, util.rgb_to_srgb(buffers['diffuse_light'][..., 0:3])[0]], axis=1) elif 'specular_light' in layer: result_image = torch.cat([result_image, util.rgb_to_srgb(buffers['specular_light'][..., 0:3])[0]], axis=1) return result_image, result_dict @torch.no_grad() def validate(glctx, geometry, opt_material, lgt, dataset_validate, out_dir, FLAGS, denoiser=None, save_viz=False): # ============================================================================================== # Validation loop # ============================================================================================== mse_values = [] psnr_values = [] dataloader_validate = torch.utils.data.DataLoader(dataset_validate, batch_size=1, collate_fn=dataset_validate.collate) os.makedirs(out_dir, exist_ok=True) with open(os.path.join(out_dir, 'metrics.txt'), 'w') as fout: fout.write('ID, MSE, PSNR\n') print("Running validation") for it, target in enumerate(dataloader_validate): # Mix validation background target = prepare_batch(target, FLAGS.background) result_image, result_dict = validate_itr(glctx, target, geometry, opt_material, lgt, FLAGS, denoiser=denoiser) # Compute metrics opt = torch.clamp(result_dict['opt'], 0.0, 1.0) ref = torch.clamp(result_dict['ref'], 0.0, 1.0) mse = torch.nn.functional.mse_loss(opt, ref, size_average=None, reduce=None, reduction='mean').item() mse_values.append(float(mse)) psnr = util.mse_to_psnr(mse) psnr_values.append(float(psnr)) line = "%d, %1.8f, %1.8f\n" % (it, mse, psnr) fout.write(str(line)) if save_viz: for k in result_dict.keys(): np_img = result_dict[k].detach().cpu().numpy() util.save_image(out_dir + '/' + ('val_%06d_%s.png' % (it, k)), np_img) avg_mse = np.mean(np.array(mse_values)) avg_psnr = np.mean(np.array(psnr_values)) line = "AVERAGES: %1.4f, %2.3f\n" % (avg_mse, avg_psnr) fout.write(str(line)) print("MSE, PSNR") print("%1.8f, %2.3f" % (avg_mse, avg_psnr)) return avg_psnr ############################################################################### # Main shape fitter function / optimization loop ############################################################################### def optimize_mesh( denoiser, glctx, geometry, opt_material, lgt, dataset_train, dataset_validate, FLAGS, warmup_iter=0, log_interval=10, pass_idx=0, pass_name="", optimize_light=True, optimize_geometry=True, visualize=True, save_path=None ): # ============================================================================================== # Setup torch optimizer # ============================================================================================== learning_rate = FLAGS.learning_rate[pass_idx] if isinstance(FLAGS.learning_rate, list) or isinstance(FLAGS.learning_rate, tuple) else FLAGS.learning_rate learning_rate_pos = learning_rate[0] if isinstance(learning_rate, list) or isinstance(learning_rate, tuple) else learning_rate learning_rate_mat = learning_rate[1] if isinstance(learning_rate, list) or isinstance(learning_rate, tuple) else learning_rate learning_rate_lgt = learning_rate[2] if isinstance(learning_rate, list) or isinstance(learning_rate, tuple) else learning_rate * 6.0 def lr_schedule(iter, fraction): if iter < warmup_iter: return iter / warmup_iter return max(0.0, 10**(-(iter - warmup_iter)*0.0002)) # Exponential falloff from [1.0, 0.1] over 5k epochs. # ============================================================================================== # Image loss # ============================================================================================== image_loss_fn = createLoss(FLAGS) params = list(material.get_parameters(opt_material)) if optimize_light: optimizer_light = torch.optim.Adam((lgt.parameters() if lgt is not None else []), lr=learning_rate_lgt) scheduler_light = torch.optim.lr_scheduler.LambdaLR(optimizer_light, lr_lambda=lambda x: lr_schedule(x, 0.9)) if optimize_geometry: if FLAGS.use_sdf_mlp: lr_msdf = learning_rate_pos * 1e-2 if FLAGS.use_msdf_mlp else learning_rate_pos deform_params = list(v[1] for v in geometry.named_parameters() if 'deform' in v[0]) if optimize_geometry else [] msdf_params = list(v[1] for v in geometry.named_parameters() if 'msdf' in v[0]) if optimize_geometry else [] sdf_params = list(v[1] for v in geometry.named_parameters() if 'sdf' in v[0] and 'msdf' not in v[0]) if optimize_geometry else [] other_params = list(v[1] for v in geometry.named_parameters() if 'sdf' not in v[0] and 'msdf' not in v[0] and 'deform' not in v[0]) if optimize_geometry else [] optimizer_mesh = torch.optim.Adam([ {'params': deform_params, 'lr': learning_rate_pos}, {'params': msdf_params, 'lr': lr_msdf}, {'params': sdf_params, 'lr': learning_rate_pos * 1e-2}, {'params': other_params, 'lr': learning_rate_pos * 1e-2}, ], eps=1e-8) else: optimizer_mesh = torch.optim.Adam(geometry.parameters(), lr=learning_rate_pos) scheduler_mesh = torch.optim.lr_scheduler.LambdaLR(optimizer_mesh, lr_lambda=lambda x: lr_schedule(x, 0.9)) optimizer = torch.optim.Adam(params, lr=learning_rate_mat) scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda x: lr_schedule(x, 0.9)) # ============================================================================================== # Training loop # ============================================================================================== img_cnt = 0 img_loss_vec = [] depth_loss_vec = [] reg_loss_vec = [] iter_dur_vec = [] dataloader_train = torch.utils.data.DataLoader(dataset_train, batch_size=FLAGS.batch, collate_fn=dataset_train.collate, shuffle=True) if visualize: dataloader_validate = torch.utils.data.DataLoader(dataset_validate, batch_size=1, collate_fn=dataset_train.collate) def cycle(iterable): iterator = iter(iterable) while True: try: yield next(iterator) except StopIteration: iterator = iter(iterable) v_it = cycle(dataloader_validate) for it, target in enumerate(dataloader_train): # Mix randomized background into dataset image target = prepare_batch(target, 'random') # ============================================================================================== # Display / save outputs. Do it before training so we get initial meshes # ============================================================================================== # Show/save image before training step (want to get correct rendering of input) if visualize and FLAGS.local_rank == 0 and it != 0: with torch.no_grad(): display_image = FLAGS.display_interval and (it % FLAGS.display_interval == 0) save_image = FLAGS.save_interval and (it % FLAGS.save_interval == 0) if display_image or save_image: save_mesh = True if save_mesh: os.makedirs(os.path.join(save_path, pass_name), exist_ok=True) obj.write_obj(os.path.join(save_path, pass_name), geometry.getMesh(opt_material)['imesh'], save_material=False) result_image, result_dict = validate_itr(glctx, prepare_batch(next(v_it), FLAGS.background), geometry, opt_material, lgt, FLAGS, denoiser=denoiser) np_result_image = result_image.detach().cpu().numpy() if display_image: util.display_image(np_result_image, title='%d / %d' % (it, FLAGS.iter)) if save_image: util.save_image(os.path.join(save_path, ('img_%s_%06d.png' % (pass_name, img_cnt))), np_result_image) img_cnt = img_cnt + 1 iter_start_time = time.time() # ============================================================================================== # Zero gradients # ============================================================================================== optimizer.zero_grad() if optimize_geometry: optimizer_mesh.zero_grad() if optimize_light: optimizer_light.zero_grad() # ============================================================================================== # Training # ============================================================================================== xfm_lgt = None if optimize_light: lgt.update_pdf() img_loss, depth_loss, reg_loss = geometry.tick( glctx, target, lgt, opt_material, image_loss_fn, it, denoiser=denoiser) # ============================================================================================== # Final loss # ============================================================================================== total_loss = img_loss + reg_loss img_loss_vec.append(img_loss.item()) depth_loss_vec.append(depth_loss.item()) reg_loss_vec.append(reg_loss.item()) # ============================================================================================== # Backpropagate # ============================================================================================== total_loss.backward() if hasattr(lgt, 'base') and lgt.base.grad is not None and optimize_light: lgt.base.grad *= 64 if 'kd_ks' in opt_material: opt_material['kd_ks'].encoder.params.grad /= 8.0 if 'kd_ks_back' in opt_material: opt_material['kd_ks_back'].encoder.params.grad /= 8.0 # Optionally clip gradients if FLAGS.clip_max_norm > 0.0: if optimize_geometry: torch.nn.utils.clip_grad_norm_(geometry.parameters() + params, FLAGS.clip_max_norm) else: torch.nn.utils.clip_grad_norm_(params, FLAGS.clip_max_norm) optimizer.step() scheduler.step() if optimize_geometry: optimizer_mesh.step() scheduler_mesh.step() if optimize_light: optimizer_light.step() scheduler_light.step() # ============================================================================================== # Clamp trainables to reasonable range # ============================================================================================== with torch.no_grad(): if 'kd' in opt_material: opt_material['kd'].clamp_() if 'ks' in opt_material: opt_material['ks'].clamp_() if 'kd_back' in opt_material: opt_material['kd_back'].clamp_() if 'ks_back' in opt_material: opt_material['ks_back'].clamp_() if 'normal' in opt_material and not FLAGS.normal_only: opt_material['normal'].clamp_() opt_material['normal'].normalize_() if lgt is not None: # lgt.clamp_(min=0.01) # For some reason gradient dissapears if light becomes 0 lgt.clamp_(min=1e-4) # For some reason gradient dissapears if light becomes 0 geometry.clamp_deform() torch.cuda.current_stream().synchronize() iter_dur_vec.append(time.time() - iter_start_time) # ============================================================================================== # Logging # ============================================================================================== if it % log_interval == 0 and FLAGS.local_rank == 0: img_loss_avg = np.mean(np.asarray(img_loss_vec[-log_interval:])) depth_loss_avg = np.mean(np.asarray(depth_loss_vec[-log_interval:])) reg_loss_avg = np.mean(np.asarray(reg_loss_vec[-log_interval:])) iter_dur_avg = np.mean(np.asarray(iter_dur_vec[-log_interval:])) remaining_time = (FLAGS.iter-it)*iter_dur_avg print("iter=%5d, img_loss=%.6f, depth_loss=%.6f, reg_loss=%.6f, lr=%.5f, time=%.1f ms, rem=%s" % (it, img_loss_avg, depth_loss_avg, reg_loss_avg, optimizer.param_groups[0]['lr'], iter_dur_avg*1000, util.time_to_text(remaining_time))) sys.stdout.flush() if it == FLAGS.iter: break return geometry, opt_material #---------------------------------------------------------------------------- # Main function. #---------------------------------------------------------------------------- if __name__ == "__main__": parser = argparse.ArgumentParser(description='nvdiffrec') parser.add_argument('--config', type=str, default=None, help='Config file') parser.add_argument('-i', '--iter', type=int, default=5000) parser.add_argument('-b', '--batch', type=int, default=1) parser.add_argument('-s', '--spp', type=int, default=1) parser.add_argument('-l', '--layers', type=int, default=1) parser.add_argument('-r', '--train-res', nargs=2, type=int, default=[512, 512]) parser.add_argument('-dr', '--display-res', type=int, default=None) parser.add_argument('-tr', '--texture-res', nargs=2, type=int, default=[1024, 1024]) parser.add_argument('-di', '--display-interval', type=int, default=0) parser.add_argument('-si', '--save-interval', type=int, default=1000) parser.add_argument('-lr', '--learning-rate', type=float, default=0.01) parser.add_argument('-mr', '--min-roughness', type=float, default=0.08) parser.add_argument('-mip', '--custom-mip', action='store_true', default=False) parser.add_argument('-rt', '--random-textures', action='store_true', default=False) parser.add_argument('-bg', '--background', default='checker', choices=['black', 'white', 'checker', 'reference']) parser.add_argument('--loss', default='logl1', choices=['logl1', 'logl2', 'mse', 'smape', 'relmse']) parser.add_argument('-o', '--out-dir', type=str, default=None) parser.add_argument('-rm', '--ref_mesh', type=str) parser.add_argument('-bm', '--base-mesh', type=str, default=None) parser.add_argument('--validate', type=bool, default=True) # Render specific arguments parser.add_argument('--n_samples', type=int, default=4) parser.add_argument('--bsdf', type=str, default='pbr', choices=['pbr', 'diffuse', 'white']) # Denoiser specific arguments parser.add_argument('--denoiser', default='bilateral', choices=['none', 'bilateral']) parser.add_argument('--denoiser_demodulate', type=bool, default=True) parser.add_argument('--index',type=int) parser.add_argument('--trainset_path', type=str) parser.add_argument('--testset_path', type=str, default='') parser.add_argument('--msdf_reg_open_scale', type=float, default=1e-6) parser.add_argument('--msdf_reg_close_scale', type=float, default=3e-6) parser.add_argument('--eikonal_scale', type=float) FLAGS = parser.parse_args() FLAGS.mtl_override = None # Override material of model FLAGS.gshell_grid = 64 # Resolution of initial tet grid. We provide 64 and 128 resolution grids. # Other resolutions can be generated with https://github.com/crawforddoran/quartet # We include examples in data/tets/generate_tets.py FLAGS.mesh_scale = 1.4 # Scale of tet grid box. Adjust to cover the model FLAGS.envlight = None # HDR environment probe FLAGS.env_scale = 1.0 # Env map intensity multiplier FLAGS.probe_res = 256 # Env map probe resolution FLAGS.learn_lighting = True # Enable optimization of env lighting FLAGS.display = None # Configure validation window/display. E.g. [{"bsdf" : "kd"}, {"bsdf" : "ks"}] FLAGS.transparency = False # Enabled transparency through depth peeling FLAGS.lock_light = False # Disable light optimization in the second pass FLAGS.lock_pos = False # Disable vertex position optimization in the second pass FLAGS.sdf_regularizer = 0.2 # Weight for sdf regularizer. FLAGS.laplace = "relative" # Mesh Laplacian ["absolute", "relative"] FLAGS.laplace_scale = 3000.0 # Weight for Laplace regularizer. Default is relative with large weight FLAGS.pre_load = True # Pre-load entire dataset into memory for faster training FLAGS.no_perturbed_nrm = False # Disable normal map FLAGS.decorrelated = False # Use decorrelated sampling in forward and backward passes FLAGS.kd_min = [ 0.0, 0.0, 0.0, 0.0] FLAGS.kd_max = [ 1.0, 1.0, 1.0, 1.0] # FLAGS.ks_min = [ 0.0, 0.08, 0.0] FLAGS.ks_min = [ 0.0, 0.001, 0.0] FLAGS.ks_max = [ 0.0, 1.0, 1.0] FLAGS.nrm_min = [-1.0, -1.0, 0.0] FLAGS.nrm_max = [ 1.0, 1.0, 1.0] FLAGS.clip_max_norm = 0.0 FLAGS.cam_near_far = [0.1, 1000.0] FLAGS.lambda_kd = 0.1 FLAGS.lambda_ks = 0.05 FLAGS.lambda_nrm = 0.025 FLAGS.lambda_nrm2 = 0.25 FLAGS.lambda_chroma = 0.0 FLAGS.lambda_diffuse = 0.15 FLAGS.lambda_specular = 0.0025 FLAGS.random_lgt = False FLAGS.normal_only = False FLAGS.use_img_2nd_layer = False FLAGS.use_depth = False FLAGS.use_depth_2nd_layer = False FLAGS.use_tanh_deform = False FLAGS.use_sdf_mlp = True FLAGS.use_msdf_mlp = False FLAGS.use_eikonal = True FLAGS.sdf_mlp_pretrain_steps = 10000 FLAGS.use_mesh_msdf_reg = True FLAGS.sphere_init = False FLAGS.sphere_init_norm = 0.5 FLAGS.pretrained_sdf_mlp_path = f'./data/pretrained_mlp_{FLAGS.gshell_grid}_deeper.pt' FLAGS.n_hidden = 6 FLAGS.d_hidden = 256 FLAGS.n_freq = 6 FLAGS.skip_in = [3] FLAGS.use_float16 = False FLAGS.visualize_watertight = False FLAGS.local_rank = 0 FLAGS.multi_gpu = "WORLD_SIZE" in os.environ and int(os.environ["WORLD_SIZE"]) > 1 if FLAGS.multi_gpu: if "MASTER_ADDR" not in os.environ: os.environ["MASTER_ADDR"] = 'localhost' if "MASTER_PORT" not in os.environ: os.environ["MASTER_PORT"] = '23456' FLAGS.local_rank = int(os.environ["LOCAL_RANK"]) torch.cuda.set_device(FLAGS.local_rank) torch.distributed.init_process_group(backend="nccl", init_method="env://") if FLAGS.config is not None: data = json.load(open(FLAGS.config, 'r')) for key in data: FLAGS.__dict__[key] = data[key] if FLAGS.display_res is None: FLAGS.display_res = FLAGS.train_res if FLAGS.local_rank == 0: print("Config / Flags:") print("---------") for key in FLAGS.__dict__.keys(): print(key, FLAGS.__dict__[key]) print("---------") os.makedirs(FLAGS.out_dir, exist_ok=True) glctx = dr.RasterizeGLContext() glctx_display = glctx if FLAGS.batch < 16 else dr.RasterizeGLContext() # Context for display mtl_default = None # ============================================================================================== # Create data pipeline # ============================================================================================== dataset_path = FLAGS.trainset_path testset_path = FLAGS.testset_path folder_name_list = [30, 92, 117, 133, 164, 320, 448, 522, 591] folder_name = folder_name_list[FLAGS.index] folder_name = str(folder_name) data_root = os.path.join(dataset_path, folder_name) dataset_train = DatasetDeepFashion(data_root, FLAGS, examples=int(1e6)) dataset_validate = DatasetDeepFashion(data_root, FLAGS) if FLAGS.testset_path is not None and FLAGS.testset_path != '': testdata_root = os.path.join(testset_path, folder_name) dataset_test = DatasetDeepFashionTestset(testdata_root, FLAGS) # ============================================================================================== # Create env light with trainable parameters # ============================================================================================== lgt = None if FLAGS.learn_lighting: lgt = light.create_trainable_env_rnd(FLAGS.probe_res, scale=0.0, bias=0.5) # lgt = light.create_trainable_env_rnd(FLAGS.probe_res, scale=0.0, bias=0.1) else: lgt = light.load_env(FLAGS.envlight, scale=FLAGS.env_scale, res=[FLAGS.probe_res, FLAGS.probe_res]) # ============================================================================================== # Setup denoiser # ============================================================================================== denoiser = None if FLAGS.denoiser == 'bilateral': denoiser = BilateralDenoiser().cuda() else: assert FLAGS.denoiser == 'none', "Invalid denoiser %s" % FLAGS.denoiser # Setup geometry for optimization geometry = GShellFlexiCubesGeometry(FLAGS.gshell_grid, FLAGS.mesh_scale, FLAGS) # Setup textures, make initial guess from reference if possible if not FLAGS.normal_only: mat = initial_guess_material(geometry, True, FLAGS, mtl_default) else: mat = initial_guess_material_knownkskd(geometry, True, FLAGS, mtl_default) mat['no_perturbed_nrm'] = True # Run optimization geometry, mat = optimize_mesh(denoiser, glctx, geometry, mat, lgt, dataset_train, dataset_validate, FLAGS, pass_idx=0, pass_name="pass1", optimize_light=FLAGS.learn_lighting, save_path=os.path.join(FLAGS.out_dir, folder_name)) validate(glctx, geometry, mat, lgt, dataset_validate, os.path.join(FLAGS.out_dir, folder_name, "validate"), FLAGS, denoiser=denoiser, save_viz=True) if FLAGS.testset_path is not None and FLAGS.testset_path != '': validate(glctx, geometry, mat, lgt, dataset_test, os.path.join(FLAGS.out_dir, folder_name, "test"), FLAGS, denoiser=denoiser, save_viz=False) with torch.no_grad(): os.makedirs(os.path.join(FLAGS.out_dir, folder_name, "mesh"), exist_ok=True) torch.save(geometry.state_dict(), os.path.join(FLAGS.out_dir, folder_name, "mesh/model.pt")) torch.save(mat['kd_ks'].state_dict(), os.path.join(FLAGS.out_dir, folder_name, "mesh/mtl.pt")) light.save_env_map(os.path.join(FLAGS.out_dir, folder_name, "mesh/probe.hdr"), lgt) # Create textured mesh from result base_mesh = geometry.getMesh(mat)['imesh'] # Dump mesh for debugging. os.makedirs(os.path.join(FLAGS.out_dir, folder_name, "mesh"), exist_ok=True) obj.write_obj(os.path.join(FLAGS.out_dir, folder_name, "mesh/"), base_mesh, save_material=False) # Free temporaries / cached memory torch.cuda.empty_cache() mat['kd_ks'].cleanup() del mat['kd_ks'] if 'kd_ks_back' in mat: mat['kd_ks_back'].cleanup() del mat['kd_ks_back'] # Free temporaries / cached memory torch.cuda.empty_cache() del mat ================================================ FILE: train_gflexicubes_polycam.py ================================================ # Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual # property and proprietary rights in and to this material, related # documentation and any modifications thereto. Any use, reproduction, # disclosure or distribution of this material and related documentation # without an express license agreement from NVIDIA CORPORATION or # its affiliates is strictly prohibited. import os import sys import time import argparse import json import numpy as np import torch import nvdiffrast.torch as dr import xatlas # Import data readers / generators from dataset.dataset_nerf_colmap import DatasetNERF # Import topology / geometry trainers from geometry.gshell_flexicubes_geometry import GShellFlexiCubesGeometry import render.renderutils as ru from render import obj from render import material from render import util from render import mesh from render import texture from render import mlptexture from render import light from render import render from denoiser.denoiser import BilateralDenoiser import tqdm RADIUS = 3.0 # Enable to debug back-prop anomalies # torch.autograd.set_detect_anomaly(True) ############################################################################### # Loss setup ############################################################################### @torch.no_grad() def createLoss(FLAGS): if FLAGS.loss == "smape": return lambda img, ref: ru.image_loss(img, ref, loss='smape', tonemapper='none') elif FLAGS.loss == "mse": return lambda img, ref: ru.image_loss(img, ref, loss='mse', tonemapper='none') elif FLAGS.loss == "logl1": return lambda img, ref: ru.image_loss(img, ref, loss='l1', tonemapper='log_srgb') elif FLAGS.loss == "logl2": return lambda img, ref: ru.image_loss(img, ref, loss='mse', tonemapper='log_srgb') elif FLAGS.loss == "relmse": return lambda img, ref: ru.image_loss(img, ref, loss='relmse', tonemapper='none') else: assert False ############################################################################### # Mix background into a dataset image ############################################################################### @torch.no_grad() def prepare_batch(target, bg_type='black'): assert len(target['img'].shape) == 4, "Image shape should be [n, h, w, c]" if bg_type == 'checker': background = torch.tensor(util.checkerboard(target['img'].shape[1:3], 8), dtype=torch.float32, device='cuda')[None, ...] elif bg_type == 'black': background = torch.zeros(target['img'].shape[0:3] + (3,), dtype=torch.float32, device='cuda') elif bg_type == 'white': background = torch.ones(target['img'].shape[0:3] + (3,), dtype=torch.float32, device='cuda') elif bg_type == 'reference': background = target['img'][..., 0:3] elif bg_type == 'random': background = torch.rand(target['img'].shape[0:3] + (3,), dtype=torch.float32, device='cuda') else: assert False, "Unknown background type %s" % bg_type target['mv'] = target['mv'].cuda() target['mvp'] = target['mvp'].cuda() target['campos'] = target['campos'].cuda() target['img'] = target['img'].cuda() target['background'] = background target['img'] = torch.cat((torch.lerp(background, target['img'][..., 0:3], target['img'][..., 3:4]), target['img'][..., 3:4]), dim=-1) return target ############################################################################### # UV - map geometry & convert to a mesh ############################################################################### @torch.no_grad() def xatlas_uvmap(glctx, geometry, mat, FLAGS): eval_mesh = geometry.getMesh(mat) try: eval_mesh = eval_mesh['imesh'] except: pass # Create uvs with xatlas v_pos = eval_mesh.v_pos.detach().cpu().numpy() t_pos_idx = eval_mesh.t_pos_idx.detach().cpu().numpy() vmapping, indices, uvs = xatlas.parametrize(v_pos, t_pos_idx) # Convert to tensors indices_int64 = indices.astype(np.uint64, casting='same_kind').view(np.int64) uvs = torch.tensor(uvs, dtype=torch.float32, device='cuda') faces = torch.tensor(indices_int64, dtype=torch.int64, device='cuda') new_mesh = mesh.Mesh(v_tex=uvs, t_tex_idx=faces, base=eval_mesh) mask, kd, ks = render.render_uv(glctx, new_mesh, FLAGS.texture_res, eval_mesh.material['kd_ks']) # Dilate all textures & use average color for background kd_avg = torch.sum(torch.sum(torch.sum(kd * mask, dim=0), dim=0), dim=0) / torch.sum(torch.sum(torch.sum(mask, dim=0), dim=0), dim=0) kd = util.dilate(kd, kd_avg[None, None, None, :], mask, 7) ks_avg = torch.sum(torch.sum(torch.sum(ks * mask, dim=0), dim=0), dim=0) / torch.sum(torch.sum(torch.sum(mask, dim=0), dim=0), dim=0) ks = util.dilate(ks, ks_avg[None, None, None, :], mask, 7) nrm_avg = torch.tensor([0, 0, 1], dtype=torch.float32, device="cuda") normal = nrm_avg[None, None, None, :].repeat(kd.shape[0], kd.shape[1], kd.shape[2], 1) new_mesh.material = mat.copy() del new_mesh.material['kd_ks'] if FLAGS.transparency: kd = torch.cat((kd, torch.rand_like(kd[...,0:1])), dim=-1) print("kd shape", kd.shape) kd_min, kd_max = torch.tensor(FLAGS.kd_min, dtype=torch.float32, device='cuda'), torch.tensor(FLAGS.kd_max, dtype=torch.float32, device='cuda') ks_min, ks_max = torch.tensor(FLAGS.ks_min, dtype=torch.float32, device='cuda'), torch.tensor(FLAGS.ks_max, dtype=torch.float32, device='cuda') nrm_min, nrm_max = torch.tensor(FLAGS.nrm_min, dtype=torch.float32, device='cuda'), torch.tensor(FLAGS.nrm_max, dtype=torch.float32, device='cuda') new_mesh.material.update({ 'kd' : texture.Texture2D(kd.clone().detach().requires_grad_(True), min_max=[kd_min, kd_max]), 'ks' : texture.Texture2D(ks.clone().detach().requires_grad_(True), min_max=[ks_min, ks_max]), 'normal' : texture.Texture2D(normal.clone().detach().requires_grad_(True), min_max=[nrm_min, nrm_max]), }) return new_mesh ############################################################################### # Utility functions for material ############################################################################### def initial_guess_material(geometry, mlp, FLAGS, init_mat=None): kd_min, kd_max = torch.tensor(FLAGS.kd_min, dtype=torch.float32, device='cuda'), torch.tensor(FLAGS.kd_max, dtype=torch.float32, device='cuda') ks_min, ks_max = torch.tensor(FLAGS.ks_min, dtype=torch.float32, device='cuda'), torch.tensor(FLAGS.ks_max, dtype=torch.float32, device='cuda') if mlp: mlp_min = torch.cat((kd_min[0:3], ks_min), dim=0) mlp_max = torch.cat((kd_max[0:3], ks_max), dim=0) mlp_map_opt = mlptexture.MLPTexture3D(geometry.getAABB(), channels=6, min_max=[mlp_min, mlp_max], use_float16=FLAGS.use_float16) mat = {'kd_ks' : mlp_map_opt} else: raise NotImplementedError mat['bsdf'] = FLAGS.bsdf mat['no_perturbed_nrm'] = FLAGS.no_perturbed_nrm return mat def initial_guess_material_knownkskd(geometry, mlp, FLAGS, init_mat=None): mat = { 'kd' : init_mat['kd'], 'ks' : init_mat['ks'] } if init_mat is not None: mat['bsdf'] = init_mat['bsdf'] else: mat['bsdf'] = 'pbr' return mat ############################################################################### # Validation & testing ############################################################################### @torch.no_grad() def validate_itr(glctx, target, geometry, opt_material, lgt, FLAGS, denoiser=None): result_dict = {} with torch.no_grad(): buffers = geometry.render(glctx, target, lgt, opt_material, use_uv=False, denoiser=denoiser)['buffers'] result_dict['ref'] = util.rgb_to_srgb(target['img'][...,0:3])[0] result_dict['opt'] = util.rgb_to_srgb(buffers['shaded'][...,0:3])[0] result_dict['mask_opt'] = buffers['shaded'][...,3:][0].expand(-1, -1, 3) result_dict['mask_ref'] = target['img'][...,3:][0].expand(-1, -1, 3) result_dict['msdf_image'] = buffers['msdf_image'][...,:][0].expand(-1, -1, 3).clamp(min=0, max=1) result_image = torch.cat([result_dict['opt'], result_dict['ref'], result_dict['mask_opt'], result_dict['mask_ref'], result_dict['msdf_image']], axis=1) result_dict = {} result_dict['ref'] = util.rgb_to_srgb(target['img'][...,0:3])[0] result_dict['opt'] = util.rgb_to_srgb(buffers['shaded'][...,0:3])[0] return result_image, result_dict @torch.no_grad() def validate(glctx, geometry, opt_material, lgt, dataset_validate, out_dir, FLAGS, denoiser=None, save_viz=False): # ============================================================================================== # Validation loop # ============================================================================================== mse_values = [] psnr_values = [] dataloader_validate = torch.utils.data.DataLoader(dataset_validate, batch_size=1, collate_fn=dataset_validate.collate) os.makedirs(out_dir, exist_ok=True) with open(os.path.join(out_dir, 'metrics.txt'), 'w') as fout: fout.write('ID, MSE, PSNR\n') print("Running validation") for it, target in enumerate(tqdm.tqdm(dataloader_validate)): # Mix validation background target = prepare_batch(target, FLAGS.background) result_image, result_dict = validate_itr(glctx, target, geometry, opt_material, lgt, FLAGS, denoiser=denoiser) # Compute metrics opt = torch.clamp(result_dict['opt'], 0.0, 1.0) ref = torch.clamp(result_dict['ref'], 0.0, 1.0) mse = torch.nn.functional.mse_loss(opt, ref, size_average=None, reduce=None, reduction='mean').item() mse_values.append(float(mse)) psnr = util.mse_to_psnr(mse) psnr_values.append(float(psnr)) line = "%d, %1.8f, %1.8f\n" % (it, mse, psnr) fout.write(str(line)) if save_viz: for k in result_dict.keys(): np_img = result_dict[k].detach().cpu().numpy() util.save_image(out_dir + '/' + ('val_%06d_%s.png' % (it, k)), np_img) avg_mse = np.mean(np.array(mse_values)) avg_psnr = np.mean(np.array(psnr_values)) line = "AVERAGES: %1.4f, %2.3f\n" % (avg_mse, avg_psnr) fout.write(str(line)) print("MSE, PSNR") print("%1.8f, %2.3f" % (avg_mse, avg_psnr)) return avg_psnr ############################################################################### # Main shape fitter function / optimization loop ############################################################################### def optimize_mesh( denoiser, glctx, geometry, opt_material, lgt, dataset_train, dataset_validate, FLAGS, warmup_iter=0, log_interval=10, pass_idx=0, pass_name="", optimize_light=True, optimize_geometry=True, visualize=True, save_path=None ): # ============================================================================================== # Setup torch optimizer # ============================================================================================== learning_rate = FLAGS.learning_rate[pass_idx] if isinstance(FLAGS.learning_rate, list) or isinstance(FLAGS.learning_rate, tuple) else FLAGS.learning_rate learning_rate_pos = learning_rate[0] if isinstance(learning_rate, list) or isinstance(learning_rate, tuple) else learning_rate learning_rate_mat = learning_rate[1] if isinstance(learning_rate, list) or isinstance(learning_rate, tuple) else learning_rate learning_rate_lgt = learning_rate[2] if isinstance(learning_rate, list) or isinstance(learning_rate, tuple) else learning_rate * 6.0 def lr_schedule(iter, fraction): if iter < warmup_iter: return iter / warmup_iter return max(0.0, 10**(-(iter - warmup_iter)*0.0002)) # Exponential falloff from [1.0, 0.1] over 5k epochs. # ============================================================================================== # Image loss # ============================================================================================== image_loss_fn = createLoss(FLAGS) params = list(material.get_parameters(opt_material)) if optimize_light: optimizer_light = torch.optim.Adam((lgt.parameters() if lgt is not None else []), lr=learning_rate_lgt) scheduler_light = torch.optim.lr_scheduler.LambdaLR(optimizer_light, lr_lambda=lambda x: lr_schedule(x, 0.9)) if optimize_geometry: if FLAGS.use_sdf_mlp: deform_params = list(v[1] for v in geometry.named_parameters() if 'deform' in v[0]) if optimize_geometry else [] msdf_params = list(v[1] for v in geometry.named_parameters() if 'msdf' in v[0]) if optimize_geometry else [] sdf_params = list(v[1] for v in geometry.named_parameters() if 'sdf' in v[0] and 'msdf' not in v[0]) if optimize_geometry else [] other_params = list(v[1] for v in geometry.named_parameters() if 'sdf' not in v[0] and 'msdf' not in v[0] and 'deform' not in v[0]) if optimize_geometry else [] optimizer_mesh = torch.optim.Adam([ {'params': deform_params, 'lr': learning_rate_pos}, {'params': msdf_params, 'lr': learning_rate_pos}, {'params': sdf_params, 'lr': learning_rate_pos * 1e-2}, {'params': other_params, 'lr': learning_rate_pos * 1e-2}, ], eps=1e-8) else: optimizer_mesh = torch.optim.Adam(geometry.parameters(), lr=learning_rate_pos) scheduler_mesh = torch.optim.lr_scheduler.LambdaLR(optimizer_mesh, lr_lambda=lambda x: lr_schedule(x, 0.9)) optimizer = torch.optim.Adam(params, lr=learning_rate_mat) scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda x: lr_schedule(x, 0.9)) # ============================================================================================== # Training loop # ============================================================================================== img_cnt = 0 img_loss_vec = [] depth_loss_vec = [] reg_loss_vec = [] iter_dur_vec = [] dataloader_train = torch.utils.data.DataLoader(dataset_train, batch_size=FLAGS.batch, collate_fn=dataset_train.collate, shuffle=True) if visualize: dataloader_validate = torch.utils.data.DataLoader(dataset_validate, batch_size=1, collate_fn=dataset_train.collate) def cycle(iterable): iterator = iter(iterable) while True: try: yield next(iterator) except StopIteration: iterator = iter(iterable) v_it = cycle(dataloader_validate) for it, target in enumerate(dataloader_train): # Mix randomized background into dataset image target = prepare_batch(target, 'random') # ============================================================================================== # Display / save outputs. Do it before training so we get initial meshes # ============================================================================================== # Show/save image before training step (want to get correct rendering of input) if visualize and FLAGS.local_rank == 0 and it != 0: with torch.no_grad(): display_image = FLAGS.display_interval and (it % FLAGS.display_interval == 0) save_image = FLAGS.save_interval and (it % FLAGS.save_interval == 0) if display_image or save_image: save_mesh = True if save_mesh: os.makedirs(os.path.join(save_path, pass_name), exist_ok=True) obj.write_obj(os.path.join(save_path, pass_name), geometry.getMesh(opt_material)['imesh'], save_material=False) result_image, result_dict = validate_itr(glctx, prepare_batch(next(v_it), FLAGS.background), geometry, opt_material, lgt, FLAGS, denoiser=denoiser) np_result_image = result_image.detach().cpu().numpy() if display_image: util.display_image(np_result_image, title='%d / %d' % (it, FLAGS.iter)) if save_image: util.save_image(os.path.join(save_path, ('img_%s_%06d.png' % (pass_name, img_cnt))), np_result_image) img_cnt = img_cnt + 1 iter_start_time = time.time() # ============================================================================================== # Zero gradients # ============================================================================================== optimizer.zero_grad() if optimize_geometry: optimizer_mesh.zero_grad() if optimize_light: optimizer_light.zero_grad() # ============================================================================================== # Training # ============================================================================================== xfm_lgt = None if optimize_light: if False and FLAGS.camera_space_light: lgt.xfm(target['mv']) elif False and ('envlight_transform' in target and target['envlight_transform'] is not None): xfm_lgt = target['envlight_transform'] lgt.xfm(xfm_lgt) lgt.update_pdf() img_loss, depth_loss, reg_loss = geometry.tick( glctx, target, lgt, opt_material, image_loss_fn, it, denoiser=denoiser) # ============================================================================================== # Final loss # ============================================================================================== total_loss = img_loss + reg_loss img_loss_vec.append(img_loss.item()) depth_loss_vec.append(depth_loss.item()) reg_loss_vec.append(reg_loss.item()) # ============================================================================================== # Backpropagate # ============================================================================================== total_loss.backward() if hasattr(lgt, 'base') and lgt.base.grad is not None and optimize_light: lgt.base.grad *= 64 if 'kd_ks' in opt_material: opt_material['kd_ks'].encoder.params.grad /= 8.0 if 'kd_ks_back' in opt_material: opt_material['kd_ks_back'].encoder.params.grad /= 8.0 # Optionally clip gradients if FLAGS.clip_max_norm > 0.0: if optimize_geometry: torch.nn.utils.clip_grad_norm_(geometry.parameters() + params, FLAGS.clip_max_norm) else: torch.nn.utils.clip_grad_norm_(params, FLAGS.clip_max_norm) optimizer.step() scheduler.step() if optimize_geometry: optimizer_mesh.step() scheduler_mesh.step() if optimize_light: optimizer_light.step() scheduler_light.step() # ============================================================================================== # Clamp trainables to reasonable range # ============================================================================================== with torch.no_grad(): if 'kd' in opt_material: opt_material['kd'].clamp_() if 'ks' in opt_material: opt_material['ks'].clamp_() if 'kd_back' in opt_material: opt_material['kd_back'].clamp_() if 'ks_back' in opt_material: opt_material['ks_back'].clamp_() if 'normal' in opt_material and not FLAGS.normal_only: opt_material['normal'].clamp_() opt_material['normal'].normalize_() if lgt is not None: # lgt.clamp_(min=0.01) # For some reason gradient dissapears if light becomes 0 lgt.clamp_(min=1e-4) # For some reason gradient dissapears if light becomes 0 geometry.clamp_deform() torch.cuda.current_stream().synchronize() iter_dur_vec.append(time.time() - iter_start_time) # ============================================================================================== # Logging # ============================================================================================== if it % log_interval == 0 and FLAGS.local_rank == 0: img_loss_avg = np.mean(np.asarray(img_loss_vec[-log_interval:])) depth_loss_avg = np.mean(np.asarray(depth_loss_vec[-log_interval:])) reg_loss_avg = np.mean(np.asarray(reg_loss_vec[-log_interval:])) iter_dur_avg = np.mean(np.asarray(iter_dur_vec[-log_interval:])) remaining_time = (FLAGS.iter-it)*iter_dur_avg print("iter=%5d, img_loss=%.6f, depth_loss=%.6f, reg_loss=%.6f, lr=%.5f, time=%.1f ms, rem=%s" % (it, img_loss_avg, depth_loss_avg, reg_loss_avg, optimizer.param_groups[0]['lr'], iter_dur_avg*1000, util.time_to_text(remaining_time))) sys.stdout.flush() if it == FLAGS.iter: break return geometry, opt_material #---------------------------------------------------------------------------- # Main function. #---------------------------------------------------------------------------- if __name__ == "__main__": parser = argparse.ArgumentParser(description='nvdiffrec') parser.add_argument('--config', type=str, default=None, help='Config file') parser.add_argument('-i', '--iter', type=int, default=5000) parser.add_argument('-b', '--batch', type=int, default=1) parser.add_argument('-s', '--spp', type=int, default=1) parser.add_argument('-l', '--layers', type=int, default=1) parser.add_argument('-r', '--train-res', nargs=2, type=int, default=[512, 512]) parser.add_argument('-dr', '--display-res', type=int, default=None) parser.add_argument('-tr', '--texture-res', nargs=2, type=int, default=[1024, 1024]) parser.add_argument('-di', '--display-interval', type=int, default=0) parser.add_argument('-si', '--save-interval', type=int, default=1000) parser.add_argument('-lr', '--learning-rate', type=float, default=0.01) parser.add_argument('-mr', '--min-roughness', type=float, default=0.08) parser.add_argument('-mip', '--custom-mip', action='store_true', default=False) parser.add_argument('-rt', '--random-textures', action='store_true', default=False) parser.add_argument('-bg', '--background', default='checker', choices=['black', 'white', 'checker', 'reference']) parser.add_argument('--loss', default='logl1', choices=['logl1', 'logl2', 'mse', 'smape', 'relmse']) parser.add_argument('-o', '--out-dir', type=str, default=None) parser.add_argument('-rm', '--ref_mesh', type=str) parser.add_argument('-bm', '--base-mesh', type=str, default=None) parser.add_argument('--validate', type=bool, default=True) # Render specific arguments parser.add_argument('--n_samples', type=int, default=4) parser.add_argument('--bsdf', type=str, default='pbr', choices=['pbr', 'diffuse', 'white']) # Denoiser specific arguments parser.add_argument('--denoiser', default='bilateral', choices=['none', 'bilateral']) parser.add_argument('--denoiser_demodulate', type=bool, default=True) parser.add_argument('--msdf_reg_open_scale', type=float, default=1e-6) parser.add_argument('--msdf_reg_close_scale', type=float, default=3e-4) parser.add_argument('--eikonal_scale', type=float, default=5e-2) parser.add_argument('--trainset_path', type=str) parser.add_argument('--testset_path', type=str, default='') FLAGS = parser.parse_args() FLAGS.mtl_override = None # Override material of model FLAGS.gshell_grid = 64 # Resolution of initial tet grid. We provide 64 and 128 resolution grids. # Other resolutions can be generated with https://github.com/crawforddoran/quartet # We include examples in data/tets/generate_tets.py FLAGS.mesh_scale = 3.6 # Scale of tet grid box. Adjust to cover the model FLAGS.envlight = None # HDR environment probe FLAGS.env_scale = 1.0 # Env map intensity multiplier FLAGS.probe_res = 256 # Env map probe resolution FLAGS.learn_lighting = True # Enable optimization of env lighting FLAGS.display = None # Configure validation window/display. E.g. [{"bsdf" : "kd"}, {"bsdf" : "ks"}] FLAGS.transparency = False # Enabled transparency through depth peeling FLAGS.lock_light = False # Disable light optimization in the second pass FLAGS.lock_pos = False # Disable vertex position optimization in the second pass FLAGS.sdf_regularizer = 0.2 # Weight for sdf regularizer. FLAGS.laplace = "relative" # Mesh Laplacian ["absolute", "relative"] FLAGS.laplace_scale = 3000.0 # Weight for Laplace regularizer. Default is relative with large weight FLAGS.pre_load = True # Pre-load entire dataset into memory for faster training FLAGS.no_perturbed_nrm = False # Disable normal map FLAGS.decorrelated = False # Use decorrelated sampling in forward and backward passes FLAGS.kd_min = [ 0.0, 0.0, 0.0, 0.0] FLAGS.kd_max = [ 1.0, 1.0, 1.0, 1.0] # FLAGS.ks_min = [ 0.0, 0.08, 0.0] FLAGS.ks_min = [ 0.0, 0.001, 0.0] FLAGS.ks_max = [ 0.0, 1.0, 1.0] FLAGS.nrm_min = [-1.0, -1.0, 0.0] FLAGS.nrm_max = [ 1.0, 1.0, 1.0] FLAGS.clip_max_norm = 0.0 FLAGS.cam_near_far = [0.1, 1000.0] FLAGS.lambda_kd = 0.1 FLAGS.lambda_ks = 0.05 FLAGS.lambda_nrm = 0.025 FLAGS.lambda_nrm2 = 0.25 FLAGS.lambda_chroma = 0.0 FLAGS.lambda_diffuse = 0.15 FLAGS.lambda_specular = 0.0025 FLAGS.random_lgt = False FLAGS.normal_only = False FLAGS.use_img_2nd_layer = False FLAGS.use_depth = False FLAGS.use_depth_2nd_layer = False FLAGS.use_tanh_deform = False FLAGS.use_sdf_mlp = True FLAGS.use_msdf_mlp = False FLAGS.use_eikonal = True FLAGS.sdf_mlp_pretrain_steps = 10000 FLAGS.use_mesh_msdf_reg = True FLAGS.sphere_init = False FLAGS.sphere_init_norm = 1.5 FLAGS.pretrained_sdf_mlp_path = f'./data/pretrained_mlp_{FLAGS.gshell_grid}_polycam.pt' FLAGS.n_hidden = 6 FLAGS.d_hidden = 256 FLAGS.n_freq = 6 FLAGS.skip_in = [3] FLAGS.use_float16 = False FLAGS.visualize_watertight = False FLAGS.local_rank = 0 FLAGS.multi_gpu = "WORLD_SIZE" in os.environ and int(os.environ["WORLD_SIZE"]) > 1 if FLAGS.multi_gpu: if "MASTER_ADDR" not in os.environ: os.environ["MASTER_ADDR"] = 'localhost' if "MASTER_PORT" not in os.environ: os.environ["MASTER_PORT"] = '23456' FLAGS.local_rank = int(os.environ["LOCAL_RANK"]) torch.cuda.set_device(FLAGS.local_rank) torch.distributed.init_process_group(backend="nccl", init_method="env://") if FLAGS.config is not None: data = json.load(open(FLAGS.config, 'r')) for key in data: FLAGS.__dict__[key] = data[key] if FLAGS.display_res is None: FLAGS.display_res = FLAGS.train_res if FLAGS.local_rank == 0: print("Config / Flags:") print("---------") for key in FLAGS.__dict__.keys(): print(key, FLAGS.__dict__[key]) print("---------") os.makedirs(FLAGS.out_dir, exist_ok=True) glctx = dr.RasterizeGLContext() glctx_display = glctx if FLAGS.batch < 16 else dr.RasterizeGLContext() # Context for display mtl_default = None # ============================================================================================== # Create data pipeline # ============================================================================================== data_root = FLAGS.trainset_path dataset_train = DatasetNERF(os.path.join(data_root, 'transforms.json'), FLAGS, examples=int(1e6)) dataset_validate = DatasetNERF(os.path.join(data_root, 'transforms.json'), FLAGS) # ============================================================================================== # Create env light with trainable parameters # ============================================================================================== lgt = None if FLAGS.learn_lighting: lgt = light.create_trainable_env_rnd(FLAGS.probe_res, scale=0.0, bias=0.5) # lgt = light.create_trainable_env_rnd(FLAGS.probe_res, scale=0.0, bias=0.1) else: lgt = light.load_env(FLAGS.envlight, scale=FLAGS.env_scale, res=[FLAGS.probe_res, FLAGS.probe_res]) # ============================================================================================== # Setup denoiser # ============================================================================================== denoiser = None if FLAGS.denoiser == 'bilateral': denoiser = BilateralDenoiser().cuda() else: assert FLAGS.denoiser == 'none', "Invalid denoiser %s" % FLAGS.denoiser # Setup geometry for optimization geometry = GShellFlexiCubesGeometry(FLAGS.gshell_grid, FLAGS.mesh_scale, FLAGS) # Setup textures, make initial guess from reference if possible if not FLAGS.normal_only: mat = initial_guess_material(geometry, True, FLAGS, mtl_default) else: mat = initial_guess_material_knownkskd(geometry, True, FLAGS, mtl_default) mat['no_perturbed_nrm'] = True # Run optimization geometry, mat = optimize_mesh(denoiser, glctx, geometry, mat, lgt, dataset_train, dataset_validate, FLAGS, pass_idx=0, pass_name="pass1", optimize_light=FLAGS.learn_lighting, save_path=FLAGS.out_dir) validate(glctx, geometry, mat, lgt, dataset_validate, os.path.join(FLAGS.out_dir, "validate"), FLAGS, denoiser=denoiser, save_viz=True) with torch.no_grad(): os.makedirs(os.path.join(FLAGS.out_dir, "mesh"), exist_ok=True) torch.save(geometry.state_dict(), os.path.join(FLAGS.out_dir, "mesh/model.pt")) torch.save(mat['kd_ks'].state_dict(), os.path.join(FLAGS.out_dir, "mesh/mtl.pt")) light.save_env_map(os.path.join(FLAGS.out_dir, "mesh/probe.hdr"), lgt) # Create textured mesh from result base_mesh = geometry.getMesh(mat)['imesh'] # Dump mesh for debugging. os.makedirs(os.path.join(FLAGS.out_dir, "mesh"), exist_ok=True) obj.write_obj(os.path.join(FLAGS.out_dir, "mesh/"), base_mesh, save_material=False) # Free temporaries / cached memory torch.cuda.empty_cache() mat['kd_ks'].cleanup() del mat['kd_ks'] if 'kd_ks_back' in mat: mat['kd_ks_back'].cleanup() del mat['kd_ks_back'] # Free temporaries / cached memory torch.cuda.empty_cache() del mat ================================================ FILE: train_gshelltet_deepfashion.py ================================================ # Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual # property and proprietary rights in and to this material, related # documentation and any modifications thereto. Any use, reproduction, # disclosure or distribution of this material and related documentation # without an express license agreement from NVIDIA CORPORATION or # its affiliates is strictly prohibited. import os import sys import time import argparse import json import numpy as np import torch import nvdiffrast.torch as dr import xatlas # Import data readers / generators from dataset.dataset_deepfashion import DatasetDeepFashion from dataset.dataset_deepfashion_testset import DatasetDeepFashionTestset # Import topology / geometry trainers from geometry.gshell_tets_geometry import GShellTetsGeometry import render.renderutils as ru from render import obj from render import material from render import util from render import mesh from render import texture from render import mlptexture from render import light from render import render from denoiser.denoiser import BilateralDenoiser RADIUS = 3.0 # Enable to debug back-prop anomalies # torch.autograd.set_detect_anomaly(True) ############################################################################### # Loss setup ############################################################################### @torch.no_grad() def createLoss(FLAGS): if FLAGS.loss == "smape": return lambda img, ref: ru.image_loss(img, ref, loss='smape', tonemapper='none') elif FLAGS.loss == "mse": return lambda img, ref: ru.image_loss(img, ref, loss='mse', tonemapper='none') elif FLAGS.loss == "logl1": return lambda img, ref: ru.image_loss(img, ref, loss='l1', tonemapper='log_srgb') elif FLAGS.loss == "logl2": return lambda img, ref: ru.image_loss(img, ref, loss='mse', tonemapper='log_srgb') elif FLAGS.loss == "relmse": return lambda img, ref: ru.image_loss(img, ref, loss='relmse', tonemapper='none') else: assert False ############################################################################### # Mix background into a dataset image ############################################################################### @torch.no_grad() def prepare_batch(target, bg_type='black'): assert len(target['img'].shape) == 4, "Image shape should be [n, h, w, c]" if bg_type == 'checker': background = torch.tensor(util.checkerboard(target['img'].shape[1:3], 8), dtype=torch.float32, device='cuda')[None, ...] elif bg_type == 'black': background = torch.zeros(target['img'].shape[0:3] + (3,), dtype=torch.float32, device='cuda') elif bg_type == 'white': background = torch.ones(target['img'].shape[0:3] + (3,), dtype=torch.float32, device='cuda') elif bg_type == 'reference': background = target['img'][..., 0:3] elif bg_type == 'random': background = torch.rand(target['img'].shape[0:3] + (3,), dtype=torch.float32, device='cuda') else: assert False, "Unknown background type %s" % bg_type target['mv'] = target['mv'].cuda() target['mvp'] = target['mvp'].cuda() target['campos'] = target['campos'].cuda() target['img'] = target['img'].cuda() target['background'] = background target['img'] = torch.cat((torch.lerp(background, target['img'][..., 0:3], target['img'][..., 3:4]), target['img'][..., 3:4]), dim=-1) return target ############################################################################### # UV - map geometry & convert to a mesh ############################################################################### @torch.no_grad() def xatlas_uvmap(glctx, geometry, mat, FLAGS): eval_mesh = geometry.getMesh(mat) try: eval_mesh = eval_mesh['imesh'] except: pass # Create uvs with xatlas v_pos = eval_mesh.v_pos.detach().cpu().numpy() t_pos_idx = eval_mesh.t_pos_idx.detach().cpu().numpy() vmapping, indices, uvs = xatlas.parametrize(v_pos, t_pos_idx) # Convert to tensors indices_int64 = indices.astype(np.uint64, casting='same_kind').view(np.int64) uvs = torch.tensor(uvs, dtype=torch.float32, device='cuda') faces = torch.tensor(indices_int64, dtype=torch.int64, device='cuda') new_mesh = mesh.Mesh(v_tex=uvs, t_tex_idx=faces, base=eval_mesh) mask, kd, ks = render.render_uv(glctx, new_mesh, FLAGS.texture_res, eval_mesh.material['kd_ks']) # Dilate all textures & use average color for background kd_avg = torch.sum(torch.sum(torch.sum(kd * mask, dim=0), dim=0), dim=0) / torch.sum(torch.sum(torch.sum(mask, dim=0), dim=0), dim=0) kd = util.dilate(kd, kd_avg[None, None, None, :], mask, 7) ks_avg = torch.sum(torch.sum(torch.sum(ks * mask, dim=0), dim=0), dim=0) / torch.sum(torch.sum(torch.sum(mask, dim=0), dim=0), dim=0) ks = util.dilate(ks, ks_avg[None, None, None, :], mask, 7) nrm_avg = torch.tensor([0, 0, 1], dtype=torch.float32, device="cuda") normal = nrm_avg[None, None, None, :].repeat(kd.shape[0], kd.shape[1], kd.shape[2], 1) new_mesh.material = mat.copy() del new_mesh.material['kd_ks'] if FLAGS.transparency: kd = torch.cat((kd, torch.rand_like(kd[...,0:1])), dim=-1) print("kd shape", kd.shape) kd_min, kd_max = torch.tensor(FLAGS.kd_min, dtype=torch.float32, device='cuda'), torch.tensor(FLAGS.kd_max, dtype=torch.float32, device='cuda') ks_min, ks_max = torch.tensor(FLAGS.ks_min, dtype=torch.float32, device='cuda'), torch.tensor(FLAGS.ks_max, dtype=torch.float32, device='cuda') nrm_min, nrm_max = torch.tensor(FLAGS.nrm_min, dtype=torch.float32, device='cuda'), torch.tensor(FLAGS.nrm_max, dtype=torch.float32, device='cuda') new_mesh.material.update({ 'kd' : texture.Texture2D(kd.clone().detach().requires_grad_(True), min_max=[kd_min, kd_max]), 'ks' : texture.Texture2D(ks.clone().detach().requires_grad_(True), min_max=[ks_min, ks_max]), 'normal' : texture.Texture2D(normal.clone().detach().requires_grad_(True), min_max=[nrm_min, nrm_max]), }) return new_mesh ############################################################################### # Utility functions for material ############################################################################### def initial_guess_material(geometry, mlp, FLAGS, init_mat=None): kd_min, kd_max = torch.tensor(FLAGS.kd_min, dtype=torch.float32, device='cuda'), torch.tensor(FLAGS.kd_max, dtype=torch.float32, device='cuda') ks_min, ks_max = torch.tensor(FLAGS.ks_min, dtype=torch.float32, device='cuda'), torch.tensor(FLAGS.ks_max, dtype=torch.float32, device='cuda') if mlp: mlp_min = torch.cat((kd_min[0:3], ks_min), dim=0) mlp_max = torch.cat((kd_max[0:3], ks_max), dim=0) mlp_map_opt = mlptexture.MLPTexture3D(geometry.getAABB(), channels=6, min_max=[mlp_min, mlp_max], use_float16=FLAGS.use_float16) mat = {'kd_ks' : mlp_map_opt} else: raise NotImplementedError mat['bsdf'] = FLAGS.bsdf mat['no_perturbed_nrm'] = FLAGS.no_perturbed_nrm return mat def initial_guess_material_knownkskd(geometry, mlp, FLAGS, init_mat=None): mat = { 'kd' : init_mat['kd'], 'ks' : init_mat['ks'] } if init_mat is not None: mat['bsdf'] = init_mat['bsdf'] else: mat['bsdf'] = 'pbr' return mat ############################################################################### # Validation & testing ############################################################################### @torch.no_grad() def validate_itr(glctx, target, geometry, opt_material, lgt, FLAGS, denoiser=None): result_dict = {} with torch.no_grad(): buffers = geometry.render(glctx, target, lgt, opt_material, use_uv=False, denoiser=denoiser)['buffers'] result_dict['ref'] = util.rgb_to_srgb(target['img'][...,0:3])[0] result_dict['opt'] = util.rgb_to_srgb(buffers['shaded'][...,0:3])[0] result_dict['mask_opt'] = buffers['shaded'][...,3:][0].expand(-1, -1, 3) result_dict['mask_ref'] = target['img'][...,3:][0].expand(-1, -1, 3) result_dict['msdf_image'] = buffers['msdf_image'][...,:][0].expand(-1, -1, 3).clamp(min=0, max=1) result_image = torch.cat([result_dict['opt'], result_dict['ref'], result_dict['mask_opt'], result_dict['mask_ref'], result_dict['msdf_image']], axis=1) if FLAGS.display is not None: white_bg = torch.ones_like(target['background']) for layer in FLAGS.display: if 'latlong' in layer and layer['latlong']: result_dict['light_image'] = lgt.generate_image(FLAGS.display_res) result_dict['light_image'] = util.rgb_to_srgb(result_dict['light_image'] / (1 + result_dict['light_image'])) result_image = torch.cat([result_image, result_dict['light_image']], axis=1) elif 'bsdf' in layer: img = render.render_mesh(FLAGS, glctx, opt_mesh, target['mvp'], target['campos'], target['light'] if lgt is None else lgt, target['resolution'], spp=target['spp'], num_layers=FLAGS.layers, background=white_bg, bsdf=layer['bsdf'], optix_ctx=geometry.optix_ctx)['shaded'] if layer['bsdf'] == 'kd': result_dict[layer['bsdf']] = util.rgb_to_srgb(img[..., 0:3])[0] else: result_dict[layer['bsdf']] = img[0, ..., 0:3] result_image = torch.cat([result_image, result_dict[layer['bsdf']]], axis=1) elif 'normals' in layer and not FLAGS.no_perturbed_nrm: result_image = torch.cat([result_image, (buffers['perturbed_nrm'][0, ...,0:3] + 1.0) * 0.5], axis=1) elif 'diffuse_light' in layer: result_image = torch.cat([result_image, util.rgb_to_srgb(buffers['diffuse_light'][..., 0:3])[0]], axis=1) elif 'specular_light' in layer: result_image = torch.cat([result_image, util.rgb_to_srgb(buffers['specular_light'][..., 0:3])[0]], axis=1) return result_image, result_dict @torch.no_grad() def validate(glctx, geometry, opt_material, lgt, dataset_validate, out_dir, FLAGS, denoiser=None, save_viz=False): # ============================================================================================== # Validation loop # ============================================================================================== mse_values = [] psnr_values = [] dataloader_validate = torch.utils.data.DataLoader(dataset_validate, batch_size=1, collate_fn=dataset_validate.collate) os.makedirs(out_dir, exist_ok=True) with open(os.path.join(out_dir, 'metrics.txt'), 'w') as fout: fout.write('ID, MSE, PSNR\n') print("Running validation") for it, target in enumerate(dataloader_validate): # Mix validation background target = prepare_batch(target, FLAGS.background) result_image, result_dict = validate_itr(glctx, target, geometry, opt_material, lgt, FLAGS, denoiser=denoiser) # Compute metrics opt = torch.clamp(result_dict['opt'], 0.0, 1.0) ref = torch.clamp(result_dict['ref'], 0.0, 1.0) mse = torch.nn.functional.mse_loss(opt, ref, size_average=None, reduce=None, reduction='mean').item() mse_values.append(float(mse)) psnr = util.mse_to_psnr(mse) psnr_values.append(float(psnr)) line = "%d, %1.8f, %1.8f\n" % (it, mse, psnr) fout.write(str(line)) if save_viz: for k in result_dict.keys(): np_img = result_dict[k].detach().cpu().numpy() util.save_image(out_dir + '/' + ('val_%06d_%s.png' % (it, k)), np_img) avg_mse = np.mean(np.array(mse_values)) avg_psnr = np.mean(np.array(psnr_values)) line = "AVERAGES: %1.4f, %2.3f\n" % (avg_mse, avg_psnr) fout.write(str(line)) print("MSE, PSNR") print("%1.8f, %2.3f" % (avg_mse, avg_psnr)) return avg_psnr ############################################################################### # Main shape fitter function / optimization loop ############################################################################### def optimize_mesh( denoiser, glctx, geometry, opt_material, lgt, dataset_train, dataset_validate, FLAGS, warmup_iter=0, log_interval=10, pass_idx=0, pass_name="", optimize_light=True, optimize_geometry=True, visualize=True, save_path=None ): # ============================================================================================== # Setup torch optimizer # ============================================================================================== learning_rate = FLAGS.learning_rate[pass_idx] if isinstance(FLAGS.learning_rate, list) or isinstance(FLAGS.learning_rate, tuple) else FLAGS.learning_rate learning_rate_pos = learning_rate[0] if isinstance(learning_rate, list) or isinstance(learning_rate, tuple) else learning_rate learning_rate_mat = learning_rate[1] if isinstance(learning_rate, list) or isinstance(learning_rate, tuple) else learning_rate learning_rate_lgt = learning_rate[2] if isinstance(learning_rate, list) or isinstance(learning_rate, tuple) else learning_rate * 6.0 def lr_schedule(iter, fraction): if iter < warmup_iter: return iter / warmup_iter return max(0.0, 10**(-(iter - warmup_iter)*0.0002)) # Exponential falloff from [1.0, 0.1] over 5k epochs. # ============================================================================================== # Image loss # ============================================================================================== image_loss_fn = createLoss(FLAGS) params = list(material.get_parameters(opt_material)) if optimize_light: optimizer_light = torch.optim.Adam((lgt.parameters() if lgt is not None else []), lr=learning_rate_lgt) scheduler_light = torch.optim.lr_scheduler.LambdaLR(optimizer_light, lr_lambda=lambda x: lr_schedule(x, 0.9)) if optimize_geometry: if FLAGS.use_sdf_mlp: lr_msdf = learning_rate_pos * 1e-2 if FLAGS.use_msdf_mlp else learning_rate_pos deform_params = list(v[1] for v in geometry.named_parameters() if 'deform' in v[0]) if optimize_geometry else [] msdf_params = list(v[1] for v in geometry.named_parameters() if 'msdf' in v[0]) if optimize_geometry else [] sdf_params = list(v[1] for v in geometry.named_parameters() if 'sdf' in v[0] and 'msdf' not in v[0]) if optimize_geometry else [] other_params = list(v[1] for v in geometry.named_parameters() if 'sdf' not in v[0] and 'msdf' not in v[0] and 'deform' not in v[0]) if optimize_geometry else [] optimizer_mesh = torch.optim.Adam([ {'params': deform_params, 'lr': learning_rate_pos}, {'params': msdf_params, 'lr': lr_msdf}, {'params': sdf_params, 'lr': learning_rate_pos * 1e-2}, {'params': other_params, 'lr': learning_rate_pos * 1e-2}, ], eps=1e-8) else: optimizer_mesh = torch.optim.Adam(geometry.parameters(), lr=learning_rate_pos) scheduler_mesh = torch.optim.lr_scheduler.LambdaLR(optimizer_mesh, lr_lambda=lambda x: lr_schedule(x, 0.9)) optimizer = torch.optim.Adam(params, lr=learning_rate_mat) scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda x: lr_schedule(x, 0.9)) # ============================================================================================== # Training loop # ============================================================================================== img_cnt = 0 img_loss_vec = [] depth_loss_vec = [] reg_loss_vec = [] iter_dur_vec = [] dataloader_train = torch.utils.data.DataLoader(dataset_train, batch_size=FLAGS.batch, collate_fn=dataset_train.collate, shuffle=True) if visualize: dataloader_validate = torch.utils.data.DataLoader(dataset_validate, batch_size=1, collate_fn=dataset_train.collate) def cycle(iterable): iterator = iter(iterable) while True: try: yield next(iterator) except StopIteration: iterator = iter(iterable) v_it = cycle(dataloader_validate) for it, target in enumerate(dataloader_train): # Mix randomized background into dataset image target = prepare_batch(target, 'random') # ============================================================================================== # Display / save outputs. Do it before training so we get initial meshes # ============================================================================================== # Show/save image before training step (want to get correct rendering of input) if visualize and FLAGS.local_rank == 0 and it != 0: with torch.no_grad(): display_image = FLAGS.display_interval and (it % FLAGS.display_interval == 0) save_image = FLAGS.save_interval and (it % FLAGS.save_interval == 0) if display_image or save_image: save_mesh = True if save_mesh: os.makedirs(os.path.join(save_path, pass_name), exist_ok=True) obj.write_obj(os.path.join(save_path, pass_name), geometry.getMesh(opt_material)['imesh'], save_material=False) result_image, result_dict = validate_itr(glctx, prepare_batch(next(v_it), FLAGS.background), geometry, opt_material, lgt, FLAGS, denoiser=denoiser) np_result_image = result_image.detach().cpu().numpy() if display_image: util.display_image(np_result_image, title='%d / %d' % (it, FLAGS.iter)) if save_image: util.save_image(os.path.join(save_path, ('img_%s_%06d.png' % (pass_name, img_cnt))), np_result_image) img_cnt = img_cnt + 1 iter_start_time = time.time() # ============================================================================================== # Zero gradients # ============================================================================================== optimizer.zero_grad() if optimize_geometry: optimizer_mesh.zero_grad() if optimize_light: optimizer_light.zero_grad() # ============================================================================================== # Training # ============================================================================================== xfm_lgt = None if optimize_light: lgt.update_pdf() img_loss, depth_loss, reg_loss = geometry.tick( glctx, target, lgt, opt_material, image_loss_fn, it, denoiser=denoiser) # ============================================================================================== # Final loss # ============================================================================================== total_loss = img_loss + reg_loss img_loss_vec.append(img_loss.item()) depth_loss_vec.append(depth_loss.item()) reg_loss_vec.append(reg_loss.item()) # ============================================================================================== # Backpropagate # ============================================================================================== total_loss.backward() if hasattr(lgt, 'base') and lgt.base.grad is not None and optimize_light: lgt.base.grad *= 64 if 'kd_ks' in opt_material: opt_material['kd_ks'].encoder.params.grad /= 8.0 if 'kd_ks_back' in opt_material: opt_material['kd_ks_back'].encoder.params.grad /= 8.0 # Optionally clip gradients if FLAGS.clip_max_norm > 0.0: if optimize_geometry: torch.nn.utils.clip_grad_norm_(geometry.parameters() + params, FLAGS.clip_max_norm) else: torch.nn.utils.clip_grad_norm_(params, FLAGS.clip_max_norm) optimizer.step() scheduler.step() if optimize_geometry: optimizer_mesh.step() scheduler_mesh.step() if optimize_light: optimizer_light.step() scheduler_light.step() # ============================================================================================== # Clamp trainables to reasonable range # ============================================================================================== with torch.no_grad(): if 'kd' in opt_material: opt_material['kd'].clamp_() if 'ks' in opt_material: opt_material['ks'].clamp_() if 'kd_back' in opt_material: opt_material['kd_back'].clamp_() if 'ks_back' in opt_material: opt_material['ks_back'].clamp_() if 'normal' in opt_material and not FLAGS.normal_only: opt_material['normal'].clamp_() opt_material['normal'].normalize_() if lgt is not None: # lgt.clamp_(min=0.01) # For some reason gradient dissapears if light becomes 0 lgt.clamp_(min=1e-4) # For some reason gradient dissapears if light becomes 0 geometry.clamp_deform() torch.cuda.current_stream().synchronize() iter_dur_vec.append(time.time() - iter_start_time) # ============================================================================================== # Logging # ============================================================================================== if it % log_interval == 0 and FLAGS.local_rank == 0: img_loss_avg = np.mean(np.asarray(img_loss_vec[-log_interval:])) depth_loss_avg = np.mean(np.asarray(depth_loss_vec[-log_interval:])) reg_loss_avg = np.mean(np.asarray(reg_loss_vec[-log_interval:])) iter_dur_avg = np.mean(np.asarray(iter_dur_vec[-log_interval:])) remaining_time = (FLAGS.iter-it)*iter_dur_avg print("iter=%5d, img_loss=%.6f, depth_loss=%.6f, reg_loss=%.6f, lr=%.5f, time=%.1f ms, rem=%s" % (it, img_loss_avg, depth_loss_avg, reg_loss_avg, optimizer.param_groups[0]['lr'], iter_dur_avg*1000, util.time_to_text(remaining_time))) sys.stdout.flush() if it == FLAGS.iter: break return geometry, opt_material #---------------------------------------------------------------------------- # Main function. #---------------------------------------------------------------------------- if __name__ == "__main__": parser = argparse.ArgumentParser(description='nvdiffrec') parser.add_argument('--config', type=str, default=None, help='Config file') parser.add_argument('-i', '--iter', type=int, default=5000) parser.add_argument('-b', '--batch', type=int, default=1) parser.add_argument('-s', '--spp', type=int, default=1) parser.add_argument('-l', '--layers', type=int, default=1) parser.add_argument('-r', '--train-res', nargs=2, type=int, default=[512, 512]) parser.add_argument('-dr', '--display-res', type=int, default=None) parser.add_argument('-tr', '--texture-res', nargs=2, type=int, default=[1024, 1024]) parser.add_argument('-di', '--display-interval', type=int, default=0) parser.add_argument('-si', '--save-interval', type=int, default=1000) parser.add_argument('-lr', '--learning-rate', type=float, default=0.01) parser.add_argument('-mr', '--min-roughness', type=float, default=0.08) parser.add_argument('-mip', '--custom-mip', action='store_true', default=False) parser.add_argument('-rt', '--random-textures', action='store_true', default=False) parser.add_argument('-bg', '--background', default='checker', choices=['black', 'white', 'checker', 'reference']) parser.add_argument('--loss', default='logl1', choices=['logl1', 'logl2', 'mse', 'smape', 'relmse']) parser.add_argument('-o', '--out-dir', type=str, default=None) parser.add_argument('-rm', '--ref_mesh', type=str) parser.add_argument('-bm', '--base-mesh', type=str, default=None) parser.add_argument('--validate', type=bool, default=True) # Render specific arguments parser.add_argument('--n_samples', type=int, default=4) parser.add_argument('--bsdf', type=str, default='pbr', choices=['pbr', 'diffuse', 'white']) # Denoiser specific arguments parser.add_argument('--denoiser', default='bilateral', choices=['none', 'bilateral']) parser.add_argument('--denoiser_demodulate', type=bool, default=True) parser.add_argument('--index',type=int) parser.add_argument('--msdf_reg_open_scale', type=float, default=1e-6) parser.add_argument('--msdf_reg_close_scale', type=float, default=3e-6) parser.add_argument('--eikonal_scale', type=float) parser.add_argument('--sdf_regularizer', type=float, default=0.2) parser.add_argument('--trainset_path', type=str) parser.add_argument('--testset_path', type=str, default='') FLAGS = parser.parse_args() FLAGS.mtl_override = None # Override material of model FLAGS.gshell_grid = 64 # Resolution of initial tet grid. We provide 64 and 128 resolution grids. # Other resolutions can be generated with https://github.com/crawforddoran/quartet # We include examples in data/tets/generate_tets.py FLAGS.mesh_scale = 1.4 # Scale of tet grid box. Adjust to cover the model FLAGS.envlight = None # HDR environment probe FLAGS.env_scale = 1.0 # Env map intensity multiplier FLAGS.probe_res = 256 # Env map probe resolution FLAGS.learn_lighting = True # Enable optimization of env lighting FLAGS.display = None # Configure validation window/display. E.g. [{"bsdf" : "kd"}, {"bsdf" : "ks"}] FLAGS.transparency = False # Enabled transparency through depth peeling FLAGS.lock_light = False # Disable light optimization in the second pass FLAGS.lock_pos = False # Disable vertex position optimization in the second pass # FLAGS.sdf_regularizer = 0.2 # Weight for sdf regularizer. FLAGS.laplace = "relative" # Mesh Laplacian ["absolute", "relative"] FLAGS.laplace_scale = 3000.0 # Weight for Laplace regularizer. Default is relative with large weight FLAGS.pre_load = True # Pre-load entire dataset into memory for faster training FLAGS.no_perturbed_nrm = False # Disable normal map FLAGS.decorrelated = False # Use decorrelated sampling in forward and backward passes FLAGS.kd_min = [ 0.0, 0.0, 0.0, 0.0] FLAGS.kd_max = [ 1.0, 1.0, 1.0, 1.0] FLAGS.ks_min = [ 0.0, 0.001, 0.0] FLAGS.ks_max = [ 0.0, 1.0, 1.0] FLAGS.nrm_min = [-1.0, -1.0, 0.0] FLAGS.nrm_max = [ 1.0, 1.0, 1.0] FLAGS.clip_max_norm = 0.0 FLAGS.cam_near_far = [0.1, 1000.0] FLAGS.lambda_kd = 0.1 FLAGS.lambda_ks = 0.05 FLAGS.lambda_nrm = 0.025 FLAGS.lambda_nrm2 = 0.25 FLAGS.lambda_chroma = 0.0 FLAGS.lambda_diffuse = 0.15 FLAGS.lambda_specular = 0.0025 FLAGS.random_lgt = False FLAGS.normal_only = False FLAGS.use_img_2nd_layer = False FLAGS.use_depth = False FLAGS.use_depth_2nd_layer = False FLAGS.use_tanh_deform = False FLAGS.use_sdf_mlp = True FLAGS.use_msdf_mlp = False FLAGS.use_eikonal = True FLAGS.sdf_mlp_pretrain_steps = 1000 FLAGS.use_mesh_msdf_reg = True FLAGS.sphere_init = False FLAGS.sphere_init_norm = 0.5 FLAGS.pretrained_sdf_mlp_path = f'./data/pretrained_mlp_{FLAGS.gshell_grid}_deeper.pt' FLAGS.n_hidden = 6 FLAGS.d_hidden = 256 FLAGS.n_freq = 6 FLAGS.skip_in = [3] FLAGS.use_float16 = False FLAGS.visualize_watertight = False FLAGS.local_rank = 0 FLAGS.multi_gpu = "WORLD_SIZE" in os.environ and int(os.environ["WORLD_SIZE"]) > 1 if FLAGS.multi_gpu: if "MASTER_ADDR" not in os.environ: os.environ["MASTER_ADDR"] = 'localhost' if "MASTER_PORT" not in os.environ: os.environ["MASTER_PORT"] = '23456' FLAGS.local_rank = int(os.environ["LOCAL_RANK"]) torch.cuda.set_device(FLAGS.local_rank) torch.distributed.init_process_group(backend="nccl", init_method="env://") if FLAGS.config is not None: data = json.load(open(FLAGS.config, 'r')) for key in data: FLAGS.__dict__[key] = data[key] if FLAGS.display_res is None: FLAGS.display_res = FLAGS.train_res if FLAGS.local_rank == 0: print("Config / Flags:") print("---------") for key in FLAGS.__dict__.keys(): print(key, FLAGS.__dict__[key]) print("---------") os.makedirs(FLAGS.out_dir, exist_ok=True) glctx = dr.RasterizeGLContext() glctx_display = glctx if FLAGS.batch < 16 else dr.RasterizeGLContext() # Context for display mtl_default = None # ============================================================================================== # Create data pipeline # ============================================================================================== dataset_path = FLAGS.trainset_path testset_path = FLAGS.testset_path folder_name_list = [30, 92, 117, 133, 164, 320, 448, 522, 591] folder_name = folder_name_list[FLAGS.index] folder_name = str(folder_name) data_root = os.path.join(dataset_path, folder_name) dataset_train = DatasetDeepFashion(data_root, FLAGS, examples=int(1e6)) dataset_validate = DatasetDeepFashion(data_root, FLAGS) if FLAGS.testset_path is not None and FLAGS.testset_path != '': testdata_root = os.path.join(testset_path, folder_name) dataset_test = DatasetDeepFashionTestset(testdata_root, FLAGS) # ============================================================================================== # Create env light with trainable parameters # ============================================================================================== lgt = None if FLAGS.learn_lighting: lgt = light.create_trainable_env_rnd(FLAGS.probe_res, scale=0.0, bias=0.5) # lgt = light.create_trainable_env_rnd(FLAGS.probe_res, scale=0.0, bias=0.1) else: lgt = light.load_env(FLAGS.envlight, scale=FLAGS.env_scale, res=[FLAGS.probe_res, FLAGS.probe_res]) # ============================================================================================== # Setup denoiser # ============================================================================================== denoiser = None if FLAGS.denoiser == 'bilateral': denoiser = BilateralDenoiser().cuda() else: assert FLAGS.denoiser == 'none', "Invalid denoiser %s" % FLAGS.denoiser # Setup geometry for optimization geometry = GShellTetsGeometry(FLAGS.gshell_grid, FLAGS.mesh_scale, FLAGS) # Setup textures, make initial guess from reference if possible if not FLAGS.normal_only: mat = initial_guess_material(geometry, True, FLAGS, mtl_default) else: mat = initial_guess_material_knownkskd(geometry, True, FLAGS, mtl_default) mat['no_perturbed_nrm'] = True # Run optimization geometry, mat = optimize_mesh(denoiser, glctx, geometry, mat, lgt, dataset_train, dataset_validate, FLAGS, pass_idx=0, pass_name="pass1", optimize_light=FLAGS.learn_lighting, save_path=os.path.join(FLAGS.out_dir, folder_name)) validate(glctx, geometry, mat, lgt, dataset_validate, os.path.join(FLAGS.out_dir, folder_name, "validate"), FLAGS, denoiser=denoiser, save_viz=True) if FLAGS.testset_path is not None and FLAGS.testset_path != '': validate(glctx, geometry, mat, lgt, dataset_test, os.path.join(FLAGS.out_dir, folder_name, "test"), FLAGS, denoiser=denoiser, save_viz=False) with torch.no_grad(): os.makedirs(os.path.join(FLAGS.out_dir, folder_name, "mesh"), exist_ok=True) torch.save(geometry.state_dict(), os.path.join(FLAGS.out_dir, folder_name, "mesh/model.pt")) torch.save(mat['kd_ks'].state_dict(), os.path.join(FLAGS.out_dir, folder_name, "mesh/mtl.pt")) light.save_env_map(os.path.join(FLAGS.out_dir, folder_name, "mesh/probe.hdr"), lgt) # Create textured mesh from result base_mesh = geometry.getMesh(mat)['imesh'] # Dump mesh for debugging. os.makedirs(os.path.join(FLAGS.out_dir, folder_name, "mesh"), exist_ok=True) obj.write_obj(os.path.join(FLAGS.out_dir, folder_name, "mesh/"), base_mesh, save_material=False) # Free temporaries / cached memory torch.cuda.empty_cache() mat['kd_ks'].cleanup() del mat['kd_ks'] if 'kd_ks_back' in mat: mat['kd_ks_back'].cleanup() del mat['kd_ks_back'] # Free temporaries / cached memory torch.cuda.empty_cache() del mat ================================================ FILE: train_gshelltet_polycam.py ================================================ # Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual # property and proprietary rights in and to this material, related # documentation and any modifications thereto. Any use, reproduction, # disclosure or distribution of this material and related documentation # without an express license agreement from NVIDIA CORPORATION or # its affiliates is strictly prohibited. import os import sys import time import argparse import json import numpy as np import torch import nvdiffrast.torch as dr import xatlas # Import data readers / generators from dataset.dataset_nerf_colmap import DatasetNERF # Import topology / geometry trainers from geometry.gshell_tets_geometry import GShellTetsGeometry import render.renderutils as ru from render import obj from render import material from render import util from render import mesh from render import texture from render import mlptexture from render import light from render import render from denoiser.denoiser import BilateralDenoiser import tqdm RADIUS = 3.0 # Enable to debug back-prop anomalies # torch.autograd.set_detect_anomaly(True) ############################################################################### # Loss setup ############################################################################### @torch.no_grad() def createLoss(FLAGS): if FLAGS.loss == "smape": return lambda img, ref: ru.image_loss(img, ref, loss='smape', tonemapper='none') elif FLAGS.loss == "mse": return lambda img, ref: ru.image_loss(img, ref, loss='mse', tonemapper='none') elif FLAGS.loss == "logl1": return lambda img, ref: ru.image_loss(img, ref, loss='l1', tonemapper='log_srgb') elif FLAGS.loss == "logl2": return lambda img, ref: ru.image_loss(img, ref, loss='mse', tonemapper='log_srgb') elif FLAGS.loss == "relmse": return lambda img, ref: ru.image_loss(img, ref, loss='relmse', tonemapper='none') else: assert False ############################################################################### # Mix background into a dataset image ############################################################################### @torch.no_grad() def prepare_batch(target, bg_type='black'): assert len(target['img'].shape) == 4, "Image shape should be [n, h, w, c]" if bg_type == 'checker': background = torch.tensor(util.checkerboard(target['img'].shape[1:3], 8), dtype=torch.float32, device='cuda')[None, ...] elif bg_type == 'black': background = torch.zeros(target['img'].shape[0:3] + (3,), dtype=torch.float32, device='cuda') elif bg_type == 'white': background = torch.ones(target['img'].shape[0:3] + (3,), dtype=torch.float32, device='cuda') elif bg_type == 'reference': background = target['img'][..., 0:3] elif bg_type == 'random': background = torch.rand(target['img'].shape[0:3] + (3,), dtype=torch.float32, device='cuda') else: assert False, "Unknown background type %s" % bg_type target['mv'] = target['mv'].cuda() target['mvp'] = target['mvp'].cuda() target['campos'] = target['campos'].cuda() target['img'] = target['img'].cuda() target['background'] = background target['img'] = torch.cat((torch.lerp(background, target['img'][..., 0:3], target['img'][..., 3:4]), target['img'][..., 3:4]), dim=-1) return target ############################################################################### # UV - map geometry & convert to a mesh ############################################################################### @torch.no_grad() def xatlas_uvmap(glctx, geometry, mat, FLAGS): eval_mesh = geometry.getMesh(mat) try: eval_mesh = eval_mesh['imesh'] except: pass # Create uvs with xatlas v_pos = eval_mesh.v_pos.detach().cpu().numpy() t_pos_idx = eval_mesh.t_pos_idx.detach().cpu().numpy() vmapping, indices, uvs = xatlas.parametrize(v_pos, t_pos_idx) # Convert to tensors indices_int64 = indices.astype(np.uint64, casting='same_kind').view(np.int64) uvs = torch.tensor(uvs, dtype=torch.float32, device='cuda') faces = torch.tensor(indices_int64, dtype=torch.int64, device='cuda') new_mesh = mesh.Mesh(v_tex=uvs, t_tex_idx=faces, base=eval_mesh) mask, kd, ks = render.render_uv(glctx, new_mesh, FLAGS.texture_res, eval_mesh.material['kd_ks']) # Dilate all textures & use average color for background kd_avg = torch.sum(torch.sum(torch.sum(kd * mask, dim=0), dim=0), dim=0) / torch.sum(torch.sum(torch.sum(mask, dim=0), dim=0), dim=0) kd = util.dilate(kd, kd_avg[None, None, None, :], mask, 7) ks_avg = torch.sum(torch.sum(torch.sum(ks * mask, dim=0), dim=0), dim=0) / torch.sum(torch.sum(torch.sum(mask, dim=0), dim=0), dim=0) ks = util.dilate(ks, ks_avg[None, None, None, :], mask, 7) nrm_avg = torch.tensor([0, 0, 1], dtype=torch.float32, device="cuda") normal = nrm_avg[None, None, None, :].repeat(kd.shape[0], kd.shape[1], kd.shape[2], 1) new_mesh.material = mat.copy() del new_mesh.material['kd_ks'] if FLAGS.transparency: kd = torch.cat((kd, torch.rand_like(kd[...,0:1])), dim=-1) print("kd shape", kd.shape) kd_min, kd_max = torch.tensor(FLAGS.kd_min, dtype=torch.float32, device='cuda'), torch.tensor(FLAGS.kd_max, dtype=torch.float32, device='cuda') ks_min, ks_max = torch.tensor(FLAGS.ks_min, dtype=torch.float32, device='cuda'), torch.tensor(FLAGS.ks_max, dtype=torch.float32, device='cuda') nrm_min, nrm_max = torch.tensor(FLAGS.nrm_min, dtype=torch.float32, device='cuda'), torch.tensor(FLAGS.nrm_max, dtype=torch.float32, device='cuda') new_mesh.material.update({ 'kd' : texture.Texture2D(kd.clone().detach().requires_grad_(True), min_max=[kd_min, kd_max]), 'ks' : texture.Texture2D(ks.clone().detach().requires_grad_(True), min_max=[ks_min, ks_max]), 'normal' : texture.Texture2D(normal.clone().detach().requires_grad_(True), min_max=[nrm_min, nrm_max]), }) return new_mesh ############################################################################### # Utility functions for material ############################################################################### def initial_guess_material(geometry, mlp, FLAGS, init_mat=None): kd_min, kd_max = torch.tensor(FLAGS.kd_min, dtype=torch.float32, device='cuda'), torch.tensor(FLAGS.kd_max, dtype=torch.float32, device='cuda') ks_min, ks_max = torch.tensor(FLAGS.ks_min, dtype=torch.float32, device='cuda'), torch.tensor(FLAGS.ks_max, dtype=torch.float32, device='cuda') if mlp: mlp_min = torch.cat((kd_min[0:3], ks_min), dim=0) mlp_max = torch.cat((kd_max[0:3], ks_max), dim=0) mlp_map_opt = mlptexture.MLPTexture3D(geometry.getAABB(), channels=6, min_max=[mlp_min, mlp_max], use_float16=FLAGS.use_float16) mat = {'kd_ks' : mlp_map_opt} else: raise NotImplementedError mat['bsdf'] = FLAGS.bsdf mat['no_perturbed_nrm'] = FLAGS.no_perturbed_nrm return mat def initial_guess_material_knownkskd(geometry, mlp, FLAGS, init_mat=None): mat = { 'kd' : init_mat['kd'], 'ks' : init_mat['ks'] } if init_mat is not None: mat['bsdf'] = init_mat['bsdf'] else: mat['bsdf'] = 'pbr' return mat ############################################################################### # Validation & testing ############################################################################### @torch.no_grad() def validate_itr(glctx, target, geometry, opt_material, lgt, FLAGS, denoiser=None): result_dict = {} with torch.no_grad(): buffers = geometry.render(glctx, target, lgt, opt_material, use_uv=False, denoiser=denoiser)['buffers'] result_dict['ref'] = util.rgb_to_srgb(target['img'][...,0:3])[0] result_dict['opt'] = util.rgb_to_srgb(buffers['shaded'][...,0:3])[0] result_dict['mask_opt'] = buffers['shaded'][...,3:][0].expand(-1, -1, 3) result_dict['mask_ref'] = target['img'][...,3:][0].expand(-1, -1, 3) result_dict['msdf_image'] = buffers['msdf_image'][...,:][0].expand(-1, -1, 3).clamp(min=0, max=1) result_image = torch.cat([result_dict['opt'], result_dict['ref'], result_dict['mask_opt'], result_dict['mask_ref'], result_dict['msdf_image']], axis=1) result_dict = {} result_dict['ref'] = util.rgb_to_srgb(target['img'][...,0:3])[0] result_dict['opt'] = util.rgb_to_srgb(buffers['shaded'][...,0:3])[0] return result_image, result_dict @torch.no_grad() def validate(glctx, geometry, opt_material, lgt, dataset_validate, out_dir, FLAGS, denoiser=None, save_viz=False): # ============================================================================================== # Validation loop # ============================================================================================== mse_values = [] psnr_values = [] dataloader_validate = torch.utils.data.DataLoader(dataset_validate, batch_size=1, collate_fn=dataset_validate.collate) os.makedirs(out_dir, exist_ok=True) with open(os.path.join(out_dir, 'metrics.txt'), 'w') as fout: fout.write('ID, MSE, PSNR\n') print("Running validation") for it, target in enumerate(tqdm.tqdm(dataloader_validate)): # Mix validation background target = prepare_batch(target, FLAGS.background) result_image, result_dict = validate_itr(glctx, target, geometry, opt_material, lgt, FLAGS, denoiser=denoiser) # Compute metrics opt = torch.clamp(result_dict['opt'], 0.0, 1.0) ref = torch.clamp(result_dict['ref'], 0.0, 1.0) mse = torch.nn.functional.mse_loss(opt, ref, size_average=None, reduce=None, reduction='mean').item() mse_values.append(float(mse)) psnr = util.mse_to_psnr(mse) psnr_values.append(float(psnr)) line = "%d, %1.8f, %1.8f\n" % (it, mse, psnr) fout.write(str(line)) if save_viz: for k in result_dict.keys(): np_img = result_dict[k].detach().cpu().numpy() util.save_image(out_dir + '/' + ('val_%06d_%s.png' % (it, k)), np_img) avg_mse = np.mean(np.array(mse_values)) avg_psnr = np.mean(np.array(psnr_values)) line = "AVERAGES: %1.4f, %2.3f\n" % (avg_mse, avg_psnr) fout.write(str(line)) print("MSE, PSNR") print("%1.8f, %2.3f" % (avg_mse, avg_psnr)) return avg_psnr ############################################################################### # Main shape fitter function / optimization loop ############################################################################### def optimize_mesh( denoiser, glctx, geometry, opt_material, lgt, dataset_train, dataset_validate, FLAGS, warmup_iter=0, log_interval=10, pass_idx=0, pass_name="", optimize_light=True, optimize_geometry=True, visualize=True, save_path=None ): # ============================================================================================== # Setup torch optimizer # ============================================================================================== learning_rate = FLAGS.learning_rate[pass_idx] if isinstance(FLAGS.learning_rate, list) or isinstance(FLAGS.learning_rate, tuple) else FLAGS.learning_rate learning_rate_pos = learning_rate[0] if isinstance(learning_rate, list) or isinstance(learning_rate, tuple) else learning_rate learning_rate_mat = learning_rate[1] if isinstance(learning_rate, list) or isinstance(learning_rate, tuple) else learning_rate # learning_rate_lgt = learning_rate[2] if isinstance(learning_rate, list) or isinstance(learning_rate, tuple) else learning_rate * 3.0 learning_rate_lgt = learning_rate[2] if isinstance(learning_rate, list) or isinstance(learning_rate, tuple) else learning_rate * 6.0 # learning_rate_lgt = learning_rate[2] if isinstance(learning_rate, list) or isinstance(learning_rate, tuple) else learning_rate * 0.5 def lr_schedule(iter, fraction): if iter < warmup_iter: return iter / warmup_iter return max(0.0, 10**(-(iter - warmup_iter)*0.0002)) # Exponential falloff from [1.0, 0.1] over 5k epochs. # ============================================================================================== # Image loss # ============================================================================================== image_loss_fn = createLoss(FLAGS) params = list(material.get_parameters(opt_material)) if optimize_light: optimizer_light = torch.optim.Adam((lgt.parameters() if lgt is not None else []), lr=learning_rate_lgt) scheduler_light = torch.optim.lr_scheduler.LambdaLR(optimizer_light, lr_lambda=lambda x: lr_schedule(x, 0.9)) if optimize_geometry: if FLAGS.use_sdf_mlp: lr_msdf = learning_rate_pos * 1e-2 if FLAGS.use_msdf_mlp else learning_rate_pos deform_params = list(v[1] for v in geometry.named_parameters() if 'deform' in v[0]) if optimize_geometry else [] msdf_params = list(v[1] for v in geometry.named_parameters() if 'msdf' in v[0]) if optimize_geometry else [] sdf_params = list(v[1] for v in geometry.named_parameters() if 'sdf' in v[0] and 'msdf' not in v[0]) if optimize_geometry else [] other_params = list(v[1] for v in geometry.named_parameters() if 'sdf' not in v[0] and 'msdf' not in v[0] and 'deform' not in v[0]) if optimize_geometry else [] optimizer_mesh = torch.optim.Adam([ {'params': deform_params, 'lr': learning_rate_pos}, {'params': msdf_params, 'lr': lr_msdf}, {'params': sdf_params, 'lr': learning_rate_pos * 1e-2}, {'params': other_params, 'lr': learning_rate_pos * 1e-2}, ], eps=1e-8) else: optimizer_mesh = torch.optim.Adam(geometry.parameters(), lr=learning_rate_pos) scheduler_mesh = torch.optim.lr_scheduler.LambdaLR(optimizer_mesh, lr_lambda=lambda x: lr_schedule(x, 0.9)) optimizer = torch.optim.Adam(params, lr=learning_rate_mat) scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda x: lr_schedule(x, 0.9)) # ============================================================================================== # Training loop # ============================================================================================== img_cnt = 0 img_loss_vec = [] depth_loss_vec = [] reg_loss_vec = [] iter_dur_vec = [] dataloader_train = torch.utils.data.DataLoader(dataset_train, batch_size=FLAGS.batch, collate_fn=dataset_train.collate, shuffle=True) if visualize: dataloader_validate = torch.utils.data.DataLoader(dataset_validate, batch_size=1, collate_fn=dataset_train.collate) def cycle(iterable): iterator = iter(iterable) while True: try: yield next(iterator) except StopIteration: iterator = iter(iterable) v_it = cycle(dataloader_validate) for it, target in enumerate(dataloader_train): # Mix randomized background into dataset image target = prepare_batch(target, 'random') # ============================================================================================== # Display / save outputs. Do it before training so we get initial meshes # ============================================================================================== # Show/save image before training step (want to get correct rendering of input) if visualize and FLAGS.local_rank == 0 and it != 0: with torch.no_grad(): display_image = FLAGS.display_interval and (it % FLAGS.display_interval == 0) save_image = FLAGS.save_interval and (it % FLAGS.save_interval == 0) if display_image or save_image: save_mesh = True if save_mesh: os.makedirs(os.path.join(save_path, pass_name), exist_ok=True) obj.write_obj(os.path.join(save_path, pass_name), geometry.getMesh(opt_material)['imesh'], save_material=False) result_image, result_dict = validate_itr(glctx, prepare_batch(next(v_it), FLAGS.background), geometry, opt_material, lgt, FLAGS, denoiser=denoiser) np_result_image = result_image.detach().cpu().numpy() if display_image: util.display_image(np_result_image, title='%d / %d' % (it, FLAGS.iter)) if save_image: util.save_image(os.path.join(save_path, ('img_%s_%06d.png' % (pass_name, img_cnt))), np_result_image) img_cnt = img_cnt + 1 iter_start_time = time.time() # ============================================================================================== # Zero gradients # ============================================================================================== optimizer.zero_grad() if optimize_geometry: optimizer_mesh.zero_grad() if optimize_light: optimizer_light.zero_grad() # ============================================================================================== # Training # ============================================================================================== xfm_lgt = None if optimize_light: if False and FLAGS.camera_space_light: lgt.xfm(target['mv']) elif False and ('envlight_transform' in target and target['envlight_transform'] is not None): xfm_lgt = target['envlight_transform'] lgt.xfm(xfm_lgt) lgt.update_pdf() img_loss, depth_loss, reg_loss = geometry.tick( glctx, target, lgt, opt_material, image_loss_fn, it, denoiser=denoiser) # ============================================================================================== # Final loss # ============================================================================================== total_loss = img_loss + reg_loss img_loss_vec.append(img_loss.item()) depth_loss_vec.append(depth_loss.item()) reg_loss_vec.append(reg_loss.item()) # ============================================================================================== # Backpropagate # ============================================================================================== total_loss.backward() if hasattr(lgt, 'base') and lgt.base.grad is not None and optimize_light: lgt.base.grad *= 64 if 'kd_ks' in opt_material: opt_material['kd_ks'].encoder.params.grad /= 8.0 if 'kd_ks_back' in opt_material: opt_material['kd_ks_back'].encoder.params.grad /= 8.0 # Optionally clip gradients if FLAGS.clip_max_norm > 0.0: if optimize_geometry: torch.nn.utils.clip_grad_norm_(geometry.parameters() + params, FLAGS.clip_max_norm) else: torch.nn.utils.clip_grad_norm_(params, FLAGS.clip_max_norm) optimizer.step() scheduler.step() if optimize_geometry: optimizer_mesh.step() scheduler_mesh.step() if optimize_light: optimizer_light.step() scheduler_light.step() # ============================================================================================== # Clamp trainables to reasonable range # ============================================================================================== with torch.no_grad(): if 'kd' in opt_material: opt_material['kd'].clamp_() if 'ks' in opt_material: opt_material['ks'].clamp_() if 'kd_back' in opt_material: opt_material['kd_back'].clamp_() if 'ks_back' in opt_material: opt_material['ks_back'].clamp_() if 'normal' in opt_material and not FLAGS.normal_only: opt_material['normal'].clamp_() opt_material['normal'].normalize_() if lgt is not None: # lgt.clamp_(min=0.01) # For some reason gradient dissapears if light becomes 0 lgt.clamp_(min=1e-4) # For some reason gradient dissapears if light becomes 0 geometry.clamp_deform() torch.cuda.current_stream().synchronize() iter_dur_vec.append(time.time() - iter_start_time) # ============================================================================================== # Logging # ============================================================================================== if it % log_interval == 0 and FLAGS.local_rank == 0: img_loss_avg = np.mean(np.asarray(img_loss_vec[-log_interval:])) depth_loss_avg = np.mean(np.asarray(depth_loss_vec[-log_interval:])) reg_loss_avg = np.mean(np.asarray(reg_loss_vec[-log_interval:])) iter_dur_avg = np.mean(np.asarray(iter_dur_vec[-log_interval:])) remaining_time = (FLAGS.iter-it)*iter_dur_avg print("iter=%5d, img_loss=%.6f, depth_loss=%.6f, reg_loss=%.6f, lr=%.5f, time=%.1f ms, rem=%s" % (it, img_loss_avg, depth_loss_avg, reg_loss_avg, optimizer.param_groups[0]['lr'], iter_dur_avg*1000, util.time_to_text(remaining_time))) sys.stdout.flush() if it == FLAGS.iter: break return geometry, opt_material #---------------------------------------------------------------------------- # Main function. #---------------------------------------------------------------------------- if __name__ == "__main__": parser = argparse.ArgumentParser(description='nvdiffrec') parser.add_argument('--config', type=str, default=None, help='Config file') parser.add_argument('-i', '--iter', type=int, default=5000) parser.add_argument('-b', '--batch', type=int, default=1) parser.add_argument('-s', '--spp', type=int, default=1) parser.add_argument('-l', '--layers', type=int, default=1) parser.add_argument('-r', '--train-res', nargs=2, type=int, default=[512, 512]) parser.add_argument('-dr', '--display-res', type=int, default=None) parser.add_argument('-tr', '--texture-res', nargs=2, type=int, default=[1024, 1024]) parser.add_argument('-di', '--display-interval', type=int, default=0) parser.add_argument('-si', '--save-interval', type=int, default=1000) parser.add_argument('-lr', '--learning-rate', type=float, default=0.01) parser.add_argument('-mr', '--min-roughness', type=float, default=0.08) parser.add_argument('-mip', '--custom-mip', action='store_true', default=False) parser.add_argument('-rt', '--random-textures', action='store_true', default=False) parser.add_argument('-bg', '--background', default='checker', choices=['black', 'white', 'checker', 'reference']) parser.add_argument('--loss', default='logl1', choices=['logl1', 'logl2', 'mse', 'smape', 'relmse']) parser.add_argument('-o', '--out-dir', type=str, default=None) parser.add_argument('-rm', '--ref_mesh', type=str) parser.add_argument('-bm', '--base-mesh', type=str, default=None) parser.add_argument('--validate', type=bool, default=True) # Render specific arguments parser.add_argument('--n_samples', type=int, default=4) parser.add_argument('--bsdf', type=str, default='pbr', choices=['pbr', 'diffuse', 'white']) # Denoiser specific arguments parser.add_argument('--denoiser', default='bilateral', choices=['none', 'bilateral']) parser.add_argument('--denoiser_demodulate', type=bool, default=True) parser.add_argument('--msdf_reg_open_scale', type=float, default=1e-6) parser.add_argument('--msdf_reg_close_scale', type=float, default=3e-4) parser.add_argument('--eikonal_scale', type=float, default=5e-3) parser.add_argument('--sdf_regularizer', type=float, default=0.2) parser.add_argument('--trainset_path', type=str) parser.add_argument('--testset_path', type=str, default='') FLAGS = parser.parse_args() FLAGS.mtl_override = None # Override material of model FLAGS.gshell_grid = 64 # Resolution of initial tet grid. We provide 64 and 128 resolution grids. # Other resolutions can be generated with https://github.com/crawforddoran/quartet # We include examples in data/tets/generate_tets.py FLAGS.mesh_scale = 3.6 # Scale of tet grid box. Adjust to cover the model FLAGS.envlight = None # HDR environment probe FLAGS.env_scale = 1.0 # Env map intensity multiplier FLAGS.probe_res = 256 # Env map probe resolution FLAGS.learn_lighting = True # Enable optimization of env lighting FLAGS.display = None # Configure validation window/display. E.g. [{"bsdf" : "kd"}, {"bsdf" : "ks"}] FLAGS.transparency = False # Enabled transparency through depth peeling FLAGS.lock_light = False # Disable light optimization in the second pass FLAGS.lock_pos = False # Disable vertex position optimization in the second pass # FLAGS.sdf_regularizer = 0.2 # Weight for sdf regularizer. FLAGS.laplace = "relative" # Mesh Laplacian ["absolute", "relative"] FLAGS.laplace_scale = 3000.0 # Weight for Laplace regularizer. Default is relative with large weight FLAGS.pre_load = True # Pre-load entire dataset into memory for faster training FLAGS.no_perturbed_nrm = False # Disable normal map FLAGS.decorrelated = False # Use decorrelated sampling in forward and backward passes FLAGS.kd_min = [ 0.0, 0.0, 0.0, 0.0] FLAGS.kd_max = [ 1.0, 1.0, 1.0, 1.0] FLAGS.ks_min = [ 0.0, 0.001, 0.0] FLAGS.ks_max = [ 0.0, 1.0, 1.0] FLAGS.nrm_min = [-1.0, -1.0, 0.0] FLAGS.nrm_max = [ 1.0, 1.0, 1.0] FLAGS.clip_max_norm = 0.0 FLAGS.cam_near_far = [0.1, 1000.0] FLAGS.lambda_kd = 0.1 FLAGS.lambda_ks = 0.05 FLAGS.lambda_nrm = 0.025 FLAGS.lambda_nrm2 = 0.25 FLAGS.lambda_chroma = 0.0 FLAGS.lambda_diffuse = 0.15 FLAGS.lambda_specular = 0.0025 FLAGS.random_lgt = False FLAGS.normal_only = False FLAGS.use_img_2nd_layer = False FLAGS.use_depth = False FLAGS.use_depth_2nd_layer = False FLAGS.use_tanh_deform = False FLAGS.use_sdf_mlp = True FLAGS.use_msdf_mlp = False FLAGS.use_eikonal = True FLAGS.sdf_mlp_pretrain_steps = 10000 FLAGS.use_mesh_msdf_reg = True FLAGS.sphere_init = False FLAGS.sphere_init_norm = 2.0 FLAGS.pretrained_sdf_mlp_path = f'./data/pretrained_mlp_{FLAGS.gshell_grid}_polycam.pt' FLAGS.n_hidden = 6 FLAGS.d_hidden = 256 FLAGS.n_freq = 6 FLAGS.skip_in = [3] FLAGS.use_float16 = False FLAGS.visualize_watertight = False FLAGS.local_rank = 0 FLAGS.multi_gpu = "WORLD_SIZE" in os.environ and int(os.environ["WORLD_SIZE"]) > 1 if FLAGS.multi_gpu: if "MASTER_ADDR" not in os.environ: os.environ["MASTER_ADDR"] = 'localhost' if "MASTER_PORT" not in os.environ: os.environ["MASTER_PORT"] = '23456' FLAGS.local_rank = int(os.environ["LOCAL_RANK"]) torch.cuda.set_device(FLAGS.local_rank) torch.distributed.init_process_group(backend="nccl", init_method="env://") if FLAGS.config is not None: data = json.load(open(FLAGS.config, 'r')) for key in data: FLAGS.__dict__[key] = data[key] if FLAGS.display_res is None: FLAGS.display_res = FLAGS.train_res if FLAGS.local_rank == 0: print("Config / Flags:") print("---------") for key in FLAGS.__dict__.keys(): print(key, FLAGS.__dict__[key]) print("---------") os.makedirs(FLAGS.out_dir, exist_ok=True) glctx = dr.RasterizeGLContext() glctx_display = glctx if FLAGS.batch < 16 else dr.RasterizeGLContext() # Context for display mtl_default = None # ============================================================================================== # Create data pipeline # ============================================================================================== data_root = FLAGS.trainset_path dataset_train = DatasetNERF(os.path.join(data_root, 'transforms.json'), FLAGS, examples=int(1e6)) dataset_validate = DatasetNERF(os.path.join(data_root, 'transforms.json'), FLAGS) # ============================================================================================== # Create env light with trainable parameters # ============================================================================================== lgt = None if FLAGS.learn_lighting: lgt = light.create_trainable_env_rnd(FLAGS.probe_res, scale=0.0, bias=0.5) # lgt = light.create_trainable_env_rnd(FLAGS.probe_res, scale=0.0, bias=0.1) else: lgt = light.load_env(FLAGS.envlight, scale=FLAGS.env_scale, res=[FLAGS.probe_res, FLAGS.probe_res]) # ============================================================================================== # Setup denoiser # ============================================================================================== denoiser = None if FLAGS.denoiser == 'bilateral': denoiser = BilateralDenoiser().cuda() else: assert FLAGS.denoiser == 'none', "Invalid denoiser %s" % FLAGS.denoiser # Setup geometry for optimization geometry = GShellTetsGeometry(FLAGS.gshell_grid, FLAGS.mesh_scale, FLAGS) # Setup textures, make initial guess from reference if possible if not FLAGS.normal_only: mat = initial_guess_material(geometry, True, FLAGS, mtl_default) else: mat = initial_guess_material_knownkskd(geometry, True, FLAGS, mtl_default) mat['no_perturbed_nrm'] = True # Run optimization geometry, mat = optimize_mesh(denoiser, glctx, geometry, mat, lgt, dataset_train, dataset_validate, FLAGS, pass_idx=0, pass_name="pass1", optimize_light=FLAGS.learn_lighting, save_path=FLAGS.out_dir) validate(glctx, geometry, mat, lgt, dataset_validate, os.path.join(FLAGS.out_dir, "validate"), FLAGS, denoiser=denoiser, save_viz=True) with torch.no_grad(): os.makedirs(os.path.join(FLAGS.out_dir, "mesh"), exist_ok=True) torch.save(geometry.state_dict(), os.path.join(FLAGS.out_dir, "mesh/model.pt")) torch.save(mat['kd_ks'].state_dict(), os.path.join(FLAGS.out_dir, "mesh/mtl.pt")) light.save_env_map(os.path.join(FLAGS.out_dir, "mesh/probe.hdr"), lgt) # Create textured mesh from result base_mesh = geometry.getMesh(mat)['imesh'] # Dump mesh for debugging. os.makedirs(os.path.join(FLAGS.out_dir, "mesh"), exist_ok=True) obj.write_obj(os.path.join(FLAGS.out_dir, "mesh/"), base_mesh, save_material=False) # Free temporaries / cached memory torch.cuda.empty_cache() mat['kd_ks'].cleanup() del mat['kd_ks'] if 'kd_ks_back' in mat: mat['kd_ks_back'].cleanup() del mat['kd_ks_back'] # Free temporaries / cached memory torch.cuda.empty_cache() del mat ================================================ FILE: train_gshelltet_synthetic.py ================================================ # Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual # property and proprietary rights in and to this material, related # documentation and any modifications thereto. Any use, reproduction, # disclosure or distribution of this material and related documentation # without an express license agreement from NVIDIA CORPORATION or # its affiliates is strictly prohibited. import os import sys import time import argparse import json import numpy as np import torch import nvdiffrast.torch as dr import xatlas # Import data readers / generators from dataset import DatasetMesh, DatasetNERF, DatasetLLFF # Import topology / geometry trainers from geometry.gshell_tets_geometry import GShellTetsGeometry import render.renderutils as ru from render import obj from render import material from render import util from render import mesh from render import texture from render import mlptexture from render import light from render import render from denoiser.denoiser import BilateralDenoiser RADIUS = 3.0 # Enable to debug back-prop anomalies # torch.autograd.set_detect_anomaly(True) ############################################################################### # Loss setup ############################################################################### @torch.no_grad() def createLoss(FLAGS): if FLAGS.loss == "smape": return lambda img, ref: ru.image_loss(img, ref, loss='smape', tonemapper='none') elif FLAGS.loss == "mse": return lambda img, ref: ru.image_loss(img, ref, loss='mse', tonemapper='none') elif FLAGS.loss == "logl1": return lambda img, ref: ru.image_loss(img, ref, loss='l1', tonemapper='log_srgb') elif FLAGS.loss == "logl2": return lambda img, ref: ru.image_loss(img, ref, loss='mse', tonemapper='log_srgb') elif FLAGS.loss == "relmse": return lambda img, ref: ru.image_loss(img, ref, loss='relmse', tonemapper='none') else: assert False ############################################################################### # Mix background into a dataset image ############################################################################### @torch.no_grad() def prepare_batch(target, bg_type='black'): assert len(target['img'].shape) == 4, "Image shape should be [n, h, w, c]" if bg_type == 'checker': background = torch.tensor(util.checkerboard(target['img'].shape[1:3], 8), dtype=torch.float32, device='cuda')[None, ...] elif bg_type == 'black': background = torch.zeros(target['img'].shape[0:3] + (3,), dtype=torch.float32, device='cuda') elif bg_type == 'white': background = torch.ones(target['img'].shape[0:3] + (3,), dtype=torch.float32, device='cuda') elif bg_type == 'reference': background = target['img'][..., 0:3] elif bg_type == 'random': background = torch.rand(target['img'].shape[0:3] + (3,), dtype=torch.float32, device='cuda') else: assert False, "Unknown background type %s" % bg_type target['mv'] = target['mv'].cuda() target['mvp'] = target['mvp'].cuda() target['campos'] = target['campos'].cuda() target['img'] = target['img'].cuda() target['background'] = background target['img'] = torch.cat((torch.lerp(background, target['img'][..., 0:3], target['img'][..., 3:4]), target['img'][..., 3:4]), dim=-1) return target ############################################################################### # UV - map geometry & convert to a mesh ############################################################################### @torch.no_grad() def xatlas_uvmap(glctx, geometry, mat, FLAGS): eval_mesh = geometry.getMesh(mat) try: eval_mesh = eval_mesh['imesh'] except: pass # Create uvs with xatlas v_pos = eval_mesh.v_pos.detach().cpu().numpy() t_pos_idx = eval_mesh.t_pos_idx.detach().cpu().numpy() vmapping, indices, uvs = xatlas.parametrize(v_pos, t_pos_idx) # Convert to tensors indices_int64 = indices.astype(np.uint64, casting='same_kind').view(np.int64) uvs = torch.tensor(uvs, dtype=torch.float32, device='cuda') faces = torch.tensor(indices_int64, dtype=torch.int64, device='cuda') new_mesh = mesh.Mesh(v_tex=uvs, t_tex_idx=faces, base=eval_mesh) mask, kd, ks = render.render_uv(glctx, new_mesh, FLAGS.texture_res, eval_mesh.material['kd_ks']) # Dilate all textures & use average color for background kd_avg = torch.sum(torch.sum(torch.sum(kd * mask, dim=0), dim=0), dim=0) / torch.sum(torch.sum(torch.sum(mask, dim=0), dim=0), dim=0) kd = util.dilate(kd, kd_avg[None, None, None, :], mask, 7) ks_avg = torch.sum(torch.sum(torch.sum(ks * mask, dim=0), dim=0), dim=0) / torch.sum(torch.sum(torch.sum(mask, dim=0), dim=0), dim=0) ks = util.dilate(ks, ks_avg[None, None, None, :], mask, 7) nrm_avg = torch.tensor([0, 0, 1], dtype=torch.float32, device="cuda") normal = nrm_avg[None, None, None, :].repeat(kd.shape[0], kd.shape[1], kd.shape[2], 1) new_mesh.material = mat.copy() del new_mesh.material['kd_ks'] if FLAGS.transparency: kd = torch.cat((kd, torch.rand_like(kd[...,0:1])), dim=-1) print("kd shape", kd.shape) kd_min, kd_max = torch.tensor(FLAGS.kd_min, dtype=torch.float32, device='cuda'), torch.tensor(FLAGS.kd_max, dtype=torch.float32, device='cuda') ks_min, ks_max = torch.tensor(FLAGS.ks_min, dtype=torch.float32, device='cuda'), torch.tensor(FLAGS.ks_max, dtype=torch.float32, device='cuda') nrm_min, nrm_max = torch.tensor(FLAGS.nrm_min, dtype=torch.float32, device='cuda'), torch.tensor(FLAGS.nrm_max, dtype=torch.float32, device='cuda') new_mesh.material.update({ 'kd' : texture.Texture2D(kd.clone().detach().requires_grad_(True), min_max=[kd_min, kd_max]), 'ks' : texture.Texture2D(ks.clone().detach().requires_grad_(True), min_max=[ks_min, ks_max]), 'normal' : texture.Texture2D(normal.clone().detach().requires_grad_(True), min_max=[nrm_min, nrm_max]), }) return new_mesh ############################################################################### # Utility functions for material ############################################################################### def initial_guess_material(geometry, mlp, FLAGS, init_mat=None): kd_min, kd_max = torch.tensor(FLAGS.kd_min, dtype=torch.float32, device='cuda'), torch.tensor(FLAGS.kd_max, dtype=torch.float32, device='cuda') ks_min, ks_max = torch.tensor(FLAGS.ks_min, dtype=torch.float32, device='cuda'), torch.tensor(FLAGS.ks_max, dtype=torch.float32, device='cuda') if mlp: mlp_min = torch.cat((kd_min[0:3], ks_min), dim=0) mlp_max = torch.cat((kd_max[0:3], ks_max), dim=0) mlp_map_opt = mlptexture.MLPTexture3D(geometry.getAABB(), channels=6, min_max=[mlp_min, mlp_max], use_float16=FLAGS.use_float16) mat = {'kd_ks' : mlp_map_opt} else: raise NotImplementedError mat['bsdf'] = FLAGS.bsdf mat['no_perturbed_nrm'] = FLAGS.no_perturbed_nrm return mat def initial_guess_material_knownkskd(geometry, mlp, FLAGS, init_mat=None): mat = { 'kd' : init_mat['kd'], 'ks' : init_mat['ks'] } if init_mat is not None: mat['bsdf'] = init_mat['bsdf'] else: mat['bsdf'] = 'pbr' return mat ############################################################################### # Validation & testing ############################################################################### @torch.no_grad() def validate_itr(glctx, target, geometry, opt_material, lgt, FLAGS, denoiser=None): result_dict = {} with torch.no_grad(): buffers = geometry.render(glctx, target, lgt, opt_material, use_uv=False, denoiser=denoiser)['buffers'] result_dict['ref'] = util.rgb_to_srgb(target['img'][...,0:3])[0] result_dict['opt'] = util.rgb_to_srgb(buffers['shaded'][...,0:3])[0] result_dict['mask_opt'] = buffers['shaded'][...,3:][0].expand(-1, -1, 3) result_dict['mask_ref'] = target['img'][...,3:][0].expand(-1, -1, 3) result_dict['msdf_image'] = buffers['msdf_image'][...,:][0].expand(-1, -1, 3).clamp(min=0, max=1) result_image = torch.cat([result_dict['opt'], result_dict['ref'], result_dict['mask_opt'], result_dict['mask_ref'], result_dict['msdf_image']], axis=1) if FLAGS.display is not None: white_bg = torch.ones_like(target['background']) for layer in FLAGS.display: if 'latlong' in layer and layer['latlong']: result_dict['light_image'] = lgt.generate_image(FLAGS.display_res) result_dict['light_image'] = util.rgb_to_srgb(result_dict['light_image'] / (1 + result_dict['light_image'])) result_image = torch.cat([result_image, result_dict['light_image']], axis=1) return result_image, result_dict @torch.no_grad() def validate(glctx, geometry, opt_material, lgt, dataset_validate, out_dir, FLAGS, denoiser=None, save_viz=False): # ============================================================================================== # Validation loop # ============================================================================================== mse_values = [] psnr_values = [] dataloader_validate = torch.utils.data.DataLoader(dataset_validate, batch_size=1, collate_fn=dataset_validate.collate) os.makedirs(out_dir, exist_ok=True) with open(os.path.join(out_dir, 'metrics.txt'), 'w') as fout: fout.write('ID, MSE, PSNR\n') print("Running validation") for it, target in enumerate(dataloader_validate): # Mix validation background target = prepare_batch(target, FLAGS.background) result_image, result_dict = validate_itr(glctx, target, geometry, opt_material, lgt, FLAGS, denoiser=denoiser) # Compute metrics opt = torch.clamp(result_dict['opt'], 0.0, 1.0) ref = torch.clamp(result_dict['ref'], 0.0, 1.0) mse = torch.nn.functional.mse_loss(opt, ref, size_average=None, reduce=None, reduction='mean').item() mse_values.append(float(mse)) psnr = util.mse_to_psnr(mse) psnr_values.append(float(psnr)) line = "%d, %1.8f, %1.8f\n" % (it, mse, psnr) fout.write(str(line)) if save_viz: for k in result_dict.keys(): np_img = result_dict[k].detach().cpu().numpy() util.save_image(out_dir + '/' + ('val_%06d_%s.png' % (it, k)), np_img) avg_mse = np.mean(np.array(mse_values)) avg_psnr = np.mean(np.array(psnr_values)) line = "AVERAGES: %1.4f, %2.3f\n" % (avg_mse, avg_psnr) fout.write(str(line)) print("MSE, PSNR") print("%1.8f, %2.3f" % (avg_mse, avg_psnr)) return avg_psnr ############################################################################### # Main shape fitter function / optimization loop ############################################################################### def optimize_mesh( denoiser, glctx, geometry, opt_material, lgt, dataset_train, dataset_validate, FLAGS, warmup_iter=0, log_interval=10, pass_idx=0, pass_name="", optimize_light=True, optimize_geometry=True, visualize=True, save_path=None ): # ============================================================================================== # Setup torch optimizer # ============================================================================================== learning_rate = FLAGS.learning_rate[pass_idx] if isinstance(FLAGS.learning_rate, list) or isinstance(FLAGS.learning_rate, tuple) else FLAGS.learning_rate learning_rate_pos = learning_rate[0] if isinstance(learning_rate, list) or isinstance(learning_rate, tuple) else learning_rate learning_rate_mat = learning_rate[1] if isinstance(learning_rate, list) or isinstance(learning_rate, tuple) else learning_rate learning_rate_lgt = learning_rate[2] if isinstance(learning_rate, list) or isinstance(learning_rate, tuple) else learning_rate * 6.0 def lr_schedule(iter, fraction): if iter < warmup_iter: return iter / warmup_iter return max(0.0, 10**(-(iter - warmup_iter)*0.0002)) # Exponential falloff from [1.0, 0.1] over 5k epochs. # ============================================================================================== # Image loss # ============================================================================================== image_loss_fn = createLoss(FLAGS) params = list(material.get_parameters(opt_material)) if optimize_light: optimizer_light = torch.optim.Adam((lgt.parameters() if lgt is not None else []), lr=learning_rate_lgt) scheduler_light = torch.optim.lr_scheduler.LambdaLR(optimizer_light, lr_lambda=lambda x: lr_schedule(x, 0.9)) if optimize_geometry: if FLAGS.use_sdf_mlp: lr_msdf = learning_rate_pos * 1e-2 if FLAGS.use_msdf_mlp else learning_rate_pos deform_params = list(v[1] for v in geometry.named_parameters() if 'deform' in v[0]) if optimize_geometry else [] msdf_params = list(v[1] for v in geometry.named_parameters() if 'msdf' in v[0]) if optimize_geometry else [] sdf_params = list(v[1] for v in geometry.named_parameters() if 'sdf' in v[0] and 'msdf' not in v[0]) if optimize_geometry else [] other_params = list(v[1] for v in geometry.named_parameters() if 'sdf' not in v[0] and 'msdf' not in v[0] and 'deform' not in v[0]) if optimize_geometry else [] optimizer_mesh = torch.optim.Adam([ {'params': deform_params, 'lr': learning_rate_pos}, {'params': msdf_params, 'lr': lr_msdf}, {'params': sdf_params, 'lr': learning_rate_pos * 1e-2}, {'params': other_params, 'lr': learning_rate_pos * 1e-2}, ], eps=1e-8) else: optimizer_mesh = torch.optim.Adam(geometry.parameters(), lr=learning_rate_pos) scheduler_mesh = torch.optim.lr_scheduler.LambdaLR(optimizer_mesh, lr_lambda=lambda x: lr_schedule(x, 0.9)) optimizer = torch.optim.Adam(params, lr=learning_rate_mat) scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda x: lr_schedule(x, 0.9)) # ============================================================================================== # Training loop # ============================================================================================== img_cnt = 0 img_loss_vec = [] depth_loss_vec = [] reg_loss_vec = [] iter_dur_vec = [] dataloader_train = torch.utils.data.DataLoader(dataset_train, batch_size=FLAGS.batch, collate_fn=dataset_train.collate, shuffle=True) if visualize: dataloader_validate = torch.utils.data.DataLoader(dataset_validate, batch_size=1, collate_fn=dataset_train.collate) def cycle(iterable): iterator = iter(iterable) while True: try: yield next(iterator) except StopIteration: iterator = iter(iterable) v_it = cycle(dataloader_validate) for it, target in enumerate(dataloader_train): # Mix randomized background into dataset image target = prepare_batch(target, 'random') # ============================================================================================== # Display / save outputs. Do it before training so we get initial meshes # ============================================================================================== # Show/save image before training step (want to get correct rendering of input) if visualize and FLAGS.local_rank == 0 and it != 0: with torch.no_grad(): display_image = FLAGS.display_interval and (it % FLAGS.display_interval == 0) save_image = FLAGS.save_interval and (it % FLAGS.save_interval == 0) if display_image or save_image: save_mesh = True if save_mesh: os.makedirs(os.path.join(save_path, pass_name), exist_ok=True) obj.write_obj(os.path.join(save_path, pass_name), geometry.getMesh(opt_material)['imesh'], save_material=False) result_image, result_dict = validate_itr(glctx, prepare_batch(next(v_it), FLAGS.background), geometry, opt_material, lgt, FLAGS, denoiser=denoiser) np_result_image = result_image.detach().cpu().numpy() if display_image: util.display_image(np_result_image, title='%d / %d' % (it, FLAGS.iter)) if save_image: util.save_image(os.path.join(save_path, ('img_%s_%06d.png' % (pass_name, img_cnt))), np_result_image) img_cnt = img_cnt + 1 iter_start_time = time.time() # ============================================================================================== # Zero gradients # ============================================================================================== optimizer.zero_grad() if optimize_geometry: optimizer_mesh.zero_grad() if optimize_light: optimizer_light.zero_grad() # ============================================================================================== # Training # ============================================================================================== xfm_lgt = None if optimize_light: lgt.update_pdf() img_loss, depth_loss, reg_loss = geometry.tick( glctx, target, lgt, opt_material, image_loss_fn, it, denoiser=denoiser) # ============================================================================================== # Final loss # ============================================================================================== total_loss = img_loss + reg_loss img_loss_vec.append(img_loss.item()) depth_loss_vec.append(depth_loss.item()) reg_loss_vec.append(reg_loss.item()) # ============================================================================================== # Backpropagate # ============================================================================================== total_loss.backward() if hasattr(lgt, 'base') and lgt.base.grad is not None and optimize_light: lgt.base.grad *= 64 if 'kd_ks' in opt_material: opt_material['kd_ks'].encoder.params.grad /= 8.0 if 'kd_ks_back' in opt_material: opt_material['kd_ks_back'].encoder.params.grad /= 8.0 # Optionally clip gradients if FLAGS.clip_max_norm > 0.0: if optimize_geometry: torch.nn.utils.clip_grad_norm_(geometry.parameters() + params, FLAGS.clip_max_norm) else: torch.nn.utils.clip_grad_norm_(params, FLAGS.clip_max_norm) optimizer.step() scheduler.step() if optimize_geometry: optimizer_mesh.step() scheduler_mesh.step() if optimize_light: optimizer_light.step() scheduler_light.step() # ============================================================================================== # Clamp trainables to reasonable range # ============================================================================================== with torch.no_grad(): if 'kd' in opt_material: opt_material['kd'].clamp_() if 'ks' in opt_material: opt_material['ks'].clamp_() if 'kd_back' in opt_material: opt_material['kd_back'].clamp_() if 'ks_back' in opt_material: opt_material['ks_back'].clamp_() if 'normal' in opt_material and not FLAGS.normal_only: opt_material['normal'].clamp_() opt_material['normal'].normalize_() if lgt is not None: # lgt.clamp_(min=0.01) # For some reason gradient dissapears if light becomes 0 lgt.clamp_(min=1e-4) # For some reason gradient dissapears if light becomes 0 geometry.clamp_deform() torch.cuda.current_stream().synchronize() iter_dur_vec.append(time.time() - iter_start_time) # ============================================================================================== # Logging # ============================================================================================== if it % log_interval == 0 and FLAGS.local_rank == 0: img_loss_avg = np.mean(np.asarray(img_loss_vec[-log_interval:])) depth_loss_avg = np.mean(np.asarray(depth_loss_vec[-log_interval:])) reg_loss_avg = np.mean(np.asarray(reg_loss_vec[-log_interval:])) iter_dur_avg = np.mean(np.asarray(iter_dur_vec[-log_interval:])) remaining_time = (FLAGS.iter-it)*iter_dur_avg print("iter=%5d, img_loss=%.6f, depth_loss=%.6f, reg_loss=%.6f, lr=%.5f, time=%.1f ms, rem=%s" % (it, img_loss_avg, depth_loss_avg, reg_loss_avg, optimizer.param_groups[0]['lr'], iter_dur_avg*1000, util.time_to_text(remaining_time))) sys.stdout.flush() if it == FLAGS.iter: break return geometry, opt_material #---------------------------------------------------------------------------- # Main function. #---------------------------------------------------------------------------- if __name__ == "__main__": parser = argparse.ArgumentParser(description='nvdiffrec') parser.add_argument('--config', type=str, default=None, help='Config file') parser.add_argument('-i', '--iter', type=int, default=5000) parser.add_argument('-b', '--batch', type=int, default=1) parser.add_argument('-s', '--spp', type=int, default=1) parser.add_argument('-l', '--layers', type=int, default=1) parser.add_argument('-r', '--train-res', nargs=2, type=int, default=[512, 512]) parser.add_argument('-dr', '--display-res', type=int, default=None) parser.add_argument('-tr', '--texture-res', nargs=2, type=int, default=[1024, 1024]) parser.add_argument('-di', '--display-interval', type=int, default=0) parser.add_argument('-si', '--save-interval', type=int, default=1000) parser.add_argument('-lr', '--learning-rate', type=float, default=0.01) parser.add_argument('-mr', '--min-roughness', type=float, default=0.08) parser.add_argument('-mip', '--custom-mip', action='store_true', default=False) parser.add_argument('-rt', '--random-textures', action='store_true', default=False) parser.add_argument('-bg', '--background', default='checker', choices=['black', 'white', 'checker', 'reference']) parser.add_argument('--loss', default='logl1', choices=['logl1', 'logl2', 'mse', 'smape', 'relmse']) parser.add_argument('-o', '--out-dir', type=str, default=None) parser.add_argument('-rm', '--ref_mesh', type=str) parser.add_argument('-bm', '--base-mesh', type=str, default=None) parser.add_argument('--validate', type=bool, default=True) # Render specific arguments parser.add_argument('--n_samples', type=int, default=4) parser.add_argument('--bsdf', type=str, default='pbr', choices=['pbr', 'diffuse', 'white']) # Denoiser specific arguments parser.add_argument('--denoiser', default='bilateral', choices=['none', 'bilateral']) parser.add_argument('--denoiser_demodulate', type=bool, default=True) parser.add_argument('--index',type=int) parser.add_argument('--msdf_reg_open_scale', type=float, default=1e-6) parser.add_argument('--msdf_reg_close_scale', type=float, default=3e-6) parser.add_argument('--eikonal_scale', type=float) parser.add_argument('--sdf_regularizer', type=float, default=0.2) FLAGS = parser.parse_args() FLAGS.mtl_override = None # Override material of model FLAGS.gshell_grid = 64 # Resolution of initial tet grid. We provide 64 and 128 resolution grids. # Other resolutions can be generated with https://github.com/crawforddoran/quartet # We include examples in data/tets/generate_tets.py FLAGS.mesh_scale = 2.1 # Scale of tet grid box. Adjust to cover the model FLAGS.envlight = None # HDR environment probe FLAGS.env_scale = 1.0 # Env map intensity multiplier FLAGS.probe_res = 256 # Env map probe resolution FLAGS.learn_lighting = True # Enable optimization of env lighting FLAGS.display = None # Configure validation window/display. E.g. [{"bsdf" : "kd"}, {"bsdf" : "ks"}] FLAGS.transparency = False # Enabled transparency through depth peeling FLAGS.lock_light = False # Disable light optimization in the second pass FLAGS.lock_pos = False # Disable vertex position optimization in the second pass # FLAGS.sdf_regularizer = 0.2 # Weight for sdf regularizer. FLAGS.laplace = "relative" # Mesh Laplacian ["absolute", "relative"] FLAGS.laplace_scale = 3000.0 # Weight for Laplace regularizer. Default is relative with large weight FLAGS.pre_load = True # Pre-load entire dataset into memory for faster training FLAGS.no_perturbed_nrm = False # Disable normal map FLAGS.decorrelated = False # Use decorrelated sampling in forward and backward passes FLAGS.kd_min = [ 0.0, 0.0, 0.0, 0.0] FLAGS.kd_max = [ 1.0, 1.0, 1.0, 1.0] FLAGS.ks_min = [ 0.0, 0.001, 0.0] FLAGS.ks_max = [ 0.0, 1.0, 1.0] FLAGS.nrm_min = [-1.0, -1.0, 0.0] FLAGS.nrm_max = [ 1.0, 1.0, 1.0] FLAGS.clip_max_norm = 0.0 FLAGS.cam_near_far = [0.1, 1000.0] FLAGS.lambda_kd = 0.1 FLAGS.lambda_ks = 0.05 FLAGS.lambda_nrm = 0.025 FLAGS.lambda_nrm2 = 0.25 FLAGS.lambda_chroma = 0.0 FLAGS.lambda_diffuse = 0.15 FLAGS.lambda_specular = 0.0025 FLAGS.random_lgt = False FLAGS.normal_only = False FLAGS.use_img_2nd_layer = False FLAGS.use_depth = False FLAGS.use_depth_2nd_layer = False FLAGS.use_tanh_deform = False FLAGS.use_sdf_mlp = True FLAGS.use_msdf_mlp = False FLAGS.use_eikonal = True FLAGS.use_mesh_msdf_reg = True FLAGS.sphere_init = False FLAGS.sphere_init_norm = 1.0 FLAGS.pretrained_sdf_mlp_path = f'./data/pretrained_mlp_{FLAGS.gshell_grid}_synthetic.pt' FLAGS.n_hidden = 6 FLAGS.d_hidden = 256 FLAGS.n_freq = 6 FLAGS.skip_in = [3] FLAGS.use_float16 = False FLAGS.visualize_watertight = False FLAGS.local_rank = 0 FLAGS.multi_gpu = "WORLD_SIZE" in os.environ and int(os.environ["WORLD_SIZE"]) > 1 if FLAGS.multi_gpu: if "MASTER_ADDR" not in os.environ: os.environ["MASTER_ADDR"] = 'localhost' if "MASTER_PORT" not in os.environ: os.environ["MASTER_PORT"] = '23456' FLAGS.local_rank = int(os.environ["LOCAL_RANK"]) torch.cuda.set_device(FLAGS.local_rank) torch.distributed.init_process_group(backend="nccl", init_method="env://") if FLAGS.config is not None: data = json.load(open(FLAGS.config, 'r')) for key in data: FLAGS.__dict__[key] = data[key] if FLAGS.display_res is None: FLAGS.display_res = FLAGS.train_res if FLAGS.local_rank == 0: print("Config / Flags:") print("---------") for key in FLAGS.__dict__.keys(): print(key, FLAGS.__dict__[key]) print("---------") os.makedirs(FLAGS.out_dir, exist_ok=True) glctx = dr.RasterizeGLContext() glctx_display = glctx if FLAGS.batch < 16 else dr.RasterizeGLContext() # Context for display mtl_default = None # ============================================================================================== # Create data pipeline # ============================================================================================== print(FLAGS.ref_mesh) if os.path.splitext(FLAGS.ref_mesh)[1] == '.obj': ref_mesh = mesh.load_mesh(FLAGS.ref_mesh, FLAGS.mtl_override) dataset_train = DatasetMesh(ref_mesh, glctx, RADIUS, FLAGS, validate=False) dataset_validate = DatasetMesh(ref_mesh, glctx_display, RADIUS, FLAGS, validate=True) elif os.path.isdir(FLAGS.ref_mesh): if os.path.isfile(os.path.join(FLAGS.ref_mesh, 'poses_bounds.npy')): dataset_train = DatasetLLFF(FLAGS.ref_mesh, FLAGS, examples=(FLAGS.iter+1)*FLAGS.batch) dataset_validate = DatasetLLFF(FLAGS.ref_mesh, FLAGS) elif os.path.isfile(os.path.join(FLAGS.ref_mesh, 'transforms_train.json')) and not os.path.isfile(os.path.join(FLAGS.ref_mesh, 'intrinsics.txt')): dataset_train = DatasetNERF(os.path.join(FLAGS.ref_mesh, 'transforms_train.json'), FLAGS, examples=(FLAGS.iter+1)*FLAGS.batch) dataset_validate = DatasetNERF(os.path.join(FLAGS.ref_mesh, 'transforms_test.json'), FLAGS) else: assert False, "Invalid dataset format" else: print("Invalid dataset format", FLAGS.ref_mesh) assert False, "Invalid dataset format" # ============================================================================================== # Create env light with trainable parameters # ============================================================================================== lgt = None if FLAGS.learn_lighting: lgt = light.create_trainable_env_rnd(FLAGS.probe_res, scale=0.0, bias=0.5) # lgt = light.create_trainable_env_rnd(FLAGS.probe_res, scale=0.0, bias=0.1) else: lgt = light.load_env(FLAGS.envlight, scale=FLAGS.env_scale, res=[FLAGS.probe_res, FLAGS.probe_res]) # ============================================================================================== # Setup denoiser # ============================================================================================== denoiser = None if FLAGS.denoiser == 'bilateral': denoiser = BilateralDenoiser().cuda() else: assert FLAGS.denoiser == 'none', "Invalid denoiser %s" % FLAGS.denoiser # Setup geometry for optimization geometry = GShellTetsGeometry(FLAGS.gshell_grid, FLAGS.mesh_scale, FLAGS) # Setup textures, make initial guess from reference if possible if not FLAGS.normal_only: mat = initial_guess_material(geometry, True, FLAGS, mtl_default) else: mat = initial_guess_material_knownkskd(geometry, True, FLAGS, mtl_default) mat['no_perturbed_nrm'] = True # Run optimization geometry, mat = optimize_mesh(denoiser, glctx, geometry, mat, lgt, dataset_train, dataset_validate, FLAGS, pass_idx=0, pass_name="pass1", optimize_light=FLAGS.learn_lighting, save_path=FLAGS.out_dir) validate(glctx, geometry, mat, lgt, dataset_validate, os.path.join(FLAGS.out_dir, "validate"), FLAGS, denoiser=denoiser, save_viz=True) with torch.no_grad(): os.makedirs(os.path.join(FLAGS.out_dir, "mesh"), exist_ok=True) torch.save(geometry.state_dict(), os.path.join(FLAGS.out_dir, "mesh/model.pt")) torch.save(mat['kd_ks'].state_dict(), os.path.join(FLAGS.out_dir, "mesh/mtl.pt")) light.save_env_map(os.path.join(FLAGS.out_dir, "mesh/probe.hdr"), lgt) # Create textured mesh from result base_mesh = geometry.getMesh(mat)['imesh'] # Dump mesh for debugging. os.makedirs(os.path.join(FLAGS.out_dir, "mesh"), exist_ok=True) obj.write_obj(os.path.join(FLAGS.out_dir, "mesh/"), base_mesh, save_material=False) # Free temporaries / cached memory torch.cuda.empty_cache() mat['kd_ks'].cleanup() del mat['kd_ks'] if 'kd_ks_back' in mat: mat['kd_ks_back'].cleanup() del mat['kd_ks_back'] # Free temporaries / cached memory torch.cuda.empty_cache() del mat